mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
removing fp16 for amp
This commit is contained in:
parent
c9e0ba1779
commit
da8f33ee85
1 changed files with 20 additions and 18 deletions
|
@ -31,7 +31,7 @@ class AMPModel(ModelBase):
|
|||
min_res = 64
|
||||
max_res = 640
|
||||
|
||||
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 224)
|
||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'wf')
|
||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||
|
@ -65,7 +65,7 @@ class AMPModel(ModelBase):
|
|||
self.ask_random_src_flip()
|
||||
self.ask_random_dst_flip()
|
||||
self.ask_batch_size(suggest_batch_size)
|
||||
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.')
|
||||
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters. Lr-dropout should be enabled.')
|
||||
|
||||
if self.is_first_run():
|
||||
resolution = io.input_int("Resolution", default_resolution, add_info="64-640", help_message="More resolution requires more VRAM and time to train. Value will be adjusted to multiple of 32 .")
|
||||
|
@ -147,19 +147,21 @@ class AMPModel(ModelBase):
|
|||
d_dims = self.options['d_dims']
|
||||
d_mask_dims = self.options['d_mask_dims']
|
||||
inter_res = self.inter_res = resolution // 32
|
||||
use_fp16 = self.options['use_fp16']
|
||||
conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
use_fp16 = False# self.options['use_fp16']
|
||||
|
||||
ae_use_fp16 = use_fp16
|
||||
ae_conv_dtype = tf.float16 if use_fp16 else tf.float32
|
||||
|
||||
class Downscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=5 ):
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=conv_dtype)
|
||||
self.conv1 = nn.Conv2D( in_ch, out_ch, kernel_size=kernel_size, strides=2, padding='SAME', dtype=ae_conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return tf.nn.leaky_relu(self.conv1(x), 0.1)
|
||||
|
||||
class Upscale(nn.ModelBase):
|
||||
def on_build(self, in_ch, out_ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv1 = nn.Conv2D(in_ch, out_ch*4, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = nn.depth_to_space(tf.nn.leaky_relu(self.conv1(x), 0.1), 2)
|
||||
|
@ -167,8 +169,8 @@ class AMPModel(ModelBase):
|
|||
|
||||
class ResidualBlock(nn.ModelBase):
|
||||
def on_build(self, ch, kernel_size=3 ):
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=conv_dtype)
|
||||
self.conv1 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
|
||||
self.conv2 = nn.Conv2D( ch, ch, kernel_size=kernel_size, padding='SAME', dtype=ae_conv_dtype)
|
||||
|
||||
def forward(self, inp):
|
||||
x = self.conv1(inp)
|
||||
|
@ -189,7 +191,7 @@ class AMPModel(ModelBase):
|
|||
self.dense1 = nn.Dense( (( resolution//(2**5) )**2) * e_dims*8, ae_dims )
|
||||
|
||||
def forward(self, x):
|
||||
if use_fp16:
|
||||
if ae_use_fp16:
|
||||
x = tf.cast(x, tf.float16)
|
||||
x = self.down1(x)
|
||||
x = self.res1(x)
|
||||
|
@ -198,7 +200,7 @@ class AMPModel(ModelBase):
|
|||
x = self.down4(x)
|
||||
x = self.down5(x)
|
||||
x = self.res5(x)
|
||||
if use_fp16:
|
||||
if ae_use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
x = nn.pixel_norm(nn.flatten(x), axes=-1)
|
||||
x = self.dense1(x)
|
||||
|
@ -233,15 +235,15 @@ class AMPModel(ModelBase):
|
|||
self.upscalem2 = Upscale(d_mask_dims*8, d_mask_dims*4, kernel_size=3)
|
||||
self.upscalem3 = Upscale(d_mask_dims*4, d_mask_dims*2, kernel_size=3)
|
||||
self.upscalem4 = Upscale(d_mask_dims*2, d_mask_dims*1, kernel_size=3)
|
||||
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
self.out_convm = nn.Conv2D( d_mask_dims*1, 1, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
|
||||
|
||||
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=conv_dtype)
|
||||
self.out_conv = nn.Conv2D( d_dims*2, 3, kernel_size=1, padding='SAME', dtype=ae_conv_dtype)
|
||||
self.out_conv1 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
|
||||
self.out_conv2 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
|
||||
self.out_conv3 = nn.Conv2D( d_dims*2, 3, kernel_size=3, padding='SAME', dtype=ae_conv_dtype)
|
||||
|
||||
def forward(self, z):
|
||||
if use_fp16:
|
||||
if ae_use_fp16:
|
||||
z = tf.cast(z, tf.float16)
|
||||
|
||||
x = self.upscale0(z)
|
||||
|
@ -265,7 +267,7 @@ class AMPModel(ModelBase):
|
|||
m = self.upscalem4(m)
|
||||
m = tf.nn.sigmoid(self.out_convm(m))
|
||||
|
||||
if use_fp16:
|
||||
if ae_use_fp16:
|
||||
x = tf.cast(x, tf.float32)
|
||||
m = tf.cast(m, tf.float32)
|
||||
return x, m
|
||||
|
@ -345,7 +347,7 @@ class AMPModel(ModelBase):
|
|||
|
||||
if self.is_training:
|
||||
if gan_power != 0:
|
||||
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], use_fp16=use_fp16, name="GAN")
|
||||
self.GAN = nn.UNetPatchDiscriminator(patch_size=self.options['gan_patch_size'], in_ch=input_ch, base_ch=self.options['gan_dims'], name="GAN")
|
||||
self.model_filename_list += [ [self.GAN, 'GAN.npy'] ]
|
||||
|
||||
# Initialize optimizers
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue