Build conda env
conda create -n jax python=3.7
conda activate jax
conda install tensorflow-gpu=2.4.1
conda install wheel scipy
echo "export LD_LIBRARY_PATH=$HOME/.conda/envs/jax/lib:$LD_LIBRARY_PATH" >> $HOME/.bashrc
Build openmpi and jax
wget https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.1.tar.gz
tar -zvxf openmpi-4.1.1.tar.gz
cd openmpi-4.1.1
./configure --prefix=$HOME/openmpi --without-verbs
make -j20
make install
echo "export PATH=$PATH:$HOME/openmpi/bin" >> $HOME/.bashrc
echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/openmpi/lib" >> $HOME/.bashrc
source $HOME/.bashrc
cd $HOME
git clone https://github.com/yxd886/tensorflow.git
cd tensorflow
git checkout jax
cd $HOME
git clone https://github.com/yxd886/jax.git
cd jax
git checkout department_jax
sh build.sh
Download trax
git clone https://github.com/yxd886/trax.git
pip install t5==0.7.1 --index-url http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com --disable-pip-version-check
pip install gin-config --index-url http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com --disable-pip-version-checkconda install tensorflow_datasets
pip install psutil
pip install matplotlib
mpirun -np 1 --host "net-g3" -x ENABLE_SEARCH=1 CORE_MODULE_ID=57 -x TOTAL_SAMPLE_TIME=1000 -x OP_FUSION=3 -x TENSOR_FUSION_THRESHOLD=40 -x BUFFEQR_SIZE=100 sh run1_rnnlm.sh
网友评论