diff --git a/core/leras/nn.py b/core/leras/nn.py index 5d023af..c1dfc93 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -182,23 +182,17 @@ class nn(): nn.conv2d_spatial_axes = [2,3] @staticmethod - def get4Dshape ( w, h, c, data_format=None ): + def get4Dshape ( w, h, c ): """ returns 4D shape based on current data_format """ - if data_format is None: - data_format = nn.data_format - - if data_format == "NHWC": + if nn.data_format == "NHWC": return (None,h,w,c) else: return (None,c,h,w) @staticmethod - def to_data_format( x, to_data_format, from_data_format=None): - if from_data_format is None: - from_data_format = nn.data_format - + def to_data_format( x, to_data_format, from_data_format): if to_data_format == from_data_format: return x diff --git a/core/leras/tensor_ops.py b/core/leras/tensor_ops.py index be0bb12..895071d 100644 --- a/core/leras/tensor_ops.py +++ b/core/leras/tensor_ops.py @@ -35,7 +35,7 @@ def initialize_tensor_ops(nn): gv = [*zip(grads,vars)] for g,v in gv: if g is None: - raise Exception("No gradient for variable {v.name}") + raise Exception(f"No gradient for variable {v.name}") return gv nn.tf_gradients = tf_gradients diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index e6938ec..a373f79 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -413,18 +413,15 @@ class QModel(ModelBase): #override def onGetPreview(self, samples): - n_samples = min(4, self.get_batch_size() ) - ( (warped_src, target_src, target_srcm), - (warped_dst, target_dst, target_dstm) ) = \ - [ [sample[0:n_samples] for sample in sample_list ] - for sample_list in samples ] + (warped_dst, target_dst, target_dstm) ) = samples S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + n_samples = min(4, self.get_batch_size() ) result = [] st = [] for i in range(n_samples): diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 8ae5d91..aa01172 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -450,9 +450,10 @@ class SAEHDModel(ModelBase): for gpu_id in range(gpu_count): with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ): - batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) + with tf.device(f'/CPU:0'): # slice on CPU, otherwise all batch data will be transfered to GPU first + batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu ) gpu_warped_src = self.warped_src [batch_slice,:,:,:] gpu_warped_dst = self.warped_dst [batch_slice,:,:,:] gpu_target_src = self.target_src [batch_slice,:,:,:] @@ -646,7 +647,6 @@ class SAEHDModel(ModelBase): model.init_weights() # initializing sample generators - if self.is_training: t = SampleProcessor.Types if self.options['face_type'] == 'h': @@ -710,12 +710,8 @@ class SAEHDModel(ModelBase): #override def onGetPreview(self, samples): - n_samples = min(4, self.get_batch_size() ) - ( (warped_src, target_src, target_srcm), - (warped_dst, target_dst, target_dstm) ) = \ - [ [sample[0:n_samples] for sample in sample_list ] - for sample_list in samples ] + (warped_dst, target_dst, target_dstm) ) = samples if self.options['learn_mask']: S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] @@ -725,6 +721,7 @@ class SAEHDModel(ModelBase): target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] + n_samples = min(4, self.get_batch_size() ) result = [] st = [] for i in range(n_samples):