Rust製Deep Learningフレームワーク「Burn」を使ってTitanicデータを予測してみた

Rust ロゴ

最近、Rust製のDeep Learningフレームワーク Burn を知りました。

普段は競馬予測AIをPyTorchで作っていますが、

「PyTorchで訓練したモデルをONNX経由でBurnから推論できる」

という情報を見かけて興味を持ちました。

今回の構成

「PyTorch → ONNX → Burn」で推論できることを確認する

ことを目的に、Titanicデータセットを使って試してみました。なので学習精度は二の次になります。

今回試した構成は次の通りです。

Python(PyTorch)
↓
モデル学習
↓
ONNX出力
↓
Rust(Burn)
↓
推論

将来的には競馬予測AIでも、

Python
↓
学習

Rust
↓
推論

という構成になれれば良いな、と思ってます。

Python側でモデルを学習

データセットはKaggleのTitanicデータセットを使用しました。

  • train.csv
  • test.csv
  • gender_submission.csv

の3ファイルがありますが、今回は学習用の train.csv のみ使用しています。

使用した特徴量は以下の5項目です。

Pclass
Sex
Age
Fare
Embarked

欠損値補完とカテゴリ変換を行い、StandardScalerで標準化しています。

モデルはシンプルなMLPです。

class TitanicNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
        )

    def forward(self, x):
        return self.net(x)

学習結果は次のようになりました。

Epoch 100 | train_loss: 0.5236 | valid_loss: 0.5497 | accuracy: 0.7709

今回は精度向上が目的ではないので、この程度でOKとします。

ONNX形式で出力

学習後のモデルをONNX形式で保存します。

(※) PyTorchモデルを一度「ONNX(オニキス)」という共通フォーマットで出力し、Burnのビルドスクリプト(build.rs)を使ってRustのコードへ自動変換する方法、という事らしいです。

dummy_input = torch.randn(1, len(feature_names), dtype=torch.float32)

torch.onnx.export(
    model,
    dummy_input,
    ONNX_PATH,
    input_names=["input"],
    output_names=["output"],
    opset_version=16,
)

実行すると次のファイルが生成されました。

models/
├── titanic.onnx
└── titanic.onnx.data

最初は

「.onnx.dataって何だろう?」

と思いましたが、最近のPyTorchではモデル構造と重みデータが分離される場合があるようで、

そのため、この2ファイルはセットで扱います。

Rustプロジェクト作成

続いてBurn用のRustプロジェクトを作成します。

cargo new predict_burn

Cargo.tomlにはBurn関連のライブラリを追加しました。

[dependencies]
burn = { version = "0.21", features = ["ndarray"] }
burn-store = "0.21"

[build-dependencies]
burn-onnx = "0.21"

ONNXからBurnコードを生成

build.rsを作成します。

use burn_onnx::ModelGen;

fn main() {
    ModelGen::new()
        .input("../train_titanic/src/models/titanic.onnx")
        .out_dir("model")
        .run_from_script();
}

その後、

cargo build

を実行すると、ONNXモデルからBurn用のRustコードが生成されます。

Burnで推論

生成されたモデルを読み込み、適当な入力データで推論を実行してみます。

let output = model.forward(input);

println!("{:?}", output);

結果は次のようになりました。

Tensor { primitive: Float(F32(Owned([[-0.103424996]], shape=[1, 1], ... ))) }

最初は

「なんだこの値は?」

と思いましたが、これは確率ではなく logit でした。

学習時に BCEWithLogitsLoss() を使用しているためです。

シグモイドで確率に変換

Burnでは次のようにシグモイド関数を適用できます。

use burn::tensor::activation;

let prob = activation::sigmoid(output);

println!("{:?}", prob);

実行結果は次のようになりました。

Tensor { primitive: Float(F32(Owned([[0.47416678]], shape=[1, 1], ... ))) }

つまり、

生存確率 47.4%

という意味になります。

やってみた感想

今回の一番の収穫は、

PyTorch
↓
ONNX
↓
Burn

の流れが実際に動くことを確認できたことです。

正直なところ、

「PyTorchで作ったモデルから、Rustで推論が(自分に)できるのだろうか?」

という半信半疑の状態で始めましたが、チャッピーの力を借りて、思ったよりスムーズに動作しました。

競馬予測AIでは学習はまだまだPyTorchを使うと思いますが、

学習: Python(PyTorch)

推論: Rust(Burn)

という構成も十分現実的だと感じています。

まだBurnを触り始めたばかりですが、今後は実際の競馬予測モデルでも試してみたいと思います。

以上になります。またお会いしましょう

鹿児島県の出水市という所に住んでいまして、インターネット周辺で色々活動して行きたいと思ってるところです。 Webサイト作ったり、サーバ設定したり、プログラムしたりしている、釣りと木工好きなMacユーザです。 今はデータサイエンスに興味を持って競馬AI予想を頑張ってます。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です


reCaptcha の認証期間が終了しました。ページを再読み込みしてください。

コメントする

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください