One of the most common problems in machine learning:
How do you deal with imbalanced datasets?
Not only does this happen frequently, but it's also a popular interview question.
Here are seven different techniques to deal with this problem:
What's an imbalanced dataset?
Imagine you have pictures of cats and dogs. Your dataset has 950 cat pictures and only 50 dog pictures.
That's an imbalanced dataset.
There's a significant difference in the number of samples for each class.
Why is an imbalanced dataset a problem?
A model that classifies every picture as a cat will be 95% accurate!
Think about this: A dumb model will get you to 95% accuracy because of your imbalanced classes.
That's a big problem, and here is how you solve it:
1. Accuracy is not a good metric for imbalanced problems.
Instead, look at any of the following metrics:
• A combination of Precision and Recall
• F-Score
• Confusion Matrix
• ROC Curves
2. Collect more data.
If you can find more dog pictures, do that.
Sometimes this is not possible, but other times it's the simplest solution you can do.
3. Augment the dataset with synthetic data.
If you have the means to create realistic samples, you can do that to augment the dataset and balance it.
For example, Tesla uses synthetic data to train their models on non-common situations.
4. Resample your dataset.
• Oversample the pictures of dogs.
• Undersample the pictures of cats.
You can also combine both.
Here is an example:
You can resample our hypothetical dataset by doing the following:
• Use every dog picture four times.
• Use every other cat picture.
New dataset:
• Dogs: 400 pictures (50 × 4)
• Cats: 475 pictures (950 ÷ 2)
Much better, huh?
Important note:
Both over and undersampling introduce biases into your dataset. You are changing the data distribution by arbitrarily messing with the existing samples.
Make sure you keep this in mind and think about the consequences.
Let's continue:
5. Weight each class differently.
There are multiple techniques to weight each class differently and have the model pay more or less attention to those samples.
For example, we can have a large weight for dogs to compensate for the lack of samples.
6. Different algorithms handle imbalances differently.
Decision Trees are excellent at handling imbalanced classes. Neural networks, not so much.
Make sure you use the correct algorithm to work on your problem.
7. Make sure you approach the problem correctly.
Many people have tried to solve anomaly detection problems using multi-class classification.
That's the wrong approach.
Understand what problem you are trying to solve before deciding how to do it.
Let's recap how you can handle an imbalanced dataset:
1. Pick the appropriate performance metric
2. Collect more data
3. Generate synthetic data
4. Resample the dataset
5. Use different weights
6. Try different algorithms
7. Approach the problem correctly.
Every week, I post 1 or 2 threads like this, breaking down machine learning concepts and giving you ideas on applying them in real-life situations.
Follow me @svpino and make sure you don't miss my next thread.