optimizations of nnlib and SampleGeneratorFace,

refactorings
This commit is contained in:
iperov 2019-01-22 11:52:04 +04:00
parent 2de45083a4
commit b6c4171ea1
9 changed files with 175 additions and 79 deletions

View file

@ -337,12 +337,28 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
gauss_kernel = gauss_kernel[:, :, tf.newaxis, tf.newaxis]
def func(input):
return tf.nn.conv2d(input, gauss_kernel, strides=[1, 1, 1, 1], padding="SAME")
input_nc = input.get_shape().as_list()[-1]
inputs = tf.split(input, input_nc, -1)
outputs = []
for i in range(len(inputs)):
outputs += [ tf.nn.conv2d( inputs[i] , gauss_kernel, strides=[1, 1, 1, 1], padding="SAME") ]
return tf.concat (outputs, axis=-1)
return func
nnlib.tf_gaussian_blur = tf_gaussian_blur
#any channel count style diff
#outputs 0.0 .. 1.0 style difference*loss_weight , 0.0 - no diff
def tf_style_loss(gaussian_blur_radius=0.0, loss_weight=1.0, batch_normalize=False, epsilon=1e-5):
def sl(content, style):
gblur = tf_gaussian_blur(gaussian_blur_radius)
def sd(content, style):
content_nc = content.get_shape().as_list()[-1]
style_nc = style.get_shape().as_list()[-1]
if content_nc != style_nc:
raise Exception("tf_style_loss() content_nc != style_nc")
axes = [1,2]
c_mean, c_var = tf.nn.moments(content, axes=axes, keep_dims=True)
s_mean, s_var = tf.nn.moments(style, axes=axes, keep_dims=True)
@ -360,23 +376,10 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
return (mean_loss + std_loss) * loss_weight
def func(target, style):
target_nc = target.get_shape().as_list()[-1]
style_nc = style.get_shape().as_list()[-1]
if target_nc != style_nc:
raise Exception("target_nc != style_nc")
targets = tf.split(target, target_nc, -1)
styles = tf.split(style, style_nc, -1)
style_loss = []
for i in range(len(targets)):
if gaussian_blur_radius > 0.0:
style_loss += [ sl( tf_gaussian_blur(gaussian_blur_radius)(targets[i]),
tf_gaussian_blur(gaussian_blur_radius)(styles[i])) ]
else:
style_loss += [ sl( targets[i],
styles[i]) ]
return np.sum ( style_loss )
if gaussian_blur_radius > 0.0:
return sd( gblur(target), gblur(style))
else:
return sd( target, style )
return func
nnlib.tf_style_loss = tf_style_loss
@ -727,8 +730,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=None, innermost=True)
#for i in range(num_downs - 5):
# unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=unet_block, use_dropout=use_dropout)
for i in range(num_downs - 5):
unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=unet_block, use_dropout=use_dropout)
unet_block = UNetSkipConnection(ngf * 4 , ngf * 8, sub_model=unet_block)
unet_block = UNetSkipConnection(ngf * 2 , ngf * 4, sub_model=unet_block)