Learning when to skim and when to read
The rise of Machine Learning, Deep Learning, and Artificial Intelligence more generally has been undeniable, and it has already had a massive impact on the field of computer science. By now, you might have heard how deep learning has surpassed super-human performance in a number of tasks ranging from image recognition to the game of Go.
The deep learning community is now eyeing natural language processing (NLP) as the next frontier of research and application.
One beauty of deep learning is that advances tend to be very generic. For example, techniques that make deep learning work for one domain can often be transferred to other domains with little to no modification. More specifically, the approach of building massive, computationally expensive, deep learning models for image and speech recognition has spilled into NLP. One can see this in the case of the most recent state-of-the-art translation system, which outperformed all previous results, but required an exorbitant amount of computers. Such demanding systems can capture very complex patterns occasionally found in real world data, but this has led many to apply these massive models to all tasks. This raises the question:
Do all tasks always have the complexity that requires such models?
Let's look at the innards of a two layered MLP trained on bag-of-words embeddings for sentiment analysis.

Deep learning with text
Most deep learning methods require floating point numbers as input and, unless you have been working with text before, you might wonder:
How do I go from a piece of text to deep learning?
A core issue with text is how to represent an arbitrarily large amount of information, given the length of the material. A popular solution has been tokenizing text into either words, sub-words, or even characters. Each word is transformed into a floating point vector using well studied methods such as word2vec or GloVe. This provides for meaningful representations of a word through the implicit relationships between different words.

By using tokenization and the word2vec methods we can turn a piece of text into a sequence of floating point representations of each word.
Now, what can we use a sequence of word representations for?
Bag-of-words

Recurrent Neural Networks


Example: Sentiment Analysis
Sentiment analysis is a type of document classification for quantifying polarity in subjective passages. Given a sentence, the model evaluates whether it is positive, negative or neutral.
Want to find livid customers on twitter before they start trending? Well, Sentiment Analysis might be just what you’re looking for!
A great public dataset for this purpose (which we will use) is the Stanford sentiment treebank (SST). We have provided a publicly available data loader in pytorch. The SST provides not only the class (positive, negative) for a sentence, but also each of its grammatical subphrases. In our systems we do not utilize any tree information however. The original SST constitutes five classes: very positive, positive, neutral, negative and very negative. We consider the simpler task of binary classification where very positive is combined with positive, very negative is combined with negative and all neutrals are removed.
We have provided a brief and technical description of our model architecture. The important point is not exactly how it is structured, but the fact that the cheap model gets 82% validation accuracy and takes 10 ms for a 64 sized batch, and the expensive LSTM achieves a significantly higher 88% validation accuracy but costs 87 ms for a 64 sized batch (Top models will be in the 88-90% accuracy ballpark).


The cheap skim reader
On some tasks, algorithms can perform at near human level accuracy, but obtaining this performance might burn a hole in the server budget. You also know that if it is not always necessary to have an LSTM powerhouse with real world data, we might be just fine with the cheaper bag-of-words. But what happens when you get a sentence such as this:
"Horrible cast, complete lack of reality, …, but I loved it 9/10”
The order agnostic bag-of-words will surely missclassify with the overwhelming amount of negative words. Completely switching to a crummy bag-of-words would drop our overall performance, which doesn’t sound that compelling. So the question becomes:
Can we learn to separate ‘easy’ and ‘hard’ sentences?
And can we do so with a cheap model to save time?
Exploring the innards


T-SNE plots are vulnerable to many over-interpretations, but a few trends might strike you.
Interpretations of T-SNE
- The sentences fall into clusters. The clusters consitutes different semantic types.
- Some clusters lie along a simple manifold with high confidence and accuracy.
- Other clusters are more scattered with low accuracy and low confidence.
- Sentences with positive and negative consituents are difficult.
Let's now look at a similar plot for the LSTM.


We can assess that many of these observations hold true for the LSTM as well. However, the LSTM only has relatively few examples with low confidence, and cooccurrence of positive and negative consituents in sentences does not look to be as challenging for the LSTM as it is for the bag-of-words.
It seems the bag-of-words has been able to cluster sentences and use its probabillity to identify whether or not it is likely to give a correct prediction for the sentences in that cluster. From these observations, a reasonable hypothesis could be
Confident answers are more correct.
To investigate this hypothesis, we can look at probability thresholds.
Probability thresholding
The bag-of-words and LSTM are trained to give us probabilities for each class, which we can use as a measure of certainty. What do we mean by this? If the bag-of-words returns a 1, it is very confident in its prediction.
Often when predicting we would take the class with the highest likelihood provided by our model. In the case of binary classification (e.g. positive or negative) the likelihood has to be over 0.5 (or else we would be predicting the opposite class!). But a low likelihood for the predicted class might indicate that the model was in doubt. Say the model predicted 0.49 for negative and 0.51 for positive, it might not be so convincing that it actually is positive.
When we say that we threshold, what we mean is that we compare the predicted probability to a value and assess whether or not to use it. E.g. we could decide that we use all sentences with a probability above 0.7. Or we look at the interval 0.5-0.55 to see how accurate predictions with this confidence are, which is exactly what we will investigate in the next plot.


In the data amount plot, the height of the bar corresponds to the amount of data reciding within two thresholds and the line is the accumulated data from each threshold bin.
From the bag-of-words plots it might occour to you that increasing the probability threshold increases the performance. From the LSTM plot it is not so obvious, which seems common as the LSTM overfits the training set and only provides confident answers.
Use the BoW for easy examples, and the pristine LSTM for difficult ones.
Thus, simply using the output probability could give us an indication of when a sentence is easy and when it is in need of guidance from a stronger system, like the powerful LSTM.
Using the probability threshold, we create a strategy which we refer to as the "probability strategy", such that we threshold the probability of the bag-of-word system, and use the LSTM on all data points not reaching the threshold. Doing so provides us with an amount of data used for the bag-of-words (sentences above the threshold) and a set of data points where we have either chosen the BoW (above the threshold) or the LSTM (below the threshold), which we can use to find an accuracy and cost of computing. We then get a ratio between the BoW and the LSTM increasing from 0.0 (only using LSTM) to 1.0 (only using BoW), which we can use to calculate the accuracy and time to compute.
Baseline


The interesting discovery is that we find that using the bag-of-words thresholds significantly outperforms not having a guided strategy.
We then measure the average value on the curve, which, we refer to as Speed Under the Curve (SUC). As shown in table below.

Learning when to skim and when to read
Knowing when to switch between two different models is not enough. We want to build a more general system that learns when to switch between each model. Such a system would help us deal with the more complicated behaviour of
Can we learn when reading is strictly better than skimming in a supervised way?
Where "reading" is using the LSTM which goes from left to right and stores a memory at each time step and "skimming" is using the BoW model. When operating on the probability from the bag-of-words model we make our decision based on the invariant that the more powerful LSTM will do a better job when the bag-of-word system is in doubt, but is that always the case?

Learning when to skim and when to read


From the comparison plot, we find that it is easy to assert when the BoW is correct and when it is in doubt. However, there is no clear relationship between when the LSTM might be right or wrong.
Can we learn this relationship?
Further, the probability strategy is quite restrictive as it relies on an inheritent binary decision and requires probabilities. Instead, we propose a trainable decision network that is based on a neural network. If we look at the confusion matrix, we can use that information to generate labels for a supervised decision network. In this way, we would only use the LSTM in the cases where the LSTM is correct and the BoW is wrong.
To generate the dataset, we need a set of sentences having the true, underlying, prediction of our bag-of-words and the LSTM. However, during training the LSTM will often achieve upwards 99% training accuracy, significantly overfitting the training set. To avoid this, we split our training set into a model training set (80% of training data) and a decision training set (remaining 20% of training data) that the model has not yet seen. Afterwards we fine-tune our model with the remaining 20%, hoping that the decision network will still generalize to this new, unseen, but very related and slightly better system.


The classes chosen on the validation set by the decision network, based on the models trained on the model training set, is then applied to the full, but very related, models on the full training set. The reason why we apply it on the model trained on the full training set, is that the models on the model training set will often be inferior and thus result in a lower accuracy. The decision network is trained with early stopping, based on maximizing the SUC on the validation set.
How does our decision network perform?
Let us start by looking at the predictions of the decision network.




Discussion
We now know that large powerful LSTMs can achieve near human-level performance on text, that not all real-world data needs near human-level performance, that we can train a bag-of-words model to understand when a sentence is easy and that using bag-of-words for easy sentences allows us to save a significant amount of computation time with only a minor drop in performance (depending on how aggressive we threshold the bag-of-words).
This approach is related to mean averaging usually performed when model ensembling as often the model with high confidence will be used. However, by having an adjustable confidence from the bag-of-words and not needing to run the LSTM, we can decide how much computation time vs. accuracy savings we are interested in. We believe that this method will be useful for deep learning engineers looking to save computational resources without having to sacrifice performance.
Citation credit
Learning when to skim and when to read (arxiv paper coming soon)