#pragma kernel ScanInBucketExclusive #pragma kernel ScanAddBucketResult #define THREADS_PER_GROUP 512 // Ensure that this equals the 'threadsPerGroup' const in the host script. Must be an odd power of 2. // These must be a multiple of THREADS_PER_GROUP. StructuredBuffer _Input; RWStructuredBuffer _Result; RWStructuredBuffer _BlockSum; uint count; groupshared uint bucket[THREADS_PER_GROUP]; // Scan in each bucket. [numthreads(THREADS_PER_GROUP, 1, 1)] void ScanInBucketExclusive(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID, uint GI : SV_GroupIndex) { if (DTid < count) { bucket[GI] = _Input[DTid]; } else { bucket[GI] = 0; } GroupMemoryBarrierWithGroupSync(); uint stride; // up-sweep [unroll] for (stride = 2; stride <= THREADS_PER_GROUP; stride <<= 1) { GroupMemoryBarrierWithGroupSync(); if (((GI + 1) % stride) == 0) { const uint half_stride = (stride >> 1); bucket[GI] += bucket[GI - half_stride]; } } // Without this barrier, setting tg_mem[-1] to 0 may not be properly // propagated across the entire threadgroup. GroupMemoryBarrierWithGroupSync(); if (GI == THREADS_PER_GROUP - 1) { // clear the last element _BlockSum[Gid] = bucket[GI]; bucket[GI] = 0; } // down-sweep [unroll] for (stride = THREADS_PER_GROUP; stride > 1; stride >>= 1) { GroupMemoryBarrierWithGroupSync(); if (((GI + 1) % stride) == 0) { const uint half_stride = (stride >> 1); const uint prev_idx = GI - half_stride; const int tmp = bucket[prev_idx]; bucket[prev_idx] = bucket[GI]; bucket[GI] += tmp; } } GroupMemoryBarrierWithGroupSync(); if (DTid < count) _Result[DTid] = bucket[GI]; } // Add the bucket scanned result to each bucket to get the final result. [numthreads(THREADS_PER_GROUP, 1, 1)] void ScanAddBucketResult(uint DTid : SV_DispatchThreadID, uint Gid : SV_GroupID) { if (DTid < count) _Result[DTid] += _Input[Gid]; }