今天的任务是依照这篇介绍的方法,使用GoogleNet和AlexNet迁移学习ECG。
Signal Classification with Wavelet Analysis and Convolutional Neural Networks
整个实现流程包括以下几步:
- 下载三个ECG Dataset;
- 整理数据集,包括降采样、截断、标签,存储到一个ECG_Data的structure里;
- Plot原始数据;
- 使用CWT,得到scalogram,作为该样本的输入特征图;
- 划分训练集和测试集;
- 使用GoogleNet进行训练;
- 使用AlexNet进行训练;
下载ECG Dataset
实验选用的数据集为以下三个:
The MIT-BIH Atrial Fibrillation Database
The MIT-BIH ST Change Database
The MIT-BIH Supraventricular Arrhythmia Database
由于网站提供的下载整个数据包的链接失效,所以使用wget工具来下载数据库,指令如下:
wget -r -np http://physionet.org/physiobank/database/afdb/
wget -r -np http://physionet.org/physiobank/database/stdb/
wget -r -np http://physionet.org/physiobank/database/svdb/
下载完成后,共得到23+28+78=129条记录,每个记录都包含双通道的心电信号。
整理数据集
- 使用WFDB工具箱中的rdsamp函数,新建一个id_list,数组内存储数据库的样本编号,按数组依次读取信号;
- 使用resample函数,将各个样本的采样率都降为128Hz;
- 直接截取信号的前65536个点(即512s信号);
- 分离信号的两个通道,存储在ECGData.data中,同时在ECGData.label中存储对应的病症标签,建立Data Structure,大小为(248, 65536)和(248, 1)。
stdb的第313-317,319-323文件只有一列数据。
经过整理后的ECGData Structure组成如下:
1-46:AF记录,共46条;
47-92:ST记录,共46条;
93-248,SV记录,共157条。
%% Load AFDB
start=1
flag={'AF'};
id_list=[04015,04043,04048,04126,04746,04908,04936,05091,05121,05261,06426,06453,06995,07162,07859,07879,07910,08215,08219,08378,08405,08434,08455];
for id = 1:23
signal=rdsamp(['/database/afdb/', num2str(id_list(id), '%05d')]);
resamp_signal=resample(signal, 128, 250);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
%% Load STDB
start=47
flag={'ST'};
% 313~317 319~323
id_list=[300,301,302,303,304,305,306,307,308,309,310,311,312,318,324,325,326,327];
for id = 1:18
signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 360);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
id_list=[313,314,315,316,317,319,320,321,322,323];
for id = 1:10
signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 360);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
start=start+1;
id
end
%% Load SVDB
start=93
flag={'SV'};
id_list=[800,801,802,803,804,805,806,807,808,809,810,811,812,820,821,822,823,824,825,826,827,828,829,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894];
for id = 1:78
signal=rdsamp(['/database/svdb/', num2str(id_list(id))]);
resamp_signal=resample(signal, 128, 128);
cutoff_signal=resamp_signal(1:65536, :);
ECGData.Data(:,start)=cutoff_signal(1:65536,1);
ECGData.Labels(start)=flag;
ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
ECGData.Labels(start+1)=flag;
start=start+2;
id
end
%% Rebuild
ECGData.Data=ECGData.Data';
ECGData.Labels=ECGData.Labels';
Plot原始数据
调用例程中的helperPlotReps()函数,看原始数据。很奇怪的是第三类问题好像采样率有些奇怪,但是找不到问题的原因。
原始信号1000点
特征提取与数据集划分
首先使用cwtfilterbank函数对原始信号进行CWT变换,得到的结果如图所示。
CWT_Result
然后使用helpCreateRGBfromTF()对整个数据集进行变换,并使用splitEachLabel进行训练集和测试集的分割,分割得到大小为199的训练集和大小为49的测试集,存储在ImageDatastore里。
helperCreateRGBfromTF(ECGData,parentDir,dataDir)
allImages = imageDatastore(fullfile(parentDir,dataDir),...
'IncludeSubfolders',true,...
'LabelSource','foldernames');
rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
使用GoogleNet进行训练
GoogleNet是使用ImageNet训练的对于1000分类的深层CNN网络,其结构如图所示,为了进行迁移学习,我们将最后四层修改为针对三分类问题的输出。
lgraph = removeLayers(lgraph,{'pool5-drop_7x7_s1','loss3-classifier','prob','output'});
numClasses = numel(categories(imgsTrain.Labels));
newLayers = [
dropoutLayer(0.6,'Name','newDropout')
fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',5,'BiasLearnRateFactor',5)
softmaxLayer('Name','softmax')
classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);
lgraph = connectLayers(lgraph,'pool5-7x7_s1','newDropout');
inputSize = net.Layers(1).InputSize;
googlenet
同时,设置GoogleNet训练的一些参数,开始训练,训练结果如图所示。
options = trainingOptions('sgdm',...
'MiniBatchSize',15,...
'MaxEpochs',20,...
'InitialLearnRate',1e-4,...
'ValidationData',imgsValidation,...
'ValidationFrequency',10,...
'ValidationPatience',Inf,...
'Verbose',1,...
'ExecutionEnvironment','cpu',...
'Plots','training-progress');
rng default
trainedGN = trainNetwork(imgsTrain,lgraph,options);
trainedGN.Layers(end-2:end)
cNames = trainedGN.Layers(end).ClassNames
GoogleNet训练结果
result
使用GoogleNet训练的最终正确率为:71.429%
同时我们还观测了GoogleNet的激活函数、对于AF病症的激活函数以及最强的AF通道,如下图所示。
第一层Activations AF Activations strongest AF Channel
使用AlexNet进行训练
AlexNet共有25层,如下图所示。
AlexNet
针对本问题,我们修改了AlexNet的最后三层,同时改变图像的形状以匹配AlexNet的输入。
%% Load
alex=alexnet;
layers = alex.Layers
%% Modify AlexNet Network Parameters
layers(23) = fullyConnectedLayer(3);
layers(25) = classificationLayer;
%% Prepare RGB Data for AlexNet
inputSize = alex.Layers(1).InputSize;
augimgsTrain = augmentedImageDatastore(inputSize(1:2),imgsTrain);
augimgsValidation = augmentedImageDatastore(inputSize(1:2),imgsValidation);
同时,设置AlexNet训练的一些参数,开始训练,训练结果如图所示。
AlexNet训练结果
result
使用AlexNet训练的最终正确率为75.51%
网友评论