removing fp16 for amp

This commit is contained in:
iperov 2021-07-15 13:23:33 +04:00
parent c9e0ba1779
commit da8f33ee85

View file

@ -31,7 +31,7 @@ class AMPModel(ModelBase):
min_res = 64
max_res = 640
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
@ -65,7 +65,7 @@ class AMPModel(ModelBase):
self.ask_random_src_flip()
self.ask_random_dst_flip()
self.ask_batch_size(suggest_batch_size)
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.')
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.')
if self.is_first_run():
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
@ -147,19 +147,21 @@ class AMPModel(ModelBase):
d_dims = self.options['d_dims']
d_mask_dims = self.options['d_mask_dims']
inter_res = self.inter_res = resolution // 32
use_fp16 = self.options['use_fp16']
conv_dtype = tf.float16 if use_fp16 else tf.float32
use_fp16 = False# self.options['use_fp16']
ae_use_fp16 = use_fp16
ae_conv_dtype = tf.float16 if use_fp16 else tf.float32
class Downscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=5 ):
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=ae_conv_dtype)
def forward(self, x):
return tf.nn.leaky_relu(self.conv1(x), 0.1)
class Upscale(nn.ModelBase):
def on_build(self, in_ch, out_ch, kernel_size=3 ):
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
def forward(self, x):
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
@ -167,8 +169,8 @@ class AMPModel(ModelBase):
class ResidualBlock(nn.ModelBase):
def on_build(self, ch, kernel_size=3 ):
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
def forward(self, inp):
x = self.conv1(inp)
@ -189,7 +191,7 @@ class AMPModel(ModelBase):
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
def forward(self, x):
if use_fp16:
if ae_use_fp16:
x = tf.cast(x, tf.float16)
x = self.down1(x)
x = self.res1(x)
@ -198,7 +200,7 @@ class AMPModel(ModelBase):
x = self.down4(x)
x = self.down5(x)
x = self.res5(x)
if use_fp16:
if ae_use_fp16:
x = tf.cast(x, tf.float32)
x = nn.pixel_norm(nn.flatten(x), axes=-1)
x = self.dense1(x)
@ -233,15 +235,15 @@ class AMPModel(ModelBase):
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
def forward(self, z):
if use_fp16:
if ae_use_fp16:
z = tf.cast(z, tf.float16)
x = self.upscale0(z)
@ -265,7 +267,7 @@ class AMPModel(ModelBase):
m = self.upscalem4(m)
m = tf.nn.sigmoid(self.out_convm(m))
if use_fp16:
if ae_use_fp16:
x = tf.cast(x, tf.float32)
m = tf.cast(m, tf.float32)
return x, m
@ -345,7 +347,7 @@ class AMPModel(ModelBase):
if self.is_training:
if gan_power != 0:
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], use_fp16=use_fp16, name="GAN")
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
self.model_filename_list += [ [self.GAN, 'GAN.npy'] ]
# Initialize optimizers