diff --git a/facelib/XSegNet.py b/facelib/XSegNet.py index 88acab2..f260e21 100644 --- a/facelib/XSegNet.py +++ b/facelib/XSegNet.py @@ -41,7 +41,6 @@ class XSegNet(object): self.model_weights = self.model.get_weights() model_name = f'{name}_{resolution}' - self.model_filename_list = [ [self.model, f'{model_name}.npy'] ] if training: @@ -59,6 +58,7 @@ class XSegNet(object): return nn.tf_sess.run ( [pred], feed_dict={self.input_t :input_np})[0] self.net_run = net_run + self.initialized = True # Loading/initializing all models/optimizers weights for model, filename in self.model_filename_list: do_init = not load_weights @@ -66,12 +66,16 @@ class XSegNet(object): if not do_init: model_file_path = self.weights_file_root / filename do_init = not model.load_weights( model_file_path ) - if do_init and raise_on_no_model_files: - raise Exception(f'{model_file_path} does not exists.') + if do_init: + if raise_on_no_model_files: + raise Exception(f'{model_file_path} does not exists.') + if not training: + self.initialized = False + break if do_init: model.init_weights() - + def get_resolution(self): return self.resolution @@ -86,6 +90,9 @@ class XSegNet(object): model.save_weights( self.weights_file_root / filename ) def extract (self, input_image): + if not self.initialized: + return 0.5*np.ones ( (self.resolution, self.resolution, 1), nn.floatx.as_numpy_dtype ) + input_shape_len = len(input_image.shape) if input_shape_len == 3: input_image = input_image[None,...]