feature: adds GANv2

This commit is contained in:
jh 2021-03-21 23:13:32 -07:00
commit 7bc5bf2b94
2 changed files with 138 additions and 11 deletions

View file

@ -195,3 +195,117 @@ class UNetPatchDiscriminator(nn.ModelBase):
return center_out, self.out_conv(x) return center_out, self.out_conv(x)
nn.UNetPatchDiscriminator = UNetPatchDiscriminator nn.UNetPatchDiscriminator = UNetPatchDiscriminator
class UNetPatchDiscriminatorV2(nn.ModelBase):
"""
Inspired by https://arxiv.org/abs/2002.12655 "A U-Net Based Discriminator for Generative Adversarial Networks"
"""
def calc_receptive_field_size(self, layers):
"""
result the same as https://fomoro.com/research/article/receptive-field-calculatorindex.html
"""
rf = 0
ts = 1
for i, (k, s) in enumerate(layers):
if i == 0:
rf = k
else:
rf += (k-1)*ts
ts *= s
return rf
def find_archi(self, target_patch_size, max_layers=6):
"""
Find the best configuration of layers using only 3x3 convs for target patch size
"""
s = {}
for layers_count in range(1,max_layers+1):
val = 1 << (layers_count-1)
while True:
val -= 1
layers = []
sum_st = 0
for i in range(layers_count-1):
st = 1 + (1 if val & (1 << i) !=0 else 0 )
layers.append ( [3, st ])
sum_st += st
layers.append ( [3, 2])
sum_st += 2
rf = self.calc_receptive_field_size(layers)
s_rf = s.get(rf, None)
if s_rf is None:
s[rf] = (layers_count, sum_st, layers)
else:
if layers_count < s_rf[0] or \
( layers_count == s_rf[0] and sum_st > s_rf[1] ):
s[rf] = (layers_count, sum_st, layers)
if val == 0:
break
x = sorted(list(s.keys()))
q=x[np.abs(np.array(x)-target_patch_size).argmin()]
return s[q][2]
def on_build(self, patch_size, in_ch):
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME')
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
return x
prev_ch = in_ch
self.convs = []
self.res = []
self.upconvs = []
self.upres = []
layers = self.find_archi(patch_size)
base_ch = 16
level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) }
self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID')
for i, (kernel_size, strides) in enumerate(layers):
self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') )
self.res.append ( ResidualBlock(level_chs[i]) )
self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') )
self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) )
self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID')
self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID')
self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID')
def forward(self, x):
x = tf.nn.leaky_relu( self.in_conv(x), 0.1 )
encs = []
for conv, res in zip(self.convs, self.res):
encs.insert(0, x)
x = tf.nn.leaky_relu( conv(x), 0.1 )
x = res(x)
center_out, x = self.center_out(x), self.center_conv(x)
for i, (upconv, enc, upres) in enumerate(zip(self.upconvs, encs, self.upres)):
x = tf.nn.leaky_relu( upconv(x), 0.1 )
x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis)
x = upres(x)
return center_out, self.out_conv(x)
nn.UNetPatchDiscriminatorV2 = UNetPatchDiscriminatorV2

View file

@ -138,6 +138,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.') self.options['uniform_yaw'] = io.input_bool ("Uniform yaw distribution of samples", default_uniform_yaw, help_message='Helps to fix blurry side faces due to small amount of them in the faceset.')
default_gan_version = self.options['gan_version'] = self.load_or_def_option('gan_version', 2)
default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0) default_gan_power = self.options['gan_power'] = self.load_or_def_option('gan_power', 0.0)
default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8) default_gan_patch_size = self.options['gan_patch_size'] = self.load_or_def_option('gan_patch_size', self.options['resolution'] // 8)
default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16) default_gan_dims = self.options['gan_dims'] = self.load_or_def_option('gan_dims', 16)
@ -154,6 +155,9 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 1.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 1.0 ) self.options['gan_power'] = np.clip ( io.input_number ("GAN power", default_gan_power, add_info="0.0 .. 1.0", help_message="Forces the neural network to learn small details of the face. Enable it only when the face is trained enough with lr_dropout(on) and random_warp(off), and don't disable. The higher the value, the higher the chances of artifacts. Typical fine value is 0.1"), 0.0, 1.0 )
if self.options['gan_power'] != 0.0: if self.options['gan_power'] != 0.0:
gan_version = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3)
self.options['gan_version'] = gan_version
gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 ) gan_patch_size = np.clip ( io.input_int("GAN patch size", default_gan_patch_size, add_info="3-640", help_message="The higher patch size, the higher the quality, the more VRAM is required. You can get sharper edges even at the lowest setting. Typical fine value is resolution / 8." ), 3, 640 )
self.options['gan_patch_size'] = gan_patch_size self.options['gan_patch_size'] = gan_patch_size
@ -299,6 +303,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.is_training: if self.is_training:
if gan_power != 0: if gan_power != 0:
if self.options['gan_version'] == 2:
self.D_src = nn.UNetPatchDiscriminatorV2(patch_size=resolution//16, in_ch=input_ch, name="D_src")
self.model_filename_list += [ [self.D_src, 'D_src_v2.npy'] ]
else:
self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src")
self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] self.model_filename_list += [ [self.D_src, 'GAN.npy'] ]
@ -325,6 +333,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ] self.model_filename_list += [ (self.D_code_opt, 'D_code_opt.npy') ]
if gan_power != 0: if gan_power != 0:
if self.options['gan_version'] == 2:
self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='D_src_dst_opt')
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
self.model_filename_list += [ (self.D_src_dst_opt, 'D_src_v2_opt.npy') ]
else:
self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.D_src_dst_opt = OptimizerClass(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt')
self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights() self.D_src_dst_opt.initialize_variables ( self.D_src.get_weights(), vars_on_cpu=optimizer_vars_on_cpu, lr_dropout_on_cpu=self.options['lr_dropout']=='cpu')#+self.D_src_x2.get_weights()
self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ] self.model_filename_list += [ (self.D_src_dst_opt, 'GAN_opt.npy') ]