{ "cells": [ { "cell_type": "markdown", "id": "07463a5f", "metadata": {}, "source": [ "# ロジスティック回帰\n", "\n", "ロジスティック回帰全般については以下を参照してください。\n", "[https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression](https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression)" ] }, { "cell_type": "markdown", "id": "22ae72e6", "metadata": {}, "source": [ "**データとモジュールのロード**" ] }, { "cell_type": "code", "execution_count": 1, "id": "4f42c570", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from sklearn import model_selection\n", "\n", "data = pd.read_csv(\"input/pn_same_judge_preprocessed.csv\")\n", "train, test = model_selection.train_test_split(data, test_size=0.1, random_state=0)" ] }, { "cell_type": "code", "execution_count": 2, "id": "55fdf55a", "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import Pipeline\n", "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.metrics import ConfusionMatrixDisplay\n", "from sklearn.metrics import PrecisionRecallDisplay" ] }, { "cell_type": "markdown", "id": "310d666d", "metadata": {}, "source": [ "## LogisticRegression" ] }, { "cell_type": "markdown", "id": "ea30df74", "metadata": {}, "source": [ "[sklearn.linear_model.LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)\n", "を使います。" ] }, { "cell_type": "code", "execution_count": 3, "id": "79cc90a0", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression" ] }, { "cell_type": "code", "execution_count": 4, "id": "bb3dbb21", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf', LogisticRegression())])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', LogisticRegression())])" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", LogisticRegression())\n", "])\n", "\n", "pipe.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 5, "id": "441429d6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pred = pipe.predict(test[\"tokens\"])\n", "ConfusionMatrixDisplay.from_predictions(y_true=test[\"label_num\"], y_pred=pred)" ] }, { "cell_type": "code", "execution_count": 6, "id": "dd63ab0c", "metadata": {}, "outputs": [], "source": [ "score = pipe.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "code", "execution_count": 7, "id": "2b58a9c2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAApQElEQVR4nO3de5xVdb3/8dc7QEFEMIEykEuoGQKSjhIHOV5KIzXJK4Ke1FTUvB1LT/Q7hqKdn1pG5pHzA1O8VN7STFQSPYmXyguDjSgYSUA6QEcOgty8gZ/fH2sNbebCXjPMnj2z9/v5eOzH7HX/rGHYn72+V0UEZmZWvj5R7ADMzKy4nAjMzMqcE4GZWZlzIjAzK3NOBGZmZa59sQNorO7du0e/fv2KHYaZWZsyd+7c/42IHvVta3OJoF+/flRWVhY7DDOzNkXS3xra5qIhM7My50RgZlbmnAjMzMqcE4GZWZlzIjAzK3MFSwSSpkt6W9JrDWyXpJskLZI0T9L+hYrFzMwaVsgngjuAUdvY/lVgr/Q1Hvh/BYzFzMwaULB+BBHxrKR+29hlNHBXJONgvyCpm6TdI2JFIeKZ9Mh8FixfW4hTF8Xoob0YN6xPscMwsxJQzDqCXsBbOcvV6bo6JI2XVCmpcuXKlS0SXGu2YMVaHq5aVuwwzKxEtImexRFxC3ALQEVFRZNm0rnya/s2a0zFNGba88UOwcxKSDGfCJYBe+Qs907XmZlZCypmIpgBfCNtPfRF4N1C1Q+YmVnDClY0JOke4FCgu6Rq4EqgA0BETAVmAkcBi4CNwJmFisXMzBpWyFZDY/NsD+CCQl3fzMyycc9iM7My50RgZlbmnAjMzMqcE4GZWZlzIjAzK3NOBGZmZc6JwMyszLWJsYasuO5+8c06g9x59FOz0uFEYFup70P/xSXvADCs/yeBZPRTwInArEQ4EZSxLB/6Ne9znwA8+qlZaXEiKCO1P/izfOibWelzIihRWb7t+0PfzMCJoGT4276ZNZUTQYl4uGoZC1asZeDuuwD+0Dez7JwI2qCab/u5lbY1SeC+c4cXK6wmcdNUs+JzIigRA3ffhdFDexU7jG1y01Sz1smJoA06eM/uAPzi7GFFjmTbmlJv4aapZi3PiaANao0JwK2UzNqugiYCSaOAnwLtgFsj4rpa2/sC04EewDvAaRFRXciYrDBqV1aDP/jN2opCTl7fDpgCHAFUA3MkzYiIBTm73QDcFRF3SjocuBb4l0LFZM2jlCqrzaywo48eBCyKiMUR8SFwLzC61j4DgafS97Pr2W5tRFuorDaz+hWyaKgX8FbOcjVQu3D7FeB4kuKj44AuknaLiFW5O0kaD4wH6NPHxQzF1lYqq80sm2JXFl8G3CzpDOBZYBmwufZOEXELcAtARUVFtGSAVpcTgFlpKWQiWAbskbPcO123RUQsJ3kiQNLOwAkRsaaAMZmZWS2FTARzgL0k9SdJAKcA43J3kNQdeCciPga+R9KCyKzF1NfstT5u/WSlrGCJICI2SboQmEXSfHR6RMyXdDVQGREzgEOBayUFSdHQBYWKxyzr/Au1ubezlbqC1hFExExgZq11E3PePwA8UMgYrHw114isTe3t3NDThp8urLUpdmWxWcG05IisWZ82/HRhrVGmRCCpJzAC+AzwHvAaSfHOxwWMzSyz+j6IC9nJralPGx5LyVqjbSYCSYcBE4BPAn8C3gY6Al8HBkh6APhxRKwtcJxmW8nyQdxcndzq60ldyHGUXKRkLS3fE8FRwDkR8WbtDZLaA8eQDCHxYAFiMwNa54B2zXU9FylZa7DNRBARl29j2ybgN80dkJW3LN++a9631Ad/c/akrn1/LlKy1qDJlcWSzoyI25szGLP6FHsU00L2pC72vZnB9rUamgQ4EVhBlOoops31dFHIegTXUZSffJXF8xraBHyq+cOxclfzQVmqmpoAshQpNbUeIUvFu+soSlu+J4JPAV8BVtdaL+CPBYnIypoHtMumqfUITa14dx1FacuXCB4Fdo6IqtobJD1diIDMrK6mFCllbZHkegpTRNsa1bmioiIqKyuLHYZZq9NvwmPAPz7kGxpHqSkf+rXP3dTzWPFImhsRFfVt8xATZiWqkN/0XWdQWpwIzEpEIWeOq31u1xmUFhcNmVmj1VdUBC4uas22VTRUyMnrzayMLFixNtMkP9b6ZC4aknRLRIxvaNnMykd9xVAuLmq7GlNHMC3PspmVifrqIeobJwpcXNQWZC4aioi521quj6RRkhZKWiRpQj3b+0iaLelPkuZJOiprPGbW+rm4qG3IN8TEI0CDtckRcew2jm0HTCEZproamCNpRkQsyNntCuD+iPh/kgaSTGvZL3v4ZtZauLio7cpXNHTDdpz7IGBRRCwGkHQvMBrITQQB7JK+7wos347rmVkReXiQtivffATP1LyX1AnoExELM567F/BWznI1UPsv5SrgCUkXAZ2BL9d3IknjgfEAffq4rNHMrDllqiOQ9DWgCng8XR4qaUYzXH8scEdE9CaZDe3nkurEFBG3RERFRFT06NGjGS5rZmY1srYauoqkqOdpgIioktQ/zzHLgD1ylnun63KdBYxKz/m8pI5Ad5K5kc2sjauvJZFbEbU+WVsNfRQR79Zal69L8hxgL0n9Je0AnALUfop4E/gSgKTPAx2BlRljMrM2xq2IWqesTwTzJY0D2knaC7iYPPMRRMQmSRcCs4B2wPSImC/paqAyImYA3wF+JulSksRyRrS1MS/MrEG1WxL1m/AYLy55x30NWpmsieAi4N+BD4B7SD7cr8l3UETMJGkSmrtuYs77BcCIrMGaWduSpSWRRzItvkyJICI2Av8u6fpkMdYVNiwzK0Xua9A6ZUoEkg4EpgNd0uV3gW9m6V1sZlYj69AULipqWVkri28DvhUR/SKiH3ABcHvBojKzsuUK5ZaXtY5gc0Q8V7MQEb+XtKlAMZlZGfGkN8WXb6yh/dO3z0iaRlJRHMAY0j4FZmbbw0NTFF++J4If11q+Mue9m3mamZWAfGMNHdZSgZiZWXE0Zoayo4F9SXr/AhARVxciKDMzazlZB52bSlIvcBEg4CSgbwHjMjOzFpL1ieCfImKIpHkRMUnSj4HfFjIwMytPnvKy5WXtR/Be+nOjpM8AHwG7FyYkM7OtuW9BYWV9InhUUjfgR8DLJC2Gbi1UUGZWvjwMRcvLOtZQzQBzD0p6FOhYz7DUZmbbzf0KWl6+DmXHb2MbEfHr5g/JzMxaUr4ngq9tY1sATgRmZm1cvg5lZ7ZUIGZmVhxZWw2ZmVmJciIwMytzmYeYaApJo4CfksxZfGtEXFdr+0+AmvGMdgJ6RkS3QsZkZm2PO5kVVtYZynYimWi+T0Sck05g/7mIeHQbx7QDpgBHANXAHEkz0nmKAYiIS3P2vwj4QtNuw8zKjec6bj5ZnwhuB+YCw9PlZcCvgAYTAXAQsCgiFgNIuhcYDSxoYP+xbD3MtZkZ4E5mhZY1EQyIiDGSxkIymb0k5TmmF/BWznI1UG9PEUl9gf7AUw1sHw+MB+jTx9nfrNy4k1lhZa0s/lBSJ9LJaCQNAD5oxjhOAR6IiM31bYyIWyKiIiIqevTo0YyXNTOzrE8EVwGPA3tI+iUwAjgjzzHLgD1ylnun6+pzCnBBxljMzKwZZR1r6AlJc4EvksxHcElE/G+ew+YAe0nqT5IATgHG1d5J0j7AroAL/MzMiiBrq6FHgLuBGRGxIcsxEbFJ0oXALJLmo9MjYr6kq4HKiJiR7noKcG9EeA5kM7MiyFo0dAPJDGXXSZoD3As8GhHvb+ugiJgJzKy1bmKt5asyR2tmZs0ua9HQM8Azad+Aw4FzgOnALgWMzczMWkBjJq/vRDIa6Rhgf+DOQgVlZmYtJ2sdwf0kHcQeB24GnomIjwsZmJmZtYysTwS3AWMbaudvZmZtV74Zyg6PiKeAzsDo2p2JPUOZmVnbl++J4BCSYR/qm6nMM5SZmZWAfDOU1QwCd3VELMndlnYUMzOzNi7rWEMP1rPugeYMxMzMiiNfHcE+wL5AV0nH52zaBehYyMDMzKxl5Ksj+BxwDNCNresJ1pF0KjMzaxXufvFNHq6qO66lZzHLL18dwcPAw5KGR4QHhTOzVqP29JU1y8P6f3LLPp7FLJt8RUP/FhE/BMbVTEqTKyIuLlhkZmaNMKz/J+t8+/csZtnkKxp6Pf1ZWehAzMwao77pK7Oorwip3IuP8hUNPZL+3DKukKRPADtHxNoCx2Zm1qCmTl/5cNUyFqxYy8DdkzEzXXyUfayhu4HzgM0kE87sIumnEfGjQgZnZrY9atcjAFuSwH3nDq+zrVxl7UcwMH0C+DrwW5KJ5v+lUEGZmRXKwN13YfTQXsUOo1XJOuhcB0kdSBLBzRHxkSTPKGZmrVpT6xHKTdYngmnAUpLB556V1BfIW0cgaZSkhZIWSZrQwD4nS1ogaX5aBGVm1ix+cfYwJ4EMss5QdhNwU86qv0k6bFvHpLOZTQGOAKqBOZJmRMSCnH32Ar4HjIiI1ZJ6NvYGzMxs+2R6IpDUVdJkSZXp68ckTwfbchCwKCIWR8SHJPMcj661zznAlIhYDRARbzcyfjMz205Zi4amkwwrcXL6WgvcnueYXsBbOcvV6bpcewN7S/qDpBckjcoYj5mZNZOslcUDIuKEnOVJkqqa6fp7AYcCvUnqHwZHxJrcnSSNB8YD9OlTvm19zcwKIesTwXuSDq5ZkDQCeC/PMcuAPXKWe6frclUDMyLio3S+g7+QJIatRMQtEVERERU9evTIGLKZmWWR9YngPOAuSV3T5dXA6XmOmQPslU5gsww4BRhXa5/fAGOB2yV1JykqWpwxJjOz7VZfpzMor2En8iYCSUOBPUk+yJcBZBleIiI2SboQmAW0A6ZHxHxJVwOVETEj3XakpAUkvZYvj4hVTb0ZM7PmUG7DTuQbfXQicBowF/ghcG1E/CzrySNiJjCz1rqJOe8D+Hb6MjNrcfV1Oiu3YSfyPRGMAYZGxEZJuwGPA5kTgZlZa+cOZ/kriz+IiI0AaZFN1splMzNrI/I9EXxW0oz0vYABOctExLEFi8zMzFpEvkRQuyfwDYUKxMzMiiPfxDTPtFQgZmZWHNss85f0iKSvpUNQ1972WUlXS/pm4cIzM7NCy1c0dA5J084bJb0DrAQ6Av2Av5LMTfBwQSM0M2th9XUyK+UOZvmKhv4O/Bvwb5L6AbuTDC3xl5rWRGZmpa7UO5hlHWKCiFhKMjmNmVlZKJd5jTMnAjOzclHT27hcOBGYmdVSbr2N3VPYzKzMZXoiSOcfuAromx4jkjHjPlu40MzMrCVkLRq6DbiUZBTSzYULx8ys9Sn1OQuyJoJ3I+K3BY3EzKwNKaUmpVkTwWxJPwJ+DXxQszIiXi5IVGZmrUipz1mQNRHU3H1FzroADm/ecMzMWp9Sb0WUKRFExGGFDsTMzIojU/NRSV0lTZZUmb5+nDOR/baOGyVpoaRFkibUs/0MSSslVaWvs5tyE2Zm1nRZ+xFMB9YBJ6evtcDt2zpAUjtgCvBVYCAwVtLAena9LyKGpq9bM0duZmbNImsdwYCIOCFneZKkqjzHHAQsiojFAJLuJZnoZkGjozQzs4LJ+kTwnqSDaxbSDmbv5TmmF/BWznJ1uq62EyTNk/SApD3qO5Gk8TXFUitXrswYspmZZZH1ieB84M60XkDAO8AZzXD9R4B7IuIDSecCd1JPS6SIuAW4BaCioiKa4bpmZtullDqZZW01VAXsJ2mXdHlthsOWAbnf8Hun63LPuypn8Vbgh1niMTNrjdpqJ7NtJgJJp0XELyR9u9Z6ACJi8jYOnwPsJak/SQI4BRhX6zy7R8SKdPFY4PXGhW9mVhyl1Mks3xNB5/Rnl8aeOCI2SboQmAW0A6ZHxHxJVwOVETEDuFjSscAmmq+4ycys4Eqpk1m+qSqnpT8nNeXkETETmFlr3cSc998DvteUc5uZWfPI2qHsh5J2kdRB0u/STmCnFTo4MzMrvKytho6MiH+TdBzJvMXHA88CvyhUYGZmbU3tlkQ1y8P6f3Kr/Vpby6Ks/QhqEsbRwK8i4t0CxWNmVtIWrFjLw1XL8u/YgrI+ETwq6c8kncjOl9QDeL9wYZmZtT21WxKdduuLWy1D62xZlLUfwQRJPySZoGazpA0kw0WYmVmqdkuittKyKF8/gsMj4ilJx+esy93l14UKzMzMWka+J4JDgKeAr9WzLXAiMDNr8/L1I7gy/Xlmy4RjZmYtLWs/gv8rqVvO8q6SflCwqMzMrMVkbT761YhYU7MQEauBowoSkZmZtaisiaCdpB1rFiR1Anbcxv5mZtZGZO1H8Evgd5Jqpqc8k2TuADMza+Oy9iO4XtIrwJfTVddExKzChWVmZi0l6xMBJHMFbIqI/5a0k6QuEbGuUIGZmVnLyNpq6BzgAWBauqoX8JsCxWRmZi0oa2XxBcAIYC1ARLwB9CxUUGZm1nKyJoIPIuLDmgVJ7Ul6FpuZWRuXNRE8I+n/AJ0kHQH8Cngk30GSRklaKGmRpAnb2O8ESSGpImM8ZmbWTLImgu8CK4FXgXNJpp+8YlsHSGoHTAG+CgwExkoaWM9+XYBLgBezh21mZs0lb6uh9AN9fkTsA/ysEec+CFgUEYvT89xLMnT1glr7XQNcD1zeiHObmVkzyftEEBGbgYWSGjuvWi/grZzl6nTdFpL2B/aIiMcaeW4zM2smWfsR7ArMl/QSsKFmZUQc29QLS/oEMBk4I8O+44HxAH36tJ55Ps3MSkHWRPD9Jpx7GbBHznLvdF2NLsAg4Ol0sptPAzMkHRsRlbkniohbgFsAKioq3FrJzKwZ5ZuhrCNwHrAnSUXxbRGxKeO55wB7SepPkgBOAcbVbIyId4HuOdd6GrisdhIwM7PCyldHcCdQQZIEvgr8OOuJ04RxITCLZHiK+yNivqSrJTW5SMnMzJpXvqKhgRExGEDSbcBLjTl5RMwkaWqau25iA/se2phzm5lZ88j3RPBRzZtGFAmZmVkbku+JYD9Ja9P3IulZvDZ9HxGxS0GjMzOzgss3eX27lgrEzMyKI+sQE2ZmVqKcCMzMypwTgZlZmXMiMDMrc04EZmZlzonAzKzMORGYmZU5JwIzszLnRGBmVuacCMzMypwTgZlZmXMiMDMrc04EZmZlzonAzKzMORGYmZW5giYCSaMkLZS0SNKEerafJ+lVSVWSfi9pYCHjMTOzuvLNUNZkktoBU4AjgGpgjqQZEbEgZ7e7I2Jquv+xwGRgVKFiMjMrtheXvAPAmGnPb1k3emgvxg3rU6yQCpcIgIOARRGxGEDSvcBoYEsiiIi1Oft3BqIpF/roo4+orq7m/fff345wzdqGjh070rt3bzp06FDsUKwZLFiRfAyWaiLoBbyVs1wNDKu9k6QLgG8DOwCH13ciSeOB8QB9+tT9ZVVXV9OlSxf69euHpO2P3KyVighWrVpFdXU1/fv3L3Y41gQH79kdgF+cnXwc5j4ZFEvRK4sjYkpEDAC+C1zRwD63RERFRFT06NGjzvb333+f3XbbzUnASp4kdtttNz/9tmG/OHvYliTQWhQyESwD9shZ7p2ua8i9wNebejEnASsX/lu35lbIRDAH2EtSf0k7AKcAM3J3kLRXzuLRwBsFjMfMzOpRsEQQEZuAC4FZwOvA/RExX9LVaQshgAslzZdURVJPcHqh4im0nXfeebvPUVlZycUXX9zg9qVLl3L33Xdn3h+gX79+DB48mCFDhnDIIYfwt7/9bbvjbC5Tp07lrrvuapZzrVixgmOOOWardf/6r/9Kr169+Pjjj7esu+OOO+jRowdDhw5l4MCB/OxnP9vuay9ZsoRhw4ax5557MmbMGD788MM6+3z00UecfvrpDB48mM9//vNce+21W7Z985vfpGfPngwaNGirYy677DKeeuqp7Y7PLK+IaFOvAw44IGpbsGBBnXUtrXPnzgW/xuzZs+Poo49u1DF9+/aNlStXRkTExIkT4+yzz97uOD7++OPYvHnzdp+nOV122WXxm9/8Zsvy5s2bo0+fPjFs2LB46qmntqy//fbb44ILLoiIiP/5n/+J7t27x9///vftuvZJJ50U99xzT0REnHvuufFf//Vfdfb55S9/GWPGjImIiA0bNkTfvn1jyZIlERHxzDPPxNy5c2Pffffd6pilS5fGEUccUe81W8PfvDWPk6f+MU6e+seCXweojAY+VwvZaqgoJj0ynwXL1+bfsREGfmYXrvzavo0+rqqqivPOO4+NGzcyYMAApk+fzq677sqcOXM466yz+MQnPsERRxzBb3/7W1577TWefvppbrjhBh599FGeeeYZLrnkEiApE3722WeZMGECr7/+OkOHDuX000/nC1/4wpb9169fz0UXXURlZSWSuPLKKznhhBO2imf48OHcdNNNAKxcuZLzzjuPN998E4Abb7yRESNGsHLlSsaNG8fy5csZPnw4Tz75JHPnzmX9+vV85StfYdiwYcydO5eZM2dy//33c//99/PBBx9w3HHHMWnSJDZs2MDJJ59MdXU1mzdv5vvf/z5jxoxhwoQJzJgxg/bt23PkkUdyww03cNVVV7Hzzjtz2WWXNfi7OvTQQxk2bBizZ89mzZo13HbbbYwcObLO7/rBBx/kBz/4wZblp59+mn333ZcxY8Zwzz33cNhhh9U5pmfPngwYMIC//e1vfOpTn2r0vy8kX6SeeuqpLU9qp59+OldddRXnn3/+VvtJYsOGDWzatIn33nuPHXbYgV122QWAf/7nf2bp0qV1zt23b19WrVrF3//+dz796U83KT6zLIreaqiUfeMb3+D6669n3rx5DB48mEmTJgFw5plnMm3aNKqqqmjXrl29x95www1MmTKFqqoqnnvuOTp16sR1113HyJEjqaqq4tJLL91q/2uuuYauXbvy6quvMm/ePA4/vG5L3Mcff5yvf/3rAFxyySVceumlzJkzhwcffJCzzz4bgEmTJnH44Yczf/58TjzxxC2JAuCNN97gW9/6FvPnz2fhwoW88cYbvPTSS1RVVTF37lyeffZZHn/8cT7zmc/wyiuv8NprrzFq1ChWrVrFQw89xPz585k3bx5XXFG3cVhDvyuATZs28dJLL3HjjTdutb7GkiVL2HXXXdlxxx23rLvnnnsYO3Ysxx13HI899hgfffRRneMWL17M4sWL2XPPPbdav3DhQoYOHVrva82aNVvtu2rVKrp160b79sl3qt69e7NsWd02ESeeeCKdO3dm9913p0+fPlx22WV88pOfrLNfbfvvvz9/+MMf8u5ntj1K7omgKd/cC+Hdd99lzZo1HHLIIUDyTfGkk05izZo1rFu3juHDhwMwbtw4Hn300TrHjxgxgm9/+9uceuqpHH/88fTu3Xub1/vv//5v7r333i3Lu+6665b3hx12GO+88w4777wz11xzzZb9Fyz4RyfvtWvXsn79en7/+9/z0EMPATBq1KitztO3b1+++MUvAvDEE0/wxBNP8IUvfAGA9evX88YbbzBy5Ei+853v8N3vfpdjjjmGkSNHsmnTJjp27MhZZ53FMcccU6csv6HfVY3jjz8egAMOOKDeb84rVqwgt1nxhx9+yMyZM5k8eTJdunRh2LBhzJo1a8t177vvPn7/+9+z4447Mm3atDofyJ/73Oeoqqra1q+70V566SXatWvH8uXLWb16NSNHjuTLX/4yn/3sZ7d5XM+ePVm+fHmzxmJWW8klglIxYcIEjj76aGbOnMmIESOYNWtWk881e/ZsunXrxqmnnsqVV17J5MmT+fjjj3nhhRfo2LFj5vN07tx5y/uI4Hvf+x7nnntunf1efvllZs6cyRVXXMGXvvQlJk6cyEsvvcTvfvc7HnjgAW6++eZGVYLWfNNv164dmzZtqrO9U6dOW7WrnzVrFmvWrGHw4MEAbNy4kU6dOm1JBGPGjOHmm29u8HoLFy5kzJgx9W57+umn6dat25bl3XbbjTVr1rBp0ybat29PdXU1vXr1qnPc3XffzahRo+jQoQM9e/ZkxIgRVFZW5k0E77//Pp06ddrmPmbby0VDBdK1a1d23XVXnnvuOQB+/vOfc8ghh9CtWze6dOnCiy++CLDVt/hcf/3rXxk8eDDf/e53OfDAA/nzn/9Mly5dWLduXb37H3HEEUyZMmXL8urVq7fa3r59e2688Ubuuusu3nnnHY488kj+8z//c8v2mm/AI0aM4P777weSb/21z1PjK1/5CtOnT2f9+vUALFu2jLfffpvly5ez0047cdppp3H55Zfz8ssvs379et59912OOuoofvKTn/DKK69k+l1ltffee2/1pHDPPfdw6623snTpUpYuXcqSJUt48skn2bhxY6bz1TwR1PfKTQKQlP0fdthhPPDAAwDceeedjB49us45+/TpsyX5bdiwgRdeeIF99tknbyx/+ctf6rQmstKzYMVaxkx7fsur34TH6Dfhsa3WjZn2PJMemV+Q6zsRNJONGzfSu3fvLa/Jkydz5513cvnllzNkyBCqqqqYOHEiALfddhvnnHMOQ4cOZcOGDXTt2rXO+W688UYGDRrEkCFD6NChA1/96lcZMmQI7dq1Y7/99uMnP/nJVvtfccUVrF69mkGDBrHffvsxe/bsOufcfffdGTt2LFOmTOGmm26isrKSIUOGMHDgQKZOnQrAlVdeyRNPPMGgQYP41a9+xac//Wm6dOlS51xHHnkk48aNY/jw4QwePJgTTzyRdevW8eqrr3LQQQcxdOhQJk2axBVXXMG6des45phjGDJkCAcffDCTJ0+uc76GfldZdO7cmQEDBrBo0SI2btzI448/ztFHH73V9oMPPphHHnkk8zkb4/rrr2fy5MnsueeerFq1irPOOguAGTNmbLmPCy64gPXr17Pvvvty4IEHcuaZZzJkyBAAxo4dy/Dhw1m4cCG9e/fmtttuA5Imp4sWLaKioqIgcVvrMHpoLwbuvktRY1DSqqjtqKioiMrKyq3Wvf7663z+858vUkSNt379+i39Dq677jpWrFjBT3/60yJHlfjggw9o164d7du35/nnn+f8889v9vLyQnjooYeYO3fuVi2H2rqHHnqIl19+eUu9Tq629jdvjXParUmJQXMORSFpbkTU+63CdQRF8Nhjj3HttdeyadMm+vbtyx133FHskLZ48803Ofnkk/n444/ZYYcdmqXDVUs47rjjWLVqVbHDaFabNm3iO9/5TrHDsCJo6bGI/ERg1gb5b94aa1tPBCVTR9DWEppZU/lv3ZpbSSSCjh07smrVKv8HsZIX6XwEjWn2a5ZPSdQR9O7dm+rqalauXFnsUMwKrmaGMrPmUhKJoEOHDp6tycysiUqiaMjMzJrOicDMrMw5EZiZlbk2149A0kqgqdNsdQf+txnDaQt8z+XB91wetuee+0ZEj/o2tLlEsD0kVTbUoaJU+Z7Lg++5PBTqnl00ZGZW5pwIzMzKXLklgluKHUAR+J7Lg++5PBTknsuqjsDMzOoqtycCMzOrxYnAzKzMlWQikDRK0kJJiyRNqGf7jpLuS7e/KKlfEcJsVhnu+duSFkiaJ+l3kvoWI87mlO+ec/Y7QVJIavNNDbPcs6ST03/r+ZLubukYm1uGv+0+kmZL+lP6931UMeJsLpKmS3pb0msNbJekm9LfxzxJ+2/3RSOipF5AO+CvwGeBHYBXgIG19vkWMDV9fwpwX7HjboF7PgzYKX1/fjncc7pfF+BZ4AWgothxt8C/817An4Bd0+WexY67Be75FuD89P1AYGmx497Oe/5nYH/gtQa2HwX8FhDwReDF7b1mKT4RHAQsiojFEfEhcC8wutY+o4E70/cPAF+SpBaMsbnlveeImB0RG9PFF4C2Po5xln9ngGuA64H3WzK4Aslyz+cAUyJiNUBEvN3CMTa3LPccQM3s712B5S0YX7OLiGeBd7axy2jgrki8AHSTtPv2XLMUE0Ev4K2c5ep0Xb37RMQm4F1gtxaJrjCy3HOus0i+UbRlee85fWTeIyIea8nACijLv/PewN6S/iDpBUmjWiy6wshyz1cBp0mqBmYCF7VMaEXT2P/veZXEfASWnaTTgArgkGLHUkiSPgFMBs4ocigtrT1J8dChJE99z0oaHBFrihlUgY0F7oiIH0saDvxc0qCI+LjYgbUVpfhEsAzYI2e5d7qu3n0ktSd5nFzVItEVRpZ7RtKXgX8Hjo2ID1ootkLJd89dgEHA05KWkpSlzmjjFcZZ/p2rgRkR8VFELAH+QpIY2qos93wWcD9ARDwPdCQZnK1UZfr/3hilmAjmAHtJ6i9pB5LK4Bm19pkBnJ6+PxF4KtJamDYq7z1L+gIwjSQJtPVyY8hzzxHxbkR0j4h+EdGPpF7k2IioLE64zSLL3/ZvSJ4GkNSdpKhocQvG2Nyy3PObwJcAJH2eJBGU8ry1M4BvpK2Hvgi8GxErtueEJVc0FBGbJF0IzCJpcTA9IuZLuhqojIgZwG0kj4+LSCplTilexNsv4z3/CNgZ+FVaL/5mRBxbtKC3U8Z7LikZ73kWcKSkBcBm4PKIaLNPuxnv+TvAzyRdSlJxfEZb/mIn6R6SZN49rfe4EugAEBFTSepBjgIWARuBM7f7mm3492VmZs2gFIuGzMysEZwIzMzKnBOBmVmZcyIwMytzTgRmZmXOicAKTtJmSVWSXpP0iKRuzXz+pWmbeSStb2CfTpKekdROUj9J76UxLZA0Ne2J3JhrVki6KX1/qKR/ytl2nqRvbM89pee5StJlefa5Q9KJjThnv4ZGtay1339Ieqv271PShZK+mfV61jY4EVhLeC8ihkbEIJJ+GxcUIYZvAr+OiM3p8l8jYigwhGTEyq835mQRURkRF6eLhwL/lLNtakTctb0BF9kjJAO+1Tad0h/Lp+w4EVhLe550gCxJAyQ9LmmupOck7ZOu/5SkhyS9kr7+KV3/m3Tf+ZLGN/K6pwIP116ZDjr4R2DP9NvyU/rHnA190uuelD7NvCLp2XTdoZIeVTKXxXnApekTxsiab/KS9pH0Us210vO/mr4/IH1CmStplvKMHinpHElz0hgelLRTzuYvS6qU9BdJx6T7t5P0o/SYeZLObcwvKyJeqK+3ajqC7VJJ9SUJa6OcCKzFSGpHMhRATa/fW4CLIuIA4DLgv9L1NwHPRMR+JOOyz0/XfzPdtwK4WFKmEWPToQk+GxFL69m2UxrTq8B/AndGxBDgl2kcABOBr6TxbNUbOz3nVOAn6VPPcznb/gzsIKl/umoMcJ+kDum1TkzvZzrwH3lu49cRcWAaw+sk4+vU6Efy7f1oYKqkjun2dyPiQOBA4JycOGru/TOSZua5bn0qgZFNOM5aqZIbYsJapU6SqkieBF4HnpS0M0lxSs2QFwA7pj8PB74BkBblvJuuv1jScen7PUgGU8syfEJ3YE2tdQPSmAJ4OCJ+K+nnwPHp9p8DP0zf/wG4Q9L9wK8zXC/X/SQJ4Lr05xjgcyQD4j2Z3ns7IN9YMYMk/QDoRjJUyKzca6Qjbb4haTGwD3AkMCSn/qArye/rLzUHRcRykqEKGuvt9BpWIpwIrCW8FxFD02/fs0jqCO4A1qTl9HlJOhT4MjA8IjZKeppkcLFM169n379mvXZEnCdpGMk37rmSDsh4XYD7SJLdr5NTxRuSBgPzI2J4I85zB/D1iHhF0hmkA8vVhFg7ZJLZqy6KiNyEgZpnWtaOJL9TKxEuGrIWk5YvX0wySNhGYImkk2DLPKz7pbv+jmQ6zZqy7q4k32hXp0lgH5JhpbNedzXQLi0y2ZY/8o8BCE8FnktjGBARL0bERJJRLfeoddw6kmGv67v2X0kGf/s+SVIAWAj0UDJ2PpI6SNo3T2xdgBVpsdKptbadJOkTkgaQTOm4kCThnp/uj6S9JXXOc42s9gbytjyytsOJwFpURPwJmEcymcipwFmSXiGpB6iZgvAS4LC0YnUuSauex4H2kl4nKWZ5oZGXfgI4OM8+FwFnSpoH/EsaB8CPJL2aNrv8I8m8ubkeAY6rqSyu57z3AafxjzHzPyQZ/vz69N6ryGl11IDvAy+SFFP9uda2N4GXSGadOy8i3gduBRYAL6dxT6NWCcC26ggk/VDJyJc7SaqWdFXO5hHAk3nitTbEo49aWVAybeWlEfEvxY6lLVMyr8W3/XssLX4isLIQES8Ds9OWS9Z03UmeTqyE+InAzKzM+YnAzKzMORGYmZU5JwIzszLnRGBmVuacCMzMytz/B64PGVsnHRwqAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(\n", " y_true=test[\"label_num\"],\n", " y_pred=score,\n", " name=\"LogisticRegression\",\n", ")" ] }, { "cell_type": "markdown", "id": "f527f53b", "metadata": {}, "source": [ "## 不均衡データに対応する\n", "\n", "`class_weight` パラメータで不均衡データに対応できます。" ] }, { "cell_type": "code", "execution_count": 8, "id": "2531e8d1", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Pipeline(steps=[('vect',\n",
       "                 TfidfVectorizer(tokenizer=<method 'split' of 'str' objects>)),\n",
       "                ('clf', LogisticRegression(class_weight='balanced'))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', LogisticRegression(class_weight='balanced'))])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe_weight = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=str.split)),\n", " (\"clf\", LogisticRegression(class_weight=\"balanced\"))\n", "])\n", "\n", "pipe_weight.fit(train[\"tokens\"], train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 9, "id": "ea8517c4", "metadata": {}, "outputs": [], "source": [ "score_weight = pipe_weight.predict_proba(test[\"tokens\"])[:,1]" ] }, { "cell_type": "markdown", "id": "196f6141", "metadata": {}, "source": [ "class_weightオプションを付けないモデルと比較します。" ] }, { "cell_type": "code", "execution_count": 10, "id": "5938278d", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "_, ax = plt.subplots()\n", "for name, pred in [\n", " (\"LogisticRegression\", score),\n", " (\"LogisticRegression+weight\", score_weight),\n", "]:\n", " PrecisionRecallDisplay.from_predictions(ax=ax, y_true=test[\"label_num\"], y_pred=pred, name=name)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }