kernel void waifu2x(texture2d_array in[[texture(0)]],
texture2d out[[texture(1)]],
constant float3x3* weights[[buffer(0)]],
constant float& bias[[buffer(1)]],
uint2 gid[[thread_position_in_grid]])
{
if (gid.x >= in.get_width() || gid.y >= in.get_height()) return;
float partial = bias;
for (uint i = 0; i < in.get_array_size(); ++i) {
float3 in0 = float3(in.read(gid + uint2(-1, -1), i).r,
in.read(gid + uint2( 0, -1), i).r,
in.read(gid + uint2(+1, -1), i).r);
float3 in1 = float3(in.read(gid + uint2(-1, 0), i).r,
in.read(gid + uint2( 0, 0), i).r,
in.read(gid + uint2(+1, 0), i).r);
float3 in2 = float3(in.read(gid + uint2(-1, +1), i).r,
in.read(gid + uint2( 0, +1), i).r,
in.read(gid + uint2(+1, +1), i).r);
float3x3 weight = weights[i];
partial += dot(in0, weight[0])
+ dot(in1, weight[1])
+ dot(in2, weight[2]);
}
float p = fmax(partial, 0) + 0.1 * fmin(partial, 0);
float4 outColor(p, 0, 0, 0);
out.write(outColor, gid);
}