// daedalus-fourier — H.264 8x8 inverse integer transform + add, V3D 7.1. // // H.264 spec §8.5.13.2 (High profile 8x8 IT). Pure integer arithmetic // — different butterfly from VP9 IDCT 8x8 (cycle 1, uses cospi // multipliers). Row pass first, column pass second; round (+32) >> 6, // add to dst, clip to u8. // // Block layout: COLUMN-MAJOR. block[c*8 + r] = coefficient at // (row r, column c). Matches FFmpeg `ff_h264_idct8_add_neon`. // // Workgroup layout: 64 invocations = 8 lanes/block × 8 blocks/WG. // - row pass: lane k (0..7) reads row k of the block (8 coefficients, // one from each column), runs the butterfly, writes 8 // outputs to one row of tmp_shared. // - column pass: lane k reads column k of tmp_shared (8 rows), // runs the butterfly, writes 8 outputs to dst as // column k at rows 0..7. // // shared = 8 × 64 × 4 B = 2 KiB. Well under V3D's 16 KiB limit. // // License: BSD-2-Clause. #version 450 #extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_explicit_arithmetic_types : require layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Coeffs { int16_t coeffs[]; // N × 64 column-major } u_coeffs; layout(binding = 1) buffer Dst { uint8_t dst[]; // H × stride bytes } u_dst; layout(binding = 2) readonly buffer Meta { uvec4 meta[]; // .x = dst_off } u_meta; layout(push_constant) uniform PC { uint n_blocks; uint dst_stride_u8; uint _pad0, _pad1; } pc; // 8 blocks/WG × 64 ints/block × 4 B = 2 KiB shared. shared int tmp_shared[8 * 64]; // 1D 8-element butterfly per H.264 §8.5.13.2. void idct8_1d(int d0, int d1, int d2, int d3, int d4, int d5, int d6, int d7, out int g0, out int g1, out int g2, out int g3, out int g4, out int g5, out int g6, out int g7) { int e0 = d0 + d4; int e1 = -d3 + d5 - d7 - (d7 >> 1); int e2 = d0 - d4; int e3 = d1 + d7 - d3 - (d3 >> 1); int e4 = (d2 >> 1) - d6; int e5 = -d1 + d7 + d5 + (d5 >> 1); int e6 = d2 + (d6 >> 1); int e7 = d3 + d5 + d1 + (d1 >> 1); int f0 = e0 + e6; int f1 = e1 + (e7 >> 2); int f2 = e2 + e4; int f3 = e3 + (e5 >> 2); int f4 = e2 - e4; int f5 = (e3 >> 2) - e5; int f6 = e0 - e6; int f7 = e7 - (e1 >> 2); g0 = f0 + f7; g1 = f2 + f5; g2 = f4 + f3; g3 = f6 + f1; g4 = f6 - f1; g5 = f4 - f3; g6 = f2 - f5; g7 = f0 - f7; } void main() { // local_size 64 = 8 blocks × 8 lanes/block. uint gid = gl_GlobalInvocationID.x; uint wg_id = gid / 64u; uint lane_in_wg = gid & 63u; uint block_local = lane_in_wg >> 3; // 0..7 uint k = lane_in_wg & 7u; // 0..7 uint block_idx = wg_id * 8u + block_local; bool oob = (block_idx >= pc.n_blocks); // ---- Row pass -------------------------------------------------- // lane k handles row r=k. Reads block[c*8 + k] for c=0..7. if (!oob) { uint base = block_idx * 64u; int d0 = int(u_coeffs.coeffs[base + 0u * 8u + k]); int d1 = int(u_coeffs.coeffs[base + 1u * 8u + k]); int d2 = int(u_coeffs.coeffs[base + 2u * 8u + k]); int d3 = int(u_coeffs.coeffs[base + 3u * 8u + k]); int d4 = int(u_coeffs.coeffs[base + 4u * 8u + k]); int d5 = int(u_coeffs.coeffs[base + 5u * 8u + k]); int d6 = int(u_coeffs.coeffs[base + 6u * 8u + k]); int d7 = int(u_coeffs.coeffs[base + 7u * 8u + k]); int g0, g1, g2, g3, g4, g5, g6, g7; idct8_1d(d0, d1, d2, d3, d4, d5, d6, d7, g0, g1, g2, g3, g4, g5, g6, g7); // Write row k of tmp_shared[block_local]. uint tbase = block_local * 64u + k * 8u; tmp_shared[tbase + 0u] = g0; tmp_shared[tbase + 1u] = g1; tmp_shared[tbase + 2u] = g2; tmp_shared[tbase + 3u] = g3; tmp_shared[tbase + 4u] = g4; tmp_shared[tbase + 5u] = g5; tmp_shared[tbase + 6u] = g6; tmp_shared[tbase + 7u] = g7; } barrier(); // ---- Column pass ---------------------------------------------- // lane k handles column c=k. Reads tmp[r][k] for r=0..7. if (!oob) { uint tbase = block_local * 64u; int s0 = tmp_shared[tbase + 0u * 8u + k]; int s1 = tmp_shared[tbase + 1u * 8u + k]; int s2 = tmp_shared[tbase + 2u * 8u + k]; int s3 = tmp_shared[tbase + 3u * 8u + k]; int s4 = tmp_shared[tbase + 4u * 8u + k]; int s5 = tmp_shared[tbase + 5u * 8u + k]; int s6 = tmp_shared[tbase + 6u * 8u + k]; int s7 = tmp_shared[tbase + 7u * 8u + k]; int g0, g1, g2, g3, g4, g5, g6, g7; idct8_1d(s0, s1, s2, s3, s4, s5, s6, s7, g0, g1, g2, g3, g4, g5, g6, g7); // Column k at rows 0..7 of dst, offset by meta.x. uint dst_off = u_meta.meta[block_idx].x; uint stride = pc.dst_stride_u8; uint a0 = dst_off + 0u * stride + k; uint a1 = dst_off + 1u * stride + k; uint a2 = dst_off + 2u * stride + k; uint a3 = dst_off + 3u * stride + k; uint a4 = dst_off + 4u * stride + k; uint a5 = dst_off + 5u * stride + k; uint a6 = dst_off + 6u * stride + k; uint a7 = dst_off + 7u * stride + k; int p0 = int(u_dst.dst[a0]); int p1 = int(u_dst.dst[a1]); int p2 = int(u_dst.dst[a2]); int p3 = int(u_dst.dst[a3]); int p4 = int(u_dst.dst[a4]); int p5 = int(u_dst.dst[a5]); int p6 = int(u_dst.dst[a6]); int p7 = int(u_dst.dst[a7]); u_dst.dst[a0] = uint8_t(clamp(p0 + ((g0 + 32) >> 6), 0, 255)); u_dst.dst[a1] = uint8_t(clamp(p1 + ((g1 + 32) >> 6), 0, 255)); u_dst.dst[a2] = uint8_t(clamp(p2 + ((g2 + 32) >> 6), 0, 255)); u_dst.dst[a3] = uint8_t(clamp(p3 + ((g3 + 32) >> 6), 0, 255)); u_dst.dst[a4] = uint8_t(clamp(p4 + ((g4 + 32) >> 6), 0, 255)); u_dst.dst[a5] = uint8_t(clamp(p5 + ((g5 + 32) >> 6), 0, 255)); u_dst.dst[a6] = uint8_t(clamp(p6 + ((g6 + 32) >> 6), 0, 255)); u_dst.dst[a7] = uint8_t(clamp(p7 + ((g7 + 32) >> 6), 0, 255)); } }