diff options
-rw-r--r-- | src/vulkan/compute_pipeline.cc | 7 | ||||
-rw-r--r-- | src/vulkan/engine_vulkan.cc | 12 | ||||
-rw-r--r-- | src/vulkan/graphics_pipeline.cc | 4 | ||||
-rw-r--r-- | src/vulkan/graphics_pipeline.h | 2 | ||||
-rw-r--r-- | src/vulkan/pipeline.cc | 13 | ||||
-rw-r--r-- | src/vulkan/pipeline.h | 13 | ||||
-rw-r--r-- | tests/cases/compute_ssbo_with_entrypoint_command.amber | 173 |
7 files changed, 216 insertions, 8 deletions
diff --git a/src/vulkan/compute_pipeline.cc b/src/vulkan/compute_pipeline.cc index fed8414..1da7f39 100644 --- a/src/vulkan/compute_pipeline.cc +++ b/src/vulkan/compute_pipeline.cc @@ -35,13 +35,18 @@ Result ComputePipeline::Initialize(VkCommandPool pool, VkQueue queue) { } Result ComputePipeline::CreateVkComputePipeline() { - const auto& shader_stage_info = GetShaderStageInfo(); + auto shader_stage_info = GetShaderStageInfo(); if (shader_stage_info.size() != 1) { return Result( "Vulkan::CreateVkComputePipeline number of shaders given to compute " "pipeline is not 1"); } + if (shader_stage_info[0].stage != VK_SHADER_STAGE_COMPUTE_BIT) + return Result("Vulkan: Non compute shader for compute pipeline"); + + shader_stage_info[0].pName = GetEntryPointName(VK_SHADER_STAGE_COMPUTE_BIT); + Result r = CreateVkDescriptorRelatedObjectsIfNeeded(); if (!r.IsSuccess()) return r; diff --git a/src/vulkan/engine_vulkan.cc b/src/vulkan/engine_vulkan.cc index 3302f24..893ccda 100644 --- a/src/vulkan/engine_vulkan.cc +++ b/src/vulkan/engine_vulkan.cc @@ -164,8 +164,7 @@ EngineVulkan::GetShaderStageInfo() { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; stage_info[stage_count].stage = ToVkShaderStage(it.first); stage_info[stage_count].module = it.second; - // TODO(jaebaek): Handle entry point command - stage_info[stage_count].pName = "main"; + stage_info[stage_count].pName = nullptr; ++stage_count; } return stage_info; @@ -320,8 +319,13 @@ Result EngineVulkan::DoCompute(const ComputeCommand* command) { command->GetZ()); } -Result EngineVulkan::DoEntryPoint(const EntryPointCommand*) { - return Result("Vulkan::DoEntryPoint Not Implemented"); +Result EngineVulkan::DoEntryPoint(const EntryPointCommand* command) { + if (!pipeline_) + return Result("Vulkan::DoEntryPoint no Pipeline exists"); + + pipeline_->SetEntryPointName(ToVkShaderStage(command->GetShaderType()), + command->GetEntryPointName()); + return {}; } Result EngineVulkan::DoPatchParameterVertices( diff --git a/src/vulkan/graphics_pipeline.cc b/src/vulkan/graphics_pipeline.cc index cd54bad..e6c2dac 100644 --- a/src/vulkan/graphics_pipeline.cc +++ b/src/vulkan/graphics_pipeline.cc @@ -260,7 +260,9 @@ Result GraphicsPipeline::CreateVkGraphicsPipeline( viewport_info.scissorCount = 1; viewport_info.pScissors = &scissor; - const auto& shader_stage_info = GetShaderStageInfo(); + auto shader_stage_info = GetShaderStageInfo(); + for (auto& info : shader_stage_info) + info.pName = GetEntryPointName(info.stage); VkGraphicsPipelineCreateInfo pipeline_info = {}; pipeline_info.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO; diff --git a/src/vulkan/graphics_pipeline.h b/src/vulkan/graphics_pipeline.h index f170e92..907aab8 100644 --- a/src/vulkan/graphics_pipeline.h +++ b/src/vulkan/graphics_pipeline.h @@ -109,8 +109,6 @@ class GraphicsPipeline : public Pipeline { VkFormat color_format_; VkFormat depth_stencil_format_; - std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info_; - uint32_t frame_width_ = 0; uint32_t frame_height_ = 0; diff --git a/src/vulkan/pipeline.cc b/src/vulkan/pipeline.cc index 3a4be57..0da91de 100644 --- a/src/vulkan/pipeline.cc +++ b/src/vulkan/pipeline.cc @@ -28,6 +28,11 @@ namespace amber { namespace vulkan { +namespace { + +const char* kDefaultEntryPointName = "main"; + +} // namespace Pipeline::Pipeline( PipelineType type, @@ -401,5 +406,13 @@ Result Pipeline::GetDescriptorInfo(const uint32_t descriptor_set, ", binding: " + std::to_string(binding) + " does not exist"); } +const char* Pipeline::GetEntryPointName(VkShaderStageFlagBits stage) const { + auto it = entry_points_.find(stage); + if (it != entry_points_.end()) + return it->second.c_str(); + + return kDefaultEntryPointName; +} + } // namespace vulkan } // namespace amber diff --git a/src/vulkan/pipeline.h b/src/vulkan/pipeline.h index c087c4e..57d8646 100644 --- a/src/vulkan/pipeline.h +++ b/src/vulkan/pipeline.h @@ -16,9 +16,12 @@ #define SRC_VULKAN_PIPELINE_H_ #include <memory> +#include <string> +#include <unordered_map> #include <vector> #include "amber/result.h" +#include "src/cast_hash.h" #include "src/engine.h" #include "src/vulkan/command.h" #include "src/vulkan/descriptor.h" @@ -55,6 +58,11 @@ class Pipeline { const uint32_t binding, ResourceInfo* info); + void SetEntryPointName(VkShaderStageFlagBits stage, + const std::string& entry) { + entry_points_[stage] = entry; + } + virtual void Shutdown(); virtual Result ProcessCommands() = 0; @@ -78,6 +86,7 @@ class Pipeline { return shader_stage_info_; } + const char* GetEntryPointName(VkShaderStageFlagBits stage) const; uint32_t GetFenceTimeout() const { return fence_timeout_ms_; } VkPipeline pipeline_ = VK_NULL_HANDLE; @@ -107,6 +116,10 @@ class Pipeline { std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info_; uint32_t fence_timeout_ms_ = 100; bool descriptor_related_objects_already_created_ = false; + std::unordered_map<VkShaderStageFlagBits, + std::string, + CastHash<VkShaderStageFlagBits>> + entry_points_; }; } // namespace vulkan diff --git a/tests/cases/compute_ssbo_with_entrypoint_command.amber b/tests/cases/compute_ssbo_with_entrypoint_command.amber new file mode 100644 index 0000000..eae3aeb --- /dev/null +++ b/tests/cases/compute_ssbo_with_entrypoint_command.amber @@ -0,0 +1,173 @@ +# Copyright 2018 The Amber Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[comment] +Source code in GLSL: + +#version 430 + +layout(set = 0, binding = 0) buffer block0 { + float data_set0_binding0[3]; +}; + +layout(set = 1, binding = 2) buffer block1 { + float data_set1_binding2[3]; +}; + +layout(set = 2, binding = 1) buffer block2 { + float data_set2_binding1[3]; +}; + +layout(set = 2, binding = 3) buffer block3 { + float data_set2_binding3[3]; +}; + +[compute shader spirv] +; SPIR-V +; Version: 1.0 +; Generator: Khronos Glslang Reference Front End; 7 +; Bound: 71 +; Schema: 0 + OpCapability Shader + %1 = OpExtInstImport "GLSL.std.450" + OpMemoryModel Logical GLSL450 + OpEntryPoint GLCompute %main "foo" %gl_WorkGroupID + OpExecutionMode %main LocalSize 1 1 1 + OpSource GLSL 430 + OpName %main "foo" + OpName %index "index" + OpName %gl_WorkGroupID "gl_WorkGroupID" + OpName %block0 "block0" + OpMemberName %block0 0 "data_set0_binding0" + OpName %_ "" + OpName %block1 "block1" + OpMemberName %block1 0 "data_set1_binding2" + OpName %__0 "" + OpName %block2 "block2" + OpMemberName %block2 0 "data_set2_binding1" + OpName %__1 "" + OpName %block3 "block3" + OpMemberName %block3 0 "data_set2_binding3" + OpName %__2 "" + OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId + OpDecorate %_arr_float_uint_3 ArrayStride 4 + OpMemberDecorate %block0 0 Offset 0 + OpDecorate %block0 BufferBlock + OpDecorate %_ DescriptorSet 0 + OpDecorate %_ Binding 0 + OpDecorate %_arr_float_uint_3_0 ArrayStride 4 + OpMemberDecorate %block1 0 Offset 0 + OpDecorate %block1 BufferBlock + OpDecorate %__0 DescriptorSet 1 + OpDecorate %__0 Binding 2 + OpDecorate %_arr_float_uint_3_1 ArrayStride 4 + OpMemberDecorate %block2 0 Offset 0 + OpDecorate %block2 BufferBlock + OpDecorate %__1 DescriptorSet 2 + OpDecorate %__1 Binding 1 + OpDecorate %_arr_float_uint_3_2 ArrayStride 4 + OpMemberDecorate %block3 0 Offset 0 + OpDecorate %block3 BufferBlock + OpDecorate %__2 DescriptorSet 2 + OpDecorate %__2 Binding 3 + %void = OpTypeVoid + %3 = OpTypeFunction %void + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint + %v3uint = OpTypeVector %uint 3 +%_ptr_Input_v3uint = OpTypePointer Input %v3uint +%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input + %uint_0 = OpConstant %uint 0 +%_ptr_Input_uint = OpTypePointer Input %uint + %float = OpTypeFloat 32 + %uint_3 = OpConstant %uint 3 +%_arr_float_uint_3 = OpTypeArray %float %uint_3 + %block0 = OpTypeStruct %_arr_float_uint_3 +%_ptr_Uniform_block0 = OpTypePointer Uniform %block0 + %_ = OpVariable %_ptr_Uniform_block0 Uniform + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 +%_ptr_Uniform_float = OpTypePointer Uniform %float + %float_1 = OpConstant %float 1 +%_arr_float_uint_3_0 = OpTypeArray %float %uint_3 + %block1 = OpTypeStruct %_arr_float_uint_3_0 +%_ptr_Uniform_block1 = OpTypePointer Uniform %block1 + %__0 = OpVariable %_ptr_Uniform_block1 Uniform +%_arr_float_uint_3_1 = OpTypeArray %float %uint_3 + %block2 = OpTypeStruct %_arr_float_uint_3_1 +%_ptr_Uniform_block2 = OpTypePointer Uniform %block2 + %__1 = OpVariable %_ptr_Uniform_block2 Uniform + %float_10 = OpConstant %float 10 +%_arr_float_uint_3_2 = OpTypeArray %float %uint_3 + %block3 = OpTypeStruct %_arr_float_uint_3_2 +%_ptr_Uniform_block3 = OpTypePointer Uniform %block3 + %__2 = OpVariable %_ptr_Uniform_block3 Uniform + %float_30 = OpConstant %float 30 + %main = OpFunction %void None %3 + %5 = OpLabel + %index = OpVariable %_ptr_Function_uint Function + %14 = OpAccessChain %_ptr_Input_uint %gl_WorkGroupID %uint_0 + %15 = OpLoad %uint %14 + OpStore %index %15 + %24 = OpLoad %uint %index + %25 = OpLoad %uint %index + %27 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %25 + %28 = OpLoad %float %27 + %30 = OpFAdd %float %28 %float_1 + %31 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %24 + OpStore %31 %30 + %36 = OpLoad %uint %index + %41 = OpLoad %uint %index + %42 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %41 + %43 = OpLoad %float %42 + %44 = OpLoad %uint %index + %45 = OpAccessChain %_ptr_Uniform_float %__0 %int_0 %44 + %46 = OpLoad %float %45 + %47 = OpFSub %float %43 %46 + %48 = OpAccessChain %_ptr_Uniform_float %__0 %int_0 %36 + OpStore %48 %47 + %49 = OpLoad %uint %index + %55 = OpLoad %uint %index + %56 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %55 + %57 = OpLoad %float %56 + %58 = OpFMul %float %float_10 %57 + %59 = OpLoad %uint %index + %60 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %59 + %61 = OpLoad %float %60 + %62 = OpFAdd %float %58 %61 + %63 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %49 + OpStore %63 %62 + %64 = OpLoad %uint %index + %66 = OpLoad %uint %index + %67 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %66 + %68 = OpLoad %float %67 + %69 = OpFMul %float %float_30 %68 + %70 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %64 + OpStore %70 %69 + OpReturn + OpFunctionEnd +[test] +ssbo 0:0 subdata vec3 0 1.0 2.0 3.0 +ssbo 1:2 subdata vec3 0 4.0 5.0 6.0 +ssbo 2:1 subdata vec3 0 21.0 22.0 23.0 +ssbo 2:3 subdata vec3 0 0.7 0.8 0.9 + +compute entrypoint foo + +compute 3 1 1 + +probe ssbo vec3 0:0 0 ~= 2.0 3.0 4.0 +probe ssbo vec3 1:2 0 ~= 17.0 17.0 17.0 +probe ssbo vec3 2:1 0 ~= 28.0 30.0 32.0 +probe ssbo vec3 2:3 0 ~= 21.0 24.0 27.0 |