機械学習モデルの標準フォーマット
機械学習モデルにも標準規格を
Python の(Python に限らずだが)機械学習フレームワークは複数存在する。
それぞれのフレームワークで実装され、そして生成される機械学習モデル(以下、 ML モデル)は、当然ながらそのフレームワーク・ライブラリに特化したものになる。 したがって、例えばライブラリAで実装・学習した ML モデルは、ライブラリA上のみでしか推論・再学習などを行えない。 ライブラリAで作ったモデルはライブラリB上で推論を実行することはできない。 ML モデルはライブラリ間で互換性がないのが基本である。
ところがこうした ML モデルについても、標準規格を設けて普及させる流れがあるらしい。 究極的には ML モデルの実装に使用するライブラリ・フレームワークに制限されることなくモデルをデプロイ、推論に使えるような世の中になる(本当か?)
ONNX とは
ONNX(Open Neural Network eXchange, 「オニキス」と呼ぶらしい1)は、まさにその ML モデル表現の標準フォーマットとして策定されたとのこと。
公式によると
ONNX is an open format built to represent machine learning models. ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.
とあるので、やはり前述したような制限をなくすことを目標としていると見て良さそうだろう。
とはいえ、まだ策定されて日が浅い(初版は2017年9月)2ので、すべてのフレームワーク・ライブラリが対応しているわけではなく、また対応しているものでも一部機能に制限があるなどは十分考えられる。
今後どれほど普及していくか注目したいところだ。
ちょっとやってみた
PyTorch の学習済みモデルを ONNX モデルに変換するチュートリアルが用意されていたので、これを参考に勉強してみた。
基本的にチュートリアルに示されていることなので、コードの詳細は省略するが、基本的に PyTorch のモデルを変換する場合 torch.onnx.export
関数を使って学習済みモデルを ONNX ファイルを出力するということを抑えておきたい。
なお、エクスポートする前に推論モードにしておくことを忘れない。
model.eval()
チュートリアルでは、ここから更に ONNX をサポートするランタイムを持ってきて、それを使って ONNX モデルをロード・推論するパートが続く。 ONNX モデルを動かすランタイムは、クラウド・エッジを問わず色々なアーキテクチャで動くものが用意されている模様[^3]。
普通の x64 アーキテクチャなので、以下のようにランタイムをインストール。
pip install onnx
pip install onnxruntime
あとはこのランタイムを使って ONNX モデルを動かす。
>>> import onnxruntime
>>>
>>> # AlexNet というモデルを ONNX に書き出したものを予め保存
>>> ort_session = onnxruntime.InferenceSession('alexnet.onnx')
>>>
>>> outputs = ort_session.run(None, {'actual_input_1': np.random.randn(10, 3, 224, 224).astype(np.float32)})
>>> print(outputs[0])
[[-0.07325774 -1.4891567 -1.7034553 ... -1.09847 -1.1589897
1.3210275 ]
[-0.112299 -1.5629382 -1.5682833 ... -1.2115804 -0.6649598
0.9185266 ]
[-0.17365408 -0.8043759 -1.5489596 ... -1.323276 -0.7312381
0.92378426]
...
[-0.35993284 -1.3951778 -1.1727893 ... -1.1726025 -1.1295764
1.3196399 ]
[-0.40342498 -1.4207238 -1.6986057 ... -1.1848352 -0.64225703
0.9559887 ]
[ 0.3027296 -1.101376 -1.243785 ... -1.0959435 -0.9556453
0.9666304 ]]
確かに推論が動作している。
[^3]; https://onnxruntime.ai/