Tensor Stunts II: Conditional Random Fields and the “gather” function

Second part of tensor stunts.

Conditional Random Fields were very fashionable in the 2010's. They are a powerful and versatile way to label sequences of texts. Obviously, the rise of Deep Learning and LLM has shadowed them. However, they can be quite useful in sequence labeling, even in a deep learning setting.

In sequence labelling, we have a sentence \(x_1,x_2,\cdots,x_n\), and we want to associate a label to each word. We might use it to perform part of speech tagging for instance:

he loves to go to new york
PRP VBZ TO VB IN JJ NNP

Or to perform named entity recognition:*

he loves to go to new york
O O O O O B-PLACE I-PLACE

In all cases, we have a sentence \(x\), and we want to find the best label sequence for \(x\), i.e. the label sequence \(t'\) such that

\[ t' = \mathrm{argmax}_t P(t|x)\]

More precisely:

\[ (t'_1, t'_2,\cdots t'_n) = \mathrm{argmax}_{(t_1, t_2,\cdots t_n)} P(T_1=t_1, T_2=t_2,\cdots T_n=t_n|(X_1=x_1, X_2=x_2,\cdots X_n=x_n))\]

Most Deep learning systems used in sequence labelling (n-to-n models) don't model the notion of label sequence. They work as if the labels in sequence were independant. However, it's hardly the case.

Suppose you have an input \((x_1,x_2,\cdots x_n)\). You want to produce an output \((t_1,t_2,\cdots t_n)\) where \(t_i\) is the label of word \(x_i\). A typical recurrent network returns an output \((y_1,y_2,\cdots y_n) = f(x_1,x_2,\cdots x_n)\) where each \(y_i\) is a logit, a vector of dimension \(k\), where \(k\) is the number of possible labels. Usually, \(y_i\), which is a vector of real numbers, is transformed into a probability distribution using a softmax.

The naive approach will be to take \(t_i = \mathrm{argmax}_j(y_i^j)\), i.e. the position of the largest coordinate of \(y_i\). But this is correct only if

\[ P(T=(t_1,\cdots,t_n)|X=(t_x,\cdots,x_n)) = P(T_1=t_1|X_1=x_1)P(T_2=t_2|X_2=x_2)\cdots P(T_n=t_n|X_n=x_n)\]

that is, if the value of label \(T_i\) doesn't depend on the value of the previous labels.

This assumption does not hold. For instance, in named entity recognition, when an entity is made of multiple words, as New York, the first word will be labeled as B-LOC (begin-location), and subsequent words will be labeled as I-LOC (in-location). A I- can't start a named entity. Hence, we have \(P(T_i=\mathrm{I-LOC}|T_{i-1}=\mathrm{O})=0\).

Using Conditional Random Fields would allow a more realistic modeling. The recurrent network would be kept, but the final probability would take into account the label predicted by the network, and a simple transition probability from \(T_i\) to \(T_{i+1}\).

The pytorch tutorial give an example of CRF.

Here, we will examine how we can use pytorch to optimise the computations. The function we want to optimize is:

def _naive_score(self,features, expected_labels):
    """
    Computes the score in a naive way
    features : shape (L x B x V) of floats
    expected_labels : shape (L x B) of long

    (returns same values as score.)

    return : scores shape (B)
    """
    (L,B,V) = features.shape
    score = torch.zeros(size=(B,)).to(device=device)
    for i in range(0,L):
        for j in range(0,B):
            current_label = expected_labels[i][j]
            if i > 0:
                previous_label = expected_labels[i-1][j]
            else:
                previous_label = self.codeO
            score[j]+= features[i,j,expected_labels[i,j]] + self.transition[previous_label][current_label]
    return score

Note that self.codeO is the used for the O value (marking that a word is not part of a named entity).

It takes as input \((y_1\cdots y_n)\) (here called features), which are the output of the recurrent network, the expected labels \((t_1\cdots t_n)\) and computes the logarithm of the probability the model assigns to those expected labels., i.e.

\[ \log(\prod_{i=1}^{n} \frac{\exp (y_i^{(t_i)})}{\sum_k \exp (y_i^{(k)})} A[t_{i-1},t_i])\]

In a simpler way: the score of each label is the sum of the logit for this label and of a transition value from the previous label.

The code is a bit more complex, as it deals with batches. It's also slow, with two layers of python loops. Can we remove them ?

The optimised code

The optimised function is the following. It requires some explanations...

def score(self, features, expected_labels):
    """
    Compute scores of couples input, output
    features : shape (L x B x V) of floats
    expected_labels : shape (L x B) of long
    """    
    # The features based values, using gather:
    label1 = expected_labels.unsqueeze(dim=2) # we add an axis for gather
    f1 = features.gather(2, label1).squeeze(dim=2) # shape (L x B)

    # then, we work on the transitions
    # We create a prefix corresponding to "O" labels: (1 x B)
    prefix = torch.full((1,features.shape[1]), fill_value= self.codeO, dtype=torch.long).to(device=device)
    first_indexes = torch.concat((prefix, expected_labels), dim=0)[:-1,:]
    # We get the transition matrix
    transitions = self.transition[first_indexes, expected_labels]
    # result : L x B
    return torch.sum(f1, dim = 0) + torch.sum(transitions, dim=0)

The use of gather

If we look at the basic code, we see that, for each expected label \(t_i\), we have used it as index to extract the corresponding value in \(y_i\): features[i,j,expected_labels[i,j]]. And then, we sum those values for all i.

Thus, we would like a way, given a tensor labels of integer indexes, to use this vector to extract values from a tensor, in order to sum them later on. That's what gather does.

In the following code:

label1 = expected_labels.unsqueeze(dim=2) # on ajoute un axe pour utiliser gather (L,B,1)
f0 = features.gather(2, label1) # (L,B,1)
f1 = f0.squeeze(dim=2) # shape (L x B)    
  1. we start by adding a third axis to expected_labels, because we need something which has the same number of axis as features ;
  2. we call gather on features. It will extract values from features, selecting them using indexes taken in label1, using the index to select along the axis given as first argument.

f0, The result of gather has the same shape as label1, i.e. (L,B,1). If we define k = label1[i,b,0], f0[i,b,0] will be the value features[i,b,k]... which is exactly what we want.

Hence, our f1 variable will contain the logits corresponding to each of the expected labels. We can then use a pytorch sum and avoid using a loop to compute and sum them.

Computing the sum of the transitions

Computing the sum of transitions from one state to another doesn’t require using a sophisticated function like gather, but it still feels a bit like a workout.

We are going to use the advanced indexing possibilities in numpy and pytorch.

In pytorch, if a and b are integer tensors of the same shape, and if t is another tensor of shape (n,m), writing t[a,b] will create a tensor with the same shape as a and b, using the values in a and b as indexes.

In other words, if r is the resulting tensor, we have r[i,j] = t[a[i,j], b[i,j]].

We want to build two tensors of indexes. One for the label index at position \(p\), and the other for the label index at position \(p+1\).

We already have this second tensor, expected_labels. What we need is another tensor of the same shape, but shifted by one position.

We have a problem, because the position “before” the actual first label is not a real position, but we will suppose it has the label O. A better option would be to define a specific “start” label.

In the following code, we build this tensor,

  • by creating a fake column filled of the code for O;
  • concatenating it with expected_labels, stripped of its last column; the result is first_indexes;
  • getting the values from the transition matrix for all the transitions in the expected result.

Thus, our transitions tensor is of shape \((L \times B)\); on a given line, we have the sequence of transition values for this line of the batch.

prefix = torch.full((1,features.shape[1]), fill_value= self.codeO, dtype=torch.long).to(device=device)
first_indexes = torch.concat((prefix, expected_labels), dim=0)[:-1,:]
transitions = self.transition[first_indexes, expected_labels]

Gather vs advanced indexing

We could have used advanced indexing instead of gather in the first case. We wanted a tensor \(R\) of shape \((L \times B)\), with \(R[i,j] = features[i,j,expected_labels[i,j]]\). To perform this, we need three tensors of indexes :

  • a first one, a, where a[i,j]=i;
  • a second one, b, where a[i,j]=j;
  • and expected_labels.

But building a and b would be slower than using gather.

Previous Post