diff --git a/core/joblib/SubprocessorBase.py b/core/joblib/SubprocessorBase.py index 993d5cb..e37e5a6 100644 --- a/core/joblib/SubprocessorBase.py +++ b/core/joblib/SubprocessorBase.py @@ -12,9 +12,11 @@ class Subprocessor(object): class Cli(object): def __init__ ( self, client_dict ): - self.s2c = multiprocessing.Queue() - self.c2s = multiprocessing.Queue() - self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,) ) + s2c = multiprocessing.Queue() + c2s = multiprocessing.Queue() + self.p = multiprocessing.Process(target=self._subprocess_run, args=(client_dict,s2c,c2s) ) + self.s2c = s2c + self.c2s = c2s self.p.daemon = True self.p.start() @@ -52,9 +54,8 @@ class Subprocessor(object): def log_err(self, msg): self.c2s.put ( {'op': 'log_err' , 'msg':msg } ) def progress_bar_inc(self, c): self.c2s.put ( {'op': 'progress_bar_inc' , 'c':c } ) - def _subprocess_run(self, client_dict): + def _subprocess_run(self, client_dict, s2c, c2s): data = None - s2c, c2s = self.s2c, self.c2s try: self.on_initialize(client_dict) @@ -85,7 +86,13 @@ class Subprocessor(object): print ('Exception: %s' % (traceback.format_exc()) ) c2s.put ( {'op': 'error', 'data' : data} ) - + + # disable pickling + def __getstate__(self): + return dict() + def __setstate__(self, d): + self.__dict__.update(d) + #overridable def __init__(self, name, SubprocessorCli_class, no_response_time_sec = 0, io_loop_sleep_time=0.005, initialize_subprocesses_in_serial=False): if not issubclass(SubprocessorCli_class, Subprocessor.Cli): @@ -179,7 +186,7 @@ class Subprocessor(object): break io.process_messages(0.005) except: - raise Exception ("Unable to start subprocess %s" % (name)) + raise Exception (f"Unable to start subprocess {name}. Error: {traceback.format_exc()}") if len(self.clis) == 0: raise Exception ("Unable to start Subprocessor '%s' " % (self.name)) diff --git a/core/leras/initializers.py b/core/leras/initializers.py index a3294cb..d935454 100644 --- a/core/leras/initializers.py +++ b/core/leras/initializers.py @@ -1,52 +1,104 @@ +import multiprocessing + import numpy as np +from core.joblib import Subprocessor + + def initialize_initializers(nn): tf = nn.tf from tensorflow.python.ops import init_ops - + class initializers(): class ca (init_ops.Initializer): def __init__(self, dtype=None): pass - + def __call__(self, shape, dtype=None, partition_info=None): return tf.zeros( shape, name="_cai_") @staticmethod - def generate(shape, eps_std=0.05, dtype=np.float32): - """ - Super fast implementation of Convolution Aware Initialization for 4D shapes - Convolution Aware Initialization https://arxiv.org/abs/1702.06295 - """ - if len(shape) != 4: - raise ValueError("only shape with rank 4 supported.") + def generate_batch( data_list, eps_std=0.05 ): + # list of (shape, np.dtype) + return CAInitializerSubprocessor (data_list).run() + + nn.initializers = initializers - row, column, stack_size, filters_size = shape +class CAInitializerSubprocessor(Subprocessor): + @staticmethod + def generate(shape, dtype=np.float32, eps_std=0.05): + """ + Super fast implementation of Convolution Aware Initialization for 4D shapes + Convolution Aware Initialization https://arxiv.org/abs/1702.06295 + """ + if len(shape) != 4: + raise ValueError("only shape with rank 4 supported.") - fan_in = stack_size * (row * column) + row, column, stack_size, filters_size = shape - kernel_shape = (row, column) + fan_in = stack_size * (row * column) - kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape + kernel_shape = (row, column) - basis_size = np.prod(kernel_fft_shape) - if basis_size == 1: - x = np.random.normal( 0.0, eps_std, (filters_size, stack_size, basis_size) ) - else: - nbb = stack_size // basis_size + 1 - x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size)) - x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size)) - u, _, v = np.linalg.svd(x) - x = np.transpose(u, (0,1,3,2) ) - x = np.reshape(x, (filters_size, -1, basis_size) ) - x = x[:,:stack_size,:] + kernel_fft_shape = np.fft.rfft2(np.zeros(kernel_shape)).shape - x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) + basis_size = np.prod(kernel_fft_shape) + if basis_size == 1: + x = np.random.normal( 0.0, eps_std, (filters_size, stack_size, basis_size) ) + else: + nbb = stack_size // basis_size + 1 + x = np.random.normal(0.0, 1.0, (filters_size, nbb, basis_size, basis_size)) + x = x + np.transpose(x, (0,1,3,2) ) * (1-np.eye(basis_size)) + u, _, v = np.linalg.svd(x) + x = np.transpose(u, (0,1,3,2) ) + x = np.reshape(x, (filters_size, -1, basis_size) ) + x = x[:,:stack_size,:] - x = np.fft.irfft2( x, kernel_shape ) \ - + np.random.normal(0, eps_std, (filters_size,stack_size,)+kernel_shape) + x = np.reshape(x, ( (filters_size,stack_size,) + kernel_fft_shape ) ) - x = x * np.sqrt( (2/fan_in) / np.var(x) ) - x = np.transpose( x, (2, 3, 1, 0) ) - return x.astype(dtype) - nn.initializers = initializers \ No newline at end of file + x = np.fft.irfft2( x, kernel_shape ) \ + + np.random.normal(0, eps_std, (filters_size,stack_size,)+kernel_shape) + + x = x * np.sqrt( (2/fan_in) / np.var(x) ) + x = np.transpose( x, (2, 3, 1, 0) ) + return x.astype(dtype) + + class Cli(Subprocessor.Cli): + #override + def process_data(self, data): + idx, shape, dtype = data + weights = CAInitializerSubprocessor.generate (shape, dtype) + return idx, weights + + #override + def __init__(self, data_list): + self.data_list = data_list + self.data_list_idxs = [*range(len(data_list))] + self.result = [None]*len(data_list) + super().__init__('CAInitializerSubprocessor', CAInitializerSubprocessor.Cli) + + #override + def process_info_generator(self): + for i in range( min(multiprocessing.cpu_count(), len(self.data_list)) ): + yield 'CPU%d' % (i), {}, {} + + #override + def get_data(self, host_dict): + if len (self.data_list_idxs) > 0: + idx = self.data_list_idxs.pop(0) + shape, dtype = self.data_list[idx] + return idx, shape, dtype + return None + + #override + def on_data_return (self, host_dict, data): + self.data_list_idxs.insert(0, data) + + #override + def on_result (self, host_dict, data, result): + idx, weights = result + self.result[idx] = weights + + #override + def get_result(self): + return self.result \ No newline at end of file diff --git a/core/leras/layers.py b/core/leras/layers.py index 7597ccf..81654a5 100644 --- a/core/leras/layers.py +++ b/core/leras/layers.py @@ -77,18 +77,25 @@ def initialize_layers(nn): def init_weights(self): ops = [] - tuples = [] + + ca_tuples_w = [] + ca_tuples = [] for w in self.get_weights(): initializer = w.initializer for input in initializer.inputs: if "_cai_" in input.name: - tuples.append ( (w, nn.initializers.ca.generate(w.shape.as_list(), dtype= w.dtype.as_numpy_dtype) ) ) + ca_tuples_w.append (w) + ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) ) break else: ops.append (initializer) - nn.tf_sess.run (ops) - nn.tf_batch_set_value(tuples) + if len(ops) != 0: + nn.tf_sess.run (ops) + + if len(ca_tuples) != 0: + nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] ) + nn.Saveable = Saveable class LayerBase(): diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index 4d18710..a071f1b 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -231,7 +231,7 @@ class QModel(ModelBase): [self.decoder_dst, 'decoder_dst.npy'] ] if self.is_training: - self.src_dst_trainable_weights = self.encoder.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() # Initialize optimizers self.src_dst_opt = nn.TFRMSpropOptimizer(lr=2e-4, lr_dropout=0.3, name='src_dst_opt') diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 67a19f3..96e553d 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -278,8 +278,7 @@ class SAEHDModel(ModelBase): z = inp if self.is_hd: - x, upx = self.res0(z) - + x, upx = self.res0(z) x = self.upscale0(x) x = tf.nn.leaky_relu(x + upx, 0.2) x, upx = self.res1(x) @@ -410,8 +409,8 @@ class SAEHDModel(ModelBase): self.src_dst_opt = nn.TFRMSpropOptimizer(lr=lr, lr_dropout=lr_dropout, clipnorm=clipnorm, name='src_dst_opt') self.model_filename_list += [ (self.src_dst_opt, 'src_dst_opt.npy') ] if 'df' in archi: - self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() - self.src_dst_trainable_weights = self.encoder.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) + self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights() + self.decoder_dst.get_weights() + self.src_dst_trainable_weights = self.encoder.get_weights() + self.inter.get_weights() + self.decoder_src.get_weights_ex(learn_mask) + self.decoder_dst.get_weights_ex(learn_mask) elif 'liae' in archi: self.src_dst_all_trainable_weights = self.encoder.get_weights() + self.inter_AB.get_weights() + self.inter_B.get_weights() + self.decoder.get_weights()