diff --git a/models/Model_UFM/Model.py b/models/Model_UFM/Model.py index 323d675..030ff0d 100644 --- a/models/Model_UFM/Model.py +++ b/models/Model_UFM/Model.py @@ -91,11 +91,11 @@ class UFMModel(ModelBase): pred_src_dst = self.decoder_src(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - target_srcm_blurred = tf_gaussian_blur(4.0)(target_srcm) + target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm) target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5 target_srcm_anti_sigm = 1.0 - target_srcm_sigm - target_dstm_blurred = tf_gaussian_blur(4.0)(target_dstm) + target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm) target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5 target_dstm_anti_sigm = 1.0 - target_dstm_sigm @@ -249,7 +249,7 @@ class UFMModel(ModelBase): @staticmethod def EncFlow(ngf=64, num_downs=4, lowest_dense=512): exec (nnlib.import_all(), locals(), globals()) - + use_bias = True def XNormalization(x): return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x) @@ -264,7 +264,7 @@ class UFMModel(ModelBase): for i in range(num_downs): x = LeakyReLU(0.1)(XNormalization(Conv2D( min(ngf* (2**i), ngf*8) , 5, 2, 'same')(x))) - if i == num_downs-1: + if i == 3: x_shape = K.int_shape(x)[1:] x = Reshape(x_shape)(Dense( np.prod(x_shape) )(Dense(lowest_dense)(Flatten()(x)))) result += [x]