Parallelizing Non-Associative Scans
In episode 241 of the excellent ADSP podcast, Bryce (one of the co-hosts) mentioned a technique he had explored with ChatGPT for parallelizing a recurrence relation that, at first glance, seems inherently sequential. Because of the limitations of the audio format, I couldn't quite visualize the code or the transformation he was describing, but it sounded like a useful pattern.
I did my own little investigation with Gemini, and I think I managed to piece together the technique. It’s a useful trick for converting certain types of sequential problems into a form that can be massively parallelized.
How Parallel Scans Work
First, let's talk about a standard "scan," also known as a prefix sum. A prefix sum takes a list of numbers and an operator (like +
) and produces a new list where each element is the cumulative result of the operator up to that point.
Input: [3, 1, 4, 1, 5, 9 ]
Inclusive Scan (+): [3, 4, 8, 9, 14, 23]
A naive, sequential implementation is simple:
output[0] = input[0]
for i from 1 to n-1:
output[i] = output[i-1] + input[i]
This is implementation decently fast on a single CPU core, but it's fundamentally sequential. The calculation for output[i]
depends directly on the final result of output[i−1]
. How can you parallelize this?
There are several approaches, but a simple one is the two-level algorithm. Imagine you have a modern GPU. You can give each of its independent processing units (like a Streaming Multiprocessor on an NVIDIA GPU or a Compute Unit on an AMD GPU) a chunk of the input array. You then perform the following operations:
Chunk-wise Scan: Each processing unit computes the prefix sum for its own small chunk, completely in parallel. At this point, the results in all chunks (except the first) are wrong, because they don't account for the sums of the preceding chunks.
[ [3, 1, 4], [1, 5, 9 ], [2, 6, 5 ], ... ] —>
[ [3, 4, 8], [1, 6, 15], [2, 8, 13], ... ]
(Local Scans)Collect Last Elements: We create a new array containing just the last element from each chunk's partial scan.
[8, 15, 13, ...]
Scan the Lasts: We perform a prefix sum on this new, much smaller array. This gives us the cumulative sum at the boundary of each chunk.
[8, 15, 13, ...]
->[8, 23, 36, ...]
Update the Chunks: Finally, we broadcast the results from step 3 back to the processing units. Each unit (except the first) adds the cumulative value from the previous chunk to every element in its local result.
Chunk 2 (
[1, 6, 15]
) gets8
(from chunk 1's end) added to it:[9, 14, 23]
Chunk 3 (
[2, 8, 13]
) gets23
(from chunk 2's cumulative end) added to it:[25, 31, 36]
After this update step, (assuming we’ve allocated the chunks in contiguous memory) we have the full prefix sum array. We did a bit of extra computation compared to the single-threaded version (the last update to the chunks), but adding a scalar to a vector is very fast on a GPU and we get to use dozens of processing units instead of a single core. (For a more advanced implementation of parallel scans, I highly recommend digesting this great paper out of Nvidia)
Associativity
This parallel algorithm can work because the addition operator (+
) is associative: (a+b)+c = a+(b+c). Because the order of operations doesn’t matter we could do a simple update of all the chunks at the end in parallel even though in the sequential version always computes the elements in order.
But what happens when the operation isn't associative?
Bryce's Example
In the podcast episode Bryce talked about computing the output of a proportional error control system:
Here, y is the output array, x is the input array, and k is some constant factor. In sequential psuedocode:
output[0] = input[0]
for i from 1 to n-1:
output[i] = output[i-1] + k*(input[i]-output[i-1])
Let's try to use our parallel chunking algorithm here.
Each processing unit gets a chunk of the input x and an initial y value (say, 0).
It computes the recurrence for its chunk.
But now what? There is no simple transformation we can make to the chunks to turn them into the correct final answer.
This is where the problem seems intractably sequential.
The Solution: Lifting to a Higher Dimension
The solution is to change what we are computing. Instead of computing the value yi at each step, we will compute the function that produces yi from yi-1.
We can rewrite our recurrence relation as
And then further rewrite it in 2x2 matrix form:
We can unroll the recurrence to
and so on. Composing transformations is now just matrix multiplication, which is associative!
The parallel algorithm is now clear:
Lift: Convert the input array [x0,x1,...] into an array of matrices [M0,M1,...].
Scan: Perform a parallel scan on the array of matrices using matrix multiplication as the operator. This gives you a new array of matrices [P0,P1,...], where
Pi=MiMi-1…M0.
Project: For each matrix Pi in the scanned array, compute the final value by multiplying it by the initial condition vector:
\(\begin{pmatrix} y_i \\ 1 \end{pmatrix} = P_i \begin{pmatrix} y_{0} \\ 1 \end{pmatrix}\)
With this, we can parallelize the original recurrence.
Generalizing the Technique
This technique has pretty broad applications. It works for any problem that can be expressed as a linear recurrence relation. The key requirement is that the state at step i can be derived from the state at step i−1 by a matrix multiplication.
The Fibonacci sequence is another example:
The state we need to carry forward is not one number, but two: (Fi-1,Fi-2). The transformation is:
Once again, we have a matrix operator. We can perform a parallel scan with this 2x2 matrix in order to calculate the Fibonacci sequence.
Other application include infinite impulse response (IIR) filters (e.g. exponential moving averages) and as Bryce mentions polynomial evaluation, a running minimum or a maxplus convolution.
Update: Code Example From Bryce
Bryce very kindly shared this Python notebook that contains an implementation in Thrust. Here is my attempt to connect the theory above to this implementation. In Bryce’s notation α=1-k and βi=k*xi so we have
In order to implement this is Thrust we want to define the recurrence relation
We can write M in the form
so that
or in psuedocode notation:
αacc ← α*αacc
βacc ← α*βacc + βi
Bryce’s code also allows users to compute multiple trajectories in one call by assigning each trajectory to a group/key. When doing so, it’s easiest to keep track of the group, y0, αacc, and βacc all in one struct which he calls affine
. The composition operation that updates the accumulator is implemented in this code block:
struct compose_affine {
__host__ __device__
affine operator()(affine acc, affine cur) const
{
if (acc.group == cur.group)
return { /*group = */ cur.group,
/*y0 = */ acc.y0,
/*alpha = */ cur.alpha * acc.alpha,
/*beta = */ cur.alpha * acc.beta + cur.beta };
else // We've got a new key.
return { /*group = */ cur.group,
/*y0 = */ cur.y0,
/*alpha = */ cur.alpha,
/*beta = */ cur.beta };
}
};
Other important parts of the code are a transform iterator that creates the affine
structs corresponding to each Pi using this function:
struct make_affine {
double K;
__host__ __device__ affine operator()(thrust::tuple<int, double> group_x) const {
auto group = thrust::get<0>(group_x);
auto x = thrust::get<1>(group_x);
return { group, /*y0 = */ x, /*alpha = */ 1.0 - K, /*beta =*/ K * x };
}
};
and an output transform that uses the affine
structs to multiply the Mis with the y0 vectors to produce the final results using this function:
struct apply_affine {
__host__ __device__ double operator()(const affine& f) const {
return f.alpha * f.y0 + f.beta;
}
};
By using Thrust iterators this implementation is able to avoid storing and loading all the affine
structs into GPU memory which would have incurred a large performance penalty.
Having never used Thrust before I learned a lot from studying this example so thanks again to Bryce for sharing!