TP Module 7 : Les modèles générateurs
Import Data
VAE
Blocs de base
Encodeur
Décodeur
Custom Class
GAN
Blocs de base
Critique
Générateur
Custom Class
Détails
Init
Ici, rien de particuliers, on définit les variables de notre classe.
Compile
Comme l'on a deux réseaux qui s'entraînent de façon séparée, on a deux optimiseurs et deux fonctions de pertes. On doit donc modifier la méthode .compile()
.
Gradient penalty
La pénalité du gradient est un des points clé faisant que les GAN de Wasserstein s'entraînent de façon correcte sans mode collapse.
Avec \(\tilde{x} = G(z)\), \(x\) une observation du minibatch, et \(\varepsilon \sim \mathcal{N}(0, 1)\).
gradient_penalty
est alors de calculer la moyenne suivante.
Rappelons tout d'abord que nous travaillons sur des minibatchs, ie calculer \(\mathbb{E}_{\hat{x} \sim \mathbb{P}_{\hat{x}}}(-)\) revient donc à calculer l'estimation non biaisée correspondante sur le minibatch \(M\) en cours, c'est à dire calculer la moyenne suivante.
Donc, pour chaque observation \(x\) du minibatch \(M\) : 1. on l'interpole en \(\hat{x}\), 2. on calcule le gradient \(\nabla_{\hat{x}}D(\hat{x})\) par rapport à \(\hat{x}\), 3. On calcule la norme \(L_{2}\), \(||\nabla_{\hat{x}}D(\hat{x})||_{2}\), 4. On fait la somme \(\sum_{x \in M} (||\nabla_{\hat{x}}D(\hat{x})||_{2}-1)^{2}\), 5. On divise.
Rappellons que pour un tenseur \(T= (t_{i,j,k})_{i,j,k}\) sur 3 axes (ici une image), sa norme \(L_{2}\) est définie de façon usuelle par la formule suivante.
Pour calculer un gradient on utilise with tf.GradientTape() as gp_tape:
et on lui dit quel variable surveiller avec gp_tape.watch(interpolated)
.
Fonction entière :
Pour l'instant, cette fonction n'est pas utilisée, elle est utilisée dans la section suivante.
Train step
Dans l'étape d'entraînement, on fait les choses suivantes.
Pour chaque minibatch,
- On entraîne le générateur et on calcule la perte associée.
- On entraîne le critique et on calcule la perte associée.
- On calcule la pénalité du gradient.
- On multiplie cette pénalité par une contante.
- On ajoute cette pénalité à la perte du critique.
- On retourne ces métriques dans un dictionnaire.
Rappelons que les pertes pour le critique \(D\), \(\mathcal{L}_{D}^{WGAN}\), et le générateur \(G\), \(\mathcal{L}_{G}^{WGAN}\), on les formules suivantes.
Un fois la pénalité de gradient ajoutée, on obtient les formules suivantes.
Avec \(\hat{x} := \tilde{x} + \varepsilon(x - \tilde{x})\), avec \(\tilde{x} = G(z)\), et \(\varepsilon \sim \mathcal{N}(0, 1)\).
Détaillons étapes par étapes.
On entraîne d'abord le critique, l'article d'origine suggère d'entraîner le critique plus longtemps que le générateur. Ici, on l'entraînera 3 étapes pour une étape de générateur.
On passe alors maintenant à l'entraînement du générateur, qui se déroule de la même façon.
On renvoit les nouvelles métriques sous la forme d'un dictionaire.
On doit maintenant définir les fonctions de pertes que l'on va calculer. On rappelle que l'on a les formules suivantes.
\(D(G(z))\) étant défini comme fake_images
et \(D(x)\) comme real_images
plus haut dans la fonction train_step
, il reste juste à prendre la moyenne. La pénalité du gradient étant directement réjoutée dans train_step
.