upd fan segmentator

This commit is contained in:
iperov 2019-03-19 19:44:14 +04:00
parent 1f569117c8
commit 034ad3cce5
3 changed files with 27 additions and 37 deletions

View file

@ -19,8 +19,10 @@ You can implement your own model. Check examples.
'''
class ModelBase(object):
#DONT OVERRIDE
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None,
ask_write_preview_history=True, ask_target_iter=True, ask_batch_size=True, ask_sort_by_yaw=True,
ask_random_flip=True, ask_src_scale_mod=True):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
device_args['cpu_only'] = device_args.get('cpu_only',False)
@ -58,6 +60,8 @@ class ModelBase(object):
self.options = {}
self.loss_history = []
self.sample_for_preview = None
model_data = {}
if self.model_data_path.exists():
model_data = pickle.loads ( self.model_data_path.read_bytes() )
self.iter = max( model_data.get('iter',0), model_data.get('epoch',0) )
@ -75,36 +79,36 @@ class ModelBase(object):
if self.iter == 0:
io.log_info ("\nModel first run. Enter model options as default for each run.")
if self.iter == 0 or ask_override:
if ask_write_preview_history and (self.iter == 0 or ask_override):
default_write_preview_history = False if self.iter == 0 else self.options.get('write_preview_history',False)
self.options['write_preview_history'] = io.input_bool("Write preview history? (y/n ?:help skip:%s) : " % (yn_str[default_write_preview_history]) , default_write_preview_history, help_message="Preview history will be writed to <ModelName>_history folder.")
else:
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
if self.iter == 0 or ask_override:
if ask_target_iter and (self.iter == 0 or ask_override):
self.options['target_iter'] = max(0, io.input_int("Target iteration (skip:unlimited/default) : ", 0))
else:
self.options['target_iter'] = max(model_data.get('target_iter',0), self.options.get('target_epoch',0))
if 'target_epoch' in self.options:
self.options.pop('target_epoch')
if self.iter == 0 or ask_override:
if ask_batch_size and (self.iter == 0 or ask_override):
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
else:
self.options['batch_size'] = self.options.get('batch_size', 0)
if self.iter == 0:
if ask_sort_by_yaw and (self.iter == 0):
self.options['sort_by_yaw'] = io.input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
else:
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
if self.iter == 0:
if ask_random_flip and (self.iter == 0):
self.options['random_flip'] = io.input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
else:
self.options['random_flip'] = self.options.get('random_flip', True)
if self.iter == 0:
if ask_src_scale_mod and (self.iter == 0):
self.options['src_scale_mod'] = np.clip( io.input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
else:
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)