Try in Colab
Setting up PyTorch Lightning and W&B
For this tutorial, we need PyTorch Lightning and W&B.DataModule - The Data Pipeline we Deserve
DataModules are a way of decoupling data-related hooks from the LightningModule so you can develop dataset agnostic models. It organizes the data pipeline into one shareable and reusable class. A datamodule encapsulates the five steps involved in data processing in PyTorch:- Download / tokenize / process.
- Clean and (maybe) save to disk.
- Load inside Dataset.
- Apply transforms (rotate, tokenize, etcโฆ).
- Wrap inside a DataLoader.
Callbacks
A callback is a self-contained program that can be reused across projects. PyTorch Lightning comes with few built-in callbacks which are regularly used. Learn more about callbacks in PyTorch Lightning here.Built-in Callbacks
In this tutorial, we will use Early Stopping and Model Checkpoint built-in callbacks. They can be passed to theTrainer.
Custom Callbacks
If you are familiar with Custom Keras callback, the ability to do the same in your PyTorch pipeline is just a cherry on the cake. Since we are performing image classification, the ability to visualize the modelโs predictions on some samples of images can be helpful. This in the form of a callback can help debug the model at an early stage.LightningModule - Define the System
The LightningModule defines a system and not a model. Here a system groups all the research code into a single class to make it self-contained.LightningModule organizes your PyTorch code into 5 sections:
- Computations (__init__).
- Train loop (training_step)
- Validation loop (validation_step)
- Test loop (test_step)
- Optimizers (configure_optimizers)
Train and Evaluate
Now that we have organized our data pipeline usingDataModule and model architecture+training loop using LightningModule, the PyTorch Lightning Trainer automates everything else for us.
The Trainer automates:
- Epoch and batch iteration
- Calling of optimizer.step(),backward,zero_grad()
- Calling of .eval(), enabling/disabling grads
- Saving and loading weights
- W&B logging
- Multi-GPU training support
- TPU support
- 16-bit training support
Final Thoughts
I come from the TensorFlow/Keras ecosystem and find PyTorch a bit overwhelming even though itโs an elegant framework. Just my personal experience though. While exploring PyTorch Lightning, I realized that almost all of the reasons that kept me away from PyTorch is taken care of. Hereโs a quick summary of my excitement:- Then: Conventional PyTorch model definition used to be all over the place. With the model in some model.pyscript and the training loop in thetrain.pyfile. It was a lot of looking back and forth to understand the pipeline.
- Now: The LightningModuleacts as a system where the model is defined along with thetraining_step,validation_step, etc. Now itโs modular and shareable.
- Then: The best part about TensorFlow/Keras is the input data pipeline. Their dataset catalog is rich and growing. PyTorchโs data pipeline used to be the biggest pain point. In normal PyTorch code, the data download/cleaning/preparation is usually scattered across many files.
- Now: The DataModule organizes the data pipeline into one shareable and reusable class. Itโs simply a collection of a train_dataloader,val_dataloader(s),test_dataloader(s) along with the matching transforms and data processing/downloads steps required.
- Then: With Keras, one can call model.fitto train the model andmodel.predictto run inference on.model.evaluateoffered a good old simple evaluation on the test data. This is not the case with PyTorch. One will usually find separatetrain.pyandtest.pyfiles.
- Now: With the LightningModulein place, theTrainerautomates everything. One needs to just calltrainer.fitandtrainer.testto train and evaluate the model.
- Then: TensorFlow loves TPU, PyTorchโฆ
- Now: With PyTorch Lightning, itโs so easy to train the same model with multiple GPUs and even on TPU.
- Then: I am a big fan of Callbacks and prefer writing custom callbacks. Something as trivial as Early Stopping used to be a point of discussion with conventional PyTorch.
- Now: With PyTorch Lightning using Early Stopping and Model Checkpointing is a piece of cake. I can even write custom callbacks.
๐จ Conclusion and Resources
I hope you find this report helpful. I will encourage to play with the code and train an image classifier with a dataset of your choice. Here are some resources to learn more about PyTorch Lightning:- Step-by-step walk-through: This is one of the official tutorials. Their documentation is really well written and I highly encourage it as a good learning resource.
- Use Pytorch Lightning with W&B: This is a quick colab that you can run through to learn more about how to use W&B with PyTorch Lightning.