ComplexSystems 2 days ago

This article is saying that it can be numerically unstable in certain situations, not that it's theoretically incorrect.

  • adgjlsfhk1 2 days ago

    This is less about numerical instability and more that iterative algorithms with error control their error, but when you run AD on them you are ADing the approximation and a derivative of an approximation can be arbitrarily different from an approximation of a derivative.

    • ogogmad 2 days ago

      That makes more sense. The title is flat out wrong IMO.

      • adgjlsfhk1 2 days ago

        I think it is correct. lots of people view AD as a black box that you can throw algorithms to and get derivatives out, and this shows that that isn't true

        • wakawaka28 2 days ago

          If you wrote code that failed to compile, you wouldn't impulsively call your compiler incorrect. This title sounds like it puts the blame in the wrong place. You can get error accumulation from even a basic calculation in a loop. We could try to solve these problems but it's not the algorithm's fault you don't know what you're doing.

  • omnicognate 2 days ago

    Yeah, perhaps the actual title would be better: "The Numerical Analysis of Differentiable Simulation". (Rather than the subtitle, which is itself a poor rewording of the actual subtitle in the video.)

  • goosedragons 2 days ago

    It can be both. A mistake in AD primitives can lead to theoretically incorrect derivatives. With the system I use I have run into a few scenarios where edge cases aren't totally covered leading to the wrong result.

    I have also run into numerical instability too.

    • froobius 2 days ago

      > A mistake in AD primitives can lead to theoretically incorrect derivatives

      Ok but that's true of any program. A mistake in the implementation of the program can lead to mistakes in the result of the program...

      • goosedragons 2 days ago

        That's true! But it's also true that any program dealing with floats can run into numerical instability if care isn't taken to avoid it, no?

        It's also not necessarily immediately obvious that the derivatives ARE wrong if the implementation is wrong.

        • srean 2 days ago

          > It's also not necessarily immediately obvious that the derivatives ARE wrong if the implementation is wrong.

          It's neither full proof or fool proof but an absolute must is a check that the loss function is reducing. It quickly detects a common error that the sign came out wrong in my gradient call. Part of good practice one learns in grad school.

        • froobius 2 days ago

          You can pretty concretely and easily check that the AD primatives are correct by comparing them to numerical differentiation.

      • godelski 2 days ago

        I haven't watched the video but the text says they're getting like 60+% error on simple linear ODEs which is pretty problematic.

        You're right, but the scale of the problem seems to be the issue

deckar01 2 days ago

I have been using sympy while learning electron physics to automatically integrate linear charge densities. It works great symbolically, but often fails silently when the symbols are substituted with floats before integration.

https://github.com/sympy/sympy/issues/27675

Lerc 2 days ago

With floats getting smaller and smaller in ML, It's hard to imagine anyone failing to learn this as one of their early experiences in the field.

The focus should not be on the possibility of error, but managing the error to be within acceptable limits. There's a hour long video there, and it's 3am, so I'm not sure how much of this covers management. Anyone familiar with it care to say?

  • Legend2440 2 days ago

    Floats are getting smaller and smaller only in neural networks - which works ok because NNs are designed for nice easy backprop. Normalization and skip connections really help avoid numerical instability.

    The author is talking about ODEs and PDEs, which come from physical systems and don't have such nice properties.

    • storus 2 days ago

      Skip connections are about preserving identity, not to make it more numerically stable.

      • D-Machine 2 days ago

        I'm not sure your claim is correct, nor particularly precise. Skip connections do propagate feature maps from earlier to later/deeper in the network, and one can model this mathematically as using an identity function to do this, but I wouldn't say that this is "about preserving identity". Also, the reason the identity is useful here is because it prevents vanishing gradients (due to the identity function having an identity tensor as the gradient). Arguably, preventing vanishing gradients is exactly about numerical stability, so I would in fact argue that skip connections are there to make gradient descent of certain architectures more numerically stable.

        EDIT: Just to provide some research support: https://arxiv.org/html/2405.01725v1#S5.SS4.5.1.1. I.e. skip connections do a lot of things, but most of these things boil down to making gradients better conditioned. Insofar as as one problem of poor conditioning is gradient vanishing, and vanishing is really only a problem because of the limited precision of floating point, it is hard for me to not think skip connections are primarily there for numerical stability, broadly.

        One case supporting your point, though, where one might argue the "main" reason for skip connections is propagating information rather than for stability, is in LONG skip connections, as in U-Nets. There, since feature maps shrink with network depth, long skip connections are needed to bring high-resolution information from earlier layers to deeper ones.

        • godelski 2 days ago

          You're correct, I also added some info in a sibling comment that might interest you. But you might also be interested in reading the DenseNet paper[0]. While it's convs, the ideas still apply. There's actually a lot of research in this topic though they'll be in more "theory" papers and the ideas are slow to propagate into the mainstream works.

          [0] https://arxiv.org/abs/1608.06993

          • D-Machine 2 days ago

            Yup, I am pretty up to date with the literature (mostly) and read about DenseNet a long time ago (and have trained trained and tuned quite a few of them too). I've also done enough experimentation with tuning custom CNN architectures (on smaller medical datasets, mind you) where the inclusion of residual connections was a hyperparameter to be tuned. It is pretty obvious that with deeper networks, turning off the skip connections just makes your loss curves hit a premature (and bad) plateau far more often than not.

            Residual connections in CNNs clearly don't change the representational capacity, and don't obviously induce any bias that seems beneficial. I.e. if you needed information from earlier feature maps, the optimizer could surely just learn to make intermediate layers preserve that information (by having some conv-layer channels just basically be identities). Actually, if you inspect a lot of feature maps through visualizations, it is pretty clear that information is lost very gradually, over many, many layers, and it just isn't plausible that typical (short) skip connections are really meaningfully being helpful because they "preserve information". This is in stark contrast to U-Nets, where the skips are long enough, and from layers of differing enough resolutions it is clear the purpose is very different. EDIT: With DenseNet, it is trickier, because some of the residual connections are quite long, and the paper does provide some evidence to suggest that the length is useful. But there are not any long residual connections in Vision Transformers, and it broadly seems like ViTs and modern conv nets (which don't use DenseNet "long" concats) have basically equivalent performance (https://arxiv.org/abs/2310.16764). So the whole idea of long connections being important just doesn't sit right with me. Short, typical residual connection is useful, and this seems most clearly to be about gradients.

            And practically, when suddenly switching on residual connections leads to better training despite an otherwise identical architecture, it is really hard to see them as anything other than a tool for stabilizing optimization.

            • godelski a day ago

                > there are not any long residual connections in Vision Transformers
              
              That's incorrect. If you look at the architecture a little closer you'll see you have very long residuals. The transformer arch is res(norm + attn) + res(norm + FFN). That allows you to just skip a whole layer. Going backwards we take the first shortcut, skipping the Norm + FFN, then can take the second shortcut, skipping the norm + attn. So you skip the whole transformer! It's pretty unlikely that this will happen in practice so you can strengthen that by doing dense transformers but as far as I've seen it isn't really much help and it does look like the gradient propagates the whole way

              I'd recommend looking at the famous 3 Things paper [0]. You'll see what I said in their diagram. It's got some other good gems.

              [0] https://arxiv.org/abs/2203.09795

              • D-Machine a day ago

                Yeah, I am aware that there are sort of long residuals even in classic ViTs, and that, as you say, you can sort of even skip the whole transformer. Like you said, though, this seems very unlikely in practice, and at least, this is a different kind of long residual as in DenseNets or U-Nets though (and yes, Dense Transformers - though I know very, very little about these). I.e. the long residual connections in these seem to be far more "direct" and less "sequential" than the "long residuals" in a classic transformer.

                It is hard for me to say what the different consequences for training and gradients are between these two kinds of long residuals, that sounds like more your expertise. But, practically, if you implement your own e.g. DenseNet and torch `forward` calls with Conv layers and adds (or concats), and then implement your own little ViT with multiple MultiheadAttention layers, these really don't feel like the same things at all, in terms of the values you need to save access to, and what you pass in to deeper layers. Just doing a bit of research, it seems like these dense residual transformers are being used for super-resolution tasks. This again seems like the U-Net long residuals, in that the functionality here from the direct long residuals is again about more efficient information propagation, and less clearly about gradients, whereas the "sequential" long residuals implicit in transformers feels again more like a gradient thing.

                But, I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts. I've also often worked with smaller datasets and more unusual data (e.g. 3D or 4D images like MRI, fMRI, or multivariate timeseries like continuous bedside monitoring data), also often with a limited training budget, so my focus has been more on practical differences than theoretical claims / arguments. The DenseNet and "direct" long residual architectures (e.g. U-Net) tended to be unworkable or inordinately expensive for larger 3D or 4D image data, because you have to hold so much monstrously large tensors in memory (or manually move between CPU and GPU to avoid this problem) for the long direct skips. Absent clear performance (or training efficiency) evidence for these architectures made me skeptical of the hand-wavey "feature reuse" claims made in their support, especially when the shorter more sequential residuals (as in classic ViTs, or, in my case, HighResNet for 3D images: https://arxiv.org/abs/1707.01992) seemed just obviously better practically in almost every way.

                But of course, we still have much to learn about all this!

                • godelski a day ago

                    > you can sort of even skip the whole transformer
                  
                  I don't mean "sort of", I mean literally.

                    > I am definitely NOT an expert here, I just have done a lot of practical twiddling with custom architectures in academic research contexts
                  
                  If a PhD makes me an expert, then I am. My thesis was on the design of neural architectures
                  • D-Machine a day ago

                    >> I don't mean "sort of", I mean literally.

                    Well, then we disagree, or are talking past each other, and I think writing out the equations and code show that, as I said, these are really not exactly the same thing. From a code standpoint, direct, long skips require retaining copies of earlier "x" values, and this is a memory cost that is a problem for certain purposes. Mathematically, this also means in a long, direct skip, you are adding in that exact earlier x value.

                    In the sequential or indirect "long skips" of a transformer, this is not the case. Yes, if you write the equations you can see there is a "path" for identity information to theoretically flow from any layer to any layer unmolested, but in practice this is not how it is implemented, and identity information is not flowing through the layers unchanged.

                    If everyone thought these subtle differences were irrelevant, than I am not sure why anyone would bother with making a dense residual transformer over the classic transformer. EDIT: nor would much of the papers incorporating special additional long skip connections to various transformer architectures make much sense. The point I was merely making was that, long skips generally serve a very different purpose than shorter / classic residual connections.

      • godelski 2 days ago

        Residual connections create numerical stability with respect to the gradient and network updating. Look up the vanishing gradient problem. We wouldn't be able to build these deep networks without them.

        It's easier to see looking at ReLU because you'll clearly get dead neurons, when the gradient hits 0 you can't recover.

        This can also me mitigated by modem activation functions that are smooth but the other problem is that your gradient signal degrades as you move further down the network. Residuals help keep a stronger signal being able to propagate through the whole network. You can see this one if you run the numbers by hand for even not so deep of a network.

        So just saying, "numerically stable" means lots of things...

        • D-Machine 2 days ago

          Well, you can build deep networks without residual connections, as long as you keep normalizing and initialize very carefully. I'm sure there are other clever ways to keep things well conditioned, like fancy reparameterizations or e.g. spectral norms. Just residual connections are dead simple to implement and reason about, and fast, to boot, so it isn't a huge surprise they feature in most of the most successful models.

          And yes, "numerically stable" is a vagueness here. I had implied above that gradient vanishing wouldn't be such a problem with infinite precision, but, actually, this isn't true, if you think about it a bit. You're still going to get stuck on plateaus even with infinite precision. In this sense, if by "numerically stable" one means "with respect to floating point issues", well, then, GP is correct, because that kind of numerical instability is actually just not really a serious issue for training modern networks (at least, certainly not if training with fp32). Instability in optimization is because gradient noise (from e.g. mini-batch sampling) will swamp the useful gradient signal when gradients get too small, i.e. you start doing a random walk.

          • godelski a day ago

            Normalization doesn't fix the vanishing gradient problem. Like you point out, it doesn't go away with infinite precision because that's not the problem. (I've got a Ph.D in this btw)

            • D-Machine a day ago

              Yup, normalization doesn't fix vanishing gradients. It just sort of generally helps with conditioning (as you likely know, given the PhD, probably in part because it helps keep certain Lipschitz constants from being too large), just like typical residual connections generally help with conditioning. But yeah, in that sense, normalization is more about exploding gradients, and residual connections are more about vanishing. A bit sloppy on my part.

ChrisRackauckas 2 days ago

Interesting to see this here! This example is one of the ones mentioned in the appendix of https://arxiv.org/abs/2406.09699, specifically "4.1.2.4 When AD is algorithmically correct but numerically wrong".

If people want a tl;dr, the main idea is you can construct an ODE where the forward pass is trivial, i.e. the ODE solution going forwards is exact, but its derivative is "hard". An easy way to do this is to make it so you have for example `x' = x - y, y' = y - x`, with initial conditions x(0)=y(0). If you start with things being the same value, then the solution to the ODE is constant since `x' = y' = 0`. But the derivative of the ODE solution with respect to its initial condition is very non-zero: a small change away from equality and boom the solution explodes. You write out the expression for dy(t)/dy(0) and what you get is a non-trivial ODE that has to be solved.

What happens in this case though is that automatic differentiation "runs the same code as the primal case". I.e., automatic differentiation has the property that it walks through your code and differentiates the steps of your code, and so the control flow always matches the control flow of the non-differentiated code, slapping the derivative parts on each step. But here the primal case is trivial, so the ODE solver goes "this is easy, let's make dt as big as possible and step through this easily". But this means that `dt` is not error controlled in the derivative (the derivative, being a non-trivial ODE, needs to have a smaller dt and take small steps in order to get an accurate answer), so the derivative then gets error due to this large dt (or small number of steps). Via this construction you can make automatic differentiation give as much error as you want just by tweaking the parameters around.

Thus by this construction, automatic differentiation has no bound to the error it can give, and no this isn't floating point errors this is strictly building a function that does not converge to the correct derivative. It just goes to show that automatic differentiation of a function that computes a nice error controlled answer does not necessarily give a derivative that also has any sense of error control, that is a property that has to be proved and is not true in general. This of course is a bit of a contrived example to show the point that the error is unbounded, but then it points to real issues that can show up in user code (in fact, this example was found because a user opened an issue with a related model).

Then one thing that's noted in here too is that the Julia differential equation solvers hook into the AD system to explicitly "not do forward-mode AD correctly", incorporating the derivative terms into the time stepping adaptivity calculation, so that it is error controlled. The property that you get is that for these solvers you get more steps to the ODE when running it in AD mode than outside of AD mode, and that is a requirement if you want to ensure that the user's tolerance is respected. But that's explicitly "not correct" in terms of what forward-mode AD is, so calling forward-mode AD on the solver doesn't quite do forward mode AD of the solver's code "correctly" in order to give a more precise solution. That of course is a choice, you could instead choose to follow standard AD rules in such a code. The trade-off is between accuracy and complexity.

  • dleary 2 days ago

    Thank you for this good description.

  • ogogmad 2 days ago

    I don't think this is a failure of Automatic Differentiation. It's a failure of assuming that the derivative of an approximation is equal to an approximation of the derivative. That's obviously nonsense, as the example f(x) = atan(10000 x)/10000, which is nearly 0, shows you. This was first pointed out here in this comment: https://news.ycombinator.com/item?id=45293066

    • adgjlsfhk1 2 days ago

      the part that's really weird is that AD normally produces extremely accurate derivatives so the rare algorithms where it doesn't are very surprising at first.

drnick1 2 days ago

The authors seem to push a Julia implementation, but the issue is that no one uses Julia...

  • eigenspace a day ago

    The context here is that the author is the principal developer of one of the biggest (if not the biggest) and most performant differential equation solver suites in the world[1].

    It's written in julia, and it is very widely used, and known to be state of the art.

    [1] https://github.com/SciML/DifferentialEquations.jl

    • drnick1 13 hours ago

      Would it not make more sense to write such a tool in C? The idea here is that C is universal, and higher level languages like Python, R, Julia and others all have a foreign function interface for C.

      • adgjlsfhk1 3 hours ago

        No. C isn't an expressive enough language to write many solvers in. Many of the best solvers are basically impossible to write without meta-programming (where you write a generator for the n-stage version that is then able to generate the code for all the possible permuations), and there are also massive performance gains through symbolic-numeric approaches where the solver does introspection of the user function to rearrangements that aren't possible in a black-box framework.

      • eigenspace 3 hours ago

        Good luck being that productive in C, distributing the software adequately, and attracting collaborators.

ismailmaj 2 days ago

The way I explained it to myself in the past why so much of the CUDA algorithms don't care much about numerical stability is that the error is a form of regularization (i.e. less overfitting over the data) in deep learning.

  • Nevermark 2 days ago

    I am not quite sure what that means! :)

    But reasons why deep learning training is very robust to moderate inaccuracy in gradients:

    1. Locally, sigmoid and similar functions are the simplest smoothest possible non-linearity to propagate gradients through.

    2. Globally, outside of deep recurrent networks, there is no recursion which makes the total function smooth and well behaved.

    3. While the perfect gradient indicates the ideal direction to adjust parameters, for fastest improvement, all that is really needed to reduce error is to move parameters in the direction of the gradient signs, with a small enough step. That is a very low bar.

    It's like telling an archer they just need to shoot an arrow so it lands closer to the target than where the archer is standing, but not worry about hitting it!

    4. Finally, the perfect first order gradient is only meaningful at one point of the optimization surface. Moving away from that point, i.e. updating the parameters at all, and the gradient changes quickly.

    So we are in gradient heuristic land even with "perfect" first order gradients. The most perfectly calculated gradient isn't actually "accurate" already.

    To actually get an accurate gradient over a parameter step, would take fitting the local gradient with a second or third order polynomial. I.e. not just first, but second and third order derivatives. At vastly greater computational and working memory cost.

    --

    The only critical issue for calculating gradients, is that there is enough precision that at least directional gradient information makes it from errors back to the parameters to update. If precision is too low, then the variable magnitude rounding inherent to floating point arithmetic can completely drop directional information for smaller gradients. Without accurate gradient signs, learning stalls.

    • ismailmaj 2 days ago

      Typically for matrix multiplications there is a wide range of algorithms you could use to compute it, on one extreme end you could use numerically stable summation and the other extreme you could have tiled matmul with FP8, the industry trend seems to go further away from numerical stable algorithms without much quality drop it seems. My claim is probably unfair since it ignores the scale you gain from the speed/precision tradeoff, so I assumed numerical stability is not that beneficial compared to something precision heavy like physics simulation in HPC.

      • Nevermark 2 days ago

        > I assumed numerical stability is not that beneficial compared to something precision heavy like physics simulation in HPC.

        Yes, exactly.

        For physics, there is a correct result. I.e. you want your simulation to reflect reality with high accuracy, over a long chain of calculations. Extremely tight constraint.

        For deep learning, you don't have any specific constraints on parameters, except that you want to end up with a combination that fits the data well. There are innumerable combinations of parameter values that will do that, you just need to find one good enough combination.

        Wildly different.