Shortcuts

wav2vec2을 이용한 강제 정렬

저자: Moto Hira 번역: 김규진

이번 튜토리얼에서는 CTC-Segmentation of Large Corpora for German End-to-end Speech Recognition 에서 설명한 CTC 분할 알고리즘을 이용하여 torchaudio를 가지고 정답 스크립트를 음성에 맞추는 방법에 대해 설명합니다.

참고

이 튜토리얼은 원래 Wav2Vec2의 사용 사례를 설명하기 위해 작성되었습니다.

TorchAudio에는 강제 정렬을 위해 설계된 API가 있습니다. CTC forced alignment API tutorial 은 핵심 API인 torchaudio.functional.forced_align() 의 사용법에 대해 보여주고 있습니다.

만약 본인만의 코퍼스에 대해 강제 정렬하려는 경우, torchaudio.pipelines.Wav2Vec2FABundle 를 사용하는 것을 추천합니다. 이는 강제 정렬을 위해 특별히 훈련된 사전 훈련 모델과 함께 forced_align() 및 여러 함수를 결합하여 사용할 수 있게 합니다. 사용법에 대한 자세한 내용은 다국어 데이터를 위한 강제 정렬을 설명하는 Forced alignment for multilingual data 를 참조하세요.

import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
2.0.1+cu117
2.0.2+cu117
cuda

개요

정렬 과정은 다음과 같습니다.

  1. 오디오 파형으로부터 프레임별 라벨 확률을 추정한다.

  2. 각 시간 별로 정렬된 라벨의 확률을 나타내는 trellis 행렬을 생성한다.

  3. trellis 행렬로부터 가장 가능성이 높은 경로를 찾는다.

이번 예시에는 음성 특징 추출을 위해 torchaudio의 wav2vec2 모델을 사용합니다.

준비

먼저 필요한 패키지를 임포트하고, 작업할 데이터를 불러옵니다.

from dataclasses import dataclass

import IPython
import matplotlib.pyplot as plt

torch.random.manual_seed(0)

SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
  0%|          | 0.00/106k [00:00<?, ?B/s]
 23%|##2       | 24.0k/106k [00:00<00:00, 128kB/s]
 53%|#####2    | 56.0k/106k [00:00<00:00, 153kB/s]
100%|##########| 106k/106k [00:00<00:00, 282kB/s]

프레임 별 라벨 확률 생성

첫번째 과정은 각 오디오 프레임 별 라벨 클래스 확률을 생성하는 것입니다. ASR(음성 인식)용으로 학습된 wav2vec2 모델을 사용할 수 있습니다. 여기서는 torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H() 를 사용합니다.

``torchaudio``는 연관된 라벨과 함께 미리 학습된 모델에 쉽게 접근할 수 있게 합니다.

참고

여기서는 수치적인 불안정성을 피하고자 로그 도메인에서 확률을 계산할 것입니다. 이렇게 하기 위해 torch.log_softmax() 를 사용하여 출력 확률 을 정규화합니다.

bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
model = bundle.get_model().to(device)
labels = bundle.get_labels()
with torch.inference_mode():
    waveform, _ = torchaudio.load(SPEECH_FILE)
    emissions, _ = model(waveform.to(device))
    emissions = torch.log_softmax(emissions, dim=-1)

emission = emissions[0].cpu().detach()

print(labels)
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /workspace/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth

  0%|          | 0.00/360M [00:00<?, ?B/s]
  2%|1         | 7.02M/360M [00:00<00:05, 73.5MB/s]
  4%|4         | 14.6M/360M [00:00<00:04, 76.8MB/s]
  6%|6         | 21.9M/360M [00:00<00:04, 72.3MB/s]
  8%|8         | 29.7M/360M [00:00<00:04, 75.8MB/s]
 10%|#         | 37.6M/360M [00:00<00:04, 78.2MB/s]
 13%|#2        | 45.8M/360M [00:00<00:04, 80.9MB/s]
 15%|#4        | 53.6M/360M [00:00<00:03, 81.0MB/s]
 17%|#7        | 61.3M/360M [00:00<00:03, 80.6MB/s]
 19%|#9        | 69.0M/360M [00:00<00:03, 79.2MB/s]
 21%|##1       | 76.7M/360M [00:01<00:03, 79.6MB/s]
 24%|##3       | 84.6M/360M [00:01<00:03, 80.7MB/s]
 26%|##5       | 92.3M/360M [00:01<00:03, 77.8MB/s]
 28%|##7       | 99.8M/360M [00:01<00:03, 74.2MB/s]
 30%|##9       | 108M/360M [00:01<00:03, 76.3MB/s]
 32%|###1      | 115M/360M [00:01<00:03, 76.5MB/s]
 34%|###3      | 122M/360M [00:01<00:03, 75.7MB/s]
 36%|###6      | 130M/360M [00:01<00:03, 76.6MB/s]
 38%|###8      | 137M/360M [00:01<00:03, 75.8MB/s]
 40%|####      | 145M/360M [00:01<00:02, 77.2MB/s]
 42%|####2     | 152M/360M [00:02<00:02, 76.3MB/s]
 44%|####4     | 160M/360M [00:02<00:02, 72.7MB/s]
 46%|####6     | 167M/360M [00:02<00:02, 71.8MB/s]
 48%|####8     | 173M/360M [00:02<00:02, 71.5MB/s]
 50%|#####     | 180M/360M [00:02<00:02, 70.2MB/s]
 52%|#####2    | 188M/360M [00:02<00:02, 71.8MB/s]
 54%|#####3    | 194M/360M [00:02<00:02, 70.8MB/s]
 56%|#####6    | 202M/360M [00:02<00:02, 72.7MB/s]
 58%|#####7    | 209M/360M [00:02<00:02, 70.8MB/s]
 60%|######    | 216M/360M [00:03<00:02, 73.7MB/s]
 62%|######2   | 224M/360M [00:03<00:01, 75.7MB/s]
 64%|######4   | 231M/360M [00:03<00:01, 76.0MB/s]
 66%|######6   | 239M/360M [00:03<00:01, 72.6MB/s]
 68%|######8   | 246M/360M [00:03<00:01, 69.8MB/s]
 70%|#######   | 253M/360M [00:03<00:01, 72.4MB/s]
 72%|#######2  | 260M/360M [00:03<00:01, 68.3MB/s]
 74%|#######4  | 268M/360M [00:03<00:01, 72.1MB/s]
 76%|#######6  | 275M/360M [00:03<00:01, 71.9MB/s]
 78%|#######8  | 282M/360M [00:03<00:01, 72.8MB/s]
 80%|########  | 289M/360M [00:04<00:01, 71.9MB/s]
 82%|########2 | 297M/360M [00:04<00:00, 74.3MB/s]
 84%|########4 | 304M/360M [00:04<00:00, 74.3MB/s]
 86%|########6 | 311M/360M [00:04<00:00, 74.0MB/s]
 88%|########8 | 318M/360M [00:04<00:00, 72.4MB/s]
 90%|######### | 325M/360M [00:04<00:00, 73.2MB/s]
 92%|#########2| 332M/360M [00:04<00:00, 73.4MB/s]
 94%|#########4| 340M/360M [00:04<00:00, 74.8MB/s]
 96%|#########6| 347M/360M [00:04<00:00, 76.4MB/s]
 98%|#########8| 355M/360M [00:04<00:00, 76.6MB/s]
100%|##########| 360M/360M [00:05<00:00, 74.7MB/s]
('-', '|', 'E', 'T', 'A', 'O', 'N', 'I', 'H', 'S', 'R', 'D', 'L', 'U', 'M', 'W', 'C', 'F', 'G', 'Y', 'P', 'B', 'V', 'K', "'", 'X', 'J', 'Q', 'Z')

시각화

def plot():
    fig, ax = plt.subplots()
    img = ax.imshow(emission.T)
    ax.set_title("Frame-wise class probability")
    ax.set_xlabel("Time")
    ax.set_ylabel("Labels")
    fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
    fig.tight_layout()


plot()
Frame-wise class probability

정렬 확률 생성 (trellis)

다음, 출력 행렬로부터 각 프레임에서 정답 스크립트의 라벨들이 등장할 수 있는 확률을 나타내는 trellis를 생성합니다. Trellis는 시간 축과 라벨 축을 가지고 있는 2D 행렬입니다. 라벨 축은 정렬하려는 정답 스크립트를 나타냅니다. \(t\) 를 시간 축에서의 인덱스로 나타내는 데 사용하고, \(j\) 를 라벨 축에서의 인덱스로 나타내는 데 사용합니다. \(c_j\) 는 라벨 인덱스 \(j\) 에서의 라벨을 나타냅니다.

\(t+1\) 시점에서의 확률을 생성하기 위해, \(t\) 시점에서의 trellis와 \(t+1\) 시점에서의 출력을 봅니다. \(t+1\) 시점에서 \(c_{j+1}\) 라벨에 도달할 수 있는 2개의 경로가 있습니다. 첫번째는 \(t\) 시점에서 라벨은 \(c_{j+1}\) 였으며 \(t\) 에서 \({t+1}\) 으로 바뀔 때 라벨 변화는 없는 경우입니다. 다른 경우는 \(t\) 시점에서 라벨은 \(c_j\) 였으며 \(t+1\) 시점에서는 다음 라벨 \(c_{j+1}\) 로 전이된 경우입니다.

아래 그림은 2가지 전이를 나타내고 있습니다.

https://download.pytorch.org/torchaudio/tutorial-assets/ctc-forward.png

가장 가능성 있는 전이를 찾기 위해, \(k_{(t+1, j+1)}\) 의 값의 가장 가능성 있는 경로를 택합니다. 이는 아래에 나와 있는 식으로 나타낼 수 있습니다.

\(k_{(t+1, j+1)} = max( k_{(t, j)} p(t+1, c_{j+1}), k_{(t, j+1)} p(t+1, repeat) )\)

\(k\) 는 trellis 행렬을 나타내며, \(p(t, c_j)\)\(t\) 시점에서의 라벨 \(c_j\) 의 확률을 나타냅니다. repeat는 CTC 식에서의 블랭크 토큰을 나타냅니다. (CTC 알고리즘에 대한 자세한 설명은 〈Sequence Modeling with CTC’를 참고하세요.) [distill.pub])

# SOS와 EOS를 나타내는 space 토큰을 가지고 정답 스크립트를 둘러쌈.
transcript = "|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|"
dictionary = {c: i for i, c in enumerate(labels)}

tokens = [dictionary[c] for c in transcript]
print(list(zip(transcript, tokens)))


def get_trellis(emission, tokens, blank_id=0):
    num_frame = emission.size(0)
    num_tokens = len(tokens)

    trellis = torch.zeros((num_frame, num_tokens))
    trellis[1:, 0] = torch.cumsum(emission[1:, blank_id], 0)
    trellis[0, 1:] = -float("inf")
    trellis[-num_tokens + 1 :, 0] = float("inf")

    for t in range(num_frame - 1):
        trellis[t + 1, 1:] = torch.maximum(
            # 같은 토큰에 머무르고 있는 경우의 점수
            trellis[t, 1:] + emission[t, blank_id],
            # 다음 토큰으로 바뀌는 경우에 대한 점수
            trellis[t, :-1] + emission[t, tokens[1:]],
        )
    return trellis


trellis = get_trellis(emission, tokens)
[('|', 1), ('I', 7), ('|', 1), ('H', 8), ('A', 4), ('D', 11), ('|', 1), ('T', 3), ('H', 8), ('A', 4), ('T', 3), ('|', 1), ('C', 16), ('U', 13), ('R', 10), ('I', 7), ('O', 5), ('S', 9), ('I', 7), ('T', 3), ('Y', 19), ('|', 1), ('B', 21), ('E', 2), ('S', 9), ('I', 7), ('D', 11), ('E', 2), ('|', 1), ('M', 14), ('E', 2), ('|', 1), ('A', 4), ('T', 3), ('|', 1), ('T', 3), ('H', 8), ('I', 7), ('S', 9), ('|', 1), ('M', 14), ('O', 5), ('M', 14), ('E', 2), ('N', 6), ('T', 3), ('|', 1)]

시각화

def plot():
    fig, ax = plt.subplots()
    img = ax.imshow(trellis.T, origin="lower")
    ax.annotate("- Inf", (trellis.size(1) / 5, trellis.size(1) / 1.5))
    ax.annotate("+ Inf", (trellis.size(0) - trellis.size(1) / 5, trellis.size(1) / 3))
    fig.colorbar(img, ax=ax, shrink=0.6, location="bottom")
    fig.tight_layout()


plot()
forced alignment with torchaudio tutorial

위 시각화된 그림에서, 행렬을 대각선으로 가로지르는 높은 확률의 추적(trace)를 볼 수 있습니다.

가장 가능성 높은 경로 찾기 (백트래킹)

trellis가 한번 생성되면, 높은 확률을 가지는 요소를 따라 이를 탐색할 수 있습니다.

가장 높은 확률을 가지는 시간 단계에서 마지막 라벨 인덱스로부터 시작합니다. 그 후에, 이전 전이 확률 \(k_{t, j} p(t+1, c_{j+1})\) 또는 \(k_{t, j+1} p(t+1, repeat)\)) 또는 전이할지 (\(c_j \rightarrow c_{j+1}\))를 시간 역순으로 진행합니다.

라벨이 한번 시작 부분에 도달하게 되면, 전이가 수행됩니다.

trellis 행렬은 경로를 찾기 위해 사용되지만, 각 분할의 최종 확률에 대해서는 출력 행렬에서의 프레임별 확률을 사용합니다.

@dataclass
class Point:
    token_index: int
    time_index: int
    score: float


def backtrack(trellis, emission, tokens, blank_id=0):
    t, j = trellis.size(0) - 1, trellis.size(1) - 1

    path = [Point(j, t, emission[t, blank_id].exp().item())]
    while j > 0:
        # 발생하지 않는 경우지만, 혹시 몰라 예외 처리함.
        assert t > 0

        # 1. 현재 위치가 stay인지 또는 change인지를 판단함.
        # stay 대 change의 프레임 별 점수 계산
        p_stay = emission[t - 1, blank_id]
        p_change = emission[t - 1, tokens[j]]

        # stay 대 change의 문맥을 고려한 점수 계산
        stayed = trellis[t - 1, j] + p_stay
        changed = trellis[t - 1, j - 1] + p_change

        # 위치 갱신
        t -= 1
        if changed > stayed:
            j -= 1

        # 프레임별 확률을 이용하여 경로를 저장함.
        prob = (p_change if changed > stayed else p_stay).exp().item()
        path.append(Point(j, t, prob))

    # 지금 j == 0이라면 이는, SoS를 도달했다는 것을 의미함.
    # 시각화를 위해 나머지 부분을 채움.
    while t > 0:
        prob = emission[t - 1, blank_id].exp().item()
        path.append(Point(j, t - 1, prob))
        t -= 1

    return path[::-1]


path = backtrack(trellis, emission, tokens)
for p in path:
    print(p)
Point(token_index=0, time_index=0, score=0.9999996423721313)
Point(token_index=0, time_index=1, score=0.9999996423721313)
Point(token_index=0, time_index=2, score=0.9999996423721313)
Point(token_index=0, time_index=3, score=0.9999996423721313)
Point(token_index=0, time_index=4, score=0.9999996423721313)
Point(token_index=0, time_index=5, score=0.9999996423721313)
Point(token_index=0, time_index=6, score=0.9999996423721313)
Point(token_index=0, time_index=7, score=0.9999996423721313)
Point(token_index=0, time_index=8, score=0.9999998807907104)
Point(token_index=0, time_index=9, score=0.9999996423721313)
Point(token_index=0, time_index=10, score=0.9999996423721313)
Point(token_index=0, time_index=11, score=0.9999998807907104)
Point(token_index=0, time_index=12, score=0.9999996423721313)
Point(token_index=0, time_index=13, score=0.9999996423721313)
Point(token_index=0, time_index=14, score=0.9999996423721313)
Point(token_index=0, time_index=15, score=0.9999996423721313)
Point(token_index=0, time_index=16, score=0.9999996423721313)
Point(token_index=0, time_index=17, score=0.9999996423721313)
Point(token_index=0, time_index=18, score=0.9999998807907104)
Point(token_index=0, time_index=19, score=0.9999996423721313)
Point(token_index=0, time_index=20, score=0.9999996423721313)
Point(token_index=0, time_index=21, score=0.9999996423721313)
Point(token_index=0, time_index=22, score=0.9999996423721313)
Point(token_index=0, time_index=23, score=0.9999997615814209)
Point(token_index=0, time_index=24, score=0.9999998807907104)
Point(token_index=0, time_index=25, score=0.9999998807907104)
Point(token_index=0, time_index=26, score=0.9999998807907104)
Point(token_index=0, time_index=27, score=0.9999998807907104)
Point(token_index=0, time_index=28, score=0.9999985694885254)
Point(token_index=0, time_index=29, score=0.9999943971633911)
Point(token_index=0, time_index=30, score=0.9999842643737793)
Point(token_index=1, time_index=31, score=0.9846184849739075)
Point(token_index=1, time_index=32, score=0.999970555305481)
Point(token_index=1, time_index=33, score=0.15382443368434906)
Point(token_index=1, time_index=34, score=0.9999172687530518)
Point(token_index=2, time_index=35, score=0.6086152791976929)
Point(token_index=2, time_index=36, score=0.9997722506523132)
Point(token_index=3, time_index=37, score=0.999713122844696)
Point(token_index=3, time_index=38, score=0.9999358654022217)
Point(token_index=4, time_index=39, score=0.9861729741096497)
Point(token_index=4, time_index=40, score=0.9242315888404846)
Point(token_index=5, time_index=41, score=0.926000714302063)
Point(token_index=5, time_index=42, score=0.015565798617899418)
Point(token_index=5, time_index=43, score=0.9998375177383423)
Point(token_index=6, time_index=44, score=0.9988489151000977)
Point(token_index=7, time_index=45, score=0.10210870206356049)
Point(token_index=7, time_index=46, score=0.9999427795410156)
Point(token_index=8, time_index=47, score=0.9999943971633911)
Point(token_index=8, time_index=48, score=0.9979603290557861)
Point(token_index=9, time_index=49, score=0.03601734712719917)
Point(token_index=9, time_index=50, score=0.06166908144950867)
Point(token_index=9, time_index=51, score=4.3360221752664074e-05)
Point(token_index=10, time_index=52, score=0.9999799728393555)
Point(token_index=11, time_index=53, score=0.9967053532600403)
Point(token_index=11, time_index=54, score=0.9999257326126099)
Point(token_index=11, time_index=55, score=0.9999982118606567)
Point(token_index=12, time_index=56, score=0.9990670084953308)
Point(token_index=12, time_index=57, score=0.9999996423721313)
Point(token_index=12, time_index=58, score=0.9999996423721313)
Point(token_index=12, time_index=59, score=0.8453131914138794)
Point(token_index=12, time_index=60, score=0.9999996423721313)
Point(token_index=13, time_index=61, score=0.9996007084846497)
Point(token_index=13, time_index=62, score=0.999998927116394)
Point(token_index=14, time_index=63, score=0.003528432222083211)
Point(token_index=14, time_index=64, score=1.0)
Point(token_index=14, time_index=65, score=1.0)
Point(token_index=14, time_index=66, score=0.9999915361404419)
Point(token_index=15, time_index=67, score=0.9971502423286438)
Point(token_index=15, time_index=68, score=0.9999990463256836)
Point(token_index=15, time_index=69, score=0.9999992847442627)
Point(token_index=15, time_index=70, score=0.9999997615814209)
Point(token_index=15, time_index=71, score=0.9999998807907104)
Point(token_index=15, time_index=72, score=0.9999880790710449)
Point(token_index=15, time_index=73, score=0.011414232663810253)
Point(token_index=15, time_index=74, score=0.9999977350234985)
Point(token_index=16, time_index=75, score=0.9996126294136047)
Point(token_index=16, time_index=76, score=0.999998927116394)
Point(token_index=16, time_index=77, score=0.9729113578796387)
Point(token_index=16, time_index=78, score=0.999998927116394)
Point(token_index=17, time_index=79, score=0.9949334263801575)
Point(token_index=17, time_index=80, score=0.999998927116394)
Point(token_index=17, time_index=81, score=0.9999123811721802)
Point(token_index=17, time_index=82, score=0.9999774694442749)
Point(token_index=18, time_index=83, score=0.6568478941917419)
Point(token_index=18, time_index=84, score=0.9984306693077087)
Point(token_index=18, time_index=85, score=0.9999876022338867)
Point(token_index=19, time_index=86, score=0.9993751645088196)
Point(token_index=19, time_index=87, score=0.9999988079071045)
Point(token_index=19, time_index=88, score=0.10450363159179688)
Point(token_index=19, time_index=89, score=0.9999969005584717)
Point(token_index=20, time_index=90, score=0.39716723561286926)
Point(token_index=20, time_index=91, score=0.9999932050704956)
Point(token_index=21, time_index=92, score=1.6974917116385768e-06)
Point(token_index=21, time_index=93, score=0.9861233234405518)
Point(token_index=21, time_index=94, score=0.9999960660934448)
Point(token_index=22, time_index=95, score=0.999272882938385)
Point(token_index=22, time_index=96, score=0.9993417859077454)
Point(token_index=22, time_index=97, score=0.9999983310699463)
Point(token_index=23, time_index=98, score=0.9999971389770508)
Point(token_index=23, time_index=99, score=0.9999998807907104)
Point(token_index=23, time_index=100, score=0.9999995231628418)
Point(token_index=23, time_index=101, score=0.9999732971191406)
Point(token_index=24, time_index=102, score=0.9983206391334534)
Point(token_index=24, time_index=103, score=0.9999991655349731)
Point(token_index=24, time_index=104, score=0.9999996423721313)
Point(token_index=24, time_index=105, score=0.9999998807907104)
Point(token_index=24, time_index=106, score=1.0)
Point(token_index=24, time_index=107, score=0.9998623132705688)
Point(token_index=24, time_index=108, score=0.9999980926513672)
Point(token_index=25, time_index=109, score=0.9988549947738647)
Point(token_index=25, time_index=110, score=0.9999798536300659)
Point(token_index=26, time_index=111, score=0.8575040698051453)
Point(token_index=26, time_index=112, score=0.9999847412109375)
Point(token_index=27, time_index=113, score=0.987016499042511)
Point(token_index=27, time_index=114, score=1.8980852473760024e-05)
Point(token_index=27, time_index=115, score=0.9999794960021973)
Point(token_index=28, time_index=116, score=0.9998254179954529)
Point(token_index=28, time_index=117, score=0.9999990463256836)
Point(token_index=29, time_index=118, score=0.9999732971191406)
Point(token_index=29, time_index=119, score=0.0009207362891174853)
Point(token_index=29, time_index=120, score=0.9993653893470764)
Point(token_index=30, time_index=121, score=0.9975402355194092)
Point(token_index=30, time_index=122, score=0.00030417481320910156)
Point(token_index=30, time_index=123, score=0.9999344348907471)
Point(token_index=31, time_index=124, score=6.0895758906553965e-06)
Point(token_index=31, time_index=125, score=0.9833247661590576)
Point(token_index=32, time_index=126, score=0.9974588751792908)
Point(token_index=33, time_index=127, score=0.0008252769475802779)
Point(token_index=33, time_index=128, score=0.9965146780014038)
Point(token_index=34, time_index=129, score=0.017434753477573395)
Point(token_index=34, time_index=130, score=0.9989168643951416)
Point(token_index=35, time_index=131, score=0.9999697208404541)
Point(token_index=36, time_index=132, score=0.9999842643737793)
Point(token_index=36, time_index=133, score=0.9997639060020447)
Point(token_index=37, time_index=134, score=0.5112806558609009)
Point(token_index=37, time_index=135, score=0.9998302459716797)
Point(token_index=38, time_index=136, score=0.08521778136491776)
Point(token_index=38, time_index=137, score=0.004070617258548737)
Point(token_index=38, time_index=138, score=0.9999815225601196)
Point(token_index=39, time_index=139, score=0.012022347189486027)
Point(token_index=39, time_index=140, score=0.9999980926513672)
Point(token_index=39, time_index=141, score=0.0005828227149322629)
Point(token_index=39, time_index=142, score=0.999907374382019)
Point(token_index=40, time_index=143, score=0.9999960660934448)
Point(token_index=40, time_index=144, score=0.9999980926513672)
Point(token_index=40, time_index=145, score=0.9999916553497314)
Point(token_index=41, time_index=146, score=0.9971166849136353)
Point(token_index=41, time_index=147, score=0.998178243637085)
Point(token_index=41, time_index=148, score=0.9999310970306396)
Point(token_index=42, time_index=149, score=0.9879370331764221)
Point(token_index=42, time_index=150, score=0.9997633099555969)
Point(token_index=42, time_index=151, score=0.9999535083770752)
Point(token_index=43, time_index=152, score=0.9999715089797974)
Point(token_index=44, time_index=153, score=0.3181014060974121)
Point(token_index=44, time_index=154, score=0.9997820258140564)
Point(token_index=45, time_index=155, score=0.01603410765528679)
Point(token_index=45, time_index=156, score=0.999901294708252)
Point(token_index=46, time_index=157, score=0.46643632650375366)
Point(token_index=46, time_index=158, score=0.9999994039535522)
Point(token_index=46, time_index=159, score=0.9999996423721313)
Point(token_index=46, time_index=160, score=0.9999995231628418)
Point(token_index=46, time_index=161, score=0.9999996423721313)
Point(token_index=46, time_index=162, score=0.9999996423721313)
Point(token_index=46, time_index=163, score=0.9999996423721313)
Point(token_index=46, time_index=164, score=0.9999995231628418)
Point(token_index=46, time_index=165, score=0.9999995231628418)
Point(token_index=46, time_index=166, score=0.9999996423721313)
Point(token_index=46, time_index=167, score=0.9999996423721313)
Point(token_index=46, time_index=168, score=0.9999995231628418)

시각화

def plot_trellis_with_path(trellis, path):
    # 경로와 함께 trellis를 그리기 위해, 'nan' 값을 이용합니다.
    trellis_with_path = trellis.clone()
    for _, p in enumerate(path):
        trellis_with_path[p.time_index, p.token_index] = float("nan")
    plt.imshow(trellis_with_path.T, origin="lower")
    plt.title("The path found by backtracking")
    plt.tight_layout()


plot_trellis_with_path(trellis, path)
The path found by backtracking

좋습니다.

경로 분할

지금 이 경로는 같은 라벨의 반복이 포함되어 있기 때문에 이를 병합하여 원본 정답 스크립트와 가깝게 만들어봅시다.

다수의 경로 지점들을 병합할 때 단순하게, 병합된 분할의 평균 확률을 취합니다.

# 라벨을 병합함
@dataclass
class Segment:
    label: str
    start: int
    end: int
    score: float

    def __repr__(self):
        return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"

    @property
    def length(self):
        return self.end - self.start


def merge_repeats(path):
    i1, i2 = 0, 0
    segments = []
    while i1 < len(path):
        while i2 < len(path) and path[i1].token_index == path[i2].token_index:
            i2 += 1
        score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
        segments.append(
            Segment(
                transcript[path[i1].token_index],
                path[i1].time_index,
                path[i2 - 1].time_index + 1,
                score,
            )
        )
        i1 = i2
    return segments


segments = merge_repeats(path)
for seg in segments:
    print(seg)
|       (1.00): [    0,    31)
I       (0.78): [   31,    35)
|       (0.80): [   35,    37)
H       (1.00): [   37,    39)
A       (0.96): [   39,    41)
D       (0.65): [   41,    44)
|       (1.00): [   44,    45)
T       (0.55): [   45,    47)
H       (1.00): [   47,    49)
A       (0.03): [   49,    52)
T       (1.00): [   52,    53)
|       (1.00): [   53,    56)
C       (0.97): [   56,    61)
U       (1.00): [   61,    63)
R       (0.75): [   63,    67)
I       (0.88): [   67,    75)
O       (0.99): [   75,    79)
S       (1.00): [   79,    83)
I       (0.89): [   83,    86)
T       (0.78): [   86,    90)
Y       (0.70): [   90,    92)
|       (0.66): [   92,    95)
B       (1.00): [   95,    98)
E       (1.00): [   98,   102)
S       (1.00): [  102,   109)
I       (1.00): [  109,   111)
D       (0.93): [  111,   113)
E       (0.66): [  113,   116)
|       (1.00): [  116,   118)
M       (0.67): [  118,   121)
E       (0.67): [  121,   124)
|       (0.49): [  124,   126)
A       (1.00): [  126,   127)
T       (0.50): [  127,   129)
|       (0.51): [  129,   131)
T       (1.00): [  131,   132)
H       (1.00): [  132,   134)
I       (0.76): [  134,   136)
S       (0.36): [  136,   139)
|       (0.50): [  139,   143)
M       (1.00): [  143,   146)
O       (1.00): [  146,   149)
M       (1.00): [  149,   152)
E       (1.00): [  152,   153)
N       (0.66): [  153,   155)
T       (0.51): [  155,   157)
|       (0.96): [  157,   169)

시각화

def plot_trellis_with_segments(trellis, segments, transcript):
    # 경로와 함께 trellis를 그리기 위해, 'nan' 값을 이용합니다.
    trellis_with_path = trellis.clone()
    for i, seg in enumerate(segments):
        if seg.label != "|":
            trellis_with_path[seg.start : seg.end, i] = float("nan")

    fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
    ax1.set_title("Path, label and probability for each label")
    ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")

    for i, seg in enumerate(segments):
        if seg.label != "|":
            ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
            ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")

    ax2.set_title("Label probability with and without repetation")
    xs, hs, ws = [], [], []
    for seg in segments:
        if seg.label != "|":
            xs.append((seg.end + seg.start) / 2 + 0.4)
            hs.append(seg.score)
            ws.append(seg.end - seg.start)
            ax2.annotate(seg.label, (seg.start + 0.8, -0.07))
    ax2.bar(xs, hs, width=ws, color="gray", alpha=0.5, edgecolor="black")

    xs, hs = [], []
    for p in path:
        label = transcript[p.token_index]
        if label != "|":
            xs.append(p.time_index + 1)
            hs.append(p.score)

    ax2.bar(xs, hs, width=0.5, alpha=0.5)
    ax2.axhline(0, color="black")
    ax2.grid(True, axis="y")
    ax2.set_ylim(-0.1, 1.1)
    fig.tight_layout()


plot_trellis_with_segments(trellis, segments, transcript)
Path, label and probability for each label, Label probability with and without repetation

좋습니다.

여러 분할을 단어로 병합

지금 단어로 병합해봅시다. wav2vec2 모델은 '|' 을 단어 경계로 사용합니다. 그래서 '|' 이 등장할 때마다 분할을 병합합니다.

그러고 나서 최종적으로 원본 오디오를 분할된 오디오로 분할하고 이를 들어 분할이 옳게 되었는지 확인합니다.

# 단어 병합
def merge_words(segments, separator="|"):
    words = []
    i1, i2 = 0, 0
    while i1 < len(segments):
        if i2 >= len(segments) or segments[i2].label == separator:
            if i1 != i2:
                segs = segments[i1:i2]
                word = "".join([seg.label for seg in segs])
                score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
                words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
            i1 = i2 + 1
            i2 = i1
        else:
            i2 += 1
    return words


word_segments = merge_words(segments)
for word in word_segments:
    print(word)
I       (0.78): [   31,    35)
HAD     (0.84): [   37,    44)
THAT    (0.52): [   45,    53)
CURIOSITY       (0.89): [   56,    92)
BESIDE  (0.94): [   95,   116)
ME      (0.67): [  118,   124)
AT      (0.66): [  126,   129)
THIS    (0.70): [  131,   139)
MOMENT  (0.88): [  143,   157)

시각화

def plot_alignments(trellis, segments, word_segments, waveform, sample_rate=bundle.sample_rate):
    trellis_with_path = trellis.clone()
    for i, seg in enumerate(segments):
        if seg.label != "|":
            trellis_with_path[seg.start : seg.end, i] = float("nan")

    fig, [ax1, ax2] = plt.subplots(2, 1)

    ax1.imshow(trellis_with_path.T, origin="lower", aspect="auto")
    ax1.set_facecolor("lightgray")
    ax1.set_xticks([])
    ax1.set_yticks([])

    for word in word_segments:
        ax1.axvspan(word.start - 0.5, word.end - 0.5, edgecolor="white", facecolor="none")

    for i, seg in enumerate(segments):
        if seg.label != "|":
            ax1.annotate(seg.label, (seg.start, i - 0.7), size="small")
            ax1.annotate(f"{seg.score:.2f}", (seg.start, i + 3), size="small")

    # 원본 waveform
    ratio = waveform.size(0) / sample_rate / trellis.size(0)
    ax2.specgram(waveform, Fs=sample_rate)
    for word in word_segments:
        x0 = ratio * word.start
        x1 = ratio * word.end
        ax2.axvspan(x0, x1, facecolor="none", edgecolor="white", hatch="/")
        ax2.annotate(f"{word.score:.2f}", (x0, sample_rate * 0.51), annotation_clip=False)

    for seg in segments:
        if seg.label != "|":
            ax2.annotate(seg.label, (seg.start * ratio, sample_rate * 0.55), annotation_clip=False)
    ax2.set_xlabel("time [second]")
    ax2.set_yticks([])
    fig.tight_layout()


plot_alignments(
    trellis,
    segments,
    word_segments,
    waveform[0],
)
forced alignment with torchaudio tutorial

오디오 샘플

def display_segment(i):
    ratio = waveform.size(1) / trellis.size(0)
    word = word_segments[i]
    x0 = int(ratio * word.start)
    x1 = int(ratio * word.end)
    print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
    segment = waveform[:, x0:x1]
    return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
# 각 분할에 해당하는 오디오 생성
print(transcript)
IPython.display.Audio(SPEECH_FILE)
|I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT|


display_segment(0)
I (0.78): 0.624 - 0.704 sec


display_segment(1)
HAD (0.84): 0.744 - 0.885 sec


display_segment(2)
THAT (0.52): 0.905 - 1.066 sec


display_segment(3)
CURIOSITY (0.89): 1.127 - 1.851 sec


display_segment(4)
BESIDE (0.94): 1.911 - 2.334 sec


display_segment(5)
ME (0.67): 2.374 - 2.495 sec


display_segment(6)
AT (0.66): 2.535 - 2.595 sec


display_segment(7)
THIS (0.70): 2.635 - 2.796 sec


display_segment(8)
MOMENT (0.88): 2.877 - 3.159 sec


결론

이번 튜토리얼에서, torchaudio의 wav2vec2 모델을 사용하여 강제 정렬을 위한 CTC 분할을 수행하는 방법을 살펴보았습니다.

Total running time of the script: ( 0 minutes 9.820 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2024, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동