Jax is so complex though! Autograd using bytecode reflection to inspect CPython’s interpreter state to emit Jax’s own front end ILR (jaxpr), the Jax-specific compiler (XLA) that lowers HLO down to at least three different implementation backends (cpu, TPU, GPU for CUDA, maybe more…) Then there’s the JIT that Jax also brings to the table. All of that to make something that seems simple on the surface.
You could say that Jax is simultaneously trying to be numpy, Theano/sympy, PyPy/numba, and pyCUDA all at the same time.
Both systems are trying to be so much. Perhaps the difference is Jax’s focus on a narrower developer interface.
You could say that Jax is simultaneously trying to be numpy, Theano/sympy, PyPy/numba, and pyCUDA all at the same time.
Both systems are trying to be so much. Perhaps the difference is Jax’s focus on a narrower developer interface.