{ "cells": [ { "cell_type": "markdown", "id": "0c97d6cd", "metadata": {}, "source": [ "# ランダムフォレスト\n", "\n", "ランダムフォレストについては以下を参照してください。\n", "[https://scikit-learn.org/stable/modules/ensemble.html#forest](https://scikit-learn.org/stable/modules/ensemble.html#forest)" ] }, { "cell_type": "markdown", "id": "beaeafb3", "metadata": {}, "source": [ "**データとモジュールのロード**\n", "\n", "学習に使うデータをロードします。" ] }, { "cell_type": "code", "execution_count": 1, "id": "aa65ec88", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn import model_selection\n", "\n", "data = pd.read_csv(\"input/pn_same_judge_preprocessed.csv\")\n", "train, test = model_selection.train_test_split(data, test_size=0.1, random_state=0)" ] }, { "cell_type": "code", "execution_count": 2, "id": "a8500c05", "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.metrics import PrecisionRecallDisplay" ] }, { "cell_type": "markdown", "id": "0c69e19a", "metadata": {}, "source": [ "## 決定木\n", "\n", "ランダムフォレストは決定木の\n", "[バギング](https://ja.wikipedia.org/wiki/バギング)\n", "によりアンサンブル学習する手法なので、\n", "まずは決定木から始めましょう。\n", "\n", "決定木を学習するには\n", "[sklearn.tree.DecisionTreeClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html)\n", "を使います。\n", "ここでは正則化のために `max_depth`, `min_samples_leaf` パラメータを指定しています。" ] }, { "cell_type": "code", "execution_count": 3, "id": "d0497415", "metadata": {}, "outputs": [], "source": [ "from sklearn.tree import DecisionTreeClassifier" ] }, { "cell_type": "code", "execution_count": 4, "id": "817bb85b", "metadata": {}, "outputs": [], "source": [ "pipe_dt = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", DecisionTreeClassifier(max_depth=2, min_samples_leaf=10, random_state=0)),\n", "])" ] }, { "cell_type": "code", "execution_count": 5, "id": "77ec26e9", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf',\n",
       "                 DecisionTreeClassifier(max_depth=2, min_samples_leaf=10,\n",
       "                                        random_state=0))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf',\n", " DecisionTreeClassifier(max_depth=2, min_samples_leaf=10,\n", " random_state=0))])" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_dt.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 6, "id": "6fd0041d", "metadata": {}, "outputs": [], "source": [ "score_dt = pipe_dt.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "code", "execution_count": 7, "id": "940d6cec", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(\n", " y_true=test[\"label_num\"],\n", " y_pred=score_dt,\n", " name=\"Decision Tree\",\n", ")" ] }, { "cell_type": "markdown", "id": "45f8e117", "metadata": {}, "source": [ "木を表示してみましょう。\n", "\n", "```{note}\n", "sklearn 0.21 から\n", "[plot_tree](https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html#sklearn.tree.plot_tree)\n", "で決定木を表示できるようになりました。\n", "\n", "以下のドキュメントが参考になります。\n", "\n", "* [https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html](https://scikit-learn.org/stable/auto_examples/tree/plot_iris_dtc.html)\n", "* [https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html](https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html)\n", "```" ] }, { "cell_type": "code", "execution_count": 8, "id": "d7fb67cd", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.tree import plot_tree\n", "\n", "plot_tree(pipe_dt[\"clf\"], filled=True)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "id": "33dd961a", "metadata": {}, "outputs": [], "source": [ "words = pipe_dt[\"vect\"].get_feature_names_out()" ] }, { "cell_type": "code", "execution_count": 10, "id": "0638c6a1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "('残念', 'が', 'ます')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "words[2159], words[198], words[518]" ] }, { "cell_type": "markdown", "id": "ff5bd725", "metadata": {}, "source": [ "## ランダムフォレスト\n", "\n", "ランダムフォレストは決定木をバギングしたモデルです。\n", "`sklearn.ensemble.BaggingClassifier` を使うと次のように実装できます。" ] }, { "cell_type": "code", "execution_count": 11, "id": "49aa38a4", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import BaggingClassifier" ] }, { "cell_type": "code", "execution_count": 12, "id": "1f3055cc", "metadata": {}, "outputs": [], "source": [ "bagging = BaggingClassifier(\n", " DecisionTreeClassifier(splitter=\"random\"), # splitterはrandomに設定して、特徴量をランダムに探索する\n", " n_estimators=1000,\n", " random_state=0,\n", " n_jobs=-1, # 全てのCPUを使う\n", ")\n", "pipe_bagging = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", bagging),\n", "])" ] }, { "cell_type": "code", "execution_count": 13, "id": "4101479e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf',\n",
       "                 BaggingClassifier(base_estimator=DecisionTreeClassifier(splitter='random'),\n",
       "                                   n_estimators=1000, n_jobs=-1,\n",
       "                                   random_state=0))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf',\n", " BaggingClassifier(base_estimator=DecisionTreeClassifier(splitter='random'),\n", " n_estimators=1000, n_jobs=-1,\n", " random_state=0))])" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_bagging.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 14, "id": "4bde2e58", "metadata": {}, "outputs": [], "source": [ "score_bagging = pipe_bagging.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "code", "execution_count": 15, "id": "7022429d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(\n", " y_true=test[\"label_num\"],\n", " y_pred=score_bagging,\n", " name=\"RandomForest (Bagging)\",\n", ")" ] }, { "cell_type": "markdown", "id": "e6fc52c3", "metadata": {}, "source": [ "BuggingClassifier を使って実装しなくても、scikit-learn は\n", "[sklearn.ensembleRandomForestClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)\n", "を提供しています。\n", "ランダムフォレストを使う場合は、こちらを使う方がいいでしょう。" ] }, { "cell_type": "code", "execution_count": 16, "id": "4838d9c6", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "code", "execution_count": 17, "id": "84cde814", "metadata": {}, "outputs": [], "source": [ "random_forest = RandomForestClassifier(\n", " n_estimators=1000,\n", " random_state=0,\n", " n_jobs=-1,\n", ")\n", "pipe_rf = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", random_forest),\n", "])" ] }, { "cell_type": "code", "execution_count": 18, "id": "2257ae6b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf',\n",
       "                 RandomForestClassifier(n_estimators=1000, n_jobs=-1,\n",
       "                                        random_state=0))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf',\n", " RandomForestClassifier(n_estimators=1000, n_jobs=-1,\n", " random_state=0))])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_rf.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 19, "id": "05596c23", "metadata": {}, "outputs": [], "source": [ "score_rf = pipe_rf.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "code", "execution_count": 20, "id": "842b14d2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(\n", " y_true=test[\"label_num\"],\n", " y_pred=score_rf,\n", " name=\"RandomForest\",\n", ")" ] }, { "cell_type": "markdown", "id": "cac39e81", "metadata": {}, "source": [ "## 特徴量の重要度\n", "\n", "ランダムフォレストでは、 `feature_importances_` 属性を見ることで、 素性の重要度を知ることができます。\n", "モデルの精度を改善していくステップで重要な情報となります。" ] }, { "cell_type": "code", "execution_count": 21, "id": "59d47fef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1.62063358e-03, 3.31198460e-06, 1.33262805e-04, ...,\n", " 8.58042819e-05, 3.38860088e-05, 1.90558795e-04])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_rf[\"clf\"].feature_importances_" ] }, { "cell_type": "code", "execution_count": 22, "id": "1a17970c", "metadata": {}, "outputs": [], "source": [ "importances = pipe_rf[\"clf\"].feature_importances_" ] }, { "cell_type": "code", "execution_count": 23, "id": "0b8a82e2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(0.06262130422316252, '残念'),\n", " (0.03232310022880119, 'が'),\n", " (0.029163227525310448, '狭い'),\n", " (0.02537735713464425, 'ない'),\n", " (0.024742874383347428, '。'),\n", " (0.020948336676578652, 'ぬ'),\n", " (0.020610069080122854, '少し'),\n", " (0.0188537174145827, 'た'),\n", " (0.016496863123408128, 'です'),\n", " (0.015541715069967977, '悪い')]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(sorted(zip(importances, words), reverse=True))[:10]" ] }, { "cell_type": "markdown", "id": "f364cc6b", "metadata": {}, "source": [ "この結果を見ると、「が」「です」といった分類に効果があるとは思えない単語が重要な素性として選ばれてしまっていることが分かります。\n", "そこで、例えば次の手としてストップワードを定義して素性から取り除く方針が思いつきます。\n", "\n", "このように、素性を選択するためにランダムフォレストをまずは適用してみるという方針も可能です。" ] }, { "cell_type": "markdown", "id": "97cfea08", "metadata": {}, "source": [ "## 勾配ブースティング\n", "\n", "最後に、勾配ブースティングの手法を見ておきます。\n", "\n", "勾配ブースティングを分類問題に適用するには\n", "[sklearn.ensemble.GradientBoostingClassifier](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.GradientBoostingClassifier.html)\n", "を使います。\n", "\n", "並列化できず、したがって `n_jobs` パラメータは指定できないことに注意してください。\n", "\n" ] }, { "cell_type": "code", "execution_count": 24, "id": "59d68f3a", "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import GradientBoostingClassifier" ] }, { "cell_type": "code", "execution_count": 25, "id": "5853be20", "metadata": {}, "outputs": [], "source": [ "gb = GradientBoostingClassifier(\n", " n_estimators=1000,\n", " random_state=0,\n", " learning_rate=0.1,\n", ")\n", "\n", "pipe_gb = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", gb),\n", "])" ] }, { "cell_type": "code", "execution_count": 26, "id": "4336c198", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf',\n",
       "                 GradientBoostingClassifier(n_estimators=1000,\n",
       "                                            random_state=0))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf',\n", " GradientBoostingClassifier(n_estimators=1000,\n", " random_state=0))])" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_gb.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 27, "id": "2ed8ce08", "metadata": {}, "outputs": [], "source": [ "score_gb = pipe_gb.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "code", "execution_count": 28, "id": "73680993", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(\n", " y_true=test[\"label_num\"],\n", " y_pred=score_gb,\n", " name=\"RandomForest\",\n", ")" ] }, { "cell_type": "markdown", "id": "e8cd036d", "metadata": {}, "source": [ "```{note}\n", "正則化としてearly stoppingを使う場合には `n_iter_no_change` と `validation_fraction` を設定します。\n", "\n", " gb = GradientBoostingClassifier(\n", " n_estimators=1000,\n", " random_state=0,\n", " learning_rate=0.1,\n", " # early stoppingのための設定\n", " validation_fraction=0.1,\n", " n_iter_no_change=3,\n", " )\n", "\n", "```" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }