Automatic differentiation is a huge field going back to at least the 1960s, so you’re right that this is an idiosyncratic way to do it. If you want to learn about it more generally, “automatic differentiation” is the phrase to google. Tinygrad and micrograd are particularly small libraries you might enjoy looking at.
Thanks! It seems like the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function". And this JAX post doesn't really speak in those terms, at least not directly.
> the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function"
A couple points on that:
1) The most direct route can be even simpler (what's called forward mode differentiation)! You want the derivative of a function in some direction at some point, and you can do that by just passing in the point and the direction. If every function involved knows how to transform the point and transform the direction, then you just evaluate it step by step, no transformations required. This is the "jvp" approach in OP.
2) Something that is often misunderstood about JAX is that JAX isn't just about taking derivatives. A large part of it is just about transforming functions. Hence its idiosyncrasies. It turns out one of the transformations you can do is exactly what you said: transform it into its AST (called jaxprs IIRC), then transform that into whatever you want (gradients, inverses, parallel computations, JIT compile it, whatever), then turn that back into a function. And that's exactly how the linked post does reverse mode differentiation a couple pages in (IIRC). That flexibility is both what makes JAX's approach so interesting, and what makes JAX such a PITA to debug.