TensorFlow

【TensorFlow】Tensorの形状を変更する方法

【TensorFlow】Tensorの形状を変更する方法

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説します。

Tensorの形状を変更する方法

ディープラーニングに関するプログラミングをしていると色々な場面で、Tensorの形状を変換したくなるケースに直面します。そのような時にTensorの形状を変更する方法を紹介します。

Tensorの形状を変更する tf.reshape

基本的な使い方

Tensorの形状を変更したい場合には、以下のようにtf.reshapeを使って変換します。いくつかのケースで例を示します。

import numpy as np
import tensorflow as tf

n = 20
tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16)
print(tensor, "\n")

# ===== Tensorの形状を変更する
# 行ベクトル
reshaped_tensor_1 = tf.reshape(tensor, shape=(1, n))
print(reshaped_tensor_1, "\n")

# 列ベクトル
reshaped_tensor_2 = tf.reshape(tensor, shape=(n, 1))
print(reshaped_tensor_2, "\n")

# 2階のTensorへ変更
reshaped_tensor_3 = tf.reshape(tensor, shape=(5, 4))
print(reshaped_tensor_3, "\n")

# 3階のTensorへ変更
reshaped_tensor_4 = tf.reshape(tensor, shape=(5, 2, 2))
print(reshaped_tensor_4)
【実行結果】
tf.Tensor(
[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
 19. 20.], shape=(20,), dtype=float16) 

tf.Tensor(
[[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
  19. 20.]], shape=(1, 20), dtype=float16) 

tf.Tensor(
[[ 1.]
 [ 2.]
 [ 3.]
 [ 4.]
 [ 5.]
 [ 6.]
 [ 7.]
 [ 8.]
 [ 9.]
 [10.]
 [11.]
 [12.]
 [13.]
 [14.]
 [15.]
 [16.]
 [17.]
 [18.]
 [19.]
 [20.]], shape=(20, 1), dtype=float16) 

tf.Tensor(
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]
 [17. 18. 19. 20.]], shape=(5, 4), dtype=float16) 

tf.Tensor(
[[[ 1.  2.]
  [ 3.  4.]]

 [[ 5.  6.]
  [ 7.  8.]]

 [[ 9. 10.]
  [11. 12.]]

 [[13. 14.]
  [15. 16.]]

 [[17. 18.]
  [19. 20.]]], shape=(5, 2, 2), dtype=float16)

上記例では、ただの数値列を行ベクトル、列ベクトル、2階のTensor、3階のTensorに形状変更しています。使用方法は簡単で、変換したいTensorと変換後の形状をshape引数に渡すのみです。

サイズが変更前後で合わないとエラー

tf.reshapeを使用した形状変更をする場合、形状変更するサイズが変更前後で合わないと以下のようにエラーとなるので注意しましょう。

import numpy as np
import tensorflow as tf

n = 20
tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16)
print(tensor, "\n")

# Tensorの形状を変更する(サイズが合わないとエラー)
reshaped_tensor = tf.reshape(tensor, shape=(4, 2, 2))
print(reshaped_tensor)
【実行結果】
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 20 values, but the requested shape has 16 [Op:Reshape]

上記の例では元のTensorのサイズが20なので、変更後もサイズが20でないといけません。しかし、変更後は4×2×2=16で一致していないためエラーとなります。

Note

tf.reshapeの公式ドキュメントはこちらを参照してください。

まとめ

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説しました。

Tensorの形状を変更したい場合には、tf.reshapeを使って変換します。Tensorの形状を変更する場面はよく発生しますので、しっかり使いこなせるようにしておくとよいでしょう。