美文网首页
CARD 空间转录组和单细胞转录组之间的反卷积

CARD 空间转录组和单细胞转录组之间的反卷积

作者: 小潤澤 | 来源:发表于2023-02-22 16:29 被阅读0次

    首先附上参考文献:《Spatially informed cell-type deconvolution for spatial transcriptomics》

    原理分析

    截至目前,大部分的反卷积求细胞组分的基本逻辑都是回归,CARD的基本逻辑同样采用的是回归策略:

    其中:

    1. B: as the G-by-K cell-type-specific expression matrix for the informative genes(单细胞表达矩阵)
    2. X: as the G-by-N gene expression matrix for the same set of informative genes measured on N spatial locations(空转表达矩阵)
    3. V: as the N-by-K cell-type composition matrix(空间每个location的细胞组分矩阵)
    4. E: 服从
    5. 以上矩阵均为非负矩阵

    同时作者也考虑到了相邻位置的两个location细胞成分可能会很相近,因此考虑了矫正,这里的矫正作者利用了回归到思想,建立了 VikVjk 之间的线性关系,表明第 i 个位置的细胞组分首临近位置细胞组分的影响(第 i 个位置的细胞组分为非自由项):

    其中:

    1. Vik:represents the proportion of cell type k on the ith location,即第 i 个位置 cell type k 的比例
    2. bk: is the kth cell-type-specific intercept that represents the average cell-type composition across locations,类似于回归问题中的截距,表示 cell type k 在所有位置比例的均值,为一列向量
    3. W: an N-by-N non-negative weight matrix with each element Wij specifying the weight used for inferring the cell-type composition on the ith location based on the cell-type composition information on the jth location,类似于回归问题中的回归系数
    4. ϕ: is a spatial autocorrelation parameter that determines the strength of the spatial correlation in cell-type composition,表征为决定系数
    5. Vjk:kth cell-type compositions on all other locations

    而我们的目的就是要推断出 Vik 矩阵,即每个location的细胞组分

    统计推断部分:
    根据引理:

    引理
    假设下面式子中的εik服从正态分布 εik 代入引理可得到公式1,我们可以得到如下关系,公式1主要描述的是寻找一个合适的 Vik 矩阵,使得误差 εik 尽可能小:
    公式1
    而整个统计推断将转变为一个最优化问题,即寻找一个合适的 Vik 矩阵,使得 Vik 之间的误差 εik 尽可能小

    将公式1化简以后:


    这里巧用矩阵乘法法则,将加和(∑)改变成为了两个矩阵乘积形式,其中:



    提出公因式后

    接下来构造似然函数,利用极大似然法求解参数:
    首先,作者先定义协方差矩阵如下:


    第二步将所有函数的似然值相乘构造似然函数:( k 为 k 个cell type,i 为 i 个location)
    这里的似然函数要估计 4 个参数 V,λk,σe2,bk,并且在这个似然函数中要优化两个回归:


    于是联合起来得到的似然函数,这个似然函数的作用是使得误差 Egi 和 εik 以及σe2在 0 处的似然值最大λk 感觉像是一个正则项
    公式2

    这里求解的是似然函数的最小值,作者把求解最小值的问题转换为求解最大值:
    公式3-1
    公式3-2

    最后利用极大似然法求出最优 V矩阵 的参数解,这里的G代表基因数目,N为 location 数目

    代码分析

    首先下载相关数据:

    1. sc_count.RData为单细胞表达矩阵
    2. sc_meta.RData为单细胞基本信息
    3. spatial_count.RData为空间转录组表达矩阵 spatial_count 行代表基因,列代表位置信息,矩阵元素代表基因在不同 location 的表达量
    4. spatial_location.RData为空间转录组位置信息 spatial_location x,y相当于二维坐标,空间转录组相当于在一个平面上按照矩形形式标记不同的像素点,10×10代表平面第10行第10列的像素点,每一个像素点相当于一个小的 bulk-seq 每一个小圆点相当于一个像素点,称之为一个 location

    示例数据给的流程如下:

    # 载入数据
    load("C:/Users/lenovo/Downloads/spatial_count.RData")
    load("C:/Users/lenovo/Downloads/spatial_location.RData")
    load("C:/Users/lenovo/Downloads/sc_count.RData")
    load("C:/Users/lenovo/Downloads/sc_meta.RData")
    
    library(CARD)
    
    CARD_obj = createCARDObject(
      sc_count = sc_count,
      sc_meta = sc_meta,
      spatial_count = spatial_count,
      spatial_location = spatial_location,
      ct.varname = "cellType",
      ct.select = unique(sc_meta$cellType),
      sample.varname = "sampleInfo",
      minCountGene = 100,
      minCountSpot = 5) 
    
    CARD_obj = CARD_deconvolution(CARD_object = CARD_obj)
    # 细胞组分矩阵
    CARD_obj@Proportion_CARD
    

    细胞组分矩阵:CARD_obj@Proportion_CARD

    1.解析createCARDObject()函数

    # 载入数据
    sc_count = sc_count
    sc_meta = sc_meta
    spatial_count = spatial_count
    spatial_location = spatial_location
    ct.varname = "cellType"
    ct.select = unique(sc_meta$cellType)
    sample.varname = "sampleInfo"
    minCountGene = 100
    minCountSpot = 5
    
    # step 1 对单细胞的数据进行质控
    sc_countMat  <- sc_count
    ct.select <- as.character(ct.select[!is.na(ct.select)])
    sc_eset = sc_QC(sc_countMat,sc_meta,ct.varname,ct.select,sample.varname)
    #### Check the spatial count dataset
    #### QC on spatial dataset
    spatial_countMat <- spatial_count
    commonGene = intersect(rownames(spatial_countMat),rownames(assays(sc_eset)$counts))
    
    # step2 对空转的数据进行过滤
    #### QC on spatial dataset
    spatial_countMat = spatial_countMat[rowSums(spatial_countMat > 0) > minCountSpot,]
    spatial_countMat = spatial_countMat[,(colSums(spatial_countMat) >= minCountGene & colSums(spatial_countMat) <= 1e6)]
    spatial_location = spatial_location[rownames(spatial_location) %in% colnames(spatial_countMat),]
    spatial_location = spatial_location[match(colnames(spatial_countMat),rownames(spatial_location)),]
    
    object <- new(
      Class = "CARD",
     # 质控后的单细胞不同 cell type 的基因表达矩阵
      sc_eset = sc_eset,
      spatial_countMat = spatial_countMat,
      spatial_location = spatial_location,
      project = "Deconvolution",
      info_parameters = list(ct.varname = ct.varname,ct.select = ct.select,sample.varname = sample.varname)
    )
    return(object)
    
    
    ## sc_QC 函数的作用是过滤一些低表达的基因和低质量的细胞
    sc_QC <- function(counts_in,metaData,ct.varname,ct.select,sample.varname = NULL, min.cells = 0,min.genes = 0){
    # Filter based on min.features
        coldf = metaData
        counts = counts_in
        if (min.genes >= 0) {
            nfeatures <- colSums(x = counts )
            counts <- counts[, which(x = nfeatures > min.genes)]
            coldf <- coldf[which(x = nfeatures > min.genes),]
        }
        # filter genes on the number of cells expressing
        if (min.cells >= 0) {
            num.cells <- rowSums(x = counts > 0)
            counts <- counts[which(x = num.cells > min.cells), ]
        }
        fdata = as.data.frame(rownames(counts))
        rownames(fdata) = rownames(counts)
        keepCell = as.character(coldf[,ct.varname]) %in% ct.select
        counts = counts[,keepCell]
        coldf = coldf[keepCell,]
        keepGene = rowSums(counts) > 0
        fdata = as.data.frame(fdata[keepGene,])
        counts = counts[keepGene,]
        sce <- SingleCellExperiment(list(counts=counts),
        colData=as.data.frame(coldf),
        rowData=as.data.frame(fdata))
        return(sce)
    }
    

    2.解析CARD_deconvolution()函数

    # 读取createCARDObject()的结果文件
    CARD_object = CARD_obj
    
    # 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
    ct.select = CARD_object@info_parameters$ct.select
    # ct.varname 为字符串 "cellType"
    ct.varname = CARD_object@info_parameters$ct.varname
    sample.varname = CARD_object@info_parameters$sample.varname
    
    # sc_eset 为单细胞表达矩阵, 利用 counts(sc_eset) 查看表达矩阵
    sc_eset = CARD_object@sc_eset
    
    # 对单细胞表达矩阵进行标准化
    Basis_ref = createscRef(sc_eset, ct.select, ct.varname, sample.varname)
    
    Basis = Basis_ref$basis
    Basis = Basis[,colnames(Basis) %in% ct.select]
    Basis = Basis[,match(ct.select,colnames(Basis))]
    # 获得空间转录组表达矩阵 spatial_count 
    spatial_count = CARD_object@spatial_countMat
    commonGene = intersect(rownames(spatial_count),rownames(Basis))
    #### remove mitochondrial and ribosomal genes
    #### 去除 mt DNA
    commonGene  = commonGene[!(commonGene %in% commonGene[grep("mt-",commonGene)])]
    
    common = selectInfo(Basis,sc_eset,commonGene,ct.select,ct.varname)
    # 空转表达矩阵 Xinput 
    Xinput = spatial_count
    # 单细胞表达矩阵 B
    B = Basis
    
    ##### match the common gene names
    ##### 对空间表达矩阵选择 common 的 gene 进行后续分析
    Xinput = Xinput[order(rownames(Xinput)),]
    B = B[order(rownames(B)),]
    B = B[rownames(B) %in% common,]
    Xinput = Xinput[rownames(Xinput) %in% common,]
    
    ##### filter out non expressed genes or cells again
    ##### 对空间表达矩阵过滤掉没有表达的细胞和基因
    Xinput = Xinput[rowSums(Xinput) > 0,]
    Xinput = Xinput[,colSums(Xinput) > 0]
    
    ##### normalize count data
    ##### 对空转表达矩阵进行标准化
    colsumvec = colSums(Xinput)
    ### 相当于每个基因对相应位置的总测序深度做标准化
    Xinput_norm = sweep(Xinput,2,colsumvec,"/")
    B = B[rownames(B) %in% rownames(Xinput_norm),]    
    B = B[match(rownames(Xinput_norm),rownames(B)),]
    
    #### spatial location
    #### 获取空转的位置信息
    spatial_location = CARD_object@spatial_location
    spatial_location = spatial_location[rownames(spatial_location) %in% colnames(Xinput_norm),]
    spatial_location = spatial_location[match(colnames(Xinput_norm),rownames(spatial_location)),]
    
    ##### normalize the coordinates without changing the shape and relative position
    ### 对空转的位置进行标准化,转换为相对位置
    norm_cords = spatial_location[ ,c("x","y")]
    norm_cords$x = norm_cords$x - min(norm_cords$x)
    norm_cords$y = norm_cords$y - min(norm_cords$y)
    scaleFactor = max(norm_cords$x,norm_cords$y)
    norm_cords$x = norm_cords$x / scaleFactor
    norm_cords$y = norm_cords$y / scaleFactor
    
    ##### initialize the proportion matrix
    ### 计算空转位置间的欧式距离
    ED <- rdist::rdist(as.matrix(norm_cords))##Euclidean distance matrix
    
    set.seed(20200107)
    Vint1 = as.matrix(gtools::rdirichlet(ncol(Xinput_norm), rep(10,ncol(B))))
    colnames(Vint1) = colnames(B)
    rownames(Vint1) = colnames(Xinput_norm)
    b = rep(0,length(ct.select))
    
    ###### parameters that need to be set
    isigma = 0.1 ####construct Gaussian kernel with the default scale /length parameter to be 0.1
    epsilon = 1e-04  #### convergence epsion 
    phi = c(0.01,0.1,0.3,0.5,0.7,0.9,0.99) #### grided values for phi
    
    ## 随机生成 W 矩阵
    kernel_mat <- exp(-ED^2 / (2 * isigma^2))
    diag(kernel_mat) <- 0
    
    ###### scale the Xinput_norm and B to speed up the convergence. 
    mean_X = mean(Xinput_norm)
    mean_B = mean(B)
    Xinput_norm = Xinput_norm * 1e-01 / mean_X
    B = B * 1e-01 / mean_B
    ResList = list()
    Obj = c()
    ## 利用不同的参数 phi 来估计模型
    for(iphi in 1:length(phi)){
      res = CARDref(
        XinputIn = as.matrix(Xinput_norm),
        UIn = as.matrix(B),
        WIn = kernel_mat, 
        phiIn = phi[iphi],
        max_iterIn =1000,
        epsilonIn = epsilon,
        initV = Vint1,
        initb = rep(0,ncol(B)),
        initSigma_e2 = 0.1, 
        initLambda = rep(10,length(ct.select)))
      rownames(res$V) = colnames(Xinput_norm)
      colnames(res$V) = colnames(B)
      ResList[[iphi]] = res
      Obj = c(Obj,res$Obj)
    }
    
    ## 选择最优的参数下的模型
    Optimal = which(Obj == max(Obj))
    Optimal = Optimal[length(Optimal)] #### just in case if there are two equal objective function values
    OptimalPhi = phi[Optimal]
    OptimalRes = ResList[[Optimal]]
    cat(paste0("## Deconvolution Finish! ...\n"))
    CARD_object@info_parameters$phi = OptimalPhi
    
    ### 获得细胞组分矩阵 Proportion_CARD
    CARD_object@Proportion_CARD = sweep(OptimalRes$V,1,rowSums(OptimalRes$V),"/")
    CARD_object@algorithm_matrix = list(B = B * mean_B / 1e-01, Xinput_norm = Xinput_norm * mean_X / 1e-01, Res = OptimalRes)
    CARD_object@spatial_location = spatial_location
    
    
    ################################## 其中 createscRef() 函数
    # 读取数据
    x = sc_eset # 单细胞表达矩阵
    ct.select = ct.select # 获取不同的 cellType 名称, ct.select 为不同 cellType 的名称
    ct.varname = ct.varname # ct.varname 为字符串 "cellType"
    sample.varname = sample.varname
    
    # 其中 createscRef() 函数
    createscRef <- function(x, ct.select = NULL, ct.varname, sample.varname = NULL){
      library(MuSiC)
      if (is.null(ct.select)) {
        ct.select <- unique(colData(x)[, ct.varname])
      }
      # 去除 cellType 为 NA 的 cell Type
      ct.select <- ct.select[!is.na(ct.select)]
      # countMat <- as.matrix(assays(x)$counts)
      # 将单细胞表达矩阵取出来赋予 countMat 
      countMat <- as(SummarizedExperiment::assays(x)$counts,"sparseMatrix")
      # ct.id 的作用相当于将每个 cell 赋予对应的 cell type
      ct.id <- droplevels(as.factor(SummarizedExperiment::colData(x)[, ct.varname]))
      #if(length(unique(colData(x)[,sample.varname])) > 1){
      if(is.null(sample.varname)){
        SummarizedExperiment::colData(x)$sampleID = "Sample"
        sample.varname = "sampleID"
      }
      sample.id <- as.character(SummarizedExperiment::colData(x)[, sample.varname])
      ct_sample.id <- paste(ct.id, sample.id, sep = "$*$")
      colSums_countMat <- colSums(countMat)
      colSums_countMat_Ct = aggregate(colSums_countMat ~ ct.id + sample.id, FUN = 'sum')
      colSums_countMat_Ct_wide = reshape(colSums_countMat_Ct, idvar = "sample.id", timevar = "ct.id", direction = "wide")
      colnames(colSums_countMat_Ct_wide) = gsub("colSums_countMat.","",colnames(colSums_countMat_Ct_wide))
      rownames(colSums_countMat_Ct_wide) = colSums_countMat_Ct_wide$sample.id
      colSums_countMat_Ct_wide$sample.id <- NULL
      tbl <- table(sample.id,ct.id)
      colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[,match(colnames(tbl),colnames(colSums_countMat_Ct_wide))]
      colSums_countMat_Ct_wide = colSums_countMat_Ct_wide[match(rownames(tbl),rownames(colSums_countMat_Ct_wide)),]
      S_JK <- colSums_countMat_Ct_wide / tbl
      S_JK <- as.matrix(S_JK)
      S_JK[S_JK == 0] = NA
      S_JK[!is.finite(S_JK)] = NA
      S = colMeans(S_JK, na.rm = TRUE)
      S = S[match(unique(ct.id),names(S))]
      library("wrMisc")
      if(nrow(countMat) > 10000 & ncol(countMat) > 50000){ ### to save memory 
        seqID = seq(1,nrow(countMat),by = 10000)
        Theta_S_rowMean = NULL
        for(igs in seqID){
          if(igs != seqID[length(seqID)]){
            Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[(igs:(igs+9999)),]), grp = ct_sample.id, na.rm = TRUE)
          }else{
            Theta_S_rowMean_Tmp <- rowGrpMeans(as.matrix(countMat[igs:nrow(countMat),]), grp = ct_sample.id, na.rm = TRUE)
            
          }
          Theta_S_rowMean <- rbind(Theta_S_rowMean,Theta_S_rowMean_Tmp)
          
        }
      }else{
        Theta_S_rowMean <- rowGrpMeans(as.matrix(countMat), grp = ct_sample.id, na.rm = TRUE)
      }
      tbl_sample = table(ct_sample.id)
      tbl_sample = tbl_sample[match(colnames(Theta_S_rowMean),names(tbl_sample))]
      Theta_S_rowSums <- sweep(Theta_S_rowMean,2,tbl_sample,"*")
      Theta_S <- sweep(Theta_S_rowSums,2,colSums(Theta_S_rowSums),"/")
      grp <- sapply(strsplit(colnames(Theta_S),split="$*$",fixed = TRUE),"[",1)
      Theta = rowGrpMeans(Theta_S, grp = grp, na.rm = TRUE)
      Theta = Theta[,match(unique(ct.id),colnames(Theta))]
      S = S[match(colnames(Theta),names(S))]
      basis = sweep(Theta,2,S,"*")
      colnames(basis) = colnames(Theta)
      rownames(basis) = rownames(Theta)
      return(list(basis = basis))
    }
    
    
    
    ################################## 其中 selectInfo() 函数
    selectInfo <- function(Basis,sc_eset,commonGene,ct.select,ct.varname){
    #### log2 mean fold change >0.5
    gene1 = lapply(ct.select,function(ict){
    rest = rowMeans(Basis[,colnames(Basis) != ict])
    FC = log((Basis[,ict] + 1e-06)) - log((rest + 1e-06))
    rownames(Basis)[FC > 1.25 & Basis[,ict] > 0]
    })
    gene1 = unique(unlist(gene1))
    gene1 = intersect(gene1,commonGene)
    counts = assays(sc_eset)$counts
    counts = counts[rownames(counts) %in% gene1,]
    ##### only check the cell type that contains at least 2 cells
    ct.select = names(table(colData(sc_eset)[,ct.varname]))[table(colData(sc_eset)[,ct.varname]) > 1]
    sd_within = sapply(ct.select,function(ict){
      temp = counts[,colData(sc_eset)[,ct.varname] == ict]
      apply(temp,1,var) / apply(temp,1,mean)
      })
    ##### remove the outliers that have high dispersion across cell types
    gene2 = rownames(sd_within)[apply(sd_within,1,mean,na.rm = T) < quantile(apply(sd_within,1,mean,na.rm = T),prob = 0.99,na.rm = T)]
    return(gene2)
    }
    

    关于 CARD_deconvolution()中的变量

    1. 有关 Basis_ref$basis Basis_ref$basis
    2. 有关 spatial_count spatial_count
    3. 有关标准化后的位置信息 norm_cords: norm_cords

    关于 createscRef()中的变量

    1. 有关 ct.id: ct.id
    2. 有关 ct.select: ct.select

    3. Cpp 函数 CARDref()

    #include <iostream>
    #include <fstream>
    #define ARMA_64BIT_WORD 1
    #include <RcppArmadillo.h>
    // [[Rcpp::depends(RcppArmadillo)]]
    
    #include <R.h>
    #include <Rmath.h>
    #include <cmath>
    #include <stdio.h>
    #include <stdlib.h>
    #include <cstring>
    #include <ctime>
    #include <Rcpp.h>
    
    // Enable C++11 via this plugin (Rcpp 0.10.3 or later)
    // [[Rcpp::plugins(cpp11)]]
    
    using namespace std;
    using namespace arma;
    using namespace Rcpp;
    
    #define ARMA_DONT_PRINT_ERRORS
    
    
    //*******************************************************************//
    //              spatially informed deconvolution:CARD                        //
    //*******************************************************************//
    //' SpatialDeconv function based on Conditional Autoregressive model
    //' @param XinputIn The input of normalized spatial data
    //' @param UIn The input of cell type specific basis matrix B
    //' @param WIn The constructed W weight matrix from Gaussian kernel
    //' @param phiIn The phi value
    //' @param max_iterIn Maximum iterations
    //' @param epsilonIn epsilon for convergence 
    //' @param initV Initial matrix of cell type compositions V
    //' @param initb Initial vector of cell type specific intercept
    //' @param initSigma_e2 Initial value of residual variance
    //' @param initLambda Initial vector of cell type sepcific scalar. 
    //'
    //' @return A list
    //'
    //' @export
    // [[Rcpp::export]]
    SEXP CARDref(SEXP XinputIn, SEXP UIn, SEXP WIn, SEXP phiIn, SEXP max_iterIn, SEXP epsilonIn, SEXP initV, SEXP initb, SEXP initSigma_e2, SEXP initLambda)
    {    
        try {
            // read in the data
            arma::mat Xinput = as<mat>(XinputIn);
            arma::mat U = as<mat>(UIn);
            arma::mat W = as<mat>(WIn);
            double phi = as<double>(phiIn);
            int max_iter = Rcpp::as<int>(max_iterIn);
            double epsilon = as<double>(epsilonIn);
            arma::mat V = as<mat>(initV);
            arma::vec b = as<vec>(initb);
            double sigma_e2 = as<double>(initSigma_e2);
            arma::vec lambda = as<vec>(initLambda);
            // initialize some useful items
            int nSample = (int)Xinput.n_cols; // number of spatial sample points
            int mGene = (int)Xinput.n_rows; // number of genes in spatial deconvolution
            int k = (int)U.n_cols; // number of cell type
            arma::mat L = zeros<mat>(nSample,nSample);
            arma::mat D = zeros<mat>(nSample,nSample);
            arma::mat V_old = zeros<mat>(nSample,k);
            arma::mat UtU = zeros<mat>(k,k);
            arma::mat VtV = zeros<mat>(k,k);
            arma::vec colsum_W = zeros<vec>(nSample);
            arma::mat UtX = zeros<mat>(k,nSample);
            arma::mat XtU = zeros<mat>(nSample,k);
            arma::mat UtXV = zeros<mat>(k,k);
            arma::mat temp = zeros<mat>(k,k);
            arma::mat part1 = zeros<mat>(nSample,k);
            arma::mat part2 = zeros<mat>(nSample,k);
            arma::vec updateV_k = zeros<vec>(k);
            arma::vec updateV_den_k = zeros<vec>(k);
            arma::vec vecOne = ones<vec>( nSample);
            arma::vec diag_UtU = zeros<vec>(k);
            bool logicalLogL = FALSE;
            double obj = 0;
            double obj_old = 0;
            double normNMF = 0;
            double logX = 0;
            double logV = 0;
            double alpha = 1.0;
            double beta = nSample / 2.0;
            double logSigmaL2 = 0.0;
            double accu_L = 0.0;
            double trac_xxt = accu(Xinput % Xinput);
            
            // initialize values
            // constant matrix caculations for increasing speed 
            UtX = U.t() * Xinput;
            XtU = UtX.t();
            colsum_W = sum(W,1);
            D =  diagmat(colsum_W);// diagnol matrix whose entries are column
            L = D -  phi*W; // graph laplacian
            accu_L = accu(L);
            UtXV = UtX * V;
            VtV = V.t() * V;
            UtU = U.t() * U;
            diag_UtU = UtU.diag();
            // calculate initial objective function 
            normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
            logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
            temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
            logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda )); 
            logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
            obj_old = logX + logV + logSigmaL2;
            V_old = V;
            // iteration starts
            for(int i = 1; i <= max_iter; ++i) {
                logV = 0.0;  
                b = sum(V.t() * L, 1) / accu_L;
                lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);  
                part1 = sigma_e2 * (D * V + phi * colsum_W * b.t());
                part2 = sigma_e2 * (phi * W * V + colsum_W * b.t());
                for(int nCT = 0; nCT < k; ++nCT){
                    updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) +  part1.col(nCT);
                    updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k;
                    V.col(nCT) %= updateV_k;
                }
                UtXV = UtX * V;
                VtV = V.t() * V;
                normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
                sigma_e2 = normNMF / (double)(mGene * nSample);
                temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
                logX = -(double)(nSample * mGene) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
                logV = -(double)(nSample) * 0.5 * sum(log(lambda))- 0.5 * (sum(temp.diag() / lambda )); 
                logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
                obj = logX + logV + logSigmaL2;
                logicalLogL = (obj > obj_old) && (abs(obj - obj_old) * 2.0 / abs(obj + obj_old) < epsilon);
                if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k))  < epsilon) || logicalLogL){
                   if(i > 5){ // run at least 5 iterations 
                       break;
                   }
           }else{
                obj_old = obj;
                V_old = V;
             }
           }
           return List::create(Named("V") = V,
                               Named("sigma_e2") = sigma_e2,
                               Named("lambda") = lambda,
                               Named("b") = b,
                               Named("Obj") = obj);
            }//end try 
            catch (std::exception &ex)
            {
                forward_exception_to_r(ex);
            }
            catch (...)
            {
                ::Rf_error("C++ exception (unknown reason)...");
            }
            return R_NilValue;
    } // end funcs
    

    Cpp 中的变量解释:

    1. trac_xxt = accu(Xinput % Xinput);,与公式3-1的:
    2. 2.0 * trace(UtXV)
    # 根据矩阵乘法的性质, 体现出加和的形式
    UtX = U.t() * Xinput;
    UtXV = UtX * V;
    trace(UtXV);
    

    UtX = U.t() * Xinput; UtXV = UtX * V;,代表 BTXV

    1. trace(UtU * VtV)
    # 根据矩阵乘法的性质, 体现出加和的形式
    VtV = V.t() * V;
    UtU = U.t() * U; 
    trace(UtU * VtV);
    

    UtU = U.t() * U;,代表 BTBVtV = V.t() * V;,代表 VTV

    其中:

    1. 初始化各项矩阵:
    UtX = U.t() * Xinput;
    XtU = UtX.t();
    colsum_W = sum(W,1);
    D =  diagmat(colsum_W);// diagnol matrix whose entries are column
    L = D -  phi*W; // graph laplacian
    accu_L = accu(L);
    UtXV = UtX * V;
    VtV = V.t() * V;
    UtU = U.t() * U;
    diag_UtU = UtU.diag();
    
    1. 构造似然函数:
    normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);
    logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);
    temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
    logV = -(double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda )); 
    logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);
    obj_old = logX + logV + logSigmaL2;
    V_old = V;
    
    1. normNMF = trac_xxt - 2.0 * trace(UtXV) + trace(UtU * VtV);代表公式3-1中的:
    2. logX = -(double)(mGene * nSample) * 0.5 * log(sigma_e2) - 0.5 * (double)(normNMF / sigma_e2);代表公式3-1中的(不清楚为什么这里的符号是反的):
    3. temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());代表公式3-1中的:
    4. logV = - (double)(nSample) * 0.5 * sum(log(lambda )) - 0.5 * (sum(temp.diag() / lambda ));代表公式3-1中的:
    5. logSigmaL2 = -(alpha + 1.0) * sum(log(lambda)) - sum(beta / lambda);代表公式3-1中的:
    6. obj_old = logX + logV + logSigmaL2;将他们加和
    7. 迭代终止条件:
    if(isnan(obj) || (sqrt(accu((V - V_old) % (V - V_old)) / double(nSample * k))  < epsilon) || logicalLogL){
         if(i > 5){ // run at least 5 iterations 
         break;
    }
    
    满足下式小于定义的epsilon即可
    1. 每次迭代更新 V 矩阵:
    for(int nCT = 0; nCT < k; ++nCT){
         // nCT 相当于cell type k,V.col(nCT) 代表提取 V 矩阵的第 nCT 列
         updateV_den_k = lambda(nCT) * (V.col(nCT) * diag_UtU(nCT) + (V * UtU.col(nCT) - V.col(nCT) * diag_UtU(nCT))) +  part1.col(nCT);
         // 计算每一列的 updateV_k 
         updateV_k = (lambda(nCT) * XtU.col(nCT) + part2.col(nCT)) / updateV_den_k;
         // V 矩阵的每一列除以 updateV_k,从而更新 V 矩阵
         V.col(nCT) %= updateV_k;
    }
    

    然后基于更新后的 V 矩阵更新 lambda :

    temp = (V.t() - b * vecOne.t()) * L * (V - vecOne * b.t());
    lambda = (temp.diag() / 2.0 + beta ) / (double(nSample) / 2.0 + alpha + 1.0);  
    

    然后基于更新后的 lambda 在更新 V 矩阵:

    相关文章

      网友评论

          本文标题:CARD 空间转录组和单细胞转录组之间的反卷积

          本文链接:https://www.haomeiwen.com/subject/ivxuhdtx.html