diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 78221c0..2f5f5b4 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -128,7 +128,7 @@ class SAEModel(ModelBase): padding = 'reflect' if self.options['remove_gray_border'] else 'zero' common_flow_kwargs = { 'padding': padding, 'norm': 'bn', - 'act':'prelu' } + 'act':'' } weights_to_load = [] if self.options['archi'] == 'liae': @@ -370,8 +370,8 @@ class SAEModel(ModelBase): dst_samples = generators_samples[1] feed = [src_samples[0], dst_samples[0] ] + \ - src_samples[1:1+self.ms_count*2] + \ - dst_samples[1:1+self.ms_count*2] + src_samples[1:1+self.ms_count*2] + \ + dst_samples[1:1+self.ms_count*2] src_loss, dst_loss, = self.src_dst_train (feed) @@ -467,7 +467,7 @@ class SAEModel(ModelBase): return LeakyReLU(alpha=lrelu_alpha) class ResidualBlock(object): - def __init__(self, filters, kernel_size=3, padding='zero', use_reflection_padding=False, norm='', act='', **kwargs): + def __init__(self, filters, kernel_size=3, padding='zero', norm='', act='', **kwargs): self.filters = filters self.kernel_size = kernel_size self.padding = padding @@ -486,9 +486,9 @@ class SAEModel(ModelBase): return x SAEModel.ResidualBlock = ResidualBlock - def downscale (dim, padding='zero', norm='', act='', **kwargs): + def downscale (dim, padding='zero', norm='', act='', kernel_regularizer=None, **kwargs): def func(x): - return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) ) + return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding, kernel_regularizer=kernel_regularizer)(x)) ) return func SAEModel.downscale = downscale @@ -508,8 +508,8 @@ class SAEModel(ModelBase): def LIAEEncFlow(resolution, ch_dims, **kwargs): exec (nnlib.import_all(), locals(), globals()) upscale = partial(SAEModel.upscale, **kwargs) - downscale = partial(SAEModel.downscale, **kwargs) - + downscale = partial(SAEModel.downscale, kernel_regularizer=keras.regularizers.l2(0.0), **kwargs) + def func(input): dims = K.int_shape(input)[-1]*ch_dims @@ -531,8 +531,12 @@ class SAEModel(ModelBase): def func(input): x = input[0] - x = Dense(ae_dims)(x) - x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x) + + #https://arxiv.org/abs/1807.01442 https://github.com/aditya-grover/uae + x = Dense(ae_dims, use_bias=False)(x) + x = Lambda ( lambda x: x + 0.1*K.random_normal(K.shape(x), 0, 1) , output_shape=(None,ae_dims) ) (x) + x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2, use_bias=False)(x) + x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x) x = upscale(ae_dims*2)(x) return x @@ -580,37 +584,38 @@ class SAEModel(ModelBase): return func @staticmethod - def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs): + def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = partial(SAEModel.upscale, padding=padding) - downscale = partial(SAEModel.downscale, padding=padding) + upscale = partial(SAEModel.upscale, **kwargs) + downscale = partial(SAEModel.downscale, kernel_regularizer=keras.regularizers.l2(0.0), **kwargs) lowest_dense_res = resolution // 16 def func(input): x = input dims = K.int_shape(input)[-1]*ch_dims - x = downscale(dims)(x) x = downscale(dims*2)(x) x = downscale(dims*4)(x) x = downscale(dims*8)(x) - - x = Dense(ae_dims)(Flatten()(x)) - x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x) + + #https://arxiv.org/abs/1807.01442 https://github.com/aditya-grover/uae + x = Dense(ae_dims, use_bias=False)(Flatten()(x)) + x = Lambda ( lambda x: x + 0.1*K.random_normal(K.shape(x), 0, 1) , output_shape=(None,ae_dims) ) (x) + x = Dense(lowest_dense_res * lowest_dense_res * ae_dims, use_bias=False)(x) + x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x) x = upscale(ae_dims)(x) - return x return func @staticmethod - def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', **kwargs): + def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = partial(SAEModel.upscale, padding=padding) - to_bgr = partial(SAEModel.to_bgr, padding=padding) + upscale = partial(SAEModel.upscale, **kwargs) + to_bgr = partial(SAEModel.to_bgr, **kwargs) dims = output_nc * ch_dims - ResidualBlock = partial(SAEModel.ResidualBlock, padding=padding) + ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs) def func(input): x = input[0]