Quick Intro to Torch Broadcasting
Introduction
Broadcasting is a fundamental feature in PyTorch that enables element-wise operations between tensors of different shapes. When performing these operations, PyTorch automatically expands the dimensions of the smaller tensor to match the larger one, without creating additional memory copies. This makes broadcasting both memory-efficient and convenient, as it eliminates the need for manual tensor reshaping.
Rules of Broadcasting
Two tensors are “broadcastable” if the following rules are met:
- Each tensor has at least one dimension
- When comparing dimensions from right to left (trailing dimension):
- The dimensions must be equal, OR
- One of the dimensions must be 1, OR
- One of the tensors does not have the dimension
Practice Problems
Training neural networks with PyTorch requires understanding tensor operations. A solid grasp of broadcasting semantics provides greater flexibility when manipulating tensors.
Let’s practice with broadcasting problems of varying difficulty - from easy to hard.
Terminology
Before we begin, let’s review the fundamental elements in PyTorch:
-
Scalar
: Single numerical value-
1
,2
,3
-
-
Vector
: 1-dimensional tensor[1, 2, 3]
-
Matrix
: 2-dimensional tensor[[1, 2, 3], [4, 5, 6]]
-
Tensor
: Multidimensional array (3 or more dimensions)[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
Easy
Addition of Scalar and Vector
1
1 + torch.tensor([1, 2, 3]) # (1, ) + (3, )
Solution
1
2
3
4
"""
torch.tensor([1, 1, 1]) + torch.tensor([1, 2, 3])
"""
torch.tensor([2, 3, 4])
Explanation
- The scalar
1
of shape(1,)
is broadcast to match the shape of the vector[1, 2, 3]
of shape(3,)
, resulting in element-wise addition.
Addition between Matrices with Different Dimensions
1
2
3
x = torch.tensor([[1, 2, 3]]) # (1, 3)
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x + y # (1, 3) + (3, 1)
Solution
1
2
3
4
5
6
7
"""
1. torch.tensor([[1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) + (3, 3)
2. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) + (3, 3)
"""
tensor([[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
Explanation
- Broadcast trailing dimension of
y
to(3, 3)
to match the trailing dimension ofx
. - Now that
y
is of shape(3, 3)
, broadcastx
to(3, 3)
Multiplication of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x * y # (3, ) * (3, 1)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([[1, 2, 3]]) * torch.tensor([[1], [2], [3]]) ; (1, 3) * (3, 1)
2. torch.tensor([[1, 2, 3]]) * torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) * (3, 3)
3. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) * torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) * (3, 3)
"""
tensor([[1, 2, 3],
[2, 4, 6],
[3, 6, 9]])
Explanation
- Broadcasting logic is same as addition except for multiplying the elements.
Subtraction of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1, 1, 1], [2, 2, 2]]) # (2, 3)
x - y # (3, ) - (2, 3)
Solution
1
2
3
4
5
6
"""
1. torch.tensor([[1, 2, 3]]) - torch.tensor([[1, 1, 1], [2, 2, 2]]) ; (1, 3) - (2, 3)
2. torch.tensor([[1, 2, 3], [1, 2, 3]]) - torch.tensor([[1, 1, 1], [2, 2, 2]]) ; (2, 3) - (2, 3)
"""
tensor([[ 0, 1, 2],
[-1, 0, 1]])
Explanation
- Prepend
1
tox
to make shape of(1, 3)
. -
x
andy
’s trailing dimensions (3
) both match. - Broadcast
x
to match the shape of(2, 3)
.
Multiplication of Scalar and Matrix
1
2
3
x = torch.tensor(3) # Scalar has no shape, thus ([])
y = torch.tensor([[1, 2], [3, 4]]) # (2, 2)
x * y # ([]) - (2, 2)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([3]) * torch.tensor([[1, 2], [3, 4]]) ; (1, ) * (2, 2)
2. torch.tensor([[3]]) * torch.tensor([[1, 2], [3, 4]]) ; (1, 1) * (2, 2)
3. torch.tensor([[3, 3]]) * torch.tensor([[1, 2], [3, 4]]) ; (1, 2) * (2, 2)
4. torch.tensor([[3, 3], [3, 3]]) * torch.tensor([[1, 2], [3, 4]]) ; (2, 2) * (2, 2)
"""
tensor([[ 3, 6],
[ 9, 12]])
Explanation
- Scalar
x
has no dimension therefore broadcast it to(1, )
. - Prepend
1
tox
to match the shape of the matrix. -
x
of shape(1, 1)
is then broadcasted to(1, 2)
to match the trailing dimension ofy
. - Remaining dimension of
x
is broadcasted to2
to match the dimesion ofy
and becomes a shape of(2, 2)
.
Medium
Addition of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x + y # (3, ) + (3, 1)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([[1, 2, 3]]) + torch.tensor([[1], [2], [3]]) ; (1, 3) + (3, 1)
2. torch.tensor([[1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) + (3, 3)
3. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) + (3, 3)
"""
tensor([[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
Explanation
- Prepend
1
tox
(dimension of tensor with fewer dimensions) to makex
andy
equal length. - Start at the trailing dimension,
y
of shape(3, 1)
is broadcasted to match the shape ofx
and becomes(3, 3)
. - Matrix
x
of shape(1, 3)
is broadcasted to match the shape ofy
and becomes(3, 3)
.
Addition of Vector and 3D Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.ones(2, 3, 3) # (2, 3, 3)
x + y # (3, ) + (2, 3, 3)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
"""
1. torch.tensor([[[1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (1, 1, 3) + (2, 3, 3)
2. torch.tensor([[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (1, 3, 3) + (2, 3, 3)
3. torch.tensor([[[1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (2, 3, 3) + (2, 3, 3)
"""
tensor([[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]],
[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]]])
Explanation
- Prepend
1
twoce tox
(dimension of tensor with fewer dimensions) to makex
andy
equal length (x
now becomes(1, 1, 3)
). - First trailing dimensions of
x
andy
match. -
x
of shape(1, 1, 3)
is broadcasted to(1, 2, 3)
match the second trailing dimension ofy
. -
x
is finally broadcasted to(2, 2, 3)
and matchesy
’s dimension.
Addition of Two Matrices
1
2
3
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # (2, 3)
y = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) # (3, 4)
x + y # (2, 3) + (3, 4)
Solution
1
2
3
4
5
6
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-cd60f97aa77f> in <cell line: 1>()
----> 1 x + y
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
Explanation
- Addition of
x
&y
should fail as it violates the broadcasting semantics. - One of following should hold according to the broadcasting semantics, starting at the trailing dimensions (i.e.,
x
:3
andy
:4
):- Dimension sizes must equal
- One of them is 1
- One of them does not exist
Custom Dimension Manipulation
1
2
3
4
5
# Try to manipulate the dimension of `x` to make both compatible for addition
# HINT: `torch.Tensor.view`
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # (2, 3)
y = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) # (3, 6)
x + y # (2, 3) + (3, 6)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
"""
1. x.view(-1): torch.tensor([1, 2, 3, 4, 5, 6])
2. torch.tensor([1, 2, 3, 4, 5, 6]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (6, ) + (3, 6)
3. 2. torch.tensor([[1, 2, 3, 4, 5, 6]]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (1, 6) + (3, 6)
4. torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (3, 6) + (3, 6)
"""
x = x.view(-1) # (6, )
x + y # (6, ) + (3, 6)
tensor([[ 2, 4, 6, 8, 10, 12],
[ 8, 10, 12, 14, 16, 18],
[14, 16, 18, 20, 22, 24]])
Explanation
-
x.view(-1)
transforms the dimension to(6, )
. - Prepends
1
tox
and makes the dimension to(1, 6)
. - Trailing dimension of both
x
andy
(i.e,6
) matches. -
x
’s1
matchesy
’s dimension3
.
Multiplication of Complex Broadcasting
1
2
3
4
# Since it is a multiplication of `1`, try to answer the shape of this operation.
x = torch.ones(2, 1, 3) # (2, 1, 3)
y = torch.ones(1, 3, 1) # (1, 3, 1)
x * y # (2, 1, 3) * (1, 3, 1)
Solution
1
2
3
4
5
6
7
8
# Shape: (2, 3, 3)
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
Explanation
- Starting with shapes
(2, 1, 3)
and(1, 3, 1)
, let’s analyze dimension by dimension:- First dimension:
2
vs1
→y
is broadcast from1
to2
- Middle dimension:
1
vs3
→x
is broadcast from1
to3
- Last dimension:
3
vs1
→y
is broadcast from1
to3
- First dimension:
- Final shapes before multiplication:
-
x
:(2, 3, 3)
(original1
expanded to3
in middle) -
y
:(2, 3, 3)
(expanded from1
to2
in first, and1
to3
in last)
-
- Since all elements are
1
, the multiplication results in a(2, 3, 3)
tensor of ones
Hard
Mean Across Specific Axes
1
2
3
# Compute the mean of the x across the last axis, and add it back to the original tensor
g = torch.Generator().manual_seed(617) # guarantees consistent random generation
x = torch.rand(3, 4, 5, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
x + x.mean(dim=2, keepdims=True)
# <--- OUTPUT --->
tensor([[[0.5952, 0.6486, 1.1932, 0.4631, 1.1062],
[0.7727, 0.9813, 0.4742, 0.6173, 1.0093],
[0.5897, 0.6357, 0.8125, 0.8716, 0.4222],
[1.1211, 0.4338, 0.6372, 0.4942, 0.7857]],
[[0.6572, 1.3236, 1.3613, 0.4812, 0.5334],
[1.5122, 1.1478, 1.1528, 1.4942, 0.7012],
[1.1803, 1.4021, 1.4673, 1.1307, 0.9080],
[1.6144, 0.6951, 1.1600, 1.2778, 1.5087]],
[[1.3744, 0.9883, 0.6701, 1.1222, 1.1225],
[0.9468, 0.8783, 0.4040, 0.6899, 1.0623],
[1.2763, 1.0732, 0.7736, 1.5029, 0.5423],
[1.0240, 1.2803, 1.1804, 1.5658, 0.9504]]])
1
2
3
4
5
6
7
8
9
10
11
12
# Verify
first_row = x[0][0] # torch.tensor([0.1946, 0.2479, 0.7926, 0.0624, 0.7056]) ; (5, )
first_row_mean = x[0][0].mean() # torch.tensor(0.4006) ; ([])
result = first_row + first_row_mean
matches = result == (x + x.mean(dim=2, keepdims=True))[0][0]
print(result)
print(matches)
# <--- OUTPUT --->
tensor([0.5952, 0.6486, 1.1932, 0.4631, 1.1062])
tensor([True, True, True, True, True])
Explanation
- The original tensor
x
has shape(3, 4, 5)
-
x.mean(dim=2, keepdims=True)
does the following:- Computes the mean along dimension 2 (the last axis)
-
keepdims=True
preserves the dimension, resulting in shape(3, 4, 1)
- When adding this back to
x
, broadcasting expands the mean from(3, 4, 1)
to(3, 4, 5)
- Each element in the result is the sum of:
- The original value from
x
- The mean of its corresponding row (broadcast across the last dimension)
- The original value from
- The verification code confirms this by:
- Taking the first row
x[0][0]
(shape(5,)
) - Adding its mean (a scalar) to each element
- Comparing with the corresponding row in the full calculation
- Taking the first row
Broadcasting with Max Reduction
1
2
3
# Find the maximum value of each row and subtract this value from the original tensor.
g = torch.Generator().manual_seed(617)
x = torch.rand(4, 3, 2, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
max_rows = x.max(dim=2, keepdims=True).values # (4, 3, 1)
x - max_rows # (4, 3, 2) - (4, 3, 1)
# <--- OUTPUT --->
tensor([[[-0.0533, 0.0000],
[ 0.0000, -0.7302],
[ 0.0000, -0.3184]],
[[ 0.0000, -0.5071],
[-0.3920, 0.0000],
[-0.0460, 0.0000]],
[[-0.0591, 0.0000],
[-0.6848, 0.0000],
[-0.2034, 0.0000]],
[[-0.2914, 0.0000],
[-0.6664, 0.0000],
[ 0.0000, -0.8801]]])
Explanation
- The original tensor
x
has shape(4, 3, 2)
-
x.max(dim=2, keepdims=True)[0]
does the following:- Takes the maximum value along dimension 2 (last axis) for each row
-
keepdims=True
maintains the dimension, resulting in shape(4, 3, 1)
-
[0]
is used becausemax()
returns both values and indices; we only want values
- When subtracting
maxOfRows
fromx
, broadcasting expands(4, 3, 1)
to(4, 3, 2)
- The result shows each value’s difference from its row’s maximum:
- Values equal to their row’s maximum become 0
- All other values are negative (being less than their row’s maximum)
Combining Argmax with Tensor Indexing
1
2
3
4
5
"""
Compute argmax of `x` along dim=2, then replace the maximum values along that dimension with their square while keeping other values unchanged.
"""
g = torch.Generator().manual_seed(617)
x = torch.rand(5, 4, 6, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
indices = x.argmax(dim=2) # (5, 4)
mask = torch.zeros_like(x, dtype=torch.bool)
batch_indices = torch.arange(x.size(0)).unsqueeze(1) # (5, 1)
row_indices = torch.arange(x.size(1)).unsqueeze(0) # (1, 4)
mask[batch_indices, row_indices, indices] = True
x = torch.where(mask, x ** 2, x)
# <--- OUTPUT --->
tensor([[[1.9459e-01, 2.4792e-01, 6.2822e-01, 6.2441e-02, 7.0561e-01,
3.8723e-01],
[5.9580e-01, 8.8707e-02, 2.3185e-01, 3.8913e-01, 2.5649e-01,
3.0251e-01],
[4.7937e-01, 5.3843e-01, 8.9037e-02, 5.9889e-01, 8.6632e-02,
2.9001e-01],
[1.4702e-01, 4.3846e-01, 2.2150e-01, 8.8794e-01, 8.5682e-01,
4.5569e-02]],
[[9.7745e-02, 8.3054e-01, 5.4700e-01, 5.5202e-01, 8.9336e-01,
1.0040e-01],
[5.7145e-01, 7.9327e-01, 8.5845e-01, 5.2186e-01, 2.9914e-01,
9.7775e-01],
[6.9493e-02, 5.3443e-01, 6.5223e-01, 7.7981e-01, 8.4665e-01,
4.6055e-01],
[1.4233e-01, 5.9448e-01, 3.5373e-01, 5.4865e-01, 4.8015e-01,
5.8755e-03]],
[[2.9175e-01, 6.6421e-01, 7.5952e-01, 5.5636e-01, 2.5674e-01,
9.7236e-01],
[2.5453e-02, 4.2386e-01, 6.8020e-01, 5.8032e-01, 9.3263e-01,
3.5035e-01],
[1.4200e-01, 1.2870e-01, 3.0852e-01, 1.0691e-01, 2.1062e-01,
1.7581e-02],
[5.7376e-01, 1.3413e-01, 5.2298e-01, 9.0579e-02, 6.6662e-01,
9.5560e-01]],
[[3.7499e-01, 7.5693e-01, 7.9677e-01, 5.3703e-01, 6.5874e-01,
6.6973e-01],
[5.4903e-01, 7.6660e-01, 2.5739e-01, 6.2306e-01, 3.9811e-01,
3.2552e-01],
[9.6114e-02, 6.8212e-01, 9.9059e-01, 8.4072e-01, 5.6158e-01,
4.1336e-01],
[2.9504e-01, 1.7929e-01, 3.7160e-01, 1.4645e-01, 1.4089e-01,
3.6650e-01]],
[[7.9860e-01, 1.6316e-01, 7.0585e-01, 7.4827e-01, 7.2970e-01,
3.8324e-01],
[3.9089e-01, 5.9064e-01, 8.0657e-01, 1.7587e-01, 3.3804e-01,
1.9550e-01],
[7.8799e-01, 3.3231e-01, 7.0377e-02, 8.0018e-01, 5.0456e-01,
9.6001e-01],
[5.4405e-01, 8.8870e-01, 3.7259e-04, 8.8793e-01, 6.1661e-01,
8.4178e-01]]])
Explanation
- First, find the indices of maximum values along dimension
2
:-
x.argmax(dim=2)
returns a tensor of shape(5, 4)
containing the positions of max values
-
- Create a boolean mask of the same shape as
x
(5, 4, 6
):-
torch.zeros_like(x, dtype=torch.bool)
initializes all values toFalse
-
- Create broadcasting-compatible indices to properly index the 3D tensor:
-
batch_indices = torch.arange(5).unsqueeze(1)
creates[[0], [1], [2], [3], [4]]
-
row_indices = torch.arange(4).unsqueeze(0)
creates[[0, 1, 2, 3]]
- When used together, these broadcast to cover all positions where max values occur
-
- Set
True
in the mask at positions of maximum values:mask[batch_indices, row_indices, indices] = True
- The indexing operation
mask[batch_indices, row_indices, indices]
works through broadcasting:1 2 3 4 5 6 7 8 9 10 11 12 13
batch_indices: [[0], # Shape: (5, 1) [1], [2], [3], [4]] row_indices: [[0, 1, 2, 3]] # Shape: (1, 4) indices: [[2, 3, 3, 4], # Shape: (5, 4) [1, 5, 3, 2], [5, 4, 0, 5], [4, 3, 2, 4], [3, 2, 5, 1]]
- When combined, these create index tuples for each maximum value:
- First row:
(0,0,2), (0,1,3), (0,2,3), (0,3,4)
- Second row:
(1,0,1), (1,1,5), (1,2,3), (1,3,2)
- And so on.
- First row:
- Each tuple represents
(batch_idx, row_idx, max_value_position)
- Finally,
torch.where()
selectively applies the squaring:- Where mask is
True
: square the values (x ** 2
) - Where mask is
False
: keep original values (x
)
- Where mask is
Comparing Reduction Results Across Axes
1
2
3
# Compute max along dim=2 and mean along dim=1 and add those results back to the original tensor
g = torch.Generator().manual_seed(617)
x = torch.rand(3, 4, 5, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
max_dim2 = x.max(dim=2, keepdim=True).values # (3, 4, 1)
mean_dim1 = x.mean(dim=1, keepdim=True) # (3, 1, 5)
result = x + max_dim2 + mean_dim1 # (3, 4, 5) + (3, 4, 1) + (3, 1, 5)
# <--- OUTPUT --->
tensor([[[1.3902, 1.3487, 1.9979, 1.1000, 1.9624],
[1.4141, 1.5278, 1.1252, 1.1006, 1.7118],
[1.1980, 1.1492, 1.4305, 1.3218, 1.0917],
[1.9508, 1.1687, 1.4766, 1.1658, 1.6766]],
[[1.8204, 2.3880, 2.5689, 1.4995, 1.3685],
[2.4959, 2.0328, 2.1810, 2.3330, 1.3568],
[2.1032, 2.2261, 2.4345, 1.9086, 1.5027],
[2.6509, 1.6327, 2.2409, 2.1693, 2.2170]],
[[2.3380, 1.8515, 1.2353, 2.1506, 1.8501],
[1.8575, 1.6887, 0.9164, 1.6655, 1.7371],
[2.3903, 2.0868, 1.4891, 2.6817, 1.4202],
[2.0343, 2.1902, 1.7924, 2.6410, 1.7248]]])
Explanation
- The original tensor
x
has shape(3, 4, 5)
- Two reduction operations are performed:
-
max_dim2
: Maximum along dimension 2 (last axis)- Shape goes from
(3, 4, 5)
→(3, 4, 1)
withkeepdim=True
- Each value represents the maximum of its corresponding row
- Shape goes from
-
mean_dim1
: Mean along dimension 1 (middle axis)- Shape goes from
(3, 4, 5)
→(3, 1, 5)
withkeepdim=True
- Each value represents the mean across all rows for that position
- Shape goes from
-
- The final addition
x + max_dim2 + mean_dim1
involves two broadcasting operations:-
max_dim2
(3, 4, 1)
is broadcast to(3, 4, 5)
-
mean_dim1
(3, 1, 5)
is broadcast to(3, 4, 5)
-
Standardizing Tensor with Reduction
1
2
3
# Standardize x along the last axis (dim=2) by subtracting the mean and dividing by the standard deviation.
g = torch.Generator().manual_seed(617)
x = torch.rand(3, 5, 4, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
mean = x.mean(dim=2, keepdims=True) # (3, 5, 1)
std = x.std(dim=2, keepdims=True) # (3, 5, 1)
result = (x - mean) / std # (3, 5, 4) - (3, 5, 1) / (3, 5, 1)
# <--- OUTPUT --->
tensor([[[-4.0344e-01, -2.3768e-01, 1.4553e+00, -8.1418e-01],
[ 9.6276e-01, -2.1043e-01, 5.5813e-01, -1.3105e+00],
[-6.6763e-01, 1.4806e+00, -5.3258e-01, -2.8035e-01],
[ 3.2340e-02, 2.4018e-01, -1.3413e+00, 1.0688e+00],
[-9.7956e-01, 3.1494e-01, -5.9517e-01, 1.2598e+00]],
[[-6.6012e-01, 8.1287e-01, 8.9620e-01, -1.0490e+00],
[-1.2885e+00, 1.1535e+00, 5.9946e-02, 7.5018e-02],
[ 8.6095e-01, -1.3867e+00, -5.1503e-02, 5.7724e-01],
[ 6.0867e-01, -4.6181e-01, -1.1701e+00, 1.0233e+00],
[-1.3592e+00, -1.0854e-03, 3.4300e-01, 1.0173e+00]],
[[ 1.1443e+00, -1.7200e-01, -1.2569e+00, 2.8459e-01],
[ 6.8955e-01, 5.1992e-01, 2.6785e-01, -1.4773e+00],
[-1.3675e+00, 4.7652e-01, 9.4842e-01, -5.7411e-02],
[-4.0622e-01, 1.3754e+00, -9.7121e-01, 2.0231e-03],
[ 1.4138e-01, -2.5028e-01, 1.2610e+00, -1.1521e+00]]])
Conclusion
We have learned about broadcasting semantics and how to use them to perform various operations on tensors.
Remeber the basic rules of broadcasting:
- Each tensor has at least one dimension.
- When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
Broadcasting is a powerful feature that allows us to perform operations on tensors of different shapes without explicitly reshaping them. It is a key concept in PyTorch and is used extensively in many operations. And this will help us in the next post to build character-level bigram models.