"""Zip primitive used by the GEMM function. Takes 1 to 3 rows of data and interleaves them in 8 byte chunks. Pads to multiply of 8 length with zeros. Calculates row sums and appends those at the end. """ import neon_emitter class Error(Exception): """Module level error.""" class ConfigurationError(Error): """Unsupported configuration.""" class ZipLane(object): def __init__(self, input_address, load, aggregator): self.input_address = input_address self.load = load self.aggregator = aggregator def GenerateZipLanes(emitter, registers, zip_lanes, input_address, stride): """Prepares read lanes for the zip operation. Args: emitter: ARM/NEON emitter. registers: ARM/NEON registers state. zip_lanes: number of lanes to prepare. input_address: register that contains the input address for the first lane. stride: memory stride for lane inputs. Returns: Array of ZipLane objects. """ lanes = [] last_address_register = input_address for i in range(0, zip_lanes): if not i: lanes.append(ZipLane(input_address, registers.DoubleRegister(), registers.QuadRegister(2))) else: address_register = registers.GeneralRegister() lanes.append(ZipLane(address_register, registers.DoubleRegister(), registers.QuadRegister(2))) emitter.EmitAdd(address_register, last_address_register, stride) last_address_register = address_register return lanes def BuildName(zip_lanes, leftovers, aligned): name = 'zip_%dx8' % zip_lanes if leftovers: name += '_%d' % leftovers if aligned: name += '_aligned' return name def GenerateClearAggregators(emitter, lanes): for lane in lanes: emitter.EmitVMov('i16', lane.aggregator, emitter.ImmediateConstant(0)) def GenerateLoadAggregateStore(emitter, lanes, output_address, alignment): """Emit inner loop code for reading N lanes and interweaving them.""" emitter.EmitNewline() emitter.EmitComment('Load Aggregate Store.') for lane in lanes: emitter.EmitVLoad( '1.8', lane.load, emitter.DereferenceIncrement(lane.input_address, alignment)) store_registers = [] for lane in lanes: emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) store_registers.append(lane.load) emitter.EmitVStoreA('1.8', store_registers, emitter.DereferenceIncrement(output_address, 64)) def GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, output_address): """Handle leftovers when count is not a multiply of 8.""" emitter.EmitNewline() emitter.EmitComment('Leftover Load Aggregate Store.') # Clear load registers. for lane in lanes: emitter.EmitVMov('i8', lane.load, emitter.ImmediateConstant(0)) if leftovers == 1: # Load 8 bits. for lane in lanes: emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 0), emitter.Dereference(lane.input_address, None)) elif leftovers == 2: # Load 16 bits. for lane in lanes: emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), emitter.Dereference(lane.input_address, None)) elif leftovers == 3: # Load 16 bits. for lane in lanes: emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 0), emitter.DereferenceIncrement(lane.input_address, None)) # Load 8 bits. for lane in lanes: emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 2), emitter.Dereference(lane.input_address, None)) elif leftovers == 4: # Load 32 bits. for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), emitter.Dereference(lane.input_address, None)) elif leftovers == 5: # Load 32 bits.. for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), emitter.DereferenceIncrement(lane.input_address, None)) # Load 8 bits. for lane in lanes: emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 4), emitter.Dereference(lane.input_address, None)) elif leftovers == 6: # Load 32 bits.. for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), emitter.DereferenceIncrement(lane.input_address, None)) # Load 16 bits. for lane in lanes: emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), emitter.Dereference(lane.input_address, None)) elif leftovers == 7: # Load 32 bits.. for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane(lane.load, 0), emitter.DereferenceIncrement(lane.input_address, None)) # Load 16 bits. for lane in lanes: emitter.EmitVLoad('1.16', emitter.Lane(lane.load, 2), emitter.DereferenceIncrement(lane.input_address, None)) # Load 8 bits. for lane in lanes: emitter.EmitVLoad('1.8', emitter.Lane(lane.load, 6), emitter.Dereference(lane.input_address, None)) else: raise ConfigurationError('Unsupported leftover num: %d' % leftovers) # Aggregate. store_registers = [] for lane in lanes: emitter.EmitVAddw('u8', lane.aggregator, lane.aggregator, lane.load) store_registers.append(lane.load) # Store. emitter.EmitVStoreA('1.8', store_registers, emitter.DereferenceIncrement(output_address, 64)) def GenerateAggregatorReduction(emitter, registers, lanes, output_address, multiplicative_offset, additive_offset): """Reduce 4 lane sum aggregators to 1 value and store the sums.""" emitter.EmitNewline() emitter.EmitComment('Aggregator Reduction.') multiplier = registers.DoubleRegister() emitter.EmitVMov('32', emitter.Lane(multiplier, 0), multiplicative_offset) offset = registers.QuadRegister() emitter.EmitVDup('32', offset, additive_offset) lane_temps = [] for lane in lanes: emitter.EmitVPaddl('u16', lane.aggregator, lane.aggregator) for lane in lanes: lane_temp = registers.DoubleRegister() lane_temps.append(lane_temp) emitter.EmitVPadd('u32', lane_temp, registers.Low(lane.aggregator), registers.High(lane.aggregator)) temp = registers.QuadRegister() low = registers.Low(temp) high = registers.High(temp) if len(lanes) == 1: emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[0]) elif len(lanes) == 2: emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) elif len(lanes) == 3: emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[2]) elif len(lanes) == 4: emitter.EmitVPadd('u32', low, lane_temps[0], lane_temps[1]) emitter.EmitVPadd('u32', high, lane_temps[2], lane_temps[3]) else: raise ConfigurationError('Unexpected number of aggregators to reduce: %d' % len(lanes)) emitter.EmitVMul('i32', temp, temp, emitter.Lane(multiplier, 0)) emitter.EmitVAdd('i32', temp, temp, offset) if len(lanes) == 1: emitter.EmitVStore('1.32', emitter.Lane(low, 0), emitter.Dereference(output_address, None)) elif len(lanes) == 2: emitter.EmitVStore('1.32', low, emitter.Dereference(output_address, 64)) elif len(lanes) == 3: emitter.EmitVStore('1.32', low, emitter.DereferenceIncrement(output_address, 64)) emitter.EmitVStore('1.32', emitter.Lane(high, 0), emitter.Dereference(output_address, None)) elif len(lanes) == 4: emitter.EmitVStoreA('1.32', [low, high], emitter.DereferenceIncrement(output_address, 64)) def GenerateZipNx8(emitter, zip_lanes, leftovers, aligned): """Emit the zip function for a given number of rows and row size leftovers.""" if leftovers < 0 or leftovers > 7: raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') if zip_lanes < 1 or zip_lanes > 4: raise ConfigurationError('Zip_lanes should should be 1, 2, 3 or 4.') name = BuildName(zip_lanes, leftovers, aligned) emitter.EmitFunctionBeginA( name, [['const std::uint8_t*', 'source'], ['std::int32_t', 'count'], ['std::int32_t', 'stride'], ['std::uint8_t*', 'destination'], ['std::int32_t', 'multiplicative_offset'], ['std::int32_t', 'additive_offset']], 'void') emitter.EmitAssert('count %% 8 == %d' % leftovers) emitter.EmitAssert('count <= 2048') emitter.EmitAssert('count >= 8') emitter.EmitAssert('reinterpret_cast(destination) % 8 == 0') if aligned: emitter.EmitAssert('reinterpret_cast(source) % 8 == 0') if zip_lanes > 1: emitter.EmitAssert('stride % 8 == 0') emitter.EmitAsmBegin() registers = neon_emitter.NeonRegisters() count = registers.MapParameter('count') output_address = registers.MapParameter('destination') lanes = GenerateZipLanes(emitter, registers, zip_lanes, registers.MapParameter('source'), registers.MapParameter('stride')) if leftovers: emitter.EmitSub(count, count, emitter.ImmediateConstant(leftovers)) GenerateClearAggregators(emitter, lanes) emitter.EmitNewline() emitter.EmitNumericalLabel(1) emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) GenerateLoadAggregateStore(emitter, lanes, output_address, 64 if aligned else None) emitter.EmitNewline() emitter.EmitBneBack(1) if leftovers: GenerateLeftoverLoadAggregateStore(emitter, leftovers, lanes, output_address) GenerateAggregatorReduction(emitter, registers, lanes, output_address, registers.MapParameter('multiplicative_offset'), registers.MapParameter('additive_offset')) emitter.EmitAsmEnd(registers.MappedParameters(), [], registers.Clobbers() + ['cc', 'memory']) emitter.EmitFunctionEnd() def GenerateFunctions(emitter): for aligned in [True, False]: for lanes in range(1, 5): for leftovers in range(0, 8): GenerateZipNx8(emitter, lanes, leftovers, aligned) emitter.EmitNewline()