美文网首页
TabNet-神经网络处理表格数据实战

TabNet-神经网络处理表格数据实战

作者: 雪糕遇上夏天 | 来源:发表于2021-09-23 11:20 被阅读0次

    我们知道神经网络在图片、信号等领域大放异彩。但在表格数据领域,基本还是树模型的主场。今天我们介绍下TabNet的使用方式,这是一个能够很好的处理tabular数据的神经网络模型。
    下面我们介绍下TabNet的使用。

    1. 安装

    根据官方介绍,安装tabnet之前需要Tensorflow 2.0+版本和Tensorflow-dataset(非必须)。确保Tensorflow 2.0+正确安装之后,就可以安装TabNet了。

    pip install tabnet[cpu]
    pip install tabnet[gpu]
    

    就像TensorFlow有cpu版和gpu版一样,TabNet也有cpu版和gpu版,可以按需选择。

    2. 使用

    tabnet包提供了TabNetClassifier和TabNetRegression分别用于处理分类任务和回归任务。以TabNetClassification为例,他是在TabNet模块的基础上加入了处理分类任务的全连接层(即:激活函数为softmax)。

    from tabnet import TabNetClassifier
    

    我们使用iris数据,做个简单的分类任务

    import tensorflow_datasets as tfds
    def transform(ds):
        features = tf.unstack(ds['features'])
        labels = ds['label']
    
        x = dict(zip(col_names, features))
        y = tf.one_hot(labels, 3)
        return x, y
    
    col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    ds_full = tfds.load(name="iris", split=tfds.Split.TRAIN)
    ds_full = ds_full.shuffle(150, seed=0)
    
    ds_train = ds_full.take(train_size)
    ds_train = ds_train.map(transform)
    ds_train = ds_train.batch(32)
    
    ds_test = ds_full.skip(train_size)
    ds_test = ds_test.map(transform)
    ds_test = ds_test.batch(32)
    

    需要注意的是要把特征数据转化成map类型,因为模型的第一个参数即为特征的参数名称。
    iris共有150条数据,每个数据有4个特征。所以我们设置如下:

    col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
    feature_columns = []
    for col_name in col_names:
        feature_columns.append(tf.feature_column.numeric_column(col_name))
    model = TabNetClassifier(feature_columns, num_classes=3, feature_dim=8, output_dim=4)
    

    至此模型就创建好了,下面就是训练的部分:

    lr = tf.keras.optimizers.schedules.ExponentialDecay(0.01, decay_steps=100, decay_rate=0.9)
    optimizer = tf.keras.optimizers.Adam(lr)
    model.compile(optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    model.fit(ds_train, epochs=100, validation_data=ds_test, verbose=2)
    
    

    相关文章

      网友评论

          本文标题:TabNet-神经网络处理表格数据实战

          本文链接:https://www.haomeiwen.com/subject/raingltx.html