首先附上参考文献:《Spatially informed cell-type deconvolution for spatial transcriptomics》
原理分析
截至目前,大部分的反卷积求细胞组分的基本逻辑都是回归,CARD的基本逻辑同样采用的是回归策略:
其中:
- B: as the G-by-K cell-type-specific expression matrix for the informative genes(单细胞表达矩阵)
- X: as the G-by-N gene expression matrix for the same set of informative genes measured on N spatial locations(空转表达矩阵)
- V: as the N-by-K cell-type composition matrix(空间每个location的细胞组分矩阵)
- E: 服从
- 以上矩阵均为非负矩阵
同时作者也考虑到了相邻位置的两个location细胞成分可能会很相近,因此考虑了矫正,这里的矫正作者利用了回归到思想,建立了 Vik 与 Vjk 之间的线性关系,表明第 i 个位置的细胞组分首临近位置细胞组分的影响(第 i 个位置的细胞组分为非自由项):
其中:
- Vik:represents the proportion of cell type k on the ith location,即第 i 个位置 cell type k 的比例
- bk: is the kth cell-type-specific intercept that represents the average cell-type composition across locations,类似于回归问题中的截距,表示 cell type k 在所有位置比例的均值,为一列向量
- W: an N-by-N non-negative weight matrix with each element Wij specifying the weight used for inferring the cell-type composition on the ith location based on the cell-type composition information on the jth location,类似于回归问题中的回归系数
- ϕ: is a spatial autocorrelation parameter that determines the strength of the spatial correlation in cell-type composition,表征为决定系数
- Vjk:kth cell-type compositions on all other locations
而我们的目的就是要推断出 Vik 矩阵,即每个location的细胞组分
统计推断部分:
根据引理:
假设下面式子中的εik服从正态分布 将 εik 代入引理可得到公式1,我们可以得到如下关系,公式1主要描述的是寻找一个合适的 Vik 矩阵,使得误差 εik 尽可能小:
公式1
而整个统计推断将转变为一个最优化问题,即寻找一个合适的 Vik 矩阵,使得 Vik 与 之间的误差 εik 尽可能小
将公式1化简以后:
这里巧用矩阵乘法法则,将加和(∑)改变成为了两个矩阵乘积形式,其中:
提出公因式后
接下来构造似然函数,利用极大似然法求解参数:
首先,作者先定义协方差矩阵如下:
第二步将所有函数的似然值相乘构造似然函数:( k 为 k 个cell type,i 为 i 个location)
这里的似然函数要估计 4 个参数 V,λk,σe2,bk,并且在这个似然函数中要优化两个回归:
于是联合起来得到的似然函数,这个似然函数的作用是使得误差 Egi 和 εik 以及σe2在 0 处的似然值最大,λk 感觉像是一个正则项:
公式2
这里求解的是似然函数的最小值,作者把求解最小值的问题转换为求解最大值:
公式3-1
公式3-2
最后利用极大似然法求出最优 V矩阵 的参数解,这里的G代表基因数目,N为 location 数目
代码分析
首先下载相关数据:
- sc_count.RData为单细胞表达矩阵
- sc_meta.RData为单细胞基本信息
- spatial_count.RData为空间转录组表达矩阵 spatial_count 行代表基因,列代表位置信息,矩阵元素代表基因在不同 location 的表达量
- spatial_location.RData为空间转录组位置信息 spatial_location x,y相当于二维坐标,空间转录组相当于在一个平面上按照矩形形式标记不同的像素点,10×10代表平面第10行第10列的像素点,每一个像素点相当于一个小的 bulk-seq 每一个小圆点相当于一个像素点,称之为一个 location
示例数据给的流程如下:
# 载入数据
load("C:/Users/lenovo/Downloads/spatial_count.RData")
load("C:/Users/lenovo/Downloads/spatial_location.RData")
load("C:/Users/lenovo/Downloads/sc_count.RData")
load("C:/Users/lenovo/Downloads/sc_meta.RData")
library(CARD)
CARD_obj = createCARDObject(
sc_count = sc_count,
sc_meta = sc_meta,
spatial_count = spatial_count,
spatial_location = spatial_location,
ct.varname = "cellType",
ct.select = unique(sc_meta$cellType),
sample.varname = "sampleInfo",
minCountGene = 100,
minCountSpot = 5)
CARD_obj = CARD_deconvolution(CARD_object = CARD_obj)
# 细胞组分矩阵
CARD_obj@Proportion_CARD
细胞组分矩阵:
CARD_obj@Proportion_CARD
1.解析createCARDObject()函数
# 载入数据
sc_count = sc_count
sc_meta = sc_meta
spatial_count = spatial_count
spatial_location = spatial_location
ct.varname = "cellType"
ct.select = unique(sc_meta$cellType)
sample.varname = "sampleInfo"
minCountGene = 100
minCountSpot = 5
# step 1 对单细胞的数据进行质控
sc_countMat <- sc_count
ct.select <- as.character(ct.select[!is.na(ct.select)])
sc_eset = sc_QC(sc_countMat,sc_meta,ct.varname,ct.select,sample.varname)
#### Check the spatial count dataset
#### QC on spatial dataset
spatial_countMat <- spatial_count
commonGene = intersect(rownames(spatial_countMat),rownames(assays(sc_eset)$counts))
# step2 对空转的数据进行过滤
#### QC on spatial dataset
spatial_countMat = spatial_countMat[rowSums(spatial_countMat > 0) > minCountSpot,]
spatial_countMat = spatial_countMat[,(colSums(spatial_countMat) >= minCountGene & colSums(spatial_countMat) <= 1e6)]
spatial_location = spatial_location[rownames(spatial_location) %in% colnames(spatial_countMat),]
spatial_location = spatial_location[match(colnames(spatial_countMat),rownames(spatial_location)),]
object <- new(
Class = "CARD",
# 质控后的单细胞不同 cell type 的基因表达矩阵
sc_eset = sc_eset,
spatial_countMat = spatial_countMat,
spatial_location = spatial_location,
project = "Deconvolution",
info_parameters = list(ct.varname = ct.varname,ct.select = ct.select,sample.varname = sample.varname)
)
return(object)
## sc_QC 函数的作用是过滤一些低表达的基因和低质量的细胞
sc_QC <- function(counts_in,metaData,ct.varname,ct.select,sample.varname = NULL, min.cells = 0,min.genes = 0){
# Filter based on min.features
coldf = metaData
counts = counts_in
if (min.genes >= 0) {
nfeatures <- colSums(x = counts )
counts <- counts[, which(x = nfeatures > min.genes)]
coldf <- coldf[which(x = nfeatures > min.genes),]
}
# filter genes on the number of cells expressing
if (min.cells >= 0) {
num.cells <- rowSums(x = counts > 0)
counts <- counts[which(x = num.cells > min.cells), ]
}
fdata = as.data.frame(rownames(counts))
rownames(fdata) = rownames(counts)
keepCell = as.character(coldf[,ct.varname]) %in% ct.select
counts = counts[,keepCell]
coldf = coldf[keepCell,]
keepGene = rowSums(counts) > 0
fdata = as.data.frame(fdata[keepGene,])
counts = counts[keepGene,]
sce <- SingleCellExperiment(list(counts=counts),
colData=as.data.frame(coldf),
rowData=as.data.frame(fdata))
return(sce)
}
2.解析CARD_deconvolution()函数
# 读取createCARDObject()的结果文件
CARD_object = CARD_obj
# 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
ct.select = CARD_object@info_parameters$ct.select
# ct.varname 为字符串 "cellType"
ct.varname = CARD_object@info_parameters$ct.varname
sample.varname = CARD_object@info_parameters$sample.varname
# sc_eset 为单细胞表达矩阵, 利用 counts(sc_eset) 查看表达矩阵
sc_eset = CARD_object@sc_eset
# 对单细胞表达矩阵进行标准化
Basis_ref = createscRef(sc_eset, ct.select, ct.varname, sample.varname)
Basis = Basis_ref$basis
Basis = Basis[,colnames(Basis) %in% ct.select]
Basis = Basis[,match(ct.select,colnames(Basis))]
# 获得空间转录组表达矩阵 spatial_count
spatial_count = CARD_object@spatial_countMat
commonGene = intersect(rownames(spatial_count),rownames(Basis))
#### remove mitochondrial and ribosomal genes
#### 去除 mt DNA
commonGene = commonGene[!(commonGene %in% commonGene[grep("mt-",commonGene)])]
common = selectInfo(Basis,sc_eset,commonGene,ct.select,ct.varname)
# 空转表达矩阵 Xinput
Xinput = spatial_count
# 单细胞表达矩阵 B
B = Basis
##### match the common gene names
##### 对空间表达矩阵选择 common 的 gene 进行后续分析
Xinput = Xinput[order(rownames(Xinput)),]
B = B[order(rownames(B)),]
B = B[rownames(B) %in% common,]
Xinput = Xinput[rownames(Xinput) %in% common,]
##### filter out non expressed genes or cells again
##### 对空间表达矩阵过滤掉没有表达的细胞和基因
Xinput = Xinput[rowSums(Xinput) > 0,]
Xinput = Xinput[,colSums(Xinput) > 0]
##### normalize count data
##### 对空转表达矩阵进行标准化
colsumvec = colSums(Xinput)
### 相当于每个基因对相应位置的总测序深度做标准化
Xinput_norm = sweep(Xinput,2,colsumvec,"/")
B = B[rownames(B) %in% rownames(Xinput_norm),]
B = B[match(rownames(Xinput_norm),rownames(B)),]
#### spatial location
#### 获取空转的位置信息
spatial_location = CARD_object@spatial_location
spatial_location = spatial_location[rownames(spatial_location) %in% colnames(Xinput_norm),]
spatial_location = spatial_location[match(colnames(Xinput_norm),rownames(spatial_location)),]
##### normalize the coordinates without changing the shape and relative position
### 对空转的位置进行标准化,转换为相对位置
norm_cords = spatial_location[ ,c("x","y")]
norm_cords$x = norm_cords$x - min(norm_cords$x)
norm_cords$y = norm_cords$y - min(norm_cords$y)
scaleFactor = max(norm_cords$x,norm_cords$y)
norm_cords$x = norm_cords$x / scaleFactor
norm_cords$y = norm_cords$y / scaleFactor
##### initialize the proportion matrix
### 计算空转位置间的欧式距离
ED <- rdist::rdist(as.matrix(norm_cords))##Euclidean distance matrix
set.seed(20200107)
Vint1 = as.matrix(gtools::rdirichlet(ncol(Xinput_norm), rep(10,ncol(B))))
colnames(Vint1) = colnames(B)
rownames(Vint1) = colnames(Xinput_norm)
b = rep(0,length(ct.select))
###### parameters that need to be set
isigma = 0.1 ####construct Gaussian kernel with the default scale /length parameter to be 0.1
epsilon = 1e-04 #### convergence epsion
phi = c(0.01,0.1,0.3,0.5,0.7,0.9,0.99) #### grided values for phi
## 随机生成 W 矩阵
kernel_mat <- exp(-ED^2 / (2 * isigma^2))
diag(kernel_mat) <- 0
###### scale the Xinput_norm and B to speed up the convergence.
mean_X = mean(Xinput_norm)
mean_B = mean(B)
Xinput_norm = Xinput_norm * 1e-01 / mean_X
B = B * 1e-01 / mean_B
ResList = list()
Obj = c()
## 利用不同的参数 phi 来估计模型
for(iphi in 1:length(phi)){
res = CARDref(
XinputIn = as.matrix(Xinput_norm),
UIn = as.matrix(B),
WIn = kernel_mat,
phiIn = phi[iphi],
max_iterIn =1000,
epsilonIn = epsilon,
initV = Vint1,
initb = rep(0,ncol(B)),
initSigma_e2 = 0.1,
initLambda = rep(10,length(ct.select)))
rownames(res$V) = colnames(Xinput_norm)
colnames(res$V) = colnames(B)
ResList[[iphi]] = res
Obj = c(Obj,res$Obj)
}
## 选择最优的参数下的模型
Optimal = which(Obj == max(Obj))
Optimal = Optimal[length(Optimal)] #### just in case if there are two equal objective function values
OptimalPhi = phi[Optimal]
OptimalRes = ResList[[Optimal]]
cat(paste0("## Deconvolution Finish! ...\n"))
CARD_object@info_parameters$phi = OptimalPhi
### 获得细胞组分矩阵 Proportion_CARD
CARD_object@Proportion_CARD = sweep(OptimalRes$V,1,rowSums(OptimalRes$V),"/")
CARD_object@algorithm_matrix = list(B = B * mean_B / 1e-01, Xinput_norm = Xinput_norm * mean_X / 1e-01, Res = OptimalRes)
CARD_object@spatial_location = spatial_location
################################## 其中 createscRef() 函数
# 读取数据
x = sc_eset # 单细胞表达矩阵
ct.select = ct.select # 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
ct.varname = ct.varname # ct.varname 为字符串 "cellType"
sample.varname = sample.varname
# 其中 createscRef() 函数
createscRef <- function(x, ct.select = NULL, ct.varname, sample.varname = NULL){
library(MuSiC)
if (is.null(ct.select)) {
ct.select <- unique(colData(x)[, ct.varname])
}
# 去除 cellType 为 NA 的 cell Type
ct.select <- ct.select[!is.na(ct.select)]
# countMat <- as.matrix(assays(x)$counts)
# 将单细胞表达矩阵取出来赋予 countMat
countMat <- as(SummarizedExperiment::assays(x)$counts,"sparseMatrix")
# ct.id 的作用相当于将每个 cell 赋予对应的 cell type
ct.id <- droplevels(as.factor(SummarizedExperiment::colData(x)[, ct.varname]))
#if(length(unique(colData(x)[,sample.varname])) > 1){
if(is.null(sample.varname)){
SummarizedExperiment::colData(x)$sampleID = "Sample"
sample.varname = "sampleID"
}
sample.id <- as.character(SummarizedExperiment::colData(x)[, sample.varname])
ct_sample.id <- paste(ct.id, sample.id, sep = "$*$")
colSums_countMat <- colSums(countMat)
colSums_countMat_Ct = aggregate(colSums_countMat ~ ct.id + sample.id, FUN = 'sum')
colSums_countMat_Ct_wide = reshape(colSums_countMat_Ct, idvar = "sample.id", timevar = "ct.id", direction = "wide")
colnames(colSums_countMat_Ct_wide) = gsub("colSums_countMat.","",colnames(colSums_countMat_Ct_wide))
rownames(colSums_countMat_Ct_wide) = colSums_countMat_Ct_wide$sample.id
colSums_countMat_Ct_wide$sample.id <- NULL
tbl <- table(sample.id,ct.id)
colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[,match(colnames(tbl),colnames(colSums_countMat_Ct_wide))]
colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[match(rownames(tbl),rownames(colSums_countMat_Ct_wide)),]
S_JK <- colSums_countMat_Ct_wide / tbl
S_JK <- as.matrix(S_JK)
S_JK[S_JK == 0] = NA
S_JK[!is.finite(S_JK)] = NA
S = colMeans(S_JK, na.rm = TRUE)
S = S[match(unique(ct.id),names(S))]
library("wrMisc")
if(nrow(countMat) > 10000 & ncol(countMat) > 50000){ ### to save memory
seqID = seq(1,nrow(countMat),by = 10000)
Theta_S_rowMean = NULL
for(igs in seqID){
if(igs != seqID[length(seqID)]){
Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[(igs:(igs+9999)),]), grp = ct_sample.id, na.rm = TRUE)
}else{
Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[igs:nrow(countMat),]), grp = ct_sample.id, na.rm = TRUE)
}
Theta_S_rowMean <- rbind(Theta_S_rowMean,Theta_S_rowMean_Tmp)
}
}else{
Theta_S_rowMean <- rowGrpMeans(as.matrix(countMat), grp = ct_sample.id, na.rm = TRUE)
}
tbl_sample = table(ct_sample.id)
tbl_sample = tbl_sample[match(colnames(Theta_S_rowMean),names(tbl_sample))]
Theta_S_rowSums <- sweep(Theta_S_rowMean,2,tbl_sample,"*")
Theta_S <- sweep(Theta_S_rowSums,2,colSums(Theta_S_rowSums),"/")
grp <- sapply(strsplit(colnames(Theta_S),split="$*$",fixed = TRUE),"[",1)
Theta = rowGrpMeans(Theta_S, grp = grp, na.rm = TRUE)
Theta = Theta[,match(unique(ct.id),colnames(Theta))]
S = S[match(colnames(Theta),names(S))]
basis = sweep(Theta,2,S,"*")
colnames(basis) = colnames(Theta)
rownames(basis) = rownames(Theta)
return(list(basis = basis))
}
################################## 其中 selectInfo() 函数
selectInfo <- function(Basis,sc_eset,commonGene,ct.select,ct.varname){
#### log2 mean fold change >0.5
gene1 = lapply(ct.select,function(ict){
rest = rowMeans(Basis[,colnames(Basis) != ict])
FC = log((Basis[,ict] + 1e-06)) - log((rest + 1e-06))
rownames(Basis)[FC > 1.25 & Basis[,ict] > 0]
})
gene1 = unique(unlist(gene1))
gene1 = intersect(gene1,commonGene)
counts = assays(sc_eset)$counts
counts = counts[rownames(counts) %in% gene1,]
##### only check the cell type that contains at least 2 cells
ct.select = names(table(colData(sc_eset)[,ct.varname]))[table(colData(sc_eset)[,ct.varname]) > 1]
sd_within = sapply(ct.select,function(ict){
temp = counts[,colData(sc_eset)[,ct.varname] == ict]
apply(temp,1,var) / apply(temp,1,mean)
})
##### remove the outliers that have high dispersion across cell types
gene2 = rownames(sd_within)[apply(sd_within,1,mean,na.rm = T) < quantile(apply(sd_within,1,mean,na.rm = T),prob = 0.99,na.rm = T)]
return(gene2)
}
关于 CARD_deconvolution()中的变量
- 有关
Basis_ref$basis
: Basis_ref$basis- 有关 spatial_count spatial_count
- 有关标准化后的位置信息 norm_cords: norm_cords
关于 createscRef()中的变量
- 有关 ct.id: ct.id
- 有关 ct.select: ct.select
3. Cpp 函数 CARDref()
#include <iostream>
#include <fstream>
#define ARMA_64BIT_WORD 1
#include <RcppArmadillo.h>
// [[Rcpp::depends(RcppArmadillo)]]
#include <R.h>
#include <Rmath.h>
#include <cmath>
#include <stdio.h>
#include <stdlib.h>
#include <cstring>
#include <ctime>
#include <Rcpp.h>
// Enable C++11 via this plugin (Rcpp 0.10.3 or later)
// [[Rcpp::plugins(cpp11)]]
using namespace std;
using namespace arma;
using namespace Rcpp;
#define ARMA_DONT_PRINT_ERRORS
//*******************************************************************//
// spatially informed deconvolution:CARD //
//*******************************************************************//
//' SpatialDeconv function based on Conditional Autoregressive model
//' @param XinputIn The input of normalized spatial data
//' @param UIn The input of cell type specific basis matrix B
//' @param WIn The constructed W weight matrix from Gaussian kernel
//' @param phiIn The phi value
//' @param max_iterIn Maximum iterations
//' @param epsilonIn epsilon for convergence
//' @param initV Initial matrix of cell type compositions V
//' @param initb Initial vector of cell type specific intercept
//' @param initSigma_e2 Initial value of residual variance
//' @param initLambda Initial vector of cell type sepcific scalar.
//'
//' @return A list
//'
//' @export
// [[Rcpp::export]]
SEXP CARDref(SEXP XinputIn, SEXP UIn, SEXP WIn, SEXP phiIn, SEXP max_iterIn, SEXP epsilonIn, SEXP initV, SEXP initb, SEXP initSigma_e2, SEXP initLambda)
{
try {
// read in the data
arma::mat Xinput = as<mat>(XinputIn);
arma::mat U = as<mat>(UIn);
arma::mat W = as<mat>(WIn);
double phi = as<double>(phiIn);
int max_iter = Rcpp::as<int>(max_iterIn);
double epsilon = as<double>(epsilonIn);
arma::mat V = as<mat>(initV);
arma::vec b = as<vec>(initb);
double sigma_e2 = as<double>(initSigma_e2);
arma::vec lambda = as<vec>(initLambda);
// initialize some useful items
int nSample = (int)Xinput.n_cols; // number of spatial sample points
int mGene = (int)Xinput.n_rows; // number of genes in spatial deconvolution
int k = (int)U.n_cols; // number of cell type
arma::mat L = zeros<mat>(nSample,nSample);
arma::mat D = zeros<mat>(nSample,nSample);
arma::mat V_old = zeros<mat>(nSample,k);
arma::mat UtU = zeros<mat>(k,k);
arma::mat VtV = zeros<mat>(k,k);
arma::vec colsum_W = zeros<vec>(nSample);
arma::mat UtX = zeros<mat>(k,nSample);
arma::mat XtU = zeros<mat>(nSample,k);
arma::mat UtXV = zeros<mat>(k,k);
arma::mat temp = zeros<mat>(k,k);
arma::mat part1 = zeros<mat>(nSample,k);
arma::mat part2 = zeros<mat>(nSample,k);
arma::vec updateV_k = zeros<vec>(k);
arma::vec updateV_den_k = zeros<vec>(k);
arma::vec vecOne = ones<vec>( nSample);
arma::vec diag_UtU = zeros<vec>(k);
bool logicalLogL = FALSE;
double obj = 0;
double obj_old = 0;
double normNMF = 0;
double logX = 0;
double logV = 0;
double alpha = 1.0;
double beta = nSample / 2.0;
double logSigmaL2 = 0.0;
double accu_L = 0.0;
double trac_xxt = accu(Xinput % Xinput);
// initialize values
// constant matrix caculations for increasing speed
UtX = U.t() * Xinput;
XtU = UtX.t();
colsum_W = sum(W,1);
D = diagmat(colsum_W);// diagnol matrix whose entries are column
L = D - phi*W; // graph laplacian
accu_L = accu(L);
UtXV = UtX * V;
VtV = V.t() * V;
UtU = U.t() * U;
diag_UtU = UtU.diag();
// calculate initial objective function
normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda ));
logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
obj_old = logX + logV + logSigmaL2;
V_old = V;
// iteration starts
for(int i = 1; i <= max_iter; ++i) {
logV = 0.0;
b = sum(V.t() * L, 1) / accu_L;
lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);
part1 = sigma_e2 * (D * V + phi * colsum_W * b.t());
part2 = sigma_e2 * (phi * W * V + colsum_W * b.t());
for(int nCT = 0; nCT < k; ++nCT){
updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) + part1.col(nCT);
updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k;
V.col(nCT) %= updateV_k;
}
UtXV = UtX * V;
VtV = V.t() * V;
normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
sigma_e2 = normNMF / (double)(mGene * nSample);
temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
logX = -(double)(nSample * mGene) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
logV = -(double)(nSample) * 0.5 * sum(log(lambda))- 0.5 * (sum(temp.diag() / lambda ));
logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
obj = logX + logV + logSigmaL2;
logicalLogL = (obj > obj_old) && (abs(obj - obj_old) * 2.0 / abs(obj + obj_old) < epsilon);
if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k)) < epsilon) || logicalLogL){
if(i > 5){ // run at least 5 iterations
break;
}
}else{
obj_old = obj;
V_old = V;
}
}
return List::create(Named("V") = V,
Named("sigma_e2") = sigma_e2,
Named("lambda") = lambda,
Named("b") = b,
Named("Obj") = obj);
}//end try
catch (std::exception &ex)
{
forward_exception_to_r(ex);
}
catch (...)
{
::Rf_error("C++ exception (unknown reason)...");
}
return R_NilValue;
} // end funcs
Cpp 中的变量解释:
trac_xxt = accu(Xinput % Xinput);
,与公式3-1的:2.0 * trace(UtXV)
:# 根据矩阵乘法的性质, 体现出加和的形式 UtX = U.t() * Xinput; UtXV = UtX * V; trace(UtXV);
UtX = U.t() * Xinput; UtXV = UtX * V;
,代表 BTXV
trace(UtU * VtV)
:# 根据矩阵乘法的性质, 体现出加和的形式 VtV = V.t() * V; UtU = U.t() * U; trace(UtU * VtV);
UtU = U.t() * U;
,代表 BTB;VtV = V.t() * V;
,代表 VTV
其中:
- 初始化各项矩阵:
UtX = U.t() * Xinput; XtU = UtX.t(); colsum_W = sum(W,1); D = diagmat(colsum_W);// diagnol matrix whose entries are column L = D - phi*W; // graph laplacian accu_L = accu(L); UtXV = UtX * V; VtV = V.t() * V; UtU = U.t() * U; diag_UtU = UtU.diag();
- 构造似然函数:
normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV); logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2); temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t()); logV = -(double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda )); logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda); obj_old = logX + logV + logSigmaL2; V_old = V;
normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
代表公式3-1中的:logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
代表公式3-1中的(不清楚为什么这里的符号是反的):temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
代表公式3-1中的:logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda ));
代表公式3-1中的:logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
代表公式3-1中的:obj_old = logX + logV + logSigmaL2;
将他们加和- 迭代终止条件:
满足下式小于定义的epsilon即可if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k)) < epsilon) || logicalLogL){ if(i > 5){ // run at least 5 iterations break; }
- 每次迭代更新 V 矩阵:
for(int nCT = 0; nCT < k; ++nCT){ // nCT 相当于cell type k,V.col(nCT) 代表提取 V 矩阵的第 nCT 列 updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) + part1.col(nCT); // 计算每一列的 updateV_k updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k; // V 矩阵的每一列除以 updateV_k,从而更新 V 矩阵 V.col(nCT) %= updateV_k; }
然后基于更新后的 V 矩阵更新 lambda :
temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t()); lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);
然后基于更新后的 lambda 在更新 V 矩阵:
网友评论