ChunPom’s diary

数学、物理、機械学習に関する話題。あと院試、資格、大学入試まで。

ニューラルネットの中間層から出力層までのネットワークを取り出す

kerasでニューラルネットの入力層から中間層までのネットワークを取り出すのはよくあるけど、中間層から最終層(出力層)までのネットワークを取り出してる例はあまり少ないので紹介する。例として、MNISTの数字画像のデータに対するオートエンコーダーのネットワークを考える。

まずは必要なライブラリのインポートとデータの取得&整形。訓練データとテストデータを定義する。

from keras.layers import Input, Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np

(x_train, y_train), (x_test, y_test) = mnist.load_data()
image_size = x_train.shape[1] # = 784
original_dim = image_size * image_size
x_train = np.reshape(x_train, [-1, original_dim])
x_test = np.reshape(x_test, [-1, original_dim])
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255


次にモデルを構築する。一気に入力層から最終層までのネットワークを実装せずに、入力層から中間層までのネットワーク(=前半のネットワーク)と、中間層から最終層までのネットワーク(=後半のネットワーク)を分けてモデル化するのがミソ。もたらん分け方を変えれば、任意の部分ネットワークを引っ張ってこれる。

# 前半のネットワーク構築
encoding_dim = 32#中間層のノード数。これが小さいほど特徴量が圧縮される。
input_img = Input(shape=(784,))
x1 = Dense(256, activation='relu')(input_img)  
x2 = Dense(64, activation='relu')(x1)  
encoded = Dense(encoding_dim, activation='relu')(x2) 
encoder = Model(input_img, encoded)
encoder.summary()

# 後半のネットワーク構築
input_hidden=Input(shape=(encoding_dim),)
x3 = Dense(64, activation='relu')(input_hidden)
x4 = Dense(256, activation='relu')(x3)  
decoded = Dense(784, activation='sigmoid')(x4) 
decoder=Model(input_hidden,decoded)
decoder.summary()

# 全体のネットワーク構築
z_output = encoder(input_img)#前半のネットワークの出力=中間層の出力のこと
outputs = decoder(z_output)#前半のネットワークの出力を新たな入力として、後半のネットワークの出力を求める

autoencoder = Model(input_img, outputs)#全体のネットワーク。次で最適化方法や損失関数を定義する。
autoencoder.compile(optimizer='Adam', loss='binary_crossentropy')#optimizerには色々あるが、デフォルト値ではadamがベストだった
autoencoder.summary()  


次に全体のネットワークを学習する。これにより前半のネットワークや後半のネットワークもフィッティングされる。普通の教師ありの場合には、x_train→y_trainやx_test→y_testなどのように教師データに適宜変更してください(今回はオートエンコーダーなので、入力データと教師データが同じになってます)。

autoencoder.fit(x_train, x_train,
                epochs=50,    
                batch_size=256,
                shuffle=True,
                validation_data=(x_test, x_test))


学習が完了したので、あとは「decoder.predict(中間層に入力したいデータ)」などのコマンドにより、後半のネットワークの予測値を引っ張ってこれる。
確認のため、テストデータを前半のネットワーク(encoder)に入力して得られた中間層の出力値を、後半のネットワーク(decoder)に入力してみよう。当然、これは全体のネットワーク(autoencoder)にテストデータを放り込んだ値に一致するはずである。printでそれぞれの値を出力して一致するか確認してみよう。

encoded_x_test=encoder.predict(x_test)
print(decoder.predict(encoded_x_test))

print(autoencoder.predict(x_test))