451 lines
16 KiB
Python
451 lines
16 KiB
Python
"""Mul primitive used by the GEMM function.
|
|
|
|
The Mul primitive takes 1-3 zipped rows and 1-3 zipped columns and performs
|
|
matrix multiplication on those resulting in a small 1x1 to 3x3 block of results.
|
|
"""
|
|
|
|
import neon_emitter
|
|
|
|
|
|
class Error(Exception):
|
|
"""Module level error."""
|
|
|
|
|
|
class ConfigurationError(Error):
|
|
"""Unsupported configuration."""
|
|
|
|
|
|
class MulLanes(object):
|
|
|
|
def __init__(self, input_address):
|
|
self.input_address = input_address
|
|
self.lanes = []
|
|
|
|
def AddLane(self, lane):
|
|
self.lanes.append(lane)
|
|
|
|
def FreeRegisters(self, registers):
|
|
for i in range(0, len(self.lanes)):
|
|
registers.FreeRegister(self.lanes[i])
|
|
self.lanes[i] = None
|
|
|
|
|
|
def GenerateMulLanes(registers, lane_count, address):
|
|
lanes = MulLanes(address)
|
|
for unused_i in range(0, lane_count):
|
|
lanes.AddLane(registers.DoubleRegister())
|
|
return lanes
|
|
|
|
|
|
def Generate3MulLanes(quad_register, registers, address):
|
|
lanes = MulLanes(address)
|
|
lanes.AddLane(registers.Low(quad_register))
|
|
lanes.AddLane(registers.High(quad_register))
|
|
lanes.AddLane(registers.DoubleRegister())
|
|
return lanes
|
|
|
|
|
|
def GenerateAndClearAggregators(emitter, registers, aggregator_count):
|
|
"""Prepare aggregators and emit aggregator clear code."""
|
|
emitter.EmitComment('Clear aggregators.')
|
|
aggregators = []
|
|
for i in range(0, aggregator_count):
|
|
aggregator = registers.QuadRegister()
|
|
aggregators.append(aggregator)
|
|
if i < 3:
|
|
emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
|
|
else:
|
|
emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
|
|
emitter.EmitNewline()
|
|
return aggregators
|
|
|
|
|
|
def GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
|
|
right_lanes, aggregators, count):
|
|
"""Emit inner loop for N rows x M cols multiplication."""
|
|
emitter.EmitComment('General NxM lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
emitter.EmitVLoadA('1.8', left_lanes.lanes,
|
|
emitter.DereferenceIncrement(left_lanes.input_address, 64))
|
|
emitter.EmitVLoadA(
|
|
'1.8', right_lanes.lanes,
|
|
emitter.DereferenceIncrement(right_lanes.input_address, 64))
|
|
|
|
emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
|
|
emitter.EmitPldOffset(right_lanes.input_address,
|
|
emitter.ImmediateConstant(64))
|
|
|
|
rows = len(left_lanes.lanes)
|
|
cols = len(right_lanes.lanes)
|
|
|
|
multiply_results = []
|
|
for i in range(0, rows * cols):
|
|
multiply_results.append(registers.QuadRegister())
|
|
|
|
for row in range(0, rows):
|
|
for col in range(0, cols):
|
|
index = row * cols + col
|
|
emitter.EmitVMull('u8', multiply_results[index], right_lanes.lanes[col],
|
|
left_lanes.lanes[row])
|
|
|
|
for i in range(0, rows * cols):
|
|
emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBneBack(1)
|
|
emitter.EmitNewline()
|
|
|
|
for register in multiply_results:
|
|
registers.FreeRegister(register)
|
|
|
|
|
|
def Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
|
|
right_lanes, aggregators, count,
|
|
backup_register):
|
|
"""Emit inner loop for 3 rows x 3 cols multiplication (register trick)."""
|
|
emitter.EmitComment('3x3 lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
emitter.EmitVLoadA('1.8', left_lanes.lanes,
|
|
emitter.DereferenceIncrement(left_lanes.input_address, 64))
|
|
emitter.EmitVLoadA(
|
|
'1.8', right_lanes.lanes,
|
|
emitter.DereferenceIncrement(right_lanes.input_address, 64))
|
|
|
|
emitter.EmitPldOffset(left_lanes.input_address, emitter.ImmediateConstant(64))
|
|
emitter.EmitPldOffset(right_lanes.input_address,
|
|
emitter.ImmediateConstant(64))
|
|
|
|
temp = []
|
|
for unused_i in range(0, 4):
|
|
temp.append(registers.QuadRegister())
|
|
|
|
emitter.EmitVMull('u8', temp[0], left_lanes.lanes[0], right_lanes.lanes[0])
|
|
emitter.EmitVMull('u8', temp[1], left_lanes.lanes[0], right_lanes.lanes[1])
|
|
emitter.EmitVMull('u8', temp[2], left_lanes.lanes[0], right_lanes.lanes[2])
|
|
emitter.EmitVMull('u8', temp[3], left_lanes.lanes[1], right_lanes.lanes[0])
|
|
|
|
emitter.EmitVPadal('u16', aggregators[0], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[1], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[2], temp[2])
|
|
emitter.EmitVPadal('u16', aggregators[3], temp[3])
|
|
|
|
emitter.EmitVMull('u8', temp[0], left_lanes.lanes[1], right_lanes.lanes[1])
|
|
emitter.EmitVMull('u8', temp[1], left_lanes.lanes[1], right_lanes.lanes[2])
|
|
emitter.EmitVMull('u8', temp[2], left_lanes.lanes[2], right_lanes.lanes[0])
|
|
emitter.EmitVMull('u8', temp[3], left_lanes.lanes[2], right_lanes.lanes[1])
|
|
emitter.EmitVMull('u8', backup_register, left_lanes.lanes[2],
|
|
right_lanes.lanes[2])
|
|
|
|
emitter.EmitVPadal('u16', aggregators[4], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[5], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[6], temp[2])
|
|
emitter.EmitVPadal('u16', aggregators[7], temp[3])
|
|
emitter.EmitVPadal('u16', aggregators[8], backup_register)
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBneBack(1)
|
|
emitter.EmitNewline()
|
|
|
|
for register in temp:
|
|
registers.FreeRegister(register)
|
|
|
|
|
|
def ReadParams(emitter, registers, input_address, elements, min_reg):
|
|
if elements == 1 or elements == 2:
|
|
register = registers.DoubleRegister(min_reg * 2)
|
|
emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
|
|
return register
|
|
elif elements == 3 or elements == 4:
|
|
register = registers.QuadRegister(min_reg)
|
|
emitter.EmitVLoad('1.32', register, emitter.Dereference(input_address, 64))
|
|
return register
|
|
else:
|
|
raise ConfigurationError('Unsupported elements no: %d' % elements)
|
|
|
|
|
|
def Duplicate(emitter, registers, rows, cols, min_register, values):
|
|
"""Populate a grid of registers duplicating provided values."""
|
|
duplicated = []
|
|
if cols == 1 or cols == 2:
|
|
for unused_i in range(0, rows):
|
|
duplicated.append(registers.DoubleRegister(min_register))
|
|
elif cols == 3 or cols == 4:
|
|
for unused_i in range(0, rows):
|
|
duplicated.append(registers.QuadRegister(min_register))
|
|
else:
|
|
raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
|
|
|
|
if rows == 1:
|
|
emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
|
|
elif rows == 2:
|
|
emitter.EmitVDup('32', duplicated[0], emitter.Lane(values, 0))
|
|
emitter.EmitVDup('32', duplicated[1], emitter.Lane(values, 1))
|
|
elif rows == 3:
|
|
emitter.EmitVDup('32', duplicated[0], emitter.Lane(
|
|
registers.Low(values), 0))
|
|
emitter.EmitVDup('32', duplicated[1], emitter.Lane(
|
|
registers.Low(values), 1))
|
|
emitter.EmitVDup('32', duplicated[2], emitter.Lane(
|
|
registers.High(values), 0))
|
|
elif rows == 4:
|
|
emitter.EmitVDup('32', duplicated[0], emitter.Lane(
|
|
registers.Low(values), 0))
|
|
emitter.EmitVDup('32', duplicated[1], emitter.Lane(
|
|
registers.Low(values), 1))
|
|
emitter.EmitVDup('32', duplicated[2], emitter.Lane(
|
|
registers.High(values), 0))
|
|
emitter.EmitVDup('32', duplicated[3], emitter.Lane(
|
|
registers.High(values), 1))
|
|
|
|
return duplicated
|
|
|
|
|
|
def DuplicateGeneralRegister(emitter, registers, cols, general_register,
|
|
min_register):
|
|
if cols == 1 or cols == 2:
|
|
duplicated = registers.DoubleRegister(min_register)
|
|
elif cols == 3 or cols == 4:
|
|
duplicated = registers.QuadRegister(min_register)
|
|
else:
|
|
raise ConfigurationError('Unsupported duplicate amount: %d' % cols)
|
|
|
|
emitter.EmitVDup('32', duplicated, general_register)
|
|
return duplicated
|
|
|
|
|
|
def ReduceAggregator(emitter, registers, aggregators, row, cols):
|
|
if cols == 1:
|
|
register = registers.Low(aggregators[row])
|
|
emitter.EmitVPadd('u32', register, register, register)
|
|
return register
|
|
elif cols == 2:
|
|
register = registers.Low(aggregators[row * 2])
|
|
emitter.EmitVPadd('u32', register, register,
|
|
registers.Low(aggregators[row * 2 + 1]))
|
|
return register
|
|
elif cols == 3:
|
|
register = aggregators[row * 3]
|
|
emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
|
|
registers.Low(aggregators[row * 3 + 1]))
|
|
emitter.EmitVPadd('u32', registers.High(register),
|
|
registers.Low(aggregators[row * 3 + 2]),
|
|
registers.Low(aggregators[row * 3 + 2]))
|
|
return register
|
|
elif cols == 4:
|
|
register = aggregators[row * 3]
|
|
emitter.EmitVPadd('u32', registers.Low(register), registers.Low(register),
|
|
registers.Low(aggregators[row * 3 + 1]))
|
|
emitter.EmitVPadd('u32', registers.High(register),
|
|
registers.Low(aggregators[row * 3 + 2]),
|
|
registers.Low(aggregators[row * 3 + 3]))
|
|
return register
|
|
else:
|
|
raise ConfigurationError('Unsupported columns no: %d' % cols)
|
|
|
|
|
|
def StoreAggregator(emitter, registers, aggregator, cols, result_address,
|
|
result_stride):
|
|
if cols == 1:
|
|
emitter.EmitVStoreOffset('1.32', emitter.Lane(aggregator, 0),
|
|
emitter.Dereference(result_address, None),
|
|
result_stride)
|
|
elif cols == 2:
|
|
emitter.EmitVStoreOffset('1.32', aggregator,
|
|
emitter.Dereference(result_address, None),
|
|
result_stride)
|
|
elif cols == 3:
|
|
emitter.EmitVStore('1.32', registers.Low(aggregator),
|
|
emitter.DereferenceIncrement(result_address, None))
|
|
emitter.EmitVStoreOffset('1.32', emitter.Lane(
|
|
registers.High(aggregator),
|
|
0), emitter.Dereference(result_address, None), result_stride)
|
|
emitter.EmitNewline()
|
|
elif cols == 4:
|
|
emitter.EmitVStoreOffsetA(
|
|
'1.32', [registers.Low(aggregator), registers.High(aggregator)],
|
|
emitter.Dereference(result_address, None), result_stride)
|
|
else:
|
|
raise ConfigurationError('Unsupported columns no: %d' % cols)
|
|
|
|
|
|
def GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
|
|
lhs_add, rhs_add, left_lanes, right_lanes,
|
|
results, results_stride):
|
|
"""Emit code that reduces 4 lane aggregators to 1 value, and stores them."""
|
|
rows = len(left_lanes.lanes)
|
|
cols = len(right_lanes.lanes)
|
|
|
|
if lhs_add:
|
|
left_offset = ReadParams(emitter, registers, left_lanes.input_address, rows,
|
|
4)
|
|
left_offsets = Duplicate(emitter, registers, rows, cols, 4, left_offset)
|
|
else:
|
|
left_offsets = None
|
|
|
|
if rhs_add:
|
|
right_offset = ReadParams(emitter, registers, right_lanes.input_address,
|
|
cols, 4)
|
|
else:
|
|
right_offset = None
|
|
|
|
if result_type is 'float':
|
|
result_scale = DuplicateGeneralRegister(
|
|
emitter, registers, cols, registers.MapParameter('result_scale'), 4)
|
|
else:
|
|
result_scale = None
|
|
|
|
if cols == 3:
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Change stride because storing in two ops.')
|
|
emitter.EmitSub(results_stride, results_stride,
|
|
emitter.ImmediateConstant(8))
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Horizontal reduce aggregators.')
|
|
for aggregator in aggregators:
|
|
emitter.EmitVPadd('u32', registers.Low(aggregator),
|
|
registers.Low(aggregator), registers.High(aggregator))
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Reduce rows.')
|
|
row_temps = []
|
|
for i in range(0, rows):
|
|
row_temps.append(ReduceAggregator(emitter, registers, aggregators, i, cols))
|
|
|
|
if lhs_add:
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Add lhs offsets to aggregated rows.')
|
|
for (row_temp, left_offset) in zip(row_temps, left_offsets):
|
|
emitter.EmitVAdd('s32', row_temp, row_temp, left_offset)
|
|
|
|
if rhs_add:
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Add rhs offset to aggregated rows.')
|
|
for row_temp in row_temps:
|
|
emitter.EmitVAdd('s32', row_temp, row_temp, right_offset)
|
|
|
|
if result_type is 'float':
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Convert to float. Multiply by result scale.')
|
|
for row_temp in row_temps:
|
|
emitter.EmitVCvt('f32', 's32', row_temp, row_temp)
|
|
for row_temp in row_temps:
|
|
emitter.EmitVMul('f32', row_temp, row_temp, result_scale)
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Store reduced rows.')
|
|
for row_temp in row_temps:
|
|
StoreAggregator(emitter, registers, row_temp, cols, results, results_stride)
|
|
|
|
|
|
def BuildName(result_type, lhs_add, rhs_add, left, right):
|
|
name = 'mul_%dx8_%dx8_%s' % (left, right, result_type)
|
|
if lhs_add:
|
|
name += '_lhsadd'
|
|
if rhs_add:
|
|
name += '_rhsadd'
|
|
return name
|
|
|
|
|
|
def CppResultType(result_type):
|
|
if result_type is 'int32':
|
|
return 'std::int32_t*'
|
|
elif result_type is 'float':
|
|
return 'float*'
|
|
else:
|
|
raise ConfigurationError('Unsupported result type: %s' % result_type)
|
|
|
|
|
|
def GetParameters(result_type):
|
|
params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs'],
|
|
['std::int32_t', 'count'], [CppResultType(result_type), 'result'],
|
|
['std::int32_t', 'result_stride']]
|
|
if result_type is 'float':
|
|
params.append(['float', 'result_scale'])
|
|
return params
|
|
|
|
|
|
def GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes_count,
|
|
right_lanes_count):
|
|
"""Emit the multiply code for given rows and cols counts."""
|
|
if left_lanes_count < 1 or left_lanes_count > 4:
|
|
raise ConfigurationError('Left_lanes should be: 1, 2, 3 or 4.')
|
|
if right_lanes_count < 1 or right_lanes_count > 4:
|
|
raise ConfigurationError('Right_lanes should be: 1, 2, 3 or 4.')
|
|
|
|
emitter.EmitFunctionBeginA(
|
|
BuildName(result_type, lhs_add, rhs_add, left_lanes_count,
|
|
right_lanes_count), GetParameters(result_type), 'inline void')
|
|
|
|
emitter.EmitAssert('count % 8 == 0')
|
|
emitter.EmitAssert('count >= 8')
|
|
emitter.EmitAsmBegin()
|
|
|
|
registers = neon_emitter.NeonRegisters()
|
|
|
|
count = registers.MapParameter('count')
|
|
|
|
size = left_lanes_count * right_lanes_count
|
|
|
|
lhs = registers.MapParameter('lhs')
|
|
rhs = registers.MapParameter('rhs')
|
|
|
|
emitter.EmitPld(lhs)
|
|
emitter.EmitPld(rhs)
|
|
|
|
aggregators = GenerateAndClearAggregators(emitter, registers, size)
|
|
|
|
if size < 9:
|
|
left_lanes = GenerateMulLanes(registers, left_lanes_count, lhs)
|
|
right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs)
|
|
|
|
GenerateNxMLoadMultiplyAggregate(emitter, registers, left_lanes,
|
|
right_lanes, aggregators, count)
|
|
|
|
else: # left == 3 and right == 3
|
|
backup_register = registers.QuadRegister()
|
|
left_lanes = Generate3MulLanes(backup_register, registers, lhs)
|
|
right_lanes = GenerateMulLanes(registers, right_lanes_count, rhs)
|
|
|
|
Generate3x3LoadMultiplyAggregate(emitter, registers, left_lanes,
|
|
right_lanes, aggregators, count,
|
|
backup_register)
|
|
left_lanes.FreeRegisters(registers)
|
|
right_lanes.FreeRegisters(registers)
|
|
|
|
GenerateAggregatorReduceStore(emitter, registers, aggregators, result_type,
|
|
lhs_add, rhs_add, left_lanes, right_lanes,
|
|
registers.MapParameter('result'),
|
|
registers.MapParameter('result_stride'))
|
|
|
|
emitter.EmitAsmEnd(registers.MappedParameters(), [],
|
|
registers.Clobbers() + ['cc', 'memory'])
|
|
emitter.EmitFunctionEnd()
|
|
|
|
|
|
def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
|
|
for left_lanes in range(1, 4):
|
|
for right_lanes in range(1, 4):
|
|
GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, left_lanes,
|
|
right_lanes)
|
|
emitter.EmitNewline()
|
|
|
|
GenerateMulNx8Mx8(emitter, result_type, lhs_add, rhs_add, 1, 4)
|
|
emitter.EmitNewline()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)
|