diff --git a/models/ModelBase.py b/models/ModelBase.py index 69b73a4..7fc166f 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -277,6 +277,11 @@ class ModelBase(object): } self.model_data_path.write_bytes( pickle.dumps(model_data) ) + def load_weights_safe(self, model_filename_list): + for model, filename in model_filename_list: + if Path(filename).exists(): + model.load_weights(filename) + def save_weights_safe(self, model_filename_list): for model, filename in model_filename_list: model.save_weights( filename + '.tmp' )