Hacker News new | past | comments | ask | show | jobs | submit login

This is great if you want to learn JAX, but it seems like not the most efficient way to learn how to implement symbolic differentiation or automated generation of derivative functions on a computer. Is there something that takes a more direct route, maybe with Scheme or something? (I don't care about the language, just whatever presents the least bureaucracy and overhead.)



I'm sure there's a lot of good material around, but here are some links that are conceptually very close to the linked Autodidax. (Disclaimer: I wrote Autodidax and some of these other materials.)

There's Autodidact [0], a predecessor to Autodidax, which was a simplified implementation of the original Autograd [1]. It focuses on reverse-mode autodiff, not building an open-ended transformation system like Autodidax. It's also pretty close to the content in these lecture slides [2] and this talk [3]. But the autodiff in Autodidax is more sophisticated and reflects clearer thinking. In particular, Autodidax shows how to implement forward- and reverse-modes using only one set of linearization rules (like in [4]).

There's an even smaller and more recent variant [5], a single ~100 line file for reverse-mode AD on top of NumPy, which was live-coded during a lecture. There's no explanatory material to go with it though.

[0] https://github.com/mattjj/autodidact

[1] https://github.com/hips/autograd

[2] https://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/slid...

[3] http://videolectures.net/deeplearning2017_johnson_automatic_...

[4] https://arxiv.org/abs/2204.10923

[5] https://gist.github.com/mattjj/52914908ac22d9ad57b76b685d19a...


I find the solutions from https://github.com/qobi/AD-Rosetta-Stone/ to be very helpful, particularly for representing forward and backward mode automatic differentiation using a functional approach.

I used this code as inspiration for a functional-only (without references/pointers) in Mercury: https://github.com/mclements/mercury-ad


Automatic differentiation is a huge field going back to at least the 1960s, so you’re right that this is an idiosyncratic way to do it. If you want to learn about it more generally, “automatic differentiation” is the phrase to google. Tinygrad and micrograd are particularly small libraries you might enjoy looking at.


Thanks! It seems like the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function". And this JAX post doesn't really speak in those terms, at least not directly.


> the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function"

A couple points on that:

1) The most direct route can be even simpler (what's called forward mode differentiation)! You want the derivative of a function in some direction at some point, and you can do that by just passing in the point and the direction. If every function involved knows how to transform the point and transform the direction, then you just evaluate it step by step, no transformations required. This is the "jvp" approach in OP.

2) Something that is often misunderstood about JAX is that JAX isn't just about taking derivatives. A large part of it is just about transforming functions. Hence its idiosyncrasies. It turns out one of the transformations you can do is exactly what you said: transform it into its AST (called jaxprs IIRC), then transform that into whatever you want (gradients, inverses, parallel computations, JIT compile it, whatever), then turn that back into a function. And that's exactly how the linked post does reverse mode differentiation a couple pages in (IIRC). That flexibility is both what makes JAX's approach so interesting, and what makes JAX such a PITA to debug.




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: