diff options
author | alan-baker <alanbaker@google.com> | 2020-03-18 08:55:55 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-03-18 08:55:55 -0400 |
commit | 32aa3b89754243dad223508bdb18cc771d927164 (patch) | |
tree | c46a9493d315bb3495300159291f814d1e4056ad | |
parent | 5b8bb777e0ae186dbc96e0b0d91a28f9614e7132 (diff) | |
download | amber-32aa3b89754243dad223508bdb18cc771d927164.tar.gz |
Set the propoer bytes for OpenCL SET commands (#816)
Fixes #814
Convert the SET values into bytes before filling in the buffer
Add a parser check that the type is a scalar
add a test
-rw-r--r-- | src/amberscript/parser.cc | 3 | ||||
-rw-r--r-- | src/amberscript/parser_pipeline_set_test.cc | 17 | ||||
-rw-r--r-- | src/pipeline.cc | 29 |
3 files changed, 48 insertions, 1 deletions
diff --git a/src/amberscript/parser.cc b/src/amberscript/parser.cc index 8d12483..953fc12 100644 --- a/src/amberscript/parser.cc +++ b/src/amberscript/parser.cc @@ -1024,6 +1024,9 @@ Result Parser::ParsePipelineSet(Pipeline* pipeline) { if (!type) return Result("invalid data type '" + token->AsString() + "' provided"); + if (type->IsVec() || type->IsMatrix() || type->IsArray() || type->IsStruct()) + return Result("data type must be a scalar type"); + token = tokenizer_->NextToken(); if (!token->IsInteger() && !token->IsDouble()) return Result("expected data value"); diff --git a/src/amberscript/parser_pipeline_set_test.cc b/src/amberscript/parser_pipeline_set_test.cc index 7d78ca7..46f0dd4 100644 --- a/src/amberscript/parser_pipeline_set_test.cc +++ b/src/amberscript/parser_pipeline_set_test.cc @@ -241,5 +241,22 @@ END EXPECT_EQ("7: SET can only be used with OPENCL-C shaders", r.Error()); } +TEST_F(AmberScriptParserTest, OpenCLSetNonScalarDataType) { + std::string in = R"( +SHADER compute my_shader OPENCL-C +#shader +END +PIPELINE compute my_pipeline + ATTACH my_shader + SET KERNEL ARG_NAME arg_a AS vec4<uint32> 0 0 0 0 +END +)"; + + Parser parser; + auto r = parser.Parse(in); + ASSERT_FALSE(r.IsSuccess()); + EXPECT_EQ("7: data type must be a scalar type", r.Error()); +} + } // namespace amberscript } // namespace amber diff --git a/src/pipeline.cc b/src/pipeline.cc index eca6326..478e6e1 100644 --- a/src/pipeline.cc +++ b/src/pipeline.cc @@ -755,7 +755,34 @@ Result Pipeline::GenerateOpenCLPodBuffers() { return Result(message); } - Result r = buffer->SetDataWithOffset({arg_info.value}, offset); + // Convert the argument value into bytes. Currently, only scalar arguments + // are supported. + const auto arg_byte_size = arg_info.fmt->SizeInBytes(); + std::vector<Value> data_bytes; + for (uint32_t i = 0; i < arg_byte_size; ++i) { + Value v; + if (arg_info.value.IsFloat()) { + if (arg_byte_size == sizeof(double)) { + union { + uint64_t u; + double d; + } u; + u.d = arg_info.value.AsDouble(); + v.SetIntValue((u.u >> (i * 8)) & 0xff); + } else { + union { + uint32_t u; + float f; + } u; + u.f = arg_info.value.AsFloat(); + v.SetIntValue((u.u >> (i * 8)) & 0xff); + } + } else { + v.SetIntValue((arg_info.value.AsUint64() >> (i * 8)) & 0xff); + } + data_bytes.push_back(v); + } + Result r = buffer->SetDataWithOffset(data_bytes, offset); if (!r.IsSuccess()) return r; } |