Tensor Stunts I: batch-aware attention computing

First part of tensor stunts.

Numpy and pytorch (and, more generally languages used for math computation) promote the replacement of good old for loops by creative use of the tensor libraries. As Python is inherently slow, this use can dramatically improve the efficiency of your code, especially if you use a GPU. In some of the examples below, the optimised code ran fourty (40) times faster than the original.

However, those optimisations are often unintuitive; I have sometimes used LLM (as a kind of stackoverflow improvement) to find the right python functions to use, and I usually needed to edit the result.

Here are a few examples of non trivial uses of pytorch libraries for complex computations.

The Process

When I start trying to optimise code that way, I often keep a straightforward, unoptimised, version of the code, written in plain Python, which I use to validate the computation. It takes a bit of time, but allows one to be confident in the optimised code. Unit tests are also very valuable.

Attention computing and masking

I wanted to compute an attention along the lines of Luong et al. in their paper “Effective Approaches to Attention-Based Neural Machine Translation.” DOI. It's not very hard per se, as it's only a scalar product. But, if you add a batch dimension, it becomes much more complex.

Attention was originally used for text translation in the encoder/decoder architecture.

  • The encoder network produces a sequence of vectors, \(s_i\), \(i \in [1..n]\), representing the input sentence;
  • the decoder produces vectors \(t_j\) for word position \(j\), \(j \in [1..m]\) in the translation;
  • The attention mechanism determines which part of \(s_i\) are more relevant when producing word \(j\) in the translation. The final attention, from a position \(j\) in the output toward an input sentence \(s\) of length \(n\), is a vector \(a\) of size \(n\), where : 1) \(\forall i, 0 \le a_i \le 1\) and 2) \(\sum_{i=0}^{n-1} a_i = 1\)

Attention is computed in two steps :

  • first, we compute a score with a simple scalar product between each \(s_i\) and \(t_j\):

\[ \mathrm{score}_{i} = s_i \cdot t_j\]

  • then, we transform the score into an actual attention by performing a softmax :

\[ a_{i} = \frac{\exp(\mathrm{score}_{i})}{\sum_{k=0}^{n-1} \exp(\mathrm{score}_{k})}\]

Padding and attention

The basic mechanism of attention is rather simple. However, when training a system, we normally use batches, working on more than one sentence at a time. As the sentences in the training corpus can have any length, a batch is typically of shape (B,M), with M being the length of the longest sentence in this particular batch.

For shorter sentences in the batch, we usually use padding to mark that certain positions are not really used. In the original input, sentences are encoded as a sequences of indexes for word in the vocabulary, and “1” is often used as a special index for padding.

If we have a batch with sentences:

  • the nice cat eats
  • hello you
  • the mouse runs

We have B = 3 and M = 4; the padded batch would be:

  • the nice cat eats
  • hello you <pad> <pad>
  • the mouse runs <pad>

And the encoding of the batch as a integer tensor would be:

\[ \mathrm{s} = \begin{pmatrix} 10 & 15 & 8 & 7\\ 25 & 90 & 1 & 1\\ 10 & 105 & 22 & 1\\ \end{pmatrix}\]

Attention and batch interact! when we have a padding at a given word position, the attention toward this word should be 0, because this word is simply not there in reality! The system might learn it on its own, but it's a waste of computational power. We can make sure the attention is right.

The idea is to change the values after the computation of the score. If, for positions which correspond to padding, we set the score value to \(-\infty\), then the exponential will be 0, and so will the attention.

Let's see how we can code this.

Notations

We use the following conventions:

B
is the batch dimension
N
is the max length of input sentences for this batch
M
is the max length of output sentences for this batch
hs
is the entry tensor, of shape \((B,N,H)\)
ht
is the output tensor, of shape \((B,M,H)\)
attention
will be of shape \((B,M,N)\), where att[b,i,j] is the attention from position i in the output to position j in the input for batch entry b.
hs_lengths
a list of actual sentence length in hs

Naive implementation

We use the standard pytorch convention for recurrent networks that batch is in second position. Apparently, it's no longer really useful, but anyway...

A plain version of the attention would be :

def attention(hs, ht, hs_lengths):
    """
    Input: 
        hs : tensor N x B x H
        ht : tensor M x B x H
        hs_lengths : B length of sentences s in the batch.
    Output:
        tensor B x M x N
    """
    N, B, H = hs.shape
    M, B1, H1 = ht.shape
    assert B == B1, "same batch sizes"
    assert H == H1, "same hidden sizes"
    # scalar product (we might also use torch.bmm here)
    score = ht.permute(1,0,2) @ hs.permute(1,2,0) # dimensions B x M x N    
    assert score.shape == (B, M, N)

    # Now, the complex part :
    # In the final result, we want 
    #
    #               attention[b,i,j] = 0 if j >= hs_length[b]
    #
    # if j >= hs_length[b], there is not really anything meaningful in hs[j,b,:]. 
    #    it's a vector for padding only, and we are not going to attend to padding.
    #
    # `score` currently contains logits, not attention values. It will be transformed into
    # attention values by softmax. 
    #
    # Hence the following simple trick : set the values in `score` to minus infinity when 
    # they are to be ignored. The softmax applyied to -inf will be 0.

    for b in range(B):
        score[b, :, hs_lengths[i]:] = float('-inf')

    return F.softmax(score, dim=2)

The code is not too long, and is reasonnably simple. However, it's not very efficient. Let's see if we can remove the loop.

Pytorch optimisation

We are going to build a boolean mask of shape \((B,M,N)\) which will tell us which values should be set to \(-\infty\).

In the naive implementation, we wrote:

# Remember, the shape of score is (B,M,N)
for b in range(B):
    score[b, :, hs_lengths[i]:] = float('-inf')

It shows that nothing peculiar happens in the M dimension. Basically, we want mask[b,i,j] to be the same for all values of i. Thus, we can work first without the second axis (M), create a kind of mask0[b,j], and then duplicate its entries to create mask[b,i,j] = mask0[b,j] for all i.

mask0 will contain an entry for each batch position; mask0[b,:] will show in fact which part of the tensor is used for actual words, and which part is used for padding.

More precisely, if j < hs_lengths[b], then the word in position j in batch entry b is an actual word; if j >= hs_lengths[b], it's padding. mask0[b,j] will be true for padding positions, and false for “normal” positions.

For instance, if we have B = 3 and N = 4, we could have

\[ \mathrm{mask0} = \begin{pmatrix} \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{False}\\ \mathrm{False} & \mathrm{False} & \mathrm{True} & \mathrm{True}\\ \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{True}\\ \end{pmatrix}\]

Meaning that the first batch entry is actuall a sentence of length 4 (all values set to false); the second sentence is of length 2, and the third of length 3.

Building mask0

Instead of looping, we use torch.arange to build a sequence which corresponds to each possible position in the input tensor hs:

t1 = torch.arange(N, device=hs.device)

We then add an axis for the batch dimension:

t1 = t1.unsqueeze(0)

this gives us a tensor of shape \((1,N)\), with t1[0,k] == k for all k.

If we compared it with a tensor v of shape (1,N) containing a constant sentence length vsize, we would get a boolean tensor m0 of shape (1,N), with m0[0,i] == true \(\:\Leftrightarrow\) i <= vsize.

Now, we want to build mask0, of shape (B,N), with mask0[b,j] true iff j >= hs_lengths[b]. To build this mask, we can combine our previous idea with a numpy/pytorch feature: broadcasting.

Broadcasting means that, for some operations, when of shape of (a,b) is needed and a tensor t of shape (1,b) is passed, pytorch will duplicate the missing values along the axis of length 1.

Basically, if we have a tensor \(z\) of shape (1,4) :

\[ z = \begin{pmatrix} 5 & 9 & 3 & 7 \\ \end{pmatrix}\]

and we need to combine it with a tensor of shape (3,4), pytorch will consider \(z\) to be

\[ z = \begin{pmatrix} 5 & 9 & 3 & 7 \\ 5 & 9 & 3 & 7 \\ 5 & 9 & 3 & 7 \\ \end{pmatrix}\]

Broadcasting operates on axis of dimension 1, to broadcast them to larger dimensions when needed. Hence the previous use of unsqueeze.

Multiple broadcast can be performed. If we need two tensors t'1 and t'2 of the same shape (a,b), and we actually have t1.shape == (1,b) and t2.shape == (a,1), pytorch will broadcast both tensors to shape (a,b).

We have :

t1 = torch.arange(N, device=hs.device).unsqueeze(0)
t2 = hs_lengths.unsqueeze(1)
  • t1, of shape (1,N), contains the sequence of indexes ;
  • t2, of shape (B,1), contains the length of each sentence in the batch.

For instance: \( t1 = \begin{pmatrix} 0 & 1 & 2 & 3 \\ \end{pmatrix} \) and \( t2 = \begin{pmatrix} 4\\ 3\\ 2\\ \end{pmatrix} \)

If we combine them, broadcasting will consider them as being respectively :

\[ t'1 = \begin{pmatrix} 0 & 1 & 2 & 3 \\ 0 & 1 & 2 & 3 \\ 0 & 1 & 2 & 3 \\ \end{pmatrix}\]

and

\[ t'2 = \begin{pmatrix} 4& 4 & 4 & 4\\ 3 & 3 & 3 & 3\\ 2 & 2 & 2 & 2\\ \end{pmatrix}\]

Computing t1 >= t2 in pytorch will ensure t1 and t2 are broadcast, and the result will be a boolean tensor, mask0 :

\[ \mathrm{mask0} = \begin{pmatrix} \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{False}\\ \mathrm{False} & \mathrm{False} & \mathrm{False} & \mathrm{True}\\ \mathrm{False} & \mathrm{False} & \mathrm{True} & \mathrm{True}\\ \end{pmatrix}\]

From mask0 to mask

The actual mask we want is of shape (B,M,N). Ours is of shape (B,N). We want to create mask such that mask[b,i,j] = mask0[b,j] for all i. Basically, we want to copy each entry in mask0[b,j] M times, along a new second axis. There are many ways to perform this in pytorch, but the most efficient is

mask = mask0.unsqueeze(1).expand(-1, M, -1)
  • we go from shape (B,N) to shape (B,1,N) by introducing a second axis with unsqueeze;
  • we use expand to duplicate all entries M times along this new axis;
  • done!

Using the mask

Once we have the mask, we can use it. Values which correspond to padding will be set to \(-\infty\), and applying softmax will give a result of 0, which is what we want: padding position will get no attention.

We can use this mask as an index in the original tensor :

score[mask] = float('-inf')

Or we can use the specific (in place) masked_fill_ function :

score.masked_fill_(mask, float('-inf'))

Resulting code

def attention(hs, ht, hs_lengths):

    N, B, H = hs.shape
    M, B1, H1 = ht.shape
    score = torch.bmm(ht.transpose(0, 1), hs.permute(1, 2, 0))
    assert score.shape == (B, M, N)
    # Compute the sequence corresponding to max sentence length, add the dimension corresponding to batch as a 1-level dimension
    t1 = torch.arange(N, device=hs.device).unsqueeze(0)
    assert t1.shape == (1, N), f"t1 shape is {t1.shape}, expected (1, N)"
    # Create a vertical tensor corresponding to sentence length values
    t2 = hs_lengths.unsqueeze(1)
    assert t2.shape == (B, 1), f"t2 shape is {t2.shape}, expected (B, 1)"
    # Now compare the two.
    # Broadcasting will create a tensor of shape (B,N)
    # where mask0[i,j] = False means that entry of position j should be kept for the batch B
    #                   True : should be replaced by -inf.
    mask0 = t1 >= t2
    assert mask0.shape == (B, N), f"mask shape is {mask.shape}, expected (B, N)"
    #
    # Now, we have only one vector for each batch in the mask. We should have M of them.
    # as the mask should be of shape (B,M,N)
    # We want mask[i,:,j] = mask0[i,j]
    #
    # we could use repeat here, but expand is more efficient.
    mask = mask0.unsqueeze(1).expand(-1, M, -1)

    # Our mask has the right shape - and the right content
    # we apply it to set the score to -inf for padding.
    score.masked_fill_(mask, float('-inf'))

    # We call softmax on the N axis (dim=2), 
    # and it will transform those -inf into 0
    # no attention will focused on padding element.
    return F.softmax(score, dim=2)

Previous Post