SKIL中在多GPU上训练
训练神经网络模型可能是一项计算代价很高的任务。如果你的机器上安装了多个GPU,为了加快训练过程,你可以选择并行训练你的模型。
SKIL可以使用 skil parallelwrapper
命令利用机器中安装的GPU。在本指南中,你将看到如何在多个GPU上的MNIST数据集上训练DL4J网络。
在继续本指南之前,请确保已将SKIL配置为GPU模式。
先决条件
你需要遵循以下步骤:
- SKIL
- 一个或多个GPU 在你的机器中
使用“skil parallelwrapper”进行分布式训练的组件
1. DataSetIteratorProviderFactory接口实现
为了为你的网络提供数据你需要实现org.deeplearning4j.parallelism.main.DataSetIteratorProviderFactory接口。这个接口定义如下:
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
public interface DataSetIteratorProviderFactory {
DataSetIterator create();
}
image.gif
SKIL提供此接口的默认实现(io.skymind.skil.parallelwrapper.MnistDataSetIteratorProviderFactory
),用于提供MNIST数据。在Java中的实现如下:
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.parallelism.main.DataSetIteratorProviderFactory;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.IOException;
public class MnistDataSetIteratorProviderFactory implements DataSetIteratorProviderFactory {
@Override
public DataSetIterator create() {
try {
return new MnistDataSetIterator(100, 1000);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
image.gif
2. 神经网络配置
最后,要训练的神经网络配置。示例配置(在scala中)如下(对于MultiLayerNetwork
)
import org.deeplearning4j.nn.api.Model
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster
import org.deeplearning4j.util.ModelSerializer
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.learning.config.Nesterovs
import org.nd4j.linalg.lossfunctions.LossFunctions
import java.io.File;
//如上所述的训练迭代
var builder = new NeuralNetConfiguration.Builder().seed(230)
.l2(0.0005)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Nesterovs.builder().learningRate(0.01).momentum(0.9).build())
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
//nIn和nOut指定深度。这里是nChannels,nOut是要应用的过滤器的数量。
.nIn(1).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
//注意,不需要在后面的层中指定nIn
.stride(1, 1).nOut(50).activation(Activation.IDENTITY).build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2)
.stride(2, 2).build())
.layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10)
.activation(Activation.SOFTMAX).build())
//见下面的备注
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.backprop(true).pretrain(false);
var network = new MultiLayerNetwork(builder.build());
network.init();
//把模型写入文件
var neuralNet = new File("/tmp", "neuralnet.bin")
neuralNet.createNewFile()
ModelSerializer.writeModel(network, neuralNet, true);
image.gif
使用“skil parallelwrapper”命令
要使用skil parellelwrapper命令,你需要使用 ModelSerializer#writeModel
函数将模型配置写入文件(同样,如上述网络代码所示)。之后,你将能够使用skil parallelwrapper命令,如下所示:
$SKIL_HOME/sbin/skil login --userId admin --password admin #你可能有不同的用户名和密码,请相应地替换它们。.
$SKIL_HOME/sbin/skil parallelwrapper --modelPath "/tmp/neuralnet.bin" --dataSetIteratorFactoryClazz "io.skymind.skil.parallelwrapper.MnistDataSetIteratorProviderFactory" --modelOutputPath "/tmp/neuralnet_out.bin" --uiUrl localhost:9002 --evalDataSetProviderClass "io.skymind.skil.parallelwrapper.MnistDataSetIteratorProviderFactory" --evalType "evaluation"
image.gif
上面的命令将从--modelPath获取模型配置,通过--dataSetIteratorFactoryClazz
指定的工厂类使用数据在可用的GPU上对其进行训练,并将输出保存在--modelOutputPath
。它还将使用与--evalDataSetProviderClass
指定的同一工厂类评估模型,并在--uiUrl
指定的URL上显示流程。
网友评论