Geneformer | 基因分类预测

作者: 尘世中一个迷途小书僮 | 来源:发表于2023-08-21 21:23

Gene Classification



cd Genecorpus-30M/example_input_files/cell_classification/disease_classifiction/human_dcm_hcm_nf.dataset
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset.arrow
wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/dataset_info.json

wget https://huggingface.co/datasets/ctheodoris/Genecorpus-30M/resolve/main/example_input_files/cell_classification/disease_classification/human_dcm_hcm_nf.dataset/state.json

作者提供一组心肌炎相关的scRNA-seq数据,其中包含来自non-failing (nf), hypertrophic, and dilated样本的数据,以及是否为对药物敏感的转录因子的gene list。根据这些数据进行微调,随后判断基因是否为对药物敏感的转录因子。

微调数据:sc-RNA-seq data and gene labels;


Modules import

import os
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
# imports
import datetime
import subprocess
import math
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datasets import load_from_disk
from sklearn import preprocessing
from sklearn.metrics import accuracy_score, auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve
from sklearn.model_selection import StratifiedKFold
import torch
from transformers import BertForTokenClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from tqdm.notebook import tqdm

from geneformer import DataCollatorForGeneClassification
from geneformer.pretrainer import token_dictionary
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:68: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def twobit_to_dna(twobit: int, size: int) -> str:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:85: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def dna_to_twobit(dna: str) -> int:
e:\miniconda3\envs\geneformer\lib\site-packages\loompy\bus_file.py:102: NumbaDeprecationWarning: �[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.�[0m
  def twobit_1hamming(twobit: int, size: int) -> List[int]:

Load Gene Attribute Information

读入作者提供的基因信息表格,包括了ensembl id, gene name和gene type信息。再将这些信息分别封装到三个字典中(gene_id_type_dict, gene_name_id_dict, gene_id_name_dict).

# table of corresponding Ensembl IDs, gene names, and gene types (e.g. coding, miRNA, etc.)
gene_info = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_info_table.csv", index_col=0)

# create dictionaries for corresponding attributes
gene_id_type_dict = dict(zip(gene_info["ensembl_id"],gene_info["gene_type"]))
gene_name_id_dict = dict(zip(gene_info["gene_name"],gene_info["ensembl_id"]))
gene_id_name_dict = {v: k for k,v in gene_name_id_dict.items()}

# first 5 key:value pairs
{k: gene_id_name_dict[k] for k in list(gene_id_name_dict)[:5]}
{'ENSG00000000003': 'TSPAN6',
 'ENSG00000000005': 'TNMD',
 'ENSG00000000419': 'DPM1',
 'ENSG00000000457': 'SCYL3',
 'ENSG00000000460': 'C1orf112'}

Load Training Data and Class Labels

接下来,读入微调训练相关数据集,包括心肌炎相关的scRNA-seq数据 ("human_dcm_hcm_nf.dataset")和是否为对药物敏感的转录因子的gene list ("dosage_sens_tf_labels.csv")

为了处理读入的dosage_sens_tf_labels,这里定义函数prep_inputs将输入的基因id转换为token id,并生成genegroup1genegroup2相应长度的labels(group1记为0, group2记为1).

token_dictionary中定义了ensembl id和token的对应关系。

# function for preparing targets and labels
def prep_inputs(genegroup1, genegroup2, id_type):
    if id_type == "gene_name":
        targets1 = [gene_name_id_dict[gene] for gene in genegroup1 if gene_name_id_dict.get(gene) in token_dictionary]
        targets2 = [gene_name_id_dict[gene] for gene in genegroup2 if gene_name_id_dict.get(gene) in token_dictionary]
    elif id_type == "ensembl_id":
        targets1 = [gene for gene in genegroup1 if gene in token_dictionary]
        targets2 = [gene for gene in genegroup2 if gene in token_dictionary]
    targets1_id = [token_dictionary[gene] for gene in targets1]
    targets2_id = [token_dictionary[gene] for gene in targets2]
    targets = np.array(targets1_id + targets2_id)
    labels = np.array([0]*len(targets1_id) + [1]*len(targets2_id))
    nsplits = min(5, min(len(targets1_id), len(targets2_id))-1)
    assert nsplits > 2
    print(f"# targets1: {len(targets1_id)}\n# targets2: {len(targets2_id)}\n# splits: {nsplits}")
    return targets, labels, nsplits
{k: token_dictionary[k] for k in list(token_dictionary)[:5]}
{'<pad>': 0,
 '<mask>': 1,
 'ENSG00000000003': 2,
 'ENSG00000000005': 3,
 'ENSG00000000419': 4}

读入作者提供的dosage sensitive tfs list,其中包含122 dosage sensitive tfs (0),和368个insensitive tfs (1). 使用prep_inputs将tfs的基因id转换为token,并划分为5个splits,做后续的5-fold cross-validation

from collections import Counter

# preparing targets and labels for dosage sensitive vs insensitive TFs
dosage_tfs = pd.read_csv("D:/jupyterNote/Geneformer/Genecorpus-30M/example_input_files/gene_classification/dosage_sensitive_tfs/dosage_sens_tf_labels.csv", header=0)
sensitive = dosage_tfs["dosage_sensitive"].dropna()
insensitive = dosage_tfs["dosage_insensitive"].dropna()
targets, labels, nsplits = prep_inputs(sensitive, insensitive, "ensembl_id")
# targets1: 122
# targets2: 368
# splits: 5
[208 223 275 295 487]
Counter({1: 368, 0: 122})


3种亚型分组:1. NF (Non-failing), 2. HCM (hypertrophic cardiomyopathy), and 3. DCM (dilated cardiomyopathy).

在打乱细胞标签后,随机抽取了50,000个细胞作为training set.

# load training dataset
shuffled_train_dataset = train_dataset.shuffle(seed=42)
subsampled_train_dataset = shuffled_train_dataset.select([i for i in range(50_000)])

Loading cached shuffled indices for dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-54b519f110fa07f1.arrow
#import pandas as pd
print("\nCelltype: ")
print("\nSubgroups: ")

    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef'],
    num_rows: 579159

Counter({'Fibroblast1': 141725, 'Cardiomyocyte1': 136167, 'Endothelial1': 78375, 'Pericyte1': 67600, 'Macrophage': 54714, 'Endothelial2': 18394, 'VSMC': 18137, 'Lymphocyte': 16246, 'Endocardial': 6489, 'Cardiomyocyte2': 5445, 'Adipocyte': 5298, 'ActivatedFibroblast': 5210, 'LymphaticEndothelial': 5181, 'Endothelial3': 4538, 'MastCell': 4465, 'Neuronal': 4292, 'Cardiomyocyte3': 3350, 'Pericyte2': 1704, 'ProliferatingMacrophage': 1276, 'Fibroblast2': 284, 'Epicardial': 269})

Counter({'hcm': 230652, 'nf': 182317, 'dcm': 166190})

Define Functions for Training and Cross-Validating Classifier

Geneformer将细胞基因表达量转为rank value encoding,且每个细胞的rank encoding长度不一样,而后续模型要求input tensors的长度一致。因此,这里定义函数preprocess_classifier_batch将不同长度的input都添加<pad> token到统一长度。

classifier_predict将input dataset 划分为forward_batch_size大小的batch利用fine-tuned的模型进行prediction,预测基因属于dosage sensitive or insensitive. 同时,根据预测labels与真实labels计算相应evaluation metrics (e.g., FPR, TPR)。


def preprocess_classifier_batch(cell_batch, max_len):
    if max_len == None:
        max_len = max([len(i) for i in cell_batch["input_ids"]])
    def pad_label_example(example):
        example["labels"] = np.pad(example["labels"], 
                                   (0, max_len-len(example["input_ids"])), 
                                   mode='constant', constant_values=-100)
        example["input_ids"] = np.pad(example["input_ids"], 
                                      (0, max_len-len(example["input_ids"])), 
                                      mode='constant', constant_values=token_dictionary.get("<pad>"))
        example["attention_mask"] = (example["input_ids"] != token_dictionary.get("<pad>")).astype(int)
        return example
    padded_batch = cell_batch.map(pad_label_example)
    return padded_batch

# forward batch size is batch size for model inference (e.g. 200)
def classifier_predict(model, evalset, forward_batch_size, mean_fpr):
    predict_logits = []
    predict_labels = []
    # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
    evalset_len = len(evalset)
    max_divisible = find_largest_div(evalset_len, forward_batch_size)
    if len(evalset) - max_divisible == 1:
        evalset_len = max_divisible
    max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
    for i in range(0, evalset_len, forward_batch_size):
        max_range = min(i+forward_batch_size, evalset_len)
        batch_evalset = evalset.select([i for i in range(i, max_range)])
        padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
        input_data_batch = padded_batch["input_ids"]
        attn_msk_batch = padded_batch["attention_mask"]
        label_batch = padded_batch["labels"]
        with torch.no_grad():
            outputs = model(
                input_ids = input_data_batch.to("cuda"), 
                attention_mask = attn_msk_batch.to("cuda"), 
                labels = label_batch.to("cuda"), 
            predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
            predict_labels += [torch.squeeze(label_batch.to("cpu"))]
    logits_by_cell = torch.cat(predict_logits)
    all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
    labels_by_cell = torch.cat(predict_labels)
    all_labels = torch.flatten(labels_by_cell)
    logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
    y_pred = [vote(item[0]) for item in logit_label_paired]
    y_true = [item[1] for item in logit_label_paired]
    logits_list = [item[0] for item in logit_label_paired]
    # probability of class 1
    y_score = [py_softmax(item)[1] for item in logits_list]
    conf_mat = confusion_matrix(y_true, y_pred)
    fpr, tpr, _ = roc_curve(y_true, y_score)
    # plot roc_curve for this split
    plt.plot(fpr, tpr)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    # interpolate to graph
    interp_tpr = np.interp(mean_fpr, fpr, tpr)
    interp_tpr[0] = 0.0
    return fpr, tpr, interp_tpr, conf_mat 

def vote(logit_pair):
    a, b = logit_pair
    if a > b:
        return 0
    elif b > a:
        return 1
    elif a == b:
        return "tie"
def py_softmax(vector):
    e = np.exp(vector)
    return e / e.sum()
# get cross-validated mean and sd metrics
def get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt):
    wts = [count/sum(all_tpr_wt) for count in all_tpr_wt]
    all_weighted_tpr = [a*b for a,b in zip(all_tpr, wts)]
    mean_tpr = np.sum(all_weighted_tpr, axis=0)
    mean_tpr[-1] = 1.0
    all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]
    roc_auc = np.sum(all_weighted_roc_auc)
    roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))
    return mean_tpr, roc_auc, roc_auc_sd

# Function to find the largest number smaller
# than or equal to N that is divisible by k
def find_largest_div(N, K):
    rem = N % K
    if(rem == 0):
        return N
        return N - rem

定义函数cross_validate封装模型数据切分(80% training set, 10% evaluation set, 10% hold-out evaluation set)、训练和预测过程。

其中,读入预训练模型这部分需要改为本地Geneformer或是hugging face上库的名字 ("ctheodoris/Geneformer")

        # load model
        model = BertForTokenClassification.from_pretrained(
            "D:/jupyterNote/Geneformer", # change to local path to the model
            output_attentions = False,
            output_hidden_states = False


        # add output directory to training args and initiate
        training_args["output_dir"] = ksplit_output_dir
        training_args_init = TrainingArguments(**training_args)
        # create the trainer
        trainer = Trainer(

        # train the gene classifier

这部分代码使用微调模型在 out-of-sample dataset (evalset_oos_labeled) 进行预测及评估。


        # evaluate model
        fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20
        # append to tpr and roc lists
        confusion = confusion + conf_mat
        all_roc_auc.append(auc(fpr, tpr))
        # append number of eval examples by which to weight tpr in averaged graphs
# cross-validate gene classifier
def cross_validate(data, targets, labels, nsplits, subsample_size, training_args, freeze_layers, output_dir, num_proc):
    # check if output directory already written to
    # ensure not overwriting previously saved model
    model_dir_test = os.path.join(output_dir, "ksplit0/models/pytorch_model.bin")
    if os.path.isfile(model_dir_test) == True:
        raise Exception("Model already saved to this directory.")
    # initiate eval metrics to return
    num_classes = len(set(labels))
    mean_fpr = np.linspace(0, 1, 100)
    all_tpr = []
    all_roc_auc = []
    all_tpr_wt = []
    label_dicts = []
    confusion = np.zeros((num_classes,num_classes))
    # set up cross-validation splits
    skf = StratifiedKFold(n_splits=nsplits, random_state=0, shuffle=True)
    # train and evaluate
    iteration_num = 0
    for train_index, eval_index in tqdm(skf.split(targets, labels)):
        if len(labels) > 500:
            print("early stopping activated due to large # of training examples")
            nsplits = 3
            if iteration_num == 3:
        print(f"****** Crossval split: {iteration_num}/{nsplits-1} ******\n")
        # generate cross-validation splits
        targets_train, targets_eval = targets[train_index], targets[eval_index]
        labels_train, labels_eval = labels[train_index], labels[eval_index]
        label_dict_train = dict(zip(targets_train, labels_train))
        label_dict_eval = dict(zip(targets_eval, labels_eval))
        label_dicts += (iteration_num, targets_train, targets_eval, labels_train, labels_eval)
        # function to filter by whether contains train or eval labels
        def if_contains_train_label(example):
            a = label_dict_train.keys()
            b = example['input_ids']
            return not set(a).isdisjoint(b)

        def if_contains_eval_label(example):
            a = label_dict_eval.keys()
            b = example['input_ids']
            return not set(a).isdisjoint(b)
        # filter dataset for examples containing classes for this split
        print(f"Filtering training data")
        trainset = data.filter(if_contains_train_label, num_proc=num_proc)
        print(f"Filtered {round((1-len(trainset)/len(data))*100)}%; {len(trainset)} remain\n")
        print(f"Filtering evalation data")
        evalset = data.filter(if_contains_eval_label, num_proc=num_proc)
        print(f"Filtered {round((1-len(evalset)/len(data))*100)}%; {len(evalset)} remain\n")

        # minimize to smaller training sample
        training_size = min(subsample_size, len(trainset))
        trainset_min = trainset.select([i for i in range(training_size)])
        eval_size = min(training_size, len(evalset))
        half_training_size = round(eval_size/2)
        evalset_train_min = evalset.select([i for i in range(half_training_size)])
        evalset_oos_min = evalset.select([i for i in range(half_training_size, eval_size)])
        # label conversion functions
        def generate_train_labels(example):
            example["labels"] = [label_dict_train.get(token_id, -100) for token_id in example["input_ids"]]
            return example

        def generate_eval_labels(example):
            example["labels"] = [label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]]
            return example
        # label datasets 
        print(f"Labeling training data")
        trainset_labeled = trainset_min.map(generate_train_labels)
        print(f"Labeling evaluation data")
        evalset_train_labeled = evalset_train_min.map(generate_eval_labels)
        print(f"Labeling evaluation OOS data")
        evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels)
        # create output directories

        ksplit_output_dir = os.path.join(output_dir, f"ksplit{iteration_num}")
        ksplit_model_dir = os.path.join(ksplit_output_dir, "models/") 
        # ensure not overwriting previously saved model
        model_output_file = os.path.join(ksplit_model_dir, "pytorch_model.bin")
        if os.path.isfile(model_output_file) == True:
            raise Exception("Model already saved to this directory.")

        # make training and model output directories
        subprocess.call(f'mkdir {ksplit_output_dir}', shell=True)
        subprocess.call(f'mkdir {ksplit_model_dir}', shell=True)
        # load model
        model = BertForTokenClassification.from_pretrained(
            "D:/jupyterNote/Geneformer", # change as the path to the model
            output_attentions = False,
            output_hidden_states = False
        if freeze_layers is not None:
            modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
            for module in modules_to_freeze:
                for param in module.parameters():
                    param.requires_grad = False
        model = model.to("cuda:0")
        # add output directory to training args and initiate
        training_args["output_dir"] = ksplit_output_dir
        training_args_init = TrainingArguments(**training_args)
        # create the trainer
        trainer = Trainer(

        # train the gene classifier
        # save model
        # evaluate model
        fpr, tpr, interp_tpr, conf_mat = classifier_predict(trainer.model, evalset_oos_labeled, 20, mean_fpr) # forward_batch_size: 20
        # append to tpr and roc lists
        confusion = confusion + conf_mat
        all_roc_auc.append(auc(fpr, tpr))
        # append number of eval examples by which to weight tpr in averaged graphs
        iteration_num = iteration_num + 1
    # get overall metrics for cross-validation
    mean_tpr, roc_auc, roc_auc_sd = get_cross_valid_metrics(all_tpr, all_roc_auc, all_tpr_wt)
    return all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts

Define Functions for Plotting Results


# plot ROC curve
def plot_ROC(bundled_data, title):
    lw = 2
    for roc_auc, roc_auc_sd, mean_fpr, mean_tpr, sample, color in bundled_data:
        plt.plot(mean_fpr, mean_tpr, color=color,
                 lw=lw, label="{0} (AUC {1:0.2f} $\pm$ {2:0.2f})".format(sample, roc_auc, roc_auc_sd))
    plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend(loc="lower right")
# plot confusion matrix
def plot_confusion_matrix(classes_list, conf_mat, title):
    display_labels = []
    i = 0
    for label in classes_list:
        display_labels += ["{0}\nn={1:.0f}".format(label, sum(conf_mat[:,i]))]
        i = i + 1
    display = ConfusionMatrixDisplay(confusion_matrix=preprocessing.normalize(conf_mat, norm="l1"), 

Fine-Tune With Gene Classification Learning Objective and Quantify Predictive Performance

定义模型微调的参数,同样的根据电脑配置调整num_gpus, num_proc, geneformer_batch_size.其余的超参延用预设的值,理论上超参也可以继续优化。

关于freeze_layers的选择,作者说下游任务和pretrain越相似的时候freeze_layers可以越大,即“记住”更多pretrain的weights (?).

Generally, in our experience, applications that are more relevant to the pretraining objective benefit from more layers being frozen to prevent overfitting to the limited task-specific data, whereas applications that are more distant from the pretraining objective benefit from fine-tuning of more layers to optimize performance on the new task.

# set model parameters
# max input size
max_input_size = 2 ** 11  # 2048

# set training hyperparameters
# max learning rate
max_lr = 5e-5
# how many pretrained layers to freeze
freeze_layers = 4
# number gpus
num_gpus = 1
# number cpu cores
num_proc = 6
# batch size for training and eval
geneformer_batch_size = 2
# learning schedule
lr_schedule_fn = "linear"
# warmup steps
warmup_steps = 500
# number of epochs
epochs = 1
# optimizer
optimizer = "adamw"
# set training arguments
subsample_size = 10_000
training_args = {
    "learning_rate": max_lr,
    "do_train": True,
    "evaluation_strategy": "no",
    "save_strategy": "epoch",
    "logging_steps": 100,
    "group_by_length": True,
    "length_column_name": "length",
    "disable_tqdm": False,
    "lr_scheduler_type": lr_schedule_fn,
    "warmup_steps": warmup_steps,
    "weight_decay": 0.001,
    "per_device_train_batch_size": geneformer_batch_size,
    "per_device_eval_batch_size": geneformer_batch_size,
    "num_train_epochs": epochs,
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
training_output_dir = f"D:\\jupyterNote\\Geneformer\\examples\\gene_class_test\\{datestamp}_geneformer_GeneClassifier_dosageTF_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_n{subsample_size}_F{freeze_layers}\\"

# ensure not overwriting previously saved model
ksplit_model_test = os.path.join(training_output_dir, "ksplit0/models/pytorch_model.bin")
if os.path.isfile(ksplit_model_test) == True:
    raise Exception("Model already saved to this directory.")

# make output directory
subprocess.call(f'mkdir {training_output_dir}', shell=True)
# clear GPU memory after pytorch training 
import torch
# not work
#!set 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' # Limit each allocation split to 500 MB

我们使用subsampled_train_dataset进行微调,其中包含50,000个细胞,每次抽取10,000个细胞做CV,一共做5次(nsplits=5).同样,将输入的targets和labels划分为80% training set (n = 392), 和 20% evaluation set (n = 98),这里采取的是stratified split,即不同split之间会有同样的数据。

这些划分的target和label存储在label_dicts中,其中每五个元素为一组,包括iteration_num, targets_train, targets_eval, labels_train, labels_eval.

cross_validate会打印每个split training相关的信息,包括training loss, learning_rate, epoch, ROC curve.

# cross-validate gene classifier
all_roc_auc, roc_auc, roc_auc_sd, mean_fpr, mean_tpr, confusion, label_dicts \
    = cross_validate(subsampled_train_dataset, targets, labels, nsplits, subsample_size, training_args, freeze_layers, training_output_dir, 1)
0it [00:00, ?it/s]

Loading cached processed dataset at D:\jupyterNote\Geneformer\Genecorpus-30M\example_input_files\cell_classification\disease_classification\human_dcm_hcm_nf.dataset\cache-509acb05b140c462.arrow

****** Crossval split: 0/4 ******

Filtering training data
Filtered 0%; 49994 remain

Filtering evalation data
Split 0 training info...
****** Crossval split: 1/4 ******

Filtering training data
Filtered 0%; 49992 remain

Filtering evalation data
Filtered 4%; 47913 remain

Labeling training data
Split 1 training info...
****** Crossval split: 2/4 ******

Filtering training data
Filtered 0%; 49993 remain

Filtering evalation data
Filtered 4%; 47886 remain

Labeling training data
Split 2 training info...
****** Crossval split: 3/4 ******

Filtering training data
Filtered 0%; 49991 remain

Filtering evalation data
Filtered 4%; 48025 remain

Labeling training data
Split 3 training info...
****** Crossval split: 4/4 ******

Filtering training data
Filtered 0%; 49977 remain

Filtering evalation data
Filtered 2%; 48951 remain

Labeling training data
Split 4 training info...
[0.25172310458495656, 0.18719408650484468, 0.1628708420737189, 0.2369393666966337, 0.16127260013984618]
# bundle data for plotting
bundled_data = []
bundled_data += [(roc_auc, roc_auc_sd, mean_fpr, mean_tpr, "Geneformer", "red")]
# plot ROC curve
plot_ROC(bundled_data, 'Dosage Sensitive vs Insensitive TFs')
# plot confusion matrix
classes_list = ["Dosage Sensitive", "Dosage Insensitive"]
plot_confusion_matrix(classes_list, confusion, "Geneformer")

以上是5-fold CV的结果,我们接下来尝试用其中10,000个细胞微调的模型在其相应的out-of-sample evaluation set上进行gene classification.

我们首先读入第一个split的fine-tuned model,并将其转换到GPU上。该模型out_features=2即进行二分类预测。

# reload fine-tuned model
ft_model = BertForTokenClassification.from_pretrained("gene_class_test/230724_geneformer_GeneClassifier_dosageTF_L2048_B2_LR5e-05_LSlinear_WU500_E1_Oadamw_n10000_F4/ksplit0/models/")

  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(25426, 256, padding_idx=0)
      (position_embeddings): Embedding(2048, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.02, inplace=False)
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.02, inplace=False)
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.02, inplace=False)
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=256, out_features=512, bias=True)
            (intermediate_act_fn): ReLU()
          (output): BertOutput(
            (dense): Linear(in_features=512, out_features=256, bias=True)
            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.02, inplace=False)
  (dropout): Dropout(p=0.02, inplace=False)
  (classifier): Linear(in_features=256, out_features=2, bias=True)

我们取出第一个split对应的evaluation targets and labels,并抽取出相应的evaluation set (evalset_oos_labeled)。

# out-of-sample evaluation set
# for set 0
label_dict_eval = dict(zip(label_dicts[2], label_dicts[4]))

def if_contains_eval_label(example, label_dict):
    a = label_dict.keys()
    b = example['input_ids']
    return not set(a).isdisjoint(b)

evalset0 = subsampled_train_dataset.filter(if_contains_eval_label, num_proc=2, fn_kwargs={"label_dict": label_dict_eval})
eval_size0 = min(10000, len(evalset0))
half_training_size = round(eval_size0/2)
evalset_oos_min = evalset0.select([i for i in range(half_training_size, eval_size0)])

def generate_eval_labels(example, label_dict):
    example["labels"] = [label_dict.get(token_id, -100) for token_id in example["input_ids"]]
    return example

evalset_oos_labeled = evalset_oos_min.map(generate_eval_labels, fn_kwargs={"label_dict": label_dict_eval})
    features: ['input_ids', 'length', 'cell_type', 'individual', 'age', 'sex', 'disease', 'lvef', 'labels'],
    num_rows: 5000

这里我们修改一下原本的classifier_predict让其输出微调模型预测的label (y_pred), 真实label (y_true), 模型的预测值 (logits_list), 细胞ID (cell_id)和转录因子的token (token_id_dict).

# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
    predict_logits = []# return prediction results
def get_classifier_predict(model, evalset, forward_batch_size):
    predict_logits = []
    predict_labels = []
    cell_id = []
    token_id_dict = {}
    # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims
    evalset_len = len(evalset)
    max_divisible = find_largest_div(evalset_len, forward_batch_size)
    if len(evalset) - max_divisible == 1:
        evalset_len = max_divisible
    max_evalset_len = max(evalset.select([i for i in range(evalset_len)])["length"])
    for i in range(0, evalset_len, forward_batch_size):
        max_range = min(i+forward_batch_size, evalset_len)
        batch_evalset = evalset.select([i for i in range(i, max_range)])
        padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)
        # cell id
        cell_id += [i for i in range(i, max_range)]
        # store token id by cell j
        for j, tokens in enumerate(batch_evalset['input_ids']):
            cell_idx = range(i, max_range)[j]
            token_id_dict[cell_idx] = [tki for k, tki in enumerate(tokens) if batch_evalset['labels'][j][k] > -1]
        input_data_batch = padded_batch["input_ids"]
        attn_msk_batch = padded_batch["attention_mask"]
        label_batch = padded_batch["labels"]
        with torch.no_grad():
            outputs = model(
                input_ids = input_data_batch.to("cuda"), 
                attention_mask = attn_msk_batch.to("cuda"), 
                labels = label_batch.to("cuda"), 
            predict_logits += [torch.squeeze(outputs.logits.to("cpu"))]
            predict_labels += [torch.squeeze(label_batch.to("cpu"))]
    logits_by_cell = torch.cat(predict_logits)
    all_logits = logits_by_cell.reshape(-1, logits_by_cell.shape[2])
    labels_by_cell = torch.cat(predict_labels)
    all_labels = torch.flatten(labels_by_cell)
    logit_label_paired = [item for item in list(zip(all_logits.tolist(), all_labels.tolist())) if item[1]!=-100]
    y_pred = [vote(item[0]) for item in logit_label_paired]
    y_true = [item[1] for item in logit_label_paired]
    logits_list = [item[0] for item in logit_label_paired]
    return y_pred, y_true, logits_list, cell_id, token_id_dict 
eval_pred, eval_label, eval_logits, cell_id, token_id = get_classifier_predict(model=ft_model, evalset=evalset_oos_labeled, forward_batch_size=20)
Model prediction info...

该模型输出两个分类的预测值,根据最大值来判断该基因的label。这里对每个细胞中的tf都进行了预测 (n = 27,939) .


Counter({1: 16492, 0: 11447})
Counter({0: 14673, 1: 13266})
[[4.6540117263793945, -4.643155574798584], [5.055752277374268, -4.894111156463623], [0.701909065246582, -0.6132677793502808]]
[0, 0, 0]


# # numbers of tfs (genes with 0/1 label) in out-of-sample evaluation set
# tf_num = [len([v for v in i if v >= 0]) for i in evalset_oos_labeled['labels']]
# sum(tf_num)

# frequencies of tokens
token_freq = Counter()

for tks in token_id.values():

Counter({1636: 2636,
         9061: 2755,
         6754: 475,
         16718: 204,
         275: 1445,
         15866: 600,
         5084: 805,
         3361: 272,
         2410: 108,
         1757: 550,
         18597: 82,
         10422: 305,
         14481: 197,
         8218: 766,
         16619: 138,
         4071: 434,
         6931: 1052,
         14023: 468,
         7445: 699,
         4445: 157,
         17672: 983,
         3982: 547,
         5944: 552,
         5357: 359,
         20144: 237,
         6257: 137,
         6456: 185,
         16597: 437,
         2774: 216,
         15781: 553,
         20018: 386,
         23967: 427,
         21561: 218,
         12006: 116,
         20989: 339,
         15753: 199,
         487: 387,
         16016: 530,
         998: 496,
         8972: 382,
         6492: 269,
         14410: 180,
         14286: 228,
         12961: 228,
         8725: 26,
         2707: 82,
         17085: 262,
         15375: 72,
         13606: 313,
         10804: 317,
         12959: 527,
         12435: 202,
         16713: 359,
         12674: 184,
         20959: 88,
         16535: 348,
         21035: 131,
         11880: 34,
         23100: 347,
         21079: 114,
         20581: 284,
         15553: 249,
         14677: 63,
         954: 171,
         17147: 47,
         12995: 51,
         20962: 74,
         12165: 46,
         17092: 66,
         15717: 54,
         9024: 118,
         16555: 67,
         7705: 78,
         13722: 44,
         18778: 100,
         9831: 41,
         5789: 40,
         14124: 59,
         13954: 31,
         10534: 50,
         16425: 6,
         20787: 3,
         9367: 44,
         14578: 1,
         15180: 1,
         12243: 4,
         11443: 1,
         13066: 1})

这里我们随机看两个基因预测分类是否正确,其中gene 9061被预测准确,为药物敏感基因。而gene 16425预测值与标签值不匹配。

# append all tokens into one list
token_id_list = [tk for tks in token_id.values() for tk in tks]

# successed prediction
# get prediction of gene (token = 9061)
target_pred1 = [eval_pred[i] for i in token_id_list if i == 9061]
print("Predicted label of gene 9061: ")
target_label1 = [eval_label[i] for i in token_id_list if i == 9061]
print("True label of gene 9061: ")

# failed prediction
# get prediction of gene (token = 16425)
target_pred2 = [eval_pred[i] for i in token_id_list if i == 16425]
print("Predicted label of gene 16425: ")
target_label2 = [eval_label[i] for i in token_id_list if i == 16425]
print("True label of gene 16425: ")
Predicted label of gene 9061: 
Counter({1: 2755})
True label of gene 9061: 
Counter({1: 2755})
Predicted label of gene 16425: 
Counter({0: 6})
True label of gene 16425: 
Counter({1: 6})



  1. 获取相应微调的数据集,并且有基因的label信息,例如某个TF是否为药物靶点之类的;


  2. BertForTokenClassification的方式读入预训练模型,并设置num_labels为分类数目;

  3. 根据微调的数据集训练,加上最后的输出层(task-specific transformer layer),并对微调模型预测性能进行评估;

  4. 在新的数据集上应用微调模型进行预测。

另外,作者最近更新上传了心肌炎单细胞数据微调的模型 (https://huggingface.co/ctheodoris/Geneformer/tree/main/fine_tuned_models/geneformer-6L-30M_CellClassifier_cardiomyopathies_220224)。大家也可以直接下载该模型使用。



