您当前的位置:首页 > IT编程 > TensorFlow
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers |

自学教程:Tensorflow2.0之绘图:分类问题的准确率、精确率、召回率、ROC曲线以及曲线下面积

51自学网 2023-10-26 22:12:28
  TensorFlow
这篇教程Tensorflow2.0之绘图:分类问题的准确率、精确率、召回率、ROC曲线以及曲线下面积写得很实用,希望能帮到您。
文章目录原文链接:https://blog.csdn.net/qq_36758914/article/details/104775509

    准确率、精确率、召回率、ROC曲线的定义
    用Tensorflow2.0绘制相关曲线
        建模时设置METRICS
        定义损失曲线、AUC曲线、精确率曲线以及召回率曲线函数
        定义ROC曲线函数
        预测训练集和测试集
        ROC曲线
        定义混淆矩阵函数
        绘制混淆矩阵

准确率、精确率、召回率、ROC曲线的定义

阳性与阴性
准确率
精确率与召回率
ROC 和曲线下面积
用Tensorflow2.0绘制相关曲线
建模时设置METRICS

METRICS = [
      keras.metrics.TruePositives(name='tp'),
      keras.metrics.FalsePositives(name='fp'),
      keras.metrics.TrueNegatives(name='tn'),
      keras.metrics.FalseNegatives(name='fn'),
      keras.metrics.BinaryAccuracy(name='accuracy'),
      keras.metrics.Precision(name='precision'),
      keras.metrics.Recall(name='recall'),
      keras.metrics.AUC(name='auc'),
]

model = keras.Sequential([
  keras.layers.Dense(16, activation='relu', input_shape=(train_features.shape[-1],)),
  keras.layers.Dropout(0.5),
  keras.layers.Dense(1, activation='sigmoid')
])

model.compile(
  optimizer=keras.optimizers.Adam(lr=1e-3),
  loss=keras.losses.BinaryCrossentropy(),
  metrics=METRICS)
 
history = model.fit(
    train_features,
    train_labels,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(val_features, val_labels))

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28

定义损失曲线、AUC曲线、精确率曲线以及召回率曲线函数

def plot_metrics(history):
    metrics =  ['loss', 'auc', 'precision', 'recall']
    for n, metric in enumerate(metrics):
        name = metric
        plt.subplot(2,2,n+1)
        plt.plot(history.epoch,  history.history[metric], color=colors[0], label='Train')
        plt.plot(history.epoch, history.history['val_'+metric],
                 color=colors[0], linestyle="--", label='Val')
        plt.xlabel('Epoch')
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1])
        else:
            plt.ylim([0,1])

        plt.legend()
plot_metrics(history)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19

在这里插入图片描述
定义ROC曲线函数

def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, _ = sklearn.metrics.roc_curve(labels, predictions)

    plt.plot(100*fp, 100*tp, label=name, linewidth=2, **kwargs)
    plt.xlabel('False positives [%]')
    plt.ylabel('True positives [%]')
    plt.xlim([-0.5,20])
    plt.ylim([80,100.5])
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11

预测训练集和测试集

train_predictions = model.predict(train_features, batch_size=BATCH_SIZE)
test_predictions = model.predict(test_features, batch_size=BATCH_SIZE)

    1
    2

ROC曲线

plot_roc("Train Baseline", train_labels, train_predictions, color=colors[0])
plot_roc("Test Baseline", test_labels, test_predictions, color=colors[0], linestyle='--')
plt.legend(loc='lower right')

    1
    2
    3

在这里插入图片描述
定义混淆矩阵函数

def plot_cm(labels, predictions, p=0.5):
    cm = confusion_matrix(labels, predictions > p)
    plt.figure(figsize=(5,5))
    sns.heatmap(cm, annot=True, fmt="d")
    plt.title('Confusion matrix @{:.2f}'.format(p))
    plt.ylabel('Actual label')
    plt.xlabel('Predicted label')

    1
    2
    3
    4
    5
    6
    7

绘制混淆矩阵

plot_cm(test_labels, test_predictions)

    1

在这里插入图片描述

————————————————
版权声明:本文为CSDN博主「cofisher」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_36758914/article/details/104775509
返回列表
出现Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR错误的解决办法dnn_rnn_ops.cc:1510 : Unknown: Fa
51自学网自学EXCEL、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。
京ICP备13026421号-1