From d129b5dd7f1703c76ff940541dd9d12d58226f3b Mon Sep 17 00:00:00 2001 From: iperov Date: Sun, 25 Aug 2019 07:50:27 +0400 Subject: [PATCH] sort data for CAInitializerMP --- nnlib/nnlib.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 2d3b5e4..b04ff16 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -681,8 +681,10 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator nnlib.Adam = Adam def CAInitializerMP( conv_weights_list ): - #Convolution Aware Initialization https://arxiv.org/abs/1702.06295 - result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run() + #Convolution Aware Initialization https://arxiv.org/abs/1702.06295 + data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ] + data = sorted(data, key=lambda data: sum(data[1]) ) + result = CAInitializerMPSubprocessor (data, K.floatx(), K.image_data_format() ).run() for idx, weights in result: K.set_value ( conv_weights_list[idx], weights ) nnlib.CAInitializerMP = CAInitializerMP @@ -1090,7 +1092,6 @@ class CAInitializerMPSubprocessor(Subprocessor): #override def __init__(self, idx_shapes_list, floatx, data_format ): - self.idx_shapes_list = idx_shapes_list self.floatx = floatx self.data_format = data_format