{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Embedding + MLP\n", "\n", "```{note}\n", "Embedding + MLP是最经典的深度学习推荐模型结构,也是后续诸多模型的基础。\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 结构\n", "\n", "![jupyter](../images/mlp1.jpeg)\n", "\n", "Feature层:类别型特征向上连接到Embedding层,而数值型特征则直接连接到Stacking层。\n", "\n", "Embedding层:将类别型特征转化为稠密向量。\n", "\n", "Stacking层:堆叠层,即将各个向量拼接(concatenate)在一起。\n", "\n", "MLP层:多层神经网络,这里使用了残差(residual)结构,我们使用普通的MLP也可以。\n", "\n", "Scoring层:输出层,若是CTR预估则使用Sigmoid激活函数。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据预处理" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras\n", "import rec\n", "\n", "# 读取movielens数据集\n", "train_dataset, test_dataset = rec.load_movielens()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
movieIduserIdratingtimestamplabelreleaseYearmovieGenre1movieGenre2movieGenre3movieRatingCount...userRatingCountuserAvgReleaseYearuserReleaseYearStddevuserAvgRatinguserRatingStddevuserGenre1userGenre2userGenre3userGenre4userGenre5
01155553.090095374001995AdventureAnimationChildren10759...9219928.983.860.74DramaComedyThrillerActionCrime
11259123.5111163176811995AdventureAnimationChildren10759...21198814.093.481.28ActionComedyRomanceAdventureThriller
21299123.086682036001995AdventureAnimationChildren10759...419950.503.000.00NaNNaNNaNNaNNaN
310176860.5119555501101995ActionAdventureThriller6330...3519928.352.971.48ComedyDramaAdventureActionThriller
4104201584.0115535769111996ComedyNaNNaN3954...8119918.703.600.72ThrillerDramaActionCrimeAdventure
..................................................................
88822968268653.085409223201968HorrorSci-FiThriller1824...94199112.233.350.85DramaThrillerComedyCrimeRomance
8882396885072.097470906101968HorrorSci-FiThriller1824...519940.892.001.00NaNNaNNaNNaNNaN
88824969166895.085785404411951AdventureComedyRomance2380...9719929.953.530.82DramaComedyCrimeRomanceThriller
88825969264602.0125027957601951AdventureComedyRomance2380...55199011.782.731.42ThrillerCrimeDramaComedySci-Fi
8882697030332.0127239460301953AdventureComedyCrime98...100198517.643.670.89DramaRomanceComedyThrillerCrime
\n", "

88827 rows × 27 columns

\n", "
" ], "text/plain": [ " movieId userId rating timestamp label releaseYear movieGenre1 \\\n", "0 1 15555 3.0 900953740 0 1995 Adventure \n", "1 1 25912 3.5 1111631768 1 1995 Adventure \n", "2 1 29912 3.0 866820360 0 1995 Adventure \n", "3 10 17686 0.5 1195555011 0 1995 Action \n", "4 104 20158 4.0 1155357691 1 1996 Comedy \n", "... ... ... ... ... ... ... ... \n", "88822 968 26865 3.0 854092232 0 1968 Horror \n", "88823 968 8507 2.0 974709061 0 1968 Horror \n", "88824 969 16689 5.0 857854044 1 1951 Adventure \n", "88825 969 26460 2.0 1250279576 0 1951 Adventure \n", "88826 970 3033 2.0 1272394603 0 1953 Adventure \n", "\n", " movieGenre2 movieGenre3 movieRatingCount ... userRatingCount \\\n", "0 Animation Children 10759 ... 92 \n", "1 Animation Children 10759 ... 21 \n", "2 Animation Children 10759 ... 4 \n", "3 Adventure Thriller 6330 ... 35 \n", "4 NaN NaN 3954 ... 81 \n", "... ... ... ... ... ... \n", "88822 Sci-Fi Thriller 1824 ... 94 \n", "88823 Sci-Fi Thriller 1824 ... 5 \n", "88824 Comedy Romance 2380 ... 97 \n", "88825 Comedy Romance 2380 ... 55 \n", "88826 Comedy Crime 98 ... 100 \n", "\n", " userAvgReleaseYear userReleaseYearStddev userAvgRating \\\n", "0 1992 8.98 3.86 \n", "1 1988 14.09 3.48 \n", "2 1995 0.50 3.00 \n", "3 1992 8.35 2.97 \n", "4 1991 8.70 3.60 \n", "... ... ... ... \n", "88822 1991 12.23 3.35 \n", "88823 1994 0.89 2.00 \n", "88824 1992 9.95 3.53 \n", "88825 1990 11.78 2.73 \n", "88826 1985 17.64 3.67 \n", "\n", " userRatingStddev userGenre1 userGenre2 userGenre3 userGenre4 \\\n", "0 0.74 Drama Comedy Thriller Action \n", "1 1.28 Action Comedy Romance Adventure \n", "2 0.00 NaN NaN NaN NaN \n", "3 1.48 Comedy Drama Adventure Action \n", "4 0.72 Thriller Drama Action Crime \n", "... ... ... ... ... ... \n", "88822 0.85 Drama Thriller Comedy Crime \n", "88823 1.00 NaN NaN NaN NaN \n", "88824 0.82 Drama Comedy Crime Romance \n", "88825 1.42 Thriller Crime Drama Comedy \n", "88826 0.89 Drama Romance Comedy Thriller \n", "\n", " userGenre5 \n", "0 Crime \n", "1 Thriller \n", "2 NaN \n", "3 Thriller \n", "4 Adventure \n", "... ... \n", "88822 Romance \n", "88823 NaN \n", "88824 Thriller \n", "88825 Sci-Fi \n", "88826 Crime \n", "\n", "[88827 rows x 27 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rec.get_movielens_df()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 处理类别型特征\n", "\n", "tf.feature_column.categorical_column_with_vocabulary_list: 指定vocab,将值转化成one-hot\n", "\n", "tf.feature_column.embedding_column: one-hot转化为embedding" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# 电影的类别\n", "genre_vocab = ['Film-Noir', 'Action', 'Adventure', 'Horror', 'Romance', 'War', \n", " 'Comedy', 'Western', 'Documentary', 'Sci-Fi', 'Drama', 'Thriller', \n", " 'Crime', 'Fantasy', 'Animation', 'IMAX', 'Mystery', 'Children', 'Musical']\n", "# 类别列\n", "GENRE_FEATURES = {\n", " 'userGenre1': genre_vocab,\n", " 'userGenre2': genre_vocab,\n", " 'userGenre3': genre_vocab,\n", " 'userGenre4': genre_vocab,\n", " 'userGenre5': genre_vocab,\n", " 'movieGenre1': genre_vocab,\n", " 'movieGenre2': genre_vocab,\n", " 'movieGenre3': genre_vocab\n", "}\n", "\n", "categorical_columns = []\n", "for feature, vocab in GENRE_FEATURES.items():\n", " # 先转化为one-hot\n", " cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n", " key=feature, vocabulary_list=vocab)\n", " # 再转化为embedding,维度是10维\n", " emb_col = tf.feature_column.embedding_column(cat_col, 10)\n", " categorical_columns.append(emb_col)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "tf.feature_column.categorical_column_with_identity: 指定id的最大取值,将id转化为one-hot" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# movie id embedding feature\n", "# movieId的取值应当在[0, num_buckets)\n", "movie_col = tf.feature_column.categorical_column_with_identity(key='movieId', num_buckets=1001)\n", "movie_emb_col = tf.feature_column.embedding_column(movie_col, 10)\n", "categorical_columns.append(movie_emb_col)\n", "\n", "# user id embedding feature\n", "user_col = tf.feature_column.categorical_column_with_identity(key='userId', num_buckets=30001)\n", "user_emb_col = tf.feature_column.embedding_column(user_col, 10)\n", "categorical_columns.append(user_emb_col)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 处理数值型特征\n", "\n", "使用tf.feature_column.numeric_column就可以了" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# all numerical features\n", "numerical_columns = [tf.feature_column.numeric_column('releaseYear'),\n", " tf.feature_column.numeric_column('movieRatingCount'),\n", " tf.feature_column.numeric_column('movieAvgRating'),\n", " tf.feature_column.numeric_column('movieRatingStddev'),\n", " tf.feature_column.numeric_column('userRatingCount'),\n", " tf.feature_column.numeric_column('userAvgRating'),\n", " tf.feature_column.numeric_column('userRatingStddev')]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 定义模型" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# embedding + MLP model architecture\n", "model = tf.keras.Sequential([\n", " # 进行数据预处理\n", " # 输入tf.feature_column的列表\n", " tf.keras.layers.DenseFeatures(numerical_columns + categorical_columns),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(128, activation='relu'),\n", " tf.keras.layers.Dense(1, activation='sigmoid'),\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 训练" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# compile the model, set loss function, optimizer and evaluation metrics\n", "model.compile(\n", " loss='binary_crossentropy',\n", " optimizer='adam',\n", " metrics=['accuracy', tf.keras.metrics.AUC(curve='ROC'), tf.keras.metrics.AUC(curve='PR')])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/5\n", "7403/7403 [==============================] - 20s 3ms/step - loss: 0.4824 - accuracy: 0.7690 - auc: 0.8447 - auc_1: 0.8684\n", "Epoch 2/5\n", "7403/7403 [==============================] - 23s 3ms/step - loss: 0.4708 - accuracy: 0.7735 - auc: 0.8526 - auc_1: 0.8767\n", "Epoch 3/5\n", "7403/7403 [==============================] - 29s 4ms/step - loss: 0.4628 - accuracy: 0.7766 - auc: 0.8582 - auc_1: 0.8824\n", "Epoch 4/5\n", "7403/7403 [==============================] - 28s 4ms/step - loss: 0.4577 - accuracy: 0.7793 - auc: 0.8615 - auc_1: 0.8863\n", "Epoch 5/5\n", "7403/7403 [==============================] - 22s 3ms/step - loss: 0.4525 - accuracy: 0.7809 - auc: 0.8648 - auc_1: 0.8904\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train the model\n", "model.fit(train_dataset, epochs=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 评估和预测" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1870/1870 [==============================] - 3s 1ms/step - loss: 0.6413 - accuracy: 0.6877 - auc: 0.7420 - auc_1: 0.7672\n", "Test Loss 0.641334, Test Accuracy 0.687656\n", "Test ROC AUC 0.742031, Test PR AUC 0.767180\n" ] } ], "source": [ "# evaluate the model\n", "test_loss, test_accuracy, test_roc_auc, test_pr_auc = model.evaluate(test_dataset)\n", "print('Test Loss {:3f}, Test Accuracy {:3f}'.format(test_loss, test_accuracy))\n", "print('Test ROC AUC {:3f}, Test PR AUC {:3f}'.format(test_roc_auc, test_pr_auc))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "prediction: 0.92 label: 0\n", "prediction: 0.01 label: 0\n", "prediction: 0.90 label: 1\n", "prediction: 0.15 label: 1\n", "prediction: 0.44 label: 0\n", "prediction: 0.53 label: 1\n", "prediction: 0.54 label: 0\n", "prediction: 0.28 label: 0\n", "prediction: 0.31 label: 1\n" ] } ], "source": [ "# print some predict results\n", "predictions = model.predict(test_dataset)\n", "# 查看9个样本的预测值和label\n", "for prediction, label in zip(predictions[:9], list(test_dataset)[0][1][:9]):\n", " print(\"prediction: {:.2f}\".format(prediction[0]), \"label: {}\".format(label))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }