aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/vulkan/compute_pipeline.cc7
-rw-r--r--src/vulkan/engine_vulkan.cc12
-rw-r--r--src/vulkan/graphics_pipeline.cc4
-rw-r--r--src/vulkan/graphics_pipeline.h2
-rw-r--r--src/vulkan/pipeline.cc13
-rw-r--r--src/vulkan/pipeline.h13
-rw-r--r--tests/cases/compute_ssbo_with_entrypoint_command.amber173
7 files changed, 216 insertions, 8 deletions
diff --git a/src/vulkan/compute_pipeline.cc b/src/vulkan/compute_pipeline.cc
index fed8414..1da7f39 100644
--- a/src/vulkan/compute_pipeline.cc
+++ b/src/vulkan/compute_pipeline.cc
@@ -35,13 +35,18 @@ Result ComputePipeline::Initialize(VkCommandPool pool, VkQueue queue) {
}
Result ComputePipeline::CreateVkComputePipeline() {
- const auto& shader_stage_info = GetShaderStageInfo();
+ auto shader_stage_info = GetShaderStageInfo();
if (shader_stage_info.size() != 1) {
return Result(
"Vulkan::CreateVkComputePipeline number of shaders given to compute "
"pipeline is not 1");
}
+ if (shader_stage_info[0].stage != VK_SHADER_STAGE_COMPUTE_BIT)
+ return Result("Vulkan: Non compute shader for compute pipeline");
+
+ shader_stage_info[0].pName = GetEntryPointName(VK_SHADER_STAGE_COMPUTE_BIT);
+
Result r = CreateVkDescriptorRelatedObjectsIfNeeded();
if (!r.IsSuccess())
return r;
diff --git a/src/vulkan/engine_vulkan.cc b/src/vulkan/engine_vulkan.cc
index 3302f24..893ccda 100644
--- a/src/vulkan/engine_vulkan.cc
+++ b/src/vulkan/engine_vulkan.cc
@@ -164,8 +164,7 @@ EngineVulkan::GetShaderStageInfo() {
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
stage_info[stage_count].stage = ToVkShaderStage(it.first);
stage_info[stage_count].module = it.second;
- // TODO(jaebaek): Handle entry point command
- stage_info[stage_count].pName = "main";
+ stage_info[stage_count].pName = nullptr;
++stage_count;
}
return stage_info;
@@ -320,8 +319,13 @@ Result EngineVulkan::DoCompute(const ComputeCommand* command) {
command->GetZ());
}
-Result EngineVulkan::DoEntryPoint(const EntryPointCommand*) {
- return Result("Vulkan::DoEntryPoint Not Implemented");
+Result EngineVulkan::DoEntryPoint(const EntryPointCommand* command) {
+ if (!pipeline_)
+ return Result("Vulkan::DoEntryPoint no Pipeline exists");
+
+ pipeline_->SetEntryPointName(ToVkShaderStage(command->GetShaderType()),
+ command->GetEntryPointName());
+ return {};
}
Result EngineVulkan::DoPatchParameterVertices(
diff --git a/src/vulkan/graphics_pipeline.cc b/src/vulkan/graphics_pipeline.cc
index cd54bad..e6c2dac 100644
--- a/src/vulkan/graphics_pipeline.cc
+++ b/src/vulkan/graphics_pipeline.cc
@@ -260,7 +260,9 @@ Result GraphicsPipeline::CreateVkGraphicsPipeline(
viewport_info.scissorCount = 1;
viewport_info.pScissors = &scissor;
- const auto& shader_stage_info = GetShaderStageInfo();
+ auto shader_stage_info = GetShaderStageInfo();
+ for (auto& info : shader_stage_info)
+ info.pName = GetEntryPointName(info.stage);
VkGraphicsPipelineCreateInfo pipeline_info = {};
pipeline_info.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO;
diff --git a/src/vulkan/graphics_pipeline.h b/src/vulkan/graphics_pipeline.h
index f170e92..907aab8 100644
--- a/src/vulkan/graphics_pipeline.h
+++ b/src/vulkan/graphics_pipeline.h
@@ -109,8 +109,6 @@ class GraphicsPipeline : public Pipeline {
VkFormat color_format_;
VkFormat depth_stencil_format_;
- std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info_;
-
uint32_t frame_width_ = 0;
uint32_t frame_height_ = 0;
diff --git a/src/vulkan/pipeline.cc b/src/vulkan/pipeline.cc
index 3a4be57..0da91de 100644
--- a/src/vulkan/pipeline.cc
+++ b/src/vulkan/pipeline.cc
@@ -28,6 +28,11 @@
namespace amber {
namespace vulkan {
+namespace {
+
+const char* kDefaultEntryPointName = "main";
+
+} // namespace
Pipeline::Pipeline(
PipelineType type,
@@ -401,5 +406,13 @@ Result Pipeline::GetDescriptorInfo(const uint32_t descriptor_set,
", binding: " + std::to_string(binding) + " does not exist");
}
+const char* Pipeline::GetEntryPointName(VkShaderStageFlagBits stage) const {
+ auto it = entry_points_.find(stage);
+ if (it != entry_points_.end())
+ return it->second.c_str();
+
+ return kDefaultEntryPointName;
+}
+
} // namespace vulkan
} // namespace amber
diff --git a/src/vulkan/pipeline.h b/src/vulkan/pipeline.h
index c087c4e..57d8646 100644
--- a/src/vulkan/pipeline.h
+++ b/src/vulkan/pipeline.h
@@ -16,9 +16,12 @@
#define SRC_VULKAN_PIPELINE_H_
#include <memory>
+#include <string>
+#include <unordered_map>
#include <vector>
#include "amber/result.h"
+#include "src/cast_hash.h"
#include "src/engine.h"
#include "src/vulkan/command.h"
#include "src/vulkan/descriptor.h"
@@ -55,6 +58,11 @@ class Pipeline {
const uint32_t binding,
ResourceInfo* info);
+ void SetEntryPointName(VkShaderStageFlagBits stage,
+ const std::string& entry) {
+ entry_points_[stage] = entry;
+ }
+
virtual void Shutdown();
virtual Result ProcessCommands() = 0;
@@ -78,6 +86,7 @@ class Pipeline {
return shader_stage_info_;
}
+ const char* GetEntryPointName(VkShaderStageFlagBits stage) const;
uint32_t GetFenceTimeout() const { return fence_timeout_ms_; }
VkPipeline pipeline_ = VK_NULL_HANDLE;
@@ -107,6 +116,10 @@ class Pipeline {
std::vector<VkPipelineShaderStageCreateInfo> shader_stage_info_;
uint32_t fence_timeout_ms_ = 100;
bool descriptor_related_objects_already_created_ = false;
+ std::unordered_map<VkShaderStageFlagBits,
+ std::string,
+ CastHash<VkShaderStageFlagBits>>
+ entry_points_;
};
} // namespace vulkan
diff --git a/tests/cases/compute_ssbo_with_entrypoint_command.amber b/tests/cases/compute_ssbo_with_entrypoint_command.amber
new file mode 100644
index 0000000..eae3aeb
--- /dev/null
+++ b/tests/cases/compute_ssbo_with_entrypoint_command.amber
@@ -0,0 +1,173 @@
+# Copyright 2018 The Amber Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+[comment]
+Source code in GLSL:
+
+#version 430
+
+layout(set = 0, binding = 0) buffer block0 {
+ float data_set0_binding0[3];
+};
+
+layout(set = 1, binding = 2) buffer block1 {
+ float data_set1_binding2[3];
+};
+
+layout(set = 2, binding = 1) buffer block2 {
+ float data_set2_binding1[3];
+};
+
+layout(set = 2, binding = 3) buffer block3 {
+ float data_set2_binding3[3];
+};
+
+[compute shader spirv]
+; SPIR-V
+; Version: 1.0
+; Generator: Khronos Glslang Reference Front End; 7
+; Bound: 71
+; Schema: 0
+ OpCapability Shader
+ %1 = OpExtInstImport "GLSL.std.450"
+ OpMemoryModel Logical GLSL450
+ OpEntryPoint GLCompute %main "foo" %gl_WorkGroupID
+ OpExecutionMode %main LocalSize 1 1 1
+ OpSource GLSL 430
+ OpName %main "foo"
+ OpName %index "index"
+ OpName %gl_WorkGroupID "gl_WorkGroupID"
+ OpName %block0 "block0"
+ OpMemberName %block0 0 "data_set0_binding0"
+ OpName %_ ""
+ OpName %block1 "block1"
+ OpMemberName %block1 0 "data_set1_binding2"
+ OpName %__0 ""
+ OpName %block2 "block2"
+ OpMemberName %block2 0 "data_set2_binding1"
+ OpName %__1 ""
+ OpName %block3 "block3"
+ OpMemberName %block3 0 "data_set2_binding3"
+ OpName %__2 ""
+ OpDecorate %gl_WorkGroupID BuiltIn WorkgroupId
+ OpDecorate %_arr_float_uint_3 ArrayStride 4
+ OpMemberDecorate %block0 0 Offset 0
+ OpDecorate %block0 BufferBlock
+ OpDecorate %_ DescriptorSet 0
+ OpDecorate %_ Binding 0
+ OpDecorate %_arr_float_uint_3_0 ArrayStride 4
+ OpMemberDecorate %block1 0 Offset 0
+ OpDecorate %block1 BufferBlock
+ OpDecorate %__0 DescriptorSet 1
+ OpDecorate %__0 Binding 2
+ OpDecorate %_arr_float_uint_3_1 ArrayStride 4
+ OpMemberDecorate %block2 0 Offset 0
+ OpDecorate %block2 BufferBlock
+ OpDecorate %__1 DescriptorSet 2
+ OpDecorate %__1 Binding 1
+ OpDecorate %_arr_float_uint_3_2 ArrayStride 4
+ OpMemberDecorate %block3 0 Offset 0
+ OpDecorate %block3 BufferBlock
+ OpDecorate %__2 DescriptorSet 2
+ OpDecorate %__2 Binding 3
+ %void = OpTypeVoid
+ %3 = OpTypeFunction %void
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+ %v3uint = OpTypeVector %uint 3
+%_ptr_Input_v3uint = OpTypePointer Input %v3uint
+%gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
+ %uint_0 = OpConstant %uint 0
+%_ptr_Input_uint = OpTypePointer Input %uint
+ %float = OpTypeFloat 32
+ %uint_3 = OpConstant %uint 3
+%_arr_float_uint_3 = OpTypeArray %float %uint_3
+ %block0 = OpTypeStruct %_arr_float_uint_3
+%_ptr_Uniform_block0 = OpTypePointer Uniform %block0
+ %_ = OpVariable %_ptr_Uniform_block0 Uniform
+ %int = OpTypeInt 32 1
+ %int_0 = OpConstant %int 0
+%_ptr_Uniform_float = OpTypePointer Uniform %float
+ %float_1 = OpConstant %float 1
+%_arr_float_uint_3_0 = OpTypeArray %float %uint_3
+ %block1 = OpTypeStruct %_arr_float_uint_3_0
+%_ptr_Uniform_block1 = OpTypePointer Uniform %block1
+ %__0 = OpVariable %_ptr_Uniform_block1 Uniform
+%_arr_float_uint_3_1 = OpTypeArray %float %uint_3
+ %block2 = OpTypeStruct %_arr_float_uint_3_1
+%_ptr_Uniform_block2 = OpTypePointer Uniform %block2
+ %__1 = OpVariable %_ptr_Uniform_block2 Uniform
+ %float_10 = OpConstant %float 10
+%_arr_float_uint_3_2 = OpTypeArray %float %uint_3
+ %block3 = OpTypeStruct %_arr_float_uint_3_2
+%_ptr_Uniform_block3 = OpTypePointer Uniform %block3
+ %__2 = OpVariable %_ptr_Uniform_block3 Uniform
+ %float_30 = OpConstant %float 30
+ %main = OpFunction %void None %3
+ %5 = OpLabel
+ %index = OpVariable %_ptr_Function_uint Function
+ %14 = OpAccessChain %_ptr_Input_uint %gl_WorkGroupID %uint_0
+ %15 = OpLoad %uint %14
+ OpStore %index %15
+ %24 = OpLoad %uint %index
+ %25 = OpLoad %uint %index
+ %27 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %25
+ %28 = OpLoad %float %27
+ %30 = OpFAdd %float %28 %float_1
+ %31 = OpAccessChain %_ptr_Uniform_float %_ %int_0 %24
+ OpStore %31 %30
+ %36 = OpLoad %uint %index
+ %41 = OpLoad %uint %index
+ %42 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %41
+ %43 = OpLoad %float %42
+ %44 = OpLoad %uint %index
+ %45 = OpAccessChain %_ptr_Uniform_float %__0 %int_0 %44
+ %46 = OpLoad %float %45
+ %47 = OpFSub %float %43 %46
+ %48 = OpAccessChain %_ptr_Uniform_float %__0 %int_0 %36
+ OpStore %48 %47
+ %49 = OpLoad %uint %index
+ %55 = OpLoad %uint %index
+ %56 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %55
+ %57 = OpLoad %float %56
+ %58 = OpFMul %float %float_10 %57
+ %59 = OpLoad %uint %index
+ %60 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %59
+ %61 = OpLoad %float %60
+ %62 = OpFAdd %float %58 %61
+ %63 = OpAccessChain %_ptr_Uniform_float %__1 %int_0 %49
+ OpStore %63 %62
+ %64 = OpLoad %uint %index
+ %66 = OpLoad %uint %index
+ %67 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %66
+ %68 = OpLoad %float %67
+ %69 = OpFMul %float %float_30 %68
+ %70 = OpAccessChain %_ptr_Uniform_float %__2 %int_0 %64
+ OpStore %70 %69
+ OpReturn
+ OpFunctionEnd
+[test]
+ssbo 0:0 subdata vec3 0 1.0 2.0 3.0
+ssbo 1:2 subdata vec3 0 4.0 5.0 6.0
+ssbo 2:1 subdata vec3 0 21.0 22.0 23.0
+ssbo 2:3 subdata vec3 0 0.7 0.8 0.9
+
+compute entrypoint foo
+
+compute 3 1 1
+
+probe ssbo vec3 0:0 0 ~= 2.0 3.0 4.0
+probe ssbo vec3 1:2 0 ~= 17.0 17.0 17.0
+probe ssbo vec3 2:1 0 ~= 28.0 30.0 32.0
+probe ssbo vec3 2:3 0 ~= 21.0 24.0 27.0