mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
converter:
fixed crashes removed useless 'ebs' color transfer changed keys for color degrade added image degrade via denoise - same as denoise extracted data_dst.bat , but you can control this option directly in the interactive converter added image degrade via bicubic downscale and upscale SAEHD: default ae_dims for df now 256.
This commit is contained in:
parent
374d8c2388
commit
770c70d778
8 changed files with 274 additions and 57 deletions
128
nnlib/nnlib.py
128
nnlib/nnlib.py
|
@ -95,6 +95,7 @@ gaussian_blur = nnlib.gaussian_blur
|
|||
style_loss = nnlib.style_loss
|
||||
dssim = nnlib.dssim
|
||||
|
||||
DenseMaxout = nnlib.DenseMaxout
|
||||
PixelShuffler = nnlib.PixelShuffler
|
||||
SubpixelUpscaler = nnlib.SubpixelUpscaler
|
||||
SubpixelDownscaler = nnlib.SubpixelDownscaler
|
||||
|
@ -911,7 +912,134 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
base_config = super(Adam, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.Adam = Adam
|
||||
|
||||
class DenseMaxout(keras.layers.Layer):
|
||||
"""A dense maxout layer.
|
||||
A `MaxoutDense` layer takes the element-wise maximum of
|
||||
`nb_feature` `Dense(input_dim, output_dim)` linear layers.
|
||||
This allows the layer to learn a convex,
|
||||
piecewise linear activation function over the inputs.
|
||||
Note that this is a *linear* layer;
|
||||
if you wish to apply activation function
|
||||
(you shouldn't need to --they are universal function approximators),
|
||||
an `Activation` layer must be added after.
|
||||
# Arguments
|
||||
output_dim: int > 0.
|
||||
nb_feature: number of Dense layers to use internally.
|
||||
init: name of initialization function for the weights of the layer
|
||||
(see [initializations](../initializations.md)),
|
||||
or alternatively, Theano function to use for weights
|
||||
initialization. This parameter is only relevant
|
||||
if you don't pass a `weights` argument.
|
||||
weights: list of Numpy arrays to set as initial weights.
|
||||
The list should have 2 elements, of shape `(input_dim, output_dim)`
|
||||
and (output_dim,) for weights and biases respectively.
|
||||
W_regularizer: instance of [WeightRegularizer](../regularizers.md)
|
||||
(eg. L1 or L2 regularization), applied to the main weights matrix.
|
||||
b_regularizer: instance of [WeightRegularizer](../regularizers.md),
|
||||
applied to the bias.
|
||||
activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
|
||||
applied to the network output.
|
||||
W_constraint: instance of the [constraints](../constraints.md) module
|
||||
(eg. maxnorm, nonneg), applied to the main weights matrix.
|
||||
b_constraint: instance of the [constraints](../constraints.md) module,
|
||||
applied to the bias.
|
||||
bias: whether to include a bias
|
||||
(i.e. make the layer affine rather than linear).
|
||||
input_dim: dimensionality of the input (integer). This argument
|
||||
(or alternatively, the keyword argument `input_shape`)
|
||||
is required when using this layer as the first layer in a model.
|
||||
# Input shape
|
||||
2D tensor with shape: `(nb_samples, input_dim)`.
|
||||
# Output shape
|
||||
2D tensor with shape: `(nb_samples, output_dim)`.
|
||||
# References
|
||||
- [Maxout Networks](http://arxiv.org/abs/1302.4389)
|
||||
"""
|
||||
|
||||
def __init__(self, output_dim,
|
||||
nb_feature=4,
|
||||
kernel_initializer='glorot_uniform',
|
||||
weights=None,
|
||||
W_regularizer=None,
|
||||
b_regularizer=None,
|
||||
activity_regularizer=None,
|
||||
W_constraint=None,
|
||||
b_constraint=None,
|
||||
bias=True,
|
||||
input_dim=None,
|
||||
**kwargs):
|
||||
self.output_dim = output_dim
|
||||
self.nb_feature = nb_feature
|
||||
self.kernel_initializer = keras.initializers.get(kernel_initializer)
|
||||
|
||||
self.W_regularizer = keras.regularizers.get(W_regularizer)
|
||||
self.b_regularizer = keras.regularizers.get(b_regularizer)
|
||||
self.activity_regularizer = keras.regularizers.get(activity_regularizer)
|
||||
|
||||
self.W_constraint = keras.constraints.get(W_constraint)
|
||||
self.b_constraint = keras.constraints.get(b_constraint)
|
||||
|
||||
self.bias = bias
|
||||
self.initial_weights = weights
|
||||
self.input_spec = keras.layers.InputSpec(ndim=2)
|
||||
|
||||
self.input_dim = input_dim
|
||||
if self.input_dim:
|
||||
kwargs['input_shape'] = (self.input_dim,)
|
||||
super(DenseMaxout, self).__init__(**kwargs)
|
||||
|
||||
def build(self, input_shape):
|
||||
input_dim = input_shape[1]
|
||||
self.input_spec = keras.layers.InputSpec(dtype=K.floatx(),
|
||||
shape=(None, input_dim))
|
||||
|
||||
self.W = self.add_weight(shape=(self.nb_feature, input_dim, self.output_dim),
|
||||
initializer=self.kernel_initializer,
|
||||
name='W',
|
||||
regularizer=self.W_regularizer,
|
||||
constraint=self.W_constraint)
|
||||
if self.bias:
|
||||
self.b = self.add_weight(shape=(self.nb_feature, self.output_dim,),
|
||||
initializer='zero',
|
||||
name='b',
|
||||
regularizer=self.b_regularizer,
|
||||
constraint=self.b_constraint)
|
||||
else:
|
||||
self.b = None
|
||||
|
||||
if self.initial_weights is not None:
|
||||
self.set_weights(self.initial_weights)
|
||||
del self.initial_weights
|
||||
self.built = True
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
assert input_shape and len(input_shape) == 2
|
||||
return (input_shape[0], self.output_dim)
|
||||
|
||||
def call(self, x):
|
||||
# no activation, this layer is only linear.
|
||||
output = K.dot(x, self.W)
|
||||
if self.bias:
|
||||
output += self.b
|
||||
output = K.max(output, axis=1)
|
||||
return output
|
||||
|
||||
def get_config(self):
|
||||
config = {'output_dim': self.output_dim,
|
||||
'kernel_initializer': initializers.serialize(self.kernel_initializer),
|
||||
'nb_feature': self.nb_feature,
|
||||
'W_regularizer': regularizers.serialize(self.W_regularizer),
|
||||
'b_regularizer': regularizers.serialize(self.b_regularizer),
|
||||
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
|
||||
'W_constraint': constraints.serialize(self.W_constraint),
|
||||
'b_constraint': constraints.serialize(self.b_constraint),
|
||||
'bias': self.bias,
|
||||
'input_dim': self.input_dim}
|
||||
base_config = super(DenseMaxout, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
nnlib.DenseMaxout = DenseMaxout
|
||||
|
||||
def CAInitializerMP( conv_weights_list ):
|
||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||
data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue