Hacker Newsnew | past | comments | ask | show | jobs | submit | fchaubard's commentslogin

exactly


Layman Abstract: Transformers keep around all previous tokens for each generated token, so they take up ENORMOUS gpu memory and cost during inference. But humans do not, we page in / out of our small, fixed-size "working memory", keeping around only the important information of the past.

RNNs are more like us, they compress all previous tokens into a small fixed-sized memory. However, we can't train them with legacy backprop through time (BPTT), because it doesnt scale and suffers exploding/vanishing gradients.

So we discovered a 1992 zero order algorithm to replace BPTT, and not only does it scale amazingly well, in some cases, it trains 19x faster than BPTT! So maybe with this, RNNs can replace transformers?


Yes. It’s more of a class spanning thing. I wanted batch composition across the two microbatches to be the same. So if you have class 1,2,3 in batch one and class 4,5,6 in class two I would fully expect the cosine distance to be orthogonal or worse, and it could be a good update. But if you have class 1,2,3 in batch one and class 1,2,3 in class two I would fully expect the cosine distance to be positively correlated and if not you should skip. So you could bring this to MB of size 5 for example but just make sure you have the same batch composition. This poses a big challenge in LLM training honestly bc technically classes is vocab size. So I need one “a”, one “b”, etc which is silly. This is why micro gradients in LLMs hit cosine distance of 2. So when you are sampling you kind of need to ensure the microbatches are of the same task at least.


I think about designing your ideal solver. What do you want in a solver. I want my solver to squeeze all the juice out of the train that it possibly can and no more. If your problem is complete noise, I don’t want my solver getting 100% train accuracy as all SGD methods I am aware of do and as the soup method likely would as well as I am not averaging memorization thetas. I want my solver to score 0% train accuracy as GAF does. There may be other ways of getting there as well.


Yes! I think this a great area of research. If you think of the gradient values as a blame score for why you got the answer wrong, then you can have a lot of fun with exploring which weights light up for different problems. A note, in Ring All Reduce they actually don’t ever share the FULL gradient but instead blocks. So to put this into practice you’d have to show that you can do the thresholding on the block of gradients vs the full gradient which you may never be able to fit in VRAM. Will results still hold? I don’t know. I believe it would but that’s for the next paper.


Very cool! Glad to hear my intuition is on the right track… I’m very much on the applied ML for engineering design side as opposed to the bleeding edge research side, so in terms of multi-node training I haven’t done much more than spin up a few GPUs and let PyTorch Lightning handle the parallelism, but cool to try to keep up with this stuff.

Thanks for the response and good luck with this!


I’ll do my best to answer here.

> Do you expect instability between successive macrobatch gradients? That is, why are you comparing microgradients within a single batch, adding a whole bunch of serialization headaches, rather than comparing with the macrogradient of the previous step?

>> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead.

> Given your test setup of noisy labels, isn't the sequential accumulation of microgradients dangerous? Suppose we take this to the limit of a minibatch size of 1, and the first microgradient happens to come from an example with an incorrect label. If I understand this correctly, gradient filtering would seek to add all the gradients that are consistent with the bad example, rejecting the gradients that belong to good examples.

>> Yes but “consistent with the bad example” is nearly impossible. The gradient directions in late stages of training without noisy labels are already orthogonal or worse.. if you flip the label of any 2 samples and do a MB of size 1 on it they will all be negatively correlated to each other so you will practically always skip with GAF. However, in standard SGD you will ALWAYS average them in until you’ve completely memorized the noisy samples.

> The filtered gradients are used via SGD with momentum (although equation (6) looks like momentum-free SGD). Have you seen / do you expect different results when supplying the filtered gradients to Adam/AdamW or other, more sophisticated optimizers?

>> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.

> Your thresholding is binary, either accepting or rejecting an entire microgradient. Have you tested soft thresholding? Is there an information-theoretic way to explain this effect?

>> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.

> In figure 7, why does GAF with a large threshold result in a lower validation accuracy than the baseline? In GAF-terms, the baseline accepts every microgradient, so I'd expect GAF to converge to the baseline result as the threshold increases. What does the figure-7-but-0%-error curve look like?

>> Good call out. Yes that wasn’t intuitive to me. You are correct that when Tau hits 2 it does converge to baseline as expected. But at 1.05 it actually does worse than baseline in the presence of 5% noise. So as you increase Tau above 1, which I never recommend doing, it starts to underperform baseline in the presence of noise then by 2 it matches. But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.


> I do. If you take a sufficiently large step, the path of steepest descent will surely change sometimes. If it doesn’t then you should just double or triple your step size. So you just don’t know why the cosine distance is high, change in curvature of your loss curve or gradient variance. Most large runs are splitting up gradients across nodes, so if you are already doing so, instead of averaging, just do GAF instead.

I agree with normal SGD, but with-momentum optimizers depend on some consistency of gradients between optimizer steps. On the other hand, with-momentum optimizers try to maximize the effective learning rate subject to that momentum, so it could go the other way as well.

> No, won’t change the results. I have been using GAF to train LLMs for my next paper and it stands.. in fact in much more expressive and large models like LLMs, the gradients sometimes hit cosine distance of 2! So GAF really helps in LLM training. Think of SGD training like a tracker. The tracker is getting noisy distance/velocity signals at some frequency. And the tracker is only as good as the signal coming into the tracker. If a bird flies in front of the radar you will get a huge OOD signal into the estimate if you don’t do some sort of validation gating (eg ||st-x_hat_t||>thresh). Think of GAF as the validation gating.

> Great question. I tried to prove formally that if the cosine distances between two randomly selected batches are negatively correlated then the average of them will result in overfitting but I couldn’t get the proof to a satisfactory spot. But I do conjecture it. So no I wouldn’t expect taking any part of a memorization direction is a good idea.

Maybe the answer lies in asking what's optimized by averaging.

For learning, we're interested in the intractable problem of the gradient of parameters with respect to the whole distribution of data. In practice, we only compute the gradient of parameters with respect to samples drawn from the data distribution, leading to stochastic gradient descent. SGD-with-momentum makes the additional assumption that the steepest descent path has relatively low curvature, so the mean gradient of previous batches is still informative.

Overall, this approach is still optimal if you imagine that the computed sample gradients are corrupted with mean-zero Gaussian noise: averaging over many samples is the best way to eliminate that noise.

Your work identifies and rejects outlier gradients. In a very hand-wavy way, this is kind of like a median filter, and a median filter is great at rejecting shot noise. I speculate that this is why your technique is particularly good for your examples with corrupted labels, since that corruption process replaces single samples with something completely uninformative.

This is why I also wonder about soft thresholding. A soft threshold could be interpreted as an intermediate rather than binary belief about whether a sample is informative, or it could be interpreted as belief that a sample has more than zero but less than the typical amount of information.

> But for practical reasons I have found Tau=[0.92-0.999] to be the sweet spot. I wouldn’t go outside that range.

If it would be easy to add (that is, if you still have the data on hand), might I suggest adding a subpanel to figure 7 noting the fraction of minibatches that are accepted/rejected with each threshold? If you're hypothetically rejecting 80% of minibatches at the optimum threshold, it'd hint that your method is finding the golden kernel of most representative data to learn from; in contrast if you're hypothetically rejecting just a couple of percent then it'd hint that your method is more narrowly finding the corrupted samples. Either range (or anything in between) would be interesting.


Hey thanks! Ya we tried similar strategies to this and could not beat cosine distance < tau, average, else skip. It was too much to put in the paper and we may put it in the arxiv version but we tried Sign AND gating and zero’ing out if the signs don’t agree, we tried L2<tau, etc but nothing beat cosine distance.


Yes it will allow stable training at much smaller batch sizes. Test it out and let us know if it works for your use case!


Here too!


Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: