Rust製Deep Learningフレームワーク「Burn」を使ってTitanicデータを予測してみた
最近、Rust製のDeep Learningフレームワーク Burn を知りました。
普段は競馬予測AIをPyTorchで作っていますが、
「PyTorchで訓練したモデルをONNX経由でBurnから推論できる」
という情報を見かけて興味を持ちました。
今回の構成
「PyTorch → ONNX → Burn」で推論できることを確認する
ことを目的に、Titanicデータセットを使って試してみました。なので学習精度は二の次になります。
今回試した構成は次の通りです。
Python(PyTorch)
↓
モデル学習
↓
ONNX出力
↓
Rust(Burn)
↓
推論
将来的には競馬予測AIでも、
Python
↓
学習
Rust
↓
推論
という構成になれれば良いな、と思ってます。
Python側でモデルを学習
データセットはKaggleのTitanicデータセットを使用しました。
- train.csv
- test.csv
- gender_submission.csv
の3ファイルがありますが、今回は学習用の train.csv のみ使用しています。
使用した特徴量は以下の5項目です。
Pclass
Sex
Age
Fare
Embarked
欠損値補完とカテゴリ変換を行い、StandardScalerで標準化しています。
モデルはシンプルなMLPです。
class TitanicNet(nn.Module):
def __init__(self, input_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 16),
nn.ReLU(),
nn.Linear(16, 8),
nn.ReLU(),
nn.Linear(8, 1),
)
def forward(self, x):
return self.net(x)
学習結果は次のようになりました。
Epoch 100 | train_loss: 0.5236 | valid_loss: 0.5497 | accuracy: 0.7709
今回は精度向上が目的ではないので、この程度でOKとします。
ONNX形式で出力
学習後のモデルをONNX形式で保存します。
(※) PyTorchモデルを一度「ONNX(オニキス)」という共通フォーマットで出力し、Burnのビルドスクリプト(build.rs)を使ってRustのコードへ自動変換する方法、という事らしいです。
dummy_input = torch.randn(1, len(feature_names), dtype=torch.float32)
torch.onnx.export(
model,
dummy_input,
ONNX_PATH,
input_names=["input"],
output_names=["output"],
opset_version=16,
)
実行すると次のファイルが生成されました。
models/
├── titanic.onnx
└── titanic.onnx.data
最初は
「.onnx.dataって何だろう?」
と思いましたが、最近のPyTorchではモデル構造と重みデータが分離される場合があるようで、
そのため、この2ファイルはセットで扱います。
Rustプロジェクト作成
続いてBurn用のRustプロジェクトを作成します。
cargo new predict_burn
Cargo.tomlにはBurn関連のライブラリを追加しました。
[dependencies]
burn = { version = "0.21", features = ["ndarray"] }
burn-store = "0.21"
[build-dependencies]
burn-onnx = "0.21"
ONNXからBurnコードを生成
build.rsを作成します。
use burn_onnx::ModelGen;
fn main() {
ModelGen::new()
.input("../train_titanic/src/models/titanic.onnx")
.out_dir("model")
.run_from_script();
}
その後、
cargo build
を実行すると、ONNXモデルからBurn用のRustコードが生成されます。
Burnで推論
生成されたモデルを読み込み、適当な入力データで推論を実行してみます。
let output = model.forward(input);
println!("{:?}", output);
結果は次のようになりました。
Tensor { primitive: Float(F32(Owned([[-0.103424996]], shape=[1, 1], ... ))) }
最初は
「なんだこの値は?」
と思いましたが、これは確率ではなく logit でした。
学習時に BCEWithLogitsLoss() を使用しているためです。
シグモイドで確率に変換
Burnでは次のようにシグモイド関数を適用できます。
use burn::tensor::activation;
let prob = activation::sigmoid(output);
println!("{:?}", prob);
実行結果は次のようになりました。
Tensor { primitive: Float(F32(Owned([[0.47416678]], shape=[1, 1], ... ))) }
つまり、
生存確率 47.4%
という意味になります。
やってみた感想
今回の一番の収穫は、
PyTorch
↓
ONNX
↓
Burn
の流れが実際に動くことを確認できたことです。
正直なところ、
「PyTorchで作ったモデルから、Rustで推論が(自分に)できるのだろうか?」
という半信半疑の状態で始めましたが、チャッピーの力を借りて、思ったよりスムーズに動作しました。
競馬予測AIでは学習はまだまだPyTorchを使うと思いますが、
学習: Python(PyTorch)
推論: Rust(Burn)
という構成も十分現実的だと感じています。
まだBurnを触り始めたばかりですが、今後は実際の競馬予測モデルでも試してみたいと思います。
以上になります。またお会いしましょう





