diff --git a/main.py b/main.py index 2b7a0c2..acdaf6a 100644 --- a/main.py +++ b/main.py @@ -81,16 +81,10 @@ if __name__ == "__main__": training_data_dst_dir=arguments.training_data_dst_dir, model_path=arguments.model_dir, model_name=arguments.model_name, - ask_for_session_options = arguments.ask_for_session_options, debug = arguments.debug, #**options - session_write_preview_history = arguments.session_write_preview_history, - session_target_epoch = arguments.session_target_epoch, - session_batch_size = arguments.session_batch_size, - save_interval_min = arguments.save_interval_min, choose_worst_gpu = arguments.choose_worst_gpu, force_best_gpu_idx = arguments.force_best_gpu_idx, - multi_gpu = arguments.multi_gpu, force_gpu_idxs = arguments.force_gpu_idxs, cpu_only = arguments.cpu_only ) @@ -101,14 +95,8 @@ if __name__ == "__main__": train_parser.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") train_parser.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model") train_parser.add_argument('--debug', action="store_true", dest="debug", default=False, help="Debug samples.") - train_parser.add_argument('--ask-for-session-options', action="store_true", dest="ask_for_session_options", default=False, help="Ask to override session options.") - train_parser.add_argument('--session-write-preview-history', action="store_true", dest="session_write_preview_history", default=None, help="Enable write preview history for this session.") - train_parser.add_argument('--session-target-epoch', type=int, dest="session_target_epoch", default=0, help="Train until target epoch for this session. Default - unlimited. Environment variable to override: DFL_TARGET_EPOCH.") - train_parser.add_argument('--session-batch-size', type=int, dest="session_batch_size", default=0, help="Model batch size for this session. Default - auto. Environment variable to override: DFL_BATCH_SIZE.") - train_parser.add_argument('--save-interval-min', type=int, dest="save_interval_min", default=10, help="Save interval in minutes. Default 10.") train_parser.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.") train_parser.add_argument('--force-gpu-idxs', type=str, dest="force_gpu_idxs", default=None, help="Override final GPU idxs. Example: 0,1,2.") - train_parser.add_argument('--multi-gpu', action="store_true", dest="multi_gpu", default=False, help="MultiGPU option (if model supports it). It will select only same best(worst) GPU models.") train_parser.add_argument('--choose-worst-gpu', action="store_true", dest="choose_worst_gpu", default=False, help="Choose worst GPU instead of best. Environment variable to force True: DFL_WORST_GPU") train_parser.add_argument('--force-best-gpu-idx', type=int, dest="force_best_gpu_idx", default=-1, help="Force to choose this GPU idx as best(worst).") diff --git a/models/ModelBase.py b/models/ModelBase.py index 4913b54..00bf8ca 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -18,14 +18,7 @@ You can implement your own model. Check examples. class ModelBase(object): #DONT OVERRIDE - def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, - ask_for_session_options=False, - session_write_preview_history = None, - session_target_epoch=0, - session_batch_size=0, - - debug = False, **in_options - ): + def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, **in_options): print ("Loading model...") self.model_path = model_path self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) @@ -56,50 +49,52 @@ class ModelBase(object): self.loss_history = model_data['loss_history'] if 'loss_history' in model_data.keys() else [] self.sample_for_preview = model_data['sample_for_preview'] if 'sample_for_preview' in model_data.keys() else None + ask_override = self.epoch != 0 and input_in_time ("Press enter during 2 seconds to override some model settings.", 2) + if self.epoch == 0: print ("\nModel first run. Enter model options as default for each run.") + + if self.epoch == 0 or ask_override: self.options['write_preview_history'] = input_bool("Write preview history? (y/n ?:help skip:n) : ", False, help_message="Preview history will be writed to _history folder.") - self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0)) - self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:model choice) : ", 0, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) - self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) - self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") - self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) - #self.options['use_fp16'] = use_fp16 = input_bool("Use float16? (y/n skip:n) : ", False) - else: - self.options['write_preview_history'] = self.options.get('write_preview_history', False) - self.options['target_epoch'] = self.options.get('target_epoch', 0) - self.options['batch_size'] = self.options.get('batch_size', 0) - self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) - self.options['random_flip'] = self.options.get('random_flip', True) - self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) - #self.options['use_fp16'] = use_fp16 = self.options['use_fp16'] if 'use_fp16' in self.options.keys() else False - - use_fp16 = False #currently models fails with fp16 - - if ask_for_session_options: - print ("Override options for current session:") - session_write_preview_history = input_bool("Write preview history? (y/n skip:default) : ", None ) - session_target_epoch = input_int("Target epoch (skip:default) : ", 0) - session_batch_size = input_int("Batch_size (skip:default) : ", 0) - - if self.options['write_preview_history']: - if session_write_preview_history is None: - session_write_preview_history = self.options['write_preview_history'] else: + self.options['write_preview_history'] = self.options.get('write_preview_history', False) + + if self.epoch == 0 or ask_override: + self.options['target_epoch'] = max(0, input_int("Target epoch (skip:unlimited) : ", 0)) + else: + self.options['target_epoch'] = self.options.get('target_epoch', 0) + + if self.epoch == 0 or ask_override: + default_batch_size = 0 if self.epoch == 0 else self.options['batch_size'] + self.options['batch_size'] = max(0, input_int("Batch_size (?:help skip:default) : ", default_batch_size, help_message="Larger batch size is always better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually.")) + else: + self.options['batch_size'] = self.options.get('batch_size', 0) + + if self.epoch == 0: + self.options['sort_by_yaw'] = input_bool("Feed faces to network sorted by yaw? (y/n ?:help skip:n) : ", False, help_message="NN will not learn src face directions that don't match dst face directions." ) + else: + self.options['sort_by_yaw'] = self.options.get('sort_by_yaw', False) + + if self.epoch == 0: + self.options['random_flip'] = input_bool("Flip faces randomly? (y/n ?:help skip:y) : ", True, help_message="Predicted face will look more naturally without this option, but src faceset should cover all face directions as dst faceset.") + else: + self.options['random_flip'] = self.options.get('random_flip', True) + + if self.epoch == 0: + self.options['src_scale_mod'] = np.clip( input_int("Src face scale modifier % ( -30...30, ?:help skip:0) : ", 0, help_message="If src face shape is wider than dst, try to decrease this value to get a better result."), -30, 30) + else: + self.options['src_scale_mod'] = self.options.get('src_scale_mod', 0) + + self.write_preview_history = self.options['write_preview_history'] + if not self.options['write_preview_history']: self.options.pop('write_preview_history') - if self.options['target_epoch'] != 0: - if session_target_epoch == 0: - session_target_epoch = self.options['target_epoch'] - else: + self.target_epoch = self.options['target_epoch'] + if self.options['target_epoch'] == 0: self.options.pop('target_epoch') - if self.options['batch_size'] != 0: - if session_batch_size == 0: - session_batch_size = self.options['batch_size'] - else: - self.options.pop('batch_size') - + self.batch_size = self.options['batch_size'] + self.sort_by_yaw = self.options['sort_by_yaw'] if not self.sort_by_yaw: self.options.pop('sort_by_yaw') @@ -112,18 +107,17 @@ class ModelBase(object): if self.src_scale_mod == 0: self.options.pop('src_scale_mod') - self.write_preview_history = session_write_preview_history - self.target_epoch = session_target_epoch - self.batch_size = session_batch_size - self.onInitializeOptions(self.epoch == 0, ask_for_session_options) + self.onInitializeOptions(self.epoch == 0, ask_override) - nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, use_fp16=use_fp16, **in_options) ) + nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **in_options) ) self.device_config = nnlib.active_DeviceConfig self.created_vram_gb = self.options['created_vram_gb'] if 'created_vram_gb' in self.options.keys() else self.device_config.gpu_total_vram_gb self.onInitialize(**in_options) + self.options['batch_size'] = self.batch_size + if self.debug or self.batch_size == 0: self.batch_size = 1 @@ -155,16 +149,10 @@ class ModelBase(object): print ("==") print ("== Model options:") for key in self.options.keys(): - print ("== |== %s : %s" % (key, self.options[key]) ) - print ("== Session options:") - if self.write_preview_history: - print ("== |== write_preview_history : True ") - if self.target_epoch != 0: - print ("== |== target_epoch : %s " % (self.target_epoch) ) - print ("== |== batch_size : %s " % (self.batch_size) ) + print ("== |== %s : %s" % (key, self.options[key]) ) + if self.device_config.multi_gpu: - print ("== |== multi_gpu : True ") - + print ("== |== multi_gpu : True ") print ("== Running on:") if self.device_config.cpu_only: @@ -183,7 +171,7 @@ class ModelBase(object): print ("=========================") #overridable - def onInitializeOptions(self, is_first_run, ask_for_session_options): + def onInitializeOptions(self, is_first_run, ask_override): pass #overridable @@ -231,23 +219,24 @@ class ModelBase(object): def is_reached_epoch_goal(self): return self.target_epoch != 0 and self.epoch >= self.target_epoch - def to_multi_gpu_model_if_possible (self, models_list): - if len(self.device_config.gpu_idxs) > 1: - #make batch_size to divide on GPU count without remainder - self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) ) - if self.batch_size == 0: - self.batch_size = 1 - self.batch_size *= len(self.device_config.gpu_idxs) - - result = [] - for model in models_list: - for i in range( len(model.output_names) ): - model.output_names = 'output_%d' % (i) - result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ] - - return result - else: - return models_list + #multi gpu in keras actually is fake and doesn't work for training https://github.com/keras-team/keras/issues/11976 + #def to_multi_gpu_model_if_possible (self, models_list): + # if len(self.device_config.gpu_idxs) > 1: + # #make batch_size to divide on GPU count without remainder + # self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) ) + # if self.batch_size == 0: + # self.batch_size = 1 + # self.batch_size *= len(self.device_config.gpu_idxs) + # + # result = [] + # for model in models_list: + # for i in range( len(model.output_names) ): + # model.output_names = 'output_%d' % (i) + # result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ] + # + # return result + # else: + # return models_list def get_previews(self): return self.onGetPreview ( self.last_sample ) diff --git a/models/Model_DF/Model.py b/models/Model_DF/Model.py index 5a5308e..71f2d50 100644 --- a/models/Model_DF/Model.py +++ b/models/Model_DF/Model.py @@ -29,9 +29,6 @@ class Model(ModelBase): self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer))) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer))) - if self.is_training_mode: - self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) - self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] ) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] ) diff --git a/models/Model_H128/Model.py b/models/Model_H128/Model.py index 6e12a90..af98a73 100644 --- a/models/Model_H128/Model.py +++ b/models/Model_H128/Model.py @@ -32,9 +32,6 @@ class Model(ModelBase): self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) - if self.is_training_mode: - self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] ) - self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] ) diff --git a/models/Model_H64/Model.py b/models/Model_H64/Model.py index 9705127..3f5d5e3 100644 --- a/models/Model_H64/Model.py +++ b/models/Model_H64/Model.py @@ -33,9 +33,6 @@ class Model(ModelBase): self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) - if self.is_training_mode: - self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] ) - self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] ) diff --git a/models/Model_LIAEF128/Model.py b/models/Model_LIAEF128/Model.py index a28da8d..6d8ad17 100644 --- a/models/Model_LIAEF128/Model.py +++ b/models/Model_LIAEF128/Model.py @@ -34,9 +34,6 @@ class Model(ModelBase): self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) ) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) ) - if self.is_training_mode: - self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) - self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] ) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] ) diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index 5c72063..f94a9f1 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -21,32 +21,40 @@ class SAEModel(ModelBase): decoder_dstmH5 = 'decoder_dstm.h5' #override - def onInitializeOptions(self, is_first_run, ask_for_session_options): + def onInitializeOptions(self, is_first_run, ask_override): default_resolution = 128 default_archi = 'liae' default_style_power = 100 default_face_type = 'f' - if is_first_run: - #first run + if is_first_run: self.options['resolution'] = input_int("Resolution (64,128, ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.") self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:liae) : ", default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower() - self.options['learn_face_style'] = input_bool("Learn face style? (y/n skip:y) : ", True) - self.options['learn_bg_style'] = input_bool("Learn background style? (y/n skip:y) : ", True) - self.options['style_power'] = np.clip ( input_int("Style power (1..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst style during generalization of src and dst faces."), 1, 100 ) - default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 + self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, but it is not tested on various scenes.").lower() + else: + self.options['resolution'] = self.options.get('resolution', default_resolution) + self.options['archi'] = self.options.get('archi', default_archi) + self.options['lighter_encoder'] = self.options.get('lighter_encoder', False) + + if is_first_run or ask_override: + self.options['face_style_power'] = np.clip ( input_int("Face style power (0..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst face style during generalization of src and dst faces."), 0, 100 ) + else: + self.options['face_style_power'] = self.options.get('face_style_power', default_style_power) + + if is_first_run or ask_override: + self.options['bg_style_power'] = np.clip ( input_int("Background style power (0..100 ?:help skip:100) : ", default_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces."), 0, 100 ) + else: + self.options['bg_style_power'] = self.options.get('bg_style_power', default_style_power) + + default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 + + if is_first_run: self.options['ae_dims'] = input_int("AutoEncoder dims (128,256,512 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, [128,256,512], help_message="More dims are better, but requires more VRAM." ) self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower() - else: - #not first run - self.options['resolution'] = self.options.get('resolution', default_resolution) - self.options['learn_face_style'] = self.options.get('learn_face_style', True) - self.options['learn_bg_style'] = self.options.get('learn_bg_style', True) - self.options['archi'] = self.options.get('archi', default_archi) - self.options['style_power'] = self.options.get('style_power', default_style_power) - default_ae_dims = 256 if self.options['archi'] == 'liae' else 512 + else: self.options['ae_dims'] = self.options.get('ae_dims', default_ae_dims) self.options['face_type'] = self.options.get('face_type', default_face_type) + #override def onInitialize(self, **in_options): @@ -68,7 +76,7 @@ class SAEModel(ModelBase): target_dstm = Input(mask_shape) if self.options['archi'] == 'liae': - self.encoder = modelify(SAEModel.EncFlow() ) (Input(bgr_shape)) + self.encoder = modelify(SAEModel.EncFlow(self.options['lighter_encoder']) ) (Input(bgr_shape)) enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] @@ -107,7 +115,7 @@ class SAEModel(ModelBase): pred_src_dst = self.decoder(warped_src_dst_inter_code) pred_src_dstm = self.decoderm(warped_src_dst_inter_code) else: - self.encoder = modelify(SAEModel.DFEncFlow(dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape)) + self.encoder = modelify(SAEModel.DFEncFlow(self.options['lighter_encoder'], dims=ae_dims,lowest_dense_res=resolution // 16) ) (Input(bgr_shape)) dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ] @@ -162,17 +170,16 @@ class SAEModel(ModelBase): psd_target_dst_masked = pred_src_dst_sigm * target_dstm_sigm psd_target_dst_anti_masked = pred_src_dst_sigm * target_dstm_anti_sigm - - - style_power = self.options['style_power'] / 100.0 - + src_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked, pred_src_src_masked )) ) - if self.options['learn_face_style']: - src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*style_power)(psd_target_dst_masked, target_dst_masked) + if self.options['face_style_power'] != 0: + face_style_power = self.options['face_style_power'] / 100.0 + src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)(psd_target_dst_masked, target_dst_masked) - if self.options['learn_bg_style']: - src_loss += K.mean( (100*style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) + if self.options['bg_style_power'] != 0: + bg_style_power = self.options['bg_style_power'] / 100.0 + src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked, target_dst_anti_masked ))) if self.options['archi'] == 'liae': src_train_weights = self.encoder.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights @@ -180,8 +187,7 @@ class SAEModel(ModelBase): src_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights self.src_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss], Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(src_loss, src_train_weights) ) - - + dst_loss = K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked, pred_dst_dst_masked )) ) if self.options['archi'] == 'liae': @@ -190,8 +196,7 @@ class SAEModel(ModelBase): dst_train_weights = self.encoder.trainable_weights + self.decoder_dst.trainable_weights self.dst_train = K.function ([warped_dst, target_dst, target_dstm],[dst_loss], Adam(lr=5e-5, beta_1=0.5, beta_2=0.999).get_updates(dst_loss, dst_train_weights) ) - - + src_mask_loss = K.mean(K.square(target_srcm-pred_src_srcm)) if self.options['archi'] == 'liae': @@ -317,13 +322,18 @@ class SAEModel(ModelBase): **in_options) @staticmethod - def EncFlow(): + def EncFlow(light_enc): exec (nnlib.import_all(), locals(), globals()) def downscale (dim): def func(x): return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return func + + def downscale_sep (dim): + def func(x): + return LeakyReLU(0.1)(SeparableConv2D(dim, 5, strides=2, padding='same')(x)) + return func def upscale (dim): def func(x): @@ -334,9 +344,15 @@ class SAEModel(ModelBase): x = input x = downscale(128)(x) - x = downscale(256)(x) - x = downscale(512)(x) - x = downscale(1024)(x) + if not light_enc: + x = downscale(256)(x) + x = downscale(512)(x) + x = downscale(1024)(x) + else: + x = downscale_sep(256)(x) + x = downscale_sep(512)(x) + x = downscale_sep(1024)(x) + x = Flatten()(x) return x return func @@ -380,7 +396,7 @@ class SAEModel(ModelBase): @staticmethod - def DFEncFlow(dims=512, lowest_dense_res=8): + def DFEncFlow(light_enc, dims=512, lowest_dense_res=8): exec (nnlib.import_all(), locals(), globals()) def downscale (dim): @@ -388,6 +404,11 @@ class SAEModel(ModelBase): return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x)) return func + def downscale_sep (dim): + def func(x): + return LeakyReLU(0.1)(SeparableConv2D(dim, 5, strides=2, padding='same')(x)) + return func + def upscale (dim): def func(x): return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x))) @@ -397,9 +418,14 @@ class SAEModel(ModelBase): x = input x = downscale(128)(x) - x = downscale(256)(x) - x = downscale(512)(x) - x = downscale(1024)(x) + if not light_enc: + x = downscale(256)(x) + x = downscale(512)(x) + x = downscale(1024)(x) + else: + x = downscale_sep(256)(x) + x = downscale_sep(512)(x) + x = downscale_sep(1024)(x) x = Dense(dims)(Flatten()(x)) x = Dense(lowest_dense_res * lowest_dense_res * dims)(x) diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 2ebf116..ec11380 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -72,6 +72,7 @@ Input = keras.layers.Input Dense = keras.layers.Dense Conv2D = keras.layers.Conv2D Conv2DTranspose = keras.layers.Conv2DTranspose +SeparableConv2D = keras.layers.SeparableConv2D MaxPooling2D = keras.layers.MaxPooling2D BatchNormalization = keras.layers.BatchNormalization diff --git a/utils/console_utils.py b/utils/console_utils.py index 1af6002..1b2638a 100644 --- a/utils/console_utils.py +++ b/utils/console_utils.py @@ -1,4 +1,7 @@ - +import os +import sys +import time +import multiprocessing def input_int(s, default_value, valid_list=None, help_message=None): while True: @@ -51,4 +54,28 @@ def input_str(s, default_value, valid_list=None, help_message=None): return inp except: print (default_value) - return default_value \ No newline at end of file + return default_value + +def input_process(stdin_fd, sq, str): + sys.stdin = os.fdopen(stdin_fd) + try: + inp = input (str) + sq.put (True) + except: + sq.put (False) + +def input_in_time (str, max_time_sec): + sq = multiprocessing.Queue() + p = multiprocessing.Process(target=input_process, args=( sys.stdin.fileno(), sq, str)) + p.start() + t = time.time() + inp = False + while True: + if not sq.empty(): + inp = sq.get() + break + if time.time() - t > max_time_sec: + break + p.terminate() + sys.stdin = os.fdopen( sys.stdin.fileno() ) + return inp \ No newline at end of file