Based on what works elsewhere in deep learning, I see no reason why you couldn't train once with a randomized number of experts, then set that number during inference based on your desired compute-accuracy tradeoff. I would expect that this has been done in the literature already.
Right, a lot of the promise of AI can (and has) been achieved with better tool design. If we get the AI to start writing Assembly or Machine Code as some people want it to, we're going to have the same problems with AI writing in those languages as we did when humans had to use them raw. We invented new languages because we didn't find those old ones expressive enough, so I don't exactly understand the idea that LLMs will have a better time expressing themselves in those languages. The AI forgetting to free memory in C and having to go back and correct itself is a perfect example of this. We invented new tools so we wouldn't have to do that anymore, and they work. Now we are going backwards, and building giant AI datacenters that suck up all the RAM in the world just to make up for lost ground? Weak.
> We invented new languages because we didn't find those old ones expressive enough
Not quite. Its not about being expressive enough to define algorithms, its about simplification, organization and avoidance of repetition. We invented languages to automate a lot of the work that programmers had to do in a lower level language.
C abstracts away handling memory addresses and setting up frame stacks like you would in assembly.
Rust makes handling memory more restrictive so you don't run into issues.
Java abstracts away memory management completely, so you don't need to manage memory, freeing up you to design algorithm without worrying about memory leaks (although apparently you do have to worry if your log statements can execute arbitrary code).
Javascript and Python abstract type definition away through dynamic interpretation.
Likewise, OOP/Typing, functional programming, and other styles were included for better organization.
LLMs are right in line with this. There is no difference between you using a compiler to compile a program, vs a sufficiently advanced LLM writing said compiler and using it to compile your program, vs LLM compiling the program directly with agentic loops for accuracy.
Once we get past the hype of big LLMs, the next chapter is gonna be much smaller, specialized LLMs with architecture that is more deterministic than probabilistic that are gonna replace a lot of tools. The future of programming will be you defining code in a high level language like Python, then the LLM will be able to infer a lot of the information (for example, the task of finding how variables relate to each other is right in line with what transformers do) just from the code and do things like auto infer types, write template code, then adapt it to the specific needs.
In fact, CPUs already do this to a certain extent - modern branch predictors are basically miniature neural networks.
Yes, this is part of the French prépa/CPGE system, which is the "standard" way for students to enter elite engineering schools. You do your first 2-3 years of undergrad in prépa.
I have never myself seen a situation where cited sources on Wikipedia did not back it up where that fact wasn't already noticed and called out by someone else. It is a frequent and common occurrence with LLMs.
JAX code usually ends up being way faster than equivalent torch code for me, even with torch.compile. There are common performance killers, though. Notably, using Python control flow (if statements, loops) instead of jax.lax primitives (where, cond, scan, etc).
Interesting. Thanks for you input. I already tried to adhere to the JAX paradigm as laid out in the documentation so I already have a fully static graph.
I would test how much of the total flop capability of the hardware you are using. Take the first order terms of your model and estimate how many flops you need per data point (a good guide is 6*param for training if you mostly have large multiplies and nonlinearity/norm layers) and then calculate the real time performance for a given data size input vs the actual expected theoretical max perfomance for the given GPU (eg 1e15 FLOPs/s for bfloat16 per H100 or H200 GPU). If you are already over 50% it is unlikely you can have big gains without very considerable effort, and most likely simple jax or pytorch are not sufficient at that point. If you are at the 2–20% range there are probably some low hanging fruit left and the closer you are to using only 1% the easier it is to see dramatic gains.
In other words, the required amount of data scales with the square root of the compute. The square root of 2 ~= 1.414. If you double the compute, you need roughly 1.414 times more data.