Quick Intro to Torch Broadcasting
Introduction
Broadcasting is a fundamental feature in PyTorch that enables element-wise operations between tensors of different shapes. When performing these operations, PyTorch automatically expands the dimensions of the smaller tensor to match the larger one, without creating additional memory copies. This makes broadcasting both memory-efficient and convenient, as it eliminates the need for manual tensor reshaping.
Rules of Broadcasting
Two tensors are “broadcastable” if the following rules are met:
- Each tensor has at least one dimension
- When comparing dimensions from right to left (trailing dimension):
- The dimensions must be equal, OR
- One of the dimensions must be 1, OR
- One of the tensors does not have the dimension
Practice Problems
Training neural networks with PyTorch requires understanding tensor operations. A solid grasp of broadcasting semantics provides greater flexibility when manipulating tensors.
Let’s practice with broadcasting problems of varying difficulty - from easy to hard.
Terminology
Before we begin, let’s review the fundamental elements in PyTorch:
-
Scalar: Single numerical value-
1,2,3
-
-
Vector: 1-dimensional tensor[1, 2, 3]
-
Matrix: 2-dimensional tensor[[1, 2, 3], [4, 5, 6]]
-
Tensor: Multidimensional array (3 or more dimensions)[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
Easy
Addition of Scalar and Vector
1
1 + torch.tensor([1, 2, 3]) # (1, ) + (3, )
Solution
1
2
3
4
"""
torch.tensor([1, 1, 1]) + torch.tensor([1, 2, 3])
"""
torch.tensor([2, 3, 4])
Explanation
- The scalar
1of shape(1,)is broadcast to match the shape of the vector[1, 2, 3]of shape(3,), resulting in element-wise addition.
Addition between Matrices with Different Dimensions
1
2
3
x = torch.tensor([[1, 2, 3]]) # (1, 3)
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x + y # (1, 3) + (3, 1)
Solution
1
2
3
4
5
6
7
"""
1. torch.tensor([[1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) + (3, 3)
2. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) + (3, 3)
"""
tensor([[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
Explanation
- Broadcast trailing dimension of
yto(3, 3)to match the trailing dimension ofx. - Now that
yis of shape(3, 3), broadcastxto(3, 3)
Multiplication of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x * y # (3, ) * (3, 1)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([[1, 2, 3]]) * torch.tensor([[1], [2], [3]]) ; (1, 3) * (3, 1)
2. torch.tensor([[1, 2, 3]]) * torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) * (3, 3)
3. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) * torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) * (3, 3)
"""
tensor([[1, 2, 3],
[2, 4, 6],
[3, 6, 9]])
Explanation
- Broadcasting logic is same as addition except for multiplying the elements.
Subtraction of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1, 1, 1], [2, 2, 2]]) # (2, 3)
x - y # (3, ) - (2, 3)
Solution
1
2
3
4
5
6
"""
1. torch.tensor([[1, 2, 3]]) - torch.tensor([[1, 1, 1], [2, 2, 2]]) ; (1, 3) - (2, 3)
2. torch.tensor([[1, 2, 3], [1, 2, 3]]) - torch.tensor([[1, 1, 1], [2, 2, 2]]) ; (2, 3) - (2, 3)
"""
tensor([[ 0, 1, 2],
[-1, 0, 1]])
Explanation
- Prepend
1toxto make shape of(1, 3). -
xandy’s trailing dimensions (3) both match. - Broadcast
xto match the shape of(2, 3).
Multiplication of Scalar and Matrix
1
2
3
x = torch.tensor(3) # Scalar has no shape, thus ([])
y = torch.tensor([[1, 2], [3, 4]]) # (2, 2)
x * y # ([]) - (2, 2)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([3]) * torch.tensor([[1, 2], [3, 4]]) ; (1, ) * (2, 2)
2. torch.tensor([[3]]) * torch.tensor([[1, 2], [3, 4]]) ; (1, 1) * (2, 2)
3. torch.tensor([[3, 3]]) * torch.tensor([[1, 2], [3, 4]]) ; (1, 2) * (2, 2)
4. torch.tensor([[3, 3], [3, 3]]) * torch.tensor([[1, 2], [3, 4]]) ; (2, 2) * (2, 2)
"""
tensor([[ 3, 6],
[ 9, 12]])
Explanation
- Scalar
xhas no dimension therefore broadcast it to(1, ). - Prepend
1toxto match the shape of the matrix. -
xof shape(1, 1)is then broadcasted to(1, 2)to match the trailing dimension ofy. - Remaining dimension of
xis broadcasted to2to match the dimesion ofyand becomes a shape of(2, 2).
Medium
Addition of Vector and Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.tensor([[1], [2], [3]]) # (3, 1)
x + y # (3, ) + (3, 1)
Solution
1
2
3
4
5
6
7
8
"""
1. torch.tensor([[1, 2, 3]]) + torch.tensor([[1], [2], [3]]) ; (1, 3) + (3, 1)
2. torch.tensor([[1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (1, 3) + (3, 3)
3. torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) + torch.tensor([[1, 1, 1], [2, 2, 2], [3, 3, 3]]) ; (3, 3) + (3, 3)
"""
tensor([[2, 3, 4],
[3, 4, 5],
[4, 5, 6]])
Explanation
- Prepend
1tox(dimension of tensor with fewer dimensions) to makexandyequal length. - Start at the trailing dimension,
yof shape(3, 1)is broadcasted to match the shape ofxand becomes(3, 3). - Matrix
xof shape(1, 3)is broadcasted to match the shape ofyand becomes(3, 3).
Addition of Vector and 3D Matrix
1
2
3
x = torch.tensor([1, 2, 3]) # (3, )
y = torch.ones(2, 3, 3) # (2, 3, 3)
x + y # (3, ) + (2, 3, 3)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
"""
1. torch.tensor([[[1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (1, 1, 3) + (2, 3, 3)
2. torch.tensor([[[1, 2, 3], [1, 2, 3], [1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (1, 3, 3) + (2, 3, 3)
3. torch.tensor([[[1, 2, 3], [1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3], [1, 2, 3]]]) + torch.tensor([[[1, 1, 1], [1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) ; (2, 3, 3) + (2, 3, 3)
"""
tensor([[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]],
[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]]])
Explanation
- Prepend
1twoce tox(dimension of tensor with fewer dimensions) to makexandyequal length (xnow becomes(1, 1, 3)). - First trailing dimensions of
xandymatch. -
xof shape(1, 1, 3)is broadcasted to(1, 2, 3)match the second trailing dimension ofy. -
xis finally broadcasted to(2, 2, 3)and matchesy’s dimension.
Addition of Two Matrices
1
2
3
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # (2, 3)
y = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) # (3, 4)
x + y # (2, 3) + (3, 4)
Solution
1
2
3
4
5
6
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-8-cd60f97aa77f> in <cell line: 1>()
----> 1 x + y
RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 1
Explanation
- Addition of
x&yshould fail as it violates the broadcasting semantics. - One of following should hold according to the broadcasting semantics, starting at the trailing dimensions (i.e.,
x:3andy:4):- Dimension sizes must equal
- One of them is 1
- One of them does not exist
Custom Dimension Manipulation
1
2
3
4
5
# Try to manipulate the dimension of `x` to make both compatible for addition
# HINT: `torch.Tensor.view`
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # (2, 3)
y = torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) # (3, 6)
x + y # (2, 3) + (3, 6)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
"""
1. x.view(-1): torch.tensor([1, 2, 3, 4, 5, 6])
2. torch.tensor([1, 2, 3, 4, 5, 6]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (6, ) + (3, 6)
3. 2. torch.tensor([[1, 2, 3, 4, 5, 6]]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (1, 6) + (3, 6)
4. torch.tensor([[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6]]) + torch.tensor([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18]]) ; (3, 6) + (3, 6)
"""
x = x.view(-1) # (6, )
x + y # (6, ) + (3, 6)
tensor([[ 2, 4, 6, 8, 10, 12],
[ 8, 10, 12, 14, 16, 18],
[14, 16, 18, 20, 22, 24]])
Explanation
-
x.view(-1)transforms the dimension to(6, ). - Prepends
1toxand makes the dimension to(1, 6). - Trailing dimension of both
xandy(i.e,6) matches. -
x’s1matchesy’s dimension3.
Multiplication of Complex Broadcasting
1
2
3
4
# Since it is a multiplication of `1`, try to answer the shape of this operation.
x = torch.ones(2, 1, 3) # (2, 1, 3)
y = torch.ones(1, 3, 1) # (1, 3, 1)
x * y # (2, 1, 3) * (1, 3, 1)
Solution
1
2
3
4
5
6
7
8
# Shape: (2, 3, 3)
tensor([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]])
Explanation
- Starting with shapes
(2, 1, 3)and(1, 3, 1), let’s analyze dimension by dimension:- First dimension:
2vs1→yis broadcast from1to2 - Middle dimension:
1vs3→xis broadcast from1to3 - Last dimension:
3vs1→yis broadcast from1to3
- First dimension:
- Final shapes before multiplication:
-
x:(2, 3, 3)(original1expanded to3in middle) -
y:(2, 3, 3)(expanded from1to2in first, and1to3in last)
-
- Since all elements are
1, the multiplication results in a(2, 3, 3)tensor of ones
Hard
Mean Across Specific Axes
1
2
3
# Compute the mean of the x across the last axis, and add it back to the original tensor
g = torch.Generator().manual_seed(617) # guarantees consistent random generation
x = torch.rand(3, 4, 5, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
x + x.mean(dim=2, keepdims=True)
# <--- OUTPUT --->
tensor([[[0.5952, 0.6486, 1.1932, 0.4631, 1.1062],
[0.7727, 0.9813, 0.4742, 0.6173, 1.0093],
[0.5897, 0.6357, 0.8125, 0.8716, 0.4222],
[1.1211, 0.4338, 0.6372, 0.4942, 0.7857]],
[[0.6572, 1.3236, 1.3613, 0.4812, 0.5334],
[1.5122, 1.1478, 1.1528, 1.4942, 0.7012],
[1.1803, 1.4021, 1.4673, 1.1307, 0.9080],
[1.6144, 0.6951, 1.1600, 1.2778, 1.5087]],
[[1.3744, 0.9883, 0.6701, 1.1222, 1.1225],
[0.9468, 0.8783, 0.4040, 0.6899, 1.0623],
[1.2763, 1.0732, 0.7736, 1.5029, 0.5423],
[1.0240, 1.2803, 1.1804, 1.5658, 0.9504]]])
1
2
3
4
5
6
7
8
9
10
11
12
# Verify
first_row = x[0][0] # torch.tensor([0.1946, 0.2479, 0.7926, 0.0624, 0.7056]) ; (5, )
first_row_mean = x[0][0].mean() # torch.tensor(0.4006) ; ([])
result = first_row + first_row_mean
matches = result == (x + x.mean(dim=2, keepdims=True))[0][0]
print(result)
print(matches)
# <--- OUTPUT --->
tensor([0.5952, 0.6486, 1.1932, 0.4631, 1.1062])
tensor([True, True, True, True, True])
Explanation
- The original tensor
xhas shape(3, 4, 5) -
x.mean(dim=2, keepdims=True)does the following:- Computes the mean along dimension 2 (the last axis)
-
keepdims=Truepreserves the dimension, resulting in shape(3, 4, 1) - When adding this back to
x, broadcasting expands the mean from(3, 4, 1)to(3, 4, 5)
- Each element in the result is the sum of:
- The original value from
x - The mean of its corresponding row (broadcast across the last dimension)
- The original value from
- The verification code confirms this by:
- Taking the first row
x[0][0](shape(5,)) - Adding its mean (a scalar) to each element
- Comparing with the corresponding row in the full calculation
- Taking the first row
Broadcasting with Max Reduction
1
2
3
# Find the maximum value of each row and subtract this value from the original tensor.
g = torch.Generator().manual_seed(617)
x = torch.rand(4, 3, 2, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
max_rows = x.max(dim=2, keepdims=True).values # (4, 3, 1)
x - max_rows # (4, 3, 2) - (4, 3, 1)
# <--- OUTPUT --->
tensor([[[-0.0533, 0.0000],
[ 0.0000, -0.7302],
[ 0.0000, -0.3184]],
[[ 0.0000, -0.5071],
[-0.3920, 0.0000],
[-0.0460, 0.0000]],
[[-0.0591, 0.0000],
[-0.6848, 0.0000],
[-0.2034, 0.0000]],
[[-0.2914, 0.0000],
[-0.6664, 0.0000],
[ 0.0000, -0.8801]]])
Explanation
- The original tensor
xhas shape(4, 3, 2) -
x.max(dim=2, keepdims=True)[0]does the following:- Takes the maximum value along dimension 2 (last axis) for each row
-
keepdims=Truemaintains the dimension, resulting in shape(4, 3, 1) -
[0]is used becausemax()returns both values and indices; we only want values
- When subtracting
maxOfRowsfromx, broadcasting expands(4, 3, 1)to(4, 3, 2) - The result shows each value’s difference from its row’s maximum:
- Values equal to their row’s maximum become 0
- All other values are negative (being less than their row’s maximum)
Combining Argmax with Tensor Indexing
1
2
3
4
5
"""
Compute argmax of `x` along dim=2, then replace the maximum values along that dimension with their square while keeping other values unchanged.
"""
g = torch.Generator().manual_seed(617)
x = torch.rand(5, 4, 6, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
indices = x.argmax(dim=2) # (5, 4)
mask = torch.zeros_like(x, dtype=torch.bool)
batch_indices = torch.arange(x.size(0)).unsqueeze(1) # (5, 1)
row_indices = torch.arange(x.size(1)).unsqueeze(0) # (1, 4)
mask[batch_indices, row_indices, indices] = True
x = torch.where(mask, x ** 2, x)
# <--- OUTPUT --->
tensor([[[1.9459e-01, 2.4792e-01, 6.2822e-01, 6.2441e-02, 7.0561e-01,
3.8723e-01],
[5.9580e-01, 8.8707e-02, 2.3185e-01, 3.8913e-01, 2.5649e-01,
3.0251e-01],
[4.7937e-01, 5.3843e-01, 8.9037e-02, 5.9889e-01, 8.6632e-02,
2.9001e-01],
[1.4702e-01, 4.3846e-01, 2.2150e-01, 8.8794e-01, 8.5682e-01,
4.5569e-02]],
[[9.7745e-02, 8.3054e-01, 5.4700e-01, 5.5202e-01, 8.9336e-01,
1.0040e-01],
[5.7145e-01, 7.9327e-01, 8.5845e-01, 5.2186e-01, 2.9914e-01,
9.7775e-01],
[6.9493e-02, 5.3443e-01, 6.5223e-01, 7.7981e-01, 8.4665e-01,
4.6055e-01],
[1.4233e-01, 5.9448e-01, 3.5373e-01, 5.4865e-01, 4.8015e-01,
5.8755e-03]],
[[2.9175e-01, 6.6421e-01, 7.5952e-01, 5.5636e-01, 2.5674e-01,
9.7236e-01],
[2.5453e-02, 4.2386e-01, 6.8020e-01, 5.8032e-01, 9.3263e-01,
3.5035e-01],
[1.4200e-01, 1.2870e-01, 3.0852e-01, 1.0691e-01, 2.1062e-01,
1.7581e-02],
[5.7376e-01, 1.3413e-01, 5.2298e-01, 9.0579e-02, 6.6662e-01,
9.5560e-01]],
[[3.7499e-01, 7.5693e-01, 7.9677e-01, 5.3703e-01, 6.5874e-01,
6.6973e-01],
[5.4903e-01, 7.6660e-01, 2.5739e-01, 6.2306e-01, 3.9811e-01,
3.2552e-01],
[9.6114e-02, 6.8212e-01, 9.9059e-01, 8.4072e-01, 5.6158e-01,
4.1336e-01],
[2.9504e-01, 1.7929e-01, 3.7160e-01, 1.4645e-01, 1.4089e-01,
3.6650e-01]],
[[7.9860e-01, 1.6316e-01, 7.0585e-01, 7.4827e-01, 7.2970e-01,
3.8324e-01],
[3.9089e-01, 5.9064e-01, 8.0657e-01, 1.7587e-01, 3.3804e-01,
1.9550e-01],
[7.8799e-01, 3.3231e-01, 7.0377e-02, 8.0018e-01, 5.0456e-01,
9.6001e-01],
[5.4405e-01, 8.8870e-01, 3.7259e-04, 8.8793e-01, 6.1661e-01,
8.4178e-01]]])
Explanation
- First, find the indices of maximum values along dimension
2:-
x.argmax(dim=2)returns a tensor of shape(5, 4)containing the positions of max values
-
- Create a boolean mask of the same shape as
x(5, 4, 6):-
torch.zeros_like(x, dtype=torch.bool)initializes all values toFalse
-
- Create broadcasting-compatible indices to properly index the 3D tensor:
-
batch_indices = torch.arange(5).unsqueeze(1)creates[[0], [1], [2], [3], [4]] -
row_indices = torch.arange(4).unsqueeze(0)creates[[0, 1, 2, 3]] - When used together, these broadcast to cover all positions where max values occur
-
- Set
Truein the mask at positions of maximum values:mask[batch_indices, row_indices, indices] = True- The indexing operation
mask[batch_indices, row_indices, indices]works through broadcasting:1 2 3 4 5 6 7 8 9 10 11 12 13
batch_indices: [[0], # Shape: (5, 1) [1], [2], [3], [4]] row_indices: [[0, 1, 2, 3]] # Shape: (1, 4) indices: [[2, 3, 3, 4], # Shape: (5, 4) [1, 5, 3, 2], [5, 4, 0, 5], [4, 3, 2, 4], [3, 2, 5, 1]]
- When combined, these create index tuples for each maximum value:
- First row:
(0,0,2), (0,1,3), (0,2,3), (0,3,4) - Second row:
(1,0,1), (1,1,5), (1,2,3), (1,3,2) - And so on.
- First row:
- Each tuple represents
(batch_idx, row_idx, max_value_position)
- Finally,
torch.where()selectively applies the squaring:- Where mask is
True: square the values (x ** 2) - Where mask is
False: keep original values (x)
- Where mask is
Comparing Reduction Results Across Axes
1
2
3
# Compute max along dim=2 and mean along dim=1 and add those results back to the original tensor
g = torch.Generator().manual_seed(617)
x = torch.rand(3, 4, 5, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
max_dim2 = x.max(dim=2, keepdim=True).values # (3, 4, 1)
mean_dim1 = x.mean(dim=1, keepdim=True) # (3, 1, 5)
result = x + max_dim2 + mean_dim1 # (3, 4, 5) + (3, 4, 1) + (3, 1, 5)
# <--- OUTPUT --->
tensor([[[1.3902, 1.3487, 1.9979, 1.1000, 1.9624],
[1.4141, 1.5278, 1.1252, 1.1006, 1.7118],
[1.1980, 1.1492, 1.4305, 1.3218, 1.0917],
[1.9508, 1.1687, 1.4766, 1.1658, 1.6766]],
[[1.8204, 2.3880, 2.5689, 1.4995, 1.3685],
[2.4959, 2.0328, 2.1810, 2.3330, 1.3568],
[2.1032, 2.2261, 2.4345, 1.9086, 1.5027],
[2.6509, 1.6327, 2.2409, 2.1693, 2.2170]],
[[2.3380, 1.8515, 1.2353, 2.1506, 1.8501],
[1.8575, 1.6887, 0.9164, 1.6655, 1.7371],
[2.3903, 2.0868, 1.4891, 2.6817, 1.4202],
[2.0343, 2.1902, 1.7924, 2.6410, 1.7248]]])
Explanation
- The original tensor
xhas shape(3, 4, 5) - Two reduction operations are performed:
-
max_dim2: Maximum along dimension 2 (last axis)- Shape goes from
(3, 4, 5)→(3, 4, 1)withkeepdim=True - Each value represents the maximum of its corresponding row
- Shape goes from
-
mean_dim1: Mean along dimension 1 (middle axis)- Shape goes from
(3, 4, 5)→(3, 1, 5)withkeepdim=True - Each value represents the mean across all rows for that position
- Shape goes from
-
- The final addition
x + max_dim2 + mean_dim1involves two broadcasting operations:-
max_dim2(3, 4, 1)is broadcast to(3, 4, 5) -
mean_dim1(3, 1, 5)is broadcast to(3, 4, 5)
-
Standardizing Tensor with Reduction
1
2
3
# Standardize x along the last axis (dim=2) by subtracting the mean and dividing by the standard deviation.
g = torch.Generator().manual_seed(617)
x = torch.rand(3, 5, 4, generator=g)
Solution
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
mean = x.mean(dim=2, keepdims=True) # (3, 5, 1)
std = x.std(dim=2, keepdims=True) # (3, 5, 1)
result = (x - mean) / std # (3, 5, 4) - (3, 5, 1) / (3, 5, 1)
# <--- OUTPUT --->
tensor([[[-4.0344e-01, -2.3768e-01, 1.4553e+00, -8.1418e-01],
[ 9.6276e-01, -2.1043e-01, 5.5813e-01, -1.3105e+00],
[-6.6763e-01, 1.4806e+00, -5.3258e-01, -2.8035e-01],
[ 3.2340e-02, 2.4018e-01, -1.3413e+00, 1.0688e+00],
[-9.7956e-01, 3.1494e-01, -5.9517e-01, 1.2598e+00]],
[[-6.6012e-01, 8.1287e-01, 8.9620e-01, -1.0490e+00],
[-1.2885e+00, 1.1535e+00, 5.9946e-02, 7.5018e-02],
[ 8.6095e-01, -1.3867e+00, -5.1503e-02, 5.7724e-01],
[ 6.0867e-01, -4.6181e-01, -1.1701e+00, 1.0233e+00],
[-1.3592e+00, -1.0854e-03, 3.4300e-01, 1.0173e+00]],
[[ 1.1443e+00, -1.7200e-01, -1.2569e+00, 2.8459e-01],
[ 6.8955e-01, 5.1992e-01, 2.6785e-01, -1.4773e+00],
[-1.3675e+00, 4.7652e-01, 9.4842e-01, -5.7411e-02],
[-4.0622e-01, 1.3754e+00, -9.7121e-01, 2.0231e-03],
[ 1.4138e-01, -2.5028e-01, 1.2610e+00, -1.1521e+00]]])
Conclusion
We have learned about broadcasting semantics and how to use them to perform various operations on tensors.
Remeber the basic rules of broadcasting:
- Each tensor has at least one dimension.
- When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.
Broadcasting is a powerful feature that allows us to perform operations on tensors of different shapes without explicitly reshaping them. It is a key concept in PyTorch and is used extensively in many operations. And this will help us in the next post to build character-level bigram models.