Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説します。
Tensorの形状を変更する方法
ディープラーニングに関するプログラミングをしていると色々な場面で、Tensorの形状を変換したくなるケースに直面します。そのような時にTensorの形状を変更する方法を紹介します。
Tensorの形状を変更する tf.reshape
基本的な使い方
Tensorの形状を変更したい場合には、以下のようにtf.reshapeを使って変換します。いくつかのケースで例を示します。
import numpy as np import tensorflow as tf n = 20 tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16) print(tensor, "\n") # ===== Tensorの形状を変更する # 行ベクトル reshaped_tensor_1 = tf.reshape(tensor, shape=(1, n)) print(reshaped_tensor_1, "\n") # 列ベクトル reshaped_tensor_2 = tf.reshape(tensor, shape=(n, 1)) print(reshaped_tensor_2, "\n") # 2階のTensorへ変更 reshaped_tensor_3 = tf.reshape(tensor, shape=(5, 4)) print(reshaped_tensor_3, "\n") # 3階のTensorへ変更 reshaped_tensor_4 = tf.reshape(tensor, shape=(5, 2, 2)) print(reshaped_tensor_4)
【実行結果】 tf.Tensor( [ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.], shape=(20,), dtype=float16) tf.Tensor( [[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.]], shape=(1, 20), dtype=float16) tf.Tensor( [[ 1.] [ 2.] [ 3.] [ 4.] [ 5.] [ 6.] [ 7.] [ 8.] [ 9.] [10.] [11.] [12.] [13.] [14.] [15.] [16.] [17.] [18.] [19.] [20.]], shape=(20, 1), dtype=float16) tf.Tensor( [[ 1. 2. 3. 4.] [ 5. 6. 7. 8.] [ 9. 10. 11. 12.] [13. 14. 15. 16.] [17. 18. 19. 20.]], shape=(5, 4), dtype=float16) tf.Tensor( [[[ 1. 2.] [ 3. 4.]] [[ 5. 6.] [ 7. 8.]] [[ 9. 10.] [11. 12.]] [[13. 14.] [15. 16.]] [[17. 18.] [19. 20.]]], shape=(5, 2, 2), dtype=float16)
上記例では、ただの数値列を行ベクトル、列ベクトル、2階のTensor、3階のTensorに形状変更しています。使用方法は簡単で、変換したいTensorと変換後の形状をshape引数に渡すのみです。
サイズが変更前後で合わないとエラー
tf.reshapeを使用した形状変更をする場合、形状変更するサイズが変更前後で合わないと以下のようにエラーとなるので注意しましょう。
import numpy as np import tensorflow as tf n = 20 tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16) print(tensor, "\n") # Tensorの形状を変更する(サイズが合わないとエラー) reshaped_tensor = tf.reshape(tensor, shape=(4, 2, 2)) print(reshaped_tensor)
【実行結果】 tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 20 values, but the requested shape has 16 [Op:Reshape]
上記の例では元のTensorのサイズが20なので、変更後もサイズが20でないといけません。しかし、変更後は4×2×2=16で一致していないためエラーとなります。
Note
tf.reshapeの公式ドキュメントはこちらを参照してください。
まとめ
Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説しました。
Tensorの形状を変更したい場合には、tf.reshapeを使って変換します。Tensorの形状を変更する場面はよく発生しますので、しっかり使いこなせるようにしておくとよいでしょう。
ソースコード
上記で紹介しているソースコードについてはgithubにて公開しています。参考にしていただければと思います。