Python 中的广播机制 (Broadcasting)

广播 (Broadcasting) 是 NumPy 和 PyTorch 等科学计算库中的一种机制,允许不同形状的数组进行算术运算,而无需显式复制数据。

Posted by nothin on October 27, 2025

Python 中的广播机制 (Broadcasting)

什么是广播

广播 (Broadcasting) 是 NumPy 和 PyTorch 等科学计算库中的一种机制,允许不同形状的数组进行算术运算,而无需显式复制数据。

核心思想

广播通过虚拟扩展较小数组的形状来匹配较大数组,在计算时重复使用数据,而不实际占用额外内存。

优势:

  • 内存高效:不创建数据副本
  • 代码简洁:避免手写循环
  • 计算快速:利用向量化操作

广播的三大规则

规则 1:维度对齐

如果两个数组维度数不同,在形状较小的数组前面补 1,直到维度数相同。

1
2
3
4
5
6
7
import numpy as np

a = np.array([1, 2, 3, 4])      # shape: (4,)
b = np.array([[10], [20], [30]]) # shape: (3, 1)

# a 自动变为 (1, 4)
# b 保持为 (3, 1)

规则 2:维度兼容性检查

从右向左逐个比较每个维度,满足以下条件之一即为兼容:

  • 两个维度相等
  • 其中一个维度为 1
1
2
3
4
5
6
7
#  兼容示例
(3, 1)  (1, 4)   可以广播
(5, 3, 4)  (3, 4)  可以广播
(8, 1, 6, 1)  (7, 1, 5)  可以广播

#  不兼容示例
(3, 4)  (3, 5)   无法广播 (最后一维 45)

规则 3:形状扩展

将维度为 1 的维度”拉伸”到匹配另一个数组的对应维度。

1
2
(3, 1)  (3, 4)  # 第2维从1扩展到4
(1, 4)  (3, 4)  # 第1维从1扩展到3

基础示例

示例 1:向量与标量

1
2
3
4
5
6
7
8
9
10
11
import numpy as np

# 标量广播到向量
a = np.array([1, 2, 3, 4])
b = 10

result = a + b
print(result)  # [11 12 13 14]

# 等价于:
# b 被广播为 [10, 10, 10, 10]

示例 2:一维数组与二维数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 向量广播到矩阵
matrix = np.array([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])

vector = np.array([10, 20, 30])

result = matrix + vector
print(result)
# [[11 22 33]
#  [14 25 36]
#  [17 28 39]]

# vector 被广播为:
# [[10, 20, 30],
#  [10, 20, 30],
#  [10, 20, 30]]

示例 3:外积运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 列向量 + 行向量 = 矩阵
x = np.array([1, 2, 3, 4])      # shape: (4,)
y = np.array([10, 20, 30])      # shape: (3,)

# 增加维度
x_col = x[:, np.newaxis]  # shape: (4, 1)
y_row = y[np.newaxis, :]  # shape: (1, 3)

result = x_col + y_row
print(result)
# [[11 21 31]
#  [12 22 32]
#  [13 23 33]
#  [14 24 34]]

示例 4:复杂形状

1
2
3
4
5
6
7
# 三维张量广播
a = np.ones((3, 4, 5))      # shape: (3, 4, 5)
b = np.ones((4, 1))         # shape: (4, 1)

# b 自动变为 (1, 4, 1),然后广播到 (3, 4, 5)
result = a + b
print(result.shape)  # (3, 4, 5)

常见应用场景

1. 数据归一化

1
2
3
4
5
6
7
8
9
# 按列归一化
data = np.random.randn(100, 5)  # 100个样本,5个特征

# 计算每列的均值和标准差
mean = data.mean(axis=0)  # shape: (5,)
std = data.std(axis=0)    # shape: (5,)

# 广播标准化
normalized = (data - mean) / std  # mean 和 std 自动广播到 (100, 5)

2. 图像处理

1
2
3
4
5
6
# RGB 图像每个通道减去均值
image = np.random.randint(0, 255, (224, 224, 3))  # H×W×C
mean_rgb = np.array([123.675, 116.28, 103.53])    # shape: (3,)

# 广播减法
centered_image = image - mean_rgb  # mean_rgb 广播到 (224, 224, 3)

3. 距离矩阵计算

1
2
3
4
5
6
7
8
# 计算所有点对之间的欧式距离
points = np.random.randn(100, 2)  # 100个2D点

# 利用广播计算距离矩阵
diff = points[:, np.newaxis, :] - points[np.newaxis, :, :]  
# shape: (100, 100, 2)

distances = np.sqrt((diff ** 2).sum(axis=2))  # shape: (100, 100)

深度学习中的应用

0. Triton Kernel 中的向量外积

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np

def add_vec_kernel_numpy(x, y):
    z = y[:, None] + x[None, :]
    return z

# 示例
x = np.array([1, 2, 3, 4])       # B0 = 4
y = np.array([10, 20, 30])       # B1 = 3

z = add_vec_kernel_numpy(x, y)
print(z)
# [[11 12 13 14]
#  [21 22 23 24]
#  [31 32 33 34]]

print(f"原始形状: x{x.shape}, y{y.shape}")
y_col = y[:, None]
x_row = x[None, :]
print(f"增加维度: y[:,None]{y_col.shape}, x[None,:]{x_row.shape}")
'''
步骤1 - 原始形状: x(4,), y(3,)
步骤2 - 增加维度: y[:,None](3, 1), x[None,:](1, 4)
步骤3 - 广播扩展:
  y[:,None] 从 (3,1) 广播到 (3,4)
  x[None,:] 从 (1,4) 广播到 (3,4)
步骤4 - 结果形状: (3, 4)
'''

1. Batch Normalization

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# BatchNorm 中的广播操作
def batch_norm(x, gamma, beta, eps=1e-5):
    """
    x: (N, C, H, W) - 批量图像
    gamma, beta: (C,) - 可学习参数
    """
    # 计算每个通道的均值和方差
    mean = x.mean(dim=(0, 2, 3), keepdim=True)  # shape: (1, C, 1, 1)
    var = x.var(dim=(0, 2, 3), keepdim=True)    # shape: (1, C, 1, 1)
    
    # 标准化
    x_norm = (x - mean) / torch.sqrt(var + eps)
    
    # gamma 和 beta 广播
    gamma = gamma.view(1, -1, 1, 1)  # (1, C, 1, 1)
    beta = beta.view(1, -1, 1, 1)    # (1, C, 1, 1)
    
    return gamma * x_norm + beta

# 使用示例
x = torch.randn(32, 64, 28, 28)  # (N, C, H, W)
gamma = torch.ones(64)
beta = torch.zeros(64)

output = batch_norm(x, gamma, beta)
print(output.shape)  # torch.Size([32, 64, 28, 28])

性能对比

广播 vs 显式循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import time
import numpy as np
# 准备数据
matrix = np.random.randn(1000, 1000)
vector = np.random.randn(1000)
# 方法1:广播(推荐)
start = time.time()
result1 = matrix + vector
time_broadcast = time.time() - start
# 方法2:显式循环
start = time.time()
result2 = np.zeros_like(matrix)
for i in range(matrix.shape[0]):
    result2[i] = matrix[i] + vector
    
time_loop = time.time() - start

print(f"广播耗时: {time_broadcast:.6f}s")
print(f"循环耗时: {time_loop:.6f}s")
print(f"加速比: {time_loop / time_broadcast:.1f}x")

# 典型输出:
# 广播耗时: 0.000631s
# 循环耗时: 0.001083s
# 加速比: 1.7x

常见陷阱与调试技巧

陷阱 1:意外的广播

1
2
3
4
5
6
7
8
# 错误示例
a = np.array([[1, 2, 3]])     # shape: (1, 3)
b = np.array([[1], [2], [3]]) # shape: (3, 1)

c = a + b  # 意外得到 (3, 3) 的结果!
print(c.shape)  # (3, 3)

assert a.shape == b.shape, "形状不匹配"

调试技巧

1
2
3
4
5
6
7
8
9
10
# 1. 使用 .shape 检查
print(f"a.shape = {a.shape}, b.shape = {b.shape}")

# 2. 使用 np.broadcast_shapes 预测结果
from numpy import broadcast_shapes
result_shape = broadcast_shapes(a.shape, b.shape)
print(f"广播后形状: {result_shape}")

# 3. 使用 keepdim 保持维度
mean = a.mean(axis=1, keepdims=True)  # 保持维度便于广播