• Tutorials >
  • (Prototype) MaskedTensor Advanced Semantics
Shortcuts

(Prototype) MaskedTensor Advanced Semantics

Before working on this tutorial, please make sure to review our MaskedTensor Overview tutorial <https://tutorials.pytorch.kr/prototype/maskedtensor_overview.html>.

The purpose of this tutorial is to help users understand how some of the advanced semantics work and how they came to be. We will focus on two particular ones:

*. Differences between MaskedTensor and NumPy’s MaskedArray *. Reduction semantics

Preparation

import torch
from torch.masked import masked_tensor
import numpy as np
import warnings

# Disable prototype warnings and such
warnings.filterwarnings(action='ignore', category=UserWarning)

MaskedTensor vs NumPy’s MaskedArray

NumPy’s MaskedArray has a few fundamental semantics differences from MaskedTensor.

*. Their factory function and basic definition inverts the mask (similar to torch.nn.MHA); that is, MaskedTensor

uses True to denote 《specified》 and False to denote 《unspecified》, or 《valid》/》invalid》, whereas NumPy does the opposite. We believe that our mask definition is not only more intuitive, but it also aligns more with the existing semantics in PyTorch as a whole.

*. Intersection semantics. In NumPy, if one of two elements are masked out, the resulting element will be

masked out as well – in practice, they apply the logical_or operator.

data = torch.arange(5.)
mask = torch.tensor([True, True, False, True, False])
npm0 = np.ma.masked_array(data.numpy(), (~mask).numpy())
npm1 = np.ma.masked_array(data.numpy(), (mask).numpy())

print("npm0:\n", npm0)
print("npm1:\n", npm1)
print("npm0 + npm1:\n", npm0 + npm1)
npm0:
 [0.0 1.0 -- 3.0 --]
npm1:
 [-- -- 2.0 -- 4.0]
npm0 + npm1:
 [-- -- -- -- --]

Meanwhile, MaskedTensor does not support addition or binary operators with masks that don’t match – to understand why, please find the section on reductions.

mt0 = masked_tensor(data, mask)
mt1 = masked_tensor(data, ~mask)
print("mt0:\n", mt0)
print("mt1:\n", mt1)

try:
    mt0 + mt1
except ValueError as e:
    print ("mt0 + mt1 failed. Error: ", e)
mt0:
 MaskedTensor(
  [  0.0000,   1.0000,       --,   3.0000,       --]
)
mt1:
 MaskedTensor(
  [      --,       --,   2.0000,       --,   4.0000]
)
mt0 + mt1 failed. Error:  Input masks must match. If you need support for this, please open an issue on Github.

However, if this behavior is desired, MaskedTensor does support these semantics by giving access to the data and masks and conveniently converting a MaskedTensor to a Tensor with masked values filled in using to_tensor(). For example:

t0 = mt0.to_tensor(0)
t1 = mt1.to_tensor(0)
mt2 = masked_tensor(t0 + t1, mt0.get_mask() & mt1.get_mask())

print("t0:\n", t0)
print("t1:\n", t1)
print("mt2 (t0 + t1):\n", mt2)
t0:
 tensor([0., 1., 0., 3., 0.])
t1:
 tensor([0., 0., 2., 0., 4.])
mt2 (t0 + t1):
 MaskedTensor(
  [      --,       --,       --,       --,       --]
)

Note that the mask is mt0.get_mask() & mt1.get_mask() since MaskedTensor’s mask is the inverse of NumPy’s.

Reduction Semantics

Recall in MaskedTensor’s Overview tutorial we discussed 《Implementing missing torch.nan* ops》. Those are examples of reductions – operators that remove one (or more) dimensions from a Tensor and then aggregate the result. In this section, we will use reduction semantics to motivate our strict requirements around matching masks from above.

Fundamentally, :class:`MaskedTensor`s perform the same reduction operation while ignoring the masked out (unspecified) values. By way of example:

data = torch.arange(12, dtype=torch.float).reshape(3, 4)
mask = torch.randint(2, (3, 4), dtype=torch.bool)
mt = masked_tensor(data, mask)

print("data:\n", data)
print("mask:\n", mask)
print("mt:\n", mt)
data:
 tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])
mask:
 tensor([[ True,  True,  True,  True],
        [ True,  True,  True, False],
        [ True,  True,  True,  True]])
mt:
 MaskedTensor(
  [
    [  0.0000,   1.0000,   2.0000,   3.0000],
    [  4.0000,   5.0000,   6.0000,       --],
    [  8.0000,   9.0000,  10.0000,  11.0000]
  ]
)

Now, the different reductions (all on dim=1):

print("torch.sum:\n", torch.sum(mt, 1))
print("torch.mean:\n", torch.mean(mt, 1))
print("torch.prod:\n", torch.prod(mt, 1))
print("torch.amin:\n", torch.amin(mt, 1))
print("torch.amax:\n", torch.amax(mt, 1))
torch.sum:
 MaskedTensor(
  [  6.0000,  15.0000,  38.0000]
)
torch.mean:
 MaskedTensor(
  [  1.5000,   5.0000,   9.5000]
)
torch.prod:
 MaskedTensor(
  [  0.0000, 120.0000, 7920.0000]
)
torch.amin:
 MaskedTensor(
  [  0.0000,   4.0000,   8.0000]
)
torch.amax:
 MaskedTensor(
  [  3.0000,   6.0000,  11.0000]
)

Of note, the value under a masked out element is not guaranteed to have any specific value, especially if the row or column is entirely masked out (the same is true for normalizations). For more details on masked semantics, you can find this RFC.

Now, we can revisit the question: why do we enforce the invariant that masks must match for binary operators? In other words, why don’t we use the same semantics as np.ma.masked_array? Consider the following example:

data0 = torch.arange(10.).reshape(2, 5)
data1 = torch.arange(10.).reshape(2, 5) + 10
mask0 = torch.tensor([[True, True, False, False, False], [False, False, False, True, True]])
mask1 = torch.tensor([[False, False, False, True, True], [True, True, False, False, False]])
npm0 = np.ma.masked_array(data0.numpy(), (mask0).numpy())
npm1 = np.ma.masked_array(data1.numpy(), (mask1).numpy())

print("npm0:", npm0)
print("npm1:", npm1)
npm0: [[-- -- 2.0 3.0 4.0]
 [5.0 6.0 7.0 -- --]]
npm1: [[10.0 11.0 12.0 -- --]
 [-- -- 17.0 18.0 19.0]]

Now, let’s try addition:

print("(npm0 + npm1).sum(0):\n", (npm0 + npm1).sum(0))
print("npm0.sum(0) + npm1.sum(0):\n", npm0.sum(0) + npm1.sum(0))
(npm0 + npm1).sum(0):
 [-- -- 38.0 -- --]
npm0.sum(0) + npm1.sum(0):
 [15.0 17.0 38.0 21.0 23.0]

Sum and addition should clearly be associative, but with NumPy’s semantics, they are not, which can certainly be confusing for the user.

MaskedTensor, on the other hand, will simply not allow this operation since mask0 != mask1. That being said, if the user wishes, there are ways around this (for example, filling in the MaskedTensor’s undefined elements with 0 values using to_tensor() like shown below), but the user must now be more explicit with their intentions.

mt0 = masked_tensor(data0, ~mask0)
mt1 = masked_tensor(data1, ~mask1)

(mt0.to_tensor(0) + mt1.to_tensor(0)).sum(0)
tensor([15., 17., 38., 21., 23.])

Conclusion

In this tutorial, we have learned about the different design decisions behind MaskedTensor and NumPy’s MaskedArray, as well as reduction semantics. In general, MaskedTensor is designed to avoid ambiguity and confusing semantics (for example, we try to preserve the associative property amongst binary operations), which in turn can necessitate the user to be more intentional with their code at times, but we believe this to be the better move. If you have any thoughts on this, please let us know!

Total running time of the script: ( 0 minutes 0.011 seconds)

Gallery generated by Sphinx-Gallery


더 궁금하시거나 개선할 내용이 있으신가요? 커뮤니티에 참여해보세요!


이 튜토리얼이 어떠셨나요? 평가해주시면 이후 개선에 참고하겠습니다! :)

© Copyright 2018-2023, PyTorch & 파이토치 한국 사용자 모임(PyTorch Korea User Group).

Built with Sphinx using a theme provided by Read the Docs.

PyTorchKorea @ GitHub

파이토치 한국 사용자 모임을 GitHub에서 만나보세요.

GitHub로 이동

한국어 튜토리얼

한국어로 번역 중인 PyTorch 튜토리얼입니다.

튜토리얼로 이동

커뮤니티

다른 사용자들과 의견을 나누고, 도와주세요!

커뮤니티로 이동