diff --git a/models/Model_SAEv2/Model.py b/models/Model_SAEv2/Model.py index 1fc5e78..78d311c 100644 --- a/models/Model_SAEv2/Model.py +++ b/models/Model_SAEv2/Model.py @@ -11,7 +11,7 @@ from samplelib import * #SAE - Styled AutoEncoder -class SAEModel(ModelBase): +class SAEv2Model(ModelBase): #override def onInitializeOptions(self, is_first_run, ask_override): @@ -141,9 +141,9 @@ class SAEModel(ModelBase): return SubpixelDownscaler()(LeakyReLU(0.1)(Conv2D(dim // 4, kernel_size=kernel_size, strides=1, dilation_rate=dilation_rate, padding='same')(x))) return func - def upscale (dim): + def upscale (dim, size=(2,2)): def func(x): - return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same')(x))) + return SubpixelUpscaler(size=size)(LeakyReLU(0.1)(Conv2D(dim * np.prod(size) , kernel_size=3, strides=1, padding='same')(x))) return func def enc_flow(e_dims, ae_dims, lowest_dense_res): @@ -188,8 +188,36 @@ class SAEModel(ModelBase): x = LeakyReLU(0.2)(x) return x return func - + """ def func(x): + x0 = upscale(dims, size=(8,8) )(x) + + x = upscale(dims*4)(x) + + if add_residual_blocks: + x = ResidualBlock(dims*4)(x) + + x1 = upscale(dims, size=(4,4) )(x) + + x = upscale(dims*2)(x) + + if add_residual_blocks: + x = ResidualBlock(dims*2)(x) + + x2 = upscale(dims, size=(2,2) )(x) + + x = upscale(dims)(x) + + if add_residual_blocks: + x = ResidualBlock(dims)(x) + + x = Add()([x0,x1,x2,x]) + + + return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x) + """ + def func(x): + x = upscale(dims*8)(x) if add_residual_blocks: @@ -206,6 +234,7 @@ class SAEModel(ModelBase): x = ResidualBlock(dims*2)(x) return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x) + return func self.encoder = modelify(enc_flow(e_dims, ae_dims, lowest_dense_res)) ( Input(bgr_shape) ) @@ -470,7 +499,7 @@ class SAEModel(ModelBase): if self.is_training_mode: self.src_dst_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) - self.D_opt = RMSprop(lr=1e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) + self.D_opt = RMSprop(lr=5e-5, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1) if not self.options['pixel_loss']: src_loss = K.mean ( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_opt, pred_src_src_masked_opt) ) @@ -503,7 +532,7 @@ class SAEModel(ModelBase): src_code_d_zeros = K.zeros_like(src_code_d) dst_code_d = self.dis( self.model.dst_code ) dst_code_d_ones = K.ones_like(dst_code_d) - opt_D_loss = [ 0.2*DLoss(src_code_d_ones, src_code_d) ] + opt_D_loss = [ DLoss(src_code_d_ones, src_code_d) ] loss_D = (DLoss(dst_code_d_ones , dst_code_d) + \ DLoss(src_code_d_zeros, src_code_d) ) * 0.5 @@ -655,4 +684,4 @@ class SAEModel(ModelBase): clip_hborder_mask_per=0.0625 if (face_type == FaceType.FULL) else 0, ) -Model = SAEModel +Model = SAEv2Model