aboutsummaryrefslogtreecommitdiff
path: root/source/opt/scalar_analysis_nodes.h
blob: 91ce446f3d1c3f582d39d5c2c62553559166fcb3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
// Copyright (c) 2018 Google LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASI,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_
#define SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_

#include <algorithm>
#include <memory>
#include <string>
#include <vector>

#include "source/opt/tree_iterator.h"

namespace spvtools {
namespace opt {

class Loop;
class ScalarEvolutionAnalysis;
class SEConstantNode;
class SERecurrentNode;
class SEAddNode;
class SEMultiplyNode;
class SENegative;
class SEValueUnknown;
class SECantCompute;

// Abstract class representing a node in the scalar evolution DAG. Each node
// contains a vector of pointers to its children and each subclass of SENode
// implements GetType and an As method to allow casting. SENodes can be hashed
// using the SENodeHash functor. The vector of children is sorted when a node is
// added. This is important as it allows the hash of X+Y to be the same as Y+X.
class SENode {
 public:
  enum SENodeType {
    Constant,
    RecurrentAddExpr,
    Add,
    Multiply,
    Negative,
    ValueUnknown,
    CanNotCompute
  };

  using ChildContainerType = std::vector<SENode*>;

  explicit SENode(ScalarEvolutionAnalysis* parent_analysis)
      : parent_analysis_(parent_analysis), unique_id_(++NumberOfNodes) {}

  virtual SENodeType GetType() const = 0;

  virtual ~SENode() {}

  virtual inline void AddChild(SENode* child) {
    // If this is a constant node, assert.
    if (AsSEConstantNode()) {
      assert(false && "Trying to add a child node to a constant!");
    }

    // Find the first point in the vector where |child| is greater than the node
    // currently in the vector.
    auto find_first_less_than = [child](const SENode* node) {
      return child->unique_id_ <= node->unique_id_;
    };

    auto position = std::find_if_not(children_.begin(), children_.end(),
                                     find_first_less_than);
    // Children are sorted so the hashing and equality operator will be the same
    // for a node with the same children. X+Y should be the same as Y+X.
    children_.insert(position, child);
  }

  // Get the type as an std::string. This is used to represent the node in the
  // dot output and is used to hash the type as well.
  std::string AsString() const;

  // Dump the SENode and its immediate children, if |recurse| is true then it
  // will recurse through all children to print the DAG starting from this node
  // as a root.
  void DumpDot(std::ostream& out, bool recurse = false) const;

  // Checks if two nodes are the same by hashing them.
  bool operator==(const SENode& other) const;

  // Checks if two nodes are not the same by comparing the hashes.
  bool operator!=(const SENode& other) const;

  // Return the child node at |index|.
  inline SENode* GetChild(size_t index) { return children_[index]; }
  inline const SENode* GetChild(size_t index) const { return children_[index]; }

  // Iterator to iterate over the child nodes.
  using iterator = ChildContainerType::iterator;
  using const_iterator = ChildContainerType::const_iterator;

  // Iterate over immediate child nodes.
  iterator begin() { return children_.begin(); }
  iterator end() { return children_.end(); }

  // Constant overloads for iterating over immediate child nodes.
  const_iterator begin() const { return children_.cbegin(); }
  const_iterator end() const { return children_.cend(); }
  const_iterator cbegin() { return children_.cbegin(); }
  const_iterator cend() { return children_.cend(); }

  // Collect all the recurrent nodes in this SENode
  std::vector<SERecurrentNode*> CollectRecurrentNodes() {
    std::vector<SERecurrentNode*> recurrent_nodes{};

    if (auto recurrent_node = AsSERecurrentNode()) {
      recurrent_nodes.push_back(recurrent_node);
    }

    for (auto child : GetChildren()) {
      auto child_recurrent_nodes = child->CollectRecurrentNodes();
      recurrent_nodes.insert(recurrent_nodes.end(),
                             child_recurrent_nodes.begin(),
                             child_recurrent_nodes.end());
    }

    return recurrent_nodes;
  }

  // Collect all the value unknown nodes in this SENode
  std::vector<SEValueUnknown*> CollectValueUnknownNodes() {
    std::vector<SEValueUnknown*> value_unknown_nodes{};

    if (auto value_unknown_node = AsSEValueUnknown()) {
      value_unknown_nodes.push_back(value_unknown_node);
    }

    for (auto child : GetChildren()) {
      auto child_value_unknown_nodes = child->CollectValueUnknownNodes();
      value_unknown_nodes.insert(value_unknown_nodes.end(),
                                 child_value_unknown_nodes.begin(),
                                 child_value_unknown_nodes.end());
    }

    return value_unknown_nodes;
  }

  // Iterator to iterate over the entire DAG. Even though we are using the tree
  // iterator it should still be safe to iterate over. However, nodes with
  // multiple parents will be visited multiple times, unlike in a tree.
  using dag_iterator = TreeDFIterator<SENode>;
  using const_dag_iterator = TreeDFIterator<const SENode>;

  // Iterate over all child nodes in the graph.
  dag_iterator graph_begin() { return dag_iterator(this); }
  dag_iterator graph_end() { return dag_iterator(); }
  const_dag_iterator graph_begin() const { return graph_cbegin(); }
  const_dag_iterator graph_end() const { return graph_cend(); }
  const_dag_iterator graph_cbegin() const { return const_dag_iterator(this); }
  const_dag_iterator graph_cend() const { return const_dag_iterator(); }

  // Return the vector of immediate children.
  const ChildContainerType& GetChildren() const { return children_; }
  ChildContainerType& GetChildren() { return children_; }

  // Return true if this node is a can't compute node.
  bool IsCantCompute() const { return GetType() == CanNotCompute; }

// Implements a casting method for each type.
// clang-format off
#define DeclareCastMethod(target)                  \
  virtual target* As##target() { return nullptr; } \
  virtual const target* As##target() const { return nullptr; }
  DeclareCastMethod(SEConstantNode)
  DeclareCastMethod(SERecurrentNode)
  DeclareCastMethod(SEAddNode)
  DeclareCastMethod(SEMultiplyNode)
  DeclareCastMethod(SENegative)
  DeclareCastMethod(SEValueUnknown)
  DeclareCastMethod(SECantCompute)
#undef DeclareCastMethod

  // Get the analysis which has this node in its cache.
  inline ScalarEvolutionAnalysis* GetParentAnalysis() const {
    return parent_analysis_;
  }

 protected:
  ChildContainerType children_;

  ScalarEvolutionAnalysis* parent_analysis_;

  // The unique id of this node, assigned on creation by incrementing the static
  // node count.
  uint32_t unique_id_;

  // The number of nodes created.
  static uint32_t NumberOfNodes;
};
// clang-format on

// Function object to handle the hashing of SENodes. Hashing algorithm hashes
// the type (as a string), the literal value of any constants, and the child
// pointers which are assumed to be unique.
struct SENodeHash {
  size_t operator()(const std::unique_ptr<SENode>& node) const;
  size_t operator()(const SENode* node) const;
};

// A node representing a constant integer.
class SEConstantNode : public SENode {
 public:
  SEConstantNode(ScalarEvolutionAnalysis* parent_analysis, int64_t value)
      : SENode(parent_analysis), literal_value_(value) {}

  SENodeType GetType() const final { return Constant; }

  int64_t FoldToSingleValue() const { return literal_value_; }

  SEConstantNode* AsSEConstantNode() override { return this; }
  const SEConstantNode* AsSEConstantNode() const override { return this; }

  inline void AddChild(SENode*) final {
    assert(false && "Attempting to add a child to a constant node!");
  }

 protected:
  int64_t literal_value_;
};

// A node representing a recurrent expression in the code. A recurrent
// expression is an expression whose value can be expressed as a linear
// expression of the loop iterations. Such as an induction variable. The actual
// value of a recurrent expression is coefficent_ * iteration + offset_, hence
// an induction variable i=0, i++ becomes a recurrent expression with an offset
// of zero and a coefficient of one.
class SERecurrentNode : public SENode {
 public:
  SERecurrentNode(ScalarEvolutionAnalysis* parent_analysis, const Loop* loop)
      : SENode(parent_analysis), loop_(loop) {}

  SENodeType GetType() const final { return RecurrentAddExpr; }

  inline void AddCoefficient(SENode* child) {
    coefficient_ = child;
    SENode::AddChild(child);
  }

  inline void AddOffset(SENode* child) {
    offset_ = child;
    SENode::AddChild(child);
  }

  inline const SENode* GetCoefficient() const { return coefficient_; }
  inline SENode* GetCoefficient() { return coefficient_; }

  inline const SENode* GetOffset() const { return offset_; }
  inline SENode* GetOffset() { return offset_; }

  // Return the loop which this recurrent expression is recurring within.
  const Loop* GetLoop() const { return loop_; }

  SERecurrentNode* AsSERecurrentNode() override { return this; }
  const SERecurrentNode* AsSERecurrentNode() const override { return this; }

 private:
  SENode* coefficient_;
  SENode* offset_;
  const Loop* loop_;
};

// A node representing an addition operation between child nodes.
class SEAddNode : public SENode {
 public:
  explicit SEAddNode(ScalarEvolutionAnalysis* parent_analysis)
      : SENode(parent_analysis) {}

  SENodeType GetType() const final { return Add; }

  SEAddNode* AsSEAddNode() override { return this; }
  const SEAddNode* AsSEAddNode() const override { return this; }
};

// A node representing a multiply operation between child nodes.
class SEMultiplyNode : public SENode {
 public:
  explicit SEMultiplyNode(ScalarEvolutionAnalysis* parent_analysis)
      : SENode(parent_analysis) {}

  SENodeType GetType() const final { return Multiply; }

  SEMultiplyNode* AsSEMultiplyNode() override { return this; }
  const SEMultiplyNode* AsSEMultiplyNode() const override { return this; }
};

// A node representing a unary negative operation.
class SENegative : public SENode {
 public:
  explicit SENegative(ScalarEvolutionAnalysis* parent_analysis)
      : SENode(parent_analysis) {}

  SENodeType GetType() const final { return Negative; }

  SENegative* AsSENegative() override { return this; }
  const SENegative* AsSENegative() const override { return this; }
};

// A node representing a value which we do not know the value of, such as a load
// instruction.
class SEValueUnknown : public SENode {
 public:
  // SEValueUnknowns must come from an instruction |unique_id| is the unique id
  // of that instruction. This is so we cancompare value unknowns and have a
  // unique value unknown for each instruction.
  SEValueUnknown(ScalarEvolutionAnalysis* parent_analysis, uint32_t result_id)
      : SENode(parent_analysis), result_id_(result_id) {}

  SENodeType GetType() const final { return ValueUnknown; }

  SEValueUnknown* AsSEValueUnknown() override { return this; }
  const SEValueUnknown* AsSEValueUnknown() const override { return this; }

  inline uint32_t ResultId() const { return result_id_; }

 private:
  uint32_t result_id_;
};

// A node which we cannot reason about at all.
class SECantCompute : public SENode {
 public:
  explicit SECantCompute(ScalarEvolutionAnalysis* parent_analysis)
      : SENode(parent_analysis) {}

  SENodeType GetType() const final { return CanNotCompute; }

  SECantCompute* AsSECantCompute() override { return this; }
  const SECantCompute* AsSECantCompute() const override { return this; }
};

}  // namespace opt
}  // namespace spvtools
#endif  // SOURCE_OPT_SCALAR_ANALYSIS_NODES_H_