Aller au contenu

TP Module 5 : Personnaliser son réseau de neurones

import tensorflow as tf
from tensorflow import keras

print(tf.__version__)
print(keras.__version__)

# Splitting
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import pandas as pd
import random
import os
import numpy as np

# freeze de l'aléatoire, pour avoir des expériences reproductibles.
RANDOM_SEED = 42

os.environ['PYTHONHASHSEED'] = str(RANDOM_SEED)
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
tf.random.set_seed(RANDOM_SEED)
2.2.0
2.3.0-tf

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import ReLU
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.layers import GlobalAvgPool2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
!nvidia-smi
Mon May 18 12:33:57 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.82       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Import Dataset

1
2
3
(X_train,y_train), (X_test,y_test)  = tf.keras.datasets.cifar100.load_data()

print(X_train.shape, y_train.shape)
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
169009152/169001437 [==============================] - 5s 0us/step
(50000, 32, 32, 3) (50000, 1)

X_train = X_train.reshape(-1, 32, 32, 3).astype('float32')
X_test = X_test.reshape(-1, 32, 32, 3).astype('float32')


X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, random_state=RANDOM_SEED)

X_test = X_test/255
X_train = X_train/255
X_valid = X_valid/255

y_train_oh = tf.keras.utils.to_categorical(y_train, num_classes=100)
y_test_oh = tf.keras.utils.to_categorical(y_test, num_classes=100)
y_valid_oh = tf.keras.utils.to_categorical(y_valid, num_classes=100)
AUTOTUNE = tf.data.experimental.AUTOTUNE

def train_preprocess(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)

    #image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)
    #image = tf.image.random_saturation(image, lower=0.5, upper=1.5)

    #Make sure the image is still in [0, 1]
    image = tf.clip_by_value(image, 0.0, 1.0)

    return image, label

def create_train_dataset(features, labels, batch=64, repet=1, prefetch=1):
    dataset = tf.data.Dataset.from_tensor_slices((features,labels))
    dataset = dataset.shuffle(len(features), seed=RANDOM_SEED)
    dataset = dataset.repeat(repet)
    dataset = dataset.map(train_preprocess, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(prefetch)
    return dataset

def create_test_dataset(features, labels, batch=64, repet=1, prefetch=1):
    dataset = tf.data.Dataset.from_tensor_slices((features,labels))
    dataset = dataset.shuffle(len(features), seed=RANDOM_SEED)
    dataset = dataset.repeat(repet)
    dataset = dataset.batch(batch)
    dataset = dataset.prefetch(prefetch)
    return dataset

ds_train = create_train_dataset(X_train, y_train_oh)
ds_val = create_test_dataset(X_valid, y_valid_oh)
ds_test = create_test_dataset(X_test, y_test_oh)

Les architectures modernes

In Convolutional Nets, there is no such thing as "fully-connected layers". There are only convolution layers with \(1 \times 1\) convolution kernels and a full connection table.

It's a too-rarely-understood fact that ConvNets don't need to have a fixed-size input. You can train them on inputs that happen to produce a single output vector (with no spatial extent), and then apply them to larger images. Instead of a single output vector, you then get a spatial map of output vectors. Each vector sees input windows at different locations on the input.

In that scenario, the "fully connected layers" really act as \(1 \times 1\) convolutions.

Yann LeCun

ResNet

Idée

  • Driven by the significance of depth, a question arises : Is learning better networks as easy as stacking more layers ? An obstacle to answering this question was the notorious problem of vanishing/exploding gradient [...]. This problem, however, has been largely addressed by normalized initilization and intermediate normalization layers. (ie BatchNorm et initialisation des poids)

  • When deeper networks are able to start converging, a degradation problem has been exposed : : with the network depth increasing, accuracy gets saturated (...) and then degrades rapidly. Unexpectedly, such a degradation is not caused by overfitting, and adding more layers to a suitably deep model leads to higher training error.

  • We show that :

  • Our extremely deep residual nets are esay to optimize, but the counterpart "plain" nets (that simply stacks layers) exhibit higher training error when the depth increases.
  • Our deep residual nets can easily enjoy accuracy gains from greatly increased depth, producing results substantially better than previous networks.

  • Our 152-layers residual net is the deeper network ever presented on ImageNet (2015), while still having lower complexity than VGG nets.

  • The degradation problem suggests that the solvers (ie weights optimization) might have difficulties in approximating identity mappings by multiple nonlinear layers.

Définition des briques de bases

  • We adopt batch normalization (BN) right after each convolution and before activation.
def conv_batchnorm_relu(x, filters, kernel_size, strides):
  x = Conv2D(filters=filters,
             kernel_size=kernel_size,
             strides=strides,
             padding='same',
             kernel_initializer="he_normal",
             use_bias = False)(x)
  x = BatchNormalization()(x)
  x = ReLU()(x)
  return x
Bloc Identité
  • The three layers are \(1 \times 1\), \(3 \times 3\), and \(1 \times 1\) convolutions, where the \(1 \times 1\) layers are responsible for reducing and then increasing (restoring) dimensions, leaving the \(3 \times 3\) layer a bottleneck with smaller input/output dimensions.

  • \(50\)-layer ResNet: We replace each \(2\)-layer block in the \(34\)-layer net with this \(3\)-layer bottleneck block, resulting in a \(50\)-layer ResNet (Table 1). We use option B for increasing dimensions (ie projection blocks).

def identity_block(tensor, filters):
  x = conv_batchnorm_relu(tensor,
                          filters=filters,
                          kernel_size=1,
                          strides=1)
  #print(x.shape)
  x = conv_batchnorm_relu(x,
                          filters=filters,
                          kernel_size=3,
                          strides=1)
  #print(x.shape)
  x = Conv2D(filters=4*filters,
             kernel_size=1, 
             strides=1,
             kernel_initializer="he_normal")(x)  # notice: filters=4*filters
  #print(x.shape)
  x = BatchNormalization()(x)

  x = Add()([x, tensor])
  x = ReLU()(x)
  return x
Bloc Projection
  • The projection shortcut in Eqn.(2) is used to match dimensions (done by \(1 \times 1\) convolutions). For both options, when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2

  • [...] Projection shortcuts are used for increasing dimensions, and other shortcuts are identity;

def projection_block(tensor, filters, strides):
  # left stream
  x = conv_batchnorm_relu(tensor, 
                          filters=filters, 
                          kernel_size=1,
                          strides=strides)
  x = conv_batchnorm_relu(x,
                          filters=filters,
                          kernel_size=3,
                          strides=1)
  x = Conv2D(filters=4*filters,
             kernel_size=1,
             strides=1,
             kernel_initializer="he_normal")(x)  # notice: filters=4*filters
  x = BatchNormalization()(x)

  # right stream
  shortcut = Conv2D(filters=4*filters,
                    kernel_size=1,
                    strides=strides,
                    kernel_initializer="he_normal")(tensor)  # notice: filters=4*filters
  shortcut = BatchNormalization()(shortcut)

  x = Add()([x, shortcut])
  x = ReLU()(x)
  return x
Bloc ResNet
  • Donwsampling is performed by conv3_1, conv4_1, and conv5_1 with a stride of 2.
1
2
3
4
5
def resnet_block(x, filters, reps, strides):
  x = projection_block(x, filters=filters, strides=strides)
  for _ in range(reps-1):
    x = identity_block(x, filters=filters)
  return x
input = Input(shape=(32, 32, 3))

x = conv_batchnorm_relu(input, filters=64, kernel_size=7, strides=2)  # [3]: 7x7, 64, strides 2
x = MaxPool2D(pool_size=3, strides=2, padding='same')(x)  # [3]: 3x3 max pool, strides 2

x = resnet_block(x, filters=64, reps=3, strides=1)
x = resnet_block(x, filters=128, reps=4, strides=2)  # strides=2 ([2]: conv3_1)
x = resnet_block(x, filters=256, reps=6, strides=2)  # strides=2 ([2]: conv4_1)
x = resnet_block(x, filters=512, reps=3, strides=2)  # strides=2 ([2]: conv5_1)

x = GlobalAvgPool2D()(x)  # [3]: average pool *it is not written any pool size so we use Global
x = Dense(100)(x)
output = Activation('softmax')(x)  # [3]: 1000-d fc, softmax

from tensorflow.keras import Model

model = Model(input, output)
model.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_167 (Conv2D)             (None, 16, 16, 64)   9408        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_166 (BatchN (None, 16, 16, 64)   256         conv2d_167[0][0]                 
__________________________________________________________________________________________________
re_lu_152 (ReLU)                (None, 16, 16, 64)   0           batch_normalization_166[0][0]    
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 8, 8, 64)     0           re_lu_152[0][0]                  
__________________________________________________________________________________________________
conv2d_168 (Conv2D)             (None, 8, 8, 64)     4096        max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_167 (BatchN (None, 8, 8, 64)     256         conv2d_168[0][0]                 
__________________________________________________________________________________________________
re_lu_153 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_167[0][0]    
__________________________________________________________________________________________________
conv2d_169 (Conv2D)             (None, 8, 8, 64)     36864       re_lu_153[0][0]                  
__________________________________________________________________________________________________
batch_normalization_168 (BatchN (None, 8, 8, 64)     256         conv2d_169[0][0]                 
__________________________________________________________________________________________________
re_lu_154 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_168[0][0]    
__________________________________________________________________________________________________
conv2d_170 (Conv2D)             (None, 8, 8, 256)    16640       re_lu_154[0][0]                  
__________________________________________________________________________________________________
conv2d_171 (Conv2D)             (None, 8, 8, 256)    16640       max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
batch_normalization_169 (BatchN (None, 8, 8, 256)    1024        conv2d_170[0][0]                 
__________________________________________________________________________________________________
batch_normalization_170 (BatchN (None, 8, 8, 256)    1024        conv2d_171[0][0]                 
__________________________________________________________________________________________________
add_50 (Add)                    (None, 8, 8, 256)    0           batch_normalization_169[0][0]    
                                                                 batch_normalization_170[0][0]    
__________________________________________________________________________________________________
re_lu_155 (ReLU)                (None, 8, 8, 256)    0           add_50[0][0]                     
__________________________________________________________________________________________________
conv2d_172 (Conv2D)             (None, 8, 8, 64)     16384       re_lu_155[0][0]                  
__________________________________________________________________________________________________
batch_normalization_171 (BatchN (None, 8, 8, 64)     256         conv2d_172[0][0]                 
__________________________________________________________________________________________________
re_lu_156 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_171[0][0]    
__________________________________________________________________________________________________
conv2d_173 (Conv2D)             (None, 8, 8, 64)     36864       re_lu_156[0][0]                  
__________________________________________________________________________________________________
batch_normalization_172 (BatchN (None, 8, 8, 64)     256         conv2d_173[0][0]                 
__________________________________________________________________________________________________
re_lu_157 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_172[0][0]    
__________________________________________________________________________________________________
conv2d_174 (Conv2D)             (None, 8, 8, 256)    16640       re_lu_157[0][0]                  
__________________________________________________________________________________________________
batch_normalization_173 (BatchN (None, 8, 8, 256)    1024        conv2d_174[0][0]                 
__________________________________________________________________________________________________
add_51 (Add)                    (None, 8, 8, 256)    0           batch_normalization_173[0][0]    
                                                                 re_lu_155[0][0]                  
__________________________________________________________________________________________________
re_lu_158 (ReLU)                (None, 8, 8, 256)    0           add_51[0][0]                     
__________________________________________________________________________________________________
conv2d_175 (Conv2D)             (None, 8, 8, 64)     16384       re_lu_158[0][0]                  
__________________________________________________________________________________________________
batch_normalization_174 (BatchN (None, 8, 8, 64)     256         conv2d_175[0][0]                 
__________________________________________________________________________________________________
re_lu_159 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_174[0][0]    
__________________________________________________________________________________________________
conv2d_176 (Conv2D)             (None, 8, 8, 64)     36864       re_lu_159[0][0]                  
__________________________________________________________________________________________________
batch_normalization_175 (BatchN (None, 8, 8, 64)     256         conv2d_176[0][0]                 
__________________________________________________________________________________________________
re_lu_160 (ReLU)                (None, 8, 8, 64)     0           batch_normalization_175[0][0]    
__________________________________________________________________________________________________
conv2d_177 (Conv2D)             (None, 8, 8, 256)    16640       re_lu_160[0][0]                  
__________________________________________________________________________________________________
batch_normalization_176 (BatchN (None, 8, 8, 256)    1024        conv2d_177[0][0]                 
__________________________________________________________________________________________________
add_52 (Add)                    (None, 8, 8, 256)    0           batch_normalization_176[0][0]    
                                                                 re_lu_158[0][0]                  
__________________________________________________________________________________________________
re_lu_161 (ReLU)                (None, 8, 8, 256)    0           add_52[0][0]                     
__________________________________________________________________________________________________
conv2d_178 (Conv2D)             (None, 4, 4, 128)    32768       re_lu_161[0][0]                  
__________________________________________________________________________________________________
batch_normalization_177 (BatchN (None, 4, 4, 128)    512         conv2d_178[0][0]                 
__________________________________________________________________________________________________
re_lu_162 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_177[0][0]    
__________________________________________________________________________________________________
conv2d_179 (Conv2D)             (None, 4, 4, 128)    147456      re_lu_162[0][0]                  
__________________________________________________________________________________________________
batch_normalization_178 (BatchN (None, 4, 4, 128)    512         conv2d_179[0][0]                 
__________________________________________________________________________________________________
re_lu_163 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_178[0][0]    
__________________________________________________________________________________________________
conv2d_180 (Conv2D)             (None, 4, 4, 512)    66048       re_lu_163[0][0]                  
__________________________________________________________________________________________________
conv2d_181 (Conv2D)             (None, 4, 4, 512)    131584      re_lu_161[0][0]                  
__________________________________________________________________________________________________
batch_normalization_179 (BatchN (None, 4, 4, 512)    2048        conv2d_180[0][0]                 
__________________________________________________________________________________________________
batch_normalization_180 (BatchN (None, 4, 4, 512)    2048        conv2d_181[0][0]                 
__________________________________________________________________________________________________
add_53 (Add)                    (None, 4, 4, 512)    0           batch_normalization_179[0][0]    
                                                                 batch_normalization_180[0][0]    
__________________________________________________________________________________________________
re_lu_164 (ReLU)                (None, 4, 4, 512)    0           add_53[0][0]                     
__________________________________________________________________________________________________
conv2d_182 (Conv2D)             (None, 4, 4, 128)    65536       re_lu_164[0][0]                  
__________________________________________________________________________________________________
batch_normalization_181 (BatchN (None, 4, 4, 128)    512         conv2d_182[0][0]                 
__________________________________________________________________________________________________
re_lu_165 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_181[0][0]    
__________________________________________________________________________________________________
conv2d_183 (Conv2D)             (None, 4, 4, 128)    147456      re_lu_165[0][0]                  
__________________________________________________________________________________________________
batch_normalization_182 (BatchN (None, 4, 4, 128)    512         conv2d_183[0][0]                 
__________________________________________________________________________________________________
re_lu_166 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_182[0][0]    
__________________________________________________________________________________________________
conv2d_184 (Conv2D)             (None, 4, 4, 512)    66048       re_lu_166[0][0]                  
__________________________________________________________________________________________________
batch_normalization_183 (BatchN (None, 4, 4, 512)    2048        conv2d_184[0][0]                 
__________________________________________________________________________________________________
add_54 (Add)                    (None, 4, 4, 512)    0           batch_normalization_183[0][0]    
                                                                 re_lu_164[0][0]                  
__________________________________________________________________________________________________
re_lu_167 (ReLU)                (None, 4, 4, 512)    0           add_54[0][0]                     
__________________________________________________________________________________________________
conv2d_185 (Conv2D)             (None, 4, 4, 128)    65536       re_lu_167[0][0]                  
__________________________________________________________________________________________________
batch_normalization_184 (BatchN (None, 4, 4, 128)    512         conv2d_185[0][0]                 
__________________________________________________________________________________________________
re_lu_168 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_184[0][0]    
__________________________________________________________________________________________________
conv2d_186 (Conv2D)             (None, 4, 4, 128)    147456      re_lu_168[0][0]                  
__________________________________________________________________________________________________
batch_normalization_185 (BatchN (None, 4, 4, 128)    512         conv2d_186[0][0]                 
__________________________________________________________________________________________________
re_lu_169 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_185[0][0]    
__________________________________________________________________________________________________
conv2d_187 (Conv2D)             (None, 4, 4, 512)    66048       re_lu_169[0][0]                  
__________________________________________________________________________________________________
batch_normalization_186 (BatchN (None, 4, 4, 512)    2048        conv2d_187[0][0]                 
__________________________________________________________________________________________________
add_55 (Add)                    (None, 4, 4, 512)    0           batch_normalization_186[0][0]    
                                                                 re_lu_167[0][0]                  
__________________________________________________________________________________________________
re_lu_170 (ReLU)                (None, 4, 4, 512)    0           add_55[0][0]                     
__________________________________________________________________________________________________
conv2d_188 (Conv2D)             (None, 4, 4, 128)    65536       re_lu_170[0][0]                  
__________________________________________________________________________________________________
batch_normalization_187 (BatchN (None, 4, 4, 128)    512         conv2d_188[0][0]                 
__________________________________________________________________________________________________
re_lu_171 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_187[0][0]    
__________________________________________________________________________________________________
conv2d_189 (Conv2D)             (None, 4, 4, 128)    147456      re_lu_171[0][0]                  
__________________________________________________________________________________________________
batch_normalization_188 (BatchN (None, 4, 4, 128)    512         conv2d_189[0][0]                 
__________________________________________________________________________________________________
re_lu_172 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_188[0][0]    
__________________________________________________________________________________________________
conv2d_190 (Conv2D)             (None, 4, 4, 512)    66048       re_lu_172[0][0]                  
__________________________________________________________________________________________________
batch_normalization_189 (BatchN (None, 4, 4, 512)    2048        conv2d_190[0][0]                 
__________________________________________________________________________________________________
add_56 (Add)                    (None, 4, 4, 512)    0           batch_normalization_189[0][0]    
                                                                 re_lu_170[0][0]                  
__________________________________________________________________________________________________
re_lu_173 (ReLU)                (None, 4, 4, 512)    0           add_56[0][0]                     
__________________________________________________________________________________________________
conv2d_191 (Conv2D)             (None, 2, 2, 256)    131072      re_lu_173[0][0]                  
__________________________________________________________________________________________________
batch_normalization_190 (BatchN (None, 2, 2, 256)    1024        conv2d_191[0][0]                 
__________________________________________________________________________________________________
re_lu_174 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_190[0][0]    
__________________________________________________________________________________________________
conv2d_192 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_174[0][0]                  
__________________________________________________________________________________________________
batch_normalization_191 (BatchN (None, 2, 2, 256)    1024        conv2d_192[0][0]                 
__________________________________________________________________________________________________
re_lu_175 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_191[0][0]    
__________________________________________________________________________________________________
conv2d_193 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_175[0][0]                  
__________________________________________________________________________________________________
conv2d_194 (Conv2D)             (None, 2, 2, 1024)   525312      re_lu_173[0][0]                  
__________________________________________________________________________________________________
batch_normalization_192 (BatchN (None, 2, 2, 1024)   4096        conv2d_193[0][0]                 
__________________________________________________________________________________________________
batch_normalization_193 (BatchN (None, 2, 2, 1024)   4096        conv2d_194[0][0]                 
__________________________________________________________________________________________________
add_57 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_192[0][0]    
                                                                 batch_normalization_193[0][0]    
__________________________________________________________________________________________________
re_lu_176 (ReLU)                (None, 2, 2, 1024)   0           add_57[0][0]                     
__________________________________________________________________________________________________
conv2d_195 (Conv2D)             (None, 2, 2, 256)    262144      re_lu_176[0][0]                  
__________________________________________________________________________________________________
batch_normalization_194 (BatchN (None, 2, 2, 256)    1024        conv2d_195[0][0]                 
__________________________________________________________________________________________________
re_lu_177 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_194[0][0]    
__________________________________________________________________________________________________
conv2d_196 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_177[0][0]                  
__________________________________________________________________________________________________
batch_normalization_195 (BatchN (None, 2, 2, 256)    1024        conv2d_196[0][0]                 
__________________________________________________________________________________________________
re_lu_178 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_195[0][0]    
__________________________________________________________________________________________________
conv2d_197 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_178[0][0]                  
__________________________________________________________________________________________________
batch_normalization_196 (BatchN (None, 2, 2, 1024)   4096        conv2d_197[0][0]                 
__________________________________________________________________________________________________
add_58 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_196[0][0]    
                                                                 re_lu_176[0][0]                  
__________________________________________________________________________________________________
re_lu_179 (ReLU)                (None, 2, 2, 1024)   0           add_58[0][0]                     
__________________________________________________________________________________________________
conv2d_198 (Conv2D)             (None, 2, 2, 256)    262144      re_lu_179[0][0]                  
__________________________________________________________________________________________________
batch_normalization_197 (BatchN (None, 2, 2, 256)    1024        conv2d_198[0][0]                 
__________________________________________________________________________________________________
re_lu_180 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_197[0][0]    
__________________________________________________________________________________________________
conv2d_199 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_180[0][0]                  
__________________________________________________________________________________________________
batch_normalization_198 (BatchN (None, 2, 2, 256)    1024        conv2d_199[0][0]                 
__________________________________________________________________________________________________
re_lu_181 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_198[0][0]    
__________________________________________________________________________________________________
conv2d_200 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_181[0][0]                  
__________________________________________________________________________________________________
batch_normalization_199 (BatchN (None, 2, 2, 1024)   4096        conv2d_200[0][0]                 
__________________________________________________________________________________________________
add_59 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_199[0][0]    
                                                                 re_lu_179[0][0]                  
__________________________________________________________________________________________________
re_lu_182 (ReLU)                (None, 2, 2, 1024)   0           add_59[0][0]                     
__________________________________________________________________________________________________
conv2d_201 (Conv2D)             (None, 2, 2, 256)    262144      re_lu_182[0][0]                  
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 2, 2, 256)    1024        conv2d_201[0][0]                 
__________________________________________________________________________________________________
re_lu_183 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
conv2d_202 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_183[0][0]                  
__________________________________________________________________________________________________
batch_normalization_201 (BatchN (None, 2, 2, 256)    1024        conv2d_202[0][0]                 
__________________________________________________________________________________________________
re_lu_184 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_201[0][0]    
__________________________________________________________________________________________________
conv2d_203 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_184[0][0]                  
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 2, 2, 1024)   4096        conv2d_203[0][0]                 
__________________________________________________________________________________________________
add_60 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_202[0][0]    
                                                                 re_lu_182[0][0]                  
__________________________________________________________________________________________________
re_lu_185 (ReLU)                (None, 2, 2, 1024)   0           add_60[0][0]                     
__________________________________________________________________________________________________
conv2d_204 (Conv2D)             (None, 2, 2, 256)    262144      re_lu_185[0][0]                  
__________________________________________________________________________________________________
batch_normalization_203 (BatchN (None, 2, 2, 256)    1024        conv2d_204[0][0]                 
__________________________________________________________________________________________________
re_lu_186 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_203[0][0]    
__________________________________________________________________________________________________
conv2d_205 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_186[0][0]                  
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 2, 2, 256)    1024        conv2d_205[0][0]                 
__________________________________________________________________________________________________
re_lu_187 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
conv2d_206 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_187[0][0]                  
__________________________________________________________________________________________________
batch_normalization_205 (BatchN (None, 2, 2, 1024)   4096        conv2d_206[0][0]                 
__________________________________________________________________________________________________
add_61 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_205[0][0]    
                                                                 re_lu_185[0][0]                  
__________________________________________________________________________________________________
re_lu_188 (ReLU)                (None, 2, 2, 1024)   0           add_61[0][0]                     
__________________________________________________________________________________________________
conv2d_207 (Conv2D)             (None, 2, 2, 256)    262144      re_lu_188[0][0]                  
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 2, 2, 256)    1024        conv2d_207[0][0]                 
__________________________________________________________________________________________________
re_lu_189 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
conv2d_208 (Conv2D)             (None, 2, 2, 256)    589824      re_lu_189[0][0]                  
__________________________________________________________________________________________________
batch_normalization_207 (BatchN (None, 2, 2, 256)    1024        conv2d_208[0][0]                 
__________________________________________________________________________________________________
re_lu_190 (ReLU)                (None, 2, 2, 256)    0           batch_normalization_207[0][0]    
__________________________________________________________________________________________________
conv2d_209 (Conv2D)             (None, 2, 2, 1024)   263168      re_lu_190[0][0]                  
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 2, 2, 1024)   4096        conv2d_209[0][0]                 
__________________________________________________________________________________________________
add_62 (Add)                    (None, 2, 2, 1024)   0           batch_normalization_208[0][0]    
                                                                 re_lu_188[0][0]                  
__________________________________________________________________________________________________
re_lu_191 (ReLU)                (None, 2, 2, 1024)   0           add_62[0][0]                     
__________________________________________________________________________________________________
conv2d_210 (Conv2D)             (None, 1, 1, 512)    524288      re_lu_191[0][0]                  
__________________________________________________________________________________________________
batch_normalization_209 (BatchN (None, 1, 1, 512)    2048        conv2d_210[0][0]                 
__________________________________________________________________________________________________
re_lu_192 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_209[0][0]    
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 1, 1, 512)    2359296     re_lu_192[0][0]                  
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 1, 1, 512)    2048        conv2d_211[0][0]                 
__________________________________________________________________________________________________
re_lu_193 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_212 (Conv2D)             (None, 1, 1, 2048)   1050624     re_lu_193[0][0]                  
__________________________________________________________________________________________________
conv2d_213 (Conv2D)             (None, 1, 1, 2048)   2099200     re_lu_191[0][0]                  
__________________________________________________________________________________________________
batch_normalization_211 (BatchN (None, 1, 1, 2048)   8192        conv2d_212[0][0]                 
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 1, 1, 2048)   8192        conv2d_213[0][0]                 
__________________________________________________________________________________________________
add_63 (Add)                    (None, 1, 1, 2048)   0           batch_normalization_211[0][0]    
                                                                 batch_normalization_212[0][0]    
__________________________________________________________________________________________________
re_lu_194 (ReLU)                (None, 1, 1, 2048)   0           add_63[0][0]                     
__________________________________________________________________________________________________
conv2d_214 (Conv2D)             (None, 1, 1, 512)    1048576     re_lu_194[0][0]                  
__________________________________________________________________________________________________
batch_normalization_213 (BatchN (None, 1, 1, 512)    2048        conv2d_214[0][0]                 
__________________________________________________________________________________________________
re_lu_195 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_213[0][0]    
__________________________________________________________________________________________________
conv2d_215 (Conv2D)             (None, 1, 1, 512)    2359296     re_lu_195[0][0]                  
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 1, 1, 512)    2048        conv2d_215[0][0]                 
__________________________________________________________________________________________________
re_lu_196 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_216 (Conv2D)             (None, 1, 1, 2048)   1050624     re_lu_196[0][0]                  
__________________________________________________________________________________________________
batch_normalization_215 (BatchN (None, 1, 1, 2048)   8192        conv2d_216[0][0]                 
__________________________________________________________________________________________________
add_64 (Add)                    (None, 1, 1, 2048)   0           batch_normalization_215[0][0]    
                                                                 re_lu_194[0][0]                  
__________________________________________________________________________________________________
re_lu_197 (ReLU)                (None, 1, 1, 2048)   0           add_64[0][0]                     
__________________________________________________________________________________________________
conv2d_217 (Conv2D)             (None, 1, 1, 512)    1048576     re_lu_197[0][0]                  
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 1, 1, 512)    2048        conv2d_217[0][0]                 
__________________________________________________________________________________________________
re_lu_198 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_218 (Conv2D)             (None, 1, 1, 512)    2359296     re_lu_198[0][0]                  
__________________________________________________________________________________________________
batch_normalization_217 (BatchN (None, 1, 1, 512)    2048        conv2d_218[0][0]                 
__________________________________________________________________________________________________
re_lu_199 (ReLU)                (None, 1, 1, 512)    0           batch_normalization_217[0][0]    
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 1, 1, 2048)   1050624     re_lu_199[0][0]                  
__________________________________________________________________________________________________
batch_normalization_218 (BatchN (None, 1, 1, 2048)   8192        conv2d_219[0][0]                 
__________________________________________________________________________________________________
add_65 (Add)                    (None, 1, 1, 2048)   0           batch_normalization_218[0][0]    
                                                                 re_lu_197[0][0]                  
__________________________________________________________________________________________________
re_lu_200 (ReLU)                (None, 1, 1, 2048)   0           add_65[0][0]                     
__________________________________________________________________________________________________
global_average_pooling2d_3 (Glo (None, 2048)         0           re_lu_200[0][0]                  
__________________________________________________________________________________________________
dense (Dense)                   (None, 10)           20490       global_average_pooling2d_3[0][0] 
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 10)           0           dense[0][0]                      
==================================================================================================
Total params: 23,600,586
Trainable params: 23,547,466
Non-trainable params: 53,120
__________________________________________________________________________________________________

1
2
3
model.compile(loss = 'categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(lr=0.001),
              metrics=['accuracy'])
1
2
3
4
5
6
7
import time
start = time.time()
history = model.fit(ds_train,
                    epochs=20,
                    validation_data = ds_val)

print(f"It took {time.time() - start} seconds")
Epoch 1/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.9570 - accuracy: 0.6647 - val_loss: 1.0660 - val_accuracy: 0.6246
Epoch 2/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.9468 - accuracy: 0.6715 - val_loss: 0.9648 - val_accuracy: 0.6563
Epoch 3/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.9169 - accuracy: 0.6779 - val_loss: 0.9009 - val_accuracy: 0.6806
Epoch 4/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.8700 - accuracy: 0.6936 - val_loss: 0.9810 - val_accuracy: 0.6605
Epoch 5/20
1172/1172 [==============================] - 58s 49ms/step - loss: 0.8400 - accuracy: 0.7034 - val_loss: 1.0241 - val_accuracy: 0.6444
Epoch 6/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.8027 - accuracy: 0.7189 - val_loss: 0.9003 - val_accuracy: 0.6813
Epoch 7/20
1172/1172 [==============================] - 58s 49ms/step - loss: 0.7728 - accuracy: 0.7295 - val_loss: 1.0182 - val_accuracy: 0.6477
Epoch 8/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.7569 - accuracy: 0.7333 - val_loss: 1.1326 - val_accuracy: 0.6317
Epoch 9/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.7486 - accuracy: 0.7394 - val_loss: 0.9695 - val_accuracy: 0.6631
Epoch 10/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.7180 - accuracy: 0.7495 - val_loss: 0.8269 - val_accuracy: 0.7190
Epoch 11/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.6806 - accuracy: 0.7623 - val_loss: 0.8275 - val_accuracy: 0.7151
Epoch 12/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.6671 - accuracy: 0.7681 - val_loss: 0.8477 - val_accuracy: 0.7069
Epoch 13/20
1172/1172 [==============================] - 57s 49ms/step - loss: 0.6447 - accuracy: 0.7735 - val_loss: 0.8543 - val_accuracy: 0.7118
Epoch 14/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.6236 - accuracy: 0.7800 - val_loss: 1.1812 - val_accuracy: 0.6317
Epoch 15/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.6177 - accuracy: 0.7875 - val_loss: 1.0423 - val_accuracy: 0.6597
Epoch 16/20
1172/1172 [==============================] - 56s 48ms/step - loss: 0.6007 - accuracy: 0.7906 - val_loss: 0.8287 - val_accuracy: 0.7252
Epoch 17/20
1172/1172 [==============================] - 56s 48ms/step - loss: 0.5753 - accuracy: 0.7982 - val_loss: 0.7585 - val_accuracy: 0.7417
Epoch 18/20
1172/1172 [==============================] - 56s 48ms/step - loss: 0.5758 - accuracy: 0.7979 - val_loss: 0.7436 - val_accuracy: 0.7445
Epoch 19/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.5455 - accuracy: 0.8107 - val_loss: 0.7485 - val_accuracy: 0.7460
Epoch 20/20
1172/1172 [==============================] - 57s 48ms/step - loss: 0.5747 - accuracy: 0.7996 - val_loss: 0.7514 - val_accuracy: 0.7444
It took 1148.7555301189423 seconds

1
2
3
4
5
6
7
import pandas as pd
import matplotlib.pyplot as plt

pd.DataFrame(history.history).plot(figsize=(12,8))
plt.grid(True)
plt.gca().set_ylim(0,2)
plt.show()

DenseNet

Idée

  • As CNNs become increasingly deep, a new research problem emerges : as information about the input of gradient passes through many layers, it can vanish and "wash out" by the time it reaches the end (or beginning) of the network.

  • to ensure maximum information flow between layers in the network, we connect all layers (with matching feature maps sizes) directly with each other.

  • Crucially, in contrast to ResNets, we never combine features through summation before they are passed into a layer, instead, we combine features by concatenating them.

  • the final classifier makes a decision based on all features maps in the network.

Définition des briques de bases

  • To facilitate down-sampling in our architecture we divide the network into multiple densely connected dense blocks.

  • We refer to layers between blocks as transition blocks.

\(\implies\) Le modèle est donc articulé autour d'une architecture qui alterne deux types de blocs :

  • Les blocs dits "denses",
  • Les blocs dits de "transitions".
Composite function & dense blocks

De l'article, nous tirons les indications suivantes.

  • The network comprises \(L\) layers, each of which implements a non-linear transformation \(H_{\ell}(-)\), where \(\ell\) indexes the layer. [...] We denote the output of the \(\ell^{th}\) layer as \(x_{\ell}\).

  • Consequenttly, the \(\ell^{th}\) layer receives the feature maps of all preceding layers, \(x_{0}, \dots, x_{\ell-1}\) as input :

\[x_{\ell} = H_{\ell}([x_{0},x_{1} \dots, x_{\ell-1}])\]
  • We define \(H_{\ell}\) as a compite function of the consecutive operations :
  • BN-ReLU-Conv(\(1 \times 1\))-BN-ReLU-Conv(\(3 \times 3\)),
  • Each Conv(3 \(\times\) 3) produces \(k\) features maps, [...] we let each Conv(1 \(\times\) 1) produce \(4k\) feature maps.
    • We refer to the hyperparameter \(k\) as the growth rate of the network.
    • The growth rate for all networks is \(k=32\)
  • For convolutionnal layers with kernel size \(3 \times 3\), each side of the inputs is zero-padded by one pixel to keep the feature-map size fixed.
  • We adopt the weight initialization introduced by [10] (ie He)
def bn_relu_conv(tensor, k, kernel_size):
  x = BatchNormalization()(tensor)
  x = ReLU()(x)
  x = Conv2D(filters=k,
             kernel_size=kernel_size,
             strides=(1,1),
             padding='same',
             kernel_initializer='he_normal',
             use_bias=False)(x)
  return x

Les blocs denses ont un schéma répétitif.

from tensorflow.keras.layers import Concatenate

def dense_block(tensor, k, reps):
  for _ in range(reps):
    x = bn_relu_conv(tensor, 4*k, 1)
    x = bn_relu_conv(x, k, 3)

    tensor = Concatenate()([x, tensor])  # le tenseur d'input en entrée par définition

  return tensor
Transition function
  • If a dense blocks contains m feature-maps, we let the following transition layer generate \(\lfloor \theta m \rfloor\) output feature maps, where \(0 < \theta \leq 1\) is referred as the compression factor.
  • We set \(\theta = 0.5\) in our experiment.
  • We use BN-ReLU-Conv(\(1 \times 1\)) followed by \(2 \times 2\) average pooling as transition layers between two contiguous dense blocks.

Pour avoir accès au nombre de feature maps, on a besoin d'utiliser la commande tf.keras.backend.int_shape(x)[-1].

1
2
3
4
5
6
7
from tensorflow.keras.layers import AvgPool2D

def transition_layer(x, theta):
    f = int(tf.keras.backend.int_shape(x)[-1] * theta)
    x = bn_relu_conv(x, f, 1)
    x = AvgPool2D(pool_size=2, strides=2, padding='same')(x)
    return x
  • The initial convolution layer comprises \(2k\) convolutions of size \(7\times7\) with stride \(2\).
  • At the end of the last dense blocks, a global average pooling is performed and then a softmax classfier is attached.
from tensorflow.keras import Model

def DenseNet(img_shape, k, theta, repetitions, small=True, include_top=True, num_classes=10):

  input = Input(img_shape)

  if small:
    x = Conv2D(filters = 16,
               kernel_size = 7,
               strides=2,
               padding='same',
               kernel_initializer='he_normal')(input)
  else:
    x = Conv2D(filters = 2*k,
               kernel_size = 7,
               strides=2,
               padding='same',
               kernel_initializer='he_normal')(input)
    x = MaxPool2D(pool_size = 3,
                  strides = 2,
                  padding='same')(x)

  for reps in repetitions:
      d = dense_block(x, k, reps)
      x = transition_layer(d, theta)

  if include_top:
    x = GlobalAvgPool2D()(d)
    x = Dense(num_classes)(x)
    output = Activation('softmax')(x)
  else:
    output = GlobalAvgPool2D()(d)

  model = Model(input, output)

  return model
1
2
3
4
5
6
7
img_shape = 32, 32, 3
k = 32
theta = 0.5
repetitions = 6, 12, 24, 16
num_classes = 100

model = DenseNet(img_shape, k, theta, repetitions, num_classes=num_classes)
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_123 (Conv2D)             (None, 16, 16, 16)   2368        input_4[0][0]                    
__________________________________________________________________________________________________
batch_normalization_120 (BatchN (None, 16, 16, 16)   64          conv2d_123[0][0]                 
__________________________________________________________________________________________________
re_lu_120 (ReLU)                (None, 16, 16, 16)   0           batch_normalization_120[0][0]    
__________________________________________________________________________________________________
conv2d_124 (Conv2D)             (None, 16, 16, 128)  2048        re_lu_120[0][0]                  
__________________________________________________________________________________________________
batch_normalization_121 (BatchN (None, 16, 16, 128)  512         conv2d_124[0][0]                 
__________________________________________________________________________________________________
re_lu_121 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_121[0][0]    
__________________________________________________________________________________________________
conv2d_125 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_121[0][0]                  
__________________________________________________________________________________________________
concatenate_58 (Concatenate)    (None, 16, 16, 48)   0           conv2d_123[0][0]                 
                                                                 conv2d_125[0][0]                 
__________________________________________________________________________________________________
batch_normalization_122 (BatchN (None, 16, 16, 48)   192         concatenate_58[0][0]             
__________________________________________________________________________________________________
re_lu_122 (ReLU)                (None, 16, 16, 48)   0           batch_normalization_122[0][0]    
__________________________________________________________________________________________________
conv2d_126 (Conv2D)             (None, 16, 16, 128)  6144        re_lu_122[0][0]                  
__________________________________________________________________________________________________
batch_normalization_123 (BatchN (None, 16, 16, 128)  512         conv2d_126[0][0]                 
__________________________________________________________________________________________________
re_lu_123 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_123[0][0]    
__________________________________________________________________________________________________
conv2d_127 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_123[0][0]                  
__________________________________________________________________________________________________
concatenate_59 (Concatenate)    (None, 16, 16, 80)   0           concatenate_58[0][0]             
                                                                 conv2d_127[0][0]                 
__________________________________________________________________________________________________
batch_normalization_124 (BatchN (None, 16, 16, 80)   320         concatenate_59[0][0]             
__________________________________________________________________________________________________
re_lu_124 (ReLU)                (None, 16, 16, 80)   0           batch_normalization_124[0][0]    
__________________________________________________________________________________________________
conv2d_128 (Conv2D)             (None, 16, 16, 128)  10240       re_lu_124[0][0]                  
__________________________________________________________________________________________________
batch_normalization_125 (BatchN (None, 16, 16, 128)  512         conv2d_128[0][0]                 
__________________________________________________________________________________________________
re_lu_125 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_125[0][0]    
__________________________________________________________________________________________________
conv2d_129 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_125[0][0]                  
__________________________________________________________________________________________________
concatenate_60 (Concatenate)    (None, 16, 16, 112)  0           concatenate_59[0][0]             
                                                                 conv2d_129[0][0]                 
__________________________________________________________________________________________________
batch_normalization_126 (BatchN (None, 16, 16, 112)  448         concatenate_60[0][0]             
__________________________________________________________________________________________________
re_lu_126 (ReLU)                (None, 16, 16, 112)  0           batch_normalization_126[0][0]    
__________________________________________________________________________________________________
conv2d_130 (Conv2D)             (None, 16, 16, 128)  14336       re_lu_126[0][0]                  
__________________________________________________________________________________________________
batch_normalization_127 (BatchN (None, 16, 16, 128)  512         conv2d_130[0][0]                 
__________________________________________________________________________________________________
re_lu_127 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_127[0][0]    
__________________________________________________________________________________________________
conv2d_131 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_127[0][0]                  
__________________________________________________________________________________________________
concatenate_61 (Concatenate)    (None, 16, 16, 144)  0           concatenate_60[0][0]             
                                                                 conv2d_131[0][0]                 
__________________________________________________________________________________________________
batch_normalization_128 (BatchN (None, 16, 16, 144)  576         concatenate_61[0][0]             
__________________________________________________________________________________________________
re_lu_128 (ReLU)                (None, 16, 16, 144)  0           batch_normalization_128[0][0]    
__________________________________________________________________________________________________
conv2d_132 (Conv2D)             (None, 16, 16, 128)  18432       re_lu_128[0][0]                  
__________________________________________________________________________________________________
batch_normalization_129 (BatchN (None, 16, 16, 128)  512         conv2d_132[0][0]                 
__________________________________________________________________________________________________
re_lu_129 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_129[0][0]    
__________________________________________________________________________________________________
conv2d_133 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_129[0][0]                  
__________________________________________________________________________________________________
concatenate_62 (Concatenate)    (None, 16, 16, 176)  0           concatenate_61[0][0]             
                                                                 conv2d_133[0][0]                 
__________________________________________________________________________________________________
batch_normalization_130 (BatchN (None, 16, 16, 176)  704         concatenate_62[0][0]             
__________________________________________________________________________________________________
re_lu_130 (ReLU)                (None, 16, 16, 176)  0           batch_normalization_130[0][0]    
__________________________________________________________________________________________________
conv2d_134 (Conv2D)             (None, 16, 16, 128)  22528       re_lu_130[0][0]                  
__________________________________________________________________________________________________
batch_normalization_131 (BatchN (None, 16, 16, 128)  512         conv2d_134[0][0]                 
__________________________________________________________________________________________________
re_lu_131 (ReLU)                (None, 16, 16, 128)  0           batch_normalization_131[0][0]    
__________________________________________________________________________________________________
conv2d_135 (Conv2D)             (None, 16, 16, 32)   36864       re_lu_131[0][0]                  
__________________________________________________________________________________________________
concatenate_63 (Concatenate)    (None, 16, 16, 208)  0           concatenate_62[0][0]             
                                                                 conv2d_135[0][0]                 
__________________________________________________________________________________________________
batch_normalization_132 (BatchN (None, 16, 16, 208)  832         concatenate_63[0][0]             
__________________________________________________________________________________________________
re_lu_132 (ReLU)                (None, 16, 16, 208)  0           batch_normalization_132[0][0]    
__________________________________________________________________________________________________
conv2d_136 (Conv2D)             (None, 16, 16, 104)  21632       re_lu_132[0][0]                  
__________________________________________________________________________________________________
average_pooling2d_4 (AveragePoo (None, 8, 8, 104)    0           conv2d_136[0][0]                 
__________________________________________________________________________________________________
batch_normalization_133 (BatchN (None, 8, 8, 104)    416         average_pooling2d_4[0][0]        
__________________________________________________________________________________________________
re_lu_133 (ReLU)                (None, 8, 8, 104)    0           batch_normalization_133[0][0]    
__________________________________________________________________________________________________
conv2d_137 (Conv2D)             (None, 8, 8, 128)    13312       re_lu_133[0][0]                  
__________________________________________________________________________________________________
batch_normalization_134 (BatchN (None, 8, 8, 128)    512         conv2d_137[0][0]                 
__________________________________________________________________________________________________
re_lu_134 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_134[0][0]    
__________________________________________________________________________________________________
conv2d_138 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_134[0][0]                  
__________________________________________________________________________________________________
concatenate_64 (Concatenate)    (None, 8, 8, 136)    0           average_pooling2d_4[0][0]        
                                                                 conv2d_138[0][0]                 
__________________________________________________________________________________________________
batch_normalization_135 (BatchN (None, 8, 8, 136)    544         concatenate_64[0][0]             
__________________________________________________________________________________________________
re_lu_135 (ReLU)                (None, 8, 8, 136)    0           batch_normalization_135[0][0]    
__________________________________________________________________________________________________
conv2d_139 (Conv2D)             (None, 8, 8, 128)    17408       re_lu_135[0][0]                  
__________________________________________________________________________________________________
batch_normalization_136 (BatchN (None, 8, 8, 128)    512         conv2d_139[0][0]                 
__________________________________________________________________________________________________
re_lu_136 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_136[0][0]    
__________________________________________________________________________________________________
conv2d_140 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_136[0][0]                  
__________________________________________________________________________________________________
concatenate_65 (Concatenate)    (None, 8, 8, 168)    0           concatenate_64[0][0]             
                                                                 conv2d_140[0][0]                 
__________________________________________________________________________________________________
batch_normalization_137 (BatchN (None, 8, 8, 168)    672         concatenate_65[0][0]             
__________________________________________________________________________________________________
re_lu_137 (ReLU)                (None, 8, 8, 168)    0           batch_normalization_137[0][0]    
__________________________________________________________________________________________________
conv2d_141 (Conv2D)             (None, 8, 8, 128)    21504       re_lu_137[0][0]                  
__________________________________________________________________________________________________
batch_normalization_138 (BatchN (None, 8, 8, 128)    512         conv2d_141[0][0]                 
__________________________________________________________________________________________________
re_lu_138 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_138[0][0]    
__________________________________________________________________________________________________
conv2d_142 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_138[0][0]                  
__________________________________________________________________________________________________
concatenate_66 (Concatenate)    (None, 8, 8, 200)    0           concatenate_65[0][0]             
                                                                 conv2d_142[0][0]                 
__________________________________________________________________________________________________
batch_normalization_139 (BatchN (None, 8, 8, 200)    800         concatenate_66[0][0]             
__________________________________________________________________________________________________
re_lu_139 (ReLU)                (None, 8, 8, 200)    0           batch_normalization_139[0][0]    
__________________________________________________________________________________________________
conv2d_143 (Conv2D)             (None, 8, 8, 128)    25600       re_lu_139[0][0]                  
__________________________________________________________________________________________________
batch_normalization_140 (BatchN (None, 8, 8, 128)    512         conv2d_143[0][0]                 
__________________________________________________________________________________________________
re_lu_140 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_140[0][0]    
__________________________________________________________________________________________________
conv2d_144 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_140[0][0]                  
__________________________________________________________________________________________________
concatenate_67 (Concatenate)    (None, 8, 8, 232)    0           concatenate_66[0][0]             
                                                                 conv2d_144[0][0]                 
__________________________________________________________________________________________________
batch_normalization_141 (BatchN (None, 8, 8, 232)    928         concatenate_67[0][0]             
__________________________________________________________________________________________________
re_lu_141 (ReLU)                (None, 8, 8, 232)    0           batch_normalization_141[0][0]    
__________________________________________________________________________________________________
conv2d_145 (Conv2D)             (None, 8, 8, 128)    29696       re_lu_141[0][0]                  
__________________________________________________________________________________________________
batch_normalization_142 (BatchN (None, 8, 8, 128)    512         conv2d_145[0][0]                 
__________________________________________________________________________________________________
re_lu_142 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_142[0][0]    
__________________________________________________________________________________________________
conv2d_146 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_142[0][0]                  
__________________________________________________________________________________________________
concatenate_68 (Concatenate)    (None, 8, 8, 264)    0           concatenate_67[0][0]             
                                                                 conv2d_146[0][0]                 
__________________________________________________________________________________________________
batch_normalization_143 (BatchN (None, 8, 8, 264)    1056        concatenate_68[0][0]             
__________________________________________________________________________________________________
re_lu_143 (ReLU)                (None, 8, 8, 264)    0           batch_normalization_143[0][0]    
__________________________________________________________________________________________________
conv2d_147 (Conv2D)             (None, 8, 8, 128)    33792       re_lu_143[0][0]                  
__________________________________________________________________________________________________
batch_normalization_144 (BatchN (None, 8, 8, 128)    512         conv2d_147[0][0]                 
__________________________________________________________________________________________________
re_lu_144 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_144[0][0]    
__________________________________________________________________________________________________
conv2d_148 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_144[0][0]                  
__________________________________________________________________________________________________
concatenate_69 (Concatenate)    (None, 8, 8, 296)    0           concatenate_68[0][0]             
                                                                 conv2d_148[0][0]                 
__________________________________________________________________________________________________
batch_normalization_145 (BatchN (None, 8, 8, 296)    1184        concatenate_69[0][0]             
__________________________________________________________________________________________________
re_lu_145 (ReLU)                (None, 8, 8, 296)    0           batch_normalization_145[0][0]    
__________________________________________________________________________________________________
conv2d_149 (Conv2D)             (None, 8, 8, 128)    37888       re_lu_145[0][0]                  
__________________________________________________________________________________________________
batch_normalization_146 (BatchN (None, 8, 8, 128)    512         conv2d_149[0][0]                 
__________________________________________________________________________________________________
re_lu_146 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_146[0][0]    
__________________________________________________________________________________________________
conv2d_150 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_146[0][0]                  
__________________________________________________________________________________________________
concatenate_70 (Concatenate)    (None, 8, 8, 328)    0           concatenate_69[0][0]             
                                                                 conv2d_150[0][0]                 
__________________________________________________________________________________________________
batch_normalization_147 (BatchN (None, 8, 8, 328)    1312        concatenate_70[0][0]             
__________________________________________________________________________________________________
re_lu_147 (ReLU)                (None, 8, 8, 328)    0           batch_normalization_147[0][0]    
__________________________________________________________________________________________________
conv2d_151 (Conv2D)             (None, 8, 8, 128)    41984       re_lu_147[0][0]                  
__________________________________________________________________________________________________
batch_normalization_148 (BatchN (None, 8, 8, 128)    512         conv2d_151[0][0]                 
__________________________________________________________________________________________________
re_lu_148 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_148[0][0]    
__________________________________________________________________________________________________
conv2d_152 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_148[0][0]                  
__________________________________________________________________________________________________
concatenate_71 (Concatenate)    (None, 8, 8, 360)    0           concatenate_70[0][0]             
                                                                 conv2d_152[0][0]                 
__________________________________________________________________________________________________
batch_normalization_149 (BatchN (None, 8, 8, 360)    1440        concatenate_71[0][0]             
__________________________________________________________________________________________________
re_lu_149 (ReLU)                (None, 8, 8, 360)    0           batch_normalization_149[0][0]    
__________________________________________________________________________________________________
conv2d_153 (Conv2D)             (None, 8, 8, 128)    46080       re_lu_149[0][0]                  
__________________________________________________________________________________________________
batch_normalization_150 (BatchN (None, 8, 8, 128)    512         conv2d_153[0][0]                 
__________________________________________________________________________________________________
re_lu_150 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_150[0][0]    
__________________________________________________________________________________________________
conv2d_154 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_150[0][0]                  
__________________________________________________________________________________________________
concatenate_72 (Concatenate)    (None, 8, 8, 392)    0           concatenate_71[0][0]             
                                                                 conv2d_154[0][0]                 
__________________________________________________________________________________________________
batch_normalization_151 (BatchN (None, 8, 8, 392)    1568        concatenate_72[0][0]             
__________________________________________________________________________________________________
re_lu_151 (ReLU)                (None, 8, 8, 392)    0           batch_normalization_151[0][0]    
__________________________________________________________________________________________________
conv2d_155 (Conv2D)             (None, 8, 8, 128)    50176       re_lu_151[0][0]                  
__________________________________________________________________________________________________
batch_normalization_152 (BatchN (None, 8, 8, 128)    512         conv2d_155[0][0]                 
__________________________________________________________________________________________________
re_lu_152 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_152[0][0]    
__________________________________________________________________________________________________
conv2d_156 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_152[0][0]                  
__________________________________________________________________________________________________
concatenate_73 (Concatenate)    (None, 8, 8, 424)    0           concatenate_72[0][0]             
                                                                 conv2d_156[0][0]                 
__________________________________________________________________________________________________
batch_normalization_153 (BatchN (None, 8, 8, 424)    1696        concatenate_73[0][0]             
__________________________________________________________________________________________________
re_lu_153 (ReLU)                (None, 8, 8, 424)    0           batch_normalization_153[0][0]    
__________________________________________________________________________________________________
conv2d_157 (Conv2D)             (None, 8, 8, 128)    54272       re_lu_153[0][0]                  
__________________________________________________________________________________________________
batch_normalization_154 (BatchN (None, 8, 8, 128)    512         conv2d_157[0][0]                 
__________________________________________________________________________________________________
re_lu_154 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_154[0][0]    
__________________________________________________________________________________________________
conv2d_158 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_154[0][0]                  
__________________________________________________________________________________________________
concatenate_74 (Concatenate)    (None, 8, 8, 456)    0           concatenate_73[0][0]             
                                                                 conv2d_158[0][0]                 
__________________________________________________________________________________________________
batch_normalization_155 (BatchN (None, 8, 8, 456)    1824        concatenate_74[0][0]             
__________________________________________________________________________________________________
re_lu_155 (ReLU)                (None, 8, 8, 456)    0           batch_normalization_155[0][0]    
__________________________________________________________________________________________________
conv2d_159 (Conv2D)             (None, 8, 8, 128)    58368       re_lu_155[0][0]                  
__________________________________________________________________________________________________
batch_normalization_156 (BatchN (None, 8, 8, 128)    512         conv2d_159[0][0]                 
__________________________________________________________________________________________________
re_lu_156 (ReLU)                (None, 8, 8, 128)    0           batch_normalization_156[0][0]    
__________________________________________________________________________________________________
conv2d_160 (Conv2D)             (None, 8, 8, 32)     36864       re_lu_156[0][0]                  
__________________________________________________________________________________________________
concatenate_75 (Concatenate)    (None, 8, 8, 488)    0           concatenate_74[0][0]             
                                                                 conv2d_160[0][0]                 
__________________________________________________________________________________________________
batch_normalization_157 (BatchN (None, 8, 8, 488)    1952        concatenate_75[0][0]             
__________________________________________________________________________________________________
re_lu_157 (ReLU)                (None, 8, 8, 488)    0           batch_normalization_157[0][0]    
__________________________________________________________________________________________________
conv2d_161 (Conv2D)             (None, 8, 8, 244)    119072      re_lu_157[0][0]                  
__________________________________________________________________________________________________
average_pooling2d_5 (AveragePoo (None, 4, 4, 244)    0           conv2d_161[0][0]                 
__________________________________________________________________________________________________
batch_normalization_158 (BatchN (None, 4, 4, 244)    976         average_pooling2d_5[0][0]        
__________________________________________________________________________________________________
re_lu_158 (ReLU)                (None, 4, 4, 244)    0           batch_normalization_158[0][0]    
__________________________________________________________________________________________________
conv2d_162 (Conv2D)             (None, 4, 4, 128)    31232       re_lu_158[0][0]                  
__________________________________________________________________________________________________
batch_normalization_159 (BatchN (None, 4, 4, 128)    512         conv2d_162[0][0]                 
__________________________________________________________________________________________________
re_lu_159 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_159[0][0]    
__________________________________________________________________________________________________
conv2d_163 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_159[0][0]                  
__________________________________________________________________________________________________
concatenate_76 (Concatenate)    (None, 4, 4, 276)    0           average_pooling2d_5[0][0]        
                                                                 conv2d_163[0][0]                 
__________________________________________________________________________________________________
batch_normalization_160 (BatchN (None, 4, 4, 276)    1104        concatenate_76[0][0]             
__________________________________________________________________________________________________
re_lu_160 (ReLU)                (None, 4, 4, 276)    0           batch_normalization_160[0][0]    
__________________________________________________________________________________________________
conv2d_164 (Conv2D)             (None, 4, 4, 128)    35328       re_lu_160[0][0]                  
__________________________________________________________________________________________________
batch_normalization_161 (BatchN (None, 4, 4, 128)    512         conv2d_164[0][0]                 
__________________________________________________________________________________________________
re_lu_161 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_161[0][0]    
__________________________________________________________________________________________________
conv2d_165 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_161[0][0]                  
__________________________________________________________________________________________________
concatenate_77 (Concatenate)    (None, 4, 4, 308)    0           concatenate_76[0][0]             
                                                                 conv2d_165[0][0]                 
__________________________________________________________________________________________________
batch_normalization_162 (BatchN (None, 4, 4, 308)    1232        concatenate_77[0][0]             
__________________________________________________________________________________________________
re_lu_162 (ReLU)                (None, 4, 4, 308)    0           batch_normalization_162[0][0]    
__________________________________________________________________________________________________
conv2d_166 (Conv2D)             (None, 4, 4, 128)    39424       re_lu_162[0][0]                  
__________________________________________________________________________________________________
batch_normalization_163 (BatchN (None, 4, 4, 128)    512         conv2d_166[0][0]                 
__________________________________________________________________________________________________
re_lu_163 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_163[0][0]    
__________________________________________________________________________________________________
conv2d_167 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_163[0][0]                  
__________________________________________________________________________________________________
concatenate_78 (Concatenate)    (None, 4, 4, 340)    0           concatenate_77[0][0]             
                                                                 conv2d_167[0][0]                 
__________________________________________________________________________________________________
batch_normalization_164 (BatchN (None, 4, 4, 340)    1360        concatenate_78[0][0]             
__________________________________________________________________________________________________
re_lu_164 (ReLU)                (None, 4, 4, 340)    0           batch_normalization_164[0][0]    
__________________________________________________________________________________________________
conv2d_168 (Conv2D)             (None, 4, 4, 128)    43520       re_lu_164[0][0]                  
__________________________________________________________________________________________________
batch_normalization_165 (BatchN (None, 4, 4, 128)    512         conv2d_168[0][0]                 
__________________________________________________________________________________________________
re_lu_165 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_165[0][0]    
__________________________________________________________________________________________________
conv2d_169 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_165[0][0]                  
__________________________________________________________________________________________________
concatenate_79 (Concatenate)    (None, 4, 4, 372)    0           concatenate_78[0][0]             
                                                                 conv2d_169[0][0]                 
__________________________________________________________________________________________________
batch_normalization_166 (BatchN (None, 4, 4, 372)    1488        concatenate_79[0][0]             
__________________________________________________________________________________________________
re_lu_166 (ReLU)                (None, 4, 4, 372)    0           batch_normalization_166[0][0]    
__________________________________________________________________________________________________
conv2d_170 (Conv2D)             (None, 4, 4, 128)    47616       re_lu_166[0][0]                  
__________________________________________________________________________________________________
batch_normalization_167 (BatchN (None, 4, 4, 128)    512         conv2d_170[0][0]                 
__________________________________________________________________________________________________
re_lu_167 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_167[0][0]    
__________________________________________________________________________________________________
conv2d_171 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_167[0][0]                  
__________________________________________________________________________________________________
concatenate_80 (Concatenate)    (None, 4, 4, 404)    0           concatenate_79[0][0]             
                                                                 conv2d_171[0][0]                 
__________________________________________________________________________________________________
batch_normalization_168 (BatchN (None, 4, 4, 404)    1616        concatenate_80[0][0]             
__________________________________________________________________________________________________
re_lu_168 (ReLU)                (None, 4, 4, 404)    0           batch_normalization_168[0][0]    
__________________________________________________________________________________________________
conv2d_172 (Conv2D)             (None, 4, 4, 128)    51712       re_lu_168[0][0]                  
__________________________________________________________________________________________________
batch_normalization_169 (BatchN (None, 4, 4, 128)    512         conv2d_172[0][0]                 
__________________________________________________________________________________________________
re_lu_169 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_169[0][0]    
__________________________________________________________________________________________________
conv2d_173 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_169[0][0]                  
__________________________________________________________________________________________________
concatenate_81 (Concatenate)    (None, 4, 4, 436)    0           concatenate_80[0][0]             
                                                                 conv2d_173[0][0]                 
__________________________________________________________________________________________________
batch_normalization_170 (BatchN (None, 4, 4, 436)    1744        concatenate_81[0][0]             
__________________________________________________________________________________________________
re_lu_170 (ReLU)                (None, 4, 4, 436)    0           batch_normalization_170[0][0]    
__________________________________________________________________________________________________
conv2d_174 (Conv2D)             (None, 4, 4, 128)    55808       re_lu_170[0][0]                  
__________________________________________________________________________________________________
batch_normalization_171 (BatchN (None, 4, 4, 128)    512         conv2d_174[0][0]                 
__________________________________________________________________________________________________
re_lu_171 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_171[0][0]    
__________________________________________________________________________________________________
conv2d_175 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_171[0][0]                  
__________________________________________________________________________________________________
concatenate_82 (Concatenate)    (None, 4, 4, 468)    0           concatenate_81[0][0]             
                                                                 conv2d_175[0][0]                 
__________________________________________________________________________________________________
batch_normalization_172 (BatchN (None, 4, 4, 468)    1872        concatenate_82[0][0]             
__________________________________________________________________________________________________
re_lu_172 (ReLU)                (None, 4, 4, 468)    0           batch_normalization_172[0][0]    
__________________________________________________________________________________________________
conv2d_176 (Conv2D)             (None, 4, 4, 128)    59904       re_lu_172[0][0]                  
__________________________________________________________________________________________________
batch_normalization_173 (BatchN (None, 4, 4, 128)    512         conv2d_176[0][0]                 
__________________________________________________________________________________________________
re_lu_173 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_173[0][0]    
__________________________________________________________________________________________________
conv2d_177 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_173[0][0]                  
__________________________________________________________________________________________________
concatenate_83 (Concatenate)    (None, 4, 4, 500)    0           concatenate_82[0][0]             
                                                                 conv2d_177[0][0]                 
__________________________________________________________________________________________________
batch_normalization_174 (BatchN (None, 4, 4, 500)    2000        concatenate_83[0][0]             
__________________________________________________________________________________________________
re_lu_174 (ReLU)                (None, 4, 4, 500)    0           batch_normalization_174[0][0]    
__________________________________________________________________________________________________
conv2d_178 (Conv2D)             (None, 4, 4, 128)    64000       re_lu_174[0][0]                  
__________________________________________________________________________________________________
batch_normalization_175 (BatchN (None, 4, 4, 128)    512         conv2d_178[0][0]                 
__________________________________________________________________________________________________
re_lu_175 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_175[0][0]    
__________________________________________________________________________________________________
conv2d_179 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_175[0][0]                  
__________________________________________________________________________________________________
concatenate_84 (Concatenate)    (None, 4, 4, 532)    0           concatenate_83[0][0]             
                                                                 conv2d_179[0][0]                 
__________________________________________________________________________________________________
batch_normalization_176 (BatchN (None, 4, 4, 532)    2128        concatenate_84[0][0]             
__________________________________________________________________________________________________
re_lu_176 (ReLU)                (None, 4, 4, 532)    0           batch_normalization_176[0][0]    
__________________________________________________________________________________________________
conv2d_180 (Conv2D)             (None, 4, 4, 128)    68096       re_lu_176[0][0]                  
__________________________________________________________________________________________________
batch_normalization_177 (BatchN (None, 4, 4, 128)    512         conv2d_180[0][0]                 
__________________________________________________________________________________________________
re_lu_177 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_177[0][0]    
__________________________________________________________________________________________________
conv2d_181 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_177[0][0]                  
__________________________________________________________________________________________________
concatenate_85 (Concatenate)    (None, 4, 4, 564)    0           concatenate_84[0][0]             
                                                                 conv2d_181[0][0]                 
__________________________________________________________________________________________________
batch_normalization_178 (BatchN (None, 4, 4, 564)    2256        concatenate_85[0][0]             
__________________________________________________________________________________________________
re_lu_178 (ReLU)                (None, 4, 4, 564)    0           batch_normalization_178[0][0]    
__________________________________________________________________________________________________
conv2d_182 (Conv2D)             (None, 4, 4, 128)    72192       re_lu_178[0][0]                  
__________________________________________________________________________________________________
batch_normalization_179 (BatchN (None, 4, 4, 128)    512         conv2d_182[0][0]                 
__________________________________________________________________________________________________
re_lu_179 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_179[0][0]    
__________________________________________________________________________________________________
conv2d_183 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_179[0][0]                  
__________________________________________________________________________________________________
concatenate_86 (Concatenate)    (None, 4, 4, 596)    0           concatenate_85[0][0]             
                                                                 conv2d_183[0][0]                 
__________________________________________________________________________________________________
batch_normalization_180 (BatchN (None, 4, 4, 596)    2384        concatenate_86[0][0]             
__________________________________________________________________________________________________
re_lu_180 (ReLU)                (None, 4, 4, 596)    0           batch_normalization_180[0][0]    
__________________________________________________________________________________________________
conv2d_184 (Conv2D)             (None, 4, 4, 128)    76288       re_lu_180[0][0]                  
__________________________________________________________________________________________________
batch_normalization_181 (BatchN (None, 4, 4, 128)    512         conv2d_184[0][0]                 
__________________________________________________________________________________________________
re_lu_181 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_181[0][0]    
__________________________________________________________________________________________________
conv2d_185 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_181[0][0]                  
__________________________________________________________________________________________________
concatenate_87 (Concatenate)    (None, 4, 4, 628)    0           concatenate_86[0][0]             
                                                                 conv2d_185[0][0]                 
__________________________________________________________________________________________________
batch_normalization_182 (BatchN (None, 4, 4, 628)    2512        concatenate_87[0][0]             
__________________________________________________________________________________________________
re_lu_182 (ReLU)                (None, 4, 4, 628)    0           batch_normalization_182[0][0]    
__________________________________________________________________________________________________
conv2d_186 (Conv2D)             (None, 4, 4, 128)    80384       re_lu_182[0][0]                  
__________________________________________________________________________________________________
batch_normalization_183 (BatchN (None, 4, 4, 128)    512         conv2d_186[0][0]                 
__________________________________________________________________________________________________
re_lu_183 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_183[0][0]    
__________________________________________________________________________________________________
conv2d_187 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_183[0][0]                  
__________________________________________________________________________________________________
concatenate_88 (Concatenate)    (None, 4, 4, 660)    0           concatenate_87[0][0]             
                                                                 conv2d_187[0][0]                 
__________________________________________________________________________________________________
batch_normalization_184 (BatchN (None, 4, 4, 660)    2640        concatenate_88[0][0]             
__________________________________________________________________________________________________
re_lu_184 (ReLU)                (None, 4, 4, 660)    0           batch_normalization_184[0][0]    
__________________________________________________________________________________________________
conv2d_188 (Conv2D)             (None, 4, 4, 128)    84480       re_lu_184[0][0]                  
__________________________________________________________________________________________________
batch_normalization_185 (BatchN (None, 4, 4, 128)    512         conv2d_188[0][0]                 
__________________________________________________________________________________________________
re_lu_185 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_185[0][0]    
__________________________________________________________________________________________________
conv2d_189 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_185[0][0]                  
__________________________________________________________________________________________________
concatenate_89 (Concatenate)    (None, 4, 4, 692)    0           concatenate_88[0][0]             
                                                                 conv2d_189[0][0]                 
__________________________________________________________________________________________________
batch_normalization_186 (BatchN (None, 4, 4, 692)    2768        concatenate_89[0][0]             
__________________________________________________________________________________________________
re_lu_186 (ReLU)                (None, 4, 4, 692)    0           batch_normalization_186[0][0]    
__________________________________________________________________________________________________
conv2d_190 (Conv2D)             (None, 4, 4, 128)    88576       re_lu_186[0][0]                  
__________________________________________________________________________________________________
batch_normalization_187 (BatchN (None, 4, 4, 128)    512         conv2d_190[0][0]                 
__________________________________________________________________________________________________
re_lu_187 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_187[0][0]    
__________________________________________________________________________________________________
conv2d_191 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_187[0][0]                  
__________________________________________________________________________________________________
concatenate_90 (Concatenate)    (None, 4, 4, 724)    0           concatenate_89[0][0]             
                                                                 conv2d_191[0][0]                 
__________________________________________________________________________________________________
batch_normalization_188 (BatchN (None, 4, 4, 724)    2896        concatenate_90[0][0]             
__________________________________________________________________________________________________
re_lu_188 (ReLU)                (None, 4, 4, 724)    0           batch_normalization_188[0][0]    
__________________________________________________________________________________________________
conv2d_192 (Conv2D)             (None, 4, 4, 128)    92672       re_lu_188[0][0]                  
__________________________________________________________________________________________________
batch_normalization_189 (BatchN (None, 4, 4, 128)    512         conv2d_192[0][0]                 
__________________________________________________________________________________________________
re_lu_189 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_189[0][0]    
__________________________________________________________________________________________________
conv2d_193 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_189[0][0]                  
__________________________________________________________________________________________________
concatenate_91 (Concatenate)    (None, 4, 4, 756)    0           concatenate_90[0][0]             
                                                                 conv2d_193[0][0]                 
__________________________________________________________________________________________________
batch_normalization_190 (BatchN (None, 4, 4, 756)    3024        concatenate_91[0][0]             
__________________________________________________________________________________________________
re_lu_190 (ReLU)                (None, 4, 4, 756)    0           batch_normalization_190[0][0]    
__________________________________________________________________________________________________
conv2d_194 (Conv2D)             (None, 4, 4, 128)    96768       re_lu_190[0][0]                  
__________________________________________________________________________________________________
batch_normalization_191 (BatchN (None, 4, 4, 128)    512         conv2d_194[0][0]                 
__________________________________________________________________________________________________
re_lu_191 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_191[0][0]    
__________________________________________________________________________________________________
conv2d_195 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_191[0][0]                  
__________________________________________________________________________________________________
concatenate_92 (Concatenate)    (None, 4, 4, 788)    0           concatenate_91[0][0]             
                                                                 conv2d_195[0][0]                 
__________________________________________________________________________________________________
batch_normalization_192 (BatchN (None, 4, 4, 788)    3152        concatenate_92[0][0]             
__________________________________________________________________________________________________
re_lu_192 (ReLU)                (None, 4, 4, 788)    0           batch_normalization_192[0][0]    
__________________________________________________________________________________________________
conv2d_196 (Conv2D)             (None, 4, 4, 128)    100864      re_lu_192[0][0]                  
__________________________________________________________________________________________________
batch_normalization_193 (BatchN (None, 4, 4, 128)    512         conv2d_196[0][0]                 
__________________________________________________________________________________________________
re_lu_193 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_193[0][0]    
__________________________________________________________________________________________________
conv2d_197 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_193[0][0]                  
__________________________________________________________________________________________________
concatenate_93 (Concatenate)    (None, 4, 4, 820)    0           concatenate_92[0][0]             
                                                                 conv2d_197[0][0]                 
__________________________________________________________________________________________________
batch_normalization_194 (BatchN (None, 4, 4, 820)    3280        concatenate_93[0][0]             
__________________________________________________________________________________________________
re_lu_194 (ReLU)                (None, 4, 4, 820)    0           batch_normalization_194[0][0]    
__________________________________________________________________________________________________
conv2d_198 (Conv2D)             (None, 4, 4, 128)    104960      re_lu_194[0][0]                  
__________________________________________________________________________________________________
batch_normalization_195 (BatchN (None, 4, 4, 128)    512         conv2d_198[0][0]                 
__________________________________________________________________________________________________
re_lu_195 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_195[0][0]    
__________________________________________________________________________________________________
conv2d_199 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_195[0][0]                  
__________________________________________________________________________________________________
concatenate_94 (Concatenate)    (None, 4, 4, 852)    0           concatenate_93[0][0]             
                                                                 conv2d_199[0][0]                 
__________________________________________________________________________________________________
batch_normalization_196 (BatchN (None, 4, 4, 852)    3408        concatenate_94[0][0]             
__________________________________________________________________________________________________
re_lu_196 (ReLU)                (None, 4, 4, 852)    0           batch_normalization_196[0][0]    
__________________________________________________________________________________________________
conv2d_200 (Conv2D)             (None, 4, 4, 128)    109056      re_lu_196[0][0]                  
__________________________________________________________________________________________________
batch_normalization_197 (BatchN (None, 4, 4, 128)    512         conv2d_200[0][0]                 
__________________________________________________________________________________________________
re_lu_197 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_197[0][0]    
__________________________________________________________________________________________________
conv2d_201 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_197[0][0]                  
__________________________________________________________________________________________________
concatenate_95 (Concatenate)    (None, 4, 4, 884)    0           concatenate_94[0][0]             
                                                                 conv2d_201[0][0]                 
__________________________________________________________________________________________________
batch_normalization_198 (BatchN (None, 4, 4, 884)    3536        concatenate_95[0][0]             
__________________________________________________________________________________________________
re_lu_198 (ReLU)                (None, 4, 4, 884)    0           batch_normalization_198[0][0]    
__________________________________________________________________________________________________
conv2d_202 (Conv2D)             (None, 4, 4, 128)    113152      re_lu_198[0][0]                  
__________________________________________________________________________________________________
batch_normalization_199 (BatchN (None, 4, 4, 128)    512         conv2d_202[0][0]                 
__________________________________________________________________________________________________
re_lu_199 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_199[0][0]    
__________________________________________________________________________________________________
conv2d_203 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_199[0][0]                  
__________________________________________________________________________________________________
concatenate_96 (Concatenate)    (None, 4, 4, 916)    0           concatenate_95[0][0]             
                                                                 conv2d_203[0][0]                 
__________________________________________________________________________________________________
batch_normalization_200 (BatchN (None, 4, 4, 916)    3664        concatenate_96[0][0]             
__________________________________________________________________________________________________
re_lu_200 (ReLU)                (None, 4, 4, 916)    0           batch_normalization_200[0][0]    
__________________________________________________________________________________________________
conv2d_204 (Conv2D)             (None, 4, 4, 128)    117248      re_lu_200[0][0]                  
__________________________________________________________________________________________________
batch_normalization_201 (BatchN (None, 4, 4, 128)    512         conv2d_204[0][0]                 
__________________________________________________________________________________________________
re_lu_201 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_201[0][0]    
__________________________________________________________________________________________________
conv2d_205 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_201[0][0]                  
__________________________________________________________________________________________________
concatenate_97 (Concatenate)    (None, 4, 4, 948)    0           concatenate_96[0][0]             
                                                                 conv2d_205[0][0]                 
__________________________________________________________________________________________________
batch_normalization_202 (BatchN (None, 4, 4, 948)    3792        concatenate_97[0][0]             
__________________________________________________________________________________________________
re_lu_202 (ReLU)                (None, 4, 4, 948)    0           batch_normalization_202[0][0]    
__________________________________________________________________________________________________
conv2d_206 (Conv2D)             (None, 4, 4, 128)    121344      re_lu_202[0][0]                  
__________________________________________________________________________________________________
batch_normalization_203 (BatchN (None, 4, 4, 128)    512         conv2d_206[0][0]                 
__________________________________________________________________________________________________
re_lu_203 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_203[0][0]    
__________________________________________________________________________________________________
conv2d_207 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_203[0][0]                  
__________________________________________________________________________________________________
concatenate_98 (Concatenate)    (None, 4, 4, 980)    0           concatenate_97[0][0]             
                                                                 conv2d_207[0][0]                 
__________________________________________________________________________________________________
batch_normalization_204 (BatchN (None, 4, 4, 980)    3920        concatenate_98[0][0]             
__________________________________________________________________________________________________
re_lu_204 (ReLU)                (None, 4, 4, 980)    0           batch_normalization_204[0][0]    
__________________________________________________________________________________________________
conv2d_208 (Conv2D)             (None, 4, 4, 128)    125440      re_lu_204[0][0]                  
__________________________________________________________________________________________________
batch_normalization_205 (BatchN (None, 4, 4, 128)    512         conv2d_208[0][0]                 
__________________________________________________________________________________________________
re_lu_205 (ReLU)                (None, 4, 4, 128)    0           batch_normalization_205[0][0]    
__________________________________________________________________________________________________
conv2d_209 (Conv2D)             (None, 4, 4, 32)     36864       re_lu_205[0][0]                  
__________________________________________________________________________________________________
concatenate_99 (Concatenate)    (None, 4, 4, 1012)   0           concatenate_98[0][0]             
                                                                 conv2d_209[0][0]                 
__________________________________________________________________________________________________
batch_normalization_206 (BatchN (None, 4, 4, 1012)   4048        concatenate_99[0][0]             
__________________________________________________________________________________________________
re_lu_206 (ReLU)                (None, 4, 4, 1012)   0           batch_normalization_206[0][0]    
__________________________________________________________________________________________________
conv2d_210 (Conv2D)             (None, 4, 4, 506)    512072      re_lu_206[0][0]                  
__________________________________________________________________________________________________
average_pooling2d_6 (AveragePoo (None, 2, 2, 506)    0           conv2d_210[0][0]                 
__________________________________________________________________________________________________
batch_normalization_207 (BatchN (None, 2, 2, 506)    2024        average_pooling2d_6[0][0]        
__________________________________________________________________________________________________
re_lu_207 (ReLU)                (None, 2, 2, 506)    0           batch_normalization_207[0][0]    
__________________________________________________________________________________________________
conv2d_211 (Conv2D)             (None, 2, 2, 128)    64768       re_lu_207[0][0]                  
__________________________________________________________________________________________________
batch_normalization_208 (BatchN (None, 2, 2, 128)    512         conv2d_211[0][0]                 
__________________________________________________________________________________________________
re_lu_208 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_208[0][0]    
__________________________________________________________________________________________________
conv2d_212 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_208[0][0]                  
__________________________________________________________________________________________________
concatenate_100 (Concatenate)   (None, 2, 2, 538)    0           average_pooling2d_6[0][0]        
                                                                 conv2d_212[0][0]                 
__________________________________________________________________________________________________
batch_normalization_209 (BatchN (None, 2, 2, 538)    2152        concatenate_100[0][0]            
__________________________________________________________________________________________________
re_lu_209 (ReLU)                (None, 2, 2, 538)    0           batch_normalization_209[0][0]    
__________________________________________________________________________________________________
conv2d_213 (Conv2D)             (None, 2, 2, 128)    68864       re_lu_209[0][0]                  
__________________________________________________________________________________________________
batch_normalization_210 (BatchN (None, 2, 2, 128)    512         conv2d_213[0][0]                 
__________________________________________________________________________________________________
re_lu_210 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_210[0][0]    
__________________________________________________________________________________________________
conv2d_214 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_210[0][0]                  
__________________________________________________________________________________________________
concatenate_101 (Concatenate)   (None, 2, 2, 570)    0           concatenate_100[0][0]            
                                                                 conv2d_214[0][0]                 
__________________________________________________________________________________________________
batch_normalization_211 (BatchN (None, 2, 2, 570)    2280        concatenate_101[0][0]            
__________________________________________________________________________________________________
re_lu_211 (ReLU)                (None, 2, 2, 570)    0           batch_normalization_211[0][0]    
__________________________________________________________________________________________________
conv2d_215 (Conv2D)             (None, 2, 2, 128)    72960       re_lu_211[0][0]                  
__________________________________________________________________________________________________
batch_normalization_212 (BatchN (None, 2, 2, 128)    512         conv2d_215[0][0]                 
__________________________________________________________________________________________________
re_lu_212 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_212[0][0]    
__________________________________________________________________________________________________
conv2d_216 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_212[0][0]                  
__________________________________________________________________________________________________
concatenate_102 (Concatenate)   (None, 2, 2, 602)    0           concatenate_101[0][0]            
                                                                 conv2d_216[0][0]                 
__________________________________________________________________________________________________
batch_normalization_213 (BatchN (None, 2, 2, 602)    2408        concatenate_102[0][0]            
__________________________________________________________________________________________________
re_lu_213 (ReLU)                (None, 2, 2, 602)    0           batch_normalization_213[0][0]    
__________________________________________________________________________________________________
conv2d_217 (Conv2D)             (None, 2, 2, 128)    77056       re_lu_213[0][0]                  
__________________________________________________________________________________________________
batch_normalization_214 (BatchN (None, 2, 2, 128)    512         conv2d_217[0][0]                 
__________________________________________________________________________________________________
re_lu_214 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_214[0][0]    
__________________________________________________________________________________________________
conv2d_218 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_214[0][0]                  
__________________________________________________________________________________________________
concatenate_103 (Concatenate)   (None, 2, 2, 634)    0           concatenate_102[0][0]            
                                                                 conv2d_218[0][0]                 
__________________________________________________________________________________________________
batch_normalization_215 (BatchN (None, 2, 2, 634)    2536        concatenate_103[0][0]            
__________________________________________________________________________________________________
re_lu_215 (ReLU)                (None, 2, 2, 634)    0           batch_normalization_215[0][0]    
__________________________________________________________________________________________________
conv2d_219 (Conv2D)             (None, 2, 2, 128)    81152       re_lu_215[0][0]                  
__________________________________________________________________________________________________
batch_normalization_216 (BatchN (None, 2, 2, 128)    512         conv2d_219[0][0]                 
__________________________________________________________________________________________________
re_lu_216 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_216[0][0]    
__________________________________________________________________________________________________
conv2d_220 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_216[0][0]                  
__________________________________________________________________________________________________
concatenate_104 (Concatenate)   (None, 2, 2, 666)    0           concatenate_103[0][0]            
                                                                 conv2d_220[0][0]                 
__________________________________________________________________________________________________
batch_normalization_217 (BatchN (None, 2, 2, 666)    2664        concatenate_104[0][0]            
__________________________________________________________________________________________________
re_lu_217 (ReLU)                (None, 2, 2, 666)    0           batch_normalization_217[0][0]    
__________________________________________________________________________________________________
conv2d_221 (Conv2D)             (None, 2, 2, 128)    85248       re_lu_217[0][0]                  
__________________________________________________________________________________________________
batch_normalization_218 (BatchN (None, 2, 2, 128)    512         conv2d_221[0][0]                 
__________________________________________________________________________________________________
re_lu_218 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_218[0][0]    
__________________________________________________________________________________________________
conv2d_222 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_218[0][0]                  
__________________________________________________________________________________________________
concatenate_105 (Concatenate)   (None, 2, 2, 698)    0           concatenate_104[0][0]            
                                                                 conv2d_222[0][0]                 
__________________________________________________________________________________________________
batch_normalization_219 (BatchN (None, 2, 2, 698)    2792        concatenate_105[0][0]            
__________________________________________________________________________________________________
re_lu_219 (ReLU)                (None, 2, 2, 698)    0           batch_normalization_219[0][0]    
__________________________________________________________________________________________________
conv2d_223 (Conv2D)             (None, 2, 2, 128)    89344       re_lu_219[0][0]                  
__________________________________________________________________________________________________
batch_normalization_220 (BatchN (None, 2, 2, 128)    512         conv2d_223[0][0]                 
__________________________________________________________________________________________________
re_lu_220 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_220[0][0]    
__________________________________________________________________________________________________
conv2d_224 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_220[0][0]                  
__________________________________________________________________________________________________
concatenate_106 (Concatenate)   (None, 2, 2, 730)    0           concatenate_105[0][0]            
                                                                 conv2d_224[0][0]                 
__________________________________________________________________________________________________
batch_normalization_221 (BatchN (None, 2, 2, 730)    2920        concatenate_106[0][0]            
__________________________________________________________________________________________________
re_lu_221 (ReLU)                (None, 2, 2, 730)    0           batch_normalization_221[0][0]    
__________________________________________________________________________________________________
conv2d_225 (Conv2D)             (None, 2, 2, 128)    93440       re_lu_221[0][0]                  
__________________________________________________________________________________________________
batch_normalization_222 (BatchN (None, 2, 2, 128)    512         conv2d_225[0][0]                 
__________________________________________________________________________________________________
re_lu_222 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_222[0][0]    
__________________________________________________________________________________________________
conv2d_226 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_222[0][0]                  
__________________________________________________________________________________________________
concatenate_107 (Concatenate)   (None, 2, 2, 762)    0           concatenate_106[0][0]            
                                                                 conv2d_226[0][0]                 
__________________________________________________________________________________________________
batch_normalization_223 (BatchN (None, 2, 2, 762)    3048        concatenate_107[0][0]            
__________________________________________________________________________________________________
re_lu_223 (ReLU)                (None, 2, 2, 762)    0           batch_normalization_223[0][0]    
__________________________________________________________________________________________________
conv2d_227 (Conv2D)             (None, 2, 2, 128)    97536       re_lu_223[0][0]                  
__________________________________________________________________________________________________
batch_normalization_224 (BatchN (None, 2, 2, 128)    512         conv2d_227[0][0]                 
__________________________________________________________________________________________________
re_lu_224 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_224[0][0]    
__________________________________________________________________________________________________
conv2d_228 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_224[0][0]                  
__________________________________________________________________________________________________
concatenate_108 (Concatenate)   (None, 2, 2, 794)    0           concatenate_107[0][0]            
                                                                 conv2d_228[0][0]                 
__________________________________________________________________________________________________
batch_normalization_225 (BatchN (None, 2, 2, 794)    3176        concatenate_108[0][0]            
__________________________________________________________________________________________________
re_lu_225 (ReLU)                (None, 2, 2, 794)    0           batch_normalization_225[0][0]    
__________________________________________________________________________________________________
conv2d_229 (Conv2D)             (None, 2, 2, 128)    101632      re_lu_225[0][0]                  
__________________________________________________________________________________________________
batch_normalization_226 (BatchN (None, 2, 2, 128)    512         conv2d_229[0][0]                 
__________________________________________________________________________________________________
re_lu_226 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_226[0][0]    
__________________________________________________________________________________________________
conv2d_230 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_226[0][0]                  
__________________________________________________________________________________________________
concatenate_109 (Concatenate)   (None, 2, 2, 826)    0           concatenate_108[0][0]            
                                                                 conv2d_230[0][0]                 
__________________________________________________________________________________________________
batch_normalization_227 (BatchN (None, 2, 2, 826)    3304        concatenate_109[0][0]            
__________________________________________________________________________________________________
re_lu_227 (ReLU)                (None, 2, 2, 826)    0           batch_normalization_227[0][0]    
__________________________________________________________________________________________________
conv2d_231 (Conv2D)             (None, 2, 2, 128)    105728      re_lu_227[0][0]                  
__________________________________________________________________________________________________
batch_normalization_228 (BatchN (None, 2, 2, 128)    512         conv2d_231[0][0]                 
__________________________________________________________________________________________________
re_lu_228 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_228[0][0]    
__________________________________________________________________________________________________
conv2d_232 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_228[0][0]                  
__________________________________________________________________________________________________
concatenate_110 (Concatenate)   (None, 2, 2, 858)    0           concatenate_109[0][0]            
                                                                 conv2d_232[0][0]                 
__________________________________________________________________________________________________
batch_normalization_229 (BatchN (None, 2, 2, 858)    3432        concatenate_110[0][0]            
__________________________________________________________________________________________________
re_lu_229 (ReLU)                (None, 2, 2, 858)    0           batch_normalization_229[0][0]    
__________________________________________________________________________________________________
conv2d_233 (Conv2D)             (None, 2, 2, 128)    109824      re_lu_229[0][0]                  
__________________________________________________________________________________________________
batch_normalization_230 (BatchN (None, 2, 2, 128)    512         conv2d_233[0][0]                 
__________________________________________________________________________________________________
re_lu_230 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_230[0][0]    
__________________________________________________________________________________________________
conv2d_234 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_230[0][0]                  
__________________________________________________________________________________________________
concatenate_111 (Concatenate)   (None, 2, 2, 890)    0           concatenate_110[0][0]            
                                                                 conv2d_234[0][0]                 
__________________________________________________________________________________________________
batch_normalization_231 (BatchN (None, 2, 2, 890)    3560        concatenate_111[0][0]            
__________________________________________________________________________________________________
re_lu_231 (ReLU)                (None, 2, 2, 890)    0           batch_normalization_231[0][0]    
__________________________________________________________________________________________________
conv2d_235 (Conv2D)             (None, 2, 2, 128)    113920      re_lu_231[0][0]                  
__________________________________________________________________________________________________
batch_normalization_232 (BatchN (None, 2, 2, 128)    512         conv2d_235[0][0]                 
__________________________________________________________________________________________________
re_lu_232 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_232[0][0]    
__________________________________________________________________________________________________
conv2d_236 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_232[0][0]                  
__________________________________________________________________________________________________
concatenate_112 (Concatenate)   (None, 2, 2, 922)    0           concatenate_111[0][0]            
                                                                 conv2d_236[0][0]                 
__________________________________________________________________________________________________
batch_normalization_233 (BatchN (None, 2, 2, 922)    3688        concatenate_112[0][0]            
__________________________________________________________________________________________________
re_lu_233 (ReLU)                (None, 2, 2, 922)    0           batch_normalization_233[0][0]    
__________________________________________________________________________________________________
conv2d_237 (Conv2D)             (None, 2, 2, 128)    118016      re_lu_233[0][0]                  
__________________________________________________________________________________________________
batch_normalization_234 (BatchN (None, 2, 2, 128)    512         conv2d_237[0][0]                 
__________________________________________________________________________________________________
re_lu_234 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_234[0][0]    
__________________________________________________________________________________________________
conv2d_238 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_234[0][0]                  
__________________________________________________________________________________________________
concatenate_113 (Concatenate)   (None, 2, 2, 954)    0           concatenate_112[0][0]            
                                                                 conv2d_238[0][0]                 
__________________________________________________________________________________________________
batch_normalization_235 (BatchN (None, 2, 2, 954)    3816        concatenate_113[0][0]            
__________________________________________________________________________________________________
re_lu_235 (ReLU)                (None, 2, 2, 954)    0           batch_normalization_235[0][0]    
__________________________________________________________________________________________________
conv2d_239 (Conv2D)             (None, 2, 2, 128)    122112      re_lu_235[0][0]                  
__________________________________________________________________________________________________
batch_normalization_236 (BatchN (None, 2, 2, 128)    512         conv2d_239[0][0]                 
__________________________________________________________________________________________________
re_lu_236 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_236[0][0]    
__________________________________________________________________________________________________
conv2d_240 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_236[0][0]                  
__________________________________________________________________________________________________
concatenate_114 (Concatenate)   (None, 2, 2, 986)    0           concatenate_113[0][0]            
                                                                 conv2d_240[0][0]                 
__________________________________________________________________________________________________
batch_normalization_237 (BatchN (None, 2, 2, 986)    3944        concatenate_114[0][0]            
__________________________________________________________________________________________________
re_lu_237 (ReLU)                (None, 2, 2, 986)    0           batch_normalization_237[0][0]    
__________________________________________________________________________________________________
conv2d_241 (Conv2D)             (None, 2, 2, 128)    126208      re_lu_237[0][0]                  
__________________________________________________________________________________________________
batch_normalization_238 (BatchN (None, 2, 2, 128)    512         conv2d_241[0][0]                 
__________________________________________________________________________________________________
re_lu_238 (ReLU)                (None, 2, 2, 128)    0           batch_normalization_238[0][0]    
__________________________________________________________________________________________________
conv2d_242 (Conv2D)             (None, 2, 2, 32)     36864       re_lu_238[0][0]                  
__________________________________________________________________________________________________
concatenate_115 (Concatenate)   (None, 2, 2, 1018)   0           concatenate_114[0][0]            
                                                                 conv2d_242[0][0]                 
__________________________________________________________________________________________________
global_average_pooling2d_1 (Glo (None, 1018)         0           concatenate_115[0][0]            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 100)          101900      global_average_pooling2d_1[0][0] 
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 100)          0           dense_1[0][0]                    
==================================================================================================
Total params: 6,965,604
Trainable params: 6,886,220
Non-trainable params: 79,384
__________________________________________________________________________________________________

1
2
3
model.compile(loss = 'categorical_crossentropy',
              optimizer=tf.keras.optimizers.Adam(lr=0.001),
              metrics=['accuracy'])
1
2
3
4
5
6
7
import time
start = time.time()
history = model.fit(ds_train,
                    epochs=20,
                    validation_data = ds_val)

print(f"It took {time.time() - start} seconds")
Epoch 1/20
586/586 [==============================] - 54s 93ms/step - loss: 3.1179 - accuracy: 0.2295 - val_loss: 3.1247 - val_accuracy: 0.2200
Epoch 2/20
586/586 [==============================] - 55s 93ms/step - loss: 2.7529 - accuracy: 0.2975 - val_loss: 9.7801 - val_accuracy: 0.1771
Epoch 3/20
586/586 [==============================] - 54s 93ms/step - loss: 2.4882 - accuracy: 0.3499 - val_loss: 2.8652 - val_accuracy: 0.2881
Epoch 4/20
586/586 [==============================] - 54s 92ms/step - loss: 2.3228 - accuracy: 0.3846 - val_loss: 2.9444 - val_accuracy: 0.2735
Epoch 5/20
586/586 [==============================] - 54s 92ms/step - loss: 2.1534 - accuracy: 0.4218 - val_loss: 5.3858 - val_accuracy: 0.3058
Epoch 6/20
586/586 [==============================] - 54s 92ms/step - loss: 1.9807 - accuracy: 0.4586 - val_loss: 2.4927 - val_accuracy: 0.3708
Epoch 7/20
586/586 [==============================] - 54s 92ms/step - loss: 1.8318 - accuracy: 0.4897 - val_loss: 2.4684 - val_accuracy: 0.3822
Epoch 8/20
586/586 [==============================] - 54s 92ms/step - loss: 1.7204 - accuracy: 0.5164 - val_loss: 2.2386 - val_accuracy: 0.4179
Epoch 9/20
586/586 [==============================] - 53s 91ms/step - loss: 1.5877 - accuracy: 0.5470 - val_loss: 2.6878 - val_accuracy: 0.3539
Epoch 10/20
586/586 [==============================] - 53s 91ms/step - loss: 1.4851 - accuracy: 0.5734 - val_loss: 2.2689 - val_accuracy: 0.4295
Epoch 11/20
586/586 [==============================] - 53s 91ms/step - loss: 1.3668 - accuracy: 0.6050 - val_loss: 2.2262 - val_accuracy: 0.4541
Epoch 12/20
586/586 [==============================] - 53s 91ms/step - loss: 1.2730 - accuracy: 0.6249 - val_loss: 2.0790 - val_accuracy: 0.4722
Epoch 13/20
586/586 [==============================] - 53s 91ms/step - loss: 1.1463 - accuracy: 0.6619 - val_loss: 2.1492 - val_accuracy: 0.4834
Epoch 14/20
586/586 [==============================] - 53s 91ms/step - loss: 1.0477 - accuracy: 0.6844 - val_loss: 2.5350 - val_accuracy: 0.4299
Epoch 15/20
586/586 [==============================] - 53s 91ms/step - loss: 0.9707 - accuracy: 0.7059 - val_loss: 2.0789 - val_accuracy: 0.5034
Epoch 16/20
586/586 [==============================] - 53s 90ms/step - loss: 0.8857 - accuracy: 0.7306 - val_loss: 2.1846 - val_accuracy: 0.4991
Epoch 17/20
586/586 [==============================] - 53s 91ms/step - loss: 0.7917 - accuracy: 0.7553 - val_loss: 4.4854 - val_accuracy: 0.4355
Epoch 18/20
586/586 [==============================] - 53s 91ms/step - loss: 0.7192 - accuracy: 0.7772 - val_loss: 2.3988 - val_accuracy: 0.4859
Epoch 19/20
586/586 [==============================] - 53s 91ms/step - loss: 0.6658 - accuracy: 0.7915 - val_loss: 2.4001 - val_accuracy: 0.4908
Epoch 20/20
586/586 [==============================] - 54s 92ms/step - loss: 0.6068 - accuracy: 0.8085 - val_loss: 3.2709 - val_accuracy: 0.3686
It took 1086.7563781738281 seconds

1
2
3
4
5
6
7
import pandas as pd
import matplotlib.pyplot as plt

pd.DataFrame(history.history).plot(figsize=(12,8))
plt.grid(True)
plt.gca().set_ylim(0,2)
plt.show()
model.evaluate(ds_test)
157/157 [==============================] - 5s 33ms/step - loss: 3.3091 - accuracy: 0.3688

[3.3090856075286865, 0.36880001425743103]

Customiser ce que se passe dans fit()

Les étapes cachées dans .fit()

class CustomModel(keras.Model):

  def train_step(self, data):
    x, y = data

    with tf.GradientTape() as tape:
      y_pred = self(x, training=True)
      loss = self.compiled_loss(y, y_pred,
                                regularization_losses=self.losses)

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}
  1. On récupère les données du minibatch :
  def train_step(self, data):
    x, y = data

Le type de données que vous récupérez dépend évidemment du type de modèle que vous entraînez et donc des données que vous passez dans .fit().

    with tf.GradientTape() as tape:
tf.GradientTape() est la méthode de Tensorflow pour différentier les fonctions, i.e. calculer des dérivées et des dérivées partielles. Qui dit dérivées partielles, dit étapes de mises à jours des poids.

  1. On calcule la prédiction sur le minibatch : $$(\hat{y}{1}, \dots, \hat{y}) =( f(x_{1}), \dots, f(x_{N})) $$

          y_pred = self(x, training=True)
    

  2. Pour chaque \(\hat{y}_{i}\), on calcule l'erreur faite via la fonction de perte \(\mathcal{L}_{\vartheta}(y_{i},\hat{y}_{i})\), et on en déduit l'erreur moyenne sur le minibatch.

\[\mathcal{L}_{\vartheta} = \frac{1}{N}\sum_{i=1}^{N}\mathcal{L}_{\vartheta}(y_{i},\hat{y}_{i})\]
      loss = self.compiled_loss(y, y_pred,
                                regularization_losses=self.losses)

On rappelle que la fonction de perte est définie dans la .compile().

  1. On calcule alors le gradient pour chaque paramètre dans \(\vartheta\), i.e.

\(\(\nabla \mathcal{L}_{\vartheta}\)\)

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)

  1. On met alors à jour les paramètres :

$$ w_{i} \leftarrow w_{i} - \eta \frac{\partial \mathcal{L}{\vartheta}}{\partial w}(\vartheta), \ b_{i} \leftarrow b_{i} - \eta \frac{\partial \mathcal{L}{\vartheta}}{\partial b}(\vartheta).$$

    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

  1. On met alors à jour les métriques (loss, accuracy, ...) et on renvoit un dictionnaire contenant ces mise à jours.
    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}

Voyons un peu comment se comportent ses différentes étapes.

Inutile de lancer ça sur un vrai modèle, plusieurs centaines de couches et millions de paramètres. Nous voulons d'abord juste voir comme nt cela fonctionne. Créons donc un dataset et un modèle complètement naïf.

class CustomModel(keras.Model):

  def train_step(self, data):
    print()
    print(f"----Etape: {self.step_counter}")
    self.step_counter += 1

    x, y = data
    print(f'Début du train : {x.shape}, {y.shape}')

    with tf.GradientTape() as tape:
      print(f'Start GradientTape step {x.shape}')
      y_pred = self(x, training=True)
      print(f'Prediction done {y_pred.shape}')
      loss = self.compiled_loss(y, y_pred,
                                regularization_losses=self.losses)
      print(f'loss {loss}')

    trainable_vars = self.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    self.compiled_metrics.update_state(y, y_pred)
    return {m.name: m.result() for m in self.metrics}
#Créons un dataset dummy
t_x = tf.random.uniform([30, 4], dtype=tf.float32)
t_y = tf.range(30)

ds_x = tf.data.Dataset.from_tensor_slices(t_x)
ds_y = tf.data.Dataset.from_tensor_slices(t_y)

ds = tf.data.Dataset.zip((ds_x, ds_y))

ds = ds.batch(3)

from tensorflow.keras import Model
# Dummy model
input = Input(shape=(4,))

x = Dense(32)(input)
output = Dense(1)(x)

model = CustomModel(input,x)

model.compile(loss = 'mean_absolute_error',
              optimizer=tf.keras.optimizers.Adam(lr=0.001),
              run_eagerly=True)
model.step_counter = 0
model.summary()
Model: "custom_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense_2 (Dense)              (None, 32)                160       
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________

1
2
3
model.fit(ds,
          epochs=1,
          verbose=0)

----Etape: 0
Start train step (None, 4), (None,)
Start GradientTape step (None, 4)
Prediction done (None, 32)
loss Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)

----Etape: 1
Start train step (None, 4), (None,)
Start GradientTape step (None, 4)
Prediction done (None, 32)
loss Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)

Pas très concluant hein ? on ne voit même pas toutes les étapes. C'est parce que Python est un langage lent, la plupart des structure internes de Tensorflow sont codés dans un langage beaucoup plus rapide tel que le C ou le C++.

Ce que l'on écrit n'est dont pas toujours ce que l'on obtient vraiment. Pour que tf.keras fasse les instructions de façon séquentielle, on complie le modèle en rajoutant l'option :

run_eagerly=True
1
2
3
4
5
6
7
8
9
model.compile(loss = 'mean_absolute_error',
              optimizer=tf.keras.optimizers.Adam(lr=0.001),
              run_eagerly=True)
model.step_counter = 0
model.summary()

model.fit(ds,
          epochs=1,
          verbose=0)
Model: "custom_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense_2 (Dense)              (None, 32)                160       
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________

----Etape: 0
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 1.0454397201538086

----Etape: 1
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 3.949005126953125

----Etape: 2
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 6.941404342651367

----Etape: 3
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 9.953435897827148

----Etape: 4
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 12.954936027526855

----Etape: 5
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 15.921150207519531

----Etape: 6
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 18.93268394470215

----Etape: 7
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 21.93629264831543

----Etape: 8
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 24.937978744506836

----Etape: 9
Start train step (3, 4), (3,)
Start GradientTape step (3, 4)
Prediction done (3, 32)
loss 27.9443302154541

<tensorflow.python.keras.callbacks.History at 0x7f258a26f278>

Prise en main de GradientTape

GradientTape enregistre les opérations qui sont faites dans un graphe (voir rappel du module 1 sur comment une fonction peut se définir comme une graphe), afin de calculer la différentielle de cette fonction.

Prenons un exemple simple, la fonction \(f(x) = x^{2}\), on souhaite calculer sa dérivée en 3. Les formules classiques d'analyse différentielle nous donnent alors \(f'(x)=2x\) et donc \(f'(3)=6\).

Avec GradientTape, on fait comme cela.

1
2
3
4
5
6
x = tf.constant(3.0)
with tf.GradientTape() as tape:
  tape.watch(x)
  y = x * x

dy_dx = tape.gradient(y, x)
dy_dx

Dans le cas de tf.keras, la partie tape.watch(x) qui nous dit par rapport à quelle variable nous allons dériver n'est pas nécessaire, tf.keras sait très bien quels sont les paramètres dans le réseau de neurones que nous entraînons.

On peut évidemment le combiner pour calculer des dérivées secondes.

1
2
3
4
5
6
7
8
x = tf.constant(3.0)
with tf.GradientTape() as tape:
  tape.watch(x)
  with tf.GradientTape() as tape2:
    tape2.watch(x)
    y = x * x
  dy_dx = tape2.gradient(y, x)     
d2y_dx2 = tape.gradient(dy_dx, x)
dy_dx
d2y_dx2

We need to go deeper

Dans les modules suivants nous verrons comment modifier la méthode .fit() pour qu'elle corresponde à l'entraînement que l'on souhaite, par exemple lorque l'on entraînera un autoencodeur variationnel.

Il faut aussi noter que l'on peut complètement écrire sa boucle d'entraînement sans passer en aucune façon par la méthode .fit(). Ce qui est pratique, voire même nécéssaire si l'on souhaite implémenter certaines bonnes pratiques lors de l'entraînement de certains modèles comme les GAN.

epochs = ...
loss_fn = tf.keras.losses.[..]
metric_fn = tf.keras.metrics.[...]
optimizer = tf.keras.optimizers.[...]

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
    #prédiction sur le minibatch
    y_pred = model(x, training=True)
    #calcul de la fonction de perte moyenne sur le minibatch
    loss_value = loss_fn(y, y_pred)

  # calcul des gradients et retropropagation
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))

  # mise à jour des métriques
  metric_fn.update_state(y, y_pred)

  return loss_value

@tf.function
def test_step(x, y):
  y_pred = model(x, training=False)
  loss_value = loss_fn(y, y_pred)
  metric_fn.update_state(y, y_pred)

  return loss_value

for epoch in range(epochs):
  print(f"\nDébut de l'époque {epoch+1},")
  start_time = time.time()

  # itération sur les minibatchs du dataset
  for step, (x_batch_train, y_batch_train) in enumerate(ds):
    loss_value = train_step(x_batch_train, y_batch_train)

    # Log tous les 10 batches
    if step % 10 == 0:
      print(f"Loss sur le batch à l'étape {step} : {float(loss_value):.4f}")

  # Affichage des métriques à la fin de l'époque
  metric = metric_fn.result()
  print(f"Métrique pour l'époque : {float(metric):.4f} \n")

  # Reset de la métrique à la fin de chaque époque
  metric_fn.reset_states()

  # validation loop à la fin de chaque époque
  for x_batch_val, y_batch_val in ds_val:
    val_loss = test_step(x_batch_val, y_batch_val)

  val_metric = metric_fn.result()
  metric_fn.reset_states()
  print()
  print(f"Loss de validation : {float(val_loss):.4f}")
  print(f"Métrique de validation : {float(val_metric):.4f}")
  print(f"Durée de l'époque: {time.time() - start_time:.2fs}")

Détaillons les parties importantes.

  • On lance une boucle for pour itérer sur les époques.
  • Pour chaque époque, on ouvre une autre boucle for pour itérer sur les batchs du dataset.
  • Pour chaque batch, on ouvre un GradientTape(), où l'on calcule l'étape de feedforward.
  • Une fois fini, on calcule le gradient par rapport aux poids du modèle et l'on met à jour ces poids.

Détaillons plus. Il est à noter ici que pour cette boucle, nous avons déjà accès aux batchs. Notre dataset est donc déjà sous la forme tensorielle via, l'API tf.data.Dataset.

Premièrement, on fixe les variables : le nombre d'époques, et les différents fonctions que l'on utilisera.

1
2
3
4
epochs = ...
loss_fn = tf.keras.losses.[..]
metric_fn = tf.keras.metrics.[...]
optimizer = tf.keras.optimizers.[...]

On lance alors la boucle principale sur le nombre d'époque.

1
2
3
for epoch in range(epochs):
  print(f"\nDébut de l'époque {epoch+1},")
  start_time = time.time()

Pour chaque batch, on lance alors l'étape d'entraînement (on revient dessus plus tard) et on affiche la perte disons par exemple tous les 10 minibatchs.

1
2
3
4
5
6
for step, (x_batch_train, y_batch_train) in enumerate(ds):
    loss_value = train_step(x_batch_train, y_batch_train)

    # Log tous les 10 batches
    if step % 10 == 0:
      print(f"Loss sur le batch à l'étape {step} : {float(loss_value):.4f}")

Une fois que tous les minibatchs sont passés, l'époque est finie. On affiche alors la métrique moyenne obtenue à la fin.

1
2
3
  # Affichage des métriques à la fin de l'époque
  metric = metric_fn.result()
  print(f"Métrique pour l'époque : {float(metric):.4f} \n")

On remet à zéro la métrique pour le début de la nouvelle époque.

  # Reset de la métrique à la fin de chaque époque
  metric_fn.reset_states()

Si on souhaite avoir un dataset de validation, c'est ici que ça se passe. Comme pour train_step, on y revient bientôt.

1
2
3
  # validation loop à la fin de chaque époque
  for x_batch_val, y_batch_val in ds_val:
    val_loss = test_step(x_batch_val, y_batch_val)

On affiche les métriques de validation.

1
2
3
4
5
6
  val_metric = metric_fn.result()
  metric_fn.reset_states()
  print()
  print(f"Loss de validation : {float(val_loss):.4f}")
  print(f"Métrique de validation : {float(val_metric):.4f}")
  print(f"Durée de l'époque: {time.time() - start_time:.2fs}")

Comme expliqué plus haut, tf.keras a en fait deux modes de fonctionnement, et le fonctionnement de base est celui dit eager mode, ce qui fait que les instructions données dans une fonctions définie à la main, comme ici pour def train_step, seront exécutées les unes à la suites des autres, ce qui est long.

Le décorateur @tf.function permet de transformer toute fonction n'yant come variable que des tenseurs en un graphe statique. Il n'est pas nécessaire d'en savoir plus sur ces fameux graphes, la seule chose à savoir est que cela augmente la vitesse à laquelle les opérations sont faites dans la fonction.

En dehors de cela, la fonction train_step est exactement la même que celle définie dans .fit(), de la même façon pour la fonction de validation.

@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
    #prédiction sur le minibatch
    y_pred = model(x, training=True)
    #calcul de la fonction de perte moyenne sur le minibatch
    loss_value = loss_fn(y, y_pred)

  # calcul des gradients et retropropagation
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))

  # mise à jour des métriques
  metric_fn.update_state(y, y_pred)

  return loss_value
1
2
3
4
5
6
7
@tf.function
def test_step(x, y):
  y_pred = model(x, training=False)
  loss_value = loss_fn(y, y_pred)
  metric_fn.update_state(y, y_pred)

  return loss_value

Callbacks

En plus des arguments classiques tel que epochs ou validation_data, la méthode .fit() accepte aussi l'argument callbacks. Les callbacks permettent de fournir une liste d'intruction que tf.keras à certains moment précis de l'entraînement :

  • au début (à la fin) de l'époque actuelle,
  • au début (à la fin de l'étape de minibatch actuelle,
  • au début (à la fin) de l'entraînement.

Nous avons déjà vu le callback LearningRateScheduler dans le module suivant lorsque nous parlions du Learninf Rate Decay. L'idée est ici de passer en revue les plus importants, et de voir la structure interne d'un callback pour en écrire un nous même.

Les callbacks importants

tf.keras.callbacks.EarlyStopping

EarlyStopping permet à tf.keras d'aretter de lui même l'entraînement. On lui passe la métrique que l'on souhaite monitorer, et les crières d'arrêts.

On l'appelle vec les paramètres suivants.

1
2
3
4
keras.callbacks.EarlyStopping(monitor='val_loss',
                              min_delta=1e-2,
                              patience=10,
                              verbose=1)

Détaillons.

keras.callbacks.EarlyStopping(monitor='val_loss',
Détermine la métrique à surveillez, principalement la val_loss, mais on peut aussi monitorer val_acc, ou les métriques d'entraînement. Ici, l'entraînement s'arette lorsque val_loss ne s'améliore plus.

                 min_delta=1e-2,

On précise le "ne s'améliore plus" : on dit que la val_loss ne s'améliore plus si

\[\mathrm{val \_loss}(t+1) - \mathrm{val\_loss}(t) \leq 10^{-2}.\]
                 patience=10,
                 verbose=1)

Si pendant 10 époques ça ne s'améliore pas, on arêtte l'entraînement.

tf.keras.callbacks.ModelCheckpoint

ModelCheckpoint permet lui de faire des sauvegarde régulière du modèle suivant un critère de métrique. Une nouvelle version du modèle ne sera enrégistrée que si la métrique s'est améliorée.

1
2
3
4
5
6
7
keras.callbacks.ModelCheckpoint('./weights.{epoch:02d}-{val_loss:02d}.h5',
                                                 monitor='val_loss',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 save_weights_only=False,
                                                 mode='auto',
                                                 save_freq='epoch')

tf.keras.callbacks.LearningRateScheduler

1
2
3
4
5
6
7
8
def exponential_decay(lr0,step):
    def exponential_decay_fn(epoch):
        return lr0*0.1**(epoch/step)
    return exponential_decay_fn

exponential_decay_fn = exponential_decay(lr0 = lr0, step = 20)

keras.callbacks.LearningRateScheduler(exponential_decay_fn)

tf.keras.callbacks.TensorBoard

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
%tensorboard --logdir logs
1
2
3
4
5
from tensorboard import notebook
notebook.list() # View open TensorBoard instances
#Control TensorBoard display. If no port is provided, 
#the most recently launched TensorBoard is used
notebook.display(port=6006, height=1000)
callbacks_fit = [keras.callbacks.EarlyStopping(
                 # Stop training when `val_loss` is no longer improving
                 monitor='val_loss',
                 # "no longer improving" being defined as "no better than 1e-2 less"
                 min_delta=1e-2,
                 # "no longer improving" being further defined as "for at least 2 epochs"
                 patience=10,
                 verbose=1),
                 keras.callbacks.LearningRateScheduler(exponential_decay_fn),
                 keras.callbacks.ModelCheckpoint('./weights.{epoch:02d}-.hdf5',
                                                 monitor='val_loss',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 save_weights_only=False,
                                                 mode='auto',
                                                 save_freq='epoch')]

Structure interne d'un callback

Ce sont tous des sousclasses de la classe tf.keras.callbacks.Callback.

On peut les passer en liste lorsque l'on fait appel à l'une des 3 commandes suivantes.

1
2
3
   model.fit()
   model.evaluate()
   model.predict()

Un callback fait une action particulière avec une étape particulière, pour définir cette étape on a les méthodes suivantes.

  • Méthodes globales

    on_(train|test|predict)_begin(self, logs=None)
    
    on_(train|test|predict)_end(self, logs=None)
    

  • Batch-level méthodes

    on_(train|test|predict)_batch_begin(self, batch, logs=None)
    
    on_(train|test|predict)_batch_end(self, batch, logs=None)
    
    Ici, logs est un dictionnaire contenant les différentes métriques.

  • Epoch-level méthodes (uniquement durant lentraînement)

    on_epoch_begin(self, epoch, logs=None)
    
    on_epoch_end(self, epoch, logs=None)
    

#Créons un dataset dummy
t_x = tf.random.uniform([30, 4], dtype=tf.float32)
t_y = tf.range(30)

ds_x = tf.data.Dataset.from_tensor_slices(t_x)
ds_y = tf.data.Dataset.from_tensor_slices(t_y)

ds = tf.data.Dataset.zip((ds_x, ds_y))

ds = ds.batch(3)

from tensorflow.keras import Model
# Dummy model
input = Input(shape=(4,))

x = Dense(32)(input)
output = Dense(1)(x)

model = Model(input,x)

model.compile(loss = 'mean_absolute_error',
              optimizer=tf.keras.optimizers.Adam(lr=0.001),
              metrics=['mean_absolute_error'])

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
dense (Dense)                (None, 32)                160       
=================================================================
Total params: 160
Trainable params: 160
Non-trainable params: 0
_________________________________________________________________

class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print(f"Début de l'entraînement, les clés du log sont: {keys}")

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print(f"Fin de l'entraînement, les clés du log sont: {keys}")

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print(f"Début de l'époque {epoch}, les clés du log sont: {keys}")

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print(f"Fin de l'époque {epoch}, les clés du log sont: {keys}")

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print(f"Entraînement : début du batch {batch}, les clés du log sont: {keys}")

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print(f"Entraînement : fin du batch {batch}, les clés du log sont: {keys}")
1
2
3
4
history = model.fit(ds,
                    epochs=1,
                    verbose=0,
                    callbacks=[CustomCallback()])
Début de l'entraînement, les clés du log sont: []
Début de l'époque 0, les clés du log sont: []
Entraînement : début du batch 0, les clés du log sont: []
Entraînement : fin du batch 0, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 1, les clés du log sont: []
Entraînement : fin du batch 1, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 2, les clés du log sont: []
Entraînement : fin du batch 2, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 3, les clés du log sont: []
Entraînement : fin du batch 3, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 4, les clés du log sont: []
Entraînement : fin du batch 4, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 5, les clés du log sont: []
Entraînement : fin du batch 5, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 6, les clés du log sont: []
Entraînement : fin du batch 6, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 7, les clés du log sont: []
Entraînement : fin du batch 7, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 8, les clés du log sont: []
Entraînement : fin du batch 8, les clés du log sont: ['loss', 'mean_absolute_error']
Entraînement : début du batch 9, les clés du log sont: []
Entraînement : fin du batch 9, les clés du log sont: ['loss', 'mean_absolute_error']
Fin de l'époque 0, les clés du log sont: ['loss', 'mean_absolute_error']
Fin de l'entraînement, les clés du log sont: []

Allons plus dans le détail et voyons par exemple de quelles façons sont calculées les métriques données par tf.keras.

1
2
3
4
5
6
7
8
9
class LossCallback(tf.keras.callbacks.Callback):

    def on_train_batch_end(self, batch, logs):
        print(f'Batch {batch}, la perte est de {logs["loss"]:.2f}.\n')

    def on_epoch_end(self, epoch, logs):
        print(f'La perte moyenne pour lépoque {epoch} est {logs["loss"]:.2f} \n')

cb = LossCallback()
1
2
3
4
history = model.fit(ds,
                    epochs=1,
                    verbose=0,
                    callbacks=[cb])
Batch 0, la perte est de 1.01.

Batch 1, la perte est de 2.46.

Batch 2, la perte est de 3.93.

Batch 3, la perte est de 5.41.

Batch 4, la perte est de 6.90.

Batch 5, la perte est de 8.40.

Batch 6, la perte est de 9.89.

Batch 7, la perte est de 11.39.

Batch 8, la perte est de 12.89.

Batch 9, la perte est de 14.39.

La perte moyenne pour lépoque 0 est 14.39 


La façon dont tf.keras calcule la fonction de perte à la fin de de chaque époque est donc en faisant une moyenne : à la fin de l'époque, la métrique de perte donnée est la perte moyenne sur un minibatch, et à la fin de chaque batch, la métrique de perte donnée est la perte moyenne mouvante.

Si l'on souhaite avoir la perte moyenne sur une observation, ou la valeur de la perte sur chaque batch, comment faire ?

\[\mathrm{AvgLoss}_{n+1} := \frac{\mathrm{Loss}_{n+1}+n\cdot \mathrm{AvgLoss}_{n}}{n+1}\]
\[(n+1)\cdot\mathrm{AvgLoss}_{n+1} -n\cdot \mathrm{AvgLoss}_{n} = \mathrm{Loss}_{n+1}\]
class LossCallback(tf.keras.callbacks.Callback):

    def __init__(self, L = 0):
        self.L = L

    def on_train_batch_end(self, batch, logs):
      if batch == 0:
        print(f'Batch {batch}, loss is {logs["loss"]:.2f}.\n')
        #logs['loss'] gives the running avg mean, not the mean of the minibatch
      else:
        print(f'Batch {batch}, loss is {(batch+1)*logs["loss"]-self.L:.2f}.\n')
        print(f'Batch {batch}, running loss is {logs["loss"]:.2f}.\n')
      self.L = (batch+1)*logs["loss"]

    def on_epoch_end(self, epoch, logs):
        print(f'Avg loss on {epoch} is {logs["loss"]:.2f} \n')

cb = LossCallback()

model.fit(ds,
          epochs=1,
          verbose=0,
          callbacks=[cb])
Batch 0, loss is 0.95.

Batch 1, loss is 3.78.

Batch 1, running loss is 2.36.

Batch 2, loss is 6.73.

Batch 2, running loss is 3.82.

Batch 3, loss is 9.71.

Batch 3, running loss is 5.29.

Batch 4, loss is 12.72.

Batch 4, running loss is 6.78.

Batch 5, loss is 15.75.

Batch 5, running loss is 8.27.

Batch 6, loss is 18.74.

Batch 6, running loss is 9.77.

Batch 7, loss is 21.78.

Batch 7, running loss is 11.27.

Batch 8, loss is 24.76.

Batch 8, running loss is 12.77.

Batch 9, loss is 27.69.

Batch 9, running loss is 14.26.

Avg loss on 0 is 14.26 


<tensorflow.python.keras.callbacks.History at 0x7f258a3c1da0>
1
2
3
4
5
6
class PrintValTrainRatioCallback(tf.keras.callbacks.Callback):

  def on_epoch_end(self, epoch, logs):
    print(f'\n Validation-Train Ratio : {logs["val_loss"]/logs["loss"]:.2f}')

cb = PrintValTrainRatioCallback()

Annexe

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Dense

import random

print(tf.__version__)
print(keras.__version__)

RANDOM_SEED = 42


random.seed(RANDOM_SEED)
tf.random.set_seed(RANDOM_SEED)



#dummy dataset
t_x = tf.random.uniform([30, 4], dtype=tf.float32)
t_y = tf.range(30)
ds_x = tf.data.Dataset.from_tensor_slices(t_x)
ds_y = tf.data.Dataset.from_tensor_slices(t_y)
ds = tf.data.Dataset.zip((ds_x, ds_y))
ds = ds.batch(3)

class LossCallback(tf.keras.callbacks.Callback):

    def on_train_batch_end(self, batch, logs):
        print(f'Batch {batch}, loss is {logs["loss"]:.2f}.\n')
        #logs['loss'] gives the running avg mean, not the mean of the minibatch

    def on_epoch_end(self, epoch, logs):
        print(f'Avg loss on {epoch} is {logs["loss"]:.2f} \n')

cb = LossCallback()

from tensorflow.keras import Model

input = Input(shape=(4,))

x = Dense(2)(input)
X = Dense(1)(x)

model = Model(input,x)

model.compile(loss = 'mean_absolute_error',
              optimizer=tf.keras.optimizers.SGD())

history = model.fit(ds,
                    epochs=1,
                    verbose=0,
                    callbacks=[cb])
2.2.0
2.3.0-tf
Batch 0, loss is 0.83.

Batch 1, loss is 2.00.

Batch 2, loss is 3.40.

Batch 3, loss is 4.83.

Batch 4, loss is 6.27.

Batch 5, loss is 7.75.

Batch 6, loss is 9.21.

Batch 7, loss is 10.71.

Batch 8, loss is 12.20.

Batch 9, loss is 13.67.

Avg loss on 0 is 13.67