• Tutorials >
  • (prototype) GPU Quantization with TorchAO

(prototype) GPU Quantization with TorchAO

Author: HDCharles

In this tutorial, we will walk you through the quantization and optimization of the popular segment anything model. These steps will mimic some of those taken to develop the segment-anything-fast repo. This step-by-step guide demonstrates how you can apply these techniques to speed up your own models, especially those that use transformers. To that end, we will focus on widely applicable techniques, such as optimizing performance with torch.compile and quantization and measure their impact.

Set up Your Environment

First, let’s configure your environment. This guide was written for CUDA 12.1. We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you are using a different hardware, you might see different performance numbers.

> conda create -n myenv python=3.10
> pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
> pip install git+https://github.com/facebookresearch/segment-anything.git
> pip install git+https://github.com/pytorch-labs/ao.git

Segment Anything Model checkpoint setup:

  1. Go to the segment-anything repo and download the vit_h checkpoint. Alternatively, you can just use wget: `wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth –directory-prefix=<path>

  2. Pass in that directory by editing the code below to say:


This was run on an A100-PG509-200 power limited to 330.00 W

import torch
from torchao.quantization import change_linear_weights_to_int8_dqtensors
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer

sam_checkpoint_base_path = "data"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
batchsize = 16
only_one_block = True

def benchmark(f, *args, **kwargs):
    for _ in range(3):
        f(*args, **kwargs)

    t0 = Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
    return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}

def get_sam_model(only_one_block=False, batchsize=1):
    sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
    model = sam.image_encoder.eval()
    image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')

    # code to use just a single block of the model
    if only_one_block:
        model = model.blocks[0]
        image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
    return model, image

In this tutorial, we focus on quantizing the image_encoder because the inputs to it are statically sized while the prompt encoder and mask decoder have variable sizes which makes them harder to quantize.

We’ll focus on just a single block at first to make the analysis easier.

Let’s start by measuring the baseline runtime.

    model, image = get_sam_model(only_one_block, batchsize)
    fp32_res = benchmark(model, image)
    print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")
    # base fp32 runtime of the model is 186.16ms and peak memory 6.33GB
except Exception as e:
    print("unable to run fp32 model: ", e)
base fp32 runtime of the model is 187.03ms and peak memory 10.22GB

We can achieve an instant performance boost by converting the model to bfloat16. The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This larger dynamic range helps protect us from overflow errors and other issues that can arise when scaling and rescaling tensors due to quantization.

model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
bf16_res = benchmark(model, image)
print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")
# bf16 runtime of the block is 25.43ms and peak memory  3.17GB
bf16 runtime of the block is 45.29ms and peak memory  7.06GB

Just this quick change improves runtime by a factor of ~7x in the tests we have conducted (186.16ms to 25.43ms).

Next, let’s use torch.compile with our model to see how much the performance improves.

model_c = torch.compile(model, mode='max-autotune')
comp_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the block is 19.95ms and peak memory  2.24GB
Traceback (most recent call last):
  File "/workspace/tutorials-kr/prototype_source/gpu_quantization_torchao_tutorial.py", line 130, in <module>
    comp_res = benchmark(model_c, image)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/tutorials-kr/prototype_source/gpu_quantization_torchao_tutorial.py", line 64, in benchmark
    f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 921, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 786, in _convert_frame
    result = inner_convert(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 500, in transform
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2149, in run
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2268, in RETURN_VALUE
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 981, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1178, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1251, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 1232, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 1731, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1102, in compile_fx
    return compile_fx(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1330, in compile_fx
    return aot_autograd(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/backends/common.py", line 58, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 903, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/aot_autograd.py", line 628, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 443, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 648, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 119, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 1257, in fw_compiler_base
    return inner_compile(
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/debug.py", line 304, in inner
    return fn(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 438, in compile_fx_inner
    compiled_graph = fx_codegen_and_compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py", line 698, in fx_codegen_and_compile
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 612, in run
    return super().run(*args)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 957, in run_node
    result = super().run_node(n)
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 819, in call_function
    raise LoweringException(e, target, args, kwargs).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/graph.py", line 816, in call_function
    out = lowerings[target](*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/lowering.py", line 296, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/kernel/mm.py", line 158, in tuned_mm
    return autotune_select_algorithm("mm", choices, [mat1, mat2], layout)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 1146, in autotune_select_algorithm
    return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 896, in __call__
    timings = self.lookup(
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/codecache.py", line 296, in lookup
    timings = benchmark(choices)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 887, in autotune
    return make_benchmark_fn()(choices)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 997, in benchmark_in_current_process
    timing = benchmark_choice_in_current_process(choice)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 987, in benchmark_choice_in_current_process
    result = choice.benchmark(*example_inputs, out=out)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/select_algorithm.py", line 687, in benchmark
    return self.bmreq.benchmark(*args, output_tensor=out)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/autotune_process.py", line 455, in benchmark
    out = do_bench(fn)
  File "/usr/local/lib/python3.10/dist-packages/torch/_inductor/utils.py", line 170, in do_bench
    return triton_do_bench(*args, **kwargs)[0]
  File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 363, in run
    device = driver.get_current_device()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 209, in __getattr__
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 206, in _initialize_obj
    self._obj = self._init_fn()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 239, in initialize_driver
    return CudaDriver()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 102, in __init__
    self.utils = CudaUtils()
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 49, in __init__
    so = _build("cuda_utils", src_path, tmpdir)
  File "/usr/local/lib/python3.10/dist-packages/triton/common/build.py", line 106, in _build
    ret = subprocess.check_call(cc_cmd)
  File "/usr/lib/python3.10/subprocess.py", line 369, in check_call
    raise CalledProcessError(retcode, cmd)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpipv93ki_/main.c', '-O3', '-I/usr/local/lib/python3.10/dist-packages/triton/common/../third_party/cuda/include', '-I/usr/include/python3.10', '-I/tmp/tmpipv93ki_', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpipv93ki_/cuda_utils.cpython-310-x86_64-linux-gnu.so', '-L/lib/x86_64-linux-gnu', '-L/lib/x86_64-linux-gnu']' returned non-zero exit status 1.
  target: aten.mm.default
  args[0]: TensorBox(
          ComputedBuffer(name='buf3', layout=FixedLayout('cuda', torch.bfloat16, size=[16, 5, 5, 14, 14, 1280], stride=[6272000, 1254400, 250880, 17920, 1280, 1]), data=Pointwise(
            def inner_fn(index):
                i0, i1, i2, i3, i4, i5 = index
                tmp0 = ops.index_expr(i3 + 14 * i1, torch.int64)
                tmp1 = ops.index_expr(64, torch.int64)
                tmp2 = tmp0 < tmp1
                tmp3 = ops.index_expr(i4 + 14 * i2, torch.int64)
                tmp4 = ops.index_expr(64, torch.int64)
                tmp5 = tmp3 < tmp4
                tmp6 = tmp2 & tmp5
                tmp7 = ops.load(arg14_1, i5 + 1280 * i4 + 17920 * i2 + 81920 * i3 + 1146880 * i1 + 5242880 * i0)
                tmp8 = ops.to_dtype(tmp7, torch.float32, src_dtype=torch.bfloat16)
                tmp9 = ops.load(buf0, i4 + 14 * i2 + 64 * i3 + 896 * i1 + 4096 * i0)
                tmp10 = tmp8 - tmp9
                tmp11 = ops.load(buf1, i4 + 14 * i2 + 64 * i3 + 896 * i1 + 4096 * i0)
                tmp12 = ops.constant(0, torch.float32)
                tmp13 = ops.constant(1280, torch.float32)
                tmp14 = ops.constant(0, torch.float32)
                tmp15 = tmp13 - tmp12
                tmp16 = ops.maximum(tmp14, tmp15)
                tmp17 = tmp11 / tmp16
                tmp18 = ops.constant(1e-06, torch.float32)
                tmp19 = tmp17 + tmp18
                tmp20 = ops.rsqrt(tmp19)
                tmp21 = tmp10 * tmp20
                tmp22 = ops.load(arg2_1, i5)
                tmp23 = ops.to_dtype(tmp22, torch.float32, src_dtype=torch.bfloat16)
                tmp24 = tmp21 * tmp23
                tmp25 = ops.load(arg3_1, i5)
                tmp26 = ops.to_dtype(tmp25, torch.float32, src_dtype=torch.bfloat16)
                tmp27 = tmp24 + tmp26
                tmp28 = ops.to_dtype(tmp27, torch.bfloat16, src_dtype=torch.float32)
                tmp29 = ops.masked(tmp6, tmp28, 0.0)
                return tmp29
            ranges=[16, 5, 5, 14, 14, 1280],
        size=[400, 14, 14, 1280],
        reindex=lambda i0, i1, i2, i3: [ModularIndexing(i0, 25, 16), ModularIndexing(i0, 5, 5), ModularIndexing(i0, 1, 5), i1, i2, i3],
        origins={clone, view_1}
      size=[78400, 1280],
      reindex=lambda i0, i1: [ModularIndexing(i0, 196, 400), ModularIndexing(i0, 14, 14), ModularIndexing(i0, 1, 14), i1],
  args[1]: TensorBox(
        InputBuffer(name='arg4_1', layout=FixedLayout('cuda', torch.bfloat16, size=[3840, 1280], stride=[1280, 1]))
      FixedLayout('cuda', torch.bfloat16, size=[1280, 3840], stride=[1, 1280]),

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The first time this is run, you should see a sequence of AUTOTUNE outputs which occurs when inductor compares the performance between various kernel parameters for a kernel. This only happens once (unless you delete your cache) so if you run the cell again you should just get the benchmark output.

torch.compile yields about another 27% improvement. This brings the model to a reasonable baseline where we now have to work a bit harder for improvements.

Next, let’s apply quantization. Quantization for GPUs comes in three main forms in torchao which is just native pytorch+python code. This includes:

  • int8 dynamic quantization

  • int8 weight-only quantization

  • int4 weight-only quantization

Different models, or sometimes different layers in a model can require different techniques. For models which are heavily compute bound, dynamic quantization tends to work the best since it swaps the normal expensive floating point matmul ops with integer versions. Weight-only quantization works better in memory bound situations where the benefit comes from loading less weight data, rather than doing less computation. The torchao APIs:

change_linear_weights_to_int8_dqtensors, change_linear_weights_to_int8_woqtensors or change_linear_weights_to_int4_woqtensors

can be used to easily apply the desired quantization technique and then once the model is compiled with torch.compile with max-autotune, quantization is complete and we can see our speedup.


You might experience issues with these on older versions of PyTorch. If you run into an issue, you can use apply_dynamic_quant and apply_weight_only_int8_quant instead as drop in replacement for the two above (no replacement for int4).

The difference between the two APIs is that change_linear_weights API

alters the weight tensor of the linear module so instead of doing a normal linear, it does a quantized operation. This is helpful when you have non-standard linear ops that do more than one thing. The apply APIs directly swap the linear modules for a quantized module which works on older versions but doesn’t work with non-standard linear modules.

In this case Segment Anything is compute-bound so we’ll use dynamic quantization:

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the quantized block is 19.04ms and peak memory  3.58GB

With quantization, we have improved performance a bit more but memory usage increased significantly.

This is for two reasons:

  1. Quantization adds overhead to the model since we need to quantize and dequantize the input and output. For small batch sizes this overhead can actually make the model go slower.

  2. Even though we are doing a quantized matmul, such as int8 x int8, the result of the multiplication gets stored in an int32 tensor which is twice the size of the result from the non-quantized model. If we can avoid creating this int32 tensor, our memory usage will improve a lot.

We can fix #2 by fusing the integer matmul with the subsequent rescale operation since the final output will be bf16, if we immediately convert the int32 tensor to bf16 and instead store that we’ll get better performance in terms of both runtime and memory.

The way to do this, is to enable the option force_fuse_int_mm_with_mul in the inductor config.

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory  2.37GB

The fusion improves performance by another small bit (about 6% over the baseline in total) and removes almost all the memory increase, the remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to quantization overhead which cannot be helped.

We’re still not done though, we can apply a few general purpose optimizations to get our final best-case performance.

  1. We can sometimes improve performance by disabling epilogue fusion since the autotuning process can be confused by fusions and choose bad kernel parameters.

  2. We can apply coordinate descent tuning in all directions to enlarge the search area for kernel parameters.

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory  2.39GB

As you can see, we’ve squeezed another small improvement from the model, taking our total improvement to over 10x compared to our original. To get a final estimate of the impact of quantization lets do an apples to apples comparison on the full model since the actual improvement will differ block by block depending on the shapes involved.

    del model_c, model, image
    model, image = get_sam_model(False, batchsize)
    model = model.to(torch.bfloat16)
    image = image.to(torch.bfloat16)
    model_c = torch.compile(model, mode='max-autotune')
    quant_res = benchmark(model_c, image)
    print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
    # bf16 compiled runtime of the compiled full model is 729.65ms and peak memory  23.96GB

    del model_c, model, image
    model, image = get_sam_model(False, batchsize)
    model = model.to(torch.bfloat16)
    image = image.to(torch.bfloat16)
    model_c = torch.compile(model, mode='max-autotune')
    quant_res = benchmark(model_c, image)
    print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
    # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory  24.93GB
except Exception as e:
    print("unable to run full model: ", e)


In this tutorial, we have learned about the quantization and optimization techniques on the example of the segment anything model.

# In the end, we achieved a full-model apples to apples quantization speedup
# of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a
# bit further by increasing the batch size and optimizing other parts of
# the model. For example, this can be done with some form of flash attention.
# For more information visit
# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own
# models.

Total running time of the script: ( 0 minutes 12.915 seconds)

Gallery generated by Sphinx-Gallery

더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!

이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동


다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동