美文网首页
使用GoogleNet和AlexNet迁移学习ECG

使用GoogleNet和AlexNet迁移学习ECG

作者: Betrayer丶 | 来源:发表于2018-07-04 09:52 被阅读0次

今天的任务是依照这篇介绍的方法,使用GoogleNet和AlexNet迁移学习ECG。
Signal Classification with Wavelet Analysis and Convolutional Neural Networks


整个实现流程包括以下几步:

  1. 下载三个ECG Dataset;
  2. 整理数据集,包括降采样、截断、标签,存储到一个ECG_Data的structure里;
  3. Plot原始数据;
  4. 使用CWT,得到scalogram,作为该样本的输入特征图;
  5. 划分训练集和测试集;
  6. 使用GoogleNet进行训练;
  7. 使用AlexNet进行训练;

下载ECG Dataset

实验选用的数据集为以下三个:
The MIT-BIH Atrial Fibrillation Database
The MIT-BIH ST Change Database
The MIT-BIH Supraventricular Arrhythmia Database
由于网站提供的下载整个数据包的链接失效,所以使用wget工具来下载数据库,指令如下:

wget -r -np http://physionet.org/physiobank/database/afdb/
wget -r -np http://physionet.org/physiobank/database/stdb/
wget -r -np http://physionet.org/physiobank/database/svdb/

下载完成后,共得到23+28+78=129条记录,每个记录都包含双通道的心电信号。

整理数据集

  1. 使用WFDB工具箱中的rdsamp函数,新建一个id_list,数组内存储数据库的样本编号,按数组依次读取信号;
  2. 使用resample函数,将各个样本的采样率都降为128Hz;
  3. 直接截取信号的前65536个点(即512s信号);
  4. 分离信号的两个通道,存储在ECGData.data中,同时在ECGData.label中存储对应的病症标签,建立Data Structure,大小为(248, 65536)和(248, 1)。

stdb的第313-317,319-323文件只有一列数据。

经过整理后的ECGData Structure组成如下:
1-46:AF记录,共46条;
47-92:ST记录,共46条;
93-248,SV记录,共157条。

%% Load AFDB
start=1
flag={'AF'};
id_list=[04015,04043,04048,04126,04746,04908,04936,05091,05121,05261,06426,06453,06995,07162,07859,07879,07910,08215,08219,08378,08405,08434,08455];
for id = 1:23
    signal=rdsamp(['/database/afdb/', num2str(id_list(id), '%05d')]);
    resamp_signal=resample(signal, 128, 250);
    cutoff_signal=resamp_signal(1:65536, :);
    ECGData.Data(:,start)=cutoff_signal(1:65536,1);
    ECGData.Labels(start)=flag;
    ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
    ECGData.Labels(start+1)=flag;
    start=start+2;
    id
end

%% Load STDB
start=47
flag={'ST'};
% 313~317 319~323
id_list=[300,301,302,303,304,305,306,307,308,309,310,311,312,318,324,325,326,327];
for id = 1:18
    signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
    resamp_signal=resample(signal, 128, 360);
    cutoff_signal=resamp_signal(1:65536, :);
    ECGData.Data(:,start)=cutoff_signal(1:65536,1);
    ECGData.Labels(start)=flag;
    ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
    ECGData.Labels(start+1)=flag;
    start=start+2;
    id
end

id_list=[313,314,315,316,317,319,320,321,322,323];
for id = 1:10
    signal=rdsamp(['/database/stdb/', num2str(id_list(id))]);
    resamp_signal=resample(signal, 128, 360);
    cutoff_signal=resamp_signal(1:65536, :);
    ECGData.Data(:,start)=cutoff_signal(1:65536,1);
    ECGData.Labels(start)=flag;
    start=start+1;
    id
end

%% Load SVDB
start=93
flag={'SV'};
id_list=[800,801,802,803,804,805,806,807,808,809,810,811,812,820,821,822,823,824,825,826,827,828,829,840,841,842,843,844,845,846,847,848,849,850,851,852,853,854,855,856,857,858,859,860,861,862,863,864,865,866,867,868,869,870,871,872,873,874,875,876,877,878,879,880,881,882,883,884,885,886,887,888,889,890,891,892,893,894];
for id = 1:78
    signal=rdsamp(['/database/svdb/', num2str(id_list(id))]);
    resamp_signal=resample(signal, 128, 128);
    cutoff_signal=resamp_signal(1:65536, :);
    ECGData.Data(:,start)=cutoff_signal(1:65536,1);
    ECGData.Labels(start)=flag;
    ECGData.Data(:,start+1)=cutoff_signal(1:65536,2);
    ECGData.Labels(start+1)=flag;
    start=start+2;
    id
end

%% Rebuild
ECGData.Data=ECGData.Data';
ECGData.Labels=ECGData.Labels';

Plot原始数据

调用例程中的helperPlotReps()函数,看原始数据。很奇怪的是第三类问题好像采样率有些奇怪,但是找不到问题的原因。


原始信号1000点

特征提取与数据集划分

首先使用cwtfilterbank函数对原始信号进行CWT变换,得到的结果如图所示。


CWT_Result

然后使用helpCreateRGBfromTF()对整个数据集进行变换,并使用splitEachLabel进行训练集和测试集的分割,分割得到大小为199的训练集和大小为49的测试集,存储在ImageDatastore里。

helperCreateRGBfromTF(ECGData,parentDir,dataDir)
allImages = imageDatastore(fullfile(parentDir,dataDir),...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');

rng default
[imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized');
disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);

使用GoogleNet进行训练

GoogleNet是使用ImageNet训练的对于1000分类的深层CNN网络,其结构如图所示,为了进行迁移学习,我们将最后四层修改为针对三分类问题的输出。

lgraph = removeLayers(lgraph,{'pool5-drop_7x7_s1','loss3-classifier','prob','output'});

numClasses = numel(categories(imgsTrain.Labels));
newLayers = [
    dropoutLayer(0.6,'Name','newDropout')
    fullyConnectedLayer(numClasses,'Name','fc','WeightLearnRateFactor',5,'BiasLearnRateFactor',5)
    softmaxLayer('Name','softmax')
    classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);

lgraph = connectLayers(lgraph,'pool5-7x7_s1','newDropout');
inputSize = net.Layers(1).InputSize;
googlenet

同时,设置GoogleNet训练的一些参数,开始训练,训练结果如图所示。

options = trainingOptions('sgdm',...
    'MiniBatchSize',15,...
    'MaxEpochs',20,...
    'InitialLearnRate',1e-4,...
    'ValidationData',imgsValidation,...
    'ValidationFrequency',10,...
    'ValidationPatience',Inf,...
    'Verbose',1,...
    'ExecutionEnvironment','cpu',...
    'Plots','training-progress');

rng default
trainedGN = trainNetwork(imgsTrain,lgraph,options);

trainedGN.Layers(end-2:end)
cNames = trainedGN.Layers(end).ClassNames
GoogleNet训练结果
result

使用GoogleNet训练的最终正确率为:71.429%
同时我们还观测了GoogleNet的激活函数、对于AF病症的激活函数以及最强的AF通道,如下图所示。


第一层Activations AF Activations strongest AF Channel

使用AlexNet进行训练

AlexNet共有25层,如下图所示。


AlexNet

针对本问题,我们修改了AlexNet的最后三层,同时改变图像的形状以匹配AlexNet的输入。

%% Load
alex=alexnet;
layers = alex.Layers

%% Modify AlexNet Network Parameters
layers(23) = fullyConnectedLayer(3);
layers(25) = classificationLayer;

%% Prepare RGB Data for AlexNet
inputSize = alex.Layers(1).InputSize;
augimgsTrain = augmentedImageDatastore(inputSize(1:2),imgsTrain);
augimgsValidation = augmentedImageDatastore(inputSize(1:2),imgsValidation);

同时,设置AlexNet训练的一些参数,开始训练,训练结果如图所示。


AlexNet训练结果
result

使用AlexNet训练的最终正确率为75.51%

相关文章

网友评论

      本文标题:使用GoogleNet和AlexNet迁移学习ECG

      本文链接:https://www.haomeiwen.com/subject/mtfluftx.html