TensorFlow

【TensorFlow/Keras】DCGAN(Deep Convolutional GAN)の実装

【TensorFlowKeras】DCGAN(Deep Convolutional GAN)の実装

TensorFlow/Kerasを用いて生成モデルであるDCGAN(Deep Convolutional GAN)を構築し、MNISTの手書き文字データに適用する実装例について解説します。

DCGAN(Deep Convolutional GAN)

DCGAN(Deep Convolutional GAN)は、生成モデルのGAN(Generative Adversarial Network)のモデルの一種で畳み込みニューラルネットワークにより画像データを扱えるようにしたものです。

本記事では、TensorFlow/Kerasを用いてDCGANを実装する例について紹介します。実装例の紹介の前に、まずは、ベースとなるGANの考え方を説明します。その後、MNISTの手書き文字データに適用する例でのDCGANの構造を説明し、実装例を紹介していきます。

GAN(Generative Adversarial Network)

GAN(Generative Adversarial Network)は、生成モデルの一種です。日本語にすると「敵対的生成ネットワーク」と言われます。

生成モデルとは、データを生成する分布自体をデータから学習する方法です。例えば、人の顔画像の生成モデルが構築できると、その分布から任意の点をサンプルすることで実在しない顔画像を生成することができたりします。

例えば、こちら(https://thispersondoesnotexist.com/)のサイトはGANの派生技術のStyleGAN2というものを利用した顔画像の生成サービスです。

このサービスではリロードするたびに別の人物の顔が生成されます。すごいのはサイトリンクにも「this person does not exist(この人は存在しません)」と示されているように、表示されている人物は生成モデルから生成されている架空の人物の画像で、実際には実在しません。ビックリするぐらい高解像度の人の顔画像が生成できていることに驚くかと思います(私は最初見たときはかなり驚きました)。

さて、上記はGANの技術を使ったサービスの一例ですが、具体的にGANの考え方について概要を紹介します。GANは以下のような構成要素で成り立っています。

GAN(Generative Adversarial Network)の構成

GANは、Generator(生成器)Discriminator(識別器)という2つのネットワークで構成されます。

まず、Generator(生成器)は、ノイズから偽物のデータを生成します。そして、Discriminator(識別器)は、本物の訓練データと偽物の生成データを受け取り、本物か偽物かの判定をします。

その後、判定結果をもとに「GeneratorはDiscriminatorが本物と間違うように学習」「Discriminatorは本物と偽物を識別できる能力を向上させるように学習」します。

GeneratorとDiscriminatorの関係は、例えば「偽札を作る密造者」と「警察」、「芸術品の贋作者」と「鑑定者」のような関係性となっています。それぞれが質を相互に挙げていくことで次第に偽札や贋作は本物と見分けがつかない質になっていくわけです。

GANにAdversarial(敵対的)という言葉が入っているのはこのような敵対関係に由来しています。

生成モデルとしては他にも変分自己符号化器(VAE:VariationalAutoEncoder)といったものがあります。VAEの概要と実装例については「変分自己符号化器(VAE:VariationalAutoEncoder)の実装」でも紹介していますので興味があれば参考にしてください。

GANの目的関数

上記でGANの構成概要を見ましたが具体的にどういった目的関数を元に学習をしていくのかについても整理しておきます。

GANの目的関数としては以下のように表現されます。$G$がGenerator(生成器)、$D$がDiscriminator(識別器)を表します。

\[
\min_{G} \max_{D} V(G, D) = E_{x \sim P_{data}(x)}[\log D(x)] + E_{z \sim P_{z}}[\log(1-D(G(z))]
\]

言葉で書いてみると、目的関数$V(G, D)$を「$G$に対して最小に」「$D$に対して最大に」なるように学習するという意味です。$x \sim P_{data}(x)$の部分は、データの分布に従う$x$ということで、本物の訓練データ$x$と思ってもらえればよいです。一方、$z \sim P_{z}$の部分はガウス分布などからサンプリングしたノイズを表します。

では、目的関数$V(G, D)$各項目について見ていきましょう。

$E_{x \sim P_{data}(x)}[\log D(x)]$

$D(x)$は、$x$を本物と判別する確率を表します。つまり、本物である訓練データを正解と判定するということは$D(x)$は1に近いということになり、その場合はこの項の値が大きくなります。つまり、この項の意味としては「Discriminatorが本物を本物と見抜く指標である」ととらえることができます。

$E_{z \sim P_{z}}[\log(1-D(G(z))]$

$G(z)$はノイズから生成された偽物データを表しており、$D(G(z))$は、偽物データを本物と判断する確率を表します。$1-D(G(z))$ということは、偽物を偽物と判断した時にこの値は大きくなります。つまり、この項の意味としては「DiscriminatorがGeneratorが生成した偽物データを偽物と見抜く指標である」ととらえることができます。

上記を踏まえて考えると、$D$、$G$それぞれの視点で考えてみましょう。

Discriminatorの$D$に関してみたときには、本物を本物として、偽物は偽物として判断しなければならないため、$E_{x \sim P_{data}(x)}[\log D(x)]$と$E_{z \sim P_{z}}[\log(1-D(G(z))]$のそれぞれの値を大きくするように学習しなければなりません。この時、目的関数としては最大化する方向になります。

一方で、Generatorの$G$に関してみたときには、偽物を本物と判断してほしいため、$E_{z \sim P_{z}}[\log(1-D(G(z))]$を小さくするように、つまりは$D(G(z))$が1に近づく(偽物を本物と判断する)ように学習しなければなりません。この時、目的関数としては最小化する方向になります。

上記がGANの目的関数から見たときの説明になります。実際の実装を見たときに、上記の考え方は重要になってきますので考え方を理解してもらえるとよいかと思います。

DCGAN(Deep Convolutional GAN)

DCGAN(Deep Convolutional GAN)は、画像生成に特化したGANの手法であり、Deep Convolutionalというように、畳み込み層を使用したGANになります。

今回以降で紹介する実装例では、以下のような構成のDCGANで28×28のMNISTの手書き文字を生成することを考えてみます。もちろん以下の層設計や紹介する実装例が必ずしも適切というわけではありません。層の数を増やしたりフィルタ数を変えたりなど色々変えることができます。

DCGAN Generator MNIST

DCGANでは、ノイズ$z$から、アップサンプリングしていき画像を生成しており、畳み込みニューラルネットワークであるCNNの考え方を用います。

Generatorでは、上記の図のように100次元のノイズから徐々にアップサンプリングし、28×28×1の画像にマッピングします。

一方、Discriminatorでは、Generatorとは逆のネットワーク構造でCNNを適用していき、その画像が本物か偽物かの判定をします。Discriminatorについては、本物の学習データに対しても判定を行い、上記で紹介したような目的関数を使って、Discriminatorにとっては最大化、Generatorにとっては最小化するように訓練をします。

また、DCGANでは、BatchNormalization(バッチ正規化)やLeaky ReLUを導入して、学習を安定化させるということがされています。Leaky ReLUについては、「ニューラルネットワークの活性化関数まとめ」を見ていただくとどういった関数か含めて理解いただけるかと思います。

では、以降では具体的に以降でDCGANをMNISTの手書き文字データに適用する実装例について見ていきましょう。

DCGANの論文としては、「UNSUPERVISED REPRESENTATION LEARNING WITH DEEP CONVOLUTIONAL GENERATIVE ADVERSARIAL NETWORKS」(https://arxiv.org/pdf/1511.06434.pdf)が参考になります。

TensorFlow/KerasでのDCGAN(Deep Convolutional GAN)の実装

DCGANを使用したTensorFlow/Kerasでの実装例を紹介します。今回は、MNISTの手書き文字データに対して適用する例を考えてみたいと思います。ノイズの$z$の次元としては、100次元として実装してみます。

実装例

DCGANのTensorFlow/Kerasを用いた実装例を以下に紹介します。一旦、すべてのコードを載せますが、以降で部分ごとに解説をしていきます。

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.activations import tanh
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Reshape


def mnist_descriminator():
    """Discriminator(識別器)モデルの構築(MMIST用)

    Returns:
        識別器モデル
    """
    # ===== Discriminatorの構築
    # 入力
    inputs = keras.Input(shape=(28, 28, 1))
    # 畳み込み(14×14×64)
    x = Conv2D(64, kernel_size=5, strides=2, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 畳み込み(7×7×128)
    x = Conv2D(128, kernel_size=5, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 全結合層への接続
    x = Flatten()(x)
    x = Dropout(0.3)(x)
    # 識別のための層
    outputs = Dense(1, activation="sigmoid")(x)
    # モデルの生成
    model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_discriminator")

    return model


def mnist_generator(latent_dim=100):
    """Generator(生成器)モデルの構築(MNIST用)

    Args:
        latent_dim: 潜在次元数(デフォルト: 100)

    Returns:
        生成器モデル
    """
    # ===== Generatorの構築
    # 入力
    inputs = keras.Input(shape=(latent_dim,))
    # Discriminatorと同じ数にマッピングし、reshapeする
    x = Dense(7 * 7 * 128)(inputs)
    x = Reshape((7, 7, 128))(x)
    # 逆畳み込み(7×7×128)
    x = Conv2DTranspose(128, kernel_size=5, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 逆畳み込み(14×14×64)
    x = Conv2DTranspose(64, kernel_size=5, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 逆畳み込み(28×28×1)
    x = Conv2DTranspose(1, kernel_size=5, strides=2, padding="same")(x)
    outputs = tanh(x)
    # モデルの生成
    model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_generator")

    return model


class DCGAN(keras.Model):
    """DCGANモデルクラス"""

    def __init__(self, discriminator, generator, latent_dim):
        """コンストラクタ

        Args:
            discriminator: 識別器
            generator: 生成器
            latent_dim: 潜在次元数
        """
        super(DCGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        # オプティマイザ
        self.d_optimizer = None
        self.g_optimizer = None
        # 損失関数
        self.loss_fn = None
        # 損失値の指標追跡用
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

    def compile(self, d_optimizer, g_optimizer, loss_fn, **kwargs):
        """コンパイルメソッド

        Args:
            d_optimizer: Discriminator(識別器)のオプティマイザー
            g_optimizer: Generator(生成器)のオプティマイザー
            loss_fn: 損失関数
        """
        super(DCGAN, self).compile(**kwargs)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

    def discriminator_loss(self, real_output, fake_output):
        """Discriminator(識別器)の損失関数計算

        Args:
            real_output: 正解画像に対する識別結果
            fake_output: 偽物画像に対する識別結果

        Returns:
            Discriminator(識別器)の損失関数の値
        """
        # ===== 計算準備
        # 本物と偽物用のラベル生成(本物1, 偽物0)
        labels = tf.concat(
            [tf.ones_like(real_output), tf.zeros_like(fake_output)], axis=0
        )
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 本物と偽物による識別結果を結合
        output = tf.concat([real_output, fake_output], axis=0)

        # ===== 損失値の計算
        d_loss = self.loss_fn(labels, output)

        return d_loss

    def generator_loss(self, fake_output):
        """Generator(生成器)の損失関数計算

        Args:
            fake_output: 偽物画像に対する識別結果

        Returns:
            Generator(生成器)の損失関数の値
        """
        # 生成画像に対する損失値
        g_loss = self.loss_fn(tf.ones_like(fake_output), fake_output)

        return g_loss

    def train_step(self, real_imgs):
        """DCGANの訓練ステップ

        Args:
            real_imgs: 正解画像

        Returns:
            損失関数の値リスト([識別器, 生成器])
        """
        # バッチサイズを取得
        batch_size = tf.shape(real_imgs)[0]
        # ノイズを生成
        noise_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Generator(生成器)を使って偽物の画像を生成
        generated_imgs = self.generator(noise_vectors)

        # ===== discriminatorの訓練
        with tf.GradientTape() as d_tape:
            # 正解画像に対する識別
            real_output = self.discriminator(real_imgs, training=True)
            # 偽物画像に対する識別
            fake_output = self.discriminator(generated_imgs, training=True)
            # 識別器の損失関数の値計算
            d_loss = self.discriminator_loss(real_output, fake_output)
        # 勾配計算
        d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_weights)
        # 重みの更新
        self.d_optimizer.apply_gradients(
            zip(d_grads, self.discriminator.trainable_weights)
        )

        # ===== generatorの訓練
        with tf.GradientTape() as g_tape:
            # 偽物画像に対する識別
            fake_output = self.discriminator(
                self.generator(noise_vectors), training=True
            )
            # 生成器の損失関数の値計算
            g_loss = self.generator_loss(fake_output)
        # 勾配計算
        g_grads = g_tape.gradient(g_loss, self.generator.trainable_weights)
        # 重みの更新
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_weights))

        # 損失関数の値を更新
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)

        # 損失関数の値の返却
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }


class GanMonitor(keras.callbacks.Callback):
    """GANの結果画像をエポックの終了時に保存するコールバッククラス"""

    def __init__(self, latent_dim=100):
        """コンストラクタ

        Args:
            latent_dim: 潜在次元数(デフォルト: 100)
        """
        # 生成画像数
        self.num_imgs = 25
        # 潜在次元数
        self.latent_dim = latent_dim
        # 生成の元となるノイズの生成
        self.noise_vectors = tf.random.normal(shape=(self.num_imgs, self.latent_dim))

    def on_epoch_end(self, epoch, logs):
        """エポック終了時実行メソッド

        Args:
            epoch: エポック数
            logs: ログ
        """
        # ノイズから偽物画像を生成
        generated_imgs = self.model.generator(self.noise_vectors, training=False)
        # 元のスケールに修正
        generated_imgs = generated_imgs * 127.5 + 127.5

        # 画像を表示
        fig = plt.figure(figsize=(5, 5))
        for i in range(generated_imgs.shape[0]):
            plt.subplot(5, 5, i + 1)
            plt.imshow(generated_imgs[i, :, :, 0], cmap="gray")
            plt.axis("off")
        fig.suptitle(f"Epoch: {epoch + 1:04d}")
        plt.savefig(f"generated_imgs_{epoch + 1:04d}.png")


def main():
    """メイン関数"""
    # ===== MNIST(エムニスト)データの読込
    (train_imgs, _), (test_imgs, _) = mnist.load_data()
    # 全ての画像を使って生成モデルを訓練するため訓練データとテストデータを結合
    train_imgs = np.concatenate([train_imgs, test_imgs], axis=0)
    # reshape
    train_imgs = train_imgs.reshape((70000, 28, 28, 1)).astype("float32")
    # 正規化(-1~1)
    train_imgs = (train_imgs - 127.5) / 127.5

    # ===== データセットの生成
    buffer_size = 32
    batch_size = 32
    train_dataset = tf.data.Dataset.from_tensor_slices(train_imgs)
    # シャフルしてバッチ化
    train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size)

    # ===== DCGANの構築
    # 潜在次元数
    latent_dim = 100
    # Discriminator(識別器)の生成
    discriminator = mnist_descriminator()
    print(discriminator.summary())
    keras.utils.plot_model(discriminator, "discriminator.png", show_shapes=True)
    # Generator(生成器)の生成
    generator = mnist_generator(latent_dim=latent_dim)
    print(generator.summary())
    keras.utils.plot_model(generator, "generator.png", show_shapes=True)

    # ===== DCGANモデルの構築
    dcgan = DCGAN(discriminator, generator, latent_dim)

    # ===== モデルのコンパイル
    dcgan.compile(
        d_optimizer=keras.optimizers.Adam(1e-5),
        g_optimizer=keras.optimizers.Adam(1e-5),
        loss_fn=keras.losses.BinaryCrossentropy(),
        run_eagerly=True,
    )

    # ===== モデルの学習
    n_epochs = 50
    dcgan.fit(
        train_dataset, epochs=n_epochs, callbacks=[GanMonitor(latent_dim=latent_dim)]
    )


if __name__ == "__main__":
    main()
DCGAN MNIST手書き画像

DCGANを用いてMNIST手書き画像を生成した実行結果例。50エポック訓練を実施。最初はノイズ画像で何の文字が分からないが、徐々に数字らしい画像が生成できるようになる。

DCGAN実行結果例 Epoch1
DCGAN実行結果例 Epoch50

実装内容の解説

上記で紹介した実装例の各部分ごとに内容を紹介していきます。少し長くなりますが、順を追って確認してもらえれば理解が深まるかと思います。

Discriminator(識別器)の定義

def mnist_descriminator():
    """Discriminator(識別器)モデルの構築(MMIST用)

    Returns:
        識別器モデル
    """
    # ===== Discriminatorの構築
    # 入力
    inputs = keras.Input(shape=(28, 28, 1))
    # 畳み込み(14×14×64)
    x = Conv2D(64, kernel_size=5, strides=2, padding="same")(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 畳み込み(7×7×128)
    x = Conv2D(128, kernel_size=5, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 全結合層への接続
    x = Flatten()(x)
    x = Dropout(0.3)(x)
    # 識別のための層
    outputs = Dense(1, activation="sigmoid")(x)
    # モデルの生成
    model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_discriminator")

    return model

GANにおけるDiscriminator(識別器)の定義をし、モデルを返却する関数です。今回は、MNIST画像なので、入力は28×28×1になります。その後、Conv2Dを使って順次畳み込みを行っています。

通常のCNNだと、畳み込み層とプーリング層を繰り返していきますが、今回はストライドが2の畳み込み層のみで構成されているのがポイントです。畳み込み層は局所的な情報を抽出するもので、プーリング層は情報の損失を抑えつつデータを圧縮しているものでした。この特徴からプーリング層を使うと局所的な情報が消えてしまいます。今回は画像を復元することが目的であるため局所的な情報が消えていってしまうことは望ましくありません。このような時には、ストライドを使ってダウンサンプリングしていくテクニックを使用します。

畳み込み層の後には、バッチ正規化(BatchNormalization)と活性化関数(LeakyReLU)を適用しています。そして、Flattenで平坦化した後にDropoutを挟みつつ、識別のためのDenseに接続しています。Dense層ではactivationを’sigmoid’にしていることから、最終的には0~1の確率値にマッピングしていることになります。

最後に、keras.Modelでモデルを構築して返却します。

Note

padding=’same’というのは元の画像とサイズが一致するようにパディングする設定なので、ダウンサンプリングされていることが気になる方もいるかもしれません。(私は最初そうでした)

CNNの画像サイズに関する有名な式として以下のような式があります。画像の幅・高さを$W$、フィルタサイズ$F$、ストライド$S$、パディング$P$としたときのCNNをかけた後の出力画像サイズです。

\[
\left(\frac{W+2P-F}{S} + 1\right) \times \left(\frac{W+2P-F}{S} + 1\right)
\]

padding=’same’の設定は、ストライド$S=1$の時に画像サイズが入力と同じになるように調整されます。つまりは、分子が入力と同じになるように$P$が設定されます。今回の場合は、ストライド$S=2$のため画像が半分ずつになっているわけです。

Generator(生成器)の定義

def mnist_generator(latent_dim=100):
    """Generator(生成器)モデルの構築(MNIST用)

    Args:
        latent_dim: 潜在次元数(デフォルト: 100)

    Returns:
        生成器モデル
    """
    # ===== Generatorの構築
    # 入力
    inputs = keras.Input(shape=(latent_dim,))
    # Discriminatorと同じ数にマッピングし、reshapeする
    x = Dense(7 * 7 * 128)(inputs)
    x = Reshape((7, 7, 128))(x)
    # 逆畳み込み(7×7×128)
    x = Conv2DTranspose(128, kernel_size=5, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 逆畳み込み(14×14×64)
    x = Conv2DTranspose(64, kernel_size=5, strides=2, padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)
    # 逆畳み込み(28×28×1)
    x = Conv2DTranspose(1, kernel_size=5, strides=2, padding="same")(x)
    outputs = tanh(x)
    # モデルの生成
    model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_generator")

    return model

GANにおけるGenerator(生成器)の定義をし、モデルを返却する関数です。Discriminator(識別器)と逆の処理をしていることが分かるかと思います。

入力は、潜在次元数になります。つまりはノイズの次元数です。ノイズを、7×7×128の全結合層につなげた後に、Reshapeで形状を変えています。その後Conv2DTransposeを使って、逆畳み込みをしていき最終的には28×28×1のMNIST画像と同じサイズの画像へ変換をしています。

Generatorの出力の活性化関数としては、双曲線正接関数(tanh: hypobolic tangent)を使用します。数式としては以下のようなものになっていて、値域としては[-1, 1]となります。

\[
\tanh(x) = \frac{e^{x}-e^{-x}}{e^{x}+e^{-x}}
\]

最後に、keras.Modelでモデルを構築して返却します。

DCGANモデルの定義

class DCGAN(keras.Model):
    """DCGANモデルクラス"""

DCGANのモデルは、keras.Modelを継承してDCGANクラスとして構築しています。このクラスは少し長いクラスなので部分ごとに解説していきます。

コンストラクタ (__init__)

    def __init__(self, discriminator, generator, latent_dim):
        """コンストラクタ

        Args:
            discriminator: 識別器
            generator: 生成器
            latent_dim: 潜在次元数
        """
        super(DCGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        # オプティマイザ
        self.d_optimizer = None
        self.g_optimizer = None
        # 損失関数
        self.loss_fn = None
        # 損失値の指標追跡用
        self.d_loss_metric = keras.metrics.Mean(name="d_loss")
        self.g_loss_metric = keras.metrics.Mean(name="g_loss")

コンストラクタでは、DCGANクラスが使用する変数を定義しています。discriminatorとgeneratorは呼び出しもとで準備したものを引数として渡す構成としています。

オプティマイザや損失関数はクラス内のメソッドで使用するのでNoneとして準備しておきます。また、損失関数の値の指標追跡用として、keras.metrics.Meanで準備しています。

コンパイル (compile)

    def compile(self, d_optimizer, g_optimizer, loss_fn, **kwargs):
        """コンパイルメソッド

        Args:
            d_optimizer: Discriminator(識別器)のオプティマイザー
            g_optimizer: Generator(生成器)のオプティマイザー
            loss_fn: 損失関数
        """
        super(DCGAN, self).compile(**kwargs)
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

DCGANのモデルコンパイルようメソッドです。親クラスのcompileメソッドを呼び出すとともに、DCGAN用のオプティマイザと損失関数を設定しています。

損失値の指標の返却 (metrics)

    @property
    def metrics(self):
        return [self.d_loss_metric, self.g_loss_metric]

この@propertyデコレータがついているmetricsメソッドで指標を定義しておくことで、fit()やevaluate()を呼び出すたびにモデルが指標をリセットするように動作するようになります。

Discriminator(識別器)の損失関数の値計算 (discriminator_loss)

    def discriminator_loss(self, real_output, fake_output):
        """Discriminator(識別器)の損失関数計算

        Args:
            real_output: 正解画像に対する識別結果
            fake_output: 偽物画像に対する識別結果

        Returns:
            Discriminator(識別器)の損失関数の値
        """
        # ===== 計算準備
        # 本物と偽物用のラベル生成(本物1, 偽物0)
        labels = tf.concat(
            [tf.ones_like(real_output), tf.zeros_like(fake_output)], axis=0
        )
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # 本物と偽物による識別結果を結合
        output = tf.concat([real_output, fake_output], axis=0)

        # ===== 損失値の計算
        d_loss = self.loss_fn(labels, output)

        return d_loss

Discriminator(識別器)の損失関数の値を計算するためのメソッドです。入力としては、正解画像に対する識別結果(real_output)と偽物画像に対する識別結果(fake_output)を受け取ります。

GANの目的関数としては以下のような式になります。

\[
\min_{G} \max_{D} V(G, D) = E_{x \sim P_{data}(x)}[\log D(x)] + E_{z \sim P_{z}}[\log(1-D(G(z))]
\]

Discriminator(識別器)からすると、関数を最大化する必要があります。つまり、正解画像を正解として、偽物画像を偽物として識別できないといけません。ここでは、本物を1、偽物を0としたラベルと、Discriminator(識別器)の結果を使って、loss_fnの損失関数の値を計算しています。

        labels += 0.05 * tf.random.uniform(tf.shape(labels))

正解ラベルを作っているのですが、上記のように正解ラベルに少しノイズを載せています。こちらはTensorFlow/Kerasの書籍でもある「Pythonによるディープラーニング」に重要なトリックとして紹介されていたので適用してみています。

Generator(生成器)の損失関数の値計算 (generator_loss)

    def generator_loss(self, fake_output):
        """Generator(生成器)の損失関数計算

        Args:
            fake_output: 偽物画像に対する識別結果

        Returns:
            Generator(生成器)の損失関数の値
        """
        # 生成画像に対する損失値
        g_loss = self.loss_fn(tf.ones_like(fake_output), fake_output)

        return g_loss

Generator(生成器)の損失関数の値を計算するためのメソッドです。入力としては、偽物画像に対する識別結果(fake_output)を受け取ります。

GANの目的関数としては以下のような式になります。

\[
\min_{G} \max_{D} V(G, D) = E_{x \sim P_{data}(x)}[\log D(x)] + E_{z \sim P_{z}}[\log(1-D(G(z))]
\]

Generator(生成器)からすると、関数を最小化する必要があります。つまり、偽物画像を正解として誤識別させないといけません。つまり、偽物を本物の1と判断させるようにloss_fnの損失関数の値を計算しています。

DCGANの訓練ステップ (train_step)

    def train_step(self, real_imgs):
        """DCGANの訓練ステップ

        Args:
            real_imgs: 正解画像

        Returns:
            損失関数の値リスト([識別器, 生成器])
        """
        # バッチサイズを取得
        batch_size = tf.shape(real_imgs)[0]
        # ノイズを生成
        noise_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        # Generator(生成器)を使って偽物の画像を生成
        generated_imgs = self.generator(noise_vectors)

        # ===== discriminatorの訓練
        with tf.GradientTape() as d_tape:
            # 正解画像に対する識別
            real_output = self.discriminator(real_imgs, training=True)
            # 偽物画像に対する識別
            fake_output = self.discriminator(generated_imgs, training=True)
            # 識別器の損失関数の値計算
            d_loss = self.discriminator_loss(real_output, fake_output)
        # 勾配計算
        d_grads = d_tape.gradient(d_loss, self.discriminator.trainable_weights)
        # 重みの更新
        self.d_optimizer.apply_gradients(
            zip(d_grads, self.discriminator.trainable_weights)
        )

        # ===== generatorの訓練
        with tf.GradientTape() as g_tape:
            # 偽物画像に対する識別
            fake_output = self.discriminator(
                self.generator(noise_vectors), training=True
            )
            # 生成器の損失関数の値計算
            g_loss = self.generator_loss(fake_output)
        # 勾配計算
        g_grads = g_tape.gradient(g_loss, self.generator.trainable_weights)
        # 重みの更新
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_weights))

        # 損失関数の値を更新
        self.d_loss_metric.update_state(d_loss)
        self.g_loss_metric.update_state(g_loss)

        # 損失関数の値の返却
        return {
            "d_loss": self.d_loss_metric.result(),
            "g_loss": self.g_loss_metric.result(),
        }

DCGANの実際の訓練ステップを実装しているメソッドです。以下のような流れで処理が進みます。

  1. latent_dim次元のノイズを生成する。
  2. ノイズからGenerator(生成器)を通して偽物画像を生成する。
  3. Discriminator(識別器)に正解画像と偽物画像をそれぞれ渡して識別する。識別した結果をdiscriminator_lossメソッドに渡して損失関数の値を計算する。計算結果をもとに勾配を計算して、discriminatorモデルの重みを更新する。
  4. Generator(生成器)を通した偽物画像を識別した結果をgenerator_lossメソッドに渡して損失関数の値を計算する。計算結果をもとに勾配を計算して、generatorモデルの重みを更新する。
  5. Discriminator(識別器)とGenerator(生成器)それぞれの損失関数の値を更新して返却する。

③と④については、それぞれで勾配や重みの更新を行います。勾配や重み計算をするには「with tf.GradientTape() as tape:」といった構成のブロックに順伝搬の処理を記載します。その後、「tape.gradient」で勾配を計算し、「optimizer.apply_gradients」といった形で重みを反映します。この流れは、TensorFlowでの勾配計算や重み更新の一般的な構成です。

損失関数の更新は、metricのupdate_stateで更新ができます。最後に損失関数の値を返却します。

GANの結果画像をエポック終了時に保存する(コールバック)

class GanMonitor(keras.callbacks.Callback):
    """GANの結果画像をエポックの終了時に保存するコールバッククラス"""

    def __init__(self, latent_dim=100):
        """コンストラクタ

        Args:
            latent_dim: 潜在次元数(デフォルト: 100)
        """
        # 生成画像数
        self.num_imgs = 25
        # 潜在次元数
        self.latent_dim = latent_dim
        # 生成の元となるノイズの生成
        self.noise_vectors = tf.random.normal(shape=(self.num_imgs, self.latent_dim))

    def on_epoch_end(self, epoch, logs):
        """エポック終了時実行メソッド

        Args:
            epoch: エポック数
            logs: ログ
        """
        # ノイズから偽物画像を生成
        generated_imgs = self.model.generator(self.noise_vectors, training=False)
        # 元のスケールに修正
        generated_imgs = generated_imgs * 127.5 + 127.5

        # 画像を表示
        fig = plt.figure(figsize=(5, 5))
        for i in range(generated_imgs.shape[0]):
            plt.subplot(5, 5, i + 1)
            plt.imshow(generated_imgs[i, :, :, 0], cmap="gray")
            plt.axis("off")
        fig.suptitle(f"Epoch: {epoch + 1:04d}")
        plt.savefig(f"generated_imgs_{epoch + 1:04d}.png")

今回のDCGANは、学習の過程を損失値だけ見ていてもイメージしづらいので、各エポックの終了時にその時点で学習したモデルを使った生成結果画像を保存するようにします。

Kerasでは訓練中の色々なポイントでカスタムコールバックを定義できるようになっています。例えば、「on_epoch_begin(epoch, logs):各エポックの最初に呼び出される」「on_epoch_end(epoch, logs):各エポックの最後に呼び出される」といった具合です。他にもon_batch_begin(batch, logs)、on_batch_end(batch, log)、on_train_begin(logs)、on_train_end(logs)のようにバッチや訓練の前後に呼び出されるものもあります。

各メソッドの呼び出しでは常にlogs引数が渡され、一つ前のエポック、バッチ、訓練に関する情報を含んだ辞書が渡されます。

今回は、keras.callbacks.Callbackクラスを継承したクラスを作成して、on_epoch_end(epoch, logs)のカスタムコールバックで、ノイズからGenerator(生成器)で画像を生成し、保存する処理を実行しています。

コンストラクタでノイズを生成しておいて、そのノイズに対する生成結果がどのように変化しているかを追えるようになっています。

実行部分(main)の内容

上記までで、DCGAN実行のための構成要素が出来上がりました。では、最後に全体の処理を実行しているmain部分について説明していきます。

MNISTデータの準備とデータセットの生成

    # ===== MNIST(エムニスト)データの読込
    (train_imgs, _), (test_imgs, _) = mnist.load_data()
    # 全ての画像を使って生成モデルを訓練するため訓練データとテストデータを結合
    train_imgs = np.concatenate([train_imgs, test_imgs], axis=0)
    # reshape
    train_imgs = train_imgs.reshape((70000, 28, 28, 1)).astype("float32")
    # 正規化(-1~1)
    train_imgs = (train_imgs - 127.5) / 127.5

    # ===== データセットの生成
    buffer_size = 32
    batch_size = 32
    train_dataset = tf.data.Dataset.from_tensor_slices(train_imgs)
    # シャフルしてバッチ化
    train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size)

今回対象とするMNISTデータを読み込んでいる部分です。今回生成モデルを訓練するために訓練データとテストデータは結合してまとめてしまっています。また、データとしてはreshapeして、-1~1の間に入るようにしています。これは、Generatorが生成する画像が$\tanh$を通すことから-1~1になることに関連しています。

Discriminator(識別器)とGenerator(生成器)の生成

    # ===== DCGANの構築
    # 潜在次元数
    latent_dim = 100
    # Discriminator(識別器)の生成
    discriminator = mnist_descriminator()
    print(discriminator.summary())
    keras.utils.plot_model(discriminator, "discriminator.png", show_shapes=True)
    # Generator(生成器)の生成
    generator = mnist_generator(latent_dim=latent_dim)
    print(generator.summary())
    keras.utils.plot_model(generator, "generator.png", show_shapes=True)

Discriminator(識別器)とGenerator(生成器)は、main内で生成してDCGANクラスに渡しています。summaryやplot_modelにより構造を出力すると構造が確認しやすくなります。

DCGANモデルの生成と訓練

    # ===== DCGANモデルの構築
    dcgan = DCGAN(discriminator, generator, latent_dim)

    # ===== モデルのコンパイル
    dcgan.compile(
        d_optimizer=keras.optimizers.Adam(1e-5),
        g_optimizer=keras.optimizers.Adam(1e-5),
        loss_fn=keras.losses.BinaryCrossentropy(),
        run_eagerly=True,
    )

    # ===== モデルの学習
    n_epochs = 50
    dcgan.fit(
        train_dataset, epochs=n_epochs, callbacks=[GanMonitor(latent_dim=latent_dim)]
    )

この部分で、DCGANモデルの生成~訓練(学習)までを実行しています。

DCGANをインスタンス化する際には、discriminatorとgeneratorを引数として渡します。コンパイル時には、d_optimizer(識別器用オプティマイザ)とg_optimizer(生成器用オプティマイザ)にAdamを指定しています。また、損失関数としては、BinaryCrossentropyを設定しています。

また、run_eagerlyをTrueにしておくと、モデル内の処理でprintで値を表示したりすることができますのでデバッグで便利です。これを指定しないとtrain_step()メソッド内等でprintしてもテンソルの値を確認できません。

モデルの学習は、fit()を使用します。この時に、上記で説明したカスタムコールバックのGanMonitorを設定することで、エポックの終了時に生成画像を保存することができます。

GIF画像での結果確認

今回は、DCGANの実行結果画像を保存するコールバックを作成しました。作成した画像をGIF画像で見られると実行結果が分かりやすいため、以下のスクリプトでGIF画像化しています。

import glob
import imageio

gif_filename = "dcgan_mnist.gif"

with imageio.get_writer(gif_filename, mode="I") as writer:
    filenames = glob.glob("generated_imgs_*.png")
    filenames = sorted(filenames)
    for filename in filenames:
        image = imageio.v3.imread(filename)
        writer.append_data(image)

imageioがインストールされておらず実行ができない場合は、pip installでimageioをインストールしてから実行するようにしてください。

pip install imageio

実行結果例(再掲)

DCGAN MNIST手書き画像

DCGANを用いてMNIST手書き画像を生成した実行結果例。50エポック訓練を実施。最初はノイズ画像で何の文字が分からないが、徐々に数字らしい画像が生成できるようになる。

DCGAN実行結果例 Epoch1
DCGAN実行結果例 Epoch50

実行結果を再掲しますが、最初のエポック1ではただのノイズ画像であったものが、訓練を繰り返すうちに徐々に変化していき50エポックの頃にはそれなりに手書き画像のような数字が生成できていることが分かるかと思います。

以上が、DCGANの実装に関する説明でした。

まとめ

TensorFlow/Kerasを用いて生成モデルであるDCGAN(Deep Convolutional GAN)を構築し、MNISTの手書き文字データに適用する実装例について解説しました。

GANは目的関数の意味などを含めて最初はなかなか理解ができないのではないかなと感じます。具体的に手を動かして実装してみるとイメージがつきやすくなりますので、GANの内容を理解するのに本記事が少しでも役に立てばよいなと思います。

Pythonによるディープラーニング」はTensorFlow/Kerasの中~上級者向けの本ですが非常におすすめできる書籍です。CNN, RNN, Transformer, GAN等高度なモデルも扱っており面白く、TensorFlow/Kerasの実装力をつけることができますので是非読んでみてもらえるといいと思います。