This commit is contained in:
iperov 2019-01-03 23:02:38 +04:00
parent 8456122cb2
commit 69d5b1e662

View file

@ -91,11 +91,11 @@ class UFMModel(ModelBase):
pred_src_dst = self.decoder_src(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
target_srcm_blurred = tf_gaussian_blur(4.0)(target_srcm)
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
target_srcm_anti_sigm = 1.0 - target_srcm_sigm
target_dstm_blurred = tf_gaussian_blur(4.0)(target_dstm)
target_dstm_blurred = tf_gaussian_blur(resolution // 32)(target_dstm)
target_dstm_sigm = target_dstm_blurred / 2.0 + 0.5
target_dstm_anti_sigm = 1.0 - target_dstm_sigm
@ -249,7 +249,7 @@ class UFMModel(ModelBase):
@staticmethod
def EncFlow(ngf=64, num_downs=4, lowest_dense=512):
exec (nnlib.import_all(), locals(), globals())
use_bias = True
def XNormalization(x):
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
@ -264,7 +264,7 @@ class UFMModel(ModelBase):
for i in range(num_downs):
x = LeakyReLU(0.1)(XNormalization(Conv2D( min(ngf* (2**i), ngf*8) , 5, 2, 'same')(x)))
if i == num_downs-1:
if i == 3:
x_shape = K.int_shape(x)[1:]
x = Reshape(x_shape)(Dense( np.prod(x_shape) )(Dense(lowest_dense)(Flatten()(x))))
result += [x]