This document was originally written by Yizhi Liu
Forward
Set auto_broadcast
to True
as is shown here. TVM will map buffer[i][j][k] to buffer[i][0][k] if dimension j’s shape equals 1.
Backward
Overview and the ideas behind the implementation
TVM op kernels are typically generated in compile time, at that point of time the input shapes remain unknown. In this scenario, we have no idea about which axes are to broadcast. In contrast with forward computation, the broadcasting axes will be reduced by summation in backward, and the axes to be reduced (and thus the axes to broadcast) must be known at compile time.
A natural solution is to enumerate all cases. For example, if A.shape=(m, n)
, we have 4 cases:
- m != 1 && n != 1
- m == 1 && n != 1
- m != 1 && n==1
- m == 1 && n == 1
We can therefore generate 4 TVM kernels to cover all the cases.
Generally, suppose that we are considering two input operands of n
dim, and we label each dim as 1
if it needs broadcasting, and 0
otherwise. With the setting we can generate bit strings correspondingly (e.g., in the above example, we have bit strings “00”, “10”, “01” and “11”). But the problem is that we may have as many as 2^n
bit strings, and thus 2^n
TVM kernels, which is incredibly large.
The optimized version is to merge consecutive 1s
and 0s
within the bit string. It’s easy to verify that consecutive broadcasting axes can jointly form a combined axis, and so can the consecutive axes that do not broadcast.
For example, if A.shape=(m, n, k)
, originally we have 2^3=8
cases, but after merging, only 2*3=6
cases (which are 0, 1, 01, 10, 010, and 101) are left:
original | merged |
---|---|
000 | 0 |
001 | 01 |
010 | 010 |
011 | 01 |
100 | 10 |
101 | 101 |
110 | 10 |
111 | 1 |
Note that after merging, two consecutive bits in the bit string must be different, which indicates that the bit string is uniquely determined by its leading bit, so the number of possible bit strings gets reduced from 2^n
to a 2*n
.
The following two sections elaborate how to implement this algorithm for simple operators (whose input gradients are ONLY related with output gradients, that is, USE_NONE)
Runtime: Before Invoking TVM Kernels
Firstly, the shapes of the input operands are padded with 1
so that they share the same number of dimensions.
Then, we identify which axis is to broadcast and which is not to, label them with 1
and 0
in a bit string, and merge the consecutive 1s
and 0s
.
Finally, invoke the TVM kernel with the reshaped output gradient.
For example, two inputs x
and y
with shapes [2, 2, 1, 2, 2] and [1, 1, 2, 2, 1], and we are calculating the gradient of x
. The output gradient is of shape [2, 2, 2, 2, 2]. We will have bit string “00100”, and the reshaped output gradient is of shape [4, 2, 4] (because we merge first two dimensions as well as last two dimensions).
Similarly, suppose we are calculating the gradient of y
. The bit string will be “11001” and the reshaped output gradient is of shape [4, 4, 2].
Compile-time: Generating TVM Kernels
For a specific input dimension, two kernels are to be generated. One for the case where the first axis is to broadcast, and the other for the case where the first axis does not broadcast. Just sum over the broadcasting axes. A helper function reduce_axes
is available for this.
Take the add
operator as an example:
def compute_backward_vadd(dtype, ndim, reduce1st, req):
axes = ([reduce1st, 1 - reduce1st] * ndim)[:ndim]
X = tvm.placeholder([tvm.var() for _ in range(ndim)], name='X', dtype=dtype)
reducer = tvm.comm_reducer(lambda x, y: x + y,
lambda t: tvm.const(0, dtype=t), name="sum")
ret = reduce_axes(X, axes, reducer)
in_grad_a, in_grad = assign_by_req(ret, req)
s = tvm.create_schedule(in_grad.op)
return s, X, in_grad_a, in_grad, [ret, in_grad]
@defop(name="backward_vadd", target="cpu", dtype=AllTypes,
ndim=[5], reduce1st=[0, 1],
req=["kWriteTo", "kAddTo"], attrs=["reduce1st", "req"])
def backward_vadd(dtype, ndim, reduce1st, req):
s, X, in_grad_a, in_grad, c_list = compute_backward_vadd(dtype, ndim, reduce1st, req)
for t in c_list:
axes = [axis for axis in t.op.axis]
fused = s[t].fuse(*axes)
s[t].parallel(fused)
return s, [X, in_grad_a, in_grad]
Here we are generating kernels for two cases: 01010
and 10101
. The ndim
variable determines the length of the bit string, and the reduce1st
variable determines its leading bit.
The backward computation takes place in the function compute_backward_vadd
. Firstly, in the function, axes
is either assigned as 01010
or 10101
, as is determined by reduct1st
. Then, we set the reducer
as summation. Actually, all backward computation for any operator reduces the axes by summation. Finally, the reduction is accomplished by reduce_axes
, which reduces the dims labeled by axes
as 1
.
Some detailed explanation about req
are omitted, which controls the operation request type and is generally unrelated with broadcasting.
Complicated operators
For operands like multiply, input gradients are related not only with the output gradient, but also input data. Suppose the inputs data for multiplication are x
and y
, and the output gradient is z
. And we are to compute the gradient of x
.
First, compute temp
= y
* z
with auto broadcast.
Then, merge the axes of temp
as is stated above and sum over the interleaving axes.