说明
生存时间:死亡时间-出生时间,aft采用log函数描述生存时间
用途
顾名思义,根据已知的存活时间、存活状态(死亡的,活着或失联的)、特征属性,预测生存时间
from pyspark.ml.regression import AFTSurvivalRegression
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
spark = SparkSession \
.builder \
.appName("AFTSurvivalRegressionExample") \
.getOrCreate()
training = spark.createDataFrame([
(1.218, 1.0, Vectors.dense(1.560, -0.605)),# 稠密向量:当成一维数组就行
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
(4.199, 0.0, Vectors.dense(0.795, -0.226))], ["label", "censor", "features"])
quantileProbabilities = [0.3, 0.6]# 分位数概率
# 前4个参数使用默认值
# 此模型需要3列数据,predictionCol和quantilesCol为输出列
# label:表示存活时间
# censor:表示生存状态。1为死亡,0为存续或失访
# features:表示特征属性值(比如患病的各个特征值)
aft = AFTSurvivalRegression(featuresCol='features',labelCol='label',censorCol='censor',predictionCol='prediction',quantileProbabilities=quantileProbabilities,
quantilesCol="quantiles")
model = aft.fit(training)
# Print the coefficients, intercept and scale parameter for AFT survival regression
print("Coefficients: " + str(model.coefficients))
print("Intercept: " + str(model.intercept))
print("Scale: " + str(model.scale))
# prediction列:表示生存时间的预测值
# quantiles列:表示给定分位数值对应的生存时间
# 预测值好像要经过log变换一下
model.transform(training).show(truncate=False)# 这个样本数据特不准
网友评论