JAX-Fluids: A fully-differentiable high-order computational fluid dynamics solver for compressible two-phase flows
Physical systems are governed by partial differential equations (PDEs). The Navier-Stokes equations describe fluid flows and are representative of nonlinear physical systems with complex spatio-temporal interactions. Fluid flows are omnipresent in nature and engineering applications, and their accurate simulation is essential for providing insights into these processes. While PDEs are typically solved with numerical methods, the recent success of machine learning (ML) has shown that ML methods can provide novel avenues of finding solutions to PDEs. ML is becoming more and more present in computational fluid dynamics (CFD). However, up to this date, there does not exist a general-purpose ML-CFD package which provides 1) powerful state-of-the-art numerical methods, 2) seamless hybridization of ML with CFD, and 3) automatic differentiation (AD) capabilities. AD in particular is essential to ML-CFD research as it provides gradient information and enables optimization of preexisting and novel CFD models. In this work, we propose JAX-Fluids: a comprehensive fully-differentiable CFD Python solver for compressible two-phase flows. JAX-Fluids is intended for ML-supported CFD research. The framework allows the simulation of complex fluid dynamics with phenomena like three-dimensional turbulence, compressibility effects, and two-phase flows. Written entirely in JAX, it is straightforward to include existing ML models into the proposed framework. Furthermore, JAX-Fluids enables end-to-end optimization. I.e., ML models can be optimized with gradients that are backpropagated through the entire CFD algorithm, and therefore contain not only information of the underlying PDE but also of the applied numerical methods. We believe that a Python package like JAX-Fluids is crucial to facilitate research at the intersection of ML and CFD and may pave the way for an era of differentiable fluid dynamics.