mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
sort data for CAInitializerMP
This commit is contained in:
parent
00dce38187
commit
d129b5dd7f
1 changed files with 4 additions and 3 deletions
|
@ -681,8 +681,10 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
nnlib.Adam = Adam
|
nnlib.Adam = Adam
|
||||||
|
|
||||||
def CAInitializerMP( conv_weights_list ):
|
def CAInitializerMP( conv_weights_list ):
|
||||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||||
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
|
data = [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ]
|
||||||
|
data = sorted(data, key=lambda data: sum(data[1]) )
|
||||||
|
result = CAInitializerMPSubprocessor (data, K.floatx(), K.image_data_format() ).run()
|
||||||
for idx, weights in result:
|
for idx, weights in result:
|
||||||
K.set_value ( conv_weights_list[idx], weights )
|
K.set_value ( conv_weights_list[idx], weights )
|
||||||
nnlib.CAInitializerMP = CAInitializerMP
|
nnlib.CAInitializerMP = CAInitializerMP
|
||||||
|
@ -1090,7 +1092,6 @@ class CAInitializerMPSubprocessor(Subprocessor):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def __init__(self, idx_shapes_list, floatx, data_format ):
|
def __init__(self, idx_shapes_list, floatx, data_format ):
|
||||||
|
|
||||||
self.idx_shapes_list = idx_shapes_list
|
self.idx_shapes_list = idx_shapes_list
|
||||||
self.floatx = floatx
|
self.floatx = floatx
|
||||||
self.data_format = data_format
|
self.data_format = data_format
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue