SAE: revert back df model from prev commit. LIAE should be restarted.

This commit is contained in:
iperov 2019-04-21 08:31:05 +04:00
parent 2cdf2745a2
commit 63beb3afd2

View file

@ -127,7 +127,7 @@ class SAEModel(ModelBase):
padding = 'reflect' if self.options['remove_gray_border'] else 'zero'
common_flow_kwargs = { 'padding': padding,
'norm': 'bn',
'norm': '',
'act':'' }
weights_to_load = []
@ -486,9 +486,9 @@ class SAEModel(ModelBase):
return x
SAEModel.ResidualBlock = ResidualBlock
def downscale (dim, padding='zero', norm='', act='', kernel_regularizer=None, **kwargs):
def downscale (dim, padding='zero', norm='', act='', **kwargs):
def func(x):
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding, kernel_regularizer=kernel_regularizer)(x)) )
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
return func
SAEModel.downscale = downscale
@ -508,7 +508,7 @@ class SAEModel(ModelBase):
def LIAEEncFlow(resolution, ch_dims, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
downscale = partial(SAEModel.downscale, kernel_regularizer=keras.regularizers.l2(0.0), **kwargs)
downscale = partial(SAEModel.downscale, **kwargs)
def func(input):
dims = K.int_shape(input)[-1]*ch_dims
@ -531,12 +531,8 @@ class SAEModel(ModelBase):
def func(input):
x = input[0]
#https://arxiv.org/abs/1807.01442 https://github.com/aditya-grover/uae
x = Dense(ae_dims, use_bias=False)(x)
x = Lambda ( lambda x: x + 0.1*K.random_normal(K.shape(x), 0, 1) , output_shape=(None,ae_dims) ) (x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2, use_bias=False)(x)
x = Dense(ae_dims)(x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
x = upscale(ae_dims*2)(x)
return x
@ -587,7 +583,7 @@ class SAEModel(ModelBase):
def DFEncFlow(resolution, ae_dims, ch_dims, **kwargs):
exec (nnlib.import_all(), locals(), globals())
upscale = partial(SAEModel.upscale, **kwargs)
downscale = partial(SAEModel.downscale, kernel_regularizer=keras.regularizers.l2(0.0), **kwargs)
downscale = partial(SAEModel.downscale, **kwargs)#, kernel_regularizer=keras.regularizers.l2(0.0),
lowest_dense_res = resolution // 16
def func(input):
@ -599,11 +595,8 @@ class SAEModel(ModelBase):
x = downscale(dims*4)(x)
x = downscale(dims*8)(x)
#https://arxiv.org/abs/1807.01442 https://github.com/aditya-grover/uae
x = Dense(ae_dims, use_bias=False)(Flatten()(x))
x = Lambda ( lambda x: x + 0.1*K.random_normal(K.shape(x), 0, 1) , output_shape=(None,ae_dims) ) (x)
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims, use_bias=False)(x)
x = Dense(ae_dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
x = upscale(ae_dims)(x)
return x