diff --git a/main.py b/main.py index 9122cc5..7ca8fd7 100644 --- a/main.py +++ b/main.py @@ -68,7 +68,7 @@ if __name__ == "__main__": arguments.target_epoch = int ( os.environ['DFL_TARGET_EPOCH'] ) if 'DFL_BATCH_SIZE' in os.environ.keys(): - arguments.batch_size = int ( os.environ['DFL_TARGET_EPOCH'] ) + arguments.batch_size = int ( os.environ['DFL_BATCH_SIZE'] ) from mainscripts import Trainer Trainer.main (