sort data for CAInitializerMP

This commit is contained in:
iperov 2019-08-25 07:50:27 +04:00
parent 00dce38187
commit d129b5dd7f

View file

@ -682,7 +682,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
def CAInitializerMP( conv_weights_list ): def CAInitializerMP( conv_weights_list ):
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295 #Convolution Aware Initialization https://arxiv.org/abs/1702.06295
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run() data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ]
data = sorted(data, key=lambda data: sum(data[1]) )
result = CAInitializerMPSubprocessor (data, K.floatx(), K.image_data_format() ).run()
for idx, weights in result: for idx, weights in result:
K.set_value ( conv_weights_list[idx], weights ) K.set_value ( conv_weights_list[idx], weights )
nnlib.CAInitializerMP = CAInitializerMP nnlib.CAInitializerMP = CAInitializerMP
@ -1090,7 +1092,6 @@ class CAInitializerMPSubprocessor(Subprocessor):
#override #override
def __init__(self, idx_shapes_list, floatx, data_format ): def __init__(self, idx_shapes_list, floatx, data_format ):
self.idx_shapes_list = idx_shapes_list self.idx_shapes_list = idx_shapes_list
self.floatx = floatx self.floatx = floatx
self.data_format = data_format self.data_format = data_format