TensorFlow

【TensorFlow】Tensorの形状を変更する方法

【TensorFlow】Tensorの形状を変更する方法

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説します。

Tensorの形状を変更する方法

ディープラーニングに関するプログラミングをしていると色々な場面で、Tensorの形状を変換したくなるケースに直面します。そのような時にTensorの形状を変更する方法を紹介します。

Tensorの形状を変更する tf.reshape

基本的な使い方

Tensorの形状を変更したい場合には、以下のようにtf.reshapeを使って変換します。いくつかのケースで例を示します。

import numpy as np
import tensorflow as tf

n = 20
tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16)
print(tensor, "\n")

# ===== Tensorの形状を変更する
# 行ベクトル
reshaped_tensor_1 = tf.reshape(tensor, shape=(1, n))
print(reshaped_tensor_1, "\n")

# 列ベクトル
reshaped_tensor_2 = tf.reshape(tensor, shape=(n, 1))
print(reshaped_tensor_2, "\n")

# 2階のTensorへ変更
reshaped_tensor_3 = tf.reshape(tensor, shape=(5, 4))
print(reshaped_tensor_3, "\n")

# 3階のTensorへ変更
reshaped_tensor_4 = tf.reshape(tensor, shape=(5, 2, 2))
print(reshaped_tensor_4)
【実行結果】
tf.Tensor(
[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
 19. 20.], shape=(20,), dtype=float16) 

tf.Tensor(
[[ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
  19. 20.]], shape=(1, 20), dtype=float16) 

tf.Tensor(
[[ 1.]
 [ 2.]
 [ 3.]
 [ 4.]
 [ 5.]
 [ 6.]
 [ 7.]
 [ 8.]
 [ 9.]
 [10.]
 [11.]
 [12.]
 [13.]
 [14.]
 [15.]
 [16.]
 [17.]
 [18.]
 [19.]
 [20.]], shape=(20, 1), dtype=float16) 

tf.Tensor(
[[ 1.  2.  3.  4.]
 [ 5.  6.  7.  8.]
 [ 9. 10. 11. 12.]
 [13. 14. 15. 16.]
 [17. 18. 19. 20.]], shape=(5, 4), dtype=float16) 

tf.Tensor(
[[[ 1.  2.]
  [ 3.  4.]]

 [[ 5.  6.]
  [ 7.  8.]]

 [[ 9. 10.]
  [11. 12.]]

 [[13. 14.]
  [15. 16.]]

 [[17. 18.]
  [19. 20.]]], shape=(5, 2, 2), dtype=float16)

上記例では、ただの数値列を行ベクトル、列ベクトル、2階のTensor、3階のTensorに形状変更しています。使用方法は簡単で、変換したいTensorと変換後の形状をshape引数に渡すのみです。

サイズが変更前後で合わないとエラー

tf.reshapeを使用した形状変更をする場合、形状変更するサイズが変更前後で合わないと以下のようにエラーとなるので注意しましょう。

import numpy as np
import tensorflow as tf

n = 20
tensor = tf.constant(np.arange(1, n + 1), dtype=tf.float16)
print(tensor, "\n")

# Tensorの形状を変更する(サイズが合わないとエラー)
reshaped_tensor = tf.reshape(tensor, shape=(4, 2, 2))
print(reshaped_tensor)
【実行結果】
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 20 values, but the requested shape has 16 [Op:Reshape]

上記の例では元のTensorのサイズが20なので、変更後もサイズが20でないといけません。しかし、変更後は4×2×2=16で一致していないためエラーとなります。

Note

tf.reshapeの公式ドキュメントはこちらを参照してください。

まとめ

Googleによって開発されている機械学習ライブラリであるTensorFlowで、Tensorの形状を変更する方法を解説しました。

Tensorの形状を変更したい場合には、tf.reshapeを使って変換します。Tensorの形状を変更する場面はよく発生しますので、しっかり使いこなせるようにしておくとよいでしょう。

Pythonによるディープラーニング」はTensorFlow/Kerasの中~上級者向けの本ですが非常におすすめできる書籍です。CNN, RNN, Transformer, GAN等高度なモデルも扱っており面白く、TensorFlow/Kerasの実装力をつけることができますので是非読んでみてもらえるといいと思います。