mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
Trainer: added option for all models
Enable autobackup? (y/n ?:help skip:%s) : Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01 SAE: added option only for CUDA builds: Enable gradient clipping? (y/n, ?:help skip:%s) : Gradient clipping reduces chance of model collapse, sacrificing speed of training.
This commit is contained in:
parent
ea1d59f620
commit
8484060e01
14 changed files with 210 additions and 80 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,19 +1,22 @@
|
|||
import os
|
||||
import json
|
||||
import time
|
||||
import inspect
|
||||
import pickle
|
||||
import colorsys
|
||||
import imagelib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from utils import Path_utils
|
||||
from utils import std_utils
|
||||
from utils.cv2_utils import *
|
||||
import numpy as np
|
||||
|
||||
import cv2
|
||||
from samplelib import SampleGeneratorBase
|
||||
from nnlib import nnlib
|
||||
import numpy as np
|
||||
|
||||
import imagelib
|
||||
from interact import interact as io
|
||||
from nnlib import nnlib
|
||||
from samplelib import SampleGeneratorBase
|
||||
from utils import Path_utils, std_utils
|
||||
from utils.cv2_utils import *
|
||||
|
||||
'''
|
||||
You can implement your own model. Check examples.
|
||||
'''
|
||||
|
@ -21,6 +24,7 @@ class ModelBase(object):
|
|||
|
||||
|
||||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, pretraining_data_path=None, debug = False, device_args = None,
|
||||
ask_enable_autobackup=True,
|
||||
ask_write_preview_history=True,
|
||||
ask_target_iter=True,
|
||||
ask_batch_size=True,
|
||||
|
@ -84,6 +88,12 @@ class ModelBase(object):
|
|||
if self.iter == 0:
|
||||
io.log_info ("\nModel first run. Enter model options as default for each run.")
|
||||
|
||||
if ask_enable_autobackup and (self.iter == 0 or ask_override):
|
||||
default_autobackup = False if self.iter == 0 else self.options.get('autobackup',False)
|
||||
self.options['autobackup'] = io.input_bool("Enable autobackup? (y/n ?:help skip:%s) : " % (yn_str[default_autobackup]) , default_autobackup, help_message="Autobackup model files with preview every hour for last 15 hours. Latest backup located in model/<>_autobackups/01")
|
||||
else:
|
||||
self.options['autobackup'] = self.options.get('autobackup', False)
|
||||
|
||||
if ask_write_preview_history and (self.iter == 0 or ask_override):
|
||||
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
|
||||
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
|
||||
|
@ -127,6 +137,10 @@ class ModelBase(object):
|
|||
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
|
||||
else:
|
||||
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
|
||||
|
||||
self.autobackup = self.options.get('autobackup', False)
|
||||
if not self.autobackup and 'autobackup' in self.options:
|
||||
self.options.pop('autobackup')
|
||||
|
||||
self.write_preview_history = self.options.get('write_preview_history', False)
|
||||
if not self.write_preview_history and 'write_preview_history' in self.options:
|
||||
|
@ -160,8 +174,16 @@ class ModelBase(object):
|
|||
if self.is_training_mode:
|
||||
if self.device_args['force_gpu_idx'] == -1:
|
||||
self.preview_history_path = self.model_path / ( '%s_history' % (self.get_model_name()) )
|
||||
self.autobackups_path = self.model_path / ( '%s_autobackups' % (self.get_model_name()) )
|
||||
else:
|
||||
self.preview_history_path = self.model_path / ( '%d_%s_history' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
||||
self.autobackups_path = self.model_path / ( '%d_%s_autobackups' % (self.device_args['force_gpu_idx'], self.get_model_name()) )
|
||||
|
||||
if self.autobackup:
|
||||
self.autobackup_current_hour = time.localtime().tm_hour
|
||||
|
||||
if not self.autobackups_path.exists():
|
||||
self.autobackups_path.mkdir(exist_ok=True)
|
||||
|
||||
if self.write_preview_history or io.is_colab():
|
||||
if not self.preview_history_path.exists():
|
||||
|
@ -205,8 +227,8 @@ class ModelBase(object):
|
|||
|
||||
io.destroy_window(wnd_name)
|
||||
else:
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
self.last_sample = self.sample_for_preview
|
||||
model_summary_text = []
|
||||
|
||||
model_summary_text += ["===== Model summary ====="]
|
||||
|
@ -277,6 +299,10 @@ class ModelBase(object):
|
|||
def get_model_name(self):
|
||||
return Path(inspect.getmodule(self).__file__).parent.name.rsplit("_", 1)[1]
|
||||
|
||||
#overridable , return [ [model, filename],... ] list
|
||||
def get_model_filename_list(self):
|
||||
return []
|
||||
|
||||
#overridable
|
||||
def get_converter(self):
|
||||
raise NotImplementedError
|
||||
|
@ -314,7 +340,8 @@ class ModelBase(object):
|
|||
return self.onGetPreview (self.sample_for_preview)[0][1] #first preview, and bgr
|
||||
|
||||
def save(self):
|
||||
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||
Path( summary_path ).write_text(self.model_summary_text)
|
||||
self.onSave()
|
||||
|
||||
model_data = {
|
||||
|
@ -325,6 +352,44 @@ class ModelBase(object):
|
|||
}
|
||||
self.model_data_path.write_bytes( pickle.dumps(model_data) )
|
||||
|
||||
bckp_filename_list = [ self.get_strpath_storage_for_file(filename) for _, filename in self.get_model_filename_list() ]
|
||||
bckp_filename_list += [ str(summary_path), str(self.model_data_path) ]
|
||||
|
||||
if self.autobackup:
|
||||
current_hour = time.localtime().tm_hour
|
||||
if self.autobackup_current_hour != current_hour:
|
||||
self.autobackup_current_hour = current_hour
|
||||
|
||||
for i in range(15,0,-1):
|
||||
idx_str = '%.2d' % i
|
||||
next_idx_str = '%.2d' % (i+1)
|
||||
|
||||
idx_backup_path = self.autobackups_path / idx_str
|
||||
next_idx_packup_path = self.autobackups_path / next_idx_str
|
||||
|
||||
if idx_backup_path.exists():
|
||||
if i == 15:
|
||||
Path_utils.delete_all_files(idx_backup_path)
|
||||
else:
|
||||
next_idx_packup_path.mkdir(exist_ok=True)
|
||||
Path_utils.move_all_files (idx_backup_path, next_idx_packup_path)
|
||||
|
||||
if i == 1:
|
||||
idx_backup_path.mkdir(exist_ok=True)
|
||||
for filename in bckp_filename_list:
|
||||
shutil.copy ( str(filename), str(idx_backup_path / Path(filename).name) )
|
||||
|
||||
previews = self.get_previews()
|
||||
plist = []
|
||||
for i in range(len(previews)):
|
||||
name, bgr = previews[i]
|
||||
plist += [ (bgr, idx_backup_path / ( ('preview_%s.jpg') % (name)) ) ]
|
||||
|
||||
for preview, filepath in plist:
|
||||
preview_lh = ModelBase.get_loss_history_preview(self.loss_history, self.iter, preview.shape[1], preview.shape[2])
|
||||
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
|
||||
cv2_imwrite (filepath, img )
|
||||
|
||||
def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
|
@ -349,12 +414,16 @@ class ModelBase(object):
|
|||
print ("Unable to load ", opt_filename)
|
||||
|
||||
|
||||
def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]):
|
||||
def save_weights_safe(self, model_filename_list):
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
model.save_weights( filename + '.tmp' )
|
||||
|
||||
rename_list = model_filename_list
|
||||
|
||||
"""
|
||||
#unused
|
||||
, optimizer_filename_list=[]
|
||||
if len(optimizer_filename_list) != 0:
|
||||
opt_filename = self.get_strpath_storage_for_file('opt.h5')
|
||||
|
||||
|
@ -374,7 +443,8 @@ class ModelBase(object):
|
|||
rename_list += [('', 'opt.h5')]
|
||||
except Exception as e:
|
||||
print ("Unable to save ", opt_filename)
|
||||
|
||||
"""
|
||||
|
||||
for _, filename in rename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
source_filename = Path(filename+'.tmp')
|
||||
|
@ -383,8 +453,7 @@ class ModelBase(object):
|
|||
if target_filename.exists():
|
||||
target_filename.unlink()
|
||||
source_filename.rename ( str(target_filename) )
|
||||
|
||||
|
||||
|
||||
def debug_one_iter(self):
|
||||
images = []
|
||||
for generator in self.generator_list:
|
||||
|
@ -490,45 +559,47 @@ class ModelBase(object):
|
|||
|
||||
lh_height = 100
|
||||
lh_img = np.ones ( (lh_height,w,c) ) * 0.1
|
||||
loss_count = len(loss_history[0])
|
||||
lh_len = len(loss_history)
|
||||
|
||||
if len(loss_history) != 0:
|
||||
loss_count = len(loss_history[0])
|
||||
lh_len = len(loss_history)
|
||||
|
||||
l_per_col = lh_len / w
|
||||
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
l_per_col = lh_len / w
|
||||
plist_max = [ [ max (0.0, loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
|
||||
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
plist_min = [ [ min (plist_max[col][p], loss_history[int(col*l_per_col)][p],
|
||||
*[ loss_history[i_ab][p]
|
||||
for i_ab in range( int(col*l_per_col), int((col+1)*l_per_col) )
|
||||
]
|
||||
)
|
||||
for p in range(loss_count)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
for col in range(w)
|
||||
]
|
||||
|
||||
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
|
||||
plist_abs_max = np.mean(loss_history[ len(loss_history) // 5 : ]) * 2
|
||||
|
||||
for col in range(0, w):
|
||||
for p in range(0,loss_count):
|
||||
point_color = [1.0]*c
|
||||
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
|
||||
for col in range(0, w):
|
||||
for p in range(0,loss_count):
|
||||
point_color = [1.0]*c
|
||||
point_color[0:3] = colorsys.hsv_to_rgb ( p * (1.0/loss_count), 1.0, 1.0 )
|
||||
|
||||
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_max = np.clip( ph_max, 0, lh_height-1 )
|
||||
ph_max = int ( (plist_max[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_max = np.clip( ph_max, 0, lh_height-1 )
|
||||
|
||||
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_min = np.clip( ph_min, 0, lh_height-1 )
|
||||
ph_min = int ( (plist_min[col][p] / plist_abs_max) * (lh_height-1) )
|
||||
ph_min = np.clip( ph_min, 0, lh_height-1 )
|
||||
|
||||
for ph in range(ph_min, ph_max+1):
|
||||
lh_img[ (lh_height-ph-1), col ] = point_color
|
||||
for ph in range(ph_min, ph_max+1):
|
||||
lh_img[ (lh_height-ph-1), col ] = point_color
|
||||
|
||||
lh_lines = 5
|
||||
lh_line_height = (lh_height-1)/lh_lines
|
||||
|
|
|
@ -11,6 +11,7 @@ class Model(ModelBase):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_enable_autobackup=False,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_sort_by_yaw=False,
|
||||
|
|
|
@ -12,6 +12,7 @@ class Model(ModelBase):
|
|||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs,
|
||||
ask_enable_autobackup=False,
|
||||
ask_write_preview_history=False,
|
||||
ask_target_iter=False,
|
||||
ask_sort_by_yaw=False,
|
||||
|
|
|
@ -59,11 +59,16 @@ class Model(ModelBase):
|
|||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
|
||||
output_sample_types=output_sample_types)
|
||||
])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']]
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, sample, generators_list):
|
||||
|
|
|
@ -70,11 +70,15 @@ class Model(ModelBase):
|
|||
output_sample_types=output_sample_types )
|
||||
])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']]
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, sample, generators_list):
|
||||
|
|
|
@ -71,11 +71,15 @@ class Model(ModelBase):
|
|||
output_sample_types=output_sample_types)
|
||||
])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']]
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']] )
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, sample, generators_list):
|
||||
|
|
|
@ -65,12 +65,16 @@ class Model(ModelBase):
|
|||
output_sample_types=output_sample_types)
|
||||
])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder, 'decoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5']]
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.encoder, 'encoder.h5'],
|
||||
[self.decoder, 'decoder.h5'],
|
||||
[self.inter_B, 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5']] )
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, sample, generators_list):
|
||||
|
|
|
@ -201,14 +201,18 @@ class RecycleGANModel(ModelBase):
|
|||
else:
|
||||
self.G_convert = K.function([real_B0],[fake_A0])
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return [ [self.GA, 'GA.h5'],
|
||||
[self.GB, 'GB.h5'],
|
||||
[self.DA, 'DA.h5'],
|
||||
[self.DB, 'DB.h5'],
|
||||
[self.PA, 'PA.h5'],
|
||||
[self.PB, 'PB.h5'] ]
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.GA, 'GA.h5'],
|
||||
[self.GB, 'GB.h5'],
|
||||
[self.DA, 'DA.h5'],
|
||||
[self.DB, 'DB.h5'],
|
||||
[self.PA, 'PA.h5'],
|
||||
[self.PB, 'PB.h5'] ])
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
|
|
|
@ -24,7 +24,7 @@ class SAEModel(ModelBase):
|
|||
#override
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
yn_str = {True:'y',False:'n'}
|
||||
|
||||
|
||||
default_resolution = 128
|
||||
default_archi = 'df'
|
||||
default_face_type = 'f'
|
||||
|
@ -90,12 +90,20 @@ class SAEModel(ModelBase):
|
|||
|
||||
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply LCT color transfer from random dst samples. It is like 'face_style' learning, but more precise color transfer and without risk of model collapse, also it does not require additional GPU resources, but the training time may be longer, due to the src faceset is becoming more diverse.")
|
||||
|
||||
if nnlib.device.backend != 'plaidML': # todo https://github.com/plaidml/plaidml/issues/301
|
||||
default_clipgrad = False if is_first_run else self.options.get('clipgrad', False)
|
||||
self.options['clipgrad'] = io.input_bool ("Enable gradient clipping? (y/n, ?:help skip:%s) : " % (yn_str[default_clipgrad]), default_clipgrad, help_message="Gradient clipping reduces chance of model collapse, sacrificing speed of training.")
|
||||
else:
|
||||
self.options['clipgrad'] = False
|
||||
|
||||
else:
|
||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||
|
||||
self.options['clipgrad'] = self.options.get('clipgrad', False)
|
||||
|
||||
if is_first_run:
|
||||
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
||||
else:
|
||||
|
@ -271,8 +279,8 @@ class SAEModel(ModelBase):
|
|||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, clipnorm=1.0 if self.options['clipgrad'] else 0.0, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||
|
||||
if 'liae' in self.options['archi']:
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
|
@ -375,12 +383,9 @@ class SAEModel(ModelBase):
|
|||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \
|
||||
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)])
|
||||
])
|
||||
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
opt_ar = [ [self.src_dst_opt, 'src_dst_opt'],
|
||||
[self.src_dst_mask_opt, 'src_dst_mask_opt']
|
||||
]
|
||||
def get_model_filename_list(self):
|
||||
ar = []
|
||||
if 'liae' in self.options['archi']:
|
||||
ar += [[self.encoder, 'encoder.h5'],
|
||||
|
@ -407,9 +412,11 @@ class SAEModel(ModelBase):
|
|||
if self.options['learn_mask']:
|
||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||
|
||||
self.save_weights_safe(ar)
|
||||
|
||||
return ar
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
|
|
|
@ -27,7 +27,24 @@ def get_image_unique_filestem_paths(dir_path, verbose_print_func=None):
|
|||
result_dup.add(f_stem)
|
||||
|
||||
return result
|
||||
|
||||
def get_file_paths(dir_path):
|
||||
dir_path = Path (dir_path)
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
return [ x.path for x in list(scandir(str(dir_path))) if x.is_file() ]
|
||||
return result
|
||||
|
||||
def get_all_dir_names (dir_path):
|
||||
dir_path = Path (dir_path)
|
||||
|
||||
result = []
|
||||
if dir_path.exists():
|
||||
return [ x.name for x in list(scandir(str(dir_path))) if x.is_dir() ]
|
||||
|
||||
return result
|
||||
|
||||
def get_all_dir_names_startswith (dir_path, startswith):
|
||||
dir_path = Path (dir_path)
|
||||
startswith = startswith.lower()
|
||||
|
@ -52,3 +69,15 @@ def get_first_file_by_stem (dir_path, stem, exts=None):
|
|||
return xp
|
||||
|
||||
return None
|
||||
|
||||
def move_all_files (src_dir_path, dst_dir_path):
|
||||
paths = get_file_paths(src_dir_path)
|
||||
for p in paths:
|
||||
p = Path(p)
|
||||
p.rename ( Path(dst_dir_path) / p.name )
|
||||
|
||||
def delete_all_files (dir_path):
|
||||
paths = get_file_paths(dir_path)
|
||||
for p in paths:
|
||||
p = Path(p)
|
||||
p.unlink()
|
Loading…
Add table
Add a link
Reference in a new issue