diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 66afd71..780cc48 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -43,7 +43,10 @@ def trainerThread (s2c, c2s, e, if not saved_models_path.exists(): saved_models_path.mkdir(exist_ok=True, parents=True) - + + if dump_ckpt: + cpu_only=True + model = models.import_model(model_class_name)( is_training=not dump_ckpt, saved_models_path=saved_models_path,