美文网首页推荐系统
Spark使用word2vec训练item2vec实现内容相关推

Spark使用word2vec训练item2vec实现内容相关推

作者: 蚂蚁学Python | 来源:发表于2019-08-28 23:36 被阅读0次

    之前使用spark als训练协同过滤,然后导出itemvectors做相似度计算,后来学到了可以用word2vec实现item2vec的训练效果貌似更好,试了一下果然不错;

    spark版本:2.3.1,开发语言为JAVA

    几大步骤

    读取查看、点击、播放等行为数据,我用的是播放数据;

    数据整理成(userid, itemid, playcnt)的形式,这个数据可能是聚合N天得到的;

    过滤掉playcnt为小于3的数据,我把这些过滤掉,觉得这个数据没有贡献;

    按照userid聚合,得到(userid, list(itemid))的形式;

    训练word2vec;

    导出model.vectors(),里面包括word和对应的向量vector,其中word其实就是itemid

    crossjoin计算两两相似度,取相似度TOP N;

    将结果存入mysql,后续可以加载到REDIS实现实时相似推荐;

    代码实现

    读取播放数据:

    Dataset<Row> playDatas = spark.sql(

            "select user_id, item_id, play_cnt " +

                    "from hive_play_table group by user_id, item_id");

    做数据按userid聚合:

    playDatas = playDatas

            // 删除掉只播放3次以下的数据

            .filter("play_cnt>2")

            // 按userid聚合

            .groupBy("user_id")

            .agg(collect_list("item_id").as("item_ids"))

            // 至少操作过2个元素

            .where(size(col("item_ids")).geq(2));

    训练word2vec:

    Word2Vec word2Vec = new Word2Vec()

            .setInputCol("item_ids")

            .setOutputCol("word2vec_result")

            .setVectorSize(50)

            .setMinCount(0)

            .setMaxIter(50)

            .setSeed(123);

    Word2VecModel word2VecModel = word2Vec.fit(playDatas);

    实现df的cross join:

    Dataset<Row> vectorsA = word2VecModel

            .getVectors()

            .select(

                    col("word").as("itemIdA"),

                    col("vector").as("vectorA"));

    Dataset<Row> vectorsB = word2VecModel

            .getVectors()

            .select(

                    col("word").as("itemIdB"),

                    col("vector").as("vectorB"));

    // self cross join

    Dataset<Row> crossDatas = vectorsA.crossJoin(vectorsB);

    注册余弦相似度计算函数:

    spark.udf().register(

            "vectorCosinSim",

            new UDF2<Vector, Vector, Double>() {

                @Override

                public Double call(Vector vectora, Vector vectorb) throws Exception {

                    return SimilarityUtils.cosineSimilarity(vectora, vectorb);

                }

            },

            DataTypes.DoubleType

    );

    其中调用的余弦相似度计算函数,使用JAVA实现:

    public static double cosineSimilarity(Vector featuresLeft, Vector featuresRight) {

        double[] dataLeft = featuresLeft.toArray();

        List<Float> lista = new ArrayList<>();

        if (dataLeft.length > 0) {

            for (double d : dataLeft) {

                lista.add((float) d);

            }

        }

        double[] dataRight = featuresRight.toArray();

        List<Float> listb = new ArrayList<>();

        if (dataRight.length > 0) {

            for (double d : dataRight) {

                listb.add((float) d);

            }

        }

        return cosineSimilarity(lista, listb);

    }

    实现相似度计算,并过滤掉自身和自身的计算:

    crossDatas = crossDatas

            .withColumn(

                    "cosineSimilarity", callUDF(

                            "vectorCosinSim", col("vectorA"), col("vectorB")))

            .select("itemIdA", "itemIdB", "cosineSimilarity")

            .filter(col("itemIdA").notEqual(col("itemIdB")));

    使用 spark的Window,提取每个group的topn:

    // 按照相似度倒序排列取TOP 300

    WindowSpec windowSpec = Window.partitionBy("itemIdA").orderBy(col("cosineSimilarity").desc());

    crossDatas = crossDatas

            .withColumn("simRank", rank().over(windowSpec))

            .where(col("simRank").leq(200));

    将数据聚合成每个Item的推荐列表的形式:

    crossDatas = crossDatas

            .groupBy("itemIdA")

            .agg(

                    collect_list("cosineSimilarity").as("columnSims"),

                    collect_list("itemIdB").as("itemIds")

            ).select(

                    col("itemIdA").as("item_id").cast(DataTypes.LongType),

                    col("columnSims").as("column_sims").cast(DataTypes.StringType),

                    col("itemIds").as("item_ids").cast(DataTypes.StringType)

            );

    将数据覆盖写入MySQL:

    crossDatas.write().mode(SaveMode.Overwrite).jdbc(

            MysqlConfig.ONLINE_MYSQL_MASTER_URL,

            "item2vec_sims",

            MysqlConfig.getOnlineProperties()

    );

    在数据库中,我们根据item_id,提取到item_ids,可以用于直接的推荐;其中column_sims也记录了对应的相似度权重,如果需要加权的话也可以直接提取;

    相关文章

      网友评论

        本文标题:Spark使用word2vec训练item2vec实现内容相关推

        本文链接:https://www.haomeiwen.com/subject/fanvectx.html