What loss function to use for imbalanced classes (using PyTorch)?

29

9

I have a dataset with 3 classes with the following items:

  • Class 1: 900 elements
  • Class 2: 15000 elements
  • Class 3: 800 elements

I need to predict class 1 and class 3, which signal important deviations from the norm. Class 2 is the default “normal” case which I don’t care about.

What kind of loss function would I use here? I was thinking of using CrossEntropyLoss, but since there is a class imbalance, this would need to be weighted I suppose? How does that work in practice? Like this (using PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

Or should the weight be inverted? i.e. 1 / weight?

Is this the right approach to begin with or are there other / better methods I could use?

Thanks

Muppet

Posted 2019-04-01T19:00:04.877

Reputation: 545

When you say: You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same. But this solution is in contradiction with the first one you gave, how does it work ? – Georges Matar – 2019-10-05T15:41:40.180

Answers

19

What kind of loss function would I use here?

Cross-entropy is the go-to loss function for classification tasks, either balanced or imbalanced. It is the first choice when no preference is built from domain knowledge yet.

This would need to be weighted I suppose? How does that work in practice?

Yes. Weight of class $c$ is the size of largest class divided by the size of class $c$.

For example, If class 1 has 900, class 2 has 15000, and class 3 has 800 samples, then their weights would be 16.67, 1.0, and 18.75 respectively.

You can also use the smallest class as nominator, which gives 0.889, 0.053, and 1.0 respectively. This is only a re-scaling, the relative weights are the same.

Is this the right approach to begin with or are there other / better methods I could use?

Yes, this is the right approach.

EDIT:

Thanks to @Muppet, we can also use class over-sampling, which is equivalent to using class weights. This is accomplished by WeightedRandomSampler in PyTorch, using the same aforementioned weights.

Esmailian

Posted 2019-04-01T19:00:04.877

Reputation: 7 434

2I just wanted to add that using WeightedRandomSampler from PyTorch also helped, in case someone else is looking at this. – Muppet – 2019-04-02T17:40:59.770

When the labels are imbalanced, say 11 labels, one of them takes 17%, and others take 6-9%, Cross-entropy cannot learn that fast, at early stage, the loss focuses on learning the label which takes the largest proportion. – GoingMyWay – 2020-06-04T15:33:14.537