# Copyright 2016 The Gemmlowp Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """.""" import common def _ReadParams(emitter, registers, input_address, elements, min_register): registers_count = (elements + 3) / 4 registers = [ registers.QuadRegister(min_register) for unused_i in range(registers_count) ] emitter.EmitVLoadAE(registers_count * 4, 32, registers, input_address, 64) return registers def _Duplicate(emitter, registers, rows, values): """Populate a grid of registers duplicating provided values.""" duplicated = [] for i in range(rows): if i is rows - 1: duplicated.append(values[0]) else: duplicated.append(registers.QuadRegister()) emitter.EmitVDup('32', duplicated[i], emitter.Lane(32, values[i / 4], i % 4)) return duplicated def _DuplicateGeneralRegister(emitter, registers, value, min_register): register = registers.QuadRegister(min_register) emitter.EmitVDup('32', register, value) return register class _StaticQuantizationUInt8Transformation(object): """Calculate quantized values and cast back to uint8.""" def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): """Load parameters and prepare duplicated registers.""" emitter.EmitNewline() emitter.EmitComment('StaticQuantization::Prepare') lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) self.multiplicative_offset = _DuplicateGeneralRegister( emitter, registers, registers.MapParameter('multiplicative_offset', 'params.kernel.multiplicative_offset'), 4) self.rounding_offset = _DuplicateGeneralRegister( emitter, registers, registers.MapParameter('rounding_offset', 'params.kernel.rounding_offset'), 4) self.shift = _DuplicateGeneralRegister( emitter, registers, registers.MapParameter('shift', 'params.kernel.shift'), 4) self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) def Transform(self, emitter, registers, data, unused_kernel_m, unused_kernel_n): """Quantize the data.""" emitter.EmitNewline() emitter.EmitComment('StaticQuantization::Transform') for (row, lhs_offset) in zip(data, self.lhs_offsets): for row_register in row: emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) for row in data: for (row_register, rhs_offset_register) in zip(row, self.rhs_offsets): emitter.EmitVAdd('s32', row_register, row_register, rhs_offset_register) for row in data: for row_register in row: emitter.EmitVMul('i32', row_register, row_register, self.multiplicative_offset) for row in data: for row_register in row: emitter.EmitVAdd('i32', row_register, row_register, self.rounding_offset) for row in data: for row_register in row: emitter.EmitVShl('s32', row_register, row_register, self.shift) if len(data[0]) is 1: for row in data: emitter.EmitVQmovn('s32', row[0], row[0]) for row in data: emitter.EmitVQmovun('s16', row[0], row[0]) return data elif len(data[0]) is 2: results = [] for row in data: emitter.EmitVQmovn2('s32', row[0], row[0], row[1]) registers.FreeRegister(row[1]) results.append([row[0]]) for row in results: emitter.EmitVQmovun('s16', row[0], row[0]) return results else: assert False def Type(self): return 8 class _StaticQuantizationInt32Transformation(object): """.""" def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): emitter.EmitNewline() emitter.EmitComment('StaticQuantizationInt32::Prepare') lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) def Transform(self, emitter, unused_registers, data, unused_kernel_m, unused_kernel_n): """Quantize data and output as int32.""" emitter.EmitNewline() emitter.EmitComment('StaticQuantizationInt32::Transform') for (row, lhs_offset) in zip(data, self.lhs_offsets): for row_register in row: emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) for row in data: for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets): emitter.EmitVAdd('s32', row_register, row_register, rhs_offsets_register) return data def Type(self): return 32 class _StaticQuantizationFloatTransformation(object): """.""" def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs): emitter.EmitNewline() emitter.EmitComment('StaticQuantizationFloat::Prepare') lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4) self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4) self.scale = _DuplicateGeneralRegister( emitter, registers, registers.MapParameter('scale', 'params.kernel.scale'), 4) self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset) def Transform(self, emitter, unused_registers, data, unused_kernel_m, unused_kernel_n): """Quantize data and output as float.""" emitter.EmitNewline() emitter.EmitComment('StaticQuantizationFloat::Transform') for (row, lhs_offset) in zip(data, self.lhs_offsets): for row_register in row: emitter.EmitVAdd('s32', row_register, row_register, lhs_offset) for row in data: for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets): emitter.EmitVAdd('s32', row_register, row_register, rhs_offsets_register) for row in data: for row_register in row: emitter.EmitVCvt('f32', 's32', row_register, row_register) for row in data: for row_register in row: emitter.EmitVMul('f32', row_register, row_register, self.scale) return data def Type(self): return 32 class _RowMajorOutput(object): """Output data in row major layout.""" def Prepare(self, emitter, registers, kernel_m, unused_kernel_n, unused_data_type): """Prepare strided load addresses.""" emitter.EmitNewline() emitter.EmitComment('RowMajorOutput::Prepare') stride = registers.MapParameter('stride', 'params.output_stream.stride') self.outputs = [] self.outputs.append(registers.MapOutputParameter('result')) for unused_i in range(kernel_m - 1): register = registers.GeneralRegister() emitter.EmitAdd(register, self.outputs[-1], stride) self.outputs.append(register) def Output(self, emitter, unused_registers, data, data_type, unused_kernel_m, kernel_n): emitter.EmitNewline() emitter.EmitComment('RowMajorOutput::Output') for (datum, output) in zip(data, self.outputs): emitter.EmitVStoreAE(data_type, kernel_n, datum, output, None) def _GenerateAndClearAggregators(emitter, registers, count): """Prepare aggregators and emit aggregator clear code.""" emitter.EmitNewline() emitter.EmitComment('Clear aggregators.') aggregators = [registers.QuadRegister() for unused_i in range(count)] for i in range(count): if i < 3: emitter.EmitVMov('i32', aggregators[i], emitter.ImmediateConstant(0)) else: emitter.EmitVMov('i32', aggregators[i], aggregators[i - 3]) return aggregators def _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count): """Emit inner loop for 3 rows x 3 cols multiplication.""" emitter.EmitNewline() emitter.EmitComment('3x3 lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() lhs_load = [registers.DoubleRegister() for unused_i in range(3)] rhs_load = [registers.DoubleRegister() for unused_i in range(3)] temp = [registers.QuadRegister() for unused_i in range(4)] emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 64)) emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0]) emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1]) emitter.EmitVLoad(1, 8, lhs_load[2], emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2]) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) emitter.EmitVMull('u8', temp[3], lhs_load[1], rhs_load[0]) emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) emitter.EmitVPadal('u16', aggregators[0], temp[0]) emitter.EmitVPadal('u16', aggregators[1], temp[1]) emitter.EmitVPadal('u16', aggregators[2], temp[2]) emitter.EmitVPadal('u16', aggregators[3], temp[3]) emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1]) emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2]) registers.FreeRegisters([lhs_load[0], lhs_load[1]]) temp.append(registers.QuadRegister()) emitter.EmitVMull('u8', temp[2], lhs_load[2], rhs_load[0]) emitter.EmitVMull('u8', temp[3], lhs_load[2], rhs_load[1]) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitVMull('u8', temp[4], lhs_load[2], rhs_load[2]) emitter.EmitVPadal('u16', aggregators[4], temp[0]) emitter.EmitVPadal('u16', aggregators[5], temp[1]) emitter.EmitVPadal('u16', aggregators[6], temp[2]) emitter.EmitVPadal('u16', aggregators[7], temp[3]) emitter.EmitVPadal('u16', aggregators[8], temp[4]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBgtBack(1) registers.FreeRegisters(temp + [lhs_load[2]] + rhs_load) def _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count): """Emit inner loop for 2 rows x 4 cols multiplication.""" emitter.EmitNewline() emitter.EmitComment('2x4 lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() lhs_load = [registers.DoubleRegister() for unused_i in range(2)] rhs_load = [registers.DoubleRegister() for unused_i in range(4)] temp = [registers.QuadRegister() for unused_i in range(5)] emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 256)) emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0]) emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64)) emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1]) emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2]) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) emitter.EmitVMull('u8', temp[3], lhs_load[0], rhs_load[3]) emitter.EmitVMull('u8', temp[4], lhs_load[1], rhs_load[0]) emitter.EmitVPadal('u16', aggregators[0], temp[0]) emitter.EmitVPadal('u16', aggregators[1], temp[1]) emitter.EmitVPadal('u16', aggregators[2], temp[2]) emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1]) emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2]) emitter.EmitVMull('u8', temp[2], lhs_load[1], rhs_load[3]) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitVPadal('u16', aggregators[3], temp[3]) emitter.EmitVPadal('u16', aggregators[4], temp[4]) emitter.EmitVPadal('u16', aggregators[5], temp[0]) emitter.EmitVPadal('u16', aggregators[6], temp[1]) emitter.EmitVPadal('u16', aggregators[7], temp[2]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBgtBack(1) registers.FreeRegisters(temp + lhs_load + rhs_load) def _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count): """Emit inner loop for 1 rows x 8 cols multiplication.""" emitter.EmitNewline() emitter.EmitComment('1x8 lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() lhs_load = registers.DoubleRegister() rhs_load = [registers.DoubleRegister() for unused_i in range(4)] temp = [registers.QuadRegister() for unused_i in range(5)] emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256) emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64) emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[0]) emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[1]) emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[2]) emitter.EmitVMull('u8', temp[3], lhs_load, rhs_load[3]) emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256) emitter.EmitVPadal('u16', aggregators[0], temp[0]) emitter.EmitVPadal('u16', aggregators[1], temp[1]) emitter.EmitVPadal('u16', aggregators[2], temp[2]) emitter.EmitVPadal('u16', aggregators[3], temp[3]) emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(256)) emitter.EmitVMull('u8', temp[4], lhs_load, rhs_load[0]) emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[1]) emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[2]) emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[3]) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(32)) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() emitter.EmitVPadal('u16', aggregators[4], temp[4]) emitter.EmitVPadal('u16', aggregators[5], temp[0]) emitter.EmitVPadal('u16', aggregators[6], temp[1]) emitter.EmitVPadal('u16', aggregators[7], temp[2]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBgtBack(1) registers.FreeRegisters(temp + [lhs_load] + rhs_load) def _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n, aggregators, lhs, rhs, count): """Emit inner loop for N rows x M cols multiplication.""" emitter.EmitNewline() emitter.EmitComment('General NxM lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() lhs_load = [registers.DoubleRegister() for unused_i in range(kernel_m)] rhs_load = [registers.DoubleRegister() for unused_i in range(kernel_n)] emitter.EmitVLoadAE(8 * kernel_m, 8, lhs_load, lhs, 64) emitter.EmitVLoadAE(8 * kernel_n, 8, rhs_load, rhs, 64) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64)) results = [ registers.QuadRegister() for unused_i in range(kernel_m * kernel_n) ] for row in range(kernel_m): for col in range(kernel_n): index = row * kernel_n + col emitter.EmitVMull('u8', results[index], rhs_load[col], lhs_load[row]) for i in range(kernel_m * kernel_n): emitter.EmitVPadal('u16', aggregators[i], results[i]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBgtBack(1) registers.FreeRegisters(lhs_load + rhs_load + results) def _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators, lhs, rhs, count): """Emit inner loop for 1 row x M cols multiplication.""" assert kernel_n in [5, 6, 7, 8] emitter.EmitNewline() emitter.EmitComment('General 1xM lanes loop.') emitter.EmitNumericalLabel(1) emitter.EmitNewline() emitter.EmitComment('Subtract counter.') emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) emitter.EmitNewline() leftover = kernel_n - 4 rhs_load = [registers.DoubleRegister() for unused_i in range(4)] lhs_load = registers.DoubleRegister() emitter.EmitVLoadAE(8 * 4, 8, rhs_load, rhs, 64) emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64) emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64)) results = [registers.QuadRegister() for unused_i in range(4)] for i in range(4): emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load) emitter.EmitVLoadAE(8 * leftover, 8, rhs_load, rhs, 64) emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(128)) for i in range(4): emitter.EmitVPadal('u16', aggregators[i], results[i]) for i in range(leftover): emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load) for i in range(leftover): emitter.EmitVPadal('u16', aggregators[i + 4], results[i]) emitter.EmitNewline() emitter.EmitComment('Loop break.') emitter.EmitBgtBack(1) registers.FreeRegisters([lhs_load] + rhs_load + results) def _GenerateMultiplyKernel(emitter, registers, kernel_m, kernel_n, lhs, rhs): """Main muliply loop. Pick best implementation for given kernel shape.""" count = registers.MapParameter('count', 'params.kernel.count') aggregators = _GenerateAndClearAggregators(emitter, registers, kernel_m * kernel_n) if kernel_m is 3 and kernel_n is 3: _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count) elif kernel_m is 2 and kernel_n is 4: _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count) elif kernel_m is 1 and kernel_n is 8: _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs, count) elif kernel_m is 1 and kernel_n > 4: _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators, lhs, rhs, count) else: _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n, aggregators, lhs, rhs, count) return aggregators def _ReduceAggregators(emitter, aggregators): reduced_count = (len(aggregators) + 3) / 4 reduced = aggregators[:reduced_count] emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators) return reduced def _GenerateAggregatorReduce(emitter, aggregators, kernel_m, kernel_n): emitter.EmitNewline() emitter.EmitComment('Reduce aggregators.') row_temps = [] for i in range(kernel_m): row_temps.append( _ReduceAggregators(emitter, aggregators[i * kernel_n:(i + 1) * kernel_n])) return row_temps class QuantizedMulKernel(common.MulKernelGenerator): """.""" def __init__(self, cc_emitter, kernel_name, output_stream_name, asm_emitter, fused_transformation, output_strategy): common.MulKernelGenerator.__init__(self, cc_emitter, kernel_name, output_stream_name) self.asm_emitter = asm_emitter self.fused_transformation = fused_transformation self.output_strategy = output_strategy def EmitMultiply(self, in_type, out_type, kernel_m, kernel_n, pack_size): assert in_type is 'uint8_t' assert pack_size is 8 assert kernel_m * kernel_n <= 9 registers = self.asm_emitter.CreateRegisters() self.asm_emitter.PushIndent(self.emitter.indent) self.asm_emitter.EmitAsmBegin() lhs = registers.MapOutputParameter('lhs') rhs = registers.MapOutputParameter('rhs') self.asm_emitter.EmitPld(lhs) self.asm_emitter.EmitPld(rhs) aggregators = _GenerateMultiplyKernel(self.asm_emitter, registers, kernel_m, kernel_n, lhs, rhs) self.fused_transformation.Prepare(self.asm_emitter, registers, kernel_m, kernel_n, lhs, rhs) self.output_strategy.Prepare(self.asm_emitter, registers, kernel_m, kernel_n, self.fused_transformation.Type()) reduced = _GenerateAggregatorReduce(self.asm_emitter, aggregators, kernel_m, kernel_n) transformed = self.fused_transformation.Transform(self.asm_emitter, registers, reduced, kernel_m, kernel_n) self.output_strategy.Output(self.asm_emitter, registers, transformed, self.fused_transformation.Type(), kernel_m, kernel_n) self.asm_emitter.EmitAsmEnd(registers) self.asm_emitter.PopIndent(len(self.emitter.indent)) class QuantizedMulStaticRowMajor(QuantizedMulKernel): """.""" def __init__(self, cc_emitter, asm_emitter): QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessed', 'RowMajor', asm_emitter, _StaticQuantizationUInt8Transformation(), _RowMajorOutput()) class QuantizedMulStaticAsInt32RowMajor(QuantizedMulKernel): """.""" def __init__(self, cc_emitter, asm_emitter): QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessedAsInt32', 'RowMajor', asm_emitter, _StaticQuantizationInt32Transformation(), _RowMajorOutput()) class QuantizedMulStaticAsFloatRowMajor(QuantizedMulKernel): """.""" def __init__(self, cc_emitter, asm_emitter): QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessedAsFloat', 'RowMajor', asm_emitter, _StaticQuantizationFloatTransformation(), _RowMajorOutput()) def GenerateKernels(cc_emitter, asm_emitter, shapes): """Generate the quantized multiplication kernels for uint8 operands.""" quantized_mul_static_row_major = QuantizedMulStaticRowMajor(cc_emitter, asm_emitter) quantized_mul_static_int32_row_major = QuantizedMulStaticAsInt32RowMajor( cc_emitter, asm_emitter) quantized_mul_static_float_row_major = QuantizedMulStaticAsFloatRowMajor( cc_emitter, asm_emitter) for shape in shapes: quantized_mul_static_row_major.SpecializeMulKernel('uint8_t', 'uint8_t', shape[0], shape[1], 8) for shape in shapes: quantized_mul_static_int32_row_major.SpecializeMulKernel('uint8_t', 'int32_t', shape[0], shape[1], 8) for shape in shapes: quantized_mul_static_float_row_major.SpecializeMulKernel('uint8_t', 'float', shape[0], shape[1], 8)