ニューラルネットの中間層から出力層までのネットワークを取り出す
kerasでニューラルネットの入力層から中間層までのネットワークを取り出すのはよくあるけど、中間層から最終層(出力層)までのネットワークを取り出してる例はあまり少ないので紹介する。例として、MNISTの数字画像のデータに対するオートエンコーダーのネットワークを考える。
まずは必要なライブラリのインポートとデータの取得&整形。訓練データとテストデータを定義する。
from keras.layers import Input, Dense from keras.models import Model from keras.datasets import mnist import numpy as np (x_train, y_train), (x_test, y_test) = mnist.load_data() image_size = x_train.shape[1] # = 784 original_dim = image_size * image_size x_train = np.reshape(x_train, [-1, original_dim]) x_test = np.reshape(x_test, [-1, original_dim]) x_train = x_train.astype('float32') / 255 x_test = x_test.astype('float32') / 255
次にモデルを構築する。一気に入力層から最終層までのネットワークを実装せずに、入力層から中間層までのネットワーク(=前半のネットワーク)と、中間層から最終層までのネットワーク(=後半のネットワーク)を分けてモデル化するのがミソ。もたらん分け方を変えれば、任意の部分ネットワークを引っ張ってこれる。
# 前半のネットワーク構築 encoding_dim = 32#中間層のノード数。これが小さいほど特徴量が圧縮される。 input_img = Input(shape=(784,)) x1 = Dense(256, activation='relu')(input_img) x2 = Dense(64, activation='relu')(x1) encoded = Dense(encoding_dim, activation='relu')(x2) encoder = Model(input_img, encoded) encoder.summary() # 後半のネットワーク構築 input_hidden=Input(shape=(encoding_dim),) x3 = Dense(64, activation='relu')(input_hidden) x4 = Dense(256, activation='relu')(x3) decoded = Dense(784, activation='sigmoid')(x4) decoder=Model(input_hidden,decoded) decoder.summary() # 全体のネットワーク構築 z_output = encoder(input_img)#前半のネットワークの出力=中間層の出力のこと outputs = decoder(z_output)#前半のネットワークの出力を新たな入力として、後半のネットワークの出力を求める autoencoder = Model(input_img, outputs)#全体のネットワーク。次で最適化方法や損失関数を定義する。 autoencoder.compile(optimizer='Adam', loss='binary_crossentropy')#optimizerには色々あるが、デフォルト値ではadamがベストだった autoencoder.summary()
次に全体のネットワークを学習する。これにより前半のネットワークや後半のネットワークもフィッティングされる。普通の教師ありの場合には、x_train→y_trainやx_test→y_testなどのように教師データに適宜変更してください(今回はオートエンコーダーなので、入力データと教師データが同じになってます)。
autoencoder.fit(x_train, x_train, epochs=50, batch_size=256, shuffle=True, validation_data=(x_test, x_test))
学習が完了したので、あとは「decoder.predict(中間層に入力したいデータ)」などのコマンドにより、後半のネットワークの予測値を引っ張ってこれる。
確認のため、テストデータを前半のネットワーク(encoder)に入力して得られた中間層の出力値を、後半のネットワーク(decoder)に入力してみよう。当然、これは全体のネットワーク(autoencoder)にテストデータを放り込んだ値に一致するはずである。printでそれぞれの値を出力して一致するか確認してみよう。
encoded_x_test=encoder.predict(x_test) print(decoder.predict(encoded_x_test)) print(autoencoder.predict(x_test))