mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
Merger: fix load time of xseg if it has no model files
This commit is contained in:
parent
6134e57762
commit
3e7ee22ae3
1 changed files with 11 additions and 4 deletions
|
@ -41,7 +41,6 @@ class XSegNet(object):
|
||||||
self.model_weights = self.model.get_weights()
|
self.model_weights = self.model.get_weights()
|
||||||
|
|
||||||
model_name = f'{name}_{resolution}'
|
model_name = f'{name}_{resolution}'
|
||||||
|
|
||||||
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
self.model_filename_list = [ [self.model, f'{model_name}.npy'] ]
|
||||||
|
|
||||||
if training:
|
if training:
|
||||||
|
@ -59,6 +58,7 @@ class XSegNet(object):
|
||||||
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0]
|
||||||
self.net_run = net_run
|
self.net_run = net_run
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
# Loading/initializing all models/optimizers weights
|
# Loading/initializing all models/optimizers weights
|
||||||
for model, filename in self.model_filename_list:
|
for model, filename in self.model_filename_list:
|
||||||
do_init = not load_weights
|
do_init = not load_weights
|
||||||
|
@ -66,12 +66,16 @@ class XSegNet(object):
|
||||||
if not do_init:
|
if not do_init:
|
||||||
model_file_path = self.weights_file_root / filename
|
model_file_path = self.weights_file_root / filename
|
||||||
do_init = not model.load_weights( model_file_path )
|
do_init = not model.load_weights( model_file_path )
|
||||||
if do_init and raise_on_no_model_files:
|
if do_init:
|
||||||
raise Exception(f'{model_file_path} does not exists.')
|
if raise_on_no_model_files:
|
||||||
|
raise Exception(f'{model_file_path} does not exists.')
|
||||||
|
if not training:
|
||||||
|
self.initialized = False
|
||||||
|
break
|
||||||
|
|
||||||
if do_init:
|
if do_init:
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
def get_resolution(self):
|
def get_resolution(self):
|
||||||
return self.resolution
|
return self.resolution
|
||||||
|
|
||||||
|
@ -86,6 +90,9 @@ class XSegNet(object):
|
||||||
model.save_weights( self.weights_file_root / filename )
|
model.save_weights( self.weights_file_root / filename )
|
||||||
|
|
||||||
def extract (self, input_image):
|
def extract (self, input_image):
|
||||||
|
if not self.initialized:
|
||||||
|
return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype )
|
||||||
|
|
||||||
input_shape_len = len(input_image.shape)
|
input_shape_len = len(input_image.shape)
|
||||||
if input_shape_len == 3:
|
if input_shape_len == 3:
|
||||||
input_image = input_image[None,...]
|
input_image = input_image[None,...]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue