> the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function"
A couple points on that:
1) The most direct route can be even simpler (what's called forward mode differentiation)! You want the derivative of a function in some direction at some point, and you can do that by just passing in the point and the direction. If every function involved knows how to transform the point and transform the direction, then you just evaluate it step by step, no transformations required. This is the "jvp" approach in OP.
2) Something that is often misunderstood about JAX is that JAX isn't just about taking derivatives. A large part of it is just about transforming functions. Hence its idiosyncrasies. It turns out one of the transformations you can do is exactly what you said: transform it into its AST (called jaxprs IIRC), then transform that into whatever you want (gradients, inverses, parallel computations, JIT compile it, whatever), then turn that back into a function. And that's exactly how the linked post does reverse mode differentiation a couple pages in (IIRC). That flexibility is both what makes JAX's approach so interesting, and what makes JAX such a PITA to debug.
A couple points on that:
1) The most direct route can be even simpler (what's called forward mode differentiation)! You want the derivative of a function in some direction at some point, and you can do that by just passing in the point and the direction. If every function involved knows how to transform the point and transform the direction, then you just evaluate it step by step, no transformations required. This is the "jvp" approach in OP.
2) Something that is often misunderstood about JAX is that JAX isn't just about taking derivatives. A large part of it is just about transforming functions. Hence its idiosyncrasies. It turns out one of the transformations you can do is exactly what you said: transform it into its AST (called jaxprs IIRC), then transform that into whatever you want (gradients, inverses, parallel computations, JIT compile it, whatever), then turn that back into a function. And that's exactly how the linked post does reverse mode differentiation a couple pages in (IIRC). That flexibility is both what makes JAX's approach so interesting, and what makes JAX such a PITA to debug.