ML Engineer comparison of Pytorch, TensorFlow, JAX, and Flax
Which Deep Learning framework is best?
When you enter the ML world, you might be overwhelmed with a choice of libraries, with divisions similar to political parties or religion (almost to the point of front-end frameworks). Among the many available, a few are the most popular: Pytorch, Tensorflow (+ Keras), Pytorch Lightning, and, more recently, JAX (and its NN framework - Flax). Each can be operated on different levels, with Pytorch Lightning and Keras being more high-level. At the same time, Tensorflow, PyTorch, and Flax allow for more control, and JAX operates on the lowest level.
In this blog post, I aim to provide qualitative examples that may allow us to compare the benefits and shortcomings of the frameworks, making it easier to understand their pros and cons.
Tensorflow + Keras
TensorFlow was developed by Google Brain and released in 2015. This framework supports many machine learning tasks, from simple linear regression to complex deep learning architectures. It allows for low-level control over the model architecture.
In contrast, higher-level control can be done via Keras, which, with the 2nd version, became a core part of Tensorflow). On the other hand, Keras 3.0 is again moving away from a Tensorflow-only approach to support a variety of backends: JAX, Tensorflow, and Pytorch.
Additionally, Tensorflow is known for its wide selection of platforms available with mobile devices (Tensorflow Lite) or large-scale distributed systems (Tensorflow Serving), production-ready deployment (TFX - Tensorflow Extended) or web deployment (Tensorflow.js for running models in Node.js).
Pytorch
The popularity of frameworks in new Machine Learning papers, source: visio.ai
Pytorch was released in 2016 by Facebook's AI research Lab (FAIR) and quickly gained popularity among researchers and developers. This is primarily due to its dynamic computation graph and ease of use, which dethroned Tensorflow's stable position then.
In our blog post regarding the indepth analysis of the pneumonia model we used PyTorch to train and evaluate the model.
Its Dynamic Computation Graph (meaning that the graph is built on the fly) was groundbreaking then, allowing for easier network debugging and runtime modifications. Such an architecture and API were more suited for Pythonic-style development, unlike other frameworks (i.e. Caffe). Another excellent benefit was its strong support for automatic gradient calculation, greatly simplifying the implementation of machine learning models. Additionally, PyTorch allows to build code flexibly and modularly, making it extremely popular for R&D.
With more time and popularity among the developer's community, the Pytorch ecosystem grew significantly with great documentation, specialised tools and libraries (ranging from different domains- audio, computer vision or text/nlp, to deployment i.e. edge deployment), but also more specialised: medical computer vision library - Monai, GenAI diffusion library - Diffusers or 3d data point clouds - Pytorch Geometrics or Pytorch3D. Additionally, libraries like fast.ai allow for a quick start in creating meaningful and capable solutions, with easy access to the latest inventions from the research world.
Pytorch Lightning
Pytorch Lightning is a high-level framework built on top of Pytorch to simplify developing and training deep learning models. It abstracts complicated operations that had to be written (i.e., training loops and validation steps) and allows researchers to focus on the core aspects of model development.
The main benefits of this framework are:
- easy scalability with multiple GPUs and nodes,
- automatic handling of common tasks like checkpointing, logging or early stopping.
- even more modular structure, making code easier to read and maintain.
JAX +Flax
JAX is an open-source numerical computation library developed by Google, combining Python's flexibility with hardware acceleration. The main benefits of JAX are:
- automatic differentiation, native with Python and NumPy.
- Just-In-Time Compilation (JIT), allowing for speedups of Python function into optimised machine code, able to run on CPU, GPU and TPU.
- extensive vectorisation and parallelism, enabling efficient code execution across different hardware platforms.
On the other hand, Flax is a high-level neural network library built on top of JAX, which Google also developed. The main aim of that library is to provide modular design with ease of customisation. The structured way to manage model parameters, state, and randomness allows it to separate model and training logic, making code easier to debug.
Combining JAX and Flax is the primary weapon of choice by both Deepmind and Google AI.
Analysis
We reviewed the general information and motivation behind each framework, but we want some qualitative comparisons, like popularity, as well as UX and performance metrics like execution time, ease of writing, etc. The following section will present better insight into choosing which framework is best for which task.
Popularity
Star-based perspective
Below, we can compare the popularity of chosen frameworks on GitHub.
Popularity on GitHub
We see the dominance of Pytorch (81k stars) , closely followed by Tensorflow (74k stars) and its closely associated framework, Keras (61k). JAX (29k stars) and Pytorch Lightning (28k stars) are younger libraries, with Flax (6k stars) being the least popular because it is the youngest.
Our open source library for Explainable AI in Computer Vision - FoXAI, was written for PyTorch models, due to its widespread popularity.
Research perspective
Additionally, when we analyse in which framework the new papers are written, we can see the dominance of Pytorch in the research community (with more than 75% of new deep learning papers using Pytorch). This eventually proves that people are more prone to use new models using Pytorch, however, aggregators like Hugging Face's Transformer often allow to be model agnostic without the need to convert models manually.
Simple computer vision example
Code examples
Jupyter notebooks to this example can be found here:
- Pytorch: pytorch_example.py
- Tensorflow + Keras: tensorflow_example.py
- Pytorch lightning: pytorch_lightning_example.py
- JAX: jax_example.py
Experimental setup
In the premise of this solution, we focus on a simple example of using:
- Neural Network consisting of 5 Conv layers with Batch Normalization (as this enforces different behaviour between training and testing phase)
- Imagenette dataset, consisting of 10 most accessible classes from ImageNet, resulting in 9472 examples.
- Adam optimizer with a simple learning rate of 0.001.
Performance
In this section, we will explore two main implementations in Machine Learning, one suitable for large datasets, working in a streaming fashion, and another ideal for smaller datasets, which can fully fit in the memory of GPU.
Out-of-the-box implementation (stream-based)
Pytorch | Pytorch lightning | Tensorflow + Keras | JAX |
232s (180s with native loader) | 180s | 64s | 84s |
In this setup, we load data in a streaming fashion, so each image is loaded via a data loader when needed. Plenty of ways to tune each setup further, but here, we aim for simple, out-of-the-box implementation. Our results show that Tesorflow and JAX are faster, while Pytroch and Pytorch lightning are slower. The reason behind it is straightforward: Pytorch and Pytorch lightning use PIL-based image loading, while Tensorflow and FLAX use TF native implementation. When experimenting with a custom-defined Dataloader for Pytorch (via torchvision.io.read_image()), it was possible to reduce the processing time by 22% (down to 180s). This is still significantly slower than TF processing (under the hood using highly optimised tf.io.read_file()).
All in-memory implementation (memory-based)
In contrast to the previous setup, in this one, we tried a subset of 3/10 classes due to memory limitations) and placed the whole dataset on the GPU (limiting the effect of slower DataLoaders). In this setup, Pytorch seems significantly faster than other solutions, PyTorch Lightning and JAX have similar times, and Tensorflow is the slowest. This might indicate that Pytorch Lightning and Keras might have certain extra overhead due to their additional features.
Both examples (stream-based and memory-based) are simple examples of the Imagenette dataset. Instead of showing which framework is best, they are good indications of certain caveats behind the scenes of basic API. Some of these caveats that were already mentioned (like data loading) will be explained, and some will be discussed in the next section.
Might interest you: Machine Learning libraries for any project
Coding UX and analysis
Each example is a similar default setup, with a loading subset of Imagenet called Imagenette (10 most accessible classes from ImageNet) consisting of 9472 examples. A comparison of lines of code (LOCs) can be seen below (without accounting imports and configurations to be set).
Comparison of LOCs for implementations in different frameworks
Pytorch and Tensorflow require the most lines, and JAX is the most concise (however, it might require certain system variables set). What needs to be visible here is the characteristics of each framework. Pytorch allows for great modularisation of the codebase (with an OOP approach to model creation), with Pytorch Lightning going further and applying a similar approach to the data plane. This allows for much clearer and easier-to-understand code.
Tensorflow started with a more functional approach but later moved to a subclassed API (inspired by Pytorch's success), which allows that, too. On the other hand, datasets and dataloaders written in that framework are more commonly written functionally.
Flax and JAX's philosophy is to make model architecture and training characteristics independent. Thus, a TrainingState with access to the model’s parameters, chosen optimizer, and learning rate is expected to be created without interfering with the exact architecture of the model.
It is important to mention here that Tensorflow has accumulated certain technical debt over the years, which can be seen in breaking changes (and effectively low support and compatibility between Keras 3.0 and TF2)
Interesting differences between frameworks
Different initialisers
Another interesting difference between the frameworks is their default initialisation of weights:
- Tensorflow uses Xavier Uniform initialisation for the weights,
- Pytorch uses Kaiming Uniform initialisation for the weights,
- Flax uses LeCun Normal initialisation for the weights.
Comparison of starting weights distribution for Xavier, Kaming and LeCun initialisations
While Kaming takes into account the layer's number of incoming units (for convolution layer: number of input channels multiplied by kernel size), Xavier initialisation takes into account outgoing units as well (fan_out). LeCun initialisation (a variant of Xavier initialisation) is more suited for tanh-activated layers.
They all might look similar but different initializiers might greatly influence how the models start the training and in which direction they might go, so keeping an eye to details might go a long way, when you port the code to new framework.
Different processing backends
As said in the performance section, certain differences exist in how images and, by extension, data augmentation are handled. Pytorch uses a PIL-based approach, while Tensorflow and JAX use more native solutions. This is a double-edged sword as it is easier to perform PIL-based image augmentations; however, they might not be as well optimised as native solutions and eventually result in longer preprocessing time (as it was in our example with the stream-based approach).
Different GPU memory allocation algorithms
Pytorch and Pytorch Lightning incrementally allocate memory, allocating more when needed. Tensorflow and JAX, on the other hand, operate in a greedy fashion, which might cause strange errors when used in the same scope. This is a common issue, which is referenced on the JAX website and can be solved with a few lines of code.
The different strategies might influence processing time but also allocate more memory when needed, blocking the possibility of using different components that might require GPU.
Support for existing models
The whole analysis shows that JAX is always a good and readable solution. However, because it is the youngest and yet to be widely adopted, more models with trained weights must be needed. This can be mitigated by utilising HuggingFace and its transformer library; however, this may not always be the best solution when trying to limit the dependencies.
Conclusions
This blog post and the attached examples made understanding some specific caveats and shortcomings of different frameworks easy. Some are well established (like Pytorch), some have technical debt from previous solutions (like Tensorflow), and some still need to show their great potential as their community grows by the day, e.g., JAX + Flax.
Reviewed by: Kamil Rzechowski