mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
fix for plaidML
This commit is contained in:
parent
3114ae9d7b
commit
c9da4cd89b
2 changed files with 6 additions and 4 deletions
|
@ -161,7 +161,6 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
return nnlib.code_import_keras
|
||||
|
||||
nnlib.backend = device_config.backend
|
||||
|
||||
if "tensorflow" in nnlib.backend:
|
||||
nnlib._import_tf(device_config)
|
||||
elif nnlib.backend == "plaidML":
|
||||
|
@ -174,6 +173,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
import keras as keras_
|
||||
nnlib.keras = keras_
|
||||
|
||||
if 'KERAS_BACKEND' in os.environ:
|
||||
os.environ.pop('KERAS_BACKEND')
|
||||
|
||||
if nnlib.backend == "plaidML":
|
||||
import plaidml
|
||||
import plaidml.tile
|
||||
|
@ -591,7 +593,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
nnlib.Adam = Adam
|
||||
|
||||
def CAInitializerMP( conv_weights_list ):
|
||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
|
||||
for idx, weights in result:
|
||||
K.set_value ( conv_weights_list[idx], weights )
|
||||
|
@ -696,7 +698,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
|||
x = ReflectionPadding2D( self.pad ) (x)
|
||||
return self.func(x)
|
||||
nnlib.Conv2D = Conv2D
|
||||
|
||||
|
||||
class Conv2DTranspose():
|
||||
def __init__ (self, *args, **kwargs):
|
||||
self.reflect_pad = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue