Can you teach a convolutional neural network new tricks?
If you’re familiar with our platform, Boon AI, then you know that it was built to help media technologists accelerate machine learning (ML) integrations by taking advantage of pre-trained models. The idea is that it’s not necessary to build solutions from the ground-up if the problem you’re solving has already been addressed by the market.
In the same sense, Boon enables customers to accelerate ML implementation by generalizing the learnings from a convolutional network that has been trained for a different purpose and applying it to their business use case. We refer to this process as transfer learning, which has been touted as a driver of ML commercial success. Let’s take a look to see how this works.
“One problem is that a CNN takes a very long time to train from scratch, and the training requires extremely big sets of pre-labeled images. This is of course not very practical, and in most cases not even possible. Luckily for us, it is possible to take a pre-trained CNN and “teach” it new concepts.”—On the challenges of taking on neural network training
Convolutional Neural Networks
Convolutional Neural Networks, or CNNS, are very good at classifying images. Because of the way these neural networks are trained, they are also really good at being repurposed for different classification tasks and at determining when images are semantically similar. In this post we will try to gain some intuitions about how all this works, and understand how these ideas can be used to work with images in real life. We are going to do this with very little or no math, and with only some simple concepts from geometry.
Image classification with CNNs can be very useful in a production environment. For example, a video editor might want to search a large library of footage by whether the shots are interior or exterior, whether they show mountains, lakes or the beach, etc. A trained CNN can find these images, and frames within video just as well, without a person having to sift through the repository.
One problem is that a CNN takes a very long time to train from scratch, and the training requires extremely big sets of pre-labeled images. Think a million images, and a week or two of processing to train. This is of course not very practical, and in most cases not even possible. Pre-trained CNNs exist, and will generate relevant keywords for the images you show them. But these pre-trained networks are trained to find images belonging to a fixed set of categories, typically a general group of concepts. So using a pre-trained CNN out of the box is not practical in production either. If you are looking for mountains or lakes, having keywords like “cat” and “dog” in your results would be just noise.
Luckily for us, it is possible to take a pre-trained CNN and “teach” it new concepts, and that process doesn’t take too long. In the next few paragraphs, we will learn how and why that works.
The “Convolutional” in a Convolutional Neural Network means that many of the layers inside the network perform convolutions of their input image.Think of a convolution as an image filter that enhances some feature of the image, like vertical or horizontal lines, edges, certain patterns, etc. The process of training a CNN produces weights for each layer, which determine what the layer does to its inputs.
Each layer of convolutions is able to detect higher and higher level kinds of features. The first layer might find lines or patterns. Subsequent layers find combinations of the patterns already found. Near the end of the network, the results (or “activations”) are encoded into a high-dimensional vector, with typically 1024, 2048 or more dimensions. This layer is in turn connected to an output layer, where there is one dimension per class that we are trying to classify for. The class predicted by the CNN is the one corresponding to the position with the highest value in this output layer.
This second to last layer in a CNN is extremely useful, and in the rest of the post we are going to see how. We are going to refer to it as an “embedding,” a “feature vector,” or, when quantized into a form that we can write into our database and use to make fast comparisons, a “similarity hash.” All these three terms refer to the same thing.
Transfer learning involves teaching a previously trained CNN new concepts. We saw that the first part of a CNN actually detects general patterns. All it takes to teach a CNN a new set of concepts is to leave all the original weights from the early layers untouched, and only train the values connecting this second to last layer to the output.
“Leaving the early layers untouched” is equivalent to using the embedding, that vector we described above, as the input to our new retrained network. Now we can explore why this works, and in doing so we arrive at the concept of semantic image similarity.
We can think of that 2048-dimensional vector near the end of our CNN as representing a point in space–this is of course a 2048-D space, so it is not recommended to try to visualize it. But it is useful to think about it in three, or even two dimensions, and in fact techniques exist to “fold” all those dimensions into two or three, in a way that we can visualize those points.
Now let’s imagine we have many images, and for each one we compute this 2048-D vector. The resulting set of points will be distributed in space in some way, and it turns out that that distribution is extremely meaningful. For a CNN trained for image classification, points that are near each other in space will represent images that are similar to each other. If we plot those points (again, by reducing their dimensions into three, for example), we can often see clusters emerging, where the groups of points represent images that might belong to some consistent category.
Now imagine we saved those vectors along with each image in our database, and that we had a very fast way to, given one vector, find the closest ones from the whole set. If we then sorted all the images according to how close their vectors are to our original image, we would get them sorted by similarity.
Here I have selected one image of my dog, in a dataset of thousands of my phone photos. The image that I started the search with is the top left one. All the other ones are sorted by proximity in that 2048-D vector space.
We can use those vectors as input for a new model, and in this way do transfer learning. Or we can use them directly, to explore the space of features found by the CNN and get things like semantic image similarity.