Backprop is just an implementation detail when doing automatic differentiation, basically setting up how you would apply the chain rule to your problem.
JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away.
If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm.
This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details.
I haven't heard about JAX before, but been tinkering in pytorch. Would I also be able to switch the use of np arrays here to torch, and then do .backwards() and get kinda the same benefits of JAX, or how does it differ in this regard?
JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away.
If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm.
This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details.