286 lines
10 KiB
Python
286 lines
10 KiB
Python
"""Multiply primitive optimized for the gemv operation."""
|
|
|
|
import neon_emitter
|
|
|
|
|
|
class Error(Exception):
|
|
"""Module level error."""
|
|
|
|
|
|
class ConfigurationError(Error):
|
|
"""Unsupported configuration."""
|
|
|
|
|
|
def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
|
|
count, lhs, rhs_1, rhs_2):
|
|
"""Emit inner loop for 1 row x M cols multiplication."""
|
|
emitter.EmitComment('General 1xM lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
right_load = [registers.DoubleRegister() for unused_i in range(4)]
|
|
left_load = registers.DoubleRegister()
|
|
|
|
emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64))
|
|
emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64))
|
|
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
|
|
emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128))
|
|
|
|
multiply_results = [registers.QuadRegister() for unused_i in range(4)]
|
|
|
|
for i in range(4):
|
|
emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
|
|
|
|
emitter.EmitVLoadA('1.8', right_load[:lanes_count],
|
|
emitter.DereferenceIncrement(rhs_2, 64))
|
|
emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32))
|
|
|
|
for i in range(4):
|
|
emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])
|
|
|
|
for i in range(lanes_count):
|
|
emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)
|
|
|
|
for i in range(lanes_count):
|
|
emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBneBack(1)
|
|
emitter.EmitNewline()
|
|
|
|
registers.FreeRegister(left_load)
|
|
registers.FreeRegisters(right_load)
|
|
registers.FreeRegisters(multiply_results)
|
|
|
|
|
|
def ReadLeft(emitter, registers, lhs):
|
|
register = registers.QuadRegister()
|
|
emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
|
|
emitter.AllLanes(registers.High(register))],
|
|
emitter.Dereference(lhs, None))
|
|
return register
|
|
|
|
|
|
def ReadRight(emitter, registers, rhs, count):
|
|
if count == 1 or count == 2:
|
|
register = registers.DoubleRegister()
|
|
elif count == 3 or count == 4:
|
|
register = registers.QuadRegister()
|
|
else:
|
|
raise ConfigurationError('Unsupported elements no: %d' % count)
|
|
emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64))
|
|
return register
|
|
|
|
|
|
def DuplicateGeneralRegister(emitter, registers, general_register,
|
|
min_register):
|
|
duplicated = registers.QuadRegister(min_register)
|
|
emitter.EmitVDup('32', duplicated, general_register)
|
|
return duplicated
|
|
|
|
|
|
def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
|
|
result_type, lhs_add, rhs_add, lhs, rhs_1,
|
|
rhs_2, results):
|
|
"""Generates assembly responsible for reducing the 4 way aggregators."""
|
|
if lhs_add:
|
|
left_offset = ReadLeft(emitter, registers, lhs)
|
|
else:
|
|
left_offset = None
|
|
|
|
if rhs_add:
|
|
right_offset_1 = ReadRight(emitter, registers, rhs_1, 4)
|
|
right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count)
|
|
else:
|
|
right_offset_1 = None
|
|
right_offset_2 = None
|
|
|
|
if result_type is 'float':
|
|
result_scale = DuplicateGeneralRegister(
|
|
emitter, registers, registers.MapParameter('result_scale'), 4)
|
|
else:
|
|
result_scale = None
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Horizontal reduce aggregators.')
|
|
for aggregator in aggregators:
|
|
emitter.EmitVPadd('u32', registers.Low(aggregator),
|
|
registers.Low(aggregator), registers.High(aggregator))
|
|
|
|
temp = aggregators[0]
|
|
emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]),
|
|
registers.Low(aggregators[1]))
|
|
emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]),
|
|
registers.Low(aggregators[3]))
|
|
|
|
if lanes_count == 1:
|
|
temp_2 = registers.Low(aggregators[1])
|
|
emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
|
|
registers.Low(aggregators[4]))
|
|
elif lanes_count == 2:
|
|
temp_2 = registers.Low(aggregators[1])
|
|
emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
|
|
registers.Low(aggregators[5]))
|
|
elif lanes_count == 3:
|
|
temp_2 = aggregators[1]
|
|
emitter.EmitVPadd('u32', registers.Low(temp_2),
|
|
registers.Low(aggregators[4]),
|
|
registers.Low(aggregators[5]))
|
|
emitter.EmitVPadd('u32', registers.High(temp_2),
|
|
registers.Low(aggregators[6]),
|
|
registers.Low(aggregators[6]))
|
|
elif lanes_count == 4:
|
|
temp_2 = aggregators[1]
|
|
emitter.EmitVPadd('u32', registers.Low(temp_2),
|
|
registers.Low(aggregators[4]),
|
|
registers.Low(aggregators[5]))
|
|
emitter.EmitVPadd('u32', registers.High(temp_2),
|
|
registers.Low(aggregators[6]),
|
|
registers.Low(aggregators[7]))
|
|
else:
|
|
temp_2 = None
|
|
|
|
if lhs_add:
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Add lhs offsets to aggregated rows.')
|
|
emitter.EmitVAdd('s32', temp, temp, left_offset)
|
|
if lanes_count == 1 or lanes_count == 2:
|
|
emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset))
|
|
elif lanes_count == 3 or lanes_count == 4:
|
|
emitter.EmitVAdd('s32', temp_2, temp_2, left_offset)
|
|
|
|
if rhs_add:
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Add rhs offset to aggregated rows.')
|
|
emitter.EmitVAdd('s32', temp, temp, right_offset_1)
|
|
emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2)
|
|
|
|
if result_type is 'float':
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Convert to float and scale.')
|
|
emitter.EmitVCvt('f32', 's32', temp, temp)
|
|
emitter.EmitVCvt('f32', 's32', temp_2, temp_2)
|
|
emitter.EmitVMul('f32', temp, temp, result_scale)
|
|
if lanes_count == 1 or lanes_count == 2:
|
|
emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale))
|
|
elif lanes_count == 3 or lanes_count == 4:
|
|
emitter.EmitVMul('f32', temp_2, temp_2, result_scale)
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Store results.')
|
|
if lanes_count == 1:
|
|
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)],
|
|
emitter.DereferenceIncrement(results, None))
|
|
emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0),
|
|
emitter.Dereference(results, None))
|
|
elif lanes_count == 2:
|
|
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
|
|
temp_2], emitter.Dereference(results, None))
|
|
elif lanes_count == 3:
|
|
emitter.EmitVStoreA(
|
|
'1.32',
|
|
[registers.Low(temp), registers.High(temp), registers.Low(temp_2)],
|
|
emitter.DereferenceIncrement(results, None))
|
|
emitter.EmitVStore('1.32', emitter.Lane(
|
|
registers.High(temp_2), 0), emitter.Dereference(results, None))
|
|
elif lanes_count == 4:
|
|
emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
|
|
registers.Low(temp_2), registers.High(temp_2)],
|
|
emitter.Dereference(results, None))
|
|
|
|
|
|
def BuildName(result_type, lhs_add, rhs_add, lanes):
|
|
name = 'mul_1x8_%dx8_%s' % (lanes, result_type)
|
|
if lhs_add:
|
|
name += '_lhsadd'
|
|
if rhs_add:
|
|
name += '_rhsadd'
|
|
return name
|
|
|
|
|
|
def CppResultType(result_type):
|
|
if result_type is 'int32':
|
|
return 'std::int32_t*'
|
|
elif result_type is 'float':
|
|
return 'float*'
|
|
else:
|
|
raise ConfigurationError('Unsupported result type: %s' % result_type)
|
|
|
|
|
|
def GetParameters(result_type):
|
|
params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'],
|
|
['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'],
|
|
[CppResultType(result_type), 'result']]
|
|
if result_type is 'float':
|
|
params.append(['float', 'result_scale'])
|
|
return params
|
|
|
|
|
|
def GenerateAndClearAggregators(emitter, registers, aggregator_count):
|
|
"""Prepare aggregators and emit aggregator clear code."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Clear aggregators.')
|
|
aggregators = []
|
|
for i in range(aggregator_count):
|
|
aggregator = registers.QuadRegister()
|
|
aggregators.append(aggregator)
|
|
if i < 3:
|
|
emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
|
|
else:
|
|
emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
|
|
emitter.EmitNewline()
|
|
return aggregators
|
|
|
|
|
|
def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count):
|
|
"""Generates the 1xN multiplication primitive."""
|
|
if lanes_count < 1 or lanes_count > 4:
|
|
raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.')
|
|
|
|
emitter.EmitFunctionBeginA(
|
|
BuildName(result_type, lhs_add, rhs_add, lanes_count + 4),
|
|
GetParameters(result_type), 'inline void')
|
|
|
|
emitter.EmitAssert('count % 8 == 0')
|
|
emitter.EmitAssert('count >= 8')
|
|
emitter.EmitAsmBegin()
|
|
|
|
registers = neon_emitter.NeonRegisters()
|
|
|
|
count = registers.MapParameter('count')
|
|
|
|
lhs = registers.MapParameter('lhs')
|
|
rhs_1 = registers.MapParameter('rhs_1')
|
|
rhs_2 = registers.MapParameter('rhs_2')
|
|
|
|
emitter.EmitPld(lhs)
|
|
emitter.EmitPld(rhs_1)
|
|
emitter.EmitPld(rhs_2)
|
|
|
|
aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4)
|
|
|
|
GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
|
|
count, lhs, rhs_1, rhs_2)
|
|
GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
|
|
result_type, lhs_add, rhs_add, lhs, rhs_1,
|
|
rhs_2, registers.MapParameter('result'))
|
|
|
|
emitter.EmitAsmEnd(registers.MappedParameters(), [],
|
|
registers.Clobbers() + ['cc', 'memory'])
|
|
emitter.EmitFunctionEnd()
|
|
|
|
|
|
def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
|
|
for lanes in range(1, 5):
|
|
GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes)
|
|
emitter.EmitNewline()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)
|