From a9d362844e4e2fc01809140642c4bd9249b177f4 Mon Sep 17 00:00:00 2001 From: iperov Date: Tue, 16 Apr 2019 14:24:03 +0400 Subject: [PATCH] update to 'partial' func --- models/Model_SAE/Model.py | 51 +++++++++------------------------------ 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 12dfffd..478fad4 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -1,3 +1,4 @@ +from functools import partial import numpy as np from nnlib import nnlib @@ -485,57 +486,29 @@ class SAEModel(ModelBase): return x SAEModel.ResidualBlock = ResidualBlock - def ResidualBlock_pre (**base_kwargs): - def func(*args, **kwargs): - kwargs.update(base_kwargs) - return ResidualBlock(*args, **kwargs) - return func - SAEModel.ResidualBlock_pre = ResidualBlock_pre - def downscale (dim, padding='zero', norm='', act='', **kwargs): def func(x): return Norm(norm)( Act(act) (Conv2D(dim, kernel_size=5, strides=2, padding=padding)(x)) ) return func SAEModel.downscale = downscale - def downscale_pre (**base_kwargs): - def func(*args, **kwargs): - kwargs.update(base_kwargs) - return downscale(*args, **kwargs) - return func - SAEModel.downscale_pre = downscale_pre - def upscale (dim, padding='zero', norm='', act='', **kwargs): def func(x): return SubpixelUpscaler()(Norm(norm)(Act(act)(Conv2D(dim * 4, kernel_size=3, strides=1, padding=padding)(x)))) return func SAEModel.upscale = upscale - def upscale_pre (**base_kwargs): - def func(*args, **kwargs): - kwargs.update(base_kwargs) - return upscale(*args, **kwargs) - return func - SAEModel.upscale_pre = upscale_pre - def to_bgr (output_nc, padding='zero', **kwargs): def func(x): return Conv2D(output_nc, kernel_size=5, padding=padding, activation='sigmoid')(x) return func SAEModel.to_bgr = to_bgr - def to_bgr_pre (**base_kwargs): - def func(*args, **kwargs): - kwargs.update(base_kwargs) - return to_bgr(*args, **kwargs) - return func - SAEModel.to_bgr_pre = to_bgr_pre - @staticmethod def LIAEEncFlow(resolution, ch_dims, **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = SAEModel.upscale_pre(**kwargs) - downscale = SAEModel.downscale_pre(**kwargs) + upscale = partial(SAEModel.upscale, **kwargs) + downscale = partial(SAEModel.downscale, **kwargs) def func(input): dims = K.int_shape(input)[-1]*ch_dims @@ -553,7 +526,7 @@ class SAEModel(ModelBase): @staticmethod def LIAEInterFlow(resolution, ae_dims=256, **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = SAEModel.upscale_pre(**kwargs) + upscale = partial(SAEModel.upscale, **kwargs) lowest_dense_res=resolution // 16 def func(input): @@ -568,10 +541,10 @@ class SAEModel(ModelBase): @staticmethod def LIAEDecFlow(output_nc,ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', norm='', **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = SAEModel.upscale_pre(**kwargs) - to_bgr = SAEModel.to_bgr_pre(**kwargs) + upscale = partial(SAEModel.upscale, **kwargs) + to_bgr = partial(SAEModel.to_bgr, **kwargs) dims = output_nc * ch_dims - ResidualBlock = SAEModel.ResidualBlock_pre(**kwargs) + ResidualBlock = partial(SAEModel.ResidualBlock, **kwargs) def func(input): x = input[0] @@ -609,8 +582,8 @@ class SAEModel(ModelBase): @staticmethod def DFEncFlow(resolution, ae_dims, ch_dims, padding='zero', **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = SAEModel.upscale_pre(padding=padding) - downscale = SAEModel.downscale_pre(padding=padding) + upscale = partial(SAEModel.upscale, padding=padding) + downscale = partial(SAEModel.downscale, padding=padding) lowest_dense_res = resolution // 16 def func(input): @@ -634,10 +607,10 @@ class SAEModel(ModelBase): @staticmethod def DFDecFlow(output_nc, ch_dims, multiscale_count=1, add_residual_blocks=False, padding='zero', **kwargs): exec (nnlib.import_all(), locals(), globals()) - upscale = SAEModel.upscale_pre(padding=padding) - to_bgr = SAEModel.to_bgr_pre(padding=padding) + upscale = partial(SAEModel.upscale, padding=padding) + to_bgr = partial(SAEModel.to_bgr, padding=padding) dims = output_nc * ch_dims - ResidualBlock = SAEModel.ResidualBlock_pre(padding=padding) + ResidualBlock = partial(SAEModel.ResidualBlock, padding=padding) def func(input): x = input[0]