在 pytorch 中,BCEWithLogitsLoss
可以用来计算多标签多分类问题的损失函数。定义如下:
loss_fn = torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
loss = loss_fn(logits, y_true)
需要注意的是,由于该损失函数包含了 Cross Entropy 以及 Sigmoid 函数的计算过程,只需要将神经网络生成的 logits 值(未经 sigmoid 计算)及真实标签输入即可。logits 和 y_true 的维度应保持一致。
损失函数的参数包含两个权重,分别是weight
和pos_weight
,Tensor 维度均为 [num_classes],可以用来缓解训练样本不均衡带来的预测标签不准确的问题。
weight
表示不同类别标签的权重,用来解决不同类别标签不均衡的问题。
pos_weights
表示每个类别正样本的权重,用来解决标签内部正负样本不均衡的问题。
num_classes = 5
label_counts = [1, 2, 0, 4, 5]
total_samples = 20
total_labels = sum(label_counts)
class_weight = [0 for _ in range(num_classes)]
pos_weight = [0 for _ in range(num_classes)]
for label_idx, count in enumerate(label_counts):
if count != 0:
class_weight[label_idx] = 1 - count / total_labels
pos_weight[label_idx] = total_samples / count - 1
print('class_weight:', class_weight)
print('pos_weight:', pos_weight)
输出:
class_weight: [0.9166666666666666, 0.8333333333333334, 0, 0.6666666666666667, 0.5833333333333333]
pos_weight: [19.0, 9.0, 0, 4.0, 3.0]