IntroductionSLURM IntroductionHydra and Submitit PyTorch Lightning and Checkpointing[Recommended] Automatically Reloading HPC weight[Experimental] Fault-tolerant Training
Introduction
PyTorch Lightning is a framework for doing deep learning research with PyTorch. It makes things like checkpointing, logging, and distributed training a lot more easier.
In this blog post, I will focus primarily on how to setup the PyTorch lightning to work on SLURM cluster. In the typical SLURM cluster workflow, you need to submit a job to the cluster. Oftentimes, if you specify the job with a very long timeout, your job will get queued very slowly. The workaround here is to submit a short-time job (i.e., 1-6 hours or time enough for 1 training and 1 validation loops) and checkpoint the training state for the next job. The emphasis of this post would be on this particular topic: how to setup PyTorch Lightning to work with a SLURM cluster and checkpointing training?
Dependencies (
torch
≥ 1.8
pytorch-lightning
≥ 1.5.9
hydra
≥ 1.1
hydra-submitit-launcher
≥ 1.1.6
In this post, I will show you two ways of accomplishing this:
- Setting a model checkpoint to store HPC weights,
- Using the fault-tolerant training.
SLURM Introduction
If you are not very familiar with a SLURM environment, SLURM workload manager is a resource management tool used in many supercomputers and computer clusters. Even if you are not working in a SLURM environment, you may skip this section.
First, let’s take a look at a sample script for submitting a job to SLURM.
#!/bin/bash
# Parameters
#SBATCH --account=realitylab
#SBATCH --constraint=[a40|rtx6k|titan]
#SBATCH --cpus-per-task=5
#SBATCH --error=/mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_0_log.err
#SBATCH --gpus-per-node=1
#SBATCH --job-name=run
#SBATCH --mem=256GB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --open-mode=append
#SBATCH --output=/mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_0_log.out
#SBATCH --partition=ckpt
#SBATCH --signal=USR1@180
#SBATCH --time=120
#SBATCH --wckey=submitit
# command
export SUBMITIT_EXECUTOR=slurm
srun --output /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_%t_log.out --error /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j/%j_%t_log.err --unbuffered /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/.env/bin/python -u -m submitit.core._submit /mmfs1/gscratch/realitylab/tjenrung/AV-codecs/logs/experiment=audiovideostream_debug/multiruns/2022-01-27_09-26-02/.submitit/%j
Let’s explain a few details based on the above script. Here, we are running the training script on
- Account
--account=realitylab
: an account in which you have resources for submitting a job (Note: this is relevant only if you have multiple accounts under different affiliations in the SLURM cluster.
- Partition
--partition=ckpt
: a partition to which you want to submit a job. Usually, a computing cluster is separated into CPU partition, GPU partition, and checkpoint partition. The checkpoint partition is mostly used for small short tasks and shared across the entire SLURM cluster whereas a normal CPU/GPU partition is specific to a particular account.
- Resource Requested
cpus-per-task
gpus-per-node
mem
nodes
ntasks-per-node
: these are specifications for CPU/GPU/Memory that you request.
- Time
--time=120
: time in minute(s) that you want your script to run. In our context, you don’t need to specify the time very long because you would want to minimize the wait time between checkpoints.
- End-time Signal
--signal=USR1@180
: this tells that there will be a signalSIGUSR1
that happened around 180 seconds before the end-time of your allocation. This is very useful for our application as we want to wrap things up before we start the next iteration (e.g., saving the model’s weight and important states). In a SLURM environment, the
Hydra and Submitit
Hydra is an open-source Python framework that simplifies complex applications that may have lots of configurations involved. In our context of deep learning experiments, you oftentimes encounter situations where you want to change one of the variables in the data modules or models. This becomes a problem once you make too many changes and are unable to keep track of what’s in each experiment. Hydra is designed to solve this problem.
For example, this is a sample configuration file that I use for my project.
# @package _global_
# to execute this experiment run:
# python run.py experiment=audiovideostream.yaml
defaults:
- override /mode: exp.yaml
- override /trainer: null
- override /model: null
- override /datamodule: null
- override /callbacks: null
- override /logger: null
- override /hydra/launcher@_here_: submitit_slurm
# we override default configurations with nulls to prevent them from loading at all
# instead we define all modules and their paths directly in this config,
# so everything is stored in one place
name: "audiovideostream"
seed: 12345
trainer:
_target_: pytorch_lightning.Trainer
gpus: 1
min_epochs: 1
max_epochs: 200
gradient_clip_val: 2.0
accelerator: gpu
strategy: ddp
weights_summary: top
resume_from_checkpoint: null
model:
_target_: src.models.audiovideostream.AudioVideoStreamLitModel
lr: 0.0001
b1: 0.5
b2: 0.9
audio_enc_channels: 32
audio_emb_channels: 80
video_enc_channels: 5
video_emb_channels: 10
loss_type: all
video_only: True
audio_only: False
datamodule:
_target_: src.datamodules.audiovideostream_datamodule.AudioVideoStreamDataModule
hdf5_path: ${data_dir}/VoxCeleb2/dataset.hdf5
video_fps: 25
audio_fps: 16000
length: 3.0
preload: True
train_val_test_split: [156, 4, 4]
batch_size: 2
num_workers: 0
pin_memory: True
callbacks:
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val/g_loss"
save_top_k: -1
save_last: True
mode: "min"
dirpath: "."
filename: "hpc_ckpt_{epoch:d}"
auto_insert_metric_name: False
logger:
csv:
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
save_dir: "."
name: "csv/${name}"
version: ${name}
prefix: ""
tensorboard:
_target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
save_dir: "tensorboard/"
name: null
version: ${name}
log_graph: False
default_hp_metric: True
prefix: ""
hydra:
launcher:
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
account: realitylab
partition: ckpt
timeout_min: 120
cpus_per_task: 5
gpus_per_node: 1
tasks_per_node: 1
mem_gb: 256
nodes: 1
name: ${hydra.job.name}
comment: null
constraint: "[a40|rtx6k|titan]"
exclude: null
cpus_per_gpu: null
gpus_per_task: null
mem_per_gpu: null
mem_per_cpu: null
signal_delay_s: 600
max_num_timeout: 20
additional_parameters: {}
array_parallelism: 256
setup: []
Hydra’s key features taken from its website are:
- Hierarchical configuration composable from multiple sources,
- Configuration can be specified or overridden from the command line,
- Dynamic command line tab completion,
- Run your application locally or launch it to run remotely,
- Run multiple jobs with different arguments with a single command.
Submitit launcher is a plugin that provides a SLURM launcher based on Submitit. For instance, you can embed these following configurations to Hydra config file.
_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.LocalLauncher
submitit_folder: ${hydra.sweep.dir}/.submitit/%j
timeout_min: 60
cpus_per_task: 1
gpus_per_node: 0
tasks_per_node: 1
mem_gb: 4
nodes: 1
name: ${hydra.job.name}
Combining Hydra and Submitit plugin, we can have a single configuration file that describes the entire experiment. It’s very useful for reproducing or understanding the experiment.
More details could be found as follows.
- Hydra: https://hydra.cc/
- Submitit plugin: https://hydra.cc/docs/plugins/submitit_launcher/
PyTorch Lightning and Checkpointing
In this section, I will first describe how to manually restore training state in the PyTorch Lightning environment. Then, I will describe the design of PyTorch Lightning that works in the high-performance computing scenario where fault-tolerant (i.e., server-down, job’s preemption).
Manually Restore Training State
To load a model along with its weights, biases, and hyperparameters, you can use the following method.
# Load training state to a model
model = MyLightingModule.load_from_checkpoint(PATH)
# Load training state to a trainer
model = LitModel()
trainer = Trainer()
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
Note that
load_from_checkpoint
is a classmethod. So, if your code is written in the incorrect way, you may encounter a bug in loading the training state.model = LitModel()
model.load_from_checkpoint(PATH) # incorrect way of loading state
model = model.load_from_checkpoint(PATH) # correct way of loading state
PyTorch Lightning and Checkpointing
If you want PyTorch Lightning to automatically resume the training from the previous training state, you need to understand how this is done under the hood in PyTorch Lightning. Note that this post is based on PyTorch Lightning 1.5.9. There might be some changes to the future version.
From [Link], the logic for automatically restore the training state is provided in this function.
def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:
1. from HPC weights if found
2. from fault-tolerant auto-saved checkpoint if found
3. from `checkpoint_path` file if provided
4. don't restore
"""
self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
return
rank_zero_info(f"Restoring states from the checkpoint path at {checkpoint_path}")
self._loaded_checkpoint = self._load_and_validate_checkpoint(checkpoint_path)
Basically, PyTorch Lightning will load the training state in this following order:
- HPC weight (From
/hpc_ckpt_{epoch:d}.ckpt
),
- Fault-tolerant auto-saved checkpoint (From
/.pl_auto_save.ckpt
),
- Checkpoint path (From
ckpt_path
in theTrainer
).
Let’s walk through a single training scenario where you would like to fine-tune a model’s weight on some tasks. You load a checkpoint
base.ckpt
. You then train a model for some number of epochs until it stops because of your wall-time or some server interruption. There will be some weights files like hpc_ckpt_10.ckpt
(or .pl_auto_save.ckpt
if you use Fault-tolerant training). On the next training loop, you will run the same training script. PyTorch Lightning will start from hpc_ckpt_10.ckpt
(or .pl_auto_save.ckpt
if you use Fault-tolerant training). Note that the training will not start from the checkpoint provided from base.ckpt
even if it’s provided in the training script because the priority is lower than hpc_ckpt_10.ckpt
or .pl_auto_save.ckpt
.In the following sections, I will show how to setup the training to create a HPC weight or Fault-tolerant checkpoint.
[Recommended] Automatically Reloading HPC weight
In the previous section, we talk about how PyTorch Lightning can automatically reload weights from
/hpc_ckpt_{epoch:d}.ckpt
. Now, we show one example way of setting this up through Hydra configuration file.model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val/g_loss"
save_top_k: -1 # important
dirpath: "." # important
filename: "hpc_ckpt_{epoch:d}" # important
auto_insert_metric_name: False # important
Essentially, you would want to use
pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
to save the checkpoints in the format of hpc_ckpt_{epoch:d}
at every single epoch. Note that the format must be exact otherwise the automatic weight loading logic will break. If you want to save weights in other format or name, you can add another checkpoint with the desired configurations.[Experimental] Fault-tolerant Training
For more details on Fault-tolerant Training, please visit https://pytorch-lightning.readthedocs.io/en/latest/advanced/fault_tolerant_training.html. Note that Fault-tolerant training is an experimental feature and contains some bugs. Based on my experience, it sometimes doesn’t work as expected. You may be able to track the progress of fault-tolerant training implementation [here].
The idea of fault-tolerant training is a mechanism for PyTorch Lightning to recover from a hardware/software failure. This is particularly relevant in the case of SLURM cluster in which instances could be shutdown at any time. Under the hood, PyTorch keep tracks of the following state updates during training:
- Samplers indices and random state
- Optimizers, LR schedulers, callbacks, etc.
- Loop progression
- Logging internal states
To enable Fault-tolerant Training on PyTorch Lightning, you need to set an environment variable
PL_FAULT_TOLERANT_TRAINING=1
. An example execution script isPL_FAULT_TOLERANT_TRAINING=1 python run.py experiment=audiovideostream_debug --multirun
During training, PyTorch Lightning will automatically create a file named
.pl_auto_save.ckpt
for keeping required states. Note: if you also setup the model checkpoint like in the previous section. The
.pl_auto_save.ckpt
will not be loaded.