JAX-FEM: A Differentiable Finite Element Solver

  • Xue, Tianju (HKUST)

Please login to view abstract download link

Differentiable programming enables the automatic computation of gradients across complex computational workflows, leveraging frameworks like JAX to optimize parameters in neural networks, physics simulations, or hybrid models. By encoding the physical process as a differentiable function, gradients of the mismatch between predictions and observations (loss) can be backpropagated to efficiently infer unknown inputs (e.g., material properties, boundary conditions), avoiding brute-force search. This presentation will introduce JAX-FEM [1], an open-source differentiable finite element framework built on JAX, designed to solve nonlinear inverse problems by integrating automatic differentiation (AD) with finite element methods (FEM). We will review existing features of JAX-FEM, discuss some ongoing efforts, and look into possible future developments. Two prominent applications of differentiable programming for solving inverse problems will be presented: parameter identification for Lithium-ion batteries and calibration of crystal plasticity finite element models. Also, some recent progress of differentiable programming in deep generative AI models will be discussed.