中野智文のブログ

データ・マエショリストのメモ

scikit learn の GridSearchCV で検証事例に class_weight(sample_weight) をつける

背景

GridSearchCV で検証事例に sample_weight をつけるような引数はまだ存在しない。

github.com

でも使いたい。(使わないとうまく行かねー)

metrics 関数自体は sample_weight に対応しているんだよね〜。

対応

対応した metrics 関数をつくり、make_scorer で呼び出そう。

以前の記事

nakano-tomofumi.hatenablog.com

にちなんだ例で作る。

from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, log_loss, make_scorer
from sklearn.utils import class_weight 


def balanced_log_loss(y_true, y_pred, cw, eps=1e-15, normalize=True, labels=None):
    sample_weight = [cw[y] for y in y_true]
    return log_loss(y_true=y_true,
                              y_pred=y_pred,
                              eps=eps,
                              sample_weight=sample_weight,
                              labels=labels)



...
    labels = [0, 1] # np.unique(y)
    cw_array = class_weight.compute_class_weight(y=y,
                                                                                    class_weight='balanced',
                                                                                    classes=labels)
    cw = dict(zip(labels, cw_array))
    clf = GridSearchCV(
        LogisticRegression(penalty='l1',
                           class_weight='balanced',
                           fit_intercept=False),
        {'C': [0.01, 0.1, 1, 10, 100]},
        cv=KFold(n_splits=20, shuffle=True),
        scoring={'neg_log_loss': make_scorer(balanced_log_loss, 
                                                                        labels=labels,
                                                                        cw=cw, 
                                                                        greater_is_better=False), 
                 'accuracy': 'accuracy'},
        n_jobs=-1,
        refit='neg_log_loss'
    )
    clf.fit(X, y)
...

ちょっとだけ説明すると、balanced_log_loss という独自の metrics 関数を作っている。 その引数に、dict 型の cw というものを用意した。class_weight の略だ。 そして、評価される真の値(y_true)に基づいて、重みを与えている。 もし、予想の値に対して重みを与えたければ、y_pred[cw[y] for y in y_true]の中のy_trueと書き換えれば良い。

さて、この cw だが、学習オプションのclass_weight='balanced'を設定すると内部で呼ばれる compute_class_weight を使っている。 今回は訓練事例全体で cw を求めたが、cross validation 中の訓練事例のみで求めることは難しいようである。 検証事例のみで求めたければ、先の自作関数の中で y_truey_pred から作れば良いのだが、それはナンセンスだろう。

このようなclass weight と、先に作った自作関数を、make_scorer で呼び出す。 今回も neg_log_loss に変換するために、greater_is_better=False オプションを付与して呼び出した。

まとめ

  • 学習時だけでなく、cross validation の評価時にも、sample_weight を使いたい。
  • GridSearchCV にはそのようなオプションはない。
  • そこで、make_scorer から呼び出すことができる自作の metrics 関数にそのようなオプションをつけて対応する。