{ "cells": [ { "cell_type": "markdown", "id": "92746262", "metadata": {}, "source": [ "# 感情分析" ] }, { "cell_type": "markdown", "id": "7ec44e6a", "metadata": {}, "source": [ "## タスク設定\n", "\n", "宿に関するレビューに関して、次を行うモデルを開発します。\n", "\n", "* レビューの中からなるべく多くのネガティブな意見を把握する\n", "* 何件レビューが来るかわからないため、現時点で全件チェックできるかわからない。従って、信頼度が大きいものから順になるべく効率的にネガティブな意見を把握する\n", "\n", "## 評価方法\n", "\n", "タスク設定を考えると、Recallを主にした評価が適切そうです。\n", "\n", "信頼度が大きいものから順になるべく効率的にネガティブな意見を把握したいということなので、\n", "学習時にはPrecision-Recallカーブの下の面積であるAverage Precisionを最大化するように学習し、\n", "最終的なレポート時には\n", "「recall 80%のもとでprecisionが~%」のような表現にする方針がよさそうです。" ] }, { "cell_type": "markdown", "id": "b8a621dd", "metadata": {}, "source": [ "## データ作成\n", "\n", "ここでは、JRTEコーパスを利用します。" ] }, { "cell_type": "markdown", "id": "e0d41bb8", "metadata": {}, "source": [ "## アノテーション結果の検証\n", "\n", "アノテーション結果の検証については\n", "[アノテーション結果の検証](annotation.ipynb)\n", "に詳しく書いていますので、そちらを確認してください。\n", "\n", "今回は最終的にジャッジが一致しているデータを利用することにします。" ] }, { "cell_type": "code", "execution_count": 1, "id": "8adb4496", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "data = pd.read_csv(\"input/pn_same_judge.csv\")" ] }, { "cell_type": "markdown", "id": "38288ac5", "metadata": {}, "source": [ "## データの概要把握" ] }, { "cell_type": "code", "execution_count": 2, "id": "286e53da", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextjudges
0neutral出張でお世話になりました。{\"0\": 3}
1neutral朝食は普通でした。{\"0\": 3}
2positiveまた是非行きたいです。{\"1\": 3}
3positiveまた利用したいと思えるホテルでした。{\"1\": 3}
4neutral新婚旅行で利用しました。{\"0\": 3}
\n", "
" ], "text/plain": [ " label text judges\n", "0 neutral 出張でお世話になりました。 {\"0\": 3}\n", "1 neutral 朝食は普通でした。 {\"0\": 3}\n", "2 positive また是非行きたいです。 {\"1\": 3}\n", "3 positive また利用したいと思えるホテルでした。 {\"1\": 3}\n", "4 neutral 新婚旅行で利用しました。 {\"0\": 3}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.head()" ] }, { "cell_type": "code", "execution_count": 3, "id": "b3442dd1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(4186, 3)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.shape" ] }, { "cell_type": "code", "execution_count": 4, "id": "efe880e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 4186 entries, 0 to 4185\n", "Data columns (total 3 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 label 4186 non-null object\n", " 1 text 4186 non-null object\n", " 2 judges 4186 non-null object\n", "dtypes: object(3)\n", "memory usage: 98.2+ KB\n" ] } ], "source": [ "data.info()" ] }, { "cell_type": "code", "execution_count": 5, "id": "164d852a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "label 0\n", "text 0\n", "judges 0\n", "dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 欠損値がないことを確認\n", "data.isnull().sum()" ] }, { "cell_type": "code", "execution_count": 6, "id": "13522c2f", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countuniquetopfreq
label41863positive2835
text41864186出張でお世話になりました。1
judges41863{\"1\": 3}2835
\n", "
" ], "text/plain": [ " count unique top freq\n", "label 4186 3 positive 2835\n", "text 4186 4186 出張でお世話になりました。 1\n", "judges 4186 3 {\"1\": 3} 2835" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# ユニークな値を確認\n", "data.describe(include=[object]).T" ] }, { "cell_type": "code", "execution_count": 7, "id": "898ffc3b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countuniquetopfreq
label41863positive2835
text41864186出張でお世話になりました。1
judges41863{\"1\": 3}2835
\n", "
" ], "text/plain": [ " count unique top freq\n", "label 4186 3 positive 2835\n", "text 4186 4186 出張でお世話になりました。 1\n", "judges 4186 3 {\"1\": 3} 2835" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.describe().T" ] }, { "cell_type": "code", "execution_count": 8, "id": "a4c2eee2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "positive 2835\n", "neutral 749\n", "negative 602\n", "Name: label, dtype: int64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[\"label\"].value_counts()" ] }, { "cell_type": "markdown", "id": "71e869c6", "metadata": {}, "source": [ "最後に学習用に数字のラベルも付与しておきましょう。\n", "ネガティブな意見を抽出したいので、negativeラベルに1を、その他のラベルpositive, neutralには0を付与しておきます。" ] }, { "cell_type": "code", "execution_count": 9, "id": "6856d278", "metadata": {}, "outputs": [], "source": [ "data[\"label_num\"] = data[\"label\"].map({\"positive\": 0, \"neutral\": 0, \"negative\": 1})" ] }, { "cell_type": "code", "execution_count": 59, "id": "8a34934e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextjudges
countuniquetopfreqcountuniquetopfreqcountuniquetopfreq
label_num
035842positive283535843584出張でお世話になりました。135842{\"1\": 3}2835
16021negative602602602期待していただけに残念でした。16021{\"-1\": 3}602
\n", "
" ], "text/plain": [ " label text \\\n", " count unique top freq count unique top freq \n", "label_num \n", "0 3584 2 positive 2835 3584 3584 出張でお世話になりました。 1 \n", "1 602 1 negative 602 602 602 期待していただけに残念でした。 1 \n", "\n", " judges \n", " count unique top freq \n", "label_num \n", "0 3584 2 {\"1\": 3} 2835 \n", "1 602 1 {\"-1\": 3} 602 " ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.groupby(\"label_num\").describe()" ] }, { "cell_type": "markdown", "id": "a14e6fc5", "metadata": {}, "source": [ "## データの分割\n", "\n", "テストデータの内容は確認したいことが重要なので、探索的データ分析をする前に分割しておきましょう。\n", "\n", "positiveとneutralは共に負例として扱いますが、分布自体は保っておきたいため `stratify` 引数には\n", "`data[\"label_num\"]` では **なく** `data[\"label\"]` を指定していることに注意してください。" ] }, { "cell_type": "code", "execution_count": 12, "id": "104b8a83", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "train, test = train_test_split(data, test_size=0.1, random_state=0, stratify=data[\"label\"])" ] }, { "cell_type": "code", "execution_count": 13, "id": "2ca5e42c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((3767, 4), (419, 4))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.shape, test.shape" ] }, { "cell_type": "markdown", "id": "2a2354a4", "metadata": {}, "source": [ "## ベースラインモデルの構築\n", "\n", "まずは簡単なモデルを作成し、どの程度の精度ができるのかを確かめてみましょう。\n", "\n", "ここでは、次のモデルを学習してCVで結果を比較します。\n", "\n", "* Naive Bayes\n", "* Logistic Regression\n", "* SVM\n", "* Random Forest\n", "\n", "ベクトル化にはTfidfVectorizerを利用することにします。" ] }, { "cell_type": "code", "execution_count": 15, "id": "17c019ff", "metadata": {}, "outputs": [], "source": [ "from sklearn.naive_bayes import MultinomialNB\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.svm import SVC\n", "from sklearn.ensemble import RandomForestClassifier\n", "\n", "from sklearn.pipeline import Pipeline\n", "from sklearn.feature_extraction.text import TfidfVectorizer" ] }, { "cell_type": "markdown", "id": "00e791f3", "metadata": {}, "source": [ "結果のPRカーブを書くためには各フォールドのprediction結果を保持しておく必要があります。\n", "そこで、GridSearchCVではなくParameterGridを使ってカスタムでループを回します。" ] }, { "cell_type": "code", "execution_count": 19, "id": "a098b537", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import ParameterGrid\n", "import numpy as np\n", "\n", "\n", "def run_cv(pipe, params, cv, X, y, stratify):\n", " result = []\n", " for param in ParameterGrid(params):\n", " pipe.set_params(**param)\n", " print(pipe)\n", " pred = np.zeros((len(X), ))\n", " for fold_id, (train_idx, test_idx) in enumerate(cv.split(X=X, y=stratify)):\n", " print(\"Fold:\", fold_id)\n", " \n", " # neutralもstratifyするためにここでmapを行う\n", " pipe.fit(X=X.iloc[train_idx], y=y.iloc[train_idx])\n", " try:\n", " pred[test_idx] = pipe.predict_proba(X.iloc[test_idx])[:,1]\n", " except AttributeError:\n", " # For SVM\n", " pred[test_idx] = pipe.decision_function(X.iloc[test_idx])\n", "\n", " result.append((param, pred))\n", " return result" ] }, { "cell_type": "markdown", "id": "42cd5644", "metadata": {}, "source": [ "トークナイザを定義します。" ] }, { "cell_type": "code", "execution_count": 20, "id": "903a64fb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-05-09 05:54:09.828683: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", "2022-05-09 05:54:09.828770: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n" ] } ], "source": [ "import spacy\n", "\n", "nlp = spacy.load(\"ja_core_news_sm\")\n", "\n", "\n", "def tokenize(text, stopwords=set()):\n", " return [t.lemma_ for t in nlp(text) if t.lemma_ not in stopwords]" ] }, { "cell_type": "markdown", "id": "7da53770", "metadata": {}, "source": [ "クロスバリデーションの設定をします。\n", "\n", "今回はデータが少ないので、フォールドごとにテストデータが少なくなってしまうのをさけるために\n", "フォールド数は小さめに 3 に設定することにします。" ] }, { "cell_type": "code", "execution_count": 23, "id": "5f25d219", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import StratifiedKFold\n", "\n", "cv = StratifiedKFold(n_splits=3, shuffle=True, random_state=0)" ] }, { "cell_type": "code", "execution_count": 24, "id": "84291046", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', MultinomialNB())])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n", "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', LogisticRegression(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n", "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n", "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', RandomForestClassifier(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n" ] } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=tokenize)),\n", " (\"clf\", MultinomialNB())\n", "])\n", "\n", "params = [\n", " {\n", " \"clf\": [\n", " MultinomialNB(),\n", " LogisticRegression(random_state=0),\n", " SVC(random_state=0),\n", " RandomForestClassifier(random_state=0)\n", " ],\n", " },\n", "]\n", "\n", "# ここでも neutral と positive の分布を保つために\n", "# stratify には train[\"label_num\"] ではなく train[\"label\"]を渡していることに注意\n", "result = run_cv(pipe=pipe, params=params, cv=cv, X=train[\"text\"], y=train[\"label_num\"], stratify=train[\"label\"])" ] }, { "cell_type": "markdown", "id": "39fc801c", "metadata": {}, "source": [ "結果を描画します。" ] }, { "cell_type": "code", "execution_count": 25, "id": "65627e75", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "def show_pr(result, y):\n", " \"\"\"run_cvの結果からPRカーブを描画する関数\"\"\"\n", " _, ax = plt.subplots(figsize=(10, 5))\n", " for model_id, (param, pred) in enumerate(result):\n", " print(model_id, param)\n", " PrecisionRecallDisplay.from_predictions(ax=ax, y_true=y, y_pred=pred, name=f\"Model-{model_id}\")" ] }, { "cell_type": "code", "execution_count": 26, "id": "bb876caa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 {'clf': MultinomialNB()}\n", "1 {'clf': LogisticRegression(random_state=0)}\n", "2 {'clf': SVC(random_state=0)}\n", "3 {'clf': RandomForestClassifier(random_state=0)}\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_pr(result=result, y=train[\"label_num\"])" ] }, { "cell_type": "markdown", "id": "e7f7b867", "metadata": {}, "source": [ "SVCの結果が良好ですので、ここではベースラインとしてSVCを使うことにします。" ] }, { "cell_type": "markdown", "id": "f785b829", "metadata": {}, "source": [ "## モデルの改良" ] }, { "cell_type": "markdown", "id": "09e4dd4a", "metadata": {}, "source": [ "### 探索的データ分析\n", "\n", "探索的データ分析では、データを観察し分類に寄与しそうな仮説を探します。\n", "\n", "ここでは、\n", "BoWの次元は大きくなりやすく次元の呪いの問題を受けやすい問題に対して、\n", "ストップワードを取り入れて分類に寄与しそうにない単語は取り除いて\n", "精度が向上するかをみてみることにします。" ] }, { "cell_type": "markdown", "id": "5907e35a", "metadata": {}, "source": [ "まずはラベル毎に単語を頻度が大きい順にいくつか出力してみます。" ] }, { "cell_type": "code", "execution_count": 32, "id": "3177f04f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\ndef show_most_frequent_words(df, tokenize, stopwords=set(), top_n=50):\\n for label in df[\"label\"].unique():\\n print(f\"label: {label}, top-{top_n} most frequent word\")\\n pos_words = count_words(df.query(\\'label == @label\\'), tokenize=tokenize)\\n print(pd.Series(pos_words).sort_values(ascending=False).iloc[:top_n])\\n'" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from collections import Counter\n", "\n", "\n", "def count_words(sr, tokenize):\n", " cnt = Counter()\n", " for words in sr.apply(lambda sent: tokenize(sent)):\n", " cnt.update(words)\n", " return cnt\n", "\n", "\n", "\"\"\"\n", "def show_most_frequent_words(df, tokenize, stopwords=set(), top_n=50):\n", " for label in df[\"label\"].unique():\n", " print(f\"label: {label}, top-{top_n} most frequent word\")\n", " pos_words = count_words(df.query('label == @label'), tokenize=tokenize)\n", " print(pd.Series(pos_words).sort_values(ascending=False).iloc[:top_n])\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 37, "id": "c8362df0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('。', 488),\n", " ('が', 443),\n", " ('の', 437),\n", " ('た', 408),\n", " ('、', 277),\n", " ('です', 244),\n", " ('に', 226),\n", " ('は', 204),\n", " ('て', 201),\n", " ('ます', 196),\n", " ('だ', 182),\n", " ('と', 161),\n", " ('する', 145),\n", " ('ない', 107),\n", " ('も', 98),\n", " ('残念', 91),\n", " ('を', 82),\n", " ('か', 79),\n", " ('部屋', 76),\n", " ('で', 73),\n", " ('ある', 72),\n", " ('いる', 59),\n", " ('ぬ', 58),\n", " ('お', 56),\n", " ('風呂', 53),\n", " ('.', 47),\n", " ('ただ', 40),\n", " ('思う', 39),\n", " ('なる', 39),\n", " ('少し', 38),\n", " ('狭い', 33),\n", " ('れる', 30),\n", " ('朝食', 30),\n", " ('?', 29),\n", " ('から', 27),\n", " ('・', 26),\n", " ('てる', 25),\n", " ('時', 25),\n", " ('悪い', 25),\n", " ('ちょっと', 24),\n", " ('気', 24),\n", " ('方', 20),\n", " ('こと', 19),\n", " ('な', 19),\n", " ('フロント', 18),\n", " ('人', 17),\n", " ('良い', 17),\n", " ('掃除', 17),\n", " ('入る', 17),\n", " ('(', 16)]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cnt_1 = count_words(train.query('label_num == 1')[\"text\"], tokenize=tokenize)\n", "cnt_1.most_common(n=50)" ] }, { "cell_type": "code", "execution_count": 38, "id": "3448401e", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "[('。', 2928),\n", " ('た', 2324),\n", " ('です', 1605),\n", " ('ます', 1299),\n", " ('の', 1261),\n", " ('も', 1108),\n", " ('て', 813),\n", " ('が', 798),\n", " ('だ', 757),\n", " ('する', 722),\n", " ('は', 716),\n", " ('に', 705),\n", " ('、', 701),\n", " ('で', 475),\n", " ('と', 447),\n", " ('良い', 417),\n", " ('お', 384),\n", " ('を', 334),\n", " ('部屋', 297),\n", " ('利用', 290),\n", " ('ある', 268),\n", " ('とても', 228),\n", " ('美味しい', 202),\n", " ('ホテル', 193),\n", " ('また', 191),\n", " ('満足', 183),\n", " ('いる', 182),\n", " ('思う', 179),\n", " ('!', 178),\n", " ('風呂', 173),\n", " ('たい', 169),\n", " ('朝食', 167),\n", " ('から', 141),\n", " ('できる', 126),\n", " ('最高', 123),\n", " ('対応', 108),\n", " ('広い', 108),\n", " ('大', 105),\n", " ('なる', 104),\n", " ('こと', 102),\n", " ('行く', 101),\n", " ('ない', 101),\n", " ('綺麗', 96),\n", " ('よい', 96),\n", " ('いただく', 91),\n", " ('せる', 89),\n", " ('温泉', 88),\n", " ('方', 80),\n", " ('清潔', 79),\n", " ('か', 78)]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cnt_0 = count_words(train.query('label_num == 0')[\"text\"], tokenize=tokenize)\n", "cnt_0.most_common(n=50)" ] }, { "cell_type": "markdown", "id": "834eff21", "metadata": {}, "source": [ "この結果をみながらストップワードを定義してみましょう。\n", "\n", "```{note}\n", "問題によってストップワードが何かという基準は変わってきます。\n", "やみくもに公開されているストップワード辞書を使うのではなく、データを見てストップワードを決めましょう。\n", "```\n", "\n", "一方のラベルにのみ頻繁に出現しているような、各ラベルで特徴的な単語は残し、それ以外は除去を行います。\n", "ここでは次のようにストップワードを定義してみます。" ] }, { "cell_type": "code", "execution_count": 53, "id": "87341af5", "metadata": { "scrolled": false }, "outputs": [], "source": [ "stopwords = {\n", " \"。\", \"、\", \"!\", \"!\", \"?\", \"(\", \")\", \"(\", \")\",\n", " \"です\", \"ます\", \"する\", \"なる\", \"ある\", \"いる\",\n", " \"れる\", \"いう\", \"から\", \"てる\", \"せる\",\n", " \"が\", \"の\", \"た\", \"は\", \"に\", \"だ\", \"て\", \"と\", \"も\", \"を\", \"お\", \"で\", \"か\",\n", " \"こと\", \"方\", \"的\", \n", " \"な\", \"ば\", \"ね\", \"や\", \"ず\", \"つ\", \"ぬ\",\n", " \"よう\", \"ござる\",\n", " \".\", \"・\",\n", "}" ] }, { "cell_type": "code", "execution_count": 54, "id": "c9294a1a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('ない', 107),\n", " ('残念', 91),\n", " ('部屋', 76),\n", " ('風呂', 53),\n", " ('ただ', 40),\n", " ('思う', 39),\n", " ('少し', 38),\n", " ('狭い', 33),\n", " ('朝食', 30),\n", " ('時', 25),\n", " ('悪い', 25),\n", " ('ちょっと', 24),\n", " ('気', 24),\n", " ('フロント', 18),\n", " ('人', 17),\n", " ('良い', 17),\n", " ('掃除', 17),\n", " ('入る', 17),\n", " ('しまう', 16),\n", " ('時間', 15),\n", " ('もう', 15),\n", " ('だけ', 15),\n", " ('とても', 15),\n", " ('駐車', 14),\n", " ('事', 14),\n", " ('露天', 14),\n", " ('トイレ', 13),\n", " ('髪の毛', 13),\n", " ('場', 13),\n", " ('音', 13),\n", " ('私', 13),\n", " ('チェック', 13),\n", " ('イン', 13),\n", " ('高い', 12),\n", " ('シャワー', 12),\n", " ('少ない', 12),\n", " ('1', 12),\n", " ('ホテル', 12),\n", " ('食事', 12),\n", " ('古い', 12),\n", " ('無い', 12),\n", " ('朝', 12),\n", " ('行く', 12),\n", " ('欲しい', 12),\n", " ('食べる', 12),\n", " ('大', 11),\n", " ('中', 11),\n", " ('洗面', 11),\n", " ('所', 11),\n", " ('夕食', 11)]" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_words(\n", " train.query('label_num == 1')[\"text\"],\n", " tokenize= lambda x: tokenize(x, stopwords=stopwords)\n", ").most_common(n=50)" ] }, { "cell_type": "code", "execution_count": 55, "id": "d037a603", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('良い', 417),\n", " ('部屋', 297),\n", " ('利用', 290),\n", " ('とても', 228),\n", " ('美味しい', 202),\n", " ('ホテル', 193),\n", " ('また', 191),\n", " ('満足', 183),\n", " ('思う', 179),\n", " ('風呂', 173),\n", " ('たい', 169),\n", " ('朝食', 167),\n", " ('できる', 126),\n", " ('最高', 123),\n", " ('対応', 108),\n", " ('広い', 108),\n", " ('大', 105),\n", " ('行く', 101),\n", " ('ない', 101),\n", " ('綺麗', 96),\n", " ('よい', 96),\n", " ('いただく', 91),\n", " ('温泉', 88),\n", " ('清潔', 79),\n", " ('接客', 77),\n", " ('気持ち', 75),\n", " ('食事', 70),\n", " ('宿', 69),\n", " ('旅行', 69),\n", " ('フロント', 68),\n", " ('快適', 68),\n", " ('駅', 66),\n", " ('頂く', 65),\n", " ('宿泊', 64),\n", " ('便利', 64),\n", " ('スタッフ', 60),\n", " ('いい', 60),\n", " ('ゆっくり', 58),\n", " ('感', 58),\n", " ('時間', 57),\n", " ('丁寧', 57),\n", " ('夕食', 56),\n", " ('露天', 51),\n", " ('本当', 51),\n", " ('まで', 50),\n", " ('楽しい', 49),\n", " ('近い', 49),\n", " ('入る', 49),\n", " ('出来る', 48),\n", " ('さん', 47)]" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count_words(\n", " train.query('label_num == 0')[\"text\"],\n", " tokenize= lambda x: tokenize(x, stopwords=stopwords)\n", ").most_common(n=50)" ] }, { "cell_type": "markdown", "id": "4d189a60", "metadata": {}, "source": [ "```{note}\n", "適切なストップワードの定義が見つかるまで、ストップワードの改善と結果のチェックを繰り返してください。\n", "```\n", "\n", "適切なストップワードが定義できたら、モデルに適用して結果を見てみます。" ] }, { "cell_type": "code", "execution_count": 56, "id": "d219903f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n", "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer= at 0x7f226eab0f70>)),\n", " ('clf', SVC(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n" ] } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer()),\n", " (\"clf\", SVC())\n", "])\n", "\n", "params = [\n", " {\n", " \"vect__tokenizer\": [tokenize, lambda x: tokenize(x, stopwords=stopwords)],\n", " \"clf\": [\n", " SVC(random_state=0),\n", " ],\n", " },\n", "]\n", "\n", "result = run_cv(pipe=pipe, params=params, cv=cv, X=train[\"text\"], y=train[\"label_num\"], stratify=train[\"label\"])" ] }, { "cell_type": "code", "execution_count": 57, "id": "26f4b4d4", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import PrecisionRecallDisplay\n", "\n", "def show_pr(result, y):\n", " \"\"\"run_cvの結果からPRカーブを描画する関数\"\"\"\n", " _, ax = plt.subplots(figsize=(10, 5))\n", " for model_id, (param, pred) in enumerate(result):\n", " print(model_id, param)\n", " PrecisionRecallDisplay.from_predictions(ax=ax, y_true=y, y_pred=pred, name=f\"Model-{model_id}\")" ] }, { "cell_type": "code", "execution_count": 58, "id": "b7ab70db", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 {'clf': SVC(random_state=0), 'vect__tokenizer': }\n", "1 {'clf': SVC(random_state=0), 'vect__tokenizer': at 0x7f226eab0f70>}\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_pr(result=result, y=train[\"label_num\"])" ] }, { "cell_type": "markdown", "id": "4b74e436", "metadata": {}, "source": [ "結果を見ると、ストップワードは精度の向上につながっていないようですので、今回は利用しないことにします。" ] }, { "cell_type": "markdown", "id": "63118c1e", "metadata": {}, "source": [ "### ハイパーパラメータを調整\n", "\n", "今回のデータは正例がすくないインバランスなデータセットです。\n", "SVCでは `class_weight=balanced` パラメータを指定することで不均衡データの問題を緩和できる可能性があるので、\n", "このパラメータが効くかをチェックしてみましょう。" ] }, { "cell_type": "code", "execution_count": 60, "id": "d35bbb7a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n", "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(class_weight='balanced', random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n" ] } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=tokenize)),\n", " (\"clf\", SVC(random_state=0))\n", "])\n", "\n", "params = [\n", " {\n", " \"clf__class_weight\": [None, \"balanced\"],\n", " },\n", "]\n", "\n", "result = run_cv(pipe=pipe, params=params, cv=cv, X=train[\"text\"], y=train[\"label_num\"], stratify=train[\"label\"])" ] }, { "cell_type": "code", "execution_count": 61, "id": "32dfa53b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 {'clf__class_weight': None}\n", "1 {'clf__class_weight': 'balanced'}\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "show_pr(result=result, y=train[\"label_num\"])" ] }, { "cell_type": "markdown", "id": "67935697", "metadata": {}, "source": [ "class_weightを使っても精度の向上は確認できないため、こちらも今回は利用しない方針にします。" ] }, { "cell_type": "markdown", "id": "09677aca", "metadata": {}, "source": [ "## テストセットでの評価\n", "\n", "テストセットを使って、選択したモデルの最終的なオフライン評価の結果を出します。\n", "\n", "クロスバリデーション時は学習データからバリデーション用のデータを切り出していました。\n", "テストセットでの評価時は、クロスバリデーションで決定したパラメータを用いて、\n", "**全ての学習データ** を使ってモデルを学習します。" ] }, { "cell_type": "code", "execution_count": 62, "id": "ab25edef", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(random_state=0))])" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=tokenize)),\n", " (\"clf\", SVC(random_state=0)),\n", "])\n", "\n", "pipe.fit(X=train[\"text\"], y=train[\"label_num\"])" ] }, { "cell_type": "code", "execution_count": 63, "id": "8ffba870", "metadata": {}, "outputs": [], "source": [ "test_pred = pipe.decision_function(X=test[\"text\"])" ] }, { "cell_type": "code", "execution_count": 65, "id": "a389d361", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "PrecisionRecallDisplay.from_predictions(y_true=test[\"label_num\"], y_pred=test_pred, name=f\"Tfidf-SVC\")" ] }, { "cell_type": "markdown", "id": "31fd4fd0", "metadata": {}, "source": [ "結果から読み取れることは、\n", "このモデルを使ってスコア付けして信頼度が大きい順に見ていくと、\n", "ネガティブなレビューデータのうち80%をチェックする(recall=0.8)までには、\n", "おおよそ10回に7回(precision=0.7)は実際にネガティブなレビューである、\n", "と言えます。\n", "\n", "もしもスコア付けしない場合だと、ネガティブなレビューの割合は" ] }, { "cell_type": "code", "execution_count": 68, "id": "8e111d13", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.1438810724714627" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train.query('label_num == 1')) / len(train)" ] }, { "cell_type": "markdown", "id": "1e8b40ba", "metadata": {}, "source": [ "なので、10回に1~2回程度しかネガティブなレビューが見つからない結果になるため、\n", "ネガティブなデータを発見するのに効率的になっていると結論付けられます。\n", "\n", "問題なければ実際の環境にデプロイを行いA/Bテストを実施してオンライン評価を行います。" ] }, { "cell_type": "markdown", "id": "0e7a1f9d", "metadata": {}, "source": [ "## 改善策の検討\n", "\n", "問題なくデプロイができたとしても、次のステップとして精度向上を目指すことになります。\n", "\n", "ここでは、現在のモデルにどのような問題があるのかをチェックしてみましょう。\n", "問題の把握には、実際にモデルの推論結果をチェックすることが重要です。\n", "推論結果からどのようなケースでミスが多いかを知ることができます。\n", "\n", "推論結果をチェックするときに重要なのは、テストセットでの推論結果は**見ない**ことです。\n", "テストセットでの結果を見てしまうと、テストセットに対してチューニングすることになってしまうため、\n", "テストセットでの評価は信頼できないものになってしまいます。\n", "\n", "そこで、CVの結果で誤った推論を行なっている結果を、間違えて信頼スコアを大きく出しているものからチェックしてみましょう。" ] }, { "cell_type": "code", "execution_count": 70, "id": "f53100b1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pipeline(steps=[('vect',\n", " TfidfVectorizer(tokenizer=)),\n", " ('clf', SVC(random_state=0))])\n", "Fold: 0\n", "Fold: 1\n", "Fold: 2\n" ] } ], "source": [ "pipe = Pipeline([\n", " (\"vect\", TfidfVectorizer(tokenizer=tokenize)),\n", " (\"clf\", SVC(random_state=0))\n", "])\n", "\n", "params = dict()\n", "\n", "result = run_cv(pipe=pipe, params=params, cv=cv, X=train[\"text\"], y=train[\"label_num\"], stratify=train[\"label\"])" ] }, { "cell_type": "code", "execution_count": 71, "id": "a38ce75d", "metadata": {}, "outputs": [], "source": [ "train[\"cv_baseline\"] = result[0][1]" ] }, { "cell_type": "markdown", "id": "ecd93bd0", "metadata": {}, "source": [ "信頼度が大きい方から順にソートした上で、ターゲットのラベルが期待していない `0` であるものをチェックします。" ] }, { "cell_type": "code", "execution_count": 73, "id": "e50c901b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
labeltextjudgeslabel_numcv_baseline
1107neutral悪くはありません。{\"0\": 3}00.873437
2811positive臭いもない。{\"1\": 3}00.732087
285neutralがゴロゴロ。{\"0\": 3}00.455168
3579neutralタバコの臭いが気にならない方で熱い風呂が好きなら良いのでは。{\"0\": 3}00.368951
3838positive前日の宿の部屋がちょっと狭かったので、部屋に入るなり広っ!{\"1\": 3}00.242181
3876neutral1階のレストランは夜の営業をやめてしまったようでモーニングのパニーニも以前ほど種類がない(5...{\"0\": 3}00.240477
1669positiveちょっと嬉しかったです。{\"1\": 3}00.218728
3334positiveまたベッドの上に折り鶴が置いてあったのですが、何となくホッコリしました。{\"1\": 3}00.182037
1254neutral部屋は狭いですが寝るだけなら十分です。{\"0\": 3}00.173821
3236neutral温泉がないですが、お風呂があります。{\"0\": 3}00.168143
\n", "
" ], "text/plain": [ " label text judges \\\n", "1107 neutral 悪くはありません。 {\"0\": 3} \n", "2811 positive 臭いもない。 {\"1\": 3} \n", "285 neutral がゴロゴロ。 {\"0\": 3} \n", "3579 neutral タバコの臭いが気にならない方で熱い風呂が好きなら良いのでは。 {\"0\": 3} \n", "3838 positive 前日の宿の部屋がちょっと狭かったので、部屋に入るなり広っ! {\"1\": 3} \n", "3876 neutral 1階のレストランは夜の営業をやめてしまったようでモーニングのパニーニも以前ほど種類がない(5... {\"0\": 3} \n", "1669 positive ちょっと嬉しかったです。 {\"1\": 3} \n", "3334 positive またベッドの上に折り鶴が置いてあったのですが、何となくホッコリしました。 {\"1\": 3} \n", "1254 neutral 部屋は狭いですが寝るだけなら十分です。 {\"0\": 3} \n", "3236 neutral 温泉がないですが、お風呂があります。 {\"0\": 3} \n", "\n", " label_num cv_baseline \n", "1107 0 0.873437 \n", "2811 0 0.732087 \n", "285 0 0.455168 \n", "3579 0 0.368951 \n", "3838 0 0.242181 \n", "3876 0 0.240477 \n", "1669 0 0.218728 \n", "3334 0 0.182037 \n", "1254 0 0.173821 \n", "3236 0 0.168143 " ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.sort_values(\"cv_baseline\", ascending=False).query('label_num == 0').head(n=10)" ] }, { "cell_type": "markdown", "id": "b71ab1ff", "metadata": {}, "source": [ "この結果から、次のようなレビューに対して正しく分類できていないことがわかります。\n", "\n", "* 否定系を含むレビュー(例:悪くはない)\n", "* 問題点はあるが、それは解消できることが後からわかるレビュー(例:温泉はないがお風呂はある)\n", "\n", "これらのケースに共通することは、文脈を捉える必要があるという点です。\n", "ですので、文脈を捉えるモデルであれば精度が改善するという仮説が立てられます。\n", "\n", "では、文脈を捉えるモデルとは何でしょうか。例えば以下の手法を検討することになるでしょう。\n", "\n", "* 特徴量としてN-gramを用いる。ただし、特徴量の次元が大きくなる点には注意が必要。\n", "* 文脈を考慮する文ベクトルを使ったニューラルネットワークを用いた手法。ただし、計算コストやモデルサイズに注意が必要。\n", "\n", "このようにベースラインを基準に、その結果を分析して次の方針を定めてモデルを改善していきます。" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }