mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 14:24:40 -07:00
update to 'partial' func
This commit is contained in:
parent
768dce12c2
commit
a9d362844e
1 changed files with 12 additions and 39 deletions
|
@ -1,3 +1,4 @@
|
||||||
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from nnlib import nnlib
|
from nnlib import nnlib
|
||||||
|
@ -485,57 +486,29 @@ class SAEModel(ModelBase):
|
||||||
return x
|
return x
|
||||||
SAEModel.ResidualBlock = ResidualBlock
|
SAEModel.ResidualBlock = ResidualBlock
|
||||||
|
|
||||||
def ResidualBlock_pre (**base_kwargs):
|
|
||||||
def func(*args, **kwargs):
|
|
||||||
kwargs.update(base_kwargs)
|
|
||||||
return ResidualBlock(*args, **kwargs)
|
|
||||||
return func
|
|
||||||
SAEModel.ResidualBlock_pre = ResidualBlock_pre
|
|
||||||
|
|
||||||
def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
def downscale (dim, padding='zero', norm='', act='', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
|
return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) )
|
||||||
return func
|
return func
|
||||||
SAEModel.downscale = downscale
|
SAEModel.downscale = downscale
|
||||||
|
|
||||||
def downscale_pre (**base_kwargs):
|
|
||||||
def func(*args, **kwargs):
|
|
||||||
kwargs.update(base_kwargs)
|
|
||||||
return downscale(*args, **kwargs)
|
|
||||||
return func
|
|
||||||
SAEModel.downscale_pre = downscale_pre
|
|
||||||
|
|
||||||
def upscale (dim, padding='zero', norm='', act='', **kwargs):
|
def upscale (dim, padding='zero', norm='', act='', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
|
return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x))))
|
||||||
return func
|
return func
|
||||||
SAEModel.upscale = upscale
|
SAEModel.upscale = upscale
|
||||||
|
|
||||||
def upscale_pre (**base_kwargs):
|
|
||||||
def func(*args, **kwargs):
|
|
||||||
kwargs.update(base_kwargs)
|
|
||||||
return upscale(*args, **kwargs)
|
|
||||||
return func
|
|
||||||
SAEModel.upscale_pre = upscale_pre
|
|
||||||
|
|
||||||
def to_bgr (output_nc, padding='zero', **kwargs):
|
def to_bgr (output_nc, padding='zero', **kwargs):
|
||||||
def func(x):
|
def func(x):
|
||||||
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x)
|
||||||
return func
|
return func
|
||||||
SAEModel.to_bgr = to_bgr
|
SAEModel.to_bgr = to_bgr
|
||||||
|
|
||||||
def to_bgr_pre (**base_kwargs):
|
|
||||||
def func(*args, **kwargs):
|
|
||||||
kwargs.update(base_kwargs)
|
|
||||||
return to_bgr(*args, **kwargs)
|
|
||||||
return func
|
|
||||||
SAEModel.to_bgr_pre = to_bgr_pre
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEEncFlow(resolution, ch_dims, **kwargs):
|
def LIAEEncFlow(resolution, ch_dims, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(**kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
downscale = SAEModel.downscale_pre(**kwargs)
|
downscale = partial(SAEModel.downscale, **kwargs)
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1]*ch_dims
|
||||||
|
@ -553,7 +526,7 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
|
def LIAEInterFlow(resolution, ae_dims=256, **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(**kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
lowest_dense_res=resolution // 16
|
lowest_dense_res=resolution // 16
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
|
@ -568,10 +541,10 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', norm='', **kwargs):
|
def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', norm='', **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(**kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
to_bgr = SAEModel.to_bgr_pre(**kwargs)
|
to_bgr = partial(SAEModel.to_bgr, **kwargs)
|
||||||
dims = output_nc * ch_dims
|
dims = output_nc * ch_dims
|
||||||
ResidualBlock = SAEModel.ResidualBlock_pre(**kwargs)
|
ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs)
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
@ -609,8 +582,8 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs):
|
def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(padding=padding)
|
upscale = partial(SAEModel.upscale, padding=padding)
|
||||||
downscale = SAEModel.downscale_pre(padding=padding)
|
downscale = partial(SAEModel.downscale, padding=padding)
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
|
@ -634,10 +607,10 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', **kwargs):
|
def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', **kwargs):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = SAEModel.upscale_pre(padding=padding)
|
upscale = partial(SAEModel.upscale, padding=padding)
|
||||||
to_bgr = SAEModel.to_bgr_pre(padding=padding)
|
to_bgr = partial(SAEModel.to_bgr, padding=padding)
|
||||||
dims = output_nc * ch_dims
|
dims = output_nc * ch_dims
|
||||||
ResidualBlock = SAEModel.ResidualBlock_pre(padding=padding)
|
ResidualBlock = partial(SAEModel.ResidualBlock, padding=padding)
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue