機械学習、とりわけ深層学習でで大きなデータを扱うときに、はじめにメモリにすべてロードすることができない場合は少なくありません。
そんな大きなデータセットを扱う上で便利なのが PyTorch の iterable-style dataset です。 Iterable-style dataset を使うことで、サンプルをはじめに全てロードすることなく、学習に必要になったときにサンプルを準備して返すことができます。
その際に問題になるのが学習データのシャッフルです。
Map-style dataset では、DataLoader の shuffle
パラメータを True
に設定することで学習データをシャッフルします。
一方で iterable-style dataset でははじめに全データをロードするわけではなく、この方法が使えません。実際、iterable-style dataset で shuffle
パラメータを True
に設定すると例外が発生します。
この問題に対処するために、PyTorch 1.8.1 では BufferedShuffleDataset が提供されていました。
torch.utils.data.BufferedShuffleDataset(dataset, buffer_size)
BufferedShuffleDatasetは buffer_size
で指定したサイズのバッファを内部で作成し、サンプルはまずバッファに格納されます。
そして、バッファが満たされたらそのうちの一つをランダムサンプルして返します。そうするとバッファに一つ空きができますので、次のサンプルをバッファに格納します。
これを続けることで、バッファサイズ分のシャッフルを行いながらサンプルを返していくのです。
今回 PyTorch を 1.8.1 から 1.9.0 に上げてみたところ、
>>> import torch
>>> torch.utils.data.BufferedShuffleDataset
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: module 'torch.utils.data' has no attribute 'BufferedShuffleDataset'
のようなエラーが出てしまいました。
ぶーちゃんもショックを受けているようです。
ぶーちゃん:ぶ〜〜〜〜。。。。(うーなんでエラーが出ちゃうんだろう、しょんぼり。。。)
PyTorchのコミットを調べてみると次のようなものを見つけました。
コミットを読んでいくと、どうやらtorch.utils.data
以下から DataPipe という機能群に移されたようです。
実際にコードを確認すると torch.utils.data.datapipes.iter.combinatorics
で ShuffleIterDataPipe
というクラスで定義され、torch.utils.data.datapipes.iter の名前空間で Shuffle という名前でimport されていることがわかります。
ということは、次のようにすればうまく動くのでは…
ぶーちゃん:ターンッ!(うごけ!)
>>> torch.utils.data.datapipes.iter.Shuffle
<class 'torch.utils.data.datapipes.iter.combinatorics.ShuffleIterDataPipe'>
インポートできているようです。 実際に引数にジェネレータを渡して動くか確かめてみましょう。
>>> shuffle_dataset = torch.utils.data.datapipes.iter.Shuffle(range(10), buffer_size=3)
>>> list(shuffle_dataset)
[2, 0, 3, 5, 1, 6, 7, 8, 4, 9]
期待通り動いているようですね!
ぶーちゃん:ぶおぉぉぉおおおっ!!!!(うまくいった!喜んだときの得意技、耳倒立)
今回は PyTorch 1.8.1 で提供されていた ShuffledBufferDataset が PyTorch 1.9.0 でインポートできてなくなっている原因を調べました。
コミットを調査すると torch.utils.data.datapipes.iter.Shuffle
に移動したようです。
PyTorch 1.9.0 のリリースノート にも書かれていなかったのであまり使われていない機能なのでしょうか。
もともと TensorFlow を使っていて tf.data.Dataset の shuffle
でバッファ付きのシャッフルという機能を知り、その後 PyTorch 1.8.1 で PyTorch に移行してきてから同等の機能を求めてドキュメントを読んでいたら見つけたクラスでした。
PyTorch 1.9.0 のリリースで削除されてしまったかと思いましたが、少なくとも今のところは torch.utils.data.datapipes.iter.Shuffle
を利用すればよさそうです。