Aller au contenu

Fusion Conv-BN et RepVGG

1
2
3
4
5
6
7
8
9
import tensorflow as tf
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import GlobalAvgPool2D, Flatten, ReLU, Softmax, Dense
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Add
import numpy as np
from tensorflow.keras.models import Model
from tensorflow.keras import backend
img_shape = 32, 32, 3

Etude des poids des Conv \(3 \times 3\) et des BatchNorm

Regardons comment s'articulent les poids dans les couches de convolutions et de batchnormalisation.

Initialisation d'un modèle

1
2
3
4
5
input = Input(img_shape)
x= Conv2D(filters = 16, kernel_size=3, padding='same', use_bias=True, kernel_initializer='he_uniform', name='testing_conv_init')(input)
x= BatchNormalization(name=f'testing_bn_init')(x)
model = Model(input, x)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
testing_conv_init (Conv2D)   (None, 32, 32, 16)        448       
_________________________________________________________________
testing_bn_init (BatchNormal (None, 32, 32, 16)        64        
=================================================================
Total params: 512
Trainable params: 480
Non-trainable params: 32
_________________________________________________________________

Etude de la couche convolutive

weights_conv = model.get_layer("testing_conv_init").get_weights()
weights_conv
[array([[[[-0.09736294,  0.01110744,  0.38817808,  0.02365676,
            0.37265095,  0.22067323,  0.44893858, -0.4457162 ,
            0.2723634 ,  0.21101996, -0.42767352,  0.39105812,
           -0.38641602, -0.39619464, -0.12856498,  0.00230291],
          [-0.15622476,  0.08576027, -0.39533868,  0.14336786,
            0.09569708,  0.05594608,  0.05045763, -0.15595007,
           -0.05612397, -0.19001147,  0.2724487 ,  0.3459774 ,
            0.01586419,  0.08192965,  0.32559904,  0.04557905],
          [-0.37503266,  0.05977681, -0.05365878,  0.34279034,
           -0.22699383, -0.20862746,  0.13931164, -0.20776296,
            0.12117836,  0.06501201, -0.28448707,  0.2668244 ,
            0.2704514 , -0.34608564, -0.35193035, -0.3525829 ]],

         [[ 0.11737838,  0.14824674, -0.00563487,  0.3061553 ,
            0.01097953,  0.23561516, -0.4535907 , -0.4175086 ,
            0.4607807 , -0.37212858, -0.30806714,  0.14160755,
            0.24837866,  0.12601265,  0.2622976 ,  0.3263745 ],
          [-0.34101892,  0.31189194, -0.11391068,  0.14759234,
           -0.30657652, -0.13771534,  0.45230624,  0.22417751,
            0.0407432 ,  0.07712099, -0.2991236 , -0.1449846 ,
           -0.00576615,  0.26184592, -0.2169587 , -0.2828556 ],
          [-0.20167324, -0.1357314 , -0.29285318,  0.33294716,
            0.23413381, -0.00896761, -0.31519616,  0.14762929,
           -0.18309683, -0.32602823,  0.10732868, -0.15018934,
           -0.27001262,  0.0079852 , -0.20946455, -0.10388163]],

         [[ 0.2783232 ,  0.05878171, -0.35913152,  0.40182415,
            0.092141  , -0.13399044,  0.03964153,  0.06443217,
            0.46447167, -0.41376618,  0.10037312,  0.42862818,
            0.17965165, -0.13864604,  0.04373595, -0.13044247],
          [-0.45564812, -0.33004287,  0.3405182 ,  0.39920047,
           -0.25907567,  0.15346971, -0.02812734, -0.19346526,
           -0.12104025, -0.4278583 ,  0.23774037,  0.3617756 ,
           -0.0419741 ,  0.09765604, -0.26489764, -0.43987206],
          [ 0.07685599,  0.04646841, -0.06539023,  0.02657837,
            0.44154963,  0.21664849,  0.4601402 ,  0.21062568,
            0.34529766,  0.30855897,  0.40019926, -0.26326853,
           -0.04224968, -0.46108606, -0.37318596, -0.42175823]]],


        [[[-0.3455502 ,  0.46076384, -0.15374777,  0.29009756,
           -0.1667389 ,  0.31607136,  0.26514402,  0.3037739 ,
           -0.0812335 ,  0.23732015,  0.07654181, -0.0679315 ,
           -0.31959674, -0.28597334, -0.10885993,  0.408551  ],
          [ 0.14297041, -0.09827566,  0.40382537, -0.4178524 ,
           -0.01861295, -0.35064167,  0.45108077, -0.12225789,
           -0.29017037,  0.23471412,  0.14587566, -0.36702543,
           -0.3979674 ,  0.29743996, -0.39642572, -0.0245271 ],
          [-0.0504581 ,  0.46684763,  0.37567046,  0.20170256,
           -0.36068565, -0.17110959,  0.30896595,  0.09373525,
           -0.21380827, -0.395913  ,  0.24522027,  0.36945108,
           -0.06558827, -0.4528262 ,  0.08612862,  0.3769413 ]],

         [[-0.06688267,  0.11741361, -0.14751217, -0.01143798,
            0.30352488,  0.23946908,  0.15358987, -0.11050937,
           -0.05894232, -0.22810066,  0.3403059 ,  0.23961261,
           -0.16434652,  0.11093548,  0.00398877, -0.11645409],
          [-0.37574512,  0.3296124 , -0.05067682, -0.09595928,
           -0.2214837 ,  0.35080925, -0.02972579,  0.1532239 ,
            0.18751535, -0.02314931, -0.18395177, -0.03616032,
            0.27149966, -0.4180094 ,  0.28113106,  0.4416559 ],
          [-0.2833048 ,  0.03007612,  0.34248617, -0.24044934,
            0.16455218,  0.15701613,  0.20851544, -0.25038487,
           -0.21328565, -0.23982422,  0.37252668,  0.15518335,
            0.28911796,  0.44327244,  0.14157644,  0.26580074]],

         [[ 0.28428563, -0.18229985, -0.20275667,  0.1955097 ,
           -0.3543805 , -0.14616191,  0.28929475,  0.14749238,
            0.2464734 , -0.46400836, -0.02009585,  0.3317866 ,
           -0.20362425, -0.42040807, -0.17440939, -0.01986095],
          [ 0.44240263,  0.1606206 , -0.4160647 ,  0.27185783,
            0.06790957, -0.32414603, -0.23126818, -0.29003426,
            0.12488225,  0.16395304, -0.3259102 , -0.38798535,
           -0.43922648, -0.3790248 ,  0.12847778,  0.13634267],
          [-0.02119684,  0.28930375, -0.4681574 ,  0.28300878,
            0.40201846,  0.1442084 ,  0.19728747, -0.02722415,
            0.42741737, -0.21063939,  0.42403385,  0.3592575 ,
           -0.20373034, -0.4468028 , -0.08656737, -0.20087367]]],


        [[[ 0.02782702,  0.01967528,  0.07292607,  0.389165  ,
            0.36987802,  0.18188533,  0.04504755,  0.41210625,
            0.01592982, -0.11140019, -0.11948159,  0.10680953,
           -0.11324027,  0.39144352, -0.35009858,  0.26030532],
          [ 0.09090939, -0.09534436,  0.03001484, -0.1512312 ,
           -0.3052565 ,  0.3308485 , -0.24254534, -0.10194659,
            0.00120181,  0.38111654,  0.21856287,  0.32130632,
           -0.33506405,  0.44324157,  0.32223514, -0.2681678 ],
          [ 0.4114515 , -0.40250486,  0.3429939 , -0.2685476 ,
            0.37336466,  0.18623039, -0.23775014,  0.18628749,
           -0.07814869,  0.15818772, -0.21979165, -0.44143543,
           -0.09982735, -0.17906868, -0.11419335, -0.10601136]],

         [[-0.04968157,  0.33098873, -0.35580167,  0.19270661,
            0.33699194,  0.41285995,  0.23392567,  0.4191365 ,
           -0.04616675,  0.22136435, -0.29591143,  0.20710972,
           -0.3502412 , -0.10524371, -0.32441103,  0.35304728],
          [-0.45383817, -0.21678528,  0.13484439, -0.23379093,
           -0.038737  ,  0.02832112, -0.36996582, -0.19088644,
            0.32281145,  0.14254984,  0.24848273,  0.3904991 ,
           -0.41010976, -0.4470079 , -0.13003865, -0.41824162],
          [ 0.14215276,  0.29101995, -0.40439036, -0.43802816,
            0.33089224, -0.34787324,  0.32901898, -0.11648187,
           -0.07495171,  0.4408585 ,  0.24113259, -0.08504415,
            0.0623028 ,  0.23243883, -0.19044554,  0.3553057 ]],

         [[-0.09105429,  0.30605808,  0.41014054,  0.33675548,
           -0.05541334,  0.1684458 , -0.12973395, -0.1533834 ,
           -0.00280687, -0.07284549,  0.35922107,  0.09879085,
            0.43200687,  0.02035648,  0.36917636, -0.4091605 ],
          [-0.22962008,  0.13721845, -0.12232229, -0.2520702 ,
            0.07350466,  0.042968  ,  0.31936017,  0.36257067,
           -0.10858962, -0.18460804, -0.1590521 , -0.47047156,
           -0.1503287 , -0.04001006, -0.11763006, -0.24719194],
          [-0.32062727,  0.13543853,  0.0306963 , -0.16442779,
           -0.37330386,  0.13838956,  0.03568128, -0.2321063 ,
            0.07501194, -0.426333  ,  0.00162551,  0.04697415,
           -0.20748636, -0.30767232,  0.43708238, -0.16146705]]]],
       dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32)]

Les poids forment une liste de deux élements : les poids des noyaux de convolutions et les biais. La méthode d'initalisation utilisée ici est he_uniform, développée dans l'article Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification

type(weights_conv)
list
len(weights_conv)
2

Les poids dans une couche convolutive sont une liste de deux éléments : - weights[0] correspond aux poids des noyaux de convolution, - weights[1] correspond aux biais.

type(weights_conv[0])
numpy.ndarray
weights_conv[0].shape 
(3, 3, 3, 16)

Les axes du tenseur de poids suivent les dimensions suivantes :

  • kernel_size1 : hauteur du kernel,
  • kernel_size2 : largeur du kernel,
  • channels_in : nombre des feature maps en entrée,
  • channels_out : nombres de features maps (filters) en sortie.

channels_out est définie dans la couche convolutive via le paramètres filters, alors que la valeur channels_in est elle directement déterminée par le tenseur en entrée. C'est une différence de TensorFlow par rapport à Pytorch où channels_in et channels_out sont tous les deux des paramètres des couches convolutives.

Ainsi, si l'on veut voir les poids du noyau de convolution par rapport au canal \(0\) en la feature map de sortie \(5\), on les obtient en regardant :

weights_conv[0][:,:,0,5]
array([[ 0.22067323,  0.23561516, -0.13399044],
       [ 0.31607136,  0.23946908, -0.14616191],
       [ 0.18188533,  0.41285995,  0.1684458 ]], dtype=float32)

Par défaut, les biais des couches de convolutions sont tous initialisés à zéro.

weights_conv[1]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
      dtype=float32)

Etude de la batchnorm

weights_bn = model.get_layer('testing_bn_init').get_weights()
weights_bn
[array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       dtype=float32),
 array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       dtype=float32)]
type(weights_bn)
list
len(weights_bn)
4

Dans une couche de Batchnormalization, on a 4 types de poids.

  • Les deux paramètres de scaling \(\gamma\) et de biais \(\beta\).
  • Les deux paramètres correspondant à la moyenne \(\mu\) et la variance \(\sigma\).

Tous ces paramètres ne sont pas entraînables, comme on peut le voir dans la liste suivante.

[(var.name, var.trainable) for var in model.get_layer('testing_bn_init').variables]
backend.shape(model.get_layer('testing_bn_init').get_weights())
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([ 4, 16], dtype=int32)>

Les \(4\) paramètres sont tous des vecteurs de dimension \(16\), ce qui correspond au nombre de feature maps en sortie de la couche convolutive.

Fusion d'une Convolution et d'une batchnorm

La fusion d'une couche de convolution avec une couche de batchnorm ressort les poids et biais d'une nouvelle couche de convolution avec les noyaux de convolutions de même dimension.

Etant donné le tenseur \(W\) de poids des noyaux de convolution d'une couche convolutive et le tenseur de \(4\) paramètres \(B=(\gamma, \beta, \mu, \sigma)\) d'une couche de batchnormalization, on obtient les nouveaux poids et poids de la nouvelle couche convolutive via les formules suivantes.

\[ \widehat{W}_{:,:,:,j} := \frac{\gamma_{j} \cdot W_{:,:,:,j}}{\sqrt{\sigma_{j} + \epsilon}} \]
\[ b_{j} = \beta_{j} - \frac{\mu_{j}\cdot\gamma_{j}}{\sqrt{\sigma_{j} + \epsilon}} \]

On remarque ici que le biais de la nouvelle couche de convolution ne dépend que des paramètres de la couche de batchnorm. Ce qui est cohérent avec la pratique de ne jamais mettre de biais dans une couche de convolution lorsqu'elle est suivie par une couche de batchnorm.

Remarque : le \(\epsilon\) présent ici est pour s'assurer que l'on ne divise jamais pas zéro, dans la pratique il est fixé à \(0,001\).

Ce qui nous donne, dans la pratique la fonction suivante.

# https://scortex.io/batch-norm-folding-an-easy-way-to-improve-your-network-speed/
# https://github.com/DingXiaoH/RepVGG/blob/4da799e33c890c624bfb484b2c35abafd327ba40/repvgg.py#L68

def fuse_bn_conv(weights_conv, weights_bn, eps=0.001):
    gamma = np.reshape(weights_bn[0], (1,1,1,weights_bn[0].shape[0]))
    beta = weights_bn[1]
    mean = weights_bn[2]
    variance = np.reshape(weights_bn[3], (1,1,1,weights_bn[3].shape[0]))

    new_weights = (weights_conv[0]*gamma) / np.sqrt(variance + eps)
    new_bias = beta - mean*gamma/np.sqrt(variance+eps)

    new_bias = np.reshape(new_bias, weights_bn[3].shape[0])

    return new_weights, new_bias

# In the code above, the reshaping is necessary to prevent a mistake if the dimension of the output O was the same as the dimension of the input I. 

# def get_equivalent_kernel_bias(self):
#    kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
#    kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
#    kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
#    return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

Détaillons la fonction ci dessus.

Nouveau tenseur de poids

Discutons premièrement de la formulation du nouveau tenseur de poids, et voyons pourquoi on modifie la forme de vecteurs \(\gamma\) et \(\sigma\).

\(W_{:,:,:,j}\) correspond dans la formule au noyau de convolution complet de la \(j\)-ième feature map de sortie.

weights_conv = model.get_layer("testing_conv_init").get_weights()
weights_conv[0].shape
(3, 3, 3, 16)

On a \(16\) noyaux de convolution, chacun de dimensions \((3,3,3)\). Par exemple, pour \(j=1\).

weights_conv[0][:,:,:,1].shape
(3, 3, 3)

Les vecteur \(\gamma\) et \(\sigma\) étant des vecteurs de dimension \(16\), on va les "transformer en tenseur" de dimensions \((1,1,1,16)\) pour bien faire correspondre le produit suivant chaque axe.

variance = np.reshape(weights_bn[3], (1,1,1,weights_bn[3].shape[0]))
variance.shape
(1, 1, 1, 16)
gamma = np.reshape(weights_bn[0], (1,1,1,weights_bn[0].shape[0]))
gamma.shape
(1, 1, 1, 16)

screen

Au final, la formule

new_weights = (weights_conv[0]*gamma) / np.sqrt(variance + eps)

résume tout cela, tous les tenseurs ayant le nombre d'axes, les opérations sont vectorisées et se font axe par axe.

Nouveau tenseur de biais

Le opérations de reshape n'ont pas ajouter de nouveaux scalaires, juste des axes, le calcul du biais se fait alors élément par élément pour tout \(j\).

Vérification via les développements limités

Créons un tenseur de poids \(W\) repéresentatif du noyau d'une convolution et un tenseur de poids \(B=(\gamma, \beta, \mu, \sigma)\) représentatif des coefficients d'une batchnormalization.

Pour vérifier si tout marche bien, fixons volontairement le tenseur poids comme un tenseur de dimensions \((3,3,4,5)\), la dimension du noyau est toujours fixé à \((3,3)\) dans RepVGG, seules les dimensions channels_in et channels_out peuvent changer.

Tous les coefficients du tenseur de poids seront fixés à \(1\).

conv_weights = np.ones(3*3*4*5).reshape((3,3,4,5))
conv_weights.shape
(3, 3, 4, 5)

La dimension channels_out ayant été fixée à \(5\), les vecteurs de la batchnormalization seront tous des vecteurs de dimension \(5\). Fixons les coefficients suivants.

1
2
3
4
5
6
7
def batchnorm_variables(gamma_coef: float, beta_coef: float, mu_coef: float, sigma_coef: float, channels: int):
    gamma = gamma_coef*np.ones(channels)
    beta = beta_coef*np.ones(channels)
    mu = mu_coef*np.ones(channels)
    sigma = sigma_coef*np.ones(channels)

    return [gamma, beta, mu, sigma]
conv, bn = fuse_bn_conv([conv_weights], batchnorm_variables(1,2,1,4,5))

Par définition, le nouveau tenseur de poids \(\widehat{W}\) de la convolution résultant de la fusion de l'ancienne convolution et de la batchnorm est donné par formule suivante.

\[ \widehat{W}_{:,:,:,j} := \frac{\gamma_{j} \cdot W_{:,:,:,j}}{\sqrt{\sigma_{j} + \epsilon}} \]

De façon générale, pour \(\gamma_{j}, \sigma_{j}\), on a le développement limité suivant.

\[ \widehat{W}_{:,:,:,j} := \frac{\gamma_{j} \cdot W_{:,:,:,j}}{\sqrt{\sigma_{j} + \epsilon}} = \frac{\gamma_{j}}{\sqrt{\sigma_{j}}}\left[1- \frac{1}{2\sigma_{j}}\epsilon + o(\epsilon^{2})\right]W_{:,:,:,j} \]

Dans notre cas, \(\forall j, \gamma_{j} = 1, \sigma_{j} = 4\) d'où

\[ \widehat{W}_{:,:,:,j} := \frac{W_{:,:,:,j}}{\sqrt{4 + \epsilon}} = \left[\frac{1}{2}- \frac{1}{16}\epsilon + o(\epsilon^{2})\right]W_{:,:,:,j} \simeq \left[\frac{1}{2}- \frac{1}{16}\epsilon\right]W_{:,:,:,j} \]
def compute_scaling_weight_factor(gamma, sigma):
    return gamma/np.sqrt(sigma)*(1-0.001/(2*sigma))
scale = compute_scaling_weight_factor(1,4)
scale
0.4999375
conv[:,:,:,4]
array([[[0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751]],

       [[0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751]],

       [[0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751],
        [0.49993751, 0.49993751, 0.49993751, 0.49993751]]])

Ce qui correspond bien à l'approximation obtenue par développement limité. On peut par exemple vérifier si \(\widehat{W}\) est approximativement égal à conv à \(10^{-3}\) avec la commande np.isclose.

conv_weights_real = scale*np.ones(3*3*4*5).reshape((3,3,4,5))
conv_weights_real[:,:,:,0]
array([[[0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375]],

       [[0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375]],

       [[0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375],
        [0.4999375, 0.4999375, 0.4999375, 0.4999375]]])

Si np.mean(...) \(< 1\) alors le calcul est faux.

np.mean(np.isclose(conv, conv_weights_real, rtol=1e-3))
1.0

Pour le biais, on a la formule suivante.

\[ b_{j} = \beta_{j} - \frac{\mu_{j}\cdot\gamma_{j}}{\sqrt{\sigma_{j} + \epsilon}} = \beta_{j} - \frac{\mu_{j}\cdot\gamma_{j}}{\sqrt{\sigma_{j}}}\left[1- \frac{1}{2\sigma_{j}}\epsilon + o(\epsilon^{2})\right] \]

dans notre cas, on a :

  • \(\beta_{j} = 2\),
  • \(\gamma_{j} = 1\),
  • \(\mu_{j} = 1\),
  • \(\sigma_{j} = 4\).
\[ b_{j} = 2 - \frac{1}{2}\left[1- \frac{1}{8}\epsilon + o(\epsilon^{2})\right] \simeq 2 - \frac{1}{2} - \frac{1}{16}\epsilon \]
1
2
3
4
5
def compute_scaling_bias_factor(gamma, beta, mu, sigma):
    a = (mu*gamma)/np.sqrt(sigma)
    b = 1 - 0.001/(2*sigma)

    return beta-a*b
bias_scale = compute_scaling_bias_factor(1,2,1,4)
bias_scale
bn
array([1.50006249, 1.50006249, 1.50006249, 1.50006249, 1.50006249])
bn_real = bias_scale*np.ones(5)
np.mean(np.isclose(bn_real, bn, rtol=1e-3))
1.0

RepVGG

screen

screen

Les couches convolutives dans RepVGG n'ayant que des noyaux \(3\times3\) ou \(1\times1\), on ne se préoccupe que de cela dans la suite.

Fusion d'une Conv \(3\times3\) avec une batchnorm puis transfert de poids

Créons un modèle simple : une couche convolutive suivi d'une couche de batchnormalisation, pour simplifier on ne condière aucune couche d'activation (qui de toute façon ne rentre pas en jeu). Nous allons :

  1. Fusionner les deux couches pour créer un nouveau tenseur (poids, biais)
  2. Transférer ce nouveau tensor dans un modèle plus simple model_after_fusion.

Remarque : la convolution dans model_after_fusion utilise elle bien un biais (use_bias = True).

1
2
3
4
5
input = Input(img_shape)
x= Conv2D(filters = 16, kernel_size=3, padding='same', use_bias=False, kernel_initializer='he_uniform', name='conv')(input)
x= BatchNormalization(name='bn')(x)
model_before_fusion = Model(input, x)
model_before_fusion.summary()
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv (Conv2D)                (None, 32, 32, 16)        432       
_________________________________________________________________
bn (BatchNormalization)      (None, 32, 32, 16)        64        
=================================================================
Total params: 496
Trainable params: 464
Non-trainable params: 32
_________________________________________________________________

1
2
3
4
input = Input(img_shape)
x= Conv2D(filters = 16, kernel_size=3, padding='same', use_bias=True, kernel_initializer='he_normal', name='conv')(input)
model_after_fusion = Model(input, x)
model_after_fusion.summary()
Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_3 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv (Conv2D)                (None, 32, 32, 16)        448       
=================================================================
Total params: 448
Trainable params: 448
Non-trainable params: 0
_________________________________________________________________

weights_1 = model_before_fusion.get_layer('conv').get_weights()[0]
weights_2 = model_after_fusion.get_layer('conv').get_weights()[0]
np.mean(weights_1-weights_2)
0.0023266133
1
2
3
4
5
conv = model_before_fusion.get_layer("conv")
bn = model_before_fusion.get_layer("bn")

conv_weights, conv_biases = fuse_bn_conv(conv.get_weights(), bn.get_weights())
model_after_fusion.get_layer(f"conv").set_weights([conv_weights, conv_biases])

Vérifions que la mise en place des nouveaux poids s'est bien passée, ie que l'opération set_weights() n'a rien ajouté de supplémtentaire. Si tout se passe bien, np.mean ne devrait renvoyer que des 1.0.

1
2
3
w0, b0 = fuse_bn_conv(model_before_fusion.get_layer("conv").get_weights(), model_before_fusion.get_layer("bn").get_weights())

w1, b1 = model_after_fusion.get_layer("conv").get_weights()
np.mean(w0 == w1)
1.0
np.mean(b0 == b1)
1.0

Donc tout s'est bien passé. Reste maintenant à généraliser cette transformation.

L'idée de RepVGG est d'utiliser une architecture à la ResNet pour l'entraînement, avec des skips connections, puis lors du déploiement du modèle de reparamétrer les skips connections via des fusions Conv-BN afin de plus avoir qu'une architecture linéaire à la VGG, beaucoup plus rapide en inférence qu'une architecture à la ResNet.

En plus de fusionner des \(\mathrm{Conv} 3 \times 3\) avec des \(\mathrm{BN}\), il est aussi nécessaire de savoir faire les opérations suivantes.

  1. Convertir une \(\mathrm{Conv} 1 \times 1\) en \(\mathrm{Conv} 3 \times 3\) puis la fusionner avec la \(\mathrm{BN}\) correspondante.
  2. Convertir une \(\mathrm{id}\) en \(\mathrm{Conv} 3 \times 3\) puis la fusionner avec la \(\mathrm{BN}\) correspondante.

Conversion d'une Conv \(1 \times 1\) en \(3 \times 3\) puis fusion avec la batchnorm.

Pour convertir une conv 1x1 en conv 3x3 les nombres de canaux en entrée et en sortie importe peu, ce qu'il faut c'est modifier la dimension des noyaux de convolutions pour passer d'une dimension 1x1 à 3x3, et pour cela on utilise un padding.

1
2
3
4
5
input = Input(img_shape)
x= Conv2D(filters = 16, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_uniform', name='conv')(input)
x= BatchNormalization(name='bn')(x)
model_before_fusion_conv1 = Model(input, x)
model_before_fusion_conv1.summary()
Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv (Conv2D)                (None, 32, 32, 16)        48        
_________________________________________________________________
bn (BatchNormalization)      (None, 32, 32, 16)        64        
=================================================================
Total params: 112
Trainable params: 80
Non-trainable params: 32
_________________________________________________________________

1
2
3
4
input = Input(img_shape)
x= Conv2D(filters = 16, kernel_size=3, padding='same', use_bias=True, kernel_initializer='he_normal', name='conv')(input)
model_after_fusion_conv1 = Model(input, x)
model_after_fusion_conv1.summary()
Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv (Conv2D)                (None, 32, 32, 16)        448       
=================================================================
Total params: 448
Trainable params: 448
Non-trainable params: 0
_________________________________________________________________

weights_conv1 = model_before_fusion_conv1.get_layer('conv')
weights_bn1 = model_before_fusion_conv1.get_layer('bn')
weights_conv1.get_weights()[0].shape
(1, 1, 3, 16)

La première chose à faire, c'est de transformer les noyaux de convolution \(1\times1\) en des noyaux de convolution \(3\times3\). Pour faire cela, on utilise la notion de "padding", déjà utilisée dans le cas des convolutions.

weights_conv1.get_weights()[0][:,:,1,1]
array([[-1.227452]], dtype=float32)

On a deux fonctions possibles pour faire ça. On peut utiliser soit la fonction de tensorflow.

padded_conv1 = tf.pad(weights_conv1.get_weights()[0], [[1,1], [1, 1], [0,0], [0,0]], "CONSTANT")

Soit la fonction de numpy.

padded_conv1 = np.pad(weights_conv1.get_weights()[0], pad_width=[[1,1], [1, 1], [0,0], [0,0]], mode='constant', constant_values=0)

Dans les deux cas, on a un paramètre donnant la taille du padding : [[1,1], [1, 1], [0,0], [0,0]], c'est une liste de longueur le nombre d'axes du tenseur que l'on souhaite modifier, chaque élément de la liste nous dit de combien on doit agrandir au début et à la fin.

[[1,1], [1, 1], [0,0], [0,0]] = [[pad_avant_axe1, pad_arrière_axe1], [pad_avant_axe2, pad_arrière_axe2], [pad_avant_axe3, pad_arrière_axe3], [pad_avant_axe4, pad_arrière_axe4]]

Le dernier paramètre nous dit quoi rajouter aux endroits où l'on a agrandi, ici des constantes : la valeur \(0\).

padded_conv1_tf = tf.pad(weights_conv1.get_weights()[0], [[1,1], [1, 1], [0,0], [0,0]], "CONSTANT")
padded_conv1_np = np.pad(weights_conv1.get_weights()[0], pad_width=[[1,1], [1, 1], [0,0], [0,0]], mode='constant', constant_values=0)

Les deux fonctions donnent le même résultat.

np.mean(padded_conv1_tf.numpy()==padded_conv1_np)==1
True

Comme la fonction set_weights() demande d'utiliser des np.array, on va utiliser la fonction de numpy.

def pad_size_one_kernel(conv_weights):    
    return np.pad(conv_weights[0], pad_width=[[1,1], [1, 1], [0,0], [0,0]], mode='constant', constant_values=0)
padded_weights_conv1 = pad_size_one_kernel(weights_conv1.get_weights())
padded_weights_conv1.shape
(3, 3, 3, 16)

Vérification

On a transformé tous les noyaux de convolutions \(1\times1\) en noyaux \(3\times3\), chacun des padded_weights_conv1[:,:,i,j] pour \(0 \leq i \leq 2\) et \(0 \leq j \leq 15\) doit être une matrice \(3\times3\) où tous les éléments sont nuls sauf possiblement celui du milieu.

def test_padded_kernel_conv(padded_kernel):
    for i in range(3):
        for j in range(16):
            print(f'Matrix of size 3x3 : {padded_kernel[:,:,i,j].shape == (3,3)}')
            squared_sum = 0
            for k in range(3):
                for l in range(3):
                    if k != 1 and l != 1:
                        squared_sum += padded_kernel[:,:,i,j][k,l]**2
            print(f'Squared sum is : {squared_sum}')
test_padded_kernel_conv(padded_weights_conv1)
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0

dummy_conv = np.ones(1*1*3*16).reshape((1,1,3,16))
padded_dummy_conv=pad_size_one_kernel([dummy_conv])
test_padded_kernel_conv(padded_dummy_conv)
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0
Matrix of size 3x3 : True
Squared sum is : 0.0

Comme précédemment, on vérifie via les développements limités que ça fonctionne.

conv, bn = fuse_bn_conv([padded_dummy_conv], batchnorm_variables(1,2,1,4,16))
scale = compute_scaling_weight_factor(1,4)
scale
0.4999375
np.mean(np.isclose(conv, scale*padded_dummy_conv, rtol=1e-3))
1.0
bias_scale = compute_scaling_bias_factor(1,2,1,4)
bias_scale
1.5000625
bn_real = bias_scale*np.ones(16)
np.mean(np.isclose(bn_real, bn, rtol=1e-3))
1.0
1
2
3
4
5
6
7
weights_conv1 = model_before_fusion_conv1.get_layer('conv')
weights_bn1 = model_before_fusion_conv1.get_layer('bn')

padded_weights_conv1 = pad_size_one_kernel(weights_conv1.get_weights())
conv_weights, conv_bias = fuse_bn_conv([padded_weights_conv1], weights_bn1.get_weights())

model_after_fusion_conv1.get_layer("conv").set_weights([conv_weights, conv_bias])

Conversion d'une \(\mathrm{id}\) en \(\mathrm{Conv} 3 \times 3\) puis fusion avec la batchnorm.

Les branches id ne sont utilisées dans l'architecture de RepVGG que lorsque la conditions channels_in = channels_out est vérifiée, c'est à dire à l'intérieur de chaque stage entre 2 blocs convolutifs avec un stride de 2.

# Fixons le nombre de channels, peut importe le nombre.
channels = 4

An identity mapping can be viewed as a \(1\times1\) conv with an identity matrix as the kernel.

1
2
3
4
5
6
7
def size_three_kernel_from_id(channels):
    kernel = np.ones(channels)
    kernel = np.diag(kernel)
    kernel = np.reshape(kernel, (1,1,channels,channels))
    kernel = np.pad(kernel, pad_width=[[1,1], [1, 1], [0,0], [0,0]], mode='constant', constant_values=0)

    return kernel
conv_from_id = size_three_kernel_from_id(4)
conv_from_id.shape
(3, 3, 4, 4)
def test_padded_kernel_from_id(padded_kernel):
    for i in range(3):
        for j in range(16):
            print(f'Matrix of size 3x3 : {padded_kernel[:,:,i,j].shape == (3,3)}')
            squared_sum = 0
            for k in range(3):
                for l in range(3):
                    if (k,l) != (1,1):
                        squared_sum += padded_kernel[:,:,i,j][k,l]**2
                    else:
                        print(f'Middle element is 1 : {padded_kernel[:,:,i,j][k,l]==1}')
            print(f'Squared sum is : {squared_sum}')
test_padded_kernel_from_id(conv)
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0
Matrix of size 3x3 : True
Middle element is 1 : False
Squared sum is : 0.0

conv, bn = fuse_bn_conv([conv_from_id], batchnorm_variables(1,2,1,4,channels))
scale = compute_scaling_weight_factor(1,4)
scale
0.4999375
np.mean(np.isclose(conv, scale*conv_from_id, rtol=1e-3))
1.0
bias_scale = compute_scaling_bias_factor(1,2,1,4)
bias_scale
1.5000625
bn_real = bias_scale*np.ones(channels)
np.mean(np.isclose(bn_real, bn, rtol=1e-3))
1.0

Vérification

Test grandeur réelle

def repvgg_block(tensor, filters, num_layer):

    # main stream
    x = Conv2D(
        filters=filters,
        kernel_size=(3,3),
        strides=(2,2),
        padding="same",
        kernel_initializer="he_normal",
        use_bias=False,
        name=f'block_{num_layer}_conv_main'
    )(tensor)
    x = BatchNormalization(name=f'block_{num_layer}_bn_main')(x)

    # conv1x1 stream

    y = Conv2D(
        filters=filters,
        kernel_size=(1,1),
        strides=(2,2),
        padding="same",
        kernel_initializer="he_normal",
        use_bias=False,
        name=f'block_{num_layer}_conv_alt'
    )(tensor)
    y = BatchNormalization(name=f'block_{num_layer}_bn_alt')(y)

    z = Add()([x,y])

    return z
def repvgg_block_with_id(tensor, filters, num_layer):

    # main stream
    x = Conv2D(
        filters=filters,
        kernel_size=(3, 3),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_normal",
        use_bias=False,
        name=f"block_{num_layer}_conv_main",
    )(tensor)
    x = BatchNormalization(name=f"block_{num_layer}_bn_main")(x)

    # conv1x1 stream

    y = Conv2D(
        filters=filters,
        kernel_size=(1, 1),
        strides=(1, 1),
        padding="same",
        kernel_initializer="he_normal",
        use_bias=False,
        name=f"block_{num_layer}_conv_alt",
    )(tensor)
    y = BatchNormalization(name=f"block_{num_layer}_bn_alt")(y)

    # id_conv branch
    z = BatchNormalization(name=f"block_{num_layer}_bn_id")(tensor)

    return Add()([x, y, z])
def get_model(img_shape):

    input = Input(img_shape)

    x = repvgg_block(input, filters=64, num_layer=0)
    x = ReLU()(x)
    x = repvgg_block(x, filters=64, num_layer=1)
    x = ReLU()(x)
    x = repvgg_block_with_id(x, filters=64, num_layer=2)
    x = ReLU()(x)
    x = Flatten()(x)
    x = Dense(10, name='dense')(x)
    x = Softmax()(x)
    model = Model(input, x)
    return model

def get_inference_model(img_shape):
    input = Input(img_shape)

    x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', name='conv_0')(input)
    x = ReLU()(x)
    x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', name='conv_1')(x)
    x = ReLU()(x)
    x = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same', name='conv_2')(x)
    x = ReLU()(x)
    x = Flatten()(x)
    x = Dense(10, name='dense')(x)
    x = Softmax()(x)
    model = Model(input, x)
    return model
training_model = get_model([32,32,3])
training_model.summary()
Model: "model_11"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_14 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
block_0_conv_main (Conv2D)      (None, 16, 16, 64)   1728        input_14[0][0]                   
__________________________________________________________________________________________________
block_0_conv_alt (Conv2D)       (None, 16, 16, 64)   192         input_14[0][0]                   
__________________________________________________________________________________________________
block_0_bn_main (BatchNormaliza (None, 16, 16, 64)   256         block_0_conv_main[0][0]          
__________________________________________________________________________________________________
block_0_bn_alt (BatchNormalizat (None, 16, 16, 64)   256         block_0_conv_alt[0][0]           
__________________________________________________________________________________________________
add_13 (Add)                    (None, 16, 16, 64)   0           block_0_bn_main[0][0]            
                                                                 block_0_bn_alt[0][0]             
__________________________________________________________________________________________________
re_lu_20 (ReLU)                 (None, 16, 16, 64)   0           add_13[0][0]                     
__________________________________________________________________________________________________
block_1_conv_main (Conv2D)      (None, 8, 8, 64)     36864       re_lu_20[0][0]                   
__________________________________________________________________________________________________
block_1_conv_alt (Conv2D)       (None, 8, 8, 64)     4096        re_lu_20[0][0]                   
__________________________________________________________________________________________________
block_1_bn_main (BatchNormaliza (None, 8, 8, 64)     256         block_1_conv_main[0][0]          
__________________________________________________________________________________________________
block_1_bn_alt (BatchNormalizat (None, 8, 8, 64)     256         block_1_conv_alt[0][0]           
__________________________________________________________________________________________________
add_14 (Add)                    (None, 8, 8, 64)     0           block_1_bn_main[0][0]            
                                                                 block_1_bn_alt[0][0]             
__________________________________________________________________________________________________
re_lu_21 (ReLU)                 (None, 8, 8, 64)     0           add_14[0][0]                     
__________________________________________________________________________________________________
block_2_conv_main (Conv2D)      (None, 8, 8, 64)     36864       re_lu_21[0][0]                   
__________________________________________________________________________________________________
block_2_conv_alt (Conv2D)       (None, 8, 8, 64)     4096        re_lu_21[0][0]                   
__________________________________________________________________________________________________
block_2_bn_main (BatchNormaliza (None, 8, 8, 64)     256         block_2_conv_main[0][0]          
__________________________________________________________________________________________________
block_2_bn_alt (BatchNormalizat (None, 8, 8, 64)     256         block_2_conv_alt[0][0]           
__________________________________________________________________________________________________
block_2_bn_id (BatchNormalizati (None, 8, 8, 64)     256         re_lu_21[0][0]                   
__________________________________________________________________________________________________
add_15 (Add)                    (None, 8, 8, 64)     0           block_2_bn_main[0][0]            
                                                                 block_2_bn_alt[0][0]             
                                                                 block_2_bn_id[0][0]              
__________________________________________________________________________________________________
re_lu_22 (ReLU)                 (None, 8, 8, 64)     0           add_15[0][0]                     
__________________________________________________________________________________________________
flatten_6 (Flatten)             (None, 4096)         0           re_lu_22[0][0]                   
__________________________________________________________________________________________________
dense (Dense)                   (None, 10)           40970       flatten_6[0][0]                  
__________________________________________________________________________________________________
softmax_6 (Softmax)             (None, 10)           0           dense[0][0]                      
==================================================================================================
Total params: 126,602
Trainable params: 125,706
Non-trainable params: 896
__________________________________________________________________________________________________

inference_model = get_inference_model([32,32,3])
inference_model.summary()
Model: "model_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_15 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv_0 (Conv2D)              (None, 16, 16, 64)        1792      
_________________________________________________________________
re_lu_23 (ReLU)              (None, 16, 16, 64)        0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 8, 8, 64)          36928     
_________________________________________________________________
re_lu_24 (ReLU)              (None, 8, 8, 64)          0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 8, 8, 64)          36928     
_________________________________________________________________
re_lu_25 (ReLU)              (None, 8, 8, 64)          0         
_________________________________________________________________
flatten_7 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                40970     
_________________________________________________________________
softmax_7 (Softmax)          (None, 10)                0         
=================================================================
Total params: 116,618
Trainable params: 116,618
Non-trainable params: 0
_________________________________________________________________

def from_repvgg_to_vgg(training_model, inference_model, depth):
    model = training_model
    inference_model = inference_model

    for i in range(depth):
        print(f"Fusion Conv-BN from main branch at depth {i}")
        conv_main = model.get_layer(f"block_{i}_conv_main")
        bn_main = model.get_layer(f"block_{i}_bn_main")

        conv_weights_main, conv_biases_main = fuse_bn_conv(
            conv_main.get_weights(), bn_main.get_weights()
        )

        print(f"Fusion Conv-BN from alt branch at depth {i}")
        conv_alt_one_by_one = model.get_layer(f"block_{i}_conv_alt")
        bn_alt = model.get_layer(f"block_{i}_bn_alt")

        conv_alt = pad_size_one_kernel(conv_alt_one_by_one.get_weights())

        conv_weights_alt, conv_biases_alt = fuse_bn_conv([conv_alt], bn_alt.get_weights())

        if i==3:
            print(f"Fusion Conv-BN from id branch at depth {i}")
            bn_id = model.get_layer(f"block_{i}_bn_id")
            channels = backend.int_shape(bn_id.get_weights()[0])[-1]

            conv_id = size_three_kernel_from_id(channels)
            conv_weights_id, conv_biases_id = fuse_bn_conv([conv_id], bn_id.get_weights())

            conv_weights = conv_weights_main + conv_weights_alt + conv_weights_id
            conv_biases = conv_biases_main + conv_biases_alt + conv_biases_id
        else:
            conv_weights = conv_weights_main + conv_weights_alt
            conv_biases = conv_biases_main + conv_biases_alt


        print(f"Setting weights on inference model at depth {i}")
        inference_model.get_layer(f"conv_{i}").set_weights([conv_weights, conv_biases])

    dense_weights = model.get_layer(f"dense").get_weights()
    inference_model.get_layer(f"dense").set_weights(dense_weights)

    return inference_model
from tensorflow.keras import datasets
from sklearn.model_selection import train_test_split

(X_train,y_train), (X_test,y_test)  = tf.keras.datasets.cifar10.load_data()

# Normalize pixel values to be between 0 and 1
X_train, X_test = X_train / 255.0, X_test / 255.0

X_train = X_train.reshape(-1, 32, 32, 3).astype('float32')
X_test = X_test.reshape(-1, 32, 32, 3).astype('float32')


X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, random_state=42)

y_train_oh = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_oh = tf.keras.utils.to_categorical(y_test, num_classes=10)
y_valid_oh = tf.keras.utils.to_categorical(y_valid, num_classes=10)
1
2
3
4
5
6
7
8
training_model.compile(loss = 'categorical_crossentropy',
             optimizer=tf.keras.optimizers.Adam(lr=0.001),
             metrics=['accuracy'])

training_model.fit(X_train, y_train_oh,
                     epochs = 200,
                     batch_size=128,
                     validation_data=(X_valid, y_valid_oh))
Epoch 1/200
293/293 [==============================] - 2s 4ms/step - loss: 1.9390 - accuracy: 0.3690 - val_loss: 1.4326 - val_accuracy: 0.4882
Epoch 2/200
293/293 [==============================] - 1s 3ms/step - loss: 1.1761 - accuracy: 0.5805 - val_loss: 1.2776 - val_accuracy: 0.5559
Epoch 3/200
293/293 [==============================] - 1s 3ms/step - loss: 0.9771 - accuracy: 0.6608 - val_loss: 1.2784 - val_accuracy: 0.5692
Epoch 4/200
293/293 [==============================] - 1s 3ms/step - loss: 0.8457 - accuracy: 0.7047 - val_loss: 1.1538 - val_accuracy: 0.6070
Epoch 5/200
293/293 [==============================] - 1s 3ms/step - loss: 0.7411 - accuracy: 0.7415 - val_loss: 1.1523 - val_accuracy: 0.6082
Epoch 6/200
293/293 [==============================] - 1s 3ms/step - loss: 0.6627 - accuracy: 0.7721 - val_loss: 1.2050 - val_accuracy: 0.6056
Epoch 7/200
293/293 [==============================] - 1s 3ms/step - loss: 0.5730 - accuracy: 0.8017 - val_loss: 1.3360 - val_accuracy: 0.5908
Epoch 8/200
293/293 [==============================] - 1s 4ms/step - loss: 0.5037 - accuracy: 0.8267 - val_loss: 1.2697 - val_accuracy: 0.6086
Epoch 9/200
293/293 [==============================] - 1s 4ms/step - loss: 0.4315 - accuracy: 0.8549 - val_loss: 1.6375 - val_accuracy: 0.5710
Epoch 10/200
293/293 [==============================] - 1s 3ms/step - loss: 0.3744 - accuracy: 0.8732 - val_loss: 1.4501 - val_accuracy: 0.6078
Epoch 11/200
293/293 [==============================] - 1s 4ms/step - loss: 0.3140 - accuracy: 0.8947 - val_loss: 1.3964 - val_accuracy: 0.6189
Epoch 12/200
293/293 [==============================] - 1s 4ms/step - loss: 0.2657 - accuracy: 0.9140 - val_loss: 1.4346 - val_accuracy: 0.6252
Epoch 13/200
293/293 [==============================] - 1s 3ms/step - loss: 0.2328 - accuracy: 0.9247 - val_loss: 1.7435 - val_accuracy: 0.6030
Epoch 14/200
293/293 [==============================] - 1s 3ms/step - loss: 0.1914 - accuracy: 0.9402 - val_loss: 1.6220 - val_accuracy: 0.6197
Epoch 15/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1661 - accuracy: 0.9486 - val_loss: 2.0838 - val_accuracy: 0.5826
Epoch 16/200
293/293 [==============================] - 1s 3ms/step - loss: 0.1485 - accuracy: 0.9539 - val_loss: 1.6913 - val_accuracy: 0.6221
Epoch 17/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1283 - accuracy: 0.9603 - val_loss: 1.7873 - val_accuracy: 0.6273
Epoch 18/200
293/293 [==============================] - 1s 3ms/step - loss: 0.1113 - accuracy: 0.9667 - val_loss: 1.9872 - val_accuracy: 0.6094
Epoch 19/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1026 - accuracy: 0.9691 - val_loss: 2.0423 - val_accuracy: 0.6062
Epoch 20/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1021 - accuracy: 0.9666 - val_loss: 2.0206 - val_accuracy: 0.6166
Epoch 21/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0920 - accuracy: 0.9714 - val_loss: 2.0964 - val_accuracy: 0.6199
Epoch 22/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0794 - accuracy: 0.9750 - val_loss: 2.4125 - val_accuracy: 0.5951
Epoch 23/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0647 - accuracy: 0.9814 - val_loss: 2.2896 - val_accuracy: 0.6095
Epoch 24/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0907 - accuracy: 0.9686 - val_loss: 2.3337 - val_accuracy: 0.6104
Epoch 25/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0774 - accuracy: 0.9755 - val_loss: 2.7223 - val_accuracy: 0.5977
Epoch 26/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0615 - accuracy: 0.9804 - val_loss: 2.3225 - val_accuracy: 0.6226
Epoch 27/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0616 - accuracy: 0.9815 - val_loss: 2.3531 - val_accuracy: 0.6187
Epoch 28/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0510 - accuracy: 0.9862 - val_loss: 2.5024 - val_accuracy: 0.6199
Epoch 29/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0528 - accuracy: 0.9838 - val_loss: 2.9272 - val_accuracy: 0.5823
Epoch 30/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0715 - accuracy: 0.9745 - val_loss: 2.4670 - val_accuracy: 0.6142
Epoch 31/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0780 - accuracy: 0.9733 - val_loss: 2.6590 - val_accuracy: 0.6194
Epoch 32/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0636 - accuracy: 0.9776 - val_loss: 2.5846 - val_accuracy: 0.6173
Epoch 33/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0364 - accuracy: 0.9898 - val_loss: 2.5448 - val_accuracy: 0.6227
Epoch 34/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0308 - accuracy: 0.9915 - val_loss: 2.6284 - val_accuracy: 0.6238
Epoch 35/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0307 - accuracy: 0.9914 - val_loss: 2.8331 - val_accuracy: 0.6170
Epoch 36/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0561 - accuracy: 0.9809 - val_loss: 3.3461 - val_accuracy: 0.5866
Epoch 37/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0740 - accuracy: 0.9740 - val_loss: 2.9504 - val_accuracy: 0.6025
Epoch 38/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0706 - accuracy: 0.9749 - val_loss: 2.7457 - val_accuracy: 0.6227
Epoch 39/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0456 - accuracy: 0.9849 - val_loss: 2.7979 - val_accuracy: 0.6226
Epoch 40/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0269 - accuracy: 0.9929 - val_loss: 2.8238 - val_accuracy: 0.6207
Epoch 41/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0178 - accuracy: 0.9957 - val_loss: 2.8020 - val_accuracy: 0.6298
Epoch 42/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0161 - accuracy: 0.9959 - val_loss: 2.8443 - val_accuracy: 0.6261
Epoch 43/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0377 - accuracy: 0.9870 - val_loss: 4.3773 - val_accuracy: 0.5440
Epoch 44/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1039 - accuracy: 0.9646 - val_loss: 3.0375 - val_accuracy: 0.6030
Epoch 45/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0617 - accuracy: 0.9799 - val_loss: 2.8326 - val_accuracy: 0.6314
Epoch 46/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0328 - accuracy: 0.9899 - val_loss: 2.8987 - val_accuracy: 0.6274
Epoch 47/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0190 - accuracy: 0.9945 - val_loss: 2.9066 - val_accuracy: 0.6258
Epoch 48/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0134 - accuracy: 0.9962 - val_loss: 3.0005 - val_accuracy: 0.6267
Epoch 49/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0098 - accuracy: 0.9980 - val_loss: 2.9672 - val_accuracy: 0.6337
Epoch 50/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0150 - accuracy: 0.9957 - val_loss: 3.4715 - val_accuracy: 0.5854
Epoch 51/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1082 - accuracy: 0.9634 - val_loss: 3.1194 - val_accuracy: 0.6069
Epoch 52/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0703 - accuracy: 0.9746 - val_loss: 3.4825 - val_accuracy: 0.6022
Epoch 53/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0372 - accuracy: 0.9875 - val_loss: 3.1542 - val_accuracy: 0.6277
Epoch 54/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0176 - accuracy: 0.9953 - val_loss: 3.0385 - val_accuracy: 0.6305
Epoch 55/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0099 - accuracy: 0.9978 - val_loss: 3.0361 - val_accuracy: 0.6351
Epoch 56/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0067 - accuracy: 0.9987 - val_loss: 3.0368 - val_accuracy: 0.6361
Epoch 57/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0042 - accuracy: 0.9995 - val_loss: 3.1196 - val_accuracy: 0.6257
Epoch 58/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0070 - accuracy: 0.9984 - val_loss: 3.2946 - val_accuracy: 0.6184
Epoch 59/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0844 - accuracy: 0.9726 - val_loss: 3.6981 - val_accuracy: 0.5838
Epoch 60/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1068 - accuracy: 0.9615 - val_loss: 3.3662 - val_accuracy: 0.6150
Epoch 61/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0331 - accuracy: 0.9884 - val_loss: 3.1085 - val_accuracy: 0.6268
Epoch 62/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0139 - accuracy: 0.9960 - val_loss: 3.2636 - val_accuracy: 0.6290
Epoch 63/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0095 - accuracy: 0.9980 - val_loss: 3.1850 - val_accuracy: 0.6310
Epoch 64/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0063 - accuracy: 0.9987 - val_loss: 3.2820 - val_accuracy: 0.6345
Epoch 65/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0031 - accuracy: 0.9997 - val_loss: 3.1574 - val_accuracy: 0.6397
Epoch 66/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0017 - accuracy: 1.0000 - val_loss: 3.1854 - val_accuracy: 0.6414
Epoch 67/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 3.1196 - val_accuracy: 0.6458
Epoch 68/200
293/293 [==============================] - 1s 4ms/step - loss: 9.6865e-04 - accuracy: 1.0000 - val_loss: 3.1958 - val_accuracy: 0.6469
Epoch 69/200
293/293 [==============================] - 1s 4ms/step - loss: 8.6404e-04 - accuracy: 1.0000 - val_loss: 3.1414 - val_accuracy: 0.6472
Epoch 70/200
293/293 [==============================] - 1s 4ms/step - loss: 5.8577e-04 - accuracy: 1.0000 - val_loss: 3.1648 - val_accuracy: 0.6446
Epoch 71/200
293/293 [==============================] - 1s 4ms/step - loss: 5.2701e-04 - accuracy: 1.0000 - val_loss: 3.2118 - val_accuracy: 0.6478
Epoch 72/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1478 - accuracy: 0.9627 - val_loss: 3.3232 - val_accuracy: 0.5995
Epoch 73/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1219 - accuracy: 0.9585 - val_loss: 3.0879 - val_accuracy: 0.6179
Epoch 74/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0325 - accuracy: 0.9888 - val_loss: 3.1610 - val_accuracy: 0.6290
Epoch 75/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0124 - accuracy: 0.9973 - val_loss: 3.0288 - val_accuracy: 0.6348
Epoch 76/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0045 - accuracy: 0.9997 - val_loss: 3.0254 - val_accuracy: 0.6465
Epoch 77/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0027 - accuracy: 0.9999 - val_loss: 3.0743 - val_accuracy: 0.6470
Epoch 78/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0019 - accuracy: 0.9999 - val_loss: 3.1198 - val_accuracy: 0.6430
Epoch 79/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0016 - accuracy: 1.0000 - val_loss: 3.0889 - val_accuracy: 0.6466
Epoch 80/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 3.1780 - val_accuracy: 0.6393
Epoch 81/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 3.1530 - val_accuracy: 0.6427
Epoch 82/200
293/293 [==============================] - 1s 4ms/step - loss: 9.9041e-04 - accuracy: 1.0000 - val_loss: 3.1667 - val_accuracy: 0.6454
Epoch 83/200
293/293 [==============================] - 1s 4ms/step - loss: 9.6558e-04 - accuracy: 1.0000 - val_loss: 3.1835 - val_accuracy: 0.6430
Epoch 84/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1062 - accuracy: 0.9710 - val_loss: 4.0625 - val_accuracy: 0.5772
Epoch 85/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1073 - accuracy: 0.9622 - val_loss: 3.1034 - val_accuracy: 0.6308
Epoch 86/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0338 - accuracy: 0.9883 - val_loss: 3.1997 - val_accuracy: 0.6258
Epoch 87/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0141 - accuracy: 0.9965 - val_loss: 3.2102 - val_accuracy: 0.6319
Epoch 88/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0064 - accuracy: 0.9991 - val_loss: 3.1613 - val_accuracy: 0.6379
Epoch 89/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0039 - accuracy: 0.9996 - val_loss: 3.2062 - val_accuracy: 0.6387
Epoch 90/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0022 - accuracy: 0.9999 - val_loss: 3.1607 - val_accuracy: 0.6414
Epoch 91/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0014 - accuracy: 1.0000 - val_loss: 3.1905 - val_accuracy: 0.6446
Epoch 92/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0012 - accuracy: 0.9999 - val_loss: 3.2038 - val_accuracy: 0.6442
Epoch 93/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0011 - accuracy: 0.9999 - val_loss: 3.2732 - val_accuracy: 0.6409
Epoch 94/200
293/293 [==============================] - 1s 4ms/step - loss: 7.5633e-04 - accuracy: 1.0000 - val_loss: 3.2499 - val_accuracy: 0.6442
Epoch 95/200
293/293 [==============================] - 1s 4ms/step - loss: 7.7323e-04 - accuracy: 1.0000 - val_loss: 3.2885 - val_accuracy: 0.6450
Epoch 96/200
293/293 [==============================] - 1s 4ms/step - loss: 9.9023e-04 - accuracy: 1.0000 - val_loss: 3.4087 - val_accuracy: 0.6306
Epoch 97/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0088 - accuracy: 0.9979 - val_loss: 5.3897 - val_accuracy: 0.5452
Epoch 98/200
293/293 [==============================] - 1s 4ms/step - loss: 0.3004 - accuracy: 0.9175 - val_loss: 3.8454 - val_accuracy: 0.5955
Epoch 99/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0424 - accuracy: 0.9852 - val_loss: 3.4719 - val_accuracy: 0.6225
Epoch 100/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0128 - accuracy: 0.9972 - val_loss: 3.4143 - val_accuracy: 0.6262
Epoch 101/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0070 - accuracy: 0.9984 - val_loss: 3.2489 - val_accuracy: 0.6374
Epoch 102/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0036 - accuracy: 0.9996 - val_loss: 3.2455 - val_accuracy: 0.6399
Epoch 103/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0019 - accuracy: 1.0000 - val_loss: 3.2159 - val_accuracy: 0.6434
Epoch 104/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 3.2219 - val_accuracy: 0.6451
Epoch 105/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0011 - accuracy: 1.0000 - val_loss: 3.2600 - val_accuracy: 0.6425
Epoch 106/200
293/293 [==============================] - 1s 4ms/step - loss: 8.9068e-04 - accuracy: 1.0000 - val_loss: 3.2839 - val_accuracy: 0.6430
Epoch 107/200
293/293 [==============================] - 1s 3ms/step - loss: 8.8232e-04 - accuracy: 1.0000 - val_loss: 3.2854 - val_accuracy: 0.6457
Epoch 108/200
293/293 [==============================] - 1s 3ms/step - loss: 6.3356e-04 - accuracy: 1.0000 - val_loss: 3.3021 - val_accuracy: 0.6451
Epoch 109/200
293/293 [==============================] - 1s 3ms/step - loss: 6.9747e-04 - accuracy: 1.0000 - val_loss: 3.3147 - val_accuracy: 0.6457
Epoch 110/200
293/293 [==============================] - 1s 3ms/step - loss: 5.3339e-04 - accuracy: 1.0000 - val_loss: 3.3380 - val_accuracy: 0.6454
Epoch 111/200
293/293 [==============================] - 1s 4ms/step - loss: 4.3812e-04 - accuracy: 1.0000 - val_loss: 3.3765 - val_accuracy: 0.6447
Epoch 112/200
293/293 [==============================] - 1s 3ms/step - loss: 4.1380e-04 - accuracy: 1.0000 - val_loss: 3.3854 - val_accuracy: 0.6460
Epoch 113/200
293/293 [==============================] - 1s 3ms/step - loss: 4.5256e-04 - accuracy: 1.0000 - val_loss: 3.4450 - val_accuracy: 0.6454
Epoch 114/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0543 - accuracy: 0.9867 - val_loss: 3.9023 - val_accuracy: 0.5873
Epoch 115/200
293/293 [==============================] - 1s 3ms/step - loss: 0.1773 - accuracy: 0.9468 - val_loss: 3.3988 - val_accuracy: 0.6180
Epoch 116/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0298 - accuracy: 0.9901 - val_loss: 3.3156 - val_accuracy: 0.6344
Epoch 117/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0111 - accuracy: 0.9966 - val_loss: 3.2564 - val_accuracy: 0.6354
Epoch 118/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0056 - accuracy: 0.9991 - val_loss: 3.3034 - val_accuracy: 0.6414
Epoch 119/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0030 - accuracy: 0.9996 - val_loss: 3.3015 - val_accuracy: 0.6431
Epoch 120/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0021 - accuracy: 0.9999 - val_loss: 3.3338 - val_accuracy: 0.6422
Epoch 121/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 3.3000 - val_accuracy: 0.6458
Epoch 122/200
293/293 [==============================] - 1s 3ms/step - loss: 9.2647e-04 - accuracy: 1.0000 - val_loss: 3.3178 - val_accuracy: 0.6457
Epoch 123/200
293/293 [==============================] - 1s 3ms/step - loss: 8.6534e-04 - accuracy: 1.0000 - val_loss: 3.3308 - val_accuracy: 0.6482
Epoch 124/200
293/293 [==============================] - 1s 3ms/step - loss: 6.3539e-04 - accuracy: 1.0000 - val_loss: 3.3759 - val_accuracy: 0.6479
Epoch 125/200
293/293 [==============================] - 1s 3ms/step - loss: 6.1540e-04 - accuracy: 1.0000 - val_loss: 3.3861 - val_accuracy: 0.6474
Epoch 126/200
293/293 [==============================] - 1s 3ms/step - loss: 5.1956e-04 - accuracy: 1.0000 - val_loss: 3.3955 - val_accuracy: 0.6460
Epoch 127/200
293/293 [==============================] - 1s 3ms/step - loss: 4.9951e-04 - accuracy: 1.0000 - val_loss: 3.5187 - val_accuracy: 0.6467
Epoch 128/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0084 - accuracy: 0.9978 - val_loss: 4.4000 - val_accuracy: 0.5709
Epoch 129/200
293/293 [==============================] - 1s 3ms/step - loss: 0.2555 - accuracy: 0.9294 - val_loss: 3.7346 - val_accuracy: 0.6086
Epoch 130/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0358 - accuracy: 0.9878 - val_loss: 3.4413 - val_accuracy: 0.6241
Epoch 131/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0122 - accuracy: 0.9971 - val_loss: 3.4731 - val_accuracy: 0.6366
Epoch 132/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0064 - accuracy: 0.9985 - val_loss: 3.3815 - val_accuracy: 0.6391
Epoch 133/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0040 - accuracy: 0.9993 - val_loss: 3.4015 - val_accuracy: 0.6413
Epoch 134/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0022 - accuracy: 0.9998 - val_loss: 3.4036 - val_accuracy: 0.6428
Epoch 135/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0015 - accuracy: 1.0000 - val_loss: 3.4494 - val_accuracy: 0.6402
Epoch 136/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0016 - accuracy: 0.9999 - val_loss: 3.4177 - val_accuracy: 0.6470
Epoch 137/200
293/293 [==============================] - 1s 4ms/step - loss: 8.6482e-04 - accuracy: 1.0000 - val_loss: 3.4614 - val_accuracy: 0.6452
Epoch 138/200
293/293 [==============================] - 1s 4ms/step - loss: 6.2557e-04 - accuracy: 1.0000 - val_loss: 3.4649 - val_accuracy: 0.6465
Epoch 139/200
293/293 [==============================] - 1s 4ms/step - loss: 6.4032e-04 - accuracy: 1.0000 - val_loss: 3.4694 - val_accuracy: 0.6444
Epoch 140/200
293/293 [==============================] - 1s 4ms/step - loss: 5.6854e-04 - accuracy: 1.0000 - val_loss: 3.5629 - val_accuracy: 0.6450
Epoch 141/200
293/293 [==============================] - 1s 4ms/step - loss: 5.7389e-04 - accuracy: 1.0000 - val_loss: 3.5976 - val_accuracy: 0.6391
Epoch 142/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0730 - accuracy: 0.9803 - val_loss: 3.6825 - val_accuracy: 0.6078
Epoch 143/200
293/293 [==============================] - 1s 4ms/step - loss: 0.1055 - accuracy: 0.9660 - val_loss: 3.4045 - val_accuracy: 0.6313
Epoch 144/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0229 - accuracy: 0.9920 - val_loss: 3.5080 - val_accuracy: 0.6366
Epoch 145/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0088 - accuracy: 0.9976 - val_loss: 3.4096 - val_accuracy: 0.6414
Epoch 146/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0040 - accuracy: 0.9994 - val_loss: 3.4737 - val_accuracy: 0.6436
Epoch 147/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0024 - accuracy: 0.9997 - val_loss: 3.5510 - val_accuracy: 0.6415
Epoch 148/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0017 - accuracy: 0.9998 - val_loss: 3.5113 - val_accuracy: 0.6454
Epoch 149/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 3.4949 - val_accuracy: 0.6440
Epoch 150/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0013 - accuracy: 0.9999 - val_loss: 3.5232 - val_accuracy: 0.6479
Epoch 151/200
293/293 [==============================] - 1s 4ms/step - loss: 9.0308e-04 - accuracy: 1.0000 - val_loss: 3.5569 - val_accuracy: 0.6500
Epoch 152/200
293/293 [==============================] - 1s 4ms/step - loss: 6.7638e-04 - accuracy: 1.0000 - val_loss: 3.5668 - val_accuracy: 0.6485
Epoch 153/200
293/293 [==============================] - 1s 3ms/step - loss: 6.3473e-04 - accuracy: 1.0000 - val_loss: 3.5986 - val_accuracy: 0.6502
Epoch 154/200
293/293 [==============================] - 1s 3ms/step - loss: 4.5041e-04 - accuracy: 1.0000 - val_loss: 3.5845 - val_accuracy: 0.6482
Epoch 155/200
293/293 [==============================] - 1s 3ms/step - loss: 4.0400e-04 - accuracy: 1.0000 - val_loss: 3.6154 - val_accuracy: 0.6446
Epoch 156/200
293/293 [==============================] - 1s 3ms/step - loss: 6.3062e-04 - accuracy: 1.0000 - val_loss: 3.7590 - val_accuracy: 0.6394
Epoch 157/200
293/293 [==============================] - 1s 3ms/step - loss: 0.1702 - accuracy: 0.9545 - val_loss: 3.6780 - val_accuracy: 0.6218
Epoch 158/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0543 - accuracy: 0.9818 - val_loss: 3.5992 - val_accuracy: 0.6316
Epoch 159/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0184 - accuracy: 0.9940 - val_loss: 3.7416 - val_accuracy: 0.6269
Epoch 160/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0061 - accuracy: 0.9988 - val_loss: 3.5675 - val_accuracy: 0.6382
Epoch 161/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0031 - accuracy: 0.9994 - val_loss: 3.6239 - val_accuracy: 0.6394
Epoch 162/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0021 - accuracy: 0.9997 - val_loss: 3.5901 - val_accuracy: 0.6389
Epoch 163/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0018 - accuracy: 0.9998 - val_loss: 3.6193 - val_accuracy: 0.6441
Epoch 164/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0012 - accuracy: 0.9999 - val_loss: 3.6285 - val_accuracy: 0.6438
Epoch 165/200
293/293 [==============================] - 1s 3ms/step - loss: 7.3676e-04 - accuracy: 1.0000 - val_loss: 3.6237 - val_accuracy: 0.6454
Epoch 166/200
293/293 [==============================] - 1s 3ms/step - loss: 5.4252e-04 - accuracy: 1.0000 - val_loss: 3.6783 - val_accuracy: 0.6441
Epoch 167/200
293/293 [==============================] - 1s 4ms/step - loss: 5.7206e-04 - accuracy: 1.0000 - val_loss: 3.6587 - val_accuracy: 0.6457
Epoch 168/200
293/293 [==============================] - 1s 3ms/step - loss: 5.1652e-04 - accuracy: 1.0000 - val_loss: 3.6578 - val_accuracy: 0.6467
Epoch 169/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0366 - accuracy: 0.9905 - val_loss: 4.2009 - val_accuracy: 0.5858
Epoch 170/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0951 - accuracy: 0.9697 - val_loss: 3.5935 - val_accuracy: 0.6283
Epoch 171/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0229 - accuracy: 0.9921 - val_loss: 3.6487 - val_accuracy: 0.6332
Epoch 172/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0066 - accuracy: 0.9983 - val_loss: 3.8405 - val_accuracy: 0.6329
Epoch 173/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0045 - accuracy: 0.9989 - val_loss: 3.6509 - val_accuracy: 0.6421
Epoch 174/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0023 - accuracy: 0.9997 - val_loss: 3.6520 - val_accuracy: 0.6394
Epoch 175/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0012 - accuracy: 1.0000 - val_loss: 3.6577 - val_accuracy: 0.6406
Epoch 176/200
293/293 [==============================] - 1s 4ms/step - loss: 6.8113e-04 - accuracy: 1.0000 - val_loss: 3.6778 - val_accuracy: 0.6417
Epoch 177/200
293/293 [==============================] - 1s 4ms/step - loss: 8.0387e-04 - accuracy: 1.0000 - val_loss: 3.6704 - val_accuracy: 0.6414
Epoch 178/200
293/293 [==============================] - 1s 3ms/step - loss: 6.4948e-04 - accuracy: 1.0000 - val_loss: 3.7070 - val_accuracy: 0.6417
Epoch 179/200
293/293 [==============================] - 1s 4ms/step - loss: 4.2156e-04 - accuracy: 1.0000 - val_loss: 3.7131 - val_accuracy: 0.6404
Epoch 180/200
293/293 [==============================] - 1s 4ms/step - loss: 5.7091e-04 - accuracy: 1.0000 - val_loss: 3.8784 - val_accuracy: 0.6346
Epoch 181/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0651 - accuracy: 0.9804 - val_loss: 4.3902 - val_accuracy: 0.5985
Epoch 182/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0633 - accuracy: 0.9788 - val_loss: 4.0795 - val_accuracy: 0.6162
Epoch 183/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0194 - accuracy: 0.9931 - val_loss: 4.0404 - val_accuracy: 0.6267
Epoch 184/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0076 - accuracy: 0.9979 - val_loss: 3.7312 - val_accuracy: 0.6357
Epoch 185/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0032 - accuracy: 0.9993 - val_loss: 3.7239 - val_accuracy: 0.6436
Epoch 186/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0016 - accuracy: 0.9997 - val_loss: 3.7600 - val_accuracy: 0.6344
Epoch 187/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0017 - accuracy: 0.9998 - val_loss: 3.7408 - val_accuracy: 0.6391
Epoch 188/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0012 - accuracy: 0.9999 - val_loss: 3.8037 - val_accuracy: 0.6379
Epoch 189/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0016 - accuracy: 0.9998 - val_loss: 3.8553 - val_accuracy: 0.6378
Epoch 190/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0050 - accuracy: 0.9984 - val_loss: 4.3736 - val_accuracy: 0.6046
Epoch 191/200
293/293 [==============================] - 1s 3ms/step - loss: 0.0759 - accuracy: 0.9757 - val_loss: 4.1114 - val_accuracy: 0.6099
Epoch 192/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0357 - accuracy: 0.9879 - val_loss: 4.0255 - val_accuracy: 0.6281
Epoch 193/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0121 - accuracy: 0.9959 - val_loss: 3.9506 - val_accuracy: 0.6345
Epoch 194/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0050 - accuracy: 0.9989 - val_loss: 3.9339 - val_accuracy: 0.6302
Epoch 195/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0030 - accuracy: 0.9994 - val_loss: 3.8998 - val_accuracy: 0.6373
Epoch 196/200
293/293 [==============================] - 1s 4ms/step - loss: 0.0015 - accuracy: 0.9997 - val_loss: 3.8751 - val_accuracy: 0.6438
Epoch 197/200
293/293 [==============================] - 1s 4ms/step - loss: 9.7056e-04 - accuracy: 1.0000 - val_loss: 3.8705 - val_accuracy: 0.6424
Epoch 198/200
293/293 [==============================] - 1s 3ms/step - loss: 5.4642e-04 - accuracy: 1.0000 - val_loss: 3.8781 - val_accuracy: 0.6441
Epoch 199/200
293/293 [==============================] - 1s 4ms/step - loss: 5.2847e-04 - accuracy: 1.0000 - val_loss: 3.8988 - val_accuracy: 0.6422
Epoch 200/200
293/293 [==============================] - 1s 3ms/step - loss: 5.0220e-04 - accuracy: 1.0000 - val_loss: 3.9146 - val_accuracy: 0.6449

<tensorflow.python.keras.callbacks.History at 0x7fc892648ca0>
training_model.evaluate(X_test, y_test_oh)
313/313 [==============================] - 0s 1ms/step - loss: 4.0375 - accuracy: 0.6408

[4.037524223327637, 0.6407999992370605]
model = from_repvgg_to_vgg(training_model, inference_model, 3)
model.summary()
Fusion Conv-BN from main branch at depth 0
Fusion Conv-BN from alt branch at depth 0
Setting weights on inference model at depth 0
Fusion Conv-BN from main branch at depth 1
Fusion Conv-BN from alt branch at depth 1
Setting weights on inference model at depth 1
Fusion Conv-BN from main branch at depth 2
Fusion Conv-BN from alt branch at depth 2
Setting weights on inference model at depth 2
Model: "model_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_15 (InputLayer)        [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv_0 (Conv2D)              (None, 16, 16, 64)        1792      
_________________________________________________________________
re_lu_23 (ReLU)              (None, 16, 16, 64)        0         
_________________________________________________________________
conv_1 (Conv2D)              (None, 8, 8, 64)          36928     
_________________________________________________________________
re_lu_24 (ReLU)              (None, 8, 8, 64)          0         
_________________________________________________________________
conv_2 (Conv2D)              (None, 8, 8, 64)          36928     
_________________________________________________________________
re_lu_25 (ReLU)              (None, 8, 8, 64)          0         
_________________________________________________________________
flatten_7 (Flatten)          (None, 4096)              0         
_________________________________________________________________
dense (Dense)                (None, 10)                40970     
_________________________________________________________________
softmax_7 (Softmax)          (None, 10)                0         
=================================================================
Total params: 116,618
Trainable params: 116,618
Non-trainable params: 0
_________________________________________________________________

1
2
3
model.compile(loss = 'categorical_crossentropy',
             optimizer=tf.keras.optimizers.SGD(lr=2e-9),
             metrics=['accuracy'])
model.evaluate(X_test, y_test_oh)
313/313 [==============================] - 0s 1ms/step - loss: 7.8139 - accuracy: 0.3986

[7.867744445800781, 0.39640000462532043]