Эта функция называется транслированием (broadcasting). batch_t формы (2, 3, 5, 5) умножается на unsqueezed_weights формы (3, 1, 1), в результате чего получается тензор формы (2, 3, 5, 5), в котором затем можно сложить третье измерение с конца (три канала):
# In[5]:
unsqueezed_weights = weights.unsqueeze(-1).unsqueeze_(-1)
img_weights = (img_t * unsqueezed_weights)
batch_weights = (batch_t * unsqueezed_weights)
img_gray_weighted = img_weights.sum(-3)
batch_gray_weighted = batch_weights.sum(-3)
batch_weights.shape, batch_t.shape, unsqueezed_weights.shape
# Out[5]:
(torch.Size([2, 3, 5, 5]), torch.Size([2, 3, 5, 5]), torch.Size([3, 1, 1]))