Main Content

plotconfusion

绘制分类混淆矩阵

说明

示例

plotconfusion(targets,outputs) 绘制真实标签 targets 和预测标签 outputs 的混淆矩阵。将标签指定为分类向量,或以一对 N (one-hot) 形式指定。

提示

不推荐将 plotconfusion 用于分类标签。请改用 confusionchart

在混淆矩阵图上,行对应于预测类(输出类),列对应于真实类(目标类)。对角线上的单元格对应于正确分类的观测值。非对角线上的单元格对应于未正确分类的观测值。每个单元格中都显示观测值数目和观测值数目占总观测值数目的百分比。

绘图最右侧的列显示预测属于正确分类和未正确分类的每个类的示例占所有示例的百分比。这些度量通常分别称为精确度(或正预测值)和假发现率。绘图底部的行显示属于正确分类和未正确分类的每个类的示例占所有示例的百分比。这些度量通常分别称为召回率(或真正率)和假负率。绘图右下角的单元格显示整体准确度。

plotconfusion(targets,outputs,name) 绘制混淆矩阵,并在绘图标题的开头添加 name

plotconfusion(targets1,outputs1,name1,targets2,outputs2,name2,...,targetsn,outputsn,namen) 在一个图窗中绘制多个混淆矩阵,并将 name 参量添加到对应绘图标题的开头。

示例

全部折叠

加载由手写数字的合成图像组成的数据。XTrain 是一个 28×28×1×5000 图像数组,YTrain 是一个包含图像标签的分类向量。

load DigitsDataTrain
classNames = categories(labelsTrain);

定义一个卷积神经网络的架构。

layers = [
    imageInputLayer([28 28 1])
    
    convolution2dLayer(3,8,'Padding','same')
    batchNormalizationLayer
    reluLayer    
    convolution2dLayer(3,16,'Padding','same','Stride',2)
    batchNormalizationLayer
    reluLayer
    convolution2dLayer(3,32,'Padding','same','Stride',2)
    batchNormalizationLayer
    reluLayer
    
    fullyConnectedLayer(10)
    softmaxLayer];

指定训练选项并训练网络。

options = trainingOptions('sgdm', ...
    'MaxEpochs',5, ...
    'Verbose',false, ...
    'Plots','training-progress', ...
    'Metrics','accuracy');
net = trainnet(XTrain,labelsTrain,layers,"crossentropy",options);

使用经过训练的网络加载测试数据并对其分类。

load digitsDataTest
scores = minibatchpredict(net,XTest)
scores = 5000×10 single matrix

    0.9979    0.0005    0.0005    0.0001    0.0001    0.0000    0.0004    0.0000    0.0002    0.0002
    0.9297    0.0000    0.0320    0.0003    0.0000    0.0001    0.0017    0.0005    0.0004    0.0353
    0.9996    0.0000    0.0001    0.0000    0.0000    0.0000    0.0001    0.0000    0.0000    0.0002
    0.9195    0.0000    0.0004    0.0001    0.0002    0.0000    0.0147    0.0000    0.0016    0.0635
    0.9429    0.0006    0.0032    0.0003    0.0000    0.0021    0.0040    0.0021    0.0447    0.0002
    0.9906    0.0003    0.0001    0.0000    0.0000    0.0000    0.0003    0.0000    0.0002    0.0085
    0.9938    0.0000    0.0004    0.0000    0.0000    0.0000    0.0011    0.0000    0.0041    0.0006
    0.9998    0.0000    0.0000    0.0000    0.0000    0.0000    0.0001    0.0000    0.0000    0.0000
    0.6722    0.0001    0.0028    0.0002    0.0013    0.0022    0.0007    0.0001    0.0035    0.3168
    0.8069    0.0020    0.1340    0.0041    0.0009    0.0015    0.0173    0.0010    0.0034    0.0288
      ⋮

YTest = scores2label(scores,classNames);

绘制由真实测试标记 YTest 和预测标签 YPredicted 组成的混淆矩阵。

plotconfusion(YTest,YTest)

行对应于预测类(输出类),列对应于真实类(目标类)。对角线上的单元格对应于正确分类的观测值。非对角线上的单元格对应于未正确分类的观测值。每个单元格中都显示观测值数目和观测值数目占总观测值数目的百分比。

绘图最右侧的列显示预测属于正确分类和未正确分类的每个类的示例占所有示例的百分比。这些度量通常分别称为精确度(或正预测值)和假发现率。绘图底部的行显示属于正确分类和未正确分类的每个类的示例占所有示例的百分比。这些度量通常分别称为召回率(或真正率)和假负率。绘图右下角的单元格显示整体准确度。

关闭所有图窗。

close(findall(groot,'Type','figure'))

使用 cancer_dataset 函数加载样本数据。XTrain 是一个 9×699 矩阵,定义 699 个活检的九个属性。YTrain 是一个 2×699 矩阵,其中每列表示对应观测值的正确类别。YTrain 的每列在第一行或第二行都有一个等于 1 的元素,分别对应于良性或恶性肿瘤。有关此数据集的详细信息,请在命令行中键入 help cancer_dataset

rng default
[XTrain,YTrain] = cancer_dataset;
YTrain(:,1:10)
ans = 2×10

     1     1     1     0     1     1     0     0     0     1
     0     0     0     1     0     0     1     1     1     0

创建一个模式识别网络,并使用样本数据对其进行训练。

net = patternnet(10);
net = train(net,XTrain,YTrain);

使用经过训练的网络估计肿瘤状态。矩阵 YPredicted 的每列包含分别属于类 1 和类 2 的每个观测值的预测概率。

YPredicted = net(XTrain);
YPredicted(:,1:10)
ans = 2×10

    0.9980    0.9979    0.9894    0.0578    0.9614    0.9960    0.0026    0.0023    0.0084    0.9944
    0.0020    0.0021    0.0106    0.9422    0.0386    0.0040    0.9974    0.9977    0.9916    0.0056

绘制混淆矩阵。为了创建绘图,plotconfusion 会根据最高类概率为每个观测值加标签。

plotconfusion(YTrain,YPredicted)

在此图窗中,对角线上的前两个单元格显示经过训练的网络的正确分类的数量和百分比。例如,446 个活检被正确分类为良性。占所有 699 个活检的 63.8%。类似的,236 个病例被正确分类为恶性。占所有活检的 33.8%。

5 个恶性活检被不正确地分类为良性,占数据中所有 699 个活检的 0.7%。类似的,12 个良性活检被不正确地分类为恶性,占所有数据的 1.7%。

在 451 个良性预测中,98.9% 是正确的,1.1% 是错误的。在 248 个恶性预测中,95.2% 是正确的,4.8% 是错误的。在 458 个良性病例中,97.4% 被正确预测为良性,2.6% 被预测为恶性。在 241 个恶性病例中,97.9% 被正确分类为恶性,2.1% 被分类为良性。

总的来说,97.6% 的预测是正确的,2.4% 是错误的。

输入参数

全部折叠

真实类标签,指定为以下项之一:

  • 分类向量,其中每个元素是一个观测值的类标签。outputstargets 参量必须具有相同的元素数。如果分类向量定义基础类,则 plotconfusion 显示所有基础类,即使没有某些基础类的观测值也是如此。如果参量是有序分类向量,则它们必须以相同的顺序定义相同的基础类别。

  • N×M 矩阵,其中 N 是类数目,M 是观测值数目。矩阵的每列必须采用一对 N (one-hot) 形式,其中一个元素等于 1,表示真实标签,所有其他元素等于 0。

预测类标签,指定为以下项之一:

  • 分类向量,其中每个元素是一个观测值的类标签。outputstargets 参量必须具有相同的元素数。如果分类向量定义基础类,则 plotconfusion 显示所有基础类,即使没有某些基础类的观测值也是如此。如果参量是有序分类向量,则它们必须以相同的顺序定义相同的基础类别。

  • N×M 矩阵,其中 N 是类数目,M 是观测值数目。矩阵的每列可以采用一对 N (one-hot) 形式,其中等于 1 的单个元素指示预测的标签;也可以采用总和为 1 的概率形式。

混淆矩阵的名称,指定为字符数组。plotconfusion 将指定的 name 添加到绘图标题的开头。

数据类型: char

版本历史记录

在 R2008a 中推出