// Copyright 2022 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #include #include #include // Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride. static bool can_dimension_be_removed( const size_t* input_stride, const size_t* output_stride, const size_t* shape, const size_t* perm, size_t dim) { if (dim == 0 && perm[dim] == 0) { return true; } if (input_stride != NULL && dim > 0) { if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) { return false; } } if (output_stride != NULL && perm[dim] > 0) { if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) { return false; } } return true; } // Remove dimension perm[dim] from shape, perm, input & output strides. static void remove_dimension( size_t* shape, size_t* perm, size_t* input_stride, size_t* output_stride, size_t num_dims, size_t dim) { for (size_t j = perm[dim]; j + 1 < num_dims; ++j) { shape[j] = shape[j + 1]; } if (input_stride != NULL) { for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) { input_stride[j] = input_stride[j + 1]; } } if (output_stride != NULL) { for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) { output_stride[j] = output_stride[j + 1]; } } for (size_t j = 0; j < num_dims; ++j) { if (perm[j] > perm[dim]) { perm[j] -= 1; } } for (size_t j = dim; j + 1 < num_dims; ++j) { perm[j] = perm[j + 1]; } } void xnn_normalize_transpose_permutation( const size_t num_dims, const size_t element_size, const size_t* perm, const size_t* shape, const size_t* input_stride, const size_t* output_stride, size_t* normalized_num_dims, size_t* normalized_element_size_out, size_t* normalized_perm, size_t* normalized_shape, size_t* normalized_input_stride, size_t* normalized_output_stride) { size_t output_dims = num_dims; memcpy(normalized_perm, perm, num_dims * sizeof(size_t)); memcpy(normalized_shape, shape, num_dims * sizeof(size_t)); size_t* normalized_input_stride_ptr = NULL; size_t* normalized_output_stride_ptr = NULL; if (input_stride != NULL) { memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t)); normalized_input_stride_ptr = normalized_input_stride; } if (output_stride != NULL) { memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t)); normalized_output_stride_ptr = normalized_output_stride; } size_t output_pos = 0; // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors. for (; output_pos < output_dims;) { if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape, normalized_perm, normalized_perm[output_pos]) && ((normalized_shape[normalized_perm[output_pos]] == 1) || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) { if (output_pos > 0) { normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]]; } remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr, output_dims, output_pos); output_dims -= 1; // When a dimension has been removed, new folds may be possible so check // it again. if (output_pos > 0) { output_pos -= 1; } } else { output_pos += 1; } } // All dimensions are size 1. if (output_pos == 0) { *normalized_num_dims = 1; *normalized_element_size_out = element_size; normalized_perm[0] = 0; normalized_shape[0] = 1; normalized_input_stride[0] = element_size; normalized_output_stride[0] = element_size; return; } // If The last input and output dimensions are the same, treat it as one large // element. size_t normalized_element_size = element_size; if (normalized_perm[output_dims - 1] == output_dims - 1) { normalized_element_size = element_size * normalized_shape[output_dims - 1]; if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape, normalized_perm, output_dims - 1)) { output_dims -= 1; } else { if (normalized_input_stride != NULL) { normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1]; } if (normalized_output_stride != NULL) { normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1]; } normalized_shape[output_dims - 1] = 1; } } // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size. if (input_stride == NULL) { normalized_input_stride[output_dims - 1] = normalized_element_size; for(size_t i = output_dims - 1; i > 0; --i) { normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i]; } } else { // Scale input_stride by element size. for (size_t i = 0; i < output_dims; ++i) { normalized_input_stride[i] *= element_size; } } // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size. if (output_stride == NULL) { normalized_output_stride[output_dims - 1] = normalized_element_size; for(size_t i = output_dims - 1; i > 0; --i) { normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]]; } } else { // Scale output_stride by element size. for (size_t i = 0; i < output_dims; ++i) { normalized_output_stride[i] *= element_size; } } *normalized_element_size_out = normalized_element_size; *normalized_num_dims = output_dims; }