Bringing PDEs to JAX with forward and reverse modes automatic differentiation

Research output: Contribution to conferencePaperScientific


Partial differential equations (PDEs) are used to describe a variety of physical phenomena. Often these equations do not have analytical solutions and numerical approximations are used instead. One of the common methods to solve PDEs is the finite element method. Computing derivative information of the solution with respect to the input parameters is important in many tasks in scientific computing. We extend JAX automatic differentiation library with an interface to Firedrake finite element library. High-level symbolic representation of PDEs allows bypassing differentiating through low-level possibly many iterations of the underlying nonlinear solvers. Differentiating through Firedrake solvers is done using tangent-linear and adjoint equations.
This enables the efficient composition of finite element solvers with arbitrary differentiable programs.
Original languageEnglish
Publication statusUnpublished - 2020
MoE publication typeNot Eligible
EventInternational Conference on Learning Representations - Addis Ababa, Ethiopia
Duration: 26 Apr 202030 Apr 2020
Conference number: 8


ConferenceInternational Conference on Learning Representations
Abbreviated titleICLR
CityAddis Ababa


  • jax
  • adjoint equation
  • automatic differentiation


Dive into the research topics of 'Bringing PDEs to JAX with forward and reverse modes automatic differentiation'. Together they form a unique fingerprint.

Cite this