diff --git a/src/weathergen/run_train.py b/src/weathergen/run_train.py index 5c1ded4f1..e2d789767 100644 --- a/src/weathergen/run_train.py +++ b/src/weathergen/run_train.py @@ -12,6 +12,7 @@ """ import logging +import os import pdb import sys import time @@ -190,15 +191,20 @@ def train_with_args(argl: list[str], stream_dir: str | None): if __name__ == "__main__": + try: + stage = os.environ.get("WEATHERGEN_STAGE") + except KeyError as e: + msg = f"missing environment variable 'WEATHERGEN_STAGE'" + raise ValueError(msg) from e - if any("train" in arg for arg in sys.argv): + if stage == "train": # Entry point for slurm script. # Check whether --from_run_id passed as argument. if any("--from_run_id" in arg for arg in sys.argv): train_continue() else: train() - elif any("inference" in arg for arg in sys.argv): + elif stage == "inference": inference() else: logger.error("No stage was found.")