Shortcuts

(beta) torch.compile과 함께 TORCH_LOGS 파이썬 API 사용하기

저자: Michael Lazos 번역: 장효영

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile. 이 튜토리얼에서는 TORCH_LOGS 환경 변수와 함께 Python API를 소개하고, 이를 적용하여 ``torch.compile``의 단계를 관찰하는 방법을 보여줍니다.

참고

이 튜토리얼에는 PyTorch 2.2.0 이상 버전이 필요합니다.

설정

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API. 이 예제에서는 요소별 덧셈을 수행하는 간단한 파이썬 함수를 설정하고 TORCH_LOGS 파이썬 API를 사용하여 컴파일 프로세스를 관찰해 보겠습니다.

참고

명령줄에서 로깅 설정을 변경하는 데 사용할 수 있는 환경 변수 ``TORCH_LOGS``도 있습니다. 각 예제에 해당하는 환경 변수 설정이 표시되어 있습니다.

import torch

# torch.compile을 지원하지 않는 기기인 경우 완전히 종료합니다.
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# 각 예제 사이의 구분 기호를 출력하고 dynamo를 reset합니다
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# dynamo tracing 보기
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# traced 그래프 보기
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# fusion decision 보기
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# inductor가 생성한 결과 코드 보기
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0] torchdynamo start compiling fn /workspace/tutorials-kr/recipes_source/torch_logs.py:44, stack (elided 5 frames):
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/bin/sphinx-build", line 8, in <module>
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     sys.exit(main())
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/cmd/build.py", line 288, in main
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return make_main(argv)
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/cmd/build.py", line 193, in make_main
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return make_mode.run_make_mode(argv[1:])
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return make.run_generic_build(args[0])
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return build_main(args + opts)
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/cmd/build.py", line 272, in build_main
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/application.py", line 256, in __init__
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     self._init_builder()
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/application.py", line 314, in _init_builder
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     self.events.emit('builder-inited')
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx/events.py", line 94, in emit
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     results.append(listener.handler(self.app, *args))
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     ) = generate_dir_rst(
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     intro, title, cost = generate_file_rst(
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     output_blocks.append(execute_code_block(
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     mem_max, _ = gallery_conf['call_memory'](
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return 0., func()
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     exec(self.code, self.fake_main.__dict__)
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/workspace/tutorials-kr/recipes_source/torch_logs.py", line 63, in <module>
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     fn(*inputs)
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]   File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 451, in _fn
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]     return fn(*args, **kwargs)
V0803 13:23:12.843000 139636907595584 torch/_dynamo/convert_frame.py:652] [0/0]
I0803 13:23:12.845000 139636907595584 torch/_dynamo/logging.py:55] [0/0] Step 1: torchdynamo start tracing fn /workspace/tutorials-kr/recipes_source/torch_logs.py:44
V0803 13:23:12.845000 139636907595584 torch/fx/experimental/symbolic_shapes.py:1980] [0/0] create_env
V0803 13:23:12.846000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source] TRACE starts_line /workspace/tutorials-kr/recipes_source/torch_logs.py:44 in fn ()
V0803 13:23:12.846000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source]         @torch.compile()
V0803 13:23:12.847000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source] TRACE starts_line /workspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn (fn)
V0803 13:23:12.847000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source]             z = x + y
V0803 13:23:12.847000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE LOAD_FAST x []
V0803 13:23:12.848000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE LOAD_FAST y [LazyVariableTracker()]
V0803 13:23:12.848000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0803 13:23:12.848000 139636907595584 torch/_dynamo/output_graph.py:1959] [0/0] create_graph_input L_x_ L['x']
V0803 13:23:12.848000 139636907595584 torch/_dynamo/variables/builder.py:1873] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0803 13:23:12.849000 139636907595584 torch/_dynamo/output_graph.py:1959] [0/0] create_graph_input L_y_ L['y']
V0803 13:23:12.849000 139636907595584 torch/_dynamo/variables/builder.py:1873] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], constraint_sizes=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0803 13:23:12.851000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE STORE_FAST z [TensorVariable()]
V0803 13:23:12.851000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source] TRACE starts_line /workspace/tutorials-kr/recipes_source/torch_logs.py:47 in fn (fn)
V0803 13:23:12.851000 139636907595584 torch/_dynamo/symbolic_convert.py:699] [0/0] [__trace_source]             return z + 2
V0803 13:23:12.852000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE LOAD_FAST z []
V0803 13:23:12.852000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE LOAD_CONST 2 [TensorVariable()]
V0803 13:23:12.852000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable()]
V0803 13:23:12.853000 139636907595584 torch/_dynamo/symbolic_convert.py:725] [0/0] TRACE RETURN_VALUE None [TensorVariable()]
I0803 13:23:12.853000 139636907595584 torch/_dynamo/logging.py:55] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0803 13:23:12.853000 139636907595584 torch/_dynamo/symbolic_convert.py:2267] [0/0] RETURN_VALUE triggered compile
V0803 13:23:12.853000 139636907595584 torch/_dynamo/output_graph.py:871] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /workspace/tutorials-kr/recipes_source/torch_logs.py, line 47 in fn>], graph_break=False)
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code] TRACED GRAPH
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]  ===== __compiled_fn_10 =====
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]  /opt/conda/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]     def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         l_x_ = L_x_
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         l_y_ = L_y_
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         # File: /workspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn, code: z = x + y
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         z = l_x_ + l_y_;  l_x_ = l_y_ = None
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         # File: /workspace/tutorials-kr/recipes_source/torch_logs.py:47 in fn, code: return z + 2
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         add_1 = z + 2;  z = None
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]         return (add_1,)
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]
V0803 13:23:12.854000 139636907595584 torch/_dynamo/output_graph.py:1157] [0/0] [__graph_code]
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] TRACED GRAPH
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph]  __compiled_fn_10 /opt/conda/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py opcode         name    target                   args          kwargs
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] -------------  ------  -----------------------  ------------  --------
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] placeholder    l_x_    L_x_                     ()            {}
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] placeholder    l_y_    L_y_                     ()            {}
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] call_function  z       <built-in function add>  (l_x_, l_y_)  {}
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] call_function  add_1   <built-in function add>  (z, 2)        {}
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] output         output  output                   ((add_1,),)   {}
V0803 13:23:12.855000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph]
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] TRACED GRAPH TENSOR SIZES
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] ===== __compiled_fn_10 =====
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] l_x_: (2, 2)
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] l_y_: (2, 2)
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] z: (2, 2)
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes] add_1: (2, 2)
V0803 13:23:12.856000 139636907595584 torch/_dynamo/output_graph.py:1164] [0/0] [__graph_sizes]
I0803 13:23:12.857000 139636907595584 torch/_dynamo/logging.py:55] [0/0] Step 2: calling compiler function inductor
V0803 13:23:12.875000 139636907595584 torch/fx/experimental/symbolic_shapes.py:4119] [0/0] eval True == True [statically known]
I0803 13:23:13.203000 139636907595584 torch/_dynamo/logging.py:55] [0/0] Step 2: done compiler function inductor
I0803 13:23:13.205000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2806] [0/0] produce_guards
V0803 13:23:13.205000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['x'].size()[0] 2 None
V0803 13:23:13.205000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['x'].size()[1] 2 None
V0803 13:23:13.205000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['x'].stride()[0] 2 None
V0803 13:23:13.205000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['x'].stride()[1] 1 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['x'].storage_offset() 0 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['y'].size()[0] 2 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['y'].size()[1] 2 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['y'].stride()[0] 2 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['y'].stride()[1] 1 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:2988] [0/0] track_symint L['y'].storage_offset() 0 None
V0803 13:23:13.206000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['x'].size()[0] == 2
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['x'].size()[1] == 2
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['x'].stride()[0] == 2
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['x'].stride()[1] == 1
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['x'].storage_offset() == 0
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['y'].size()[0] == 2
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['y'].size()[1] == 2
V0803 13:23:13.207000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['y'].stride()[0] == 2
V0803 13:23:13.208000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['y'].stride()[1] == 1
V0803 13:23:13.208000 139636907595584 torch/fx/experimental/symbolic_shapes.py:3138] [0/0] Skipping guard L['y'].storage_offset() == 0
V0803 13:23:13.208000 139636907595584 torch/_dynamo/guards.py:1076] [0/0] [__guards] GUARDS:
V0803 13:23:13.208000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # orkspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn
V0803 13:23:13.209000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # orkspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn
V0803 13:23:13.209000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:430 in init_ambient_guards
V0803 13:23:13.210000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] ___check_current_backend(139623618138688)                     # _dynamo/output_graph.py:436 in init_ambient_guards
V0803 13:23:13.210000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # orkspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn
V0803 13:23:13.211000 139636907595584 torch/_dynamo/guards.py:1085] [0/0] [__guards] check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # orkspace/tutorials-kr/recipes_source/torch_logs.py:46 in fn
===================Traced Graph=========================
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] TRACED GRAPH
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph]  __compiled_fn_11 /opt/conda/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py opcode         name    target                   args          kwargs
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] -------------  ------  -----------------------  ------------  --------
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] placeholder    l_x_    L_x_                     ()            {}
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] placeholder    l_y_    L_y_                     ()            {}
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] call_function  z       <built-in function add>  (l_x_, l_y_)  {}
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] call_function  add_1   <built-in function add>  (z, 2)        {}
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph] output         output  output                   ((add_1,),)   {}
V0803 13:23:13.232000 139636907595584 torch/_dynamo/output_graph.py:1163] [0/0] [__graph]
===================Fusion Decisions=========================
V0803 13:23:13.355000 139636907595584 torch/_inductor/scheduler.py:1683] [0/0] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0803 13:23:13.355000 139636907595584 torch/_inductor/scheduler.py:1885] [0/0] [__fusion] found 0 possible fusions
V0803 13:23:13.356000 139636907595584 torch/_inductor/scheduler.py:1688] [0/0] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0803 13:23:13.356000 139636907595584 torch/_inductor/scheduler.py:1688] [0/0] [__fusion]
V0803 13:23:13.356000 139636907595584 torch/_inductor/scheduler.py:1695] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
===================Output Code=========================
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] Output code:
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from ctypes import c_void_p, c_long
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import torch
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import math
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import random
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import os
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import tempfile
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from math import inf, nan
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch import device, empty_strided
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.codecache import AsyncCompile
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] aten = torch.ops.aten
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] reinterpret_tensor = torch.ops.inductor._reinterpret_tensor
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] async_compile = AsyncCompile()
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] # kernel path: /tmp/torchinductor_root/ms/cmsgrbdorrqtkj5nb2og2nucyc4kkdabkby2b7ynwk6ckyr6wj2u.py
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] # Source Nodes: [add_1, z], Original ATen: [aten.add]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] # add_1 => add_1
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] # z => add
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_', '''
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import triton
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import triton.language as tl
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from triton.compiler.compiler import AttrsDescriptor
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor import triton_helpers, triton_heuristics
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.ir import ReductionHint, TileHint
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.triton_helpers import libdevice, math as tl_math
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.triton_heuristics import AutotuneHint
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.utils import instance_descriptor
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] @triton_heuristics.pointwise(
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     size_hints=[4],
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     filename=__file__,
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     triton_meta={'signature': {0: '*fp32', 1: '*fp32', 2: '*fp32', 3: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=())]},
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': 'ab56b6b9315b80b94abf24aacc1ceb5a37d2037dfc8cf4997f86e7c42fbaa402'},
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     min_elem_per_thread=0
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] )
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] @triton.jit
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     xnumel = 4
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     xmask = xindex < xnumel
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     x0 = xindex
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tmp3 = 2.0
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] ''', device_str='cuda')
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import triton
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] import triton.language as tl
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._inductor.triton_heuristics import grid, split_scan_grid, start_graph, end_graph
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] async_compile.wait(globals())
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] del async_compile
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] def call(args):
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     args.clear()
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         torch.cuda.set_device(0)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         # Source Nodes: [add_1, z], Original ATen: [aten.add]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream0)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         del arg0_1
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]         del arg1_1
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     return (buf0, )
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code] if __name__ == "__main__":
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0803 13:23:13.965000 139636907595584 torch/_inductor/graph.py:1267] [0/0] [__output_code]
I0803 13:23:13.968000 139636907595584 torch/_inductor/graph.py:1273] [0/0] [__output_code] Output code written to: /tmp/torchinductor_root/bt/cbtxarzr5qsbqr2wamsfhyug537ntdyyju3lna5z4kealqq4qs63.py
============================================

결론

이 튜토리얼에서는 사용 가능한 몇 가지 로깅 옵션을 실험하여 TORCH_LOGS 환경 변수와 Python API를 소개했습니다. 사용 가능한 모든 옵션에 대한 설명을 보려면 파이썬 스크립트에서 import torch를 실행하고 TORCH_LOGS를 《help》로 설정하세요.

다른 방법으로는, torch._logging 문서 를 보면, 사용 가능한 모든 로깅 옵션에 대한 설명을 확인할 수 있습니다.

torch.compile에 관한 더 많은 정보는, `torch.compile 튜토리얼`_를 보세요.

Total running time of the script: ( 0 minutes 1.146 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동