Tensor Stunts I: batch-aware attention computing
First part of tensor stunts.
Numpy and pytorch (and, more generally languages used for math computation) promote the replacement of good old for
loops by creative use of the tensor libraries. As Python is inherently slow, this use can dramatically improve the efficiency of your code, especially if you use a GPU. In some of the examples below, the optimised code ran fourty (40) times faster than the original.
However, those optimisations are often unintuitive; I have sometimes used LLM (as a kind of stackoverflow improvement) to find the right python functions to use, and I usually needed to edit the result.
Here are a few examples of non trivial uses of pytorch libraries for complex computations.
The Process
When I start trying to optimise code that way, I often keep a straightforward, unoptimised, version of the code, written in plain Python, which I use to validate the computation. It takes a bit of time, but allows one to be confident in the optimised code. Unit tests are also very valuable.
Attention computing and masking
I wanted to compute an attention along the lines of Luong et al. in their paper “Effective Approaches to Attention-Based Neural Machine Translation.” DOI. It's not very hard per se, as it's only a scalar product. But, if you add a batch dimension, it becomes much more complex.
Attention was originally used for text translation in the encoder/decoder architecture.
- The encoder network produces a sequence of vectors, \(s_i\), \(i \in [1..n]\), representing the input sentence;
- the decoder produces vectors \(t_j\) for word position \(j\), \(j \in [1..m]\) in the translation;
- The attention mechanism determines which part of \(s_i\) are more relevant when producing word \(j\) in the translation. The final attention, from a position \(j\) in the output toward an input sentence \(s\) of length \(n\), is a vector \(a\) of size \(n\), where : 1) \(\forall i, 0 \le a_i \le 1\) and 2) \(\sum_{i=0}^{n-1} a_i = 1\)
Attention is computed in two steps :
- first, we compute a score with a simple scalar product between each \(s_i\) and \(t_j\):
\[ \mathrm{score}_{i} = s_i \cdot t_j\]
- then, we transform the score into an actual attention by performing a softmax :
\[ a_{i} = \frac{\exp(\mathrm{score}_{i})}{\sum_{k=0}^{n-1} \exp(\mathrm{score}_{k})}\]
Padding and attention
The basic mechanism of attention is rather simple. However, when training a system, we normally use batches, working on more than one sentence at a time. As the sentences in the training corpus can have any length, a batch is typically of shape (B,M), with M being the length of the longest sentence in this particular batch.
For shorter sentences in the batch, we usually use padding to mark that certain positions are not really used. In the original input, sentences are encoded as a sequences of indexes for word in the vocabulary, and “1” is often used as a special index for padding.
If we have a batch with sentences:
- the nice cat eats
- hello you
- the mouse runs
We have B = 3 and M = 4; the padded batch would be:
- the nice cat eats
- hello you
<pad>
<pad>
- the mouse runs
<pad>
And the encoding of the batch as a integer tensor would be:
\[ \mathrm{s} = \begin{pmatrix} 10 & 15 & 8 & 7\\ 25 & 90 & 1 & 1\\ 10 & 105 & 22 & 1\\ \end{pmatrix}\]
Attention and batch interact! when we have a padding at a given word position, the attention toward this word should be 0, because this word is simply not there in reality! The system might learn it on its own, but it's a waste of computational power. We can make sure the attention is right.
The idea is to change the values after the computation of the score. If, for positions which correspond to padding, we set the score value to \(-\infty\), then the exponential will be 0, and so will the attention.
Let's see how we can code this.
Notations
We use the following conventions:
- B
- is the batch dimension
- N
- is the max length of input sentences for this batch
- M
- is the max length of output sentences for this batch
- hs
- is the entry tensor, of shape \((B,N,H)\)
- ht
- is the output tensor, of shape \((B,M,H)\)
- attention
- will be of shape \((B,M,N)\), where
att[b,i,j]
is the attention from positioni
in the output to positionj
in the input for batch entryb
. - hs_lengths
- a list of actual sentence length in
hs
Naive implementation
We use the standard pytorch
convention for recurrent networks that batch is in second position. Apparently, it's no longer really useful, but anyway...
A plain version of the attention
would be :
def attention(hs, ht, hs_lengths):
"""
Input:
hs : tensor N x B x H
ht : tensor M x B x H
hs_lengths : B length of sentences s in the batch.
Output:
tensor B x M x N
"""
N, B, H = hs.shape
M, B1, H1 = ht.shape
assert B == B1, "same batch sizes"
assert H == H1, "same hidden sizes"
# scalar product (we might also use torch.bmm here)
score = ht.permute(1,0,2) @ hs.permute(1,2,0) # dimensions B x M x N
assert score.shape == (B, M, N)
# Now, the complex part :
# In the final result, we want
#
# attention[b,i,j] = 0 if j >= hs_length[b]
#
# if j >= hs_length[b], there is not really anything meaningful in hs[j,b,:].
# it's a vector for padding only, and we are not going to attend to padding.
#
# `score` currently contains logits, not attention values. It will be transformed into
# attention values by softmax.
#
# Hence the following simple trick : set the values in `score` to minus infinity when
# they are to be ignored. The softmax applyied to -inf will be 0.
for b in range(B):
score[b, :, hs_lengths[i]:] = float('-inf')
return F.softmax(score, dim=2)
The code is not too long, and is reasonnably simple. However, it's not very efficient. Let's see if we can remove the loop.
Pytorch optimisation
We are going to build a boolean mask of shape \((B,M,N)\) which will tell us which values should be set to \(-\infty\).
In the naive implementation, we wrote:
# Remember, the shape of score is (B,M,N)
for b in range(B):
score[b, :, hs_lengths[i]:] = float('-inf')
It shows that nothing peculiar happens in the M dimension. Basically, we want mask[b,i,j]
to be the same for all values of i
. Thus, we can work first without the second axis (M), create a kind of mask0[b,j]
, and then duplicate its entries to create mask[b,i,j] = mask0[b,j]
for all i
.
mask0
will contain an entry for each batch position; mask0[b,:]
will show in fact which part of the tensor is used for actual words, and which part is used for padding.
More precisely, if j < hs_lengths[b]
, then the word in position j
in batch entry b
is an actual word; if j >= hs_lengths[b]
, it's padding. mask0[b,j]
will be true for padding positions, and false for “normal” positions.
For instance, if we have B = 3 and N = 4, we could have
\[ \mathrm{mask0} = \begin{pmatrix} \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{False}\\ \mathrm{False} & \mathrm{False} & \mathrm{True} & \mathrm{True}\\ \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{True}\\ \end{pmatrix}\]
Meaning that the first batch entry is actuall a sentence of length 4 (all values set to false); the second sentence is of length 2, and the third of length 3.
Building mask0
Instead of looping, we use torch.arange
to build a sequence which corresponds to each possible position in the input tensor hs
:
t1 = torch.arange(N, device=hs.device)
We then add an axis for the batch dimension:
t1 = t1.unsqueeze(0)
this gives us a tensor of shape \((1,N)\), with t1[0,k] == k
for all k.
If we compared it with a tensor v
of shape (1,N) containing a constant sentence length vsize
, we would get a boolean tensor m0
of shape (1,N), with m0[0,i] == true
\(\:\Leftrightarrow\) i <= vsize
.
Now, we want to build mask0, of shape (B,N), with mask0[b,j]
true iff j >= hs_lengths[b]
. To build this mask, we can combine our previous idea with a numpy/pytorch feature: broadcasting.
Broadcasting means that, for some operations, when of shape of (a,b) is needed and a tensor t
of shape (1,b) is passed, pytorch will duplicate the missing values along the axis of length 1.
Basically, if we have a tensor \(z\) of shape (1,4) :
\[ z = \begin{pmatrix} 5 & 9 & 3 & 7 \\ \end{pmatrix}\]
and we need to combine it with a tensor of shape (3,4), pytorch will consider \(z\) to be
\[ z = \begin{pmatrix} 5 & 9 & 3 & 7 \\ 5 & 9 & 3 & 7 \\ 5 & 9 & 3 & 7 \\ \end{pmatrix}\]
Broadcasting operates on axis of dimension 1, to broadcast them to larger dimensions when needed. Hence the previous use of unsqueeze
.
Multiple broadcast can be performed. If we need two tensors t'1 and t'2 of the same shape (a,b), and we actually have t1.shape == (1,b)
and t2.shape == (a,1)
, pytorch will broadcast both tensors to shape (a,b)
.
We have :
t1 = torch.arange(N, device=hs.device).unsqueeze(0)
t2 = hs_lengths.unsqueeze(1)
t1
, of shape (1,N), contains the sequence of indexes ;t2
, of shape (B,1), contains the length of each sentence in the batch.
For instance: \( t1 = \begin{pmatrix} 0 & 1 & 2 & 3 \\ \end{pmatrix} \) and \( t2 = \begin{pmatrix} 4\\ 3\\ 2\\ \end{pmatrix} \)
If we combine them, broadcasting will consider them as being respectively :
\[ t'1 = \begin{pmatrix} 0 & 1 & 2 & 3 \\ 0 & 1 & 2 & 3 \\ 0 & 1 & 2 & 3 \\ \end{pmatrix}\]
and
\[ t'2 = \begin{pmatrix} 4& 4 & 4 & 4\\ 3 & 3 & 3 & 3\\ 2 & 2 & 2 & 2\\ \end{pmatrix}\]
Computing t1 >= t2
in pytorch will ensure t1
and t2
are broadcast, and the result will be a boolean tensor, mask0
:
\[ \mathrm{mask0} = \begin{pmatrix} \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{False}\\ \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{True}\\ \mathrm{False} & \mathrm{False} & \mathrm{True} & \mathrm{True}\\ \end{pmatrix}\]
From mask0 to mask
The actual mask we want is of shape (B,M,N). Ours is of shape (B,N). We want to create mask
such that mask[b,i,j] = mask0[b,j]
for all i
. Basically, we want to copy each entry in mask0[b,j]
M times, along a new second axis.
There are many ways to perform this in pytorch, but the most efficient is
mask = mask0.unsqueeze(1).expand(-1, M, -1)
- we go from shape (B,N) to shape (B,1,N) by introducing a second axis with
unsqueeze
; - we use expand to duplicate all entries M times along this new axis;
- done!
Using the mask
Once we have the mask, we can use it. Values which correspond to padding will be set to \(-\infty\), and applying softmax
will give a result of 0, which is what we want: padding position will get no attention.
We can use this mask as an index in the original tensor :
score[mask] = float('-inf')
Or we can use the specific (in place) masked_fill_
function :
score.masked_fill_(mask, float('-inf'))
Resulting code
def attention(hs, ht, hs_lengths):
N, B, H = hs.shape
M, B1, H1 = ht.shape
score = torch.bmm(ht.transpose(0, 1), hs.permute(1, 2, 0))
assert score.shape == (B, M, N)
# Compute the sequence corresponding to max sentence length, add the dimension corresponding to batch as a 1-level dimension
t1 = torch.arange(N, device=hs.device).unsqueeze(0)
assert t1.shape == (1, N), f"t1 shape is {t1.shape}, expected (1, N)"
# Create a vertical tensor corresponding to sentence length values
t2 = hs_lengths.unsqueeze(1)
assert t2.shape == (B, 1), f"t2 shape is {t2.shape}, expected (B, 1)"
# Now compare the two.
# Broadcasting will create a tensor of shape (B,N)
# where mask0[i,j] = False means that entry of position j should be kept for the batch B
# True : should be replaced by -inf.
mask0 = t1 >= t2
assert mask0.shape == (B, N), f"mask shape is {mask.shape}, expected (B, N)"
#
# Now, we have only one vector for each batch in the mask. We should have M of them.
# as the mask should be of shape (B,M,N)
# We want mask[i,:,j] = mask0[i,j]
#
# we could use repeat here, but expand is more efficient.
mask = mask0.unsqueeze(1).expand(-1, M, -1)
# Our mask has the right shape - and the right content
# we apply it to set the score to -inf for padding.
score.masked_fill_(mask, float('-inf'))
# We call softmax on the N axis (dim=2),
# and it will transform those -inf into 0
# no attention will focused on padding element.
return F.softmax(score, dim=2)