Slaying OOMs with PyTorch FSDP and torchao

fine-tuning
llm-conf-2024
Published

June 11, 2024

Abstract

Have you ever hit an OOM (and wished you had more VRAM)? If you’ve done much fine-tuning, then you have. And if you are just starting, then you will. Hop on the bus with us and feel the road become smoother as we talk about stacking together techniques like FSDP2 + QLoRa+ CPU Offloading + Fused ADAM (thanks Intel) + more in PyTorch native.

Chapters

00:00 Introduction

Mark introduces the session on addressing Out of Memory (OOM) errors in PyTorch, discussing tools and techniques to handle these issues more effectively.

00:30 Traditional Solutions to OOMs

Mark describes conventional methods of dealing with OOMs, such as reducing batch size or model size, and the limitations of these approaches.

00:48 VRAM Constraints

Mark explains VRAM constraints on different GPUs and how it impacts model training, emphasizing the perpetual challenge of being VRAM-starved.

01:24 Estimating VRAM Requirements for Your Model

Mark outlines the components involved in estimating a model’s memory usage, including parameters, gradients, and optimizer states.

06:06 Quantization Techniques

Mark introduces quantization techniques, such as 4-bit quantization, to reduce model size and memory requirements. He also demonstrates using Torch compile to generate efficient quantization kernels, avoiding the complexity of writing custom CUDA kernels.

09:27 LoRA

Mark introduces the LoRa technique for updating a subset of parameters to save memory.

09:56 QLORA Algorithm

Mark details the QLORA algorithm, combining quantized parameters with selective parameter updates to enable efficient fine-tuning.

10:51 Implementing QLORA with PyTorch

Discussion on implementing QLORA with PyTorch, highlighting the complexity of writing efficient kernels and the benefits of using Torch compile.

14:38 Introducing Jane’s Section on Model Parallelism

Mark hands over to Jane to discuss parallelism techniques and how to manage memory across multiple devices.

15:20 Understanding Memory Allocation During Training

Jane illustrates memory allocation during training, showing the impact of activations, gradients, and optimizer states. Jane also explains data parallelism and model sharding as techniques to distribute memory load across multiple GPUs.

17:45 Fully Sharded Data Parallel (FSDP)

Jane introduces Fully Sharded Data Parallel (FSDP) and its mechanism to manage memory efficiently by sharding model parameters.

21:49 CPU Offloading

Jane discusses CPU offloading as a method to handle memory constraints by temporarily storing parameters on the CPU during training.

23:05 Challenges and Improvements in FSDP

Jane outlines the challenges with FSDP1 and introduces FSDP2, which offers more flexibility and efficiency in managing memory and data types.

29:50 Identifying and Addressing Performance Gaps

Jane discusses the process of identifying performance gaps in FSDP2 and the steps taken to optimize and match the performance of FSDP1. Jane discusses benchmarking and profiling techniques that are helpful in debugging performance.

37:06 Overcoming Debugging Challenges

Jane shares insights from debugging and optimizing the performance of FSDP2, highlighting the importance of detailed trace analysis. She also explains the impact of wrapping policy on memory usage.

47:38 How You Can Get Started

Jane encourages students to try this process themselves in torchtune.

Slides

Download PDF file.

Resources

Links to resources mentioned in the talk:

Full Transcript


[0:00] Mark: Hi, everyone. We’re here to talk about slaying OOMs. OOMs are maybe the notoriously and one of the most annoying bugs to deal with in PyTorch code. Both Jane and I are developers on PyTorch Core at Meta. We wanted to talk to you about a lot of the tools we’ve been building to make this process a bit easier to deal with. So traditionally, the way I’ve seen people deal with OOMs is this way, which is basically people see an OOM, they’re like, OK, crap, what do I do?
[0:30] Mark: And then the sort of two nuclear options are you either reduce your batch size or you reduce the size of your model. You go, oh, this thing’s too big. Let me just have something smaller. But this is a very coarse tool. And there’s certainly a lot more finer-grained things you could do with a bit more knowledge. So The first constraint you have is essentially like how much VRAM you have on your chip. So for example, for 3090s and 4090s, which are very popular consumer cards, you have about 24 gigs.
[0:59] Mark: And then for the A100s, you have like either 40 or 80. I think the H100 is like about 100, if I recall correctly. But my point is, is that like you’re always going to be VRAM starved. And specifically for consumer cards, you know, if I were to speculate, like I would speculate that the 5090 probably also has like around 24 gigs of VRAM. And so we’re always going to be in the VRAM-constrained environment.
[1:24] Mark: But again, as we’re thinking about memory for a model, instead of just thinking about, oh, it’s like something is blah gigabytes, let’s be a bit more concrete about what’s involved in estimating the size of a model. So you have like three core buckets. So for example, let’s say you say, oh, I’m downloading Lama 7b. 7b is referring, 7b is the number of parameters. And if each of those parameters is an FB16. then you need two bytes per parameter, which means that the total size of the model is about like 14 gigs.
[1:55] Mark: And indeed, if you download like Lama 7B on your desk, like that’s roughly how big it’s going to be. But that’s not enough. Like, but if that was enough, you know, you could just like run Lama 7B on a 3090, but you can’t. And how come, right? So the reason why is like, well, like if you’re doing fine tuning or training, you also need to basically per parameter, you have gradients and your gradients will also be in FP16. And so… you basically end up with another 14 gigs.
[2:22] Mark: And then finally, like a detail everyone seems to forget all the time is also the size of your optimizer. So the single most popular optimizer used in the world is Atom. And Atom is like the amount of memory that Atom takes is twice the amount of parameters. So basically, if your parameters are 14 gigs, Atom takes 28. So if you sum all of these up, you get like 14 plus 14 plus 28, which is 56 gigs, which is bigger than 40 gigs, so bigger than most GPUs that people have.
[2:46] Mark: So this is sort of the traditional, what people would refer to as a full fine tune. I also sort of neglected to mention activations. So activations are basically the inter… Let’s say you’re running… For example, you have your weights. It takes times a certain input, so WX. The output of that is your activations. Activations tend to dominate the VRAM of your model. Larger batch sizes and context length. That’s why optimizations like Flash Attention are really important. I’m not going to talk too much about activations like Bert in the beginning.
[3:21] Mark: But they’re a bit harder to estimate. They don’t have as clean of a formula as the gradient sun optimizers do. But there are tools that can help you estimate it, which are a bit better than just doing math that I found can be a bit finicky and error-prone. Anyway, so the first thing you might think, like, looking at this picture, you’re like, okay, why the heck does Adam take this much memory? Like I’m going to instead use another optimizer, maybe like SGD, which has no memory overhead. Sorry, can I ask a question real quick?
[3:49] Mark: Is there a rule of thumb about how much, like, you know, for the activations that you kind of try to give yourself headroom for? Um, I usually always estimate it. Like, Jane, have you found a good heuristic for it?
[4:07] Jane: Well, for activations, it usually corresponds to your batch size. Estimating it is doable for things like Lama or transformers, where you could literally sit down, do some math, and figure out the sizes of everything. But otherwise, for other models, there’s no straight-up formula like, oh, Adam is 2x per rams, that type of stuff.
[4:28] Mark: Yeah, so for what it’s worth, though, the reason I showed this picture is if you sort of, as you slowly, as the batch size gets to about 1,000 or the context length gets to about 3,000 or so, 99% of the memory overhead of your model is going to be activations. So I just think of it more mentally as if I go to the large batch size and context length, this is the bottleneck. Otherwise, it’s not.
[4:52] Mark: So again, like back to Adam, like you might think like, Hey, like there’s, there’s always a new optimizer that comes out and people say, Oh, like there’s this new fancy optimizer. It’s more memory efficient. The problem with a lot of that work is that like, Atom is basically so used and so popular because it works. And there’s like tons of, there’s like, there’s a close to a decade worth of like papers and people show anecdotal results showing that it’s like a great optimizer.
[5:19] Mark: And that hasn’t been true for a lot of like newer optimizers that have been introduced. So conceptually, you might think this is the bottleneck, but it ends up being like a very poor first thing to try to make, like replacing Atom is very, very challenging as a first step. OK, so let’s instead take a look at the parameters. So like we said, basically, at Lama7b, we have the parameters, and we have 14 gigs at FP16. And you might have heard of 4-bit quantization.
[5:44] Mark: And so what we’re going to do is we’re going to basically take every one of those weights and turn them into int4. For an int4, it’s actually a sub-byte D-type. So basically, every int4 is half a byte. And so roughly, you get a model that’s like about 3 and 1 half gigs this way. So yeah, great. So we found basically a good approach to dealing with the parameters. Conceptually, the way this works, it’s not like a two call.
[6:12] Mark: The way a lot of quantization kernels look like, basically, if you wanted to, for example, cast FP32 to an int8, generally the formulas look very similar to this, which is you basically go over all of the elements, every element of your vector or your tensor, you find what the maximum value is, and then you use that to figure out how to correctly scale the values in the new int8 domain. So the formulas for this ends up looking very much like this.
[6:40] Mark: And basically, quantizing is going from Fp32 to int8, and then dequantizing is going from int8 back to Fp32. So this terminology is very common in the community. Great. So you might think now, well, how do we make sure that this process is fast? And historically, you would rely on people writing custom CUDA kernels for this. But those are hard to write.
[7:02] Mark: And so as a good workaround, like what I’ve been using a lot has been Torch compile in this case, where I just made a couple of simple changes to this code, where I decorated the functions with Torch compile. And then I also added this environment. Okay, I see a question by Roy. You have to unmute yourself. All right. Maybe I’ll keep going then. So what we’ve done here is we’ve just decorated these functions with Torch compile. And there’s also this very handy environment variable called Torch logs output code that I use all the time.
[7:43] Mark: And if you enable this, you can actually code generate an efficient Trident kernel that basically, so in this case, this is how the quantization code would work. So the point is that you don’t really need to learn How to write CUDA or Trident kernels in this case. So you can use it as a starting point and get efficient kernels out of it. So this doesn’t require a lot of domain expertise. See, Andres also has a question. All right, people, can they unmute? Maybe that’s the problem.
[8:16] Jane: There’s also questions in the Q&A. I can read some of them if you want to address them now, Mark.
[8:21] Mark: Yeah, let’s do that. Thank you.
[8:22] Jane: Okay, the first one says, how do these calculations for memory also apply to machines with multiple GPUs to get to 24 gigabytes, like two 12-gigabyte cards?
[8:31] Mark: 212. I see. So regarding for like if you hypothetically, you know, if NVIDIA released like a hypothetical GPU with like 200 gigs of VRAM, like the same sort of calculations would help. But like for multiple GPUs, I think like Jane’s going to be talking a lot more about how this works with like sharding, model sharding. So I’m not going to be covering that right now. Oh, this is going to come in a few minutes. Cool. Cool. So let’s go back to gradients. So remember we had the parameters, then we have the gradients.
[9:01] Mark: The gradients is you have like, it’s again, like you need a gradient per model parameter. So let’s say we quantize the gradients to four bits. Like, would this work? And the answer is it simply does not anecdotally. Like, your model will not converge because there’s just, like, not information in the backwards pass. So this is just, like, scientifically, if you were to just run convergence studies, you wouldn’t get anywhere doing this. So that’s how we know this is not, like, very fruitful. Okay. But there’s other tricks.
[9:29] Mark: Like, basically, the main trick to make this work is LoRa. And what LoRa is, is basically, I think a lot of people might have already talked about LoRa here. But the core idea is that instead of updating the weights for all of your parameters, you pick a subset of the parameters for which you update the weights. And we call these like adapters. And so the idea now is that instead of basically quantizing the gradients directly, we make it so that it’s only a very small percentage of the parameters that actually need gradients.
[10:04] Mark: And so this lets us shave off, like, let’s say in the QLORA paper, only like 2% of the total model parameters are trainable and will thus have like gradients and optimizer states associated with them. So basically doing that, plus quantizing your parameters to 4-bit gets you QLORA. So this is exactly what the QLORA algorithm is at a high level. So great. So we got this working. Anecdotally also, because a lot of this stuff is scientific in the sense that we know that Qlora works because everyone uses it and says it works.
[10:38] Mark: And last year I helped host a NeurIPS competition about fine tuning. There was no entry that did not use Qlora. It was by far 99% of the meta and all of the winners used Qlora. However, implementing Qlora is kind of tricky. Basically Qlora was mostly implemented by Tim Detmers. If you look at the file, it’s about 4,000 lines of CUDA code. Some people here may know me by the Discord group CUDA mode.
[11:06] Mark: The way this term came about was Tim Detmers was describing his process for writing CUDA kernels and it was basically he sits down in a room, no music, no lights, nothing, and he just wrote the kernels in a single night. I personally do not have the ability to write such kernels in a single night. So for me, what has been much more accessible is basically writing those kernels in pure PyTorch and compiling them.
[11:29] Mark: And one of the reasons why this file is very long, by the way, is I gave you this nice picture of Qlora, but the algorithm has a lot more details. Basically the weights aren’t in int4, they’re in a format called NF4, which basically mimics the normal distribution of it better. You also can’t just matrix multiply and have four tensors, so you need to upcast them to be F16 and then do a MathML. Remember when I said that it’s very important to figure out what the max is in the original tensor?
[11:55] Mark: Well, this makes you very sensitive to outliers, and that’s why people do it in blocks. Then Qlora also has basically scales per block, but then it also quantizes the scales in what they call double quantization. And so it’s just like a lot. Like basically it’s just a lot of math and that you need to write to be productive. And the alternative we have now at this point is like basically this kernel was just written by Driscusus, also on the PyTorch team.
[12:21] Mark: And what he did was in basically about 900 lines of Python code, got the NF4 tensor from Qlora working for FSTP without writing any custom code. So this is all pure Python. So you can see, for example, here where it says double quantization, the full algorithm is, okay, well, we have these NF4 tensors. Then we’re doing the double quantization, and we’re doing a normalization, and then we return. And this is all Pythonic code, so you can add breakpoints, you can read it, you can understand it.
[12:51] Mark: So again, this is not a great intro quantization algorithm to understand. The core idea is really covered here. But if you understand this, you’re well on your way to understanding more complex kernels. So make sure to just go poke around at that when you have some free time. So the other thing, though, is that within PyTorch itself, PyTorch does not have a concept of an NF4 tensor. PyTorch goes down to int8, it has FP16, and it recently has FP8, and that’s it. But we’ve seen a lot of people experiment with lower D-type kernels.
[13:27] Mark: They’re just not doing it with PyTorch. Today they have actually a way of creating like native data types. And so this is using some modern machine machinery called tensor subclasses, which is probably the feature that PyTorch devs are the most excited about internally. And what this does is you can basically override PyTorch ops such that, for example, with NF4, the way you do a matrix multiplication over an input is you cast the weight from basically an a fourth or int four to the basically the weight of the input, which is an FP 16.
[13:58] Mark: And then you do like a matrix multiplication. You can customize all of this logic using like operations using tensor subclasses in this way. And most notably, you can also define like what the semantics should be for how this data type should be distributed over multiple devices. So if you’re getting QLora not composing with any sort of PyTorch subsystem, generally subclasses is one way of modifying the behavior of PyTorch purely in Python without being a developer on the PyTorch team.
[14:28] Mark: So yeah, I know this is a lot, but yeah, I guess now I’m going to hand it off to Jane, who’s going to talk a lot more about how we got this working over multiple devices. So let me, I’ll stop sharing my screen.
[14:38] Jane: Cool. There’s also a question in the Q&A you can answer in the meantime about fine-tuning full parameter Lama 3 8B on 24 gigabytes. So the question is, so no way to do that, basically?
[14:51] Mark: So there’s no way to fine-tune, like, yeah, floating point Lama 3 8B on 24. Yes, that’s going to be very challenging. Fundamentally, you would need a smaller model, and it’s kind of why QLore has become such a dominant algorithm to work with this size.
[15:05] Jane: Hi, I’m sharing my screen. Well, I’m trying to share my screen. There it is. There it is. Okay, can people see it?
[15:16] Mark: Not yet.
[15:17] Jane: How about now?
[15:18] Mark: There we go. There we go.
[15:20] Jane: Okay. So following up with your questions, like, oh, dang, if we only have one GPU, it might not be enough to fit all the things we want to do. So on the left here, I’ve did a little illustration of what memory looks like when you’re training. So it goes left to right. So in your forward, as you’re training, you gain activations that you’re saving for the backward and the backward, you start using them. So they go down. But then in the backwards, you’re building up your gradients for your parameters. So. the grads also goes up.
[15:48] Jane: And at one point, the activations are all gone and the grads need to be there for the optimizer step. So I note that optimizer step, the state is an atom W state and it is about 2X bigger than the params. I like measured it. So that’s the left side. And just huge disclaimer, there’s more in your model. When you’re doing math, you sometimes need scratch space. So there are intermediate tensors that are not in this illustration, but they will not matter as much. Okay.
[16:19] Jane: If people have questions on that already, please ask now because you will be seeing this little boy a lot. No questions? Great. Okay. So let’s say it’s too tall. At the peak here, you see it’s just taller than your GPU. And GPU is sad, but it’s okay. GPU has a twin. So now the question is to you. What happens? What would you do if you had two GPUs? How would you fit it within that? Like, it’s clearly dividable. So what’s the way to do it? Do people have answers? Has anyone commented on the Discord?
[16:53] Jane: I can’t see the Discord right now.
[16:57] Mark: No answer in the Discord yet.
[16:59] Jane: Wow. Okay. I will answer my own question. So, oh, I guess it wasn’t so obvious after all, but we’ll start with just parallelizing your data. So as mentioned earlier, parallelizing your data will cut down on the batch size, which contributes to this activation memory. So like your params are the same, everything else is the same because they relate to the params like Mark said. But when you slash your activations in half, you get that peak to be lower on each device. Note that everything else is replicated. But let’s make this harder.
[17:29] Jane: Let’s say you have less memory than that. And even after doing that data parallelism, it was still oohing. It’s still out of memory. What else can you do? And here you can get you can go do what Mark was mentioning, but not with quantization. You can shard your parameters in half. You can be like, I want half of my params to live on the first one and half to live on the other. And since the Corad’s and Adam W’s state correspond with the params, each of them also become halved.
[17:59] Jane: So now you’re like, wow, I’m doing great. This is awesome. I now can go on with my life. But note that this is not the entire picture. Because there’s some complications when you do sharding across anything. Because when you shard anything, you kind of, at some point, you’re going to need to bring them back. So imagine you’re doing your model training at this current moment. You’re running on GPU zero. You’re doing your first linear and you’re like, oh crap, I only have half of my parameters.
[18:27] Jane: The other half of me is in that GPU over there. What am I going to do? Well, what you’re going to do is you’re going to be like, yo. Other GPU, we got to talk, we got to exchange some parameters and you will do that. And so what that really looks like is for every step, every layer you run through. you’re going to need a little more memory that’s just representing the layer that’s currently getting processed. And FSDP will do this in a way. So, yeah, this technique, it’s called fully sharded data parallel.
[18:58] Jane: You don’t need to remember that. It’s okay. We’re talking about FSDP the whole time anyway. But, like, it will save, it will bring in the memory you need to do, like, a linear, like a matmul. And once it’s done, it’ll be like, oh, I don’t want this anymore. And then it will put that back and drop it. And it will keep doing that to make sure that you don’t peek too much.
[19:17] Mark: All right.
[19:19] Jane: But you’re like, how do we know that it’s small? Well, you don’t. Well, you do. Because you’re the user. And you get to determine what gets like what gets considered as a layer in FSDP. So this tree thing might look a little scary, but this is Lama2. Lama2 is a transformer decoder that kind of branches into a bunch of things, including 32 transformer decoder layers. And then those branch into a bunch of things, including attention.
[19:47] Jane: And then if you do LoRa, then your LoRa has linears and it just keeps going, dot, dot, dot, dot, dot, dot. But it’s a big tree. And how FSDP wraps things determines what gets brought in when you need something. So if that’s a little confusing, it’s okay. But you can think of each of these blobs, like this green blob is one layer, this big…
[20:08] Jane: blue blob is another layer, but FSDP would wrap, in this case, if you’re specifying linear, then it will be like, okay, I found a linear, that’s my one layer, and then you’ll find another linear, and that’s one layer, and it will kind of iterate from bottom up to do that, and because So in this specific wrapping policy, the linears are wrapped and the transformer decoder layers are wrapped. And then everything else gets wrapped. So each blob is its own layer. And if you’re like, Jane, why are you telling me about this? This is so confusing.
[20:40] Jane: I am now lost. Don’t worry. So the big key point here is that the blobs correspond to how big this little orange dotted box is going to be. So the bigger the blobs, the more memory you’re going to need to bring in at a time. So the bigger that box is going to be. So the way you wrap can really influence the memory you use. Okay, pausing here. Questions, comments? Okay. Next.
[21:11] Mark: We’re actually getting a question around MOEs, but yeah, it’s like, how does model tensor parallelism work for MOEs?
[21:19] Jane: You are getting very ahead. But yeah, so you’re right, there’s a lot more ways to parallelize, but today we’re going to focus on FSDP. And the cool thing about FSDP2, which we’re going to introduce, is that it will handle that tensor parallelism more easily than today’s FSDP. We’ll get into that soon. Okay. So let’s say after you do all of that, you still, like, what can you do now? What else is there? What is left? And that’s where CPU offloading comes in.
[21:51] Jane: And it’s nice because CPU is like your little brother on the side who’s, like, not doing anything as you’re, like, trying to beef up stuff. But it can hold things. So you can make it hold your parameters as you are iterating. So in this case, you’re just like, let’s just put the parameters in CPU, and when we need it, we will move what we need to GPU, do the all gather, do that sharing of knowledge, sharing of data and parameters, and then move on with our merry lives.
[22:17] Jane: And with that, with CPU offloading, you get your memory to be much, much smaller because the biggest chunks here are now living in CPU. So note that we really want the GPU. Like GPUs are accelerators. They’re meant to do like beefy work. And your forward and backward are the beefy work in a model usually. So for the optimizer step, people are like, it’s fine. We don’t need the GPU for that.
[22:43] Jane: And in this case, we’ll only bring in the parameters for forward backward and be like, OK, we can put that on CPU now, which means the optimizer step runs on CPU. And you also save a lot of space on your GPU if you host your whole atom state there. Ah, so there’s that. Okay, so you might be wondering, Jane, FSDP has been published for like a long time now. Why are you explaining this to us? What is the point? People use it. In fact, that’s true.
[23:13] Jane: People like answer.ai who are wonderful, they already built out an integration for FSDP and bits and bytes params for bit to make Qlora happen. But it’s kind of annoying to work with FSDP1. They had to do a lot. And we came out with per-parameter FSDP, which I will also call FSDP2 for later. And what is that? So let’s start with the status quo. Like what is it today? Let’s say you, just for our toy example, you have three tensors that you need to distribute across your two GPUs. And they are these shapes.
[23:49] Jane: So the goal is that you want to make that exchange of, you know, when you’re talking to the other GPU, that thing efficient. And nickel, which is the software and driver stuff that does it, it requires that each GPU will give the same tensor size. So those are, that’s the constraint. What does FSDP1 do today? Okay, what it does is it flattens all your tensors. And this is what they look like in memory. So flattening is actually pretty chill. And then it’s going to line them up in a line.
[24:20] Jane: And then it’s just going to slice it in the middle. And if it’s not even, it will add a padding at the end. And then, because now it’s just arbitrarily sliced in the middle, it will be like, alright, this half goes to 1, this half goes to 0, oh, I guess 0 and 1. And so you’ll end up with something like this, where tensor 1 and I guess a little more than a third of T2 will live on GPU 0, and then the other half of this, including the padding, will live on GPU 1.
[24:51] Jane: And this is nice, but note that the way this is implemented today is that T1 and half of T2 is going to get smooshed into one tensor, which is a big con. We’ll get into that later. And same thing with T2, T3, and the padding. That gets moved into one tensor. And we call that tensor a flat parameter because it’s so flat.
[25:15] Jane: And some reasons why you might already be thinking, hmm, this might not be a good idea after all, is the fact that this forces T1 and T2 and T3 to all have the same D type, to have the same requires gradness, and all the other metadata you might want your tensors to have. Okay. Well, what if we tried down a different path of dividing things in two? So what we’re going to do is we’re going to slice every tensor first. We’re going to cut T1 in half, T2 in half, and T3 in half. Great.
[25:46] Jane: Except we notice that T2 needs some padding because it’s 3 by 3. You can’t really cut that in half. We’ll do it. It’s fine. We’ll shard. And that way, what this will look like is every tensor gets its own representation on this GPU. And this is great. The main pro is that they keep their identity. Like if T1 was UN8 and T2 were BF16, totally fine. They can stay that way. But in the previous case, in FSDP1, you wouldn’t even be able to put them together.
[26:18] Jane: That’s just like, or you’d have to hack around it a lot. And this gets into the QLORA stuff soon. Very soon as well. Okay. There is a con to this. Because of all the slicing and rearranging you have to do, there are extra copies to FSDB2. So that’s a con, but the pros it gives are so much more worth it. And just a recap of before, in more clear terms, like a flat parameter would force all three of these tensors to share all the metadata they have that a tensor can have.
[26:51] Jane: And in FSDB2, because they are individual separate tensors, we call them detensors because they’re not like, you know, the full tensor, they’re a distributed tensor, they’re smaller. They can be themselves, they get to keep their identities, they can have their own D type, their own requires grad. And so you’re like, okay, but why? So if you think of quantization, which Mark talked about earlier, what if you wanted T1 to be UN8, T2 to be full size, like FP32, any other size? In the first regime, unacceptable. The second regime, totally fine.
[27:27] Jane: No one is going to bother you. FSTP will do that. Another thing that is very popular nowadays that LoRa is, is you don’t really want to train the big stuff because that will require big grads, big optimizer step. So what if T2 is frozen, you don’t want it to actually train, and T3 is the LoRa adapter that you do want to train?
[27:50] Jane: In that case, In your first world, you still can’t have that because a tensor can only have one requiresGrad, and the flat parameter here will force you to either make it requiresGrad or not requiresGrad, or to force you to do a lot of fancy hacking to make it work. All of these things, all of these concepts that are like, oh, I wish I had this. In FSDP 1 today, it would be difficult. But in FSDP 2, it’s for free.
[28:16] Jane: And another thing that’s really cool about Fsdp2 that is its own other lecture entirely is memory determinism. So one of the major implementation changes is that now Fsdp2 actually guarantees that you will have only that small little sliver of memory before, like this little orange thing. Whereas Fsdp1 actually didn’t do it well enough and could cause memory spikes that are not deterministic. But yeah, for this one, you should read the blog links here if you want more details. Okay.
[28:53] Jane: So now that we have Fsdp2 and we’re like, this should be easier to use, let’s do it. Let’s use it. And Wei did do that. We did do that. So Wei, who’s another dude on the team, he wrote this PR here that puts together Fsdp2 and NF4. And it works. It works. It’s great. We know like, okay, like FSTP2 is cleaner, it’s more composable. But the last question remains of like, can this actually replace FSTP1?
[29:22] Jane: Like we would love to use it, but can you tell us that it is good on perf, that we won’t be slower than before. And so that’s what the next few slides are going to be. All right, pausing here to see if people have questions, thoughts. If not, we’re going to go with the plan. All right. So here’s the plan. The plan is I’m going to go get some GPUs. We’re going to run some benchmarks. And then we’re going to make sure those are the same benchmark.
[29:50] Jane: And then once they are, we’re going to record some gaps and figure out what the gaps are and if we could make them faster. All right. So getting some GPUs, this is the easiest part of the journey. You just go to Vast AI and then you ask for it. Well, first you need to have money and then you go and you’re like, give me two 3090s or 4090s. And I got to, they are 24 gigabytes each for VRAM. There are some other details here if you care, but they’re not super relevant this time.
[30:19] Jane: Just know that I have two, I have consumer hardware, and they are 24 each. So I ran a bunch of benchmarks on answer.ai’s train.py, which is our baseline, like FSDP1 and BNB. That’s our baseline. And I’m using… the batch size 8 as a baseline, and just so that it works. If you’re curious, if you wanted to run this for yourself, the command is right here. Feel free to copy paste that in the future, but you could just pay attention now. I ran the same thing in the TorchTune recipe.
[30:56] Jane: One difference in TorchTune and train.py is that it uses a YAML for the config versus the command line. It’s just different. And I did have to tweak the YAML quite a bit to make sure that I was running the same config. And since And then these were my results. So peak memory wise, we were doing we were doing better. And for runtime, though, we were like 19% slower.
[31:21] Jane: So someone here might be like, FSDP2, we know it’s stricter about memory, we know it requires extra copies, that makes sense that we’re better at memory and worse at runtime, right? But no, no, no, we got to be diligent. And very quickly, if you look at the traces, it reveals some troubling shenanigans. So, here’s the two traces. On the top is the baseline. On the bottom is our new trace. Can you spot the difference? Okay. There’s a lot of differences. So, I’ll just go with my favorite one.
[31:54] Jane: I work on optimizers and those little light blue things are optimizer steps. And immediately I was like, dude, the optimizer is taking so much longer. What could that be? And so, this is where I get into the actual tracing stuff. I wonder if I can actually show you the traces. That would be pretty cool. Okay. I’m going to stop sharing to reshare and then we can do that. Let’s just share my entire screen. Okay. Do people see traces? Yep. Okay.
[32:33] Mark: Awesome. So,
[32:34] Jane: I’m going to go ahead and share my screen. So, on the left is our baseline, on the right is the slow boy. So in our baseline, we’re going to go to the second step because the first step is always full of like warm up stuff and initiating stuff. So we’re just going to ignore that and we’re going to go here because every other step after this is much more normalized. And something you can do, I’m using Perfetto. I don’t know if people are familiar with Perfetto already, but it’s been pretty helpful. Yeah. Okay.
[33:12] Jane: And something that is super useful and nice in Profetto is you can highlight a region and it will give you profile or profile or profile or while I cannot talk today results on here. So here you can see that there are a bunch of it tells you what thing takes the longest. It’s like the moles take 77 milliseconds and there are 70 768 of them. And on this side, when we do that. we’re going to notice some very different numbers. So here, the mole also takes the longest, but there’s 1,700 of them compared to 700.
[33:48] Jane: And you might be like, what is going on here? But let’s go look at the smallest number. In optimizers, there’s only one divide. You can just take my word for that. So here we know that there are 384 divs, which means that there are 384 parameters. Here, we see that there are 800. 896, which is like more than double. And so let’s go find a div. Like, where are they? And here you can just like do do do. But you can already notice that everything kind of gets doubled here. Whereas in here.
[34:25] Jane: they are just called immediately. So like this A10 mole goes directly to A10-2. This A10 mole though goes to A10 mole again. And you’re like, wait, what is going on? Why is that? And this is where you learn the painful things of tensor subclass dispatch. So since we’re using detensors, it means that it goes into this mole as a detensor. And then detensor is like, all right, let me do my metadata unwrapping. And now you get to go to A10 mole as a… playing tensor now. So there are double the amounts.
[34:58] Jane: And just to spare you some math, um, spare you some math, it turns out that if we divide this mole by two or the div by two or any of the things that run once, it shows that we actually are running training on more parameters than we thought. So in the left side, we’re only running it for 384, and the right side, we’re running 64 more parameters. Like, can people guess where this came from? I will show you what war story of debugging has shown me.
[35:37] Jane: In the end, I realized that this was a config difference, where if you are fast at reading, you might have already spotted the difference here. The answer is that, in train.py, they do not glorify the output projection, whereas in our world, we do do that. And since glorifying means you add two adapters and there are 32 layers, 32 times 2 is 64 more extra parameters to train. So yeah, that was bad. And that was a great lesson because it was like, I was not measuring apple to apples.
[36:12] Jane: And I needed to do some other things to make sure that we were. So the first thing is like making sure the data set was the same, making sure that every parameter was the same, changing the wrapping policy to be the same. And after I did all of that, I ran another benchmark. And here, I mean, that was like also weeks of work, by the way, like, like, it was not easy to figure out every little difference and why they were different.
[36:36] Jane: But after all of that, I was like, let me run it again, maybe it will be better. But no, it was still slow. It was actually slower than before. But the peak memory was a lot better. So however, I was still happy, because even though it may feel like we took a step back, we actually made a giant leap forward. on blocking the first real step, which is like now that we have apples to apples, there are things that will match. And then we should just look at things that are different.
[37:02] Jane: And so that’s where I could start playing my very long game of spot the difference. The first thing I did, though, was like measure their runtime. And here you can see that the forward, backward and optimizer were like the culprits. And that’s how I knew what to focus on. OK, so this is a big slide. I, okay. If at this point you have not looked at the traces yet, but you would like to, I sent a link to the Google Drive in the Discord, and there are the traces that you want.
[37:38] Jane: There’s like, there are four of them, and the two you care about to look at now are the ones that don’t start with. bad or final, the other two. Do the Answer AI one and the TorchTune one. But I’m going to, I already found these gaps for you, and it’d be fun, if you find it fun, if you want to learn more about traces, it’d be fun if you could find each of these yourself, like where they are located, and do the same exploration I did.
[38:02] Jane: But because this is a presentation and I do want to save time, we’re going to dive right into what these gaps were and where, like, how we ended up fighting them. So the very first gap is, yeah, so we’re going to go. And this is the link, by the way. Okay, the very first gap is the optimizer step was still slower. But remember how earlier I hinted that there was all this overhead here for detensor? And the way it works is because of C++ dispatch, it just takes a long time to do that kernel call.
[38:34] Jane: And because there are so many calls, all of this overhead adds up to make the optimizer step three times slower. And also another minor detail, if you look at this, this is 64 microseconds and this is 81. The answer for that is because the parameter is not contiguous, but that’s more minor. So the solution here is actually really chill. We worked with Intel, we were like, hey, what if we just had a fused atom? Like your optimizer step, but all happening in one kernel, so you can dispatch just once and get results.
[39:08] Jane: So this avoids all that overhead because there’s just one kernel now versus the other kernel. 384 times however many ops there are. And it also leverages vectorization. So we go from about one second to 120 milliseconds, which is like an 8x speedup. So that’s one gap that is now gone. All right. Pausing here to see if people have questions.
[39:38] Mark: I’ve been speed answering everyone on Discord.
[39:41] Jane: Okay, nice. Okay. I was like, are they lost? But no. Okay.
[39:46] Mark: They’re very happy, actually. People are saying they love this kind of debugging. And yeah, people love it.
[39:51] Jane: Okay. Well, let’s keep going. So the next three are a little boring, but we’re gonna go through. And there was a lot of pain in this one, actually. This second gap was crazy. So I went and the way I literally did this was the most brute force way you can imagine. Like you open your trace, you find the first forward, you find the first GPU kernel, and you’re just like, do they match in size, shape, input, everything? And I would do that for you.
[40:18] Jane: But we are a little short on time, so I’m going to just show you what the difference is here. And the first difference that was major was that the second all-gather in every linear, like every layer, was somehow five milliseconds longer. And that was when I needed to click on them and figure out how big they were. on the left side for train.py, there were just fewer things getting all gathered. Like it was just not the same thing. And I was like, why is it not the same thing?
[40:51] Jane: So I did a lot of printing and I hacked around FSDB2. And what that yielded was me writing down the sizes of every tensor that got packed up in the all gather and realizing that the difference was because in our NF4 metadata, where… Answer.ai did not pack their scalars and quantization factor. They just like for bits and bytes when they did their four bit thing, they use a dictionary. They have their own method.
[41:19] Jane: And this is actually the reason they couldn’t pack it is because FSDP1 had restrictions, by the way, like it just wouldn’t do it for them. So they needed to work around that. So that was one piece of the puzzle. where we just packed everything in one go. But in the other bigger piece of the puzzle, the big, big difference, like that was just like 1 million bytes, whatever. It doesn’t matter. But the other thing was like so many more. It was like 12 million bytes.
[41:45] Jane: And that was when we realized that when we opted out of LoRa, the output projection did not get quantized in our world. So it was like four times the size of Answer.ai’s version. And I was like, Why don’t we do that? And then I talked to the TorchTune team and they’re like, oh yeah, we should do that. And so we should do that. That’s the conclusion. We should do that. The first one is intended behavior. So we don’t really need to worry about that. But the second one, we should do it.
[42:13] Jane: And we will hit you up when that happens. So this gap I would mark as yellow. Like we know what to do. It’s solved. Okay, third gap is realizing that there were just like more overhead. And remember what Mark was saying how when you have NF4, you have to you can’t just like put that through your gem kernel like CUDA is going to complain, you need to de quantize get it to the BF 16. And then put it through the CUDA kernel.
[42:39] Jane: And it turns out that Tim Detmers, you know, brilliant guy already wrote a really fast version of that, whereas we just get stuck with the little raggedy version that tries every op. So that’s also where we are slower just because of additional overhead. But again, this is not a major problem. Solutions. We could use Torch Compile. I tried it. It was not trivial, so I was like, I will look into this later when I have a month of time. And then, or when I don’t have a presentation on Monday.
[43:05] Jane: And then the second step is to just use Triton kernels. So Driss, our coworker, already wrote them. I didn’t want to copy paste them for the sake of this presentation. But if you want to copy paste them, go for it. No one’s judging. They’re good. They work. And so we’re like, okay, we also know how to fix this one. The third one was really stupid. This one is definitely the worst gap, where basically there were just very different ops happening before the SDPA. And this was just because we used two different ropes. algorithms.
[43:41] Jane: TorchTune was like, we are traditional. We are going to use the original meta algorithm. We work there. So there will be no numerical differences. And everybody else in the world is like, oh, but we want it to be faster. So it’s fine. And then the real solution here is just for TorchTune to offer more of these. And that’s also in progress. So, yeah, but okay. Let’s talk about the most fun one. The most fun gap I noticed is I realize this is a little hard to read.
[44:07] Jane: But on the left side, note that there are two things happening. The first, this green thing here, is the CPU offload where you go from CPU to GPU. And the second one is when you after you’ve moved it to the GPUs, you like have them talk to each other. And here in train.py, in FSTP1, we’re like, wait, how come this is overlapped? Like, you see how this is, there’s no gap here. But this one is so badly overlapped. Look at that. Look at that exposition. It’s bad.
[44:36] Jane: And this was really frustrating because this was costing us like 10 milliseconds each step. And I was like, I wonder why this is. But actually, this is very much expected behavior. And this is part of the bigger memory constraints that FSTP2 promises you. So, FSTP2 is promising that, hey, we’re not only looking at all gathers, we’re also going to make sure that before we bring CPUs to GPUs that you have the space you have. So it is guaranteeing the constraint that only two layers of parameters will be allowed at a time.
[45:12] Jane: And that is why on the left, because of how FSTP1 was implemented, it didn’t do that very well. So you’d get random memory spikes. And FSDP2, you’re promised to never get that. But someone looking at this trace will be like, but this is kind of, what if I’m okay with non-deterministic memory? Like it’s not happening now. Like maybe I can just go on with my life. But no, no, no, we have an answer for you. And the answer is the reason it’s so exposed is not because FSDP is too strict. That’s not the problem.
[45:41] Jane: The problem is that the overlap of computation and communication was too different. The computation here is super duper tiny because it corresponds to this tiny linear here. And then here when you’re CPU offloading, you’re actually trying to bring this big boy back in. So it’s like the mismatch in the layer sizes was really causing this problem. So what could we do? Well, it’s fine. We’ll just change the wrapping policy, have bigger layers. And it’s really just, hey, don’t wrap those linears by themselves.
[46:16] Jane: Just group them in so we can just have every transformer decoder layer be its own layer. And note that this is only possible with FSDP2. You can’t have the right-hand side in FSDP1. Why? Because the lower linears have the base weights. The base weights are quantized. They’re NF4 tensors. They’re going to be small. They’re UN8. Whereas your linear tensors, those are BF16 or whatever training thing you want. And in FSDB1, they can’t be brought together under the same family because they’re going to be forced into one big flat parameter. But in FSDB2, they can coexist.
[46:55] Jane: And the change is actually really easy to do. The policy is just a matter of commenting out these two lines here. And once we do that, the solution is like they’re both overlapped. It’s kind of crazy. Look, like there’s no more exposition at all, where even in the first case, even before here in the train.py one, this all gather was exposed, also due to the same reason. That’s not even true at all here. And this wrapping policy, this new one, is only possible because of fsdp2, which is a great point here. All right.
[47:32] Jane: So now we’ve fought all our gaps and things work. So it’s your turn. You should try them out. You should really, like, it doesn’t have to be NF4. If you have another favorite lower precision thing, you can try it. If you want to try out more Lama stuff with it, you can. There are now scripts ready to do that. So yeah. One disclaimer, though, we are working on this. This is not a forever thing. Pointing. of just Fsdp and Qlora does not work yet. So that’s just FYI, we’re working on it.
[48:06] Jane: Sorry, you can’t do that. But you can try, you can just like try and play with it, among other things. So yeah, here I would love to, I mean, Mark and I are speaking here, but this was really made possible by all these amazing people. So Driss wrote the original NF4 stuff. Andrew is the main dude who designed FSDP2. Wei ended up taking Andrew’s work and making it compose with… Driss’s work, so he like amalgamated both of them. And then Torch Tomb people, so like Rohan and Evan, they were super helpful.
[48:40] Jane: They wrote the Laura recipes and they’re the foundation of which we built and showcased our work. And of course, Mark, who’s amazing. So thanks, Mark.
[48:52] Mark: Yeah, so I really hope this gives people some more context around what we’re thinking here. We did want to showcase a useful recipe with Qlora and FSTP composition. But really, this is kind of like our call to action here would really be if you’re doing interesting research at the intersection of quantization and distributed, we’d really, really love to hear from you. So if you’re working with more exotic D types or more exotic forms of parallelism, a lot of this work should really, really be helpful.
[49:21] Mark: And we have all these links here that can give you some more context. I guess we’ll pause here if people have any questions. Otherwise, we’ll probably be hanging out on Discord for a couple more hours if people have any questions.