XGBRankerを使って競馬AIのランク学習をやってみた

競馬AI
CatBoostや、LightGBMとのアンサンブル学習の続きでXGBoostのXGBRankerも試してみました。
XGBRankerもパラメータの指定だけでGPUを使って学習できるということで、計算はかなり早いです。

参考

LightGBM, CatBoostと違ったところ

  • XGBoostはint, float, boolのみなので、category変数をintへ変換する
  • 1つのブロック(1レース)単位の情報の作り方【*1】

学習していく

まだ私自身が学習始めたばっかりで間違った箇所があるかもしれませんが、そのときはご容赦くださればと思います。
import xgboost as xgb
1つのブロック(1レース)単位でランキングを出していきたいので、1レースのデータ数(出走頭数)を計算で求める。
(除外馬もいるかもしれないので特徴量の「出走頭数」からは取得しないでおく)
こんな感じで作って、XGBRankerに渡せば良いみたい。【*1】
train_groups = X_train.groupby('race_id').size().to_frame('size')['size'].to_numpy()
valid_groups = X_valid.groupby('race_id').size().to_frame('size')['size'].to_numpy()
groups = (train_groups, valid_groups)
パラメータの設定と学習
model = xgb.XGBRanker(
    tree_method      = 'gpu_hist', # GPUを使う
    booster          = 'gbtree',
    objective        = 'rank:pairwise',
    random_state     = 0,
    learning_rate    = 0.99, # 学習率
    colsample_bytree = 0.9,
    eta              = 0.05,
    max_depth        = 7,
    n_estimators     = 500,
    subsample        = 0.75, # 0 ~ 1
    min_child_weight = 5,   # 一番重要 決定木の下限
)

valids = (X_valid, y_valid)
model.fit(X_train, y_train, group=train_groups, eval_group=[valid_groups], eval_set=[valids], verbose='Verbose', early_stopping_rounds=15)

予測してみる

実際は芝レースとダートレースの2つのモデル(models)を作りました。
def predict(models, df):
    # XGBoostはint, float, boolのみなので、category変数をintへ変換する
    for col in df.columns:
        if df[col].dtypes == 'category':
            df[col] = df[col].astype(int)

    baba = df['surface']
    if baba.values[0] == 0:
        model = models[0]
    else:
        model = models[1]
    
    return model.predict(df.loc[:, ~df.columns.isin(['race_id'])])

predictions = (df_test_now.groupby('race_id').apply(lambda x: predict(models, x)))
上手く学習できているかはまだ分かってないけど、2016年〜2020年までのデータで学習・テストし、2021年のデータで検証した結果、1位の予想だけは回収率95%程だったので、とりあえずはこれで良しということにします。。。
以上になります

鹿児島県の出水市という所に住んでいまして、インターネット周辺で色々活動して行きたいと思ってるところです。 Webサイト作ったり、サーバ設定したり、プログラムしたりしている、釣りと木工好きなMacユーザです。 今はデータサイエンスに興味を持って競馬AI予想を頑張ってます。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です


The reCAPTCHA verification period has expired. Please reload the page.

コメントする

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください