

Jane Xu Mark Saroufim

**PyTorch Devs** 

# How people deal with OOMs

Smaller batch size Smaller model



# RuntimeError: CUDA out of memory







24 GB of VRAM

40 or 80 GB of VRAM

# Memory crash course

Llama 7B has 7B parameters in fp16

Each parameter is 2 bytes so parameters is 14GB

Gradients memory = parameter memory

Adam Optimizer State = 2 \* parameter memory

Total = 14GB + 14GB + 28GB = 56GB



# Larger batch sizes and context lengths

Bottleneck is almost always Activations that's why Flash Attention is important

Paper math is great but Papers don't tell us when we're wrong <u>https://dev-discuss.pytorch.org/t/how-to-measu</u>

<u>re-memory-usage-from-your-model-without-run</u> <u>ning-it/2024</u>



#### Let's optimize the bottleneck!

arXiv https://arxiv.org > cs
[1412.6980] Adam: A Method for Stochastic Optimization
by DP Kingma · 2014 Cited by 179876 – Abstract:We introduce Adam, an algorithm for first-order gradient-based optimization of stochastic objective functions, based on adaptive ...

:0



Ok let's take a look at params



14 GB at fp16

3.5 GB at int4 Each int4 needs ½ byte\*

#### Hello quantization

```
import torch
def quantize_tensor(x_fp32):
    absmax = torch.max(torch.abs(x_fp32))
    c = 127.0 / absmax
    x_int8 = torch.round(c * x_fp32).to(torch.int8)
    return x_int8, c
def dequantize_tensor(x_int8, c):
    x_fp32 = x_int8.to(torch.float32) / c
    return x_fp32
```

#### torch.compile

#### import os os.environ["TORCH |

os.environ["TORCH\_LOGS"] = "output\_code"
import torch

# @torch.compile() def quantize\_tensor(x\_fp32): absmax = torch.max(torch.abs(x\_fp32)) c = 127.0 / absmax x\_int8 = torch.round(c \* x\_fp32).to(torch.int8) return x\_int8, c

```
@torch.compile()
def dequantize_tensor(x_int8, c):
    x_fp32 = x_int8.to(torch.float32) / c
    return x_fp32
```

x\_int8, c = quantize\_tensor(torch.randn(10, device="cuda"))
x\_fp32 = dequantize\_tensor(x\_int8, c)

#### https://github.com/pytorch/ao

#### def triton (in out ptr0, in ptr0, out ptr0, xnumel, rnumel, XBLOCK : tl.constexpr): xnumel = 1rnumel = 10RBLOCK: tl.constexpr = 16 xoffset = tl.program id(0) \* XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None] xmask = xindex < xnumel</pre> rindex = tl.arange(0, RBLOCK)[None, :] roffset = 0rmask = rindex < rnumel</pre> r0 = rindextmp0 = tl.load(in\_ptr0 + (r0), rmask, other=0.0) tmp1 = tl math.abs(tmp0) tmp2 = tl.broadcast\_to(tmp1, [XBLOCK, RBLOCK]) tmp4 = tl.where(rmask, tmp2, float("-inf")) tmp5 = triton\_helpers.max2(tmp4, 1)[:, None] tmp6 = 1 / tmp5tmp7 = 127.0tmp8 = tmp6 \* tmp7tmp9 = tmp8 \* tmp0tmp10 = libdevice.nearbyint(tmp9) tmp11 = tmp10.to(tl.int8) tl.debug\_barrier() tl.store(in\_out\_ptr0 + (tl.full([XBLOCK, 1], 0, tl.int32)), tmp8, None) tl.store(out\_ptr0 + (tl.broadcast\_to(r0, [XBLOCK, RBLOCK])), tmp11, rmask)

# Back to gradients



# Full finetuning vs LORA



weight += (lora\_B @ lora\_A) \* scaling

# QLoRA

All winning entries for <a href="https://llm-efficiency-challenge.github.io/">https://llm-efficiency-challenge.github.io/</a> used QLoRA



# Implementing QLoRA

4000 lines of CUDA code <a href="https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu">https://github.com/TimDettmers/bitsandbytes/blob/main/csrc/kernels.cu</a>



...

master forgive me, but i need to activate "cuda mode" just this once



Jeremy Howard 🤣 @jeremyphoward · Dec 15, 2023

Replying to @jeremyphoward

he says he goes into "cuda mode" to write kernels. No music, lights off, no distractions.

He wrote the 4bit kernel in one night.

# Forgot to mention some details

- Weights aren't in int4 but NF4 which is closer to a normal distribution
- Can't matrix multiply NF4 tensors, need to dequantize and matmul
- Remember how important the max is when doing the quant? Well you can't use the same max for everything otherwise you're too sensitive to outliers
- Quantization typically done in blocks with independent scales
- QLoRA quantizes the scales, double quantization!
- 🥯
- Let's look at some code

https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py

# Bitpacking

PyTorch supports down to int8 <a href="https://pytorch.org/docs/stable/tensors.html">https://pytorch.org/docs/stable/tensors.html</a>

Elements of a tensor need to be byte (8 bit) addressable

C++ is the same a bool takes 8 bits of memory



# But what if we wanted to implement a real Tensor

Probably feature PyTorch devs are most excited about <a href="https://github.com/albanD/subclass\_zoo/">https://github.com/albanD/subclass\_zoo/</a>

We can define what matrix multiplication over NF4 means using Python <a href="https://github.com/pytorch/ao/pull/37">https://github.com/pytorch/ao/pull/37</a> by @drisspg

return F.linear(input, weight.to(input.dtype))

But we can also define how FSDP would handle an NF4 Tensor <u>https://github.com/pytorch/ao/pull/150</u> i.e aten.split by @weifengpy

https://pytorch.org/docs/stable/tensors.html

One GPU was not enough...



GPU O

\* DISCLAIMER: the model memory to the left does *not* include literally everything that'll take up memory during training, but is meant to be illustrative of the significant pieces.

#### But what if you had 2?



Let's start with the obvious: parallelize the data (batch size)



Sharding the batch size halves the activations. Everything else is duplicated.

# Let's start with the obvious: parallelize the data (batch size)



Note that we need to sync/sum the grads before the optim step with an all-reduce! In general, techniques to lower memory require additional compute and management. But what if that wasn't enough?



What else can we do?

Let's keep parallelizing! Shard the params too.



Sharding the params will in turn reduce gradient and optimizer memory.

# Congrats you have discovered FSDP! - fully sharded data parallel



FSDP will bring in only a layer's weights at a time to avoid using too much memory. As a result, we need more collectives to shuffle tensors between GPUs. A slightly more accurate depiction of memory for a step in FSDP



Will be freed when the layer is done.

# What constitutes a layer in FSDP?



Every nn module is a tree of more nn modules.

The user's wrapping policy determines what gets treated as its own "layer".

This depicts a wrapping policy where TransformerDecoderLayer and Linear are specified.

# What you decide to wrap influences memory usage (and more)



The more "fine-grained" you wrap, the smaller that dotted memory will be.

Smaller blobs = less memory needs to be all-gathered at a time.

Will be freed when the layer is done.

#### But what if after all your tweaking, you still OOM?



What *else* can we do?

#### In comes CPU offloading!



Will be freed when the layer is done.

Don't forget about the CPU!

Just keep parameters on the CPU and move them to the GPU when computing forward + backward.

Note that the optimizer update will be done on CPU, so the optim state lives there too.

# None of this is quite new...right?

I mean...okay, yes, FSDP has existed for a while, with all the features mentioned above.

And wonderful people have been using these features, like Answer.AI who built <u>fsdp\_glora</u> with FSDP x bnb to compose qLoRA and distributed.

BUT we've recently come out with *per-parameter* FSDP!

# What is per-parameter FSDP?

Let's start with the status quo: **flat-parameter FSDP1**. Say you have these params to shard across our two GPUs:

t1: (2, 3) t2: (3, 3) t3: (2, 2)

Goal: make all-gather efficientIn MemoryConstraint: NCCL requires eachGPU contribute same-size Tensors





•

.

Worker 1

t3

t3

PAD

PAD

Worker 0

# FlatParam FSDP

Each chunk is smooshed into **one** Tensor, which we call a FlatParameter.

This approach has its pros:

- Contiguous memory
- One can use views to retrieve t1, t2, t3 (vs copy's)

but also its cons...



Will be freed when the layer is done.

# Another way to shard, dubbed "per parameter"



# Per-param FSDP2

Each chunk is **many** Tensors, each a DTensor (D is for distributed).

This approach has a major pro: each param maintains its own identity (dtype, subclass, metadata).

BUT does require extra copies (\$\$\$ > views) during All-Gather.









Both Workers

# Why do we think the "per-parameter-ness" is worth it?





FlatParameter forces t1, t2, and t3 to share dtype, requires\_grad as it is one Tensor. In per-params, t1, t2, and t3 here can be themselves! They can have their own dtype, requires\_grad.

## Why do we think the "per-parameter-ness" is worth it?



Think **quantization**: what if you wanted t2 to be uint8 + t1 to remain fullsized bf16? Think **parameter freezing/LoRA**: what if t2 is a frozen base weight while t3 is the LoRA adapter?

You'd have to hack around FSDP1 concepts you'd get for free in FSDP2.

# FSDP2 also has other cool pros, like deterministic memory

This is another major implementation change that actually guarantees deterministic memory:

Only 2 layers worth of memory will coexist at a time.

But it'll take another lecture to explain xD, for more details, see <u>https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486</u>

# Why do we think the "per-parameter-ness" is worth it?

Well, it...

#### 1. just makes more sense

- a. Every parameter is an evenly sliced version of itself in FSDP2
- b. Whereas in FSDP1, some parameters are entirely on 1 machine while others could be split across arbitrarily. Plus, every parameter belonging to a FlatParam must share dtype and subclass and requires\_grad.
- 2. widens what could be wrapped by FSDP into a layer
- 3. unlocks param-wise optimizers, like AdaFactor
- 4. **composes with other distributed parallelisms** (TP, PP) through DTensor, as tensor structure is maintained

### FSDP2 also has other cool pros, like deterministic memory

Due to how FSDP1 implemented its rate limiter on CPU, it couldn't actually guarantee:

Only 2 layers worth of memory will coexist at a time.

For example, using CPU offloading sometimes caused *more* memory usage!

FSDP2 moved the burden of rate limiting from CPU to CUDA events, so now this guarantee can actually be met :D

For more details, see

https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486

# Our implementation overlaps communication with computation



This way, the all-gathers are imperceptible in terms of runtime!

Note that this requires **prefetching** the next layer's parameters so that they could be ready by the time its compute starts.

#### But we do it methodically to avoid peaking memory



The FSDP rate limiter forces prefetching to wait until the previous layer is freed.

Desired guarantee: only 2 layers worth of memory will coexist at a time.

## So let's take FSDP2 out for a swim

Answer.AI had already <u>successfully composed FSDP1 with QLoRA</u>, but only after expertly maneuvering through its limitations.

e.g., "FSDP was not copying the quantization information needed for each shard to use the model! That's because FSDP is quite opinionated on the subset of data it will sync between GPUs"

We want to offer cleaner, more general solutions to composing distributed with low precision parameters, so why not start here, with FSDP2 x NF4?

So we did! <u>https://github.com/pytorch/torchtune/pull/909</u>

Cleaner and more composable is always good, but how do we do on perf? Let's find out!

# The plan

- 1. Get some GPUs
- 2. Run a benchmark on Answer.Al's train.py
- 3. Run the same benchmark on Wei's torchtune recipe
- 4. Wait...were those actually the same benchmark?
- 5. Make sure what I'm measuring was  $\stackrel{\checkmark}{=} \Leftrightarrow \stackrel{\checkmark}{=}$  and not  $\stackrel{\checkmark}{=} \Leftrightarrow \stackrel{\checkmark}{=}$
- 6. Record the gaps
- 7. Investigate and fill the gaps if possible

# Getting some GPUs

I rented myself a dual setup on vast.ai

- 2 RTX 3090s, 24 GB VRAM each
- 117 GB RAM
- 12 cores

=> required torch.set\_num\_threads(8)

- PCIE 3.0 16x, with 9.0 GB/s bandwidth each
- CUDA 12.2

# Running a benchmark on Answer.Al's train.py

llama2-7B, context length 2048

| batch size                                                      | batch size peak memory |       |
|-----------------------------------------------------------------|------------------------|-------|
| 8                                                               | 15.03 GiB              | 13.9s |
| 10                                                              | 10 18.05 GiB           |       |
| below needs<br>PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |                        |       |
| 12 21.06 GiB                                                    |                        | 19.9s |
| +                                                               | + OOM                  |       |

llama2-7B, context length 2048, with CPU offloading

| batch size peak memory                                         |              | runtime for a step  |
|----------------------------------------------------------------|--------------|---------------------|
| 8 12.88 GiB                                                    |              | 14.0s               |
| 10 15.89 GiB                                                   |              | 17.5s               |
| below needs<br>PYTORCH_CUDA_ALLOC_CONF=expandable_segments:Tru |              | lable_segments:True |
| 12                                                             | 12 18.91 GiB |                     |
| 14                                                             | 21.92 GiB    | 23.6s               |
| +                                                              | OOM          | N/A                 |

Thanks Answer. AI peeps on CUDA MODE for sending me benchmarks to try!

# Running a benchmark on Answer.Al's train.py

llama2-7B, context length 2048

| batch size                                                      | batch size peak memory |       |
|-----------------------------------------------------------------|------------------------|-------|
| 8                                                               | 15.03 GiB              | 13.9s |
| 10                                                              | 10 18.05 GiB           |       |
| below needs<br>PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |                        |       |
| 12 21.06 GiB 19.9s                                              |                        | 19.9s |
| +                                                               | OOM N/A                |       |

llama2-7B, context length 2048, with CPU offloading

| batch size                                                      | batch size peak memory |                     |
|-----------------------------------------------------------------|------------------------|---------------------|
| 8                                                               | 8 12.88 GiB            |                     |
| 10                                                              | 10 15.89 GiB           |                     |
| below needs<br>PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |                        | dable_segments:True |
| 12                                                              | 18.91 GiB 20.9s        |                     |
| 14                                                              | 21.92 GiB              | 23.6s               |
| +                                                               | OOM                    | N/A                 |

I decided to focus on just one of these to do an apples to apples comparison.

python train.py --model\_name meta-llama/Llama-2-7b-hf --batch\_size 8
--context\_length 2048 --train\_type qlora --use\_gradient\_checkpointing True
--reentrant\_checkpointing True --dataset dummy --dataset\_samples 48

#### Running the same benchmark on Wei's torchtune recipe

tune run --nnodes 1 --nproc\_per\_node 2 lora\_finetune\_fsdp2 --config
recipes/configs/dev/llama2/7B glora fsdp2.yaml \* with tweaks to align the configs

|           | batch size | peak memory | runtime for a step |
|-----------|------------|-------------|--------------------|
| train.py  | 8          | 12.88 GiB   | 14.0s              |
| torchtune | 8          | 12.60 GiB   | 16.5s              |

Since FSDP2 is stricter about memory and requires extra copies, it would be easy to chalk up the differences above as expected.

But, nah, we have to be diligent! And very quickly, one glance at the trace revealed troubling shenanigans.



Spot the difference!



Why were the optimizer steps so much bigger in the torchtune trace?

#### train.py trace

|                                           | 00.00.00         | 00:00:10                | 00:00:20                                     | 00:00:30                | 00:00:40             |
|-------------------------------------------|------------------|-------------------------|----------------------------------------------|-------------------------|----------------------|
| 19871d17:34859<br>54 204 000<br>X =       |                  | 00:00:33<br>100:000:000 | 00:00:33<br>200 000 000<br>318ms 746us 650ns | 00:00:33<br>300:000:000 | 00:00:33 400 000 000 |
| ✓ Misc Global Tracks                      |                  |                         |                                              |                         |                      |
| <ul> <li>python 827147</li> </ul>         | Ξ                |                         |                                              |                         |                      |
| thread 827147 (pt_main_thread) 82714<br>7 |                  |                         | Optimizer.step#AdamW<br>aten::::             |                         |                      |
| thread 828498 (pt_autograd_0) 828498      | auto cud<br>Chec |                         |                                              |                         |                      |
| Current Selection                         |                  |                         |                                              |                         | Ť                    |

#### Area Selection Pivot Table

| name ····                 | SUM(dur) -  | SUM(thread_dur) ···· | Count ···· | ≡ |
|---------------------------|-------------|----------------------|------------|---|
| Total values:             | 319ms 138us | null                 | 1          |   |
| Optimizer.step#AdamW.step | 319ms 138us | null                 | 1          | • |
| aten::mul_                | 77ms 533us  | null                 | 768        | • |
| aten::div                 | 50ms 360us  | null                 | 384        |   |
| aten::to                  | 49ms 803us  | null                 | 1920       | • |
| aten::_to_copy            | 45ms 904us  | null                 | 1920       |   |
| atoniaddemul              | 43me 758ue  | null                 | 294        |   |
| aten::lerp_               | 32ms 71us   | null                 | 384        |   |
| aten::empty_strided       | 28ms 551us  | null                 | 1920       |   |
| aten::add_                | 28ms 415us  | null                 | 768        | • |

#### torchtune trace



#### Area Selection Pivot Table Flow Events

| name ····                 | SUM(dur) -        | SUM(thread_dur) ··· | Count ···· |
|---------------------------|-------------------|---------------------|------------|
| Total values:             | 976ms 556us 612ns | null                | 1          |
| Optimizer.step#AdamW.step | 976ms 556us 612ns | null                | 1          |
| aten::mul                 | 391ms 372us 93ns  | null                | 1792       |
| aten::lerp_               | 242ms 250us 483ns | null                | 896        |
| aten::div                 | 181ms /33us 248ns | nuii                | 896        |
| aten::addcdiv_            | 138ms 731us 84ns  | null                | 896        |
| aten::sqrt                | 115ms 135us 309ns | null                | 896        |
| aten::add_                | 111ms 392us 168ns | null                | 1344       |
| aten::addcmul_            | 95ms 505us 620ns  | null                | 896        |
| aten::to                  | 63ms 552us 869ns  | null                | 2240       |
| aten::_to_copy            | 58ms 362us 598ns  | null                | 2240       |





Area Selection Pivot Table Flow Events

| name ····                 | SUM(dur) -         | SUM(thread_dur) ···· | Count ··· |
|---------------------------|--------------------|----------------------|-----------|
| Total values:             | 976ms 556us 612ns  | null                 | 1         |
| Optimizer.step#AdamW.step | 976ms 556us 612ns  | null                 | 1         |
| aten::mul                 | 391ms 372us 93ns   | null                 | 1792      |
| aten::lerp_               | 242ms 250us 483ns  | null                 | 896       |
| aten::div                 | 181ms / 33us 248ns | nuii                 | 896       |
| aten::addcdiv_            | 138ms 731us 84ns   | null                 | 896       |
| aten::sqrt                | 115ms 135us 309ns  | null                 | 896       |
| aten::add_                | 111ms 392us 168ns  | null                 | 1344      |
| aten::addcmul_            | 95ms 505us 620ns   | null                 | 896       |
| aten::to                  | 63ms 552us 869ns   | null                 | 2240      |
| aten::_to_copy            | 58ms 362us 598ns   | null                 | 2240      |

Aha! torchtune was training more parameters than Answer.Al's train.py config. 448 - 384 = 64 extra params! Any guesses where they came from?

| 1116 | # If lora_target_modules is 'all', set sensible defaults for llama + mistral type modules                               |
|------|-------------------------------------------------------------------------------------------------------------------------|
| 1117 | <pre># See peft.utils.constants -&gt; TRANSFORMERS_MODELS_T0_LORA_TARGET_MODULES_MAPPING for the current defaults</pre> |
| 1118 | <pre>if lora_target_modules == "all":</pre>                                                                             |
| 1119 | args["lora_target_modules"] = ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]                       |
| 1120 | <pre>elif lora_target_modules.lower() == "default":</pre>                                                               |
| 1121 | <pre>args["lora_target_modules"] = None</pre>                                                                           |
| 1122 |                                                                                                                         |

| 18 | # Model Arguments                                                |
|----|------------------------------------------------------------------|
| 19 | model:                                                           |
| 20 | <pre>_component_: torchtune.models.llama2.qlora_llama2_7b</pre>  |
| 21 | lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] |
| 22 | apply_lora_to_mlp: True                                          |
| 23 | apply_lora_to_output: False                                      |
| 24 | lora_rank: 8                                                     |
| 25 | lora_alpha: 16                                                   |
| 26 |                                                                  |

#### Spot the difference!

answer is on next slide :D

| 1116 | # If lora_target_modules is 'all', set sensible defaults for llama + mistral type modules                               |
|------|-------------------------------------------------------------------------------------------------------------------------|
| 1117 | <pre># See peft.utils.constants -&gt; TRANSFORMERS_MODELS_T0_LORA_TARGET_MODULES_MAPPING for the current defaults</pre> |
| 1118 | <pre>if lora_target_modules == "all":</pre>                                                                             |
| 1119 | args["lora_target_modules"] = ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]                       |
| 1120 | <pre>elif lora_target_modules.lower() == "default":</pre>                                                               |
| 1121 | <pre>args["lora_target_modules"] = None</pre>                                                                           |
| 1122 |                                                                                                                         |

| 18 | # Model Arguments                                                |
|----|------------------------------------------------------------------|
| 19 | model:                                                           |
| 20 | <pre>_component_: torchtune.models.llama2.qlora_llama2_7b</pre>  |
| 21 | lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] |
| 22 | apply_lora_to_mlp: True                                          |
| 23 | apply_lora_to_output: False                                      |
| 24 | lora_rank: 8                                                     |
| 25 | lora_alpha: 16                                                   |
| 26 |                                                                  |

#### Spot the difference!

torchtune LoRA-fied the output\_proj when the train.py did not.
LoRA-fying = adding 2 low rank adapters to the o of every qkv.
32 TransformerDecoderLayers \* 2 more params each = 64 extra params to train.



I took a pause cuz it wasn't going to be fruitful if the items getting measured weren't sufficiently aligned!

Steps I took:

- Stopped LoRA-fying the output\_proj in my torchtune recipe
- Changed FSDP2 wrapping policy to wrap the same layers
- Replicated the same "dummy" dataset for my benchmark
- Took another pass ensuring max seq len + other hyperparams for model construction were the same



I then ran the benchmark after my changes...and FSDP2 x NF4 still looked mighty slow.

|           | batch size | peak memory | runtime for a step |
|-----------|------------|-------------|--------------------|
| train.py  | 8          | 12.88 GiB   | 14.0s              |
| torchtune | 8          | 10.70 GiB   | 16.6s              |

Even though it may feel like we took a mini step back, we've made a giant leap unblocking our official first step: understanding the problem (gaps).

I could finally start a very long game of Spot the Difference.

train.py trace

torchtune trace

# Recording the gaps

I first did a survey of the land, and derived this chart:

|             | FSDP1 & answerai | FSDPZ & ture |          |
|-------------|------------------|--------------|----------|
| <b>GNUM</b> | 92 ms            | 129 ms       |          |
| first AGs   | (20 ms           | [03 ms       |          |
| Forward     | 4 364 MS         | 5209 ms      | + 900 ms |
| bachward    | 8 975 ms         | 10136 ms     | + 1100ms |
| optimizer   | 3(9 ms           | 894 ms       | + 500 ms |
| total       | 13 870 ms        | 16471 ms     |          |
| eZe         | 13 964 ms        | 16586 ms     | +2622 ms |

We see that we should focus on the forward and the optimizer step kernels.

# Recording the gaps



VS

FSDP::pre\_forward M Memcpy ... Memc... Memc... Memc...

nccl:\_all\_gather\_base ncclDevKernel\_AllGather\_RING\_LL(ncclDevComm\*, unsigned long,

exposed AGs/mem H2Ds

emcpy Hto

# Investigating and filling the gaps if possible

lesgo

traces:

https://drive.google.com/drive/u/3/folders/1HmGNC4v4L5nXhtdDMVCpUBrme1ELp-2C

# Gap: the optimizer step is still slower





Understanding why:

- DTensor overhead
- parameter is not necessarily contiguous

Solution: used fused! (thanks Intel)

# Gap filled: the optimizer step is now faster



Solution: used fused! (thanks Intel)

- avoid DTensor overhead by only dispatching 1 fused kernel!
- leverage vectorization
- goes from ~1s -> 120ms, speeding up 8x



## Gap: the 2nd AG was 5ms longer



Check the all-gather input arguments! Realize that 25,300,992 bf16s != 64,646,208 Bytes

In FSDP1, print out \_fqns of a FlatParameter. In FSDP2, print <u>all\_gather\_inputs</u>. Lining up the parameters revealed...

#### Gap understood: the 2nd AG was much larger



Why the heavy load?

- our NF4 all-gathers the NF4 metadata whereas bnb Params4bit does not
- more significantly, after opting out of LoRA, our output\_proj remained frozen but full sized. train.py froze their output\_proj too, but quantized it

## Gap to be filled: detangle the q from qLoRA

Why the heavy load?

- our NF4 all-gathers the NF4 metadata whereas bnb Params4bit does not
  - This is intended! FSDP2 allows NF4Tensor subclass to <u>decide</u> which of its inner tensors are all-gathered
- more significantly, after opting out of LoRA, our output\_proj remained frozen but full sized. train.py froze their output\_proj too, but quantized it.
  - This is not intended!
  - This is a next step for torchtune to allow base weights to be quantized even if they opt out of LoRA



## Gap: additional overhead right before the gemms



Understanding why:

- NF4Tensor overrides the mm in order to dequantize before calling the gemm
- bnb has a CUDA kernel for the dequantization work

### Gap to be filled: fuse LinearNF4 overhead



Solutions:

- A next step is to leverage torch.compile. I did try it, but it does not play well with activations checkpointing at the moment
- Another next step is to package and use the Triton kernels that Driss wrote



# Gap: differing ops before sdpa (costing us 6ms per layer!)

This gap is the most boring of them all: torchtune and the default LLaMa2 config simply use different RoPE algorithms.

- torchtune uses the original Meta algorithm with no numerical differences.
- the default LlamaRotaryEmbedding is 2-3x faster (for our trace 6ms faster) but is not the same numerically.

Solution:

• A next step is for torchtune to offer more options for more optimized but less faithful Embedding algos if desired.



# Gap: exposed AGs/mem H2Ds



We wonder: why is the left side Memcpy hidden in FSDP1, but very exposed in FSDP2? Answer: the stricter memory restraints!

- Memcpy is used to bring offloaded params from CPU to GPU
- FSDP2 is guaranteeing the constraint that only 2 layers of params will be allowed at a time by having the Memcpy wait as well.

#### Gap: exposed AGs/mem H2Ds

Note that the problem here isn't that FSDP2 is too strict. It's that the computation is too small to properly hide the communication!

ModuleLis

Solution: wrap more granularly. Have bigger layers.



## We want a new wrapping policy:



This new wrapping policy is only possible with FSDP2! As now, both NF4Tensors and plain Tensors can coexist in 1 layer.

#### Side note: this is very easy to do in FSDP2



# Gap filled: hidden AGs/mem H2Ds

Solution: wrap more granularly. Have bigger layers.

| v v v void cutlass:                                        | :Kernel v             | oid cutlass::Kernel <cutla< td=""><td>V</td><td>void cutlass::Kernel<cutlass< td=""><td>i</td></cutlass<></td></cutla<> | V | void cutlass::Kernel <cutlass< td=""><td>i</td></cutlass<> | i |
|------------------------------------------------------------|-----------------------|-------------------------------------------------------------------------------------------------------------------------|---|------------------------------------------------------------|---|
|                                                            |                       |                                                                                                                         |   |                                                            |   |
|                                                            | nccl:_all_gather_base |                                                                                                                         |   |                                                            |   |
| ncclDevKernel_AllGather_RING_LL(ncclDevComm*, unsigned lor |                       |                                                                                                                         |   | n*, unsigned long, ncclWork*)                              |   |
| FSDP::pre_forw                                             | vard                  |                                                                                                                         |   |                                                            |   |
| Memcpy Memcp N                                             | lemcp Memcp           |                                                                                                                         |   |                                                            |   |
|                                                            |                       |                                                                                                                         |   |                                                            |   |
|                                                            |                       |                                                                                                                         |   |                                                            |   |
|                                                            |                       |                                                                                                                         |   |                                                            |   |

Now both are overlapped!

final\_torchtune\_trace.json



#### Positive gap: how come torchtune uses less memory?



Sister Andr. Bodoboboolog.22, Size: Iteration of ISAITAG by a state and the state of the st

:0:at::TensorBase at::detail::empty\_generic<long>(10::ÅrrayRef<long>, cl0::Alločator\*, čl0::DispatchKeySet, ??:0:at::detail::empty\_generic(cl0::ArrayRef<long>, cl0::Allocator\*, cl0::DispatchKeySet, cl0::ScalarType, stc ??:0:at::detail::empty\_cuda(cl0::ArrayRef<long>, cl0::ScalarType, std::optional<cl0::Device>, std::optional<c ??:0:at::detail::empty\_cuda(cl0::ArrayRef<long>, std::optional<cl0::ScalarType, std::optional<cl0::Layout>, { ??:0:at::detail::empty\_cuda(cl0::ArrayRef<long>, std::optional<cl0::ScalarType, std::optional<cl0::Layout>, {

RegisterCUDA.cpp:0:atii(anonymous naméspace):icreate\_out(cl0::ArrayRef<long>, cl0::ArrayRef<long>, cl0::Tensoi RegisterCUDA.cpp:0:ati:(anonymous naméspace)::structured\_ufunc\_add\_CUDA\_functional::set\_output\_raw\_strided(lor ?!00:ati:TensorIteratorBase::fast\_set\_up(at::TensorIteratorConfig const&)

??:0:at::TensorIteratorBase::build(at::TensorIteratorConfig&)

??:0:at::TensorIteratorBase::build\_borrowing\_binary\_op(at::TensorBase const&, at::TensorBase const&, at::Tensor



5049 Addr: bba7088000\_22, Size: 128.0MiB (134217728 bytes) allocation, Total memory used after allocation: CUDACachingAllocator.cpp:0:c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::malloc(signed c :0:c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::allocate(unsigned locat) or:cl0::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::allocate(unsigned long)

:0:at::TensorBase at::detail:: empty\_generic<long>(c10::AtrayRef<long>, c10::Allocator\*, c10::DispatchKeySe ??:0:at::detail::empty\_generic[c10::ArrayRef<long>, c10::ScalarType, std::optional<c10::Device>, std::optional ??:0:at::detail::empty\_cuda(c10::ArrayRef<long>, c10::ScalarType, std::optional<c10::Device>, std::optional ??:0:at::detail::empty\_cuda(c10::ArrayRef<long>, std::optional<c10::ScalarType>, std::optional<c10::Allocator\*, ??:0:at::detail::empty\_cuda(c10::ArrayRef<long>, std::optional<c10::ScalarType>, std::optional<c10::Layout>

RegisterCUDA.cpp:0:at::(anonymous namespace)::create\_out(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::Ten RegisterCUDA.cpp:0:at::(anonymous namespace)::structured\_ufunc\_add\_CUDA\_functional::set\_output\_raw\_strided( ??:0:at::TensorIteratorBase::fast\_set\_up(at::TensorIteratorConfig\_Const&) &?:0:at::TensorIteratorBase::build(at::TensorIteratorConfig&)

One, yes, FSDP2 has better guarantees. But here, it's that torchtune frees the loss early!

Positive gap: how come torchtune uses less memory?



Zooming in, the gap is the size of the loss.

#### torchtune can get up to bs=16 for llama2-7b, 2048 context len

#### llama2-7B, context length 2048, with CPU offloading

| batch size                                       | peak memory | runtime for a step | batch size                                                   | peak memory | runtime for a step |
|--------------------------------------------------|-------------|--------------------|--------------------------------------------------------------|-------------|--------------------|
| 8                                                | 12.88 GiB   | 14.0s              | 8                                                            | 10.7 GiB    | 14.8s              |
| 10                                               | 15.89 GiB   | 17.5s              | 10                                                           | 13.2 GiB    | 18.2s              |
|                                                  | below needs |                    | 12                                                           | 15.7 GiB    | 21.7s              |
| PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |             |                    | 14                                                           | 18.3 GiB    | 25.3s              |
| 12                                               | 18.91 GiB   | 20.9s              | below needs PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |             |                    |
| 14                                               | 21.92 GiB   | 23.6s              |                                                              |             |                    |
| +                                                | OOM         | N/A                | 16                                                           | 20.8 GiB    | 28.8s              |
|                                                  |             |                    | +                                                            | OOM         | N/A                |

train.py

torchtune

#### Try this out in torchtune!

https://github.com/pytorch/torchtune

though HUGE DISCLAIMER checkpointing is not working yet

# The rest of the team

@drisspg: Driss wrote the original NF4 tensor implementation

@awgu: Andrew is the main architect of FSDP2

@weifengpy: Wei showed how to compose new dtypes w/ FSDP2

@rohan-varma/@ebsmothers: wrote the LoRA recipes and merged code in tune

### Thanks!

Implement new dtypes that work with compile and FSDP: <u>https://github.com/pytorch/ao</u>

Compile them: <a href="https://pytorch.org/docs/main/torch.compiler">https://pytorch.org/docs/main/torch.compiler</a>

Author them as subclasses so they work like real PyTorch tensors: <u>https://github.com/albanD/subclass\_zoo/</u>

Go from 1 GPU to N GPUs with FSDP2: https://github.com/pytorch/pytorch/issues/114299

End to end finetuning examples: <u>https://github.com/pytorch/torchtune</u>

End to end training examples: <u>https://github.com/pytorch/torchtitan</u>

And remember to profile your memory: <u>https://pytorch.org/blog/understanding-gpu-memory-1/</u>

If you have any questions reach out to us on Discord.

If you're doing research at the intersection of quantization and distributed we'd loooove to hear from you