Hacker Newsnew | past | comments | ask | show | jobs | submitlogin
Autodidax: JAX core from scratch (jax.readthedocs.io)
107 points by sva_ on Feb 11, 2023 | hide | past | favorite | 12 comments


I love this. I wish every big library had something like this. It helped me contribute to JAX in the past, and is a great educational resource and source of inspiration for my own tools.

I’ve tried to find something similar for pytorch and numpy in the past and was let down.


But seriously, for open source projects actually looking for contributors (vis-a-vis companies with a developer product they just make open source), there is no better resource.

To understand a framework, you need a mental model of it. Good documentation is helpful, but it seldom walks through why specific design choices were necessary.


> I’ve tried to find something similar for pytorch and numpy in the past and was let down.

Well I got a treat for you then

https://minitorch.github.io/


I don’t mean how torch's API could be implemented, I mean how pytorch implements it. Do you know which this is?


I think the closest thing for pytorch is the Karpathy video:

https://www.youtube.com/watch?v=VMj-3S1tku0

There's also an old book on the internals of numpy - not sure how out of date it is though.


Same! I looked into how pytest worked and got bogged down in its plugin system (which is great, but very distracting when trying to read the code). I've always been curious how the assertion rewriting works.


This is great if you want to learn JAX, but it seems like not the most efficient way to learn how to implement symbolic differentiation or automated generation of derivative functions on a computer. Is there something that takes a more direct route, maybe with Scheme or something? (I don't care about the language, just whatever presents the least bureaucracy and overhead.)


I'm sure there's a lot of good material around, but here are some links that are conceptually very close to the linked Autodidax. (Disclaimer: I wrote Autodidax and some of these other materials.)

There's Autodidact [0], a predecessor to Autodidax, which was a simplified implementation of the original Autograd [1]. It focuses on reverse-mode autodiff, not building an open-ended transformation system like Autodidax. It's also pretty close to the content in these lecture slides [2] and this talk [3]. But the autodiff in Autodidax is more sophisticated and reflects clearer thinking. In particular, Autodidax shows how to implement forward- and reverse-modes using only one set of linearization rules (like in [4]).

There's an even smaller and more recent variant [5], a single ~100 line file for reverse-mode AD on top of NumPy, which was live-coded during a lecture. There's no explanatory material to go with it though.

[0] https://github.com/mattjj/autodidact

[1] https://github.com/hips/autograd

[2] https://www.cs.toronto.edu/~rgrosse/courses/csc321_2018/slid...

[3] http://videolectures.net/deeplearning2017_johnson_automatic_...

[4] https://arxiv.org/abs/2204.10923

[5] https://gist.github.com/mattjj/52914908ac22d9ad57b76b685d19a...


I find the solutions from https://github.com/qobi/AD-Rosetta-Stone/ to be very helpful, particularly for representing forward and backward mode automatic differentiation using a functional approach.

I used this code as inspiration for a functional-only (without references/pointers) in Mercury: https://github.com/mclements/mercury-ad


Automatic differentiation is a huge field going back to at least the 1960s, so you’re right that this is an idiosyncratic way to do it. If you want to learn about it more generally, “automatic differentiation” is the phrase to google. Tinygrad and micrograd are particularly small libraries you might enjoy looking at.


Thanks! It seems like the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function". And this JAX post doesn't really speak in those terms, at least not directly.


> the direct route should be something like "turn the function into it's AST, apply derivative rules to transform the AST, turn the result back into a function"

A couple points on that:

1) The most direct route can be even simpler (what's called forward mode differentiation)! You want the derivative of a function in some direction at some point, and you can do that by just passing in the point and the direction. If every function involved knows how to transform the point and transform the direction, then you just evaluate it step by step, no transformations required. This is the "jvp" approach in OP.

2) Something that is often misunderstood about JAX is that JAX isn't just about taking derivatives. A large part of it is just about transforming functions. Hence its idiosyncrasies. It turns out one of the transformations you can do is exactly what you said: transform it into its AST (called jaxprs IIRC), then transform that into whatever you want (gradients, inverses, parallel computations, JIT compile it, whatever), then turn that back into a function. And that's exactly how the linked post does reverse mode differentiation a couple pages in (IIRC). That flexibility is both what makes JAX's approach so interesting, and what makes JAX such a PITA to debug.




Consider applying for YC's Fall 2025 batch! Applications are open till Aug 4

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: