diff --git a/models/Model_DEV_FUNIT/Model.py b/models/Model_DEV_FUNIT/Model.py index 0382d18..395c56b 100644 --- a/models/Model_DEV_FUNIT/Model.py +++ b/models/Model_DEV_FUNIT/Model.py @@ -57,7 +57,6 @@ class FUNITModel(ModelBase): class_downs=4, class_nf=64, class_latent=64, - mlp_nf=256, mlp_blks=2, dis_nf=64, dis_res_blks=10, diff --git a/models/Model_TrueFace/Model.py b/models/Model_TrueFace/Model.py index 3a82010..66678dc 100644 --- a/models/Model_TrueFace/Model.py +++ b/models/Model_TrueFace/Model.py @@ -54,7 +54,6 @@ class TrueFaceModel(ModelBase): class_downs=4, class_nf=64, class_latent=64, - mlp_nf=256, mlp_blks=2, dis_nf=64, dis_res_blks=10, diff --git a/nnlib/FUNIT.py b/nnlib/FUNIT.py index fe769ed..9df50c6 100644 --- a/nnlib/FUNIT.py +++ b/nnlib/FUNIT.py @@ -19,7 +19,6 @@ class FUNIT(object): class_downs=4, class_nf=64, class_latent=64, - mlp_nf=256, mlp_blks=2, dis_nf=64, dis_res_blks=10, @@ -41,7 +40,7 @@ class FUNIT(object): self.enc_content = modelify ( FUNIT.ContentEncoderFlow(downs=encoder_downs, nf=encoder_nf, n_res_blks=encoder_res_blk) ) ( Input(bgr_shape) ) self.enc_class_model = modelify ( FUNIT.ClassModelEncoderFlow(downs=class_downs, nf=class_nf, latent_dim=class_latent) ) ( Input(bgr_shape) ) - self.decoder = modelify ( FUNIT.DecoderFlow(ups=encoder_downs, n_res_blks=encoder_res_blk, mlp_nf=mlp_nf, mlp_blks=mlp_blks, subpixel_decoder=subpixel_decoder ) ) \ + self.decoder = modelify ( FUNIT.DecoderFlow(ups=encoder_downs, n_res_blks=encoder_res_blk, mlp_blks=mlp_blks, subpixel_decoder=subpixel_decoder ) ) \ ( [ Input(K.int_shape(self.enc_content.outputs[0])[1:], name="decoder_input_1"), Input(K.int_shape(self.enc_class_model.outputs[0])[1:], name="decoder_input_2") ] ) @@ -248,7 +247,7 @@ class FUNIT(object): return func @staticmethod - def DecoderFlow(ups, n_res_blks=2, mlp_nf=256, mlp_blks=2, subpixel_decoder=False ): + def DecoderFlow(ups, n_res_blks=2, mlp_blks=2, subpixel_decoder=False ): exec (nnlib.import_all(), locals(), globals()) @@ -273,7 +272,7 @@ class FUNIT(object): ### MLP block inside decoder mlp = class_code for i in range(mlp_blks): - mlp = Dense(mlp_nf, activation='relu')(mlp) + mlp = Dense(nf, activation='relu')(mlp) for i in range(n_res_blks): x = ResBlock(nf)( [x,mlp] )