aboutsummaryrefslogtreecommitdiff
path: root/source/opt/interp_fixup_pass.cpp
blob: ad29e6a7207901ba1429e56949c643887739b524 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright (c) 2021 The Khronos Group Inc.
// Copyright (c) 2021 Valve Corporation
// Copyright (c) 2021 LunarG Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "source/opt/interp_fixup_pass.h"

#include <set>
#include <string>

#include "ir_builder.h"
#include "source/opt/ir_context.h"
#include "type_manager.h"

namespace spvtools {
namespace opt {

namespace {

// Input Operand Indices
static const int kSpvVariableStorageClassInIdx = 0;

// Avoid unused variable warning/error on Linux
#ifndef NDEBUG
#define USE_ASSERT(x) assert(x)
#else
#define USE_ASSERT(x) ((void)(x))
#endif

// Folding rule function which attempts to replace |op(OpLoad(a),...)|
// by |op(a,...)|, where |op| is one of the GLSLstd450 InterpolateAt*
// instructions. Returns true if replaced, false otherwise.
bool ReplaceInternalInterpolate(IRContext* ctx, Instruction* inst,
                                const std::vector<const analysis::Constant*>&) {
  uint32_t glsl450_ext_inst_id =
      ctx->get_feature_mgr()->GetExtInstImportId_GLSLstd450();
  assert(glsl450_ext_inst_id != 0);

  uint32_t ext_opcode = inst->GetSingleWordInOperand(1);

  uint32_t op1_id = inst->GetSingleWordInOperand(2);

  Instruction* load_inst = ctx->get_def_use_mgr()->GetDef(op1_id);
  if (load_inst->opcode() != SpvOpLoad) return false;

  Instruction* base_inst = load_inst->GetBaseAddress();
  USE_ASSERT(base_inst->opcode() == SpvOpVariable &&
             base_inst->GetSingleWordInOperand(kSpvVariableStorageClassInIdx) ==
                 SpvStorageClassInput &&
             "unexpected interpolant in InterpolateAt*");

  uint32_t ptr_id = load_inst->GetSingleWordInOperand(0);
  uint32_t op2_id = (ext_opcode != GLSLstd450InterpolateAtCentroid)
                        ? inst->GetSingleWordInOperand(3)
                        : 0;

  Instruction::OperandList new_operands;
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {glsl450_ext_inst_id}});
  new_operands.push_back(
      {SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER, {ext_opcode}});
  new_operands.push_back({SPV_OPERAND_TYPE_ID, {ptr_id}});
  if (op2_id != 0) new_operands.push_back({SPV_OPERAND_TYPE_ID, {op2_id}});

  inst->SetInOperands(std::move(new_operands));
  ctx->UpdateDefUse(inst);
  return true;
}

class InterpFoldingRules : public FoldingRules {
 public:
  explicit InterpFoldingRules(IRContext* ctx) : FoldingRules(ctx) {}

 protected:
  virtual void AddFoldingRules() override {
    uint32_t extension_id =
        context()->get_feature_mgr()->GetExtInstImportId_GLSLstd450();

    if (extension_id != 0) {
      ext_rules_[{extension_id, GLSLstd450InterpolateAtCentroid}].push_back(
          ReplaceInternalInterpolate);
      ext_rules_[{extension_id, GLSLstd450InterpolateAtSample}].push_back(
          ReplaceInternalInterpolate);
      ext_rules_[{extension_id, GLSLstd450InterpolateAtOffset}].push_back(
          ReplaceInternalInterpolate);
    }
  }
};

class InterpConstFoldingRules : public ConstantFoldingRules {
 public:
  InterpConstFoldingRules(IRContext* ctx) : ConstantFoldingRules(ctx) {}

 protected:
  virtual void AddFoldingRules() override {}
};

}  // namespace

Pass::Status InterpFixupPass::Process() {
  bool changed = false;

  // Traverse the body of the functions to replace instructions that require
  // the extensions.
  InstructionFolder folder(
      context(),
      std::unique_ptr<InterpFoldingRules>(new InterpFoldingRules(context())),
      MakeUnique<InterpConstFoldingRules>(context()));
  for (Function& func : *get_module()) {
    func.ForEachInst([&changed, &folder](Instruction* inst) {
      if (folder.FoldInstruction(inst)) {
        changed = true;
      }
    });
  }

  return changed ? Status::SuccessWithChange : Status::SuccessWithoutChange;
}

}  // namespace opt
}  // namespace spvtools