TensorFlow

【TensorFlow】Tensorのシャッフル方法

【TensorFlow】Tensorのシャッフル方法

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorをシャッフルする方法を解説します。

Tensorのシャッフル方法

ディープラーニングに関する処理の実装では、データをシャッフルする場合が頻繁に発生します。

モデルの訓練の際には、データセットをいくつかのバッチ単位で学習させるといったことをしますが、その際に同じ特徴のデータが固まっている場合、そのままの順番のまま区切るようなことをすると一つのバッチに含まれるデータの種類が偏ってしまいます。

このような際には、データセットのTensorをランダムにシャッフルしてから処理することが大切です。以降では、Tensorをランダムにシャッフルする方法について説明していきます。

Tensorのシャッフル方法 tf.random.shuffle

Tensorの要素をランダムにシャッフルするには、以下のようにtf.random.shuffleを使用します。

実行結果を固定するためにtf.random.set_seedでglobal seedを設定し、shuffleメソッドでは、seed引数によってoperation seedを設定しています。2種類のシードの違いについては後ほど説明します。

結果が毎回ランダムになるように動作させ場合は、シード設定はしないようにすればよいです。

import numpy as np
import tensorflow as tf

tensor = tf.constant(np.arange(1, 51), shape=(10, 5))
print(tensor, "\n")

# グローバルな乱数シードの設定
tf.random.set_seed(0)

# Tensorをシャッフルする (オペレーションの乱数シードの設定)
shuffled_tensor = tf.random.shuffle(tensor, seed=1)
print(shuffled_tensor)
【実行結果】
tf.Tensor(
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]
 [11 12 13 14 15]
 [16 17 18 19 20]
 [21 22 23 24 25]
 [26 27 28 29 30]
 [31 32 33 34 35]
 [36 37 38 39 40]
 [41 42 43 44 45]
 [46 47 48 49 50]], shape=(10, 5), dtype=int32) 

tf.Tensor(
[[36 37 38 39 40]
 [11 12 13 14 15]
 [26 27 28 29 30]
 [41 42 43 44 45]
 [16 17 18 19 20]
 [21 22 23 24 25]
 [31 32 33 34 35]
 [ 1  2  3  4  5]
 [46 47 48 49 50]
 [ 6  7  8  9 10]], shape=(10, 5), dtype=int32)

結果を見ると分かるように、tf.random.shuffleは0番目の軸に対してデータをシャッフルします。上記の例では、行がデータ番号、列が属性値を表していると考えると、データの順番がうまくシャッフルできることが分かるかと思います。

また、例えば画像データで(データ数, チャンネル, 縦, 横)のようなTensorを用意してある場合を考えてみましょう。この場合では、shuffleで0番目の軸でシャッフルされることで、画像データの順番がうまくシャッフルされることになります。

Note

tf.random.shuffleの公式ドキュメントはこちらを参照してください。

乱数シードの種類と挙動

TensorFlowでの乱数シードには2つの種類があります。tf.random.set_seedで設定したglobal seedとtf.random.shuffleのseed引数で指定したoperation seedという2つです。これらの組み合わせで乱数の生成シーケンスが決まってきます。

乱数シードに関する説明については、公式ドキュメントのこちらに記載があります。ここでも簡単に整理しておきます。

  1. global seedとoperationseedがどちらも設定されていない場合:ランダムなシードが選択されて使用される。
  2. global seedが設定されていて、operation seedが設定されていない場合:同じバージョンのTensorFlowとユーザーコード内では、決まったランダムシーケンスが生成されるが、異なるバージョンのTensorFlowの場合、順序が異なる可能性がある。
  3. operation seedが設定されていて、global seedが設定されていない場合:デフォルトのglobal seedと指定したoperation seedでランダムシーケンスが生成される。
  4. global seedとoperation seedが両方設定されている場合:両方の指定したシードを使用してランダムシーケンスが生成される。

上記のように、2つの乱数シードによって乱数の生成シーケンスが変わってきます。バージョン等含めて同じような結果になるようにしたい場合には、global seedとoperation seedの両方を指定するとよいようです。

まとめ

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorをシャッフルする方法を解説しました。具体的には、tf.random.shuffleの使い方について紹介しています。

ディープラーニングでは、モデルの訓練時等でデータセットをシャッフルするということをよく行いますので、使い方をしっかり覚えておきましょう。

また、乱数シードには、global seedとoperation seedがあり、それらによって乱数のシーケンスが決まることについても説明しました。ディープラーニングの実装をしているときの動作確認で、毎回乱数が変わってしまうと精度の変化が確認できず困ってしまいますので、乱数のシードは固定して試すことが多くなります。乱数シードの設定方法と挙動についてよく理解しておくとよいでしょう。

Pythonによるディープラーニング」はTensorFlow/Kerasの中~上級者向けの本ですが非常におすすめできる書籍です。CNN, RNN, Transformer, GAN等高度なモデルも扱っており面白く、TensorFlow/Kerasの実装力をつけることができますので是非読んでみてもらえるといいと思います。