aboutsummaryrefslogtreecommitdiff
path: root/source/reduce/remove_struct_member_reduction_opportunity.cpp
blob: da096e1ebe8e8bed232874796beed139bbb15b4b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright (c) 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/reduce/remove_struct_member_reduction_opportunity.h"

#include "source/opt/ir_context.h"

namespace spvtools {
namespace reduce {

bool RemoveStructMemberReductionOpportunity::PreconditionHolds() {
  return struct_type_->NumInOperands() == original_number_of_members_;
}

void RemoveStructMemberReductionOpportunity::Apply() {
  std::set<opt::Instruction*> decorations_to_kill;

  // We need to remove decorations that target the removed struct member, and
  // adapt decorations that target later struct members by decrementing the
  // member identifier.  We also need to adapt composite construction
  // instructions so that no id is provided for the member being removed.
  //
  // To do this, we consider every use of the struct type.
  struct_type_->context()->get_def_use_mgr()->ForEachUse(
      struct_type_, [this, &decorations_to_kill](opt::Instruction* user,
                                                 uint32_t /*operand_index*/) {
        switch (user->opcode()) {
          case SpvOpCompositeConstruct:
          case SpvOpConstantComposite:
            // This use is constructing a composite of the struct type, so we
            // must remove the id that was provided for the member we are
            // removing.
            user->RemoveInOperand(member_index_);
            break;
          case SpvOpMemberDecorate:
            // This use is decorating a member of the struct.
            if (user->GetSingleWordInOperand(1) == member_index_) {
              // The member we are removing is being decorated, so we record
              // that we need to get rid of the decoration.
              decorations_to_kill.insert(user);
            } else if (user->GetSingleWordInOperand(1) > member_index_) {
              // A member beyond the one we are removing is being decorated, so
              // we adjust the index that identifies the member.
              user->SetInOperand(1, {user->GetSingleWordInOperand(1) - 1});
            }
            break;
          default:
            break;
        }
      });

  // Get rid of all the decorations that were found to target the member being
  // removed.
  for (auto decoration_to_kill : decorations_to_kill) {
    decoration_to_kill->context()->KillInst(decoration_to_kill);
  }

  // We now look through all instructions that access composites via sequences
  // of indices. Every time we find an index into the struct whose member is
  // being removed, and if the member being accessed comes after the member
  // being removed, we need to adjust the index accordingly.
  //
  // We go through every relevant instruction in every block of every function,
  // and invoke a helper to adjust it.
  auto context = struct_type_->context();
  for (auto& function : *context->module()) {
    for (auto& block : function) {
      for (auto& inst : block) {
        switch (inst.opcode()) {
          case SpvOpAccessChain:
          case SpvOpInBoundsAccessChain: {
            // These access chain instructions take sequences of ids for
            // indexing, starting from input operand 1.
            auto composite_type_id =
                context->get_def_use_mgr()
                    ->GetDef(context->get_def_use_mgr()
                                 ->GetDef(inst.GetSingleWordInOperand(0))
                                 ->type_id())
                    ->GetSingleWordInOperand(1);
            AdjustAccessedIndices(composite_type_id, 1, false, context, &inst);
          } break;
          case SpvOpPtrAccessChain:
          case SpvOpInBoundsPtrAccessChain: {
            // These access chain instructions take sequences of ids for
            // indexing, starting from input operand 2.
            auto composite_type_id =
                context->get_def_use_mgr()
                    ->GetDef(context->get_def_use_mgr()
                                 ->GetDef(inst.GetSingleWordInOperand(1))
                                 ->type_id())
                    ->GetSingleWordInOperand(1);
            AdjustAccessedIndices(composite_type_id, 2, false, context, &inst);
          } break;
          case SpvOpCompositeExtract: {
            // OpCompositeExtract uses literals for indexing, starting at input
            // operand 1.
            auto composite_type_id =
                context->get_def_use_mgr()
                    ->GetDef(inst.GetSingleWordInOperand(0))
                    ->type_id();
            AdjustAccessedIndices(composite_type_id, 1, true, context, &inst);
          } break;
          case SpvOpCompositeInsert: {
            // OpCompositeInsert uses literals for indexing, starting at input
            // operand 2.
            auto composite_type_id =
                context->get_def_use_mgr()
                    ->GetDef(inst.GetSingleWordInOperand(1))
                    ->type_id();
            AdjustAccessedIndices(composite_type_id, 2, true, context, &inst);
          } break;
          default:
            break;
        }
      }
    }
  }

  // Remove the member from the struct type.
  struct_type_->RemoveInOperand(member_index_);

  context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
}

void RemoveStructMemberReductionOpportunity::AdjustAccessedIndices(
    uint32_t composite_type_id, uint32_t first_index_input_operand,
    bool literal_indices, opt::IRContext* context,
    opt::Instruction* composite_access_instruction) const {
  // Walk the series of types that are encountered by following the
  // instruction's sequence of indices. For all types except structs, this is
  // routine: the type of the composite dictates what the next type will be
  // regardless of the specific index value.
  uint32_t next_type = composite_type_id;
  for (uint32_t i = first_index_input_operand;
       i < composite_access_instruction->NumInOperands(); i++) {
    auto type_inst = context->get_def_use_mgr()->GetDef(next_type);
    switch (type_inst->opcode()) {
      case SpvOpTypeArray:
      case SpvOpTypeMatrix:
      case SpvOpTypeRuntimeArray:
      case SpvOpTypeVector:
        next_type = type_inst->GetSingleWordInOperand(0);
        break;
      case SpvOpTypeStruct: {
        // Struct types are special becuase (a) we may need to adjust the index
        // being used, if the struct type is the one from which we are removing
        // a member, and (b) the type encountered by following the current index
        // is dependent on the value of the index.

        // Work out the member being accessed.  If literal indexing is used this
        // is simple; otherwise we need to look up the id of the constant
        // instruction being used as an index and get the value of the constant.
        uint32_t index_operand =
            composite_access_instruction->GetSingleWordInOperand(i);
        uint32_t member = literal_indices ? index_operand
                                          : context->get_def_use_mgr()
                                                ->GetDef(index_operand)
                                                ->GetSingleWordInOperand(0);

        // The next type we will consider is obtained by looking up the struct
        // type at |member|.
        next_type = type_inst->GetSingleWordInOperand(member);

        if (type_inst == struct_type_ && member > member_index_) {
          // The struct type is the struct from which we are removing a member,
          // and the member being accessed is beyond the member we are removing.
          // We thus need to decrement the index by 1.
          uint32_t new_in_operand;
          if (literal_indices) {
            // With literal indexing this is straightforward.
            new_in_operand = member - 1;
          } else {
            // With id-based indexing this is more tricky: we need to find or
            // create a constant instruction whose value is one less than
            // |member|, and use the id of this constant as the replacement
            // input operand.
            auto constant_inst =
                context->get_def_use_mgr()->GetDef(index_operand);
            auto int_type = context->get_type_mgr()
                                ->GetType(constant_inst->type_id())
                                ->AsInteger();
            auto new_index_constant =
                opt::analysis::IntConstant(int_type, {member - 1});
            new_in_operand = context->get_constant_mgr()
                                 ->GetDefiningInstruction(&new_index_constant)
                                 ->result_id();
          }
          composite_access_instruction->SetInOperand(i, {new_in_operand});
        }
      } break;
      default:
        assert(0 && "Unknown composite type.");
        break;
    }
  }
}

}  // namespace reduce
}  // namespace spvtools