aboutsummaryrefslogtreecommitdiff
path: root/tensorflow_lite_support/codegen/code_generator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow_lite_support/codegen/code_generator.cc')
-rw-r--r--tensorflow_lite_support/codegen/code_generator.cc179
1 files changed, 179 insertions, 0 deletions
diff --git a/tensorflow_lite_support/codegen/code_generator.cc b/tensorflow_lite_support/codegen/code_generator.cc
new file mode 100644
index 00000000..1337708d
--- /dev/null
+++ b/tensorflow_lite_support/codegen/code_generator.cc
@@ -0,0 +1,179 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow_lite_support/codegen/code_generator.h"
+
+#include <cctype>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow_lite_support/codegen/utils.h"
+#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
+
+namespace tflite {
+namespace support {
+namespace codegen {
+
+namespace {
+
+void ResolveConflictedNamesByAddingIndex(std::vector<std::string>* names_ptr) {
+ auto& names = *names_ptr;
+ std::unordered_map<std::string, int> indexes;
+ std::unordered_map<std::string, int> first_appearance;
+ for (int i = 0; i < names.size(); i++) {
+ if (indexes.find(names[i]) == indexes.end()) {
+ indexes[names[i]] = 1;
+ first_appearance[names[i]] = i;
+ } else {
+ indexes[names[i]] += 1;
+ names[i].append(std::to_string(indexes[names[i]]));
+ }
+ }
+ for (const auto& it : first_appearance) {
+ const auto& name = it.first;
+ const auto i = it.second;
+ if (indexes[name] > 1) {
+ names[i].append("1");
+ }
+ }
+}
+
+} // namespace
+
+CodeGenerator::CodeGenerator() {}
+
+bool CodeGenerator::VerifyMetadata(const ModelMetadata* metadata,
+ ErrorReporter* err) {
+ if (metadata == nullptr) {
+ err->Error("Loading nullptr is not allowed");
+ return false;
+ }
+ if (metadata->subgraph_metadata()->size() != 1) {
+ err->Error("Only exact 1 subgraph is supported");
+ return false;
+ }
+ return true;
+}
+
+std::pair<std::vector<std::string>, std::vector<std::string>>
+CodeGenerator::NameInputsAndOutputs(const TensorMetadataList* inputs,
+ const TensorMetadataList* outputs) {
+ std::vector<std::string> input_names;
+ std::vector<std::string> output_names;
+ if (inputs != nullptr) {
+ input_names.reserve(inputs->size());
+ for (const auto* tensor : *inputs) {
+ input_names.push_back(NameTensor(*tensor, "input"));
+ }
+ }
+ if (outputs != nullptr) {
+ output_names.reserve(outputs->size());
+ for (const auto* tensor : *outputs) {
+ output_names.push_back(NameTensor(*tensor, "output"));
+ }
+ }
+ // Solve conflict
+ ResolveConflictedInputAndOutputNames(&input_names, &output_names);
+ return std::make_pair(input_names, output_names);
+}
+
+std::string CodeGenerator::ConvertToValidName(const std::string& name) {
+ // lowercase all
+ std::string result = name;
+ for (int i = 0; i < result.size(); i++) {
+ result[i] = std::tolower(result[i]);
+ }
+ // replace all non-alpha or non-numeric with underscores, except underscore
+ // itself
+ for (int i = 0; i < result.size(); i++) {
+ if (result[i] != '_' && !std::isalnum(result[i])) {
+ result[i] = '_';
+ }
+ }
+ // remove leading underscores
+ int leading_underscores = 0;
+ while (leading_underscores < result.size() &&
+ result[leading_underscores] == '_') {
+ leading_underscores++;
+ }
+ result.erase(0, leading_underscores);
+ if (result.empty()) {
+ return "";
+ }
+ // first char should be alpha
+ if (std::isalpha(result[0])) {
+ return result;
+ }
+ return "tensor_" + result;
+}
+
+std::string CodeGenerator::NameTensor(const TensorMetadata& tensor,
+ const std::string& default_name) {
+ if (tensor.name() != nullptr && tensor.name()->size() > 0) {
+ // TODO(b/141225157) Validate tensor name. It should be in lower case.
+ auto suggested_name = ConvertToValidName(tensor.name()->str());
+ if (!suggested_name.empty()) {
+ return suggested_name;
+ }
+ }
+ auto* content = tensor.content();
+ if (content == nullptr || content->content_properties() == nullptr) {
+ return default_name;
+ }
+ switch (content->content_properties_type()) {
+ case ContentProperties_ImageProperties:
+ return "image";
+ case ContentProperties_FeatureProperties:
+ return "feature";
+ default:
+ return default_name;
+ }
+}
+
+void CodeGenerator::ResolveConflictedInputAndOutputNames(
+ std::vector<std::string>* inputs, std::vector<std::string>* outputs) {
+ std::unordered_set<std::string> io_conflict;
+ auto& input_names = *inputs;
+ auto& output_names = *outputs;
+ for (const auto& input : input_names) {
+ if (io_conflict.find(input) != io_conflict.end()) {
+ continue;
+ }
+ for (const auto& output : output_names) {
+ if (input == output) {
+ io_conflict.insert(input);
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < input_names.size(); i++) {
+ if (io_conflict.find(input_names[i]) != io_conflict.end()) {
+ input_names[i] = "input_" + input_names[i];
+ }
+ }
+ for (int i = 0; i < output_names.size(); i++) {
+ if (io_conflict.find(output_names[i]) != io_conflict.end()) {
+ output_names[i] = "output_" + output_names[i];
+ }
+ }
+ // 2. Second, add index if input[i] == input[j]
+ ResolveConflictedNamesByAddingIndex(&input_names);
+ ResolveConflictedNamesByAddingIndex(&output_names);
+}
+
+} // namespace codegen
+} // namespace support
+} // namespace tflite