added '?' help for model options. Added 'Src face scale modifier' to model options.

This commit is contained in:
iperov 2019-01-09 09:22:22 +04:00
parent e8620919a7
commit 1a2555e160
8 changed files with 254 additions and 104 deletions

View file

@ -92,7 +92,9 @@ DF - good for side faces, but results in a lower resolution and details. Covers
LIAE - can partially fix dissimilar face shapes, but results in a less recognizable face.
SAE - no matter how similar faces, src face will be morphed onto dst face, which can make face absolutely unrecognizable. Model can collapse on some scenes. Easy to overlay final face because dst background is also predicted.
also quality of src faceset significantly affects the final face.
Quality of src faceset significantly affects the final face.
Narrow src face is better fakeable than wide. This is why Cage is so popular in deepfakes.
### **Sort tool**:

View file

@ -58,11 +58,12 @@ class ModelBase(object):
if self.epoch == 0:
print ("\nModel first run. Enter model options as default for each run.")
self.options['write_preview_history'] = input_bool("Write preview history? (y/n skip:n) : ", False)
self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to <ModelName>_history folder.")
self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0))
self.options['batch_size'] = max(0, input_int("Batch_size (skip:model choice) : ", 0))
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n skip:n) : ", False)
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n skip:y) : ", True)
self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:model choice) : ", 0, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." )
self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.")
self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30)
#self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False)
else:
self.options['write_preview_history'] = self.options.get('write_preview_history', False)
@ -70,6 +71,7 @@ class ModelBase(object):
self.options['batch_size'] = self.options.get('batch_size', 0)
self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False)
self.options['random_flip'] = self.options.get('random_flip', True)
self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0)
#self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False
use_fp16 = False #currently models fails with fp16
@ -105,6 +107,10 @@ class ModelBase(object):
self.random_flip = self.options['random_flip']
if self.random_flip:
self.options.pop('random_flip')
self.src_scale_mod = self.options['src_scale_mod']
if self.src_scale_mod == 0:
self.options.pop('src_scale_mod')
self.write_preview_history = session_write_preview_history
self.target_epoch = session_target_epoch

View file

@ -40,7 +40,7 @@ class Model(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),

View file

@ -46,7 +46,7 @@ class Model(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),

View file

@ -47,7 +47,7 @@ class Model(ModelBase):
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_BGR, 64],
[f.TRANSFORMED | f.FACE_ALIGN_HALF | f.MODE_M | f.FACE_MASK_FULL, 64] ] ),

View file

@ -47,7 +47,7 @@ class Model(ModelBase):
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_BGR, 128],
[f.TRANSFORMED | f.FACE_ALIGN_FULL | f.MODE_M | f.FACE_MASK_FULL, 128] ] ),

View file

@ -15,21 +15,32 @@ class SAEModel(ModelBase):
decoderH5 = 'decoder.h5'
decodermH5 = 'decoderm.h5'
decoder_srcH5 = 'decoder_src.h5'
decoder_srcmH5 = 'decoder_srcm.h5'
decoder_dstH5 = 'decoder_dst.h5'
decoder_dstmH5 = 'decoder_dstm.h5'
#override
def onInitializeOptions(self, is_first_run, ask_for_session_options):
default_resolution = 128
default_ae_dims = 256
default_archi = 'liae'
default_style_power = 100
default_face_type = 'f'
if is_first_run:
#first run
self.options['resolution'] = input_int("Resolution (valid: 64,128, skip:128) : ", default_resolution, [64,128])
self.options['ae_dims'] = input_int("AutoEncoder dims (valid: 128,256,512 skip:256) : ", default_ae_dims, [128,256,512])
self.options['face_type'] = input_str ("Half or Full face? (h/f, skip:f) : ", default_face_type, ['h','f'])
self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 )
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." )
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
else:
#not first run
self.options['resolution'] = self.options.get('resolution', default_resolution)
self.options['archi'] = self.options.get('archi', default_archi)
self.options['style_power'] = self.options.get('style_power', default_style_power)
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims)
self.options['face_type'] = self.options.get('face_type', default_face_type)
@ -44,52 +55,84 @@ class SAEModel(ModelBase):
bgr_shape = (resolution, resolution, 3)
mask_shape = (resolution, resolution, 1)
self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.inter_B = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs)
self.inter_AB = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs)
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.DecFlow (bgr_shape[2],dims=ae_dims*2)) (inter_output_Inputs)
self.decoderm = modelify(SAEModel.DecFlow (mask_shape[2],dims=ae_dims)) (inter_output_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
warped_src = Input(bgr_shape)
target_src = Input(bgr_shape)
target_srcm = Input(mask_shape)
warped_src_code = self.encoder (warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
pred_src_src = self.decoder(warped_src_inter_code)
pred_src_srcm = self.decoderm(warped_src_inter_code)
warped_dst = Input(bgr_shape)
target_dst = Input(bgr_shape)
target_dstm = Input(mask_shape)
if self.options['archi'] == 'liae':
self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape))
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.inter_B = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs)
self.inter_AB = modelify(SAEModel.InterFlow(dims=ae_dims,lowest_dense_res=resolution // 16)) (enc_output_Inputs)
inter_output_Inputs = [ Input( np.array(K.int_shape(x)[1:])*(1,1,2) ) for x in self.inter_B.outputs ]
self.decoder = modelify(SAEModel.DecFlow (bgr_shape[2],dims=ae_dims*2)) (inter_output_Inputs)
self.decoderm = modelify(SAEModel.DecFlow (mask_shape[2],dims=ae_dims)) (inter_output_Inputs)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
warped_src_code = self.encoder (warped_src)
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
pred_src_src = self.decoder(warped_src_inter_code)
pred_src_srcm = self.decoderm(warped_src_inter_code)
warped_dst_code = self.encoder (warped_dst)
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
pred_dst_dst = self.decoder(warped_dst_inter_code)
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
pred_src_dst = self.decoder(warped_src_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
else:
self.encoder = modelify(SAEModel.DFEncFlow(dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape))
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs)
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],dims=ae_dims)) (dec_Inputs)
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs)
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],dims=ae_dims//2)) (dec_Inputs)
warped_dst_code = self.encoder (warped_dst)
warped_dst_inter_B_code = self.inter_B (warped_dst_code)
warped_dst_inter_AB_code = self.inter_AB (warped_dst_code)
warped_dst_inter_code = Concatenate()([warped_dst_inter_B_code,warped_dst_inter_AB_code])
pred_dst_dst = self.decoder(warped_dst_inter_code)
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
warped_src_dst_inter_code = Concatenate()([warped_dst_inter_AB_code,warped_dst_inter_AB_code])
pred_src_dst = self.decoder(warped_src_dst_inter_code)
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
warped_src_code = self.encoder (warped_src)
pred_src_src = self.decoder_src(warped_src_code)
pred_src_srcm = self.decoder_srcm(warped_src_code)
warped_dst_code = self.encoder (warped_dst)
pred_dst_dst = self.decoder_dst(warped_dst_code)
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
pred_src_dst = self.decoder_src(warped_dst_code)
pred_src_dstm = self.decoder_srcm(warped_dst_code)
target_srcm_blurred = tf_gaussian_blur(resolution // 32)(target_srcm)
target_srcm_sigm = target_srcm_blurred / 2.0 + 0.5
target_srcm_anti_sigm = 1.0 - target_srcm_sigm
@ -105,48 +148,61 @@ class SAEModel(ModelBase):
pred_dst_dst_sigm = pred_dst_dst+1
pred_src_dst_sigm = pred_src_dst+1
pred_src_dstm_blurred = tf_gaussian_blur(resolution // 32)(pred_src_dstm)
pred_src_dstm_sigm = pred_src_dstm_blurred / 2.0 + 0.5
pred_src_dstm_anti_sigm = 1.0 - pred_src_dstm_sigm
target_src_masked = target_src_sigm*target_srcm_sigm
target_dst_masked = target_dst_sigm * target_dstm_sigm
target_dst_anti_masked = target_dst_sigm * target_dstm_anti_sigm
target_dst_psd_masked = target_dst_sigm * pred_src_dstm_sigm
target_dst_psd_anti_masked = target_dst_sigm * pred_src_dstm_anti_sigm
pred_src_src_masked = pred_src_src_sigm * target_srcm_sigm
pred_dst_dst_masked = pred_dst_dst_sigm * target_dstm_sigm
psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm
psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm
psd_psd_masked = pred_src_dst_sigm * pred_src_dstm_sigm
psd_psd_anti_masked = pred_src_dst_sigm * pred_src_dstm_anti_sigm
style_power = self.options['style_power'] / 100.0
src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) )
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_target_dst_masked, target_dst_masked)
src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
#src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2)(psd_psd_masked, target_dst_psd_masked)
#src_loss += K.mean( 100*K.square(tf_dssim(2.0)( psd_psd_anti_masked, target_dst_psd_anti_masked )))
#src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked)
src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked )))
if self.options['archi'] == 'liae':
src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
else:
src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights
self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) )
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) )
dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) )
if self.options['archi'] == 'liae':
dst_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
else:
dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights
self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights) )
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) )
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm))
if self.options['archi'] == 'liae':
src_mask_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
src_mask_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights
self.src_mask_train = K.function ([warped_src, target_srcm],[src_mask_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) )
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_mask_loss, src_mask_train_weights ) )
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
dst_mask_loss = K.mean(K.square(target_dstm-pred_dst_dstm))
if self.options['archi'] == 'liae':
dst_mask_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights
else:
dst_mask_train_weights = self.encoder.trainable_weights + self.decoder_dstm.trainable_weights
self.dst_mask_train = K.function ([warped_dst, target_dstm],[dst_mask_loss],
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoderm.trainable_weights) )
Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_mask_loss, dst_mask_train_weights) )
self.AE_view = K.function ([warped_src, warped_dst],[pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm])
self.AE_convert = K.function ([warped_dst],[pred_src_dst, pred_src_dstm])
@ -155,11 +211,11 @@ class SAEModel(ModelBase):
f = SampleProcessor.TypeFlags
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution],
@ -174,14 +230,22 @@ class SAEModel(ModelBase):
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution] ] )
])
#override
def onSave(self):
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
] )
def onSave(self):
if self.options['archi'] == 'liae':
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.inter_B, self.get_strpath_storage_for_file(self.inter_BH5)],
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)],
[self.decoder, self.get_strpath_storage_for_file(self.decoderH5)],
[self.decoderm, self.get_strpath_storage_for_file(self.decodermH5)],
] )
else:
self.save_weights_safe( [[self.encoder, self.get_strpath_storage_for_file(self.encoderH5)],
[self.decoder_src, self.get_strpath_storage_for_file(self.decoder_srcH5)],
[self.decoder_srcm, self.get_strpath_storage_for_file(self.decoder_srcmH5)],
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)],
[self.decoder_dstm, self.get_strpath_storage_for_file(self.decoder_dstmH5)],
] )
#override
def onTrainOneEpoch(self, sample):
warped_src, target_src, target_src_mask = sample[0]
@ -216,7 +280,7 @@ class SAEModel(ModelBase):
st.append ( np.concatenate ( (
S[i], SS[i], #SM[i],
D[i], DD[i], #DM[i],
SD[i], #SDM[i]
SD[i], SDM[i]
), axis=1) )
return [ ('SAE', np.concatenate ( st, axis=0 ) ) ]
@ -306,4 +370,54 @@ class SAEModel(ModelBase):
return func
@staticmethod
def DFEncFlow(dims=512, lowest_dense_res=8):
exec (nnlib.import_all(), locals(), globals())
def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func
def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def func(input):
x = input
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Dense(dims)(Flatten()(x))
x = Dense(lowest_dense_res * lowest_dense_res * dims)(x)
x = Reshape((lowest_dense_res, lowest_dense_res, dims))(x)
x = upscale(dims)(x)
return x
return func
@staticmethod
def DFDecFlow(output_nc,dims,activation='tanh'):
exec (nnlib.import_all(), locals(), globals())
def upscale (dim):
def func(x):
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def func(input):
x = input[0]
x = upscale(dims)(x)
x = upscale(dims//2)(x)
x = upscale(dims//4)(x)
x = Conv2D(output_nc, kernel_size=5, padding='same', activation=activation)(x)
return x
return func
Model = SAEModel

View file

@ -1,26 +1,54 @@
def input_int(s, default_value, valid_list=None):
try:
inp = input(s)
i = int(inp)
if (valid_list is not None) and (i not in valid_list):
def input_int(s, default_value, valid_list=None, help_message=None):
while True:
try:
inp = input(s)
if len(inp) == 0:
raise ValueError("")
if help_message is not None and inp == '?':
print (help_message)
continue
i = int(inp)
if (valid_list is not None) and (i not in valid_list):
return default_value
return i
except:
print (default_value)
return default_value
return i
except:
return default_value
def input_bool(s, default_value):
try:
return bool ( {"y":True,"n":False,"1":True,"0":False}.get(input(s).lower(), default_value) )
except:
return default_value
def input_str(s, default_value, valid_list=None):
try:
inp = input(s)
if (valid_list is not None) and (inp.lower() not in valid_list):
def input_bool(s, default_value, help_message=None):
while True:
try:
inp = input(s)
if len(inp) == 0:
raise ValueError("")
if help_message is not None and inp == '?':
print (help_message)
continue
return bool ( {"y":True,"n":False,"1":True,"0":False}.get(inp.lower(), default_value) )
except:
print ( "y" if default_value else "n" )
return default_value
return inp
except:
return default_value
def input_str(s, default_value, valid_list=None, help_message=None):
while True:
try:
inp = input(s)
if len(inp) == 0:
raise ValueError("")
if help_message is not None and inp == '?':
print (help_message)
continue
if (valid_list is not None) and (inp.lower() not in valid_list):
return default_value
return inp
except:
print (default_value)
return default_value