A Journey Through Fastbook (AJTFB) - Chapter 6: Multilabel Classification
Its the more things you can do with computer vision chapter of "Deep Learning for Coders with fastai & PyTorch"! We'll go over everything you need to know to get started with multi-label classification tasks from datablocks to training and everything in between. Next post we'll look at regression tasks, in particular key point regression models that are also covered in chapter 6. Soooo lets go!
- Multiclass vs Multi-label classification (again)...
- Defining your DataBlock
- Train a model
- Summary
- Resources
Multiclass vs Multi-label classification (again)...
Last post we saw that multiclass classification is all about predicting a SINGLE CLASS an object belongs to from a list of two or more classes. It's the go to task if we're confident that every image our model sees is going to be one of these classes. Cross-entropy loss is our go to loss function as it wants to confidently pick one thing.
Multi-label classification involves predicting MULTIPLE CLASSES to which an object belongs; it can belong to one, some, all, or even none of those classes. For example, you may be looking at satellite photos from which you need to predict the different kinds of terrain (your classes) each contains.
Defining your DataBlock
Again, the DataBlock
is a blueprint for everything required to turn your raw data (images and labels) into something that can be fed through a neural network (DataLoaders with a numerical representation of both your images and labels). Below is the one presented in this chapter.
from fastai.vision.all import *
path = untar_data(URLs.PASCAL_2007)
Instead of working with the filesystem structure to get our images and define our labels, in this example we use a .csv file that we can explore and manipulate further via a pandas DataFrame.
Here are some of my favorite pandas resources:
- https://chrisalbon.com/
- https://pandas.pydata.org/docs/ (yah, the docs are really good!)
df = pd.read_csv(path/'train.csv')
df.head()
def get_x(r): return path/'train'/r['fname']
def get_y(r): return r['labels'].split(' ')
def splitter(df):
train = df.index[~df['is_valid']].tolist()
valid = df.index[df['is_valid']].tolist()
return train, valid
dblock = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
get_x=get_x,
get_y=get_y,
splitter=splitter, # or could just have used ColSplitter()
item_tfms=RandomResizedCrop(128, min_scale=0.35))
Let's break down our blueprint!
-
Define the data types for our inputs and targets via the
blocks
argument.This is defined as a tuple, where we tell our
DataBlock
that the imputs are images and our targets are multiple potential categories. Above we can see that these labels are space delimited in the "labels" column. The later essentially returns a one-hot encoded list of possible labels, with a 0 indicating that the label wasn't found for the item and a 1 indicating otherwise (see theDataBlock.summary
results below). -
Define how we're going to get our images via
get_x
.As we'll be passing in a Dataframe as the raw data source, we don't need to define a
get_items
to pull the raw data. We do however, need to instruct the DataBlock as to how to find the images, which we do via theget_x
method. That method will get one row of DataFrame (r
) at a time. -
Define how, from the raw data, we're going to create our labels via
get_y
.As already mentioned, the classes are in the "labels" column and delimited by a space, and so, we return a list of labels splitting on ' '. Easy peasy.
-
Define how we're going to create our validation dataset via
splitter
Here we define a custom splitter mostly to just show you how to do it. It has to return at least a tuple of train, validation data. We could have just used
ColSplitter
(see it in the docs here) -
Define things we want to do for each item via
item_tfms
item_tfms
are transforms, or things we want to do, to each input individually! Above we only have one which says, "Randomly crop the image to be 128x128 that captures at least 35% of the image each time you grab an image". See here for more info onRandomResizedCrop
-
Define things we want to do for each mini-batch of items via
batch_tfms
None here, but remember that these are transforms you want applied to a mini-batch of images on the GPU at the same time.
Important: Do not use lambda functions for defining your DataBlock methods! They can’t be serialized and so you’re lucky to get some errors when you try to save/export your DataLoaders and/or LearnerImportant: Verify yourDataBlock
works as expected, or else troubleshoot it, by runningDataBlock.summary(data)
dblock.summary(df)
Now we can create our DataLoaders
and take a look at our x's and y's, our images and their labels (multiple labeled images have their labels separated by semi-colon)
dls = dblock.dataloaders(df)
dls.show_batch()
To get a feel for what our item_tfms
(and batch_tfms
if we had them) are doing, we can show_batch
using a single image as we do below.
dls.show_batch(unique=True)
The combination of what we're doing in the item_tfms
and batch_tfms
is known as presizing.
"Presizing is a particular way to do image augmentation that is designed to minimize data destruction while maintaining good performance." After resizing all the images to a larger dimension that we will train on, we perform all our core augmentations on the GPU. This results in both faster and less destructive transformations of the data.
Define your loss function
To train a model we need a good loss function that will allow us to optimize the parameters of our model. For multi-label classification tasks where we want to predict a single class/label, to go to is binary cross-entropy loss
Why can't we just use cross-entropy loss?
Because "the softmax function really wants to pick one class" whereas here want it to pick multiple or even none.
"softmax ... requires that all predictions sum to 1, and tends to push one activation to be much larger than the others (because of the use of exp
) ... we may want the sum to be less than 1, if we don't think any of the categories appear in an image."
"nll_loss ... returns the value of just one activation: the single activation corresponding with the single label for an item [which] doesn't make sense when we have multiple labels"
learn = cnn_learner(dls, resnet18)
xb, yb = to_cpu(dls.train.one_batch())
res = learn.model(xb)
xb.shape, yb[0], res.shape, res[0]
So now we need a loss function that will scale those activations to be between 1 and 0 and then compare each activation with the value (0 or 1) in each target column.
def bce(inputs, targets):
inputs = inputs.sigmoid()
return -torch.where(targets==1, inputs, 1-inputs).log().mean()
print(bce(res, yb))
So breaking the above down, line by line, for a single input/targets ...
inps = res.sigmoid()
print(f'1. {inps[0]}')
print(f'2. {yb[0]}')
print(f'3. {torch.where(yb==1, inps, 1-inps)[0]}')
print(f'4. {torch.where(yb==1, inps, 1-inps)[0].log()}')
print(f'5. {torch.where(yb==1, inps, 1-inps)[0].log().mean()}')
print(f'6. {-torch.where(yb==1, inps, 1-inps)[0].log().mean()}')
... what is binary cross-entropy loss doing?
Scale all activations to be between 0 and 1 using the sigmoid
function (1). The resulting activations tell us, for each potential label, how confident the model is that the value is a "1".
Build a tensor with a value for each target (2); if the target = 1 then use the corresponding scaled value above ... if the target = 0, then use 1 minus this value (3). Notice how confident correct predictions will be very large, while confident incorrect predictions will be very small. We can think of this value as telling us how right the model was in predicting each label.
Take the log
(4) which will will turn correct and more confident predictions (those closer to 1) to a value closer to zero, and wrong and more confident prediction to a value closer to 0. This exactly what we want since the better the model, the smaller the lost, and the log(1) = 0
where as the log(0)
approaches negative infinity! See the chart below.
Lastly, because the loss has to be a single value, we mean the losses for each label (5), and then turn it into a positive (6).
plot_function(torch.log, 'x (prob correct class)', '-log(x)', title='Negative Log-Likelihood', min=0, max=1)
Fortunately, PyTorch has a function and module we can use:
loss = F.binary_cross_entropy_with_logits(res, yb)
print(loss)
# modular form (most commonly used)
loss_func = nn.BCEWithLogitsLoss()
loss = loss_func(res, yb)
print(loss)
# and for shits and giggles
print(bce(res, yb))
F.binary_cross_entropy_with_logits
(or nn.BCEWithLogitsLoss
), which do both sigmoid and binary cross entropy in a single function.
If the final activations already have the sigmoid applied to it, then you'd use F.binary_cross_entropy
(or nn.BCELoss
).
learn = cnn_learner(dls, resnet50, metrics=partial(accuracy_multi, thresh=0.2))
learn.fine_tune(3, base_lr=3e-3, freeze_epochs=4)
learn.metrics = partial(accuracy_multi, thresh=0.2)
print(learn.validate())
learn.metrics = partial(accuracy_multi, thresh=0.9)
print(learn.validate())
learn.metrics = partial(accuracy_multi, thresh=0.75)
print(learn.validate())
get_preds
applies the output activat function (sigmoid, in this case) for us, so we’ll need to tell accuracy_multi
to not apply it"
preds, targs = learn.get_preds()
print(accuracy_multi(preds, targs, thresh=0.9, sigmoid=False))
thresholds = torch.linspace(0.05, 0.99, 29)
accs = [accuracy_multi(preds, targs, thresh=th, sigmoid=False) for th in thresholds]
plt.plot(thresholds, accs)
print(accuracy_multi(preds, targs, thresh=0.55, sigmoid=False))
See p.231 for more discussion on this.
Summary
You now know how to train both multi-label and muticlass vistion problems, when to use one or another, and what loss function to choose for each. You should consider treating multiclass problems where the predicted class should be "None" as a multi-label problem, especially if this is going to be used in the real-world and not just against some prefabbed dataset.
Also, we're using accuracy as our metric to optimize the threshold in the example above, but you can use any metric (or combination of metrics you want). For example, a common issue with multi-label tasks is unbalanced datasets where one or more targets are ill represented in number. In that case, it may be more productive to use something like F1 or Recall.
Resources
- https://book.fast.ai - The book's website; it's updated regularly with new content and recommendations from everything to GPUs to use, how to run things locally and on the cloud, etc...
- To learn more about pandas check the pandas documentation and chrisalbon.com.