642 lines
23 KiB
Python
642 lines
23 KiB
Python
# Copyright 2016 The Gemmlowp Authors. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""."""
|
|
|
|
import common
|
|
|
|
|
|
def _ReadParams(emitter, registers, input_address, elements, min_register):
|
|
registers_count = (elements + 3) / 4
|
|
registers = [
|
|
registers.QuadRegister(min_register)
|
|
for unused_i in range(registers_count)
|
|
]
|
|
emitter.EmitVLoadAE(registers_count * 4, 32, registers, input_address, 64)
|
|
return registers
|
|
|
|
|
|
def _Duplicate(emitter, registers, rows, values):
|
|
"""Populate a grid of registers duplicating provided values."""
|
|
duplicated = []
|
|
for i in range(rows):
|
|
if i is rows - 1:
|
|
duplicated.append(values[0])
|
|
else:
|
|
duplicated.append(registers.QuadRegister())
|
|
|
|
emitter.EmitVDup('32', duplicated[i],
|
|
emitter.Lane(32, values[i / 4], i % 4))
|
|
|
|
return duplicated
|
|
|
|
|
|
def _DuplicateGeneralRegister(emitter, registers, value, min_register):
|
|
register = registers.QuadRegister(min_register)
|
|
emitter.EmitVDup('32', register, value)
|
|
return register
|
|
|
|
|
|
class _StaticQuantizationUInt8Transformation(object):
|
|
"""Calculate quantized values and cast back to uint8."""
|
|
|
|
def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
|
|
"""Load parameters and prepare duplicated registers."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantization::Prepare')
|
|
|
|
lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
|
|
self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
|
|
self.multiplicative_offset = _DuplicateGeneralRegister(
|
|
emitter, registers,
|
|
registers.MapParameter('multiplicative_offset',
|
|
'params.kernel.multiplicative_offset'), 4)
|
|
self.rounding_offset = _DuplicateGeneralRegister(
|
|
emitter, registers,
|
|
registers.MapParameter('rounding_offset',
|
|
'params.kernel.rounding_offset'), 4)
|
|
self.shift = _DuplicateGeneralRegister(
|
|
emitter, registers,
|
|
registers.MapParameter('shift', 'params.kernel.shift'), 4)
|
|
self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
|
|
|
|
def Transform(self, emitter, registers, data, unused_kernel_m,
|
|
unused_kernel_n):
|
|
"""Quantize the data."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantization::Transform')
|
|
|
|
for (row, lhs_offset) in zip(data, self.lhs_offsets):
|
|
for row_register in row:
|
|
emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
|
|
|
|
for row in data:
|
|
for (row_register, rhs_offset_register) in zip(row, self.rhs_offsets):
|
|
emitter.EmitVAdd('s32', row_register, row_register, rhs_offset_register)
|
|
|
|
for row in data:
|
|
for row_register in row:
|
|
emitter.EmitVMul('i32', row_register, row_register,
|
|
self.multiplicative_offset)
|
|
|
|
for row in data:
|
|
for row_register in row:
|
|
emitter.EmitVAdd('i32', row_register, row_register,
|
|
self.rounding_offset)
|
|
|
|
for row in data:
|
|
for row_register in row:
|
|
emitter.EmitVShl('s32', row_register, row_register, self.shift)
|
|
|
|
if len(data[0]) is 1:
|
|
for row in data:
|
|
emitter.EmitVQmovn('s32', row[0], row[0])
|
|
|
|
for row in data:
|
|
emitter.EmitVQmovun('s16', row[0], row[0])
|
|
|
|
return data
|
|
elif len(data[0]) is 2:
|
|
results = []
|
|
for row in data:
|
|
emitter.EmitVQmovn2('s32', row[0], row[0], row[1])
|
|
registers.FreeRegister(row[1])
|
|
results.append([row[0]])
|
|
|
|
for row in results:
|
|
emitter.EmitVQmovun('s16', row[0], row[0])
|
|
|
|
return results
|
|
else:
|
|
assert False
|
|
|
|
def Type(self):
|
|
return 8
|
|
|
|
|
|
class _StaticQuantizationInt32Transformation(object):
|
|
"""."""
|
|
|
|
def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantizationInt32::Prepare')
|
|
|
|
lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
|
|
self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
|
|
self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
|
|
|
|
def Transform(self, emitter, unused_registers, data, unused_kernel_m,
|
|
unused_kernel_n):
|
|
"""Quantize data and output as int32."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantizationInt32::Transform')
|
|
|
|
for (row, lhs_offset) in zip(data, self.lhs_offsets):
|
|
for row_register in row:
|
|
emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
|
|
|
|
for row in data:
|
|
for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
|
|
emitter.EmitVAdd('s32', row_register, row_register,
|
|
rhs_offsets_register)
|
|
|
|
return data
|
|
|
|
def Type(self):
|
|
return 32
|
|
|
|
|
|
class _StaticQuantizationFloatTransformation(object):
|
|
"""."""
|
|
|
|
def Prepare(self, emitter, registers, kernel_m, kernel_n, lhs, rhs):
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantizationFloat::Prepare')
|
|
|
|
lhs_offset = _ReadParams(emitter, registers, lhs, kernel_m, 4)
|
|
self.rhs_offsets = _ReadParams(emitter, registers, rhs, kernel_n, 4)
|
|
self.scale = _DuplicateGeneralRegister(
|
|
emitter, registers,
|
|
registers.MapParameter('scale', 'params.kernel.scale'), 4)
|
|
self.lhs_offsets = _Duplicate(emitter, registers, kernel_m, lhs_offset)
|
|
|
|
def Transform(self, emitter, unused_registers, data, unused_kernel_m,
|
|
unused_kernel_n):
|
|
"""Quantize data and output as float."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('StaticQuantizationFloat::Transform')
|
|
|
|
for (row, lhs_offset) in zip(data, self.lhs_offsets):
|
|
for row_register in row:
|
|
emitter.EmitVAdd('s32', row_register, row_register, lhs_offset)
|
|
|
|
for row in data:
|
|
for (row_register, rhs_offsets_register) in zip(row, self.rhs_offsets):
|
|
emitter.EmitVAdd('s32', row_register, row_register,
|
|
rhs_offsets_register)
|
|
|
|
for row in data:
|
|
for row_register in row:
|
|
emitter.EmitVCvt('f32', 's32', row_register, row_register)
|
|
|
|
for row in data:
|
|
for row_register in row:
|
|
emitter.EmitVMul('f32', row_register, row_register, self.scale)
|
|
|
|
return data
|
|
|
|
def Type(self):
|
|
return 32
|
|
|
|
|
|
class _RowMajorOutput(object):
|
|
"""Output data in row major layout."""
|
|
|
|
def Prepare(self, emitter, registers, kernel_m, unused_kernel_n,
|
|
unused_data_type):
|
|
"""Prepare strided load addresses."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('RowMajorOutput::Prepare')
|
|
|
|
stride = registers.MapParameter('stride', 'params.output_stream.stride')
|
|
|
|
self.outputs = []
|
|
self.outputs.append(registers.MapOutputParameter('result'))
|
|
|
|
for unused_i in range(kernel_m - 1):
|
|
register = registers.GeneralRegister()
|
|
emitter.EmitAdd(register, self.outputs[-1], stride)
|
|
self.outputs.append(register)
|
|
|
|
def Output(self, emitter, unused_registers, data, data_type, unused_kernel_m,
|
|
kernel_n):
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('RowMajorOutput::Output')
|
|
|
|
for (datum, output) in zip(data, self.outputs):
|
|
emitter.EmitVStoreAE(data_type, kernel_n, datum, output, None)
|
|
|
|
|
|
def _GenerateAndClearAggregators(emitter, registers, count):
|
|
"""Prepare aggregators and emit aggregator clear code."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Clear aggregators.')
|
|
aggregators = [registers.QuadRegister() for unused_i in range(count)]
|
|
for i in range(count):
|
|
if i < 3:
|
|
emitter.EmitVMov('i32', aggregators[i], emitter.ImmediateConstant(0))
|
|
else:
|
|
emitter.EmitVMov('i32', aggregators[i], aggregators[i - 3])
|
|
return aggregators
|
|
|
|
|
|
def _Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count):
|
|
"""Emit inner loop for 3 rows x 3 cols multiplication."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('3x3 lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
|
|
lhs_load = [registers.DoubleRegister() for unused_i in range(3)]
|
|
rhs_load = [registers.DoubleRegister() for unused_i in range(3)]
|
|
temp = [registers.QuadRegister() for unused_i in range(4)]
|
|
|
|
emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 64))
|
|
emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
|
|
|
|
emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
|
|
emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
|
|
|
|
emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
|
|
emitter.EmitVLoad(1, 8, lhs_load[2], emitter.DereferenceIncrement(lhs, 64))
|
|
|
|
emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
|
|
|
|
emitter.EmitVMull('u8', temp[3], lhs_load[1], rhs_load[0])
|
|
emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
|
|
|
|
emitter.EmitVPadal('u16', aggregators[0], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[1], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[2], temp[2])
|
|
emitter.EmitVPadal('u16', aggregators[3], temp[3])
|
|
|
|
emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
|
|
emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
|
|
|
|
registers.FreeRegisters([lhs_load[0], lhs_load[1]])
|
|
temp.append(registers.QuadRegister())
|
|
|
|
emitter.EmitVMull('u8', temp[2], lhs_load[2], rhs_load[0])
|
|
emitter.EmitVMull('u8', temp[3], lhs_load[2], rhs_load[1])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
emitter.EmitVMull('u8', temp[4], lhs_load[2], rhs_load[2])
|
|
|
|
emitter.EmitVPadal('u16', aggregators[4], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[5], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[6], temp[2])
|
|
emitter.EmitVPadal('u16', aggregators[7], temp[3])
|
|
emitter.EmitVPadal('u16', aggregators[8], temp[4])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBgtBack(1)
|
|
|
|
registers.FreeRegisters(temp + [lhs_load[2]] + rhs_load)
|
|
|
|
|
|
def _Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count):
|
|
"""Emit inner loop for 2 rows x 4 cols multiplication."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('2x4 lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
|
|
lhs_load = [registers.DoubleRegister() for unused_i in range(2)]
|
|
rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
|
|
temp = [registers.QuadRegister() for unused_i in range(5)]
|
|
|
|
emitter.EmitVLoadA(1, 8, rhs_load, emitter.DereferenceIncrement(rhs, 256))
|
|
emitter.EmitVLoad(1, 8, lhs_load[0], emitter.DereferenceIncrement(lhs, 64))
|
|
|
|
emitter.EmitVMull('u8', temp[0], lhs_load[0], rhs_load[0])
|
|
emitter.EmitVLoad(1, 8, lhs_load[1], emitter.DereferenceIncrement(lhs, 64))
|
|
|
|
emitter.EmitVMull('u8', temp[1], lhs_load[0], rhs_load[1])
|
|
emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
|
|
|
|
emitter.EmitVMull('u8', temp[2], lhs_load[0], rhs_load[2])
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
|
|
|
|
emitter.EmitVMull('u8', temp[3], lhs_load[0], rhs_load[3])
|
|
emitter.EmitVMull('u8', temp[4], lhs_load[1], rhs_load[0])
|
|
|
|
emitter.EmitVPadal('u16', aggregators[0], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[1], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[2], temp[2])
|
|
|
|
emitter.EmitVMull('u8', temp[0], lhs_load[1], rhs_load[1])
|
|
emitter.EmitVMull('u8', temp[1], lhs_load[1], rhs_load[2])
|
|
emitter.EmitVMull('u8', temp[2], lhs_load[1], rhs_load[3])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitVPadal('u16', aggregators[3], temp[3])
|
|
emitter.EmitVPadal('u16', aggregators[4], temp[4])
|
|
emitter.EmitVPadal('u16', aggregators[5], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[6], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[7], temp[2])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBgtBack(1)
|
|
|
|
registers.FreeRegisters(temp + lhs_load + rhs_load)
|
|
|
|
|
|
def _Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count):
|
|
"""Emit inner loop for 1 rows x 8 cols multiplication."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('1x8 lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
|
|
lhs_load = registers.DoubleRegister()
|
|
rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
|
|
temp = [registers.QuadRegister() for unused_i in range(5)]
|
|
|
|
emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
|
|
emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
|
|
|
|
emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[0])
|
|
emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[1])
|
|
emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[2])
|
|
emitter.EmitVMull('u8', temp[3], lhs_load, rhs_load[3])
|
|
|
|
emitter.EmitVLoadAE(4 * 8, 8, rhs_load, rhs, 256)
|
|
|
|
emitter.EmitVPadal('u16', aggregators[0], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[1], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[2], temp[2])
|
|
emitter.EmitVPadal('u16', aggregators[3], temp[3])
|
|
|
|
emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(256))
|
|
|
|
emitter.EmitVMull('u8', temp[4], lhs_load, rhs_load[0])
|
|
emitter.EmitVMull('u8', temp[0], lhs_load, rhs_load[1])
|
|
emitter.EmitVMull('u8', temp[1], lhs_load, rhs_load[2])
|
|
emitter.EmitVMull('u8', temp[2], lhs_load, rhs_load[3])
|
|
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(32))
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitVPadal('u16', aggregators[4], temp[4])
|
|
emitter.EmitVPadal('u16', aggregators[5], temp[0])
|
|
emitter.EmitVPadal('u16', aggregators[6], temp[1])
|
|
emitter.EmitVPadal('u16', aggregators[7], temp[2])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBgtBack(1)
|
|
|
|
registers.FreeRegisters(temp + [lhs_load] + rhs_load)
|
|
|
|
|
|
def _GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
|
|
aggregators, lhs, rhs, count):
|
|
"""Emit inner loop for N rows x M cols multiplication."""
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('General NxM lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
lhs_load = [registers.DoubleRegister() for unused_i in range(kernel_m)]
|
|
rhs_load = [registers.DoubleRegister() for unused_i in range(kernel_n)]
|
|
|
|
emitter.EmitVLoadAE(8 * kernel_m, 8, lhs_load, lhs, 64)
|
|
emitter.EmitVLoadAE(8 * kernel_n, 8, rhs_load, rhs, 64)
|
|
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
|
|
emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(64))
|
|
|
|
results = [
|
|
registers.QuadRegister() for unused_i in range(kernel_m * kernel_n)
|
|
]
|
|
|
|
for row in range(kernel_m):
|
|
for col in range(kernel_n):
|
|
index = row * kernel_n + col
|
|
emitter.EmitVMull('u8', results[index], rhs_load[col], lhs_load[row])
|
|
|
|
for i in range(kernel_m * kernel_n):
|
|
emitter.EmitVPadal('u16', aggregators[i], results[i])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBgtBack(1)
|
|
|
|
registers.FreeRegisters(lhs_load + rhs_load + results)
|
|
|
|
|
|
def _Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
|
|
lhs, rhs, count):
|
|
"""Emit inner loop for 1 row x M cols multiplication."""
|
|
assert kernel_n in [5, 6, 7, 8]
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('General 1xM lanes loop.')
|
|
emitter.EmitNumericalLabel(1)
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Subtract counter.')
|
|
emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
|
|
emitter.EmitNewline()
|
|
|
|
leftover = kernel_n - 4
|
|
|
|
rhs_load = [registers.DoubleRegister() for unused_i in range(4)]
|
|
lhs_load = registers.DoubleRegister()
|
|
|
|
emitter.EmitVLoadAE(8 * 4, 8, rhs_load, rhs, 64)
|
|
emitter.EmitVLoadE(8, 8, lhs_load, lhs, 64)
|
|
|
|
emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
|
|
|
|
results = [registers.QuadRegister() for unused_i in range(4)]
|
|
|
|
for i in range(4):
|
|
emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
|
|
|
|
emitter.EmitVLoadAE(8 * leftover, 8, rhs_load, rhs, 64)
|
|
emitter.EmitPldOffset(rhs, emitter.ImmediateConstant(128))
|
|
|
|
for i in range(4):
|
|
emitter.EmitVPadal('u16', aggregators[i], results[i])
|
|
|
|
for i in range(leftover):
|
|
emitter.EmitVMull('u8', results[i], rhs_load[i], lhs_load)
|
|
|
|
for i in range(leftover):
|
|
emitter.EmitVPadal('u16', aggregators[i + 4], results[i])
|
|
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Loop break.')
|
|
emitter.EmitBgtBack(1)
|
|
|
|
registers.FreeRegisters([lhs_load] + rhs_load + results)
|
|
|
|
|
|
def _GenerateMultiplyKernel(emitter, registers, kernel_m, kernel_n, lhs, rhs):
|
|
"""Main muliply loop. Pick best implementation for given kernel shape."""
|
|
count = registers.MapParameter('count', 'params.kernel.count')
|
|
|
|
aggregators = _GenerateAndClearAggregators(emitter, registers,
|
|
kernel_m * kernel_n)
|
|
if kernel_m is 3 and kernel_n is 3:
|
|
_Generate3x3LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count)
|
|
elif kernel_m is 2 and kernel_n is 4:
|
|
_Generate2x4LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count)
|
|
elif kernel_m is 1 and kernel_n is 8:
|
|
_Generate1x8LoadMultiplyAggregate(emitter, registers, aggregators, lhs, rhs,
|
|
count)
|
|
elif kernel_m is 1 and kernel_n > 4:
|
|
_Generate1xNLoadMultiplyAggregate(emitter, registers, kernel_n, aggregators,
|
|
lhs, rhs, count)
|
|
else:
|
|
_GenerateNxMLoadMultiplyAggregate(emitter, registers, kernel_m, kernel_n,
|
|
aggregators, lhs, rhs, count)
|
|
return aggregators
|
|
|
|
|
|
def _ReduceAggregators(emitter, aggregators):
|
|
reduced_count = (len(aggregators) + 3) / 4
|
|
reduced = aggregators[:reduced_count]
|
|
emitter.EmitVSumReduce('u32', len(aggregators), 4, reduced, aggregators)
|
|
return reduced
|
|
|
|
|
|
def _GenerateAggregatorReduce(emitter, aggregators, kernel_m, kernel_n):
|
|
emitter.EmitNewline()
|
|
emitter.EmitComment('Reduce aggregators.')
|
|
row_temps = []
|
|
for i in range(kernel_m):
|
|
row_temps.append(
|
|
_ReduceAggregators(emitter, aggregators[i * kernel_n:(i + 1) *
|
|
kernel_n]))
|
|
return row_temps
|
|
|
|
|
|
class QuantizedMulKernel(common.MulKernelGenerator):
|
|
"""."""
|
|
|
|
def __init__(self, cc_emitter, kernel_name, output_stream_name, asm_emitter,
|
|
fused_transformation, output_strategy):
|
|
common.MulKernelGenerator.__init__(self, cc_emitter, kernel_name,
|
|
output_stream_name)
|
|
self.asm_emitter = asm_emitter
|
|
self.fused_transformation = fused_transformation
|
|
self.output_strategy = output_strategy
|
|
|
|
def EmitMultiply(self, in_type, out_type, kernel_m, kernel_n, pack_size):
|
|
assert in_type is 'uint8_t'
|
|
assert pack_size is 8
|
|
assert kernel_m * kernel_n <= 9
|
|
|
|
registers = self.asm_emitter.CreateRegisters()
|
|
|
|
self.asm_emitter.PushIndent(self.emitter.indent)
|
|
self.asm_emitter.EmitAsmBegin()
|
|
|
|
lhs = registers.MapOutputParameter('lhs')
|
|
rhs = registers.MapOutputParameter('rhs')
|
|
self.asm_emitter.EmitPld(lhs)
|
|
self.asm_emitter.EmitPld(rhs)
|
|
|
|
aggregators = _GenerateMultiplyKernel(self.asm_emitter, registers, kernel_m,
|
|
kernel_n, lhs, rhs)
|
|
|
|
self.fused_transformation.Prepare(self.asm_emitter, registers, kernel_m,
|
|
kernel_n, lhs, rhs)
|
|
|
|
self.output_strategy.Prepare(self.asm_emitter, registers, kernel_m,
|
|
kernel_n, self.fused_transformation.Type())
|
|
|
|
reduced = _GenerateAggregatorReduce(self.asm_emitter, aggregators, kernel_m,
|
|
kernel_n)
|
|
|
|
transformed = self.fused_transformation.Transform(self.asm_emitter,
|
|
registers, reduced,
|
|
kernel_m, kernel_n)
|
|
|
|
self.output_strategy.Output(self.asm_emitter, registers, transformed,
|
|
self.fused_transformation.Type(), kernel_m,
|
|
kernel_n)
|
|
|
|
self.asm_emitter.EmitAsmEnd(registers)
|
|
self.asm_emitter.PopIndent(len(self.emitter.indent))
|
|
|
|
|
|
class QuantizedMulStaticRowMajor(QuantizedMulKernel):
|
|
"""."""
|
|
|
|
def __init__(self, cc_emitter, asm_emitter):
|
|
QuantizedMulKernel.__init__(self, cc_emitter, 'QuantizedStaticPreprocessed',
|
|
'RowMajor', asm_emitter,
|
|
_StaticQuantizationUInt8Transformation(),
|
|
_RowMajorOutput())
|
|
|
|
|
|
class QuantizedMulStaticAsInt32RowMajor(QuantizedMulKernel):
|
|
"""."""
|
|
|
|
def __init__(self, cc_emitter, asm_emitter):
|
|
QuantizedMulKernel.__init__(self, cc_emitter,
|
|
'QuantizedStaticPreprocessedAsInt32',
|
|
'RowMajor', asm_emitter,
|
|
_StaticQuantizationInt32Transformation(),
|
|
_RowMajorOutput())
|
|
|
|
|
|
class QuantizedMulStaticAsFloatRowMajor(QuantizedMulKernel):
|
|
"""."""
|
|
|
|
def __init__(self, cc_emitter, asm_emitter):
|
|
QuantizedMulKernel.__init__(self, cc_emitter,
|
|
'QuantizedStaticPreprocessedAsFloat',
|
|
'RowMajor', asm_emitter,
|
|
_StaticQuantizationFloatTransformation(),
|
|
_RowMajorOutput())
|
|
|
|
|
|
def GenerateKernels(cc_emitter, asm_emitter, shapes):
|
|
"""Generate the quantized multiplication kernels for uint8 operands."""
|
|
quantized_mul_static_row_major = QuantizedMulStaticRowMajor(cc_emitter,
|
|
asm_emitter)
|
|
quantized_mul_static_int32_row_major = QuantizedMulStaticAsInt32RowMajor(
|
|
cc_emitter, asm_emitter)
|
|
|
|
quantized_mul_static_float_row_major = QuantizedMulStaticAsFloatRowMajor(
|
|
cc_emitter, asm_emitter)
|
|
|
|
for shape in shapes:
|
|
quantized_mul_static_row_major.SpecializeMulKernel('uint8_t', 'uint8_t',
|
|
shape[0], shape[1], 8)
|
|
for shape in shapes:
|
|
quantized_mul_static_int32_row_major.SpecializeMulKernel('uint8_t',
|
|
'int32_t',
|
|
shape[0], shape[1],
|
|
8)
|
|
|
|
for shape in shapes:
|
|
quantized_mul_static_float_row_major.SpecializeMulKernel('uint8_t', 'float',
|
|
shape[0], shape[1],
|
|
8)
|