diff --git a/core/leras/layers/__init__.py b/core/leras/layers/__init__.py index 5c9d43b..b2f81f6 100644 --- a/core/leras/layers/__init__.py +++ b/core/leras/layers/__init__.py @@ -16,3 +16,4 @@ from .ScaleAdd import * from .DenseNorm import * from .AdaIN import * from .MsSsim import * +from .TanhPolar import * diff --git a/core/leras/models/PatchDiscriminator.py b/core/leras/models/PatchDiscriminator.py index dcaf941..90f1354 100644 --- a/core/leras/models/PatchDiscriminator.py +++ b/core/leras/models/PatchDiscriminator.py @@ -130,12 +130,14 @@ class UNetPatchDiscriminator(nn.ModelBase): q=x[np.abs(np.array(x)-target_patch_size).argmin()] return s[q][2] - def on_build(self, patch_size, in_ch, base_ch = 16): + def on_build(self, patch_size, in_ch, base_ch = 16, use_fp16 = False): + self.use_fp16 = use_fp16 + conv_dtype = tf.float16 if use_fp16 else tf.float32 class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) @@ -155,26 +157,29 @@ class UNetPatchDiscriminator(nn.ModelBase): level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } - self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID') + self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) for i, (kernel_size, strides) in enumerate(layers): - self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') ) + self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.res1.append ( ResidualBlock(level_chs[i]) ) self.res2.append ( ResidualBlock(level_chs[i]) ) - self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') ) + self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.upres1.insert (0, ResidualBlock(level_chs[i-1]*2) ) self.upres2.insert (0, ResidualBlock(level_chs[i-1]*2) ) - self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID') + self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID') self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID') def forward(self, x): + if self.use_fp16: + x = tf.cast(x, tf.float16) + x = tf.nn.leaky_relu( self.in_conv(x), 0.2 ) encs = [] @@ -192,7 +197,13 @@ class UNetPatchDiscriminator(nn.ModelBase): x = upres1(x) x = upres2(x) - return center_out, self.out_conv(x) + x = self.out_conv(x) + + if self.use_fp16: + center_out = tf.cast(center_out, tf.float32) + x = tf.cast(x, tf.float32) + + return center_out, x nn.UNetPatchDiscriminator = UNetPatchDiscriminator @@ -250,11 +261,14 @@ class UNetPatchDiscriminatorV2(nn.ModelBase): q=x[np.abs(np.array(x)-target_patch_size).argmin()] return s[q][2] - def on_build(self, patch_size, in_ch): + def on_build(self, patch_size, in_ch, use_fp16 = False): + self.use_fp16 = use_fp16 + conv_dtype = tf.float16 if use_fp16 else tf.float32 + class ResidualBlock(nn.ModelBase): def on_build(self, ch, kernel_size=3 ): - self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') - self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME') + self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) + self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype) def forward(self, inp): x = self.conv1(inp) @@ -273,24 +287,27 @@ class UNetPatchDiscriminatorV2(nn.ModelBase): level_chs = { i-1:v for i,v in enumerate([ min( base_ch * (2**i), 512 ) for i in range(len(layers)+1)]) } - self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID') + self.in_conv = nn.Conv2D( in_ch, level_chs[-1], kernel_size=1, padding='VALID', dtype=conv_dtype) for i, (kernel_size, strides) in enumerate(layers): - self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME') ) + self.convs.append ( nn.Conv2D( level_chs[i-1], level_chs[i], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.res.append ( ResidualBlock(level_chs[i]) ) - self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME') ) + self.upconvs.insert (0, nn.Conv2DTranspose( level_chs[i]*(2 if i != len(layers)-1 else 1), level_chs[i-1], kernel_size=kernel_size, strides=strides, padding='SAME', dtype=conv_dtype) ) self.upres.insert (0, ResidualBlock(level_chs[i-1]*2) ) - self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID') + self.out_conv = nn.Conv2D( level_chs[-1]*2, 1, kernel_size=1, padding='VALID', dtype=conv_dtype) - self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID') - self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID') + self.center_out = nn.Conv2D( level_chs[len(layers)-1], 1, kernel_size=1, padding='VALID', dtype=conv_dtype) + self.center_conv = nn.Conv2D( level_chs[len(layers)-1], level_chs[len(layers)-1], kernel_size=1, padding='VALID', dtype=conv_dtype) def forward(self, x): + if self.use_fp16: + x = tf.cast(x, tf.float16) + x = tf.nn.leaky_relu( self.in_conv(x), 0.1 ) encs = [] @@ -306,6 +323,12 @@ class UNetPatchDiscriminatorV2(nn.ModelBase): x = tf.concat( [enc, x], axis=nn.conv2d_ch_axis) x = upres(x) - return center_out, self.out_conv(x) + x = self.out_conv(x) + + if self.use_fp16: + center_out = tf.cast(center_out, tf.float32) + x = tf.cast(x, tf.float32) + return center_out, x + nn.UNetPatchDiscriminatorV2 = UNetPatchDiscriminatorV2 diff --git a/core/leras/nn.py b/core/leras/nn.py index 7c28874..07cbdf6 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -107,7 +107,7 @@ class nn(): else: nn.tf_default_device_name = f'/{device_config.devices[0].tf_dev_type}:0' - config = tf.ConfigProto(allow_soft_placement=True) + config = tf.ConfigProto() config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.force_gpu_compatible = True diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 2e95504..cff687e 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -334,7 +334,7 @@ class AMPModel(ModelBase): self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if gan_power != 0: - self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN") + self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], use_fp16=use_fp16, name="GAN") self.GAN_opt = nn.AdaBelief(lr=5e-5, lr_dropout=lr_dropout, clipnorm=clipnorm, name='GAN_opt') self.GAN_opt.initialize_variables ( self.GAN.get_weights(), vars_on_cpu=optimizer_vars_on_cpu) self.model_filename_list += [ [self.GAN, 'GAN.npy'], @@ -676,7 +676,7 @@ class AMPModel(ModelBase): name='AMP', input_names=['in_face:0','morph_value:0'], output_names=['out_face_mask:0','out_celeb_face:0','out_celeb_face_mask:0'], - opset=9, + opset=11, output_path=output_path) #override diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index b2703af..ac16fe1 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -343,10 +343,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.is_training: if gan_power != 0: if self.options['gan_version'] == 2: - self.D_src = nn.UNetPatchDiscriminatorV2(patch_size=resolution//16, in_ch=input_ch, name="D_src") + self.D_src = nn.UNetPatchDiscriminatorV2(patch_size=resolution//16, in_ch=input_ch, name="D_src", use_fp16=self.options['use_fp16']) self.model_filename_list += [ [self.D_src, 'D_src_v2.npy'] ] else: - self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="D_src") + self.D_src = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], use_fp16=self.options['use_fp16'], name="D_src") self.model_filename_list += [ [self.D_src, 'GAN.npy'] ] # Initialize optimizers @@ -928,7 +928,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] ) target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) + target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] ) target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )