PyTorch 的函數(shù)式神經(jīng)網(wǎng)絡(luò)接口 torch.nn.functional,主要集中在 torch.nn.functional 模塊中,通常簡寫為F。它提供了大量可以直接作用于張量的神經(jīng)網(wǎng)絡(luò)函數(shù),例如線性變換、激活函數(shù)、卷積、池化、Dropout、歸一化、Softmax 和損失函數(shù)等。
簡單地說,torch.nn.functional 模塊回答的是:如果不把某個(gè)操作封裝成一個(gè)網(wǎng)絡(luò)層對象,而是直接把它當(dāng)作函數(shù)來調(diào)用,應(yīng)當(dāng)如何完成神經(jīng)網(wǎng)絡(luò)中的常見計(jì)算。
如果說 torch.nn 中的 nn.Linear、nn.ReLU、nn.Conv2d 等模塊更強(qiáng)調(diào)“層”的組織與參數(shù)管理,那么 torch.nn.functional 更強(qiáng)調(diào)“操作”的直接調(diào)用。理解函數(shù)式接口,有助于更清楚地認(rèn)識(shí) PyTorch 中網(wǎng)絡(luò)層背后的實(shí)際計(jì)算過程。
一、認(rèn)識(shí) torch.nn.functional 模塊
torch.nn.functional 是 PyTorch 中用于神經(jīng)網(wǎng)絡(luò)計(jì)算的函數(shù)式接口模塊。它通常這樣導(dǎo)入:
import torch.nn.functional as F![]()
圖 1:函數(shù)式接口在 PyTorch 神經(jīng)網(wǎng)絡(luò)中的位置
在實(shí)際代碼中,F(xiàn) 經(jīng)常出現(xiàn)在模型的 forward() 方法中。例如:
這里的 F.relu(x) 并不是一個(gè)網(wǎng)絡(luò)層對象,而是一個(gè)直接作用于張量的函數(shù)。它接收輸入張量 x,返回經(jīng)過 ReLU 處理后的新張量。
這類接口的特點(diǎn)是:
? 不自動(dòng)保存可學(xué)習(xí)參數(shù)
? 不自動(dòng)注冊為模型子模塊
? 通常直接在 forward() 中調(diào)用
? 更接近底層張量計(jì)算
? 適合無參數(shù)操作或需要靈活控制的操作
因此,函數(shù)式接口不是 torch.nn 模塊的替代品,而是對模塊式接口的重要補(bǔ)充。
二、函數(shù)式接口與模塊式接口的區(qū)別
PyTorch 中很多操作都有兩種寫法:一種是模塊式接口,一種是函數(shù)式接口。
![]()
圖 2:模塊式接口與函數(shù)式接口的區(qū)別
例如 ReLU 可以寫成模塊式:
也可以寫成函數(shù)式:
這兩種寫法在計(jì)算結(jié)果上通常相同,但使用語義不同。
模塊式接口強(qiáng)調(diào)“這是模型中的一個(gè)層”:
self.relu = nn.ReLU()函數(shù)式接口強(qiáng)調(diào)“這里執(zhí)行一次操作”:
x = F.relu(x)對于沒有可學(xué)習(xí)參數(shù)、也不需要保存狀態(tài)的操作,函數(shù)式接口通常很方便。例如:
? F.relu
? F.gelu
? F.softmax
? F.max_pool2d
? F.cross_entropy
對于有可學(xué)習(xí)參數(shù)或內(nèi)部狀態(tài)的結(jié)構(gòu),通常更推薦使用模塊式接口。例如:
? nn.Linear
? nn.Conv2d
? nn.BatchNorm2d
? nn.Embedding
原因是這些層需要保存權(quán)重、偏置、運(yùn)行統(tǒng)計(jì)量或其他狀態(tài)。模塊式接口可以自動(dòng)注冊參數(shù),并讓 model.parameters()、state_dict()、train()、eval() 等機(jī)制正常工作。
三、函數(shù)式接口的基本使用方式
函數(shù)式接口通常直接接收輸入張量,并返回輸出張量。
例如:
輸出結(jié)果為:
tensor([0., 0., 0., 1., 2.])這里沒有創(chuàng)建 nn.ReLU() 對象,而是直接調(diào)用 F.relu(x)。
再例如:
這里的 dim=1 表示對每個(gè)樣本的類別維度做 Softmax,使每一行的數(shù)值轉(zhuǎn)換為概率分布。
函數(shù)式接口的基本形式可以概括為:
輸出張量 = F.函數(shù)名(輸入張量, 其他參數(shù))
它更像普通數(shù)學(xué)函數(shù),不負(fù)責(zé)保存模型結(jié)構(gòu),也不負(fù)責(zé)管理參數(shù)。
四、常用激活函數(shù):ReLU、Sigmoid、Tanh 與 GELU
激活函數(shù)用于引入非線性能力。函數(shù)式接口中常見激活函數(shù)包括:
? F.relu
? F.sigmoid
? F.tanh
? F.leaky_relu
? F.gelu
? F.silu
示例:
其中,ReLU 是最常見的激活函數(shù)之一:
y = F.relu(x)它會(huì)把負(fù)數(shù)變?yōu)?0,正數(shù)保持不變。
GELU 在 Transformer、BERT、GPT 等模型中較常見:
y = F.gelu(x)在模型中,函數(shù)式激活函數(shù)常寫在 forward() 中:
這種寫法簡潔清晰,適合沒有內(nèi)部狀態(tài)的激活函數(shù)。
五、F.linear:函數(shù)式線性變換
F.linear 用于執(zhí)行線性變換。它與 nn.Linear 的核心計(jì)算一致,但使用方式不同。
線性變換可以寫為:
其中:
? X 表示輸入張量
? W 表示權(quán)重矩陣
? b 表示偏置
? Y 表示輸出張量
使用 nn.Linear 時(shí),權(quán)重和偏置由模塊自動(dòng)創(chuàng)建并管理:
使用 F.linear 時(shí),需要自己提供權(quán)重和偏置:
這里要特別注意權(quán)重形狀:
weight.shape == (out_features, in_features)
而不是:
(in_features, out_features)
在普通神經(jīng)網(wǎng)絡(luò)中,通常更推薦使用 nn.Linear,因?yàn)樗鼤?huì)自動(dòng)創(chuàng)建并注冊參數(shù)。F.linear 更適合需要手動(dòng)控制權(quán)重的場景,例如自定義層、元學(xué)習(xí)、函數(shù)式模型調(diào)用或特殊實(shí)驗(yàn)。
六、F.conv2d:函數(shù)式卷積操作
F.conv2d 用于執(zhí)行二維卷積操作。它與 nn.Conv2d 的核心計(jì)算對應(yīng),但同樣需要手動(dòng)提供卷積核權(quán)重。
示例:
這里需要注意:
? 輸入 x 的形狀通常是 (N, C, H, W)
? weight 的形狀是 (out_channels, in_channels, kH, kW)
? bias 的形狀是 (out_channels,)
? padding=1 可以在 kernel_size=3 時(shí)保持高寬不變
與 nn.Conv2d 相比,F(xiàn).conv2d 不會(huì)自動(dòng)創(chuàng)建卷積核參數(shù)。普通卷積網(wǎng)絡(luò)中,通常使用:
self.conv = nn.Conv2d(3, 16, kernel_size=3, padding=1)而不是手動(dòng)創(chuàng)建 weight 后調(diào)用 F.conv2d。
不過,在實(shí)現(xiàn)特殊卷積、自定義權(quán)重共享、動(dòng)態(tài)卷積或研究型模型時(shí),函數(shù)式卷積接口非常有用。
七、池化函數(shù):max_pool2d 與 avg_pool2d
池化層常用于降低特征圖的空間尺寸,減少計(jì)算量,并增強(qiáng)局部特征的穩(wěn)定性。
函數(shù)式接口中常見池化函數(shù)包括:
? F.max_pool2d
? F.avg_pool2d
示例:
平均池化示例:
池化操作通常沒有可學(xué)習(xí)參數(shù),因此使用函數(shù)式接口很自然。
在模型中可以寫成:
這里的卷積層有參數(shù),因此使用 nn.Conv2d;池化和 ReLU 沒有參數(shù),因此使用 F.max_pool2d 和 F.relu,這是一種常見組合。
八、F.dropout:函數(shù)式 Dropout
Dropout 用于在訓(xùn)練階段隨機(jī)丟棄部分神經(jīng)元輸出,從而緩解過擬合。
模塊式寫法通常是:
self.dropout = nn.Dropout(p=0.5)函數(shù)式寫法是:
x = F.dropout(x, p=0.5, training=self.training)完整示例:
這里最重要的是:
training=self.training原因是 Dropout 在訓(xùn)練模式和推理模式下行為不同:
? 訓(xùn)練模式:隨機(jī)丟棄部分元素
? 推理模式:不再隨機(jī)丟棄元素
如果使用 nn.Dropout,它會(huì)自動(dòng)根據(jù) model.train() 和 model.eval() 切換行為。如果使用 F.dropout,則應(yīng)顯式傳入 training=self.training,否則容易在推理階段仍然執(zhí)行隨機(jī)丟棄,導(dǎo)致結(jié)果不穩(wěn)定。
因此,函數(shù)式 Dropout 雖然靈活,但也更容易寫錯(cuò)。
九、Softmax、LogSoftmax 與分類輸出
Softmax 常用于把分類模型輸出的 logits 轉(zhuǎn)換為概率分布。
示例:
這里的 dim=1 表示在類別維度上做歸一化。對于形狀為 (batch_size, num_classes) 的 logits,通常應(yīng)使用:
F.softmax(logits, dim=1)需要注意的是,訓(xùn)練多分類模型時(shí),如果使用 F.cross_entropy 或 nn.CrossEntropyLoss,通常不需要提前手動(dòng)調(diào)用 Softmax。
常見訓(xùn)練寫法是:
loss = F.cross_entropy(logits, target)而不是:
loss = F.cross_entropy(probs, target)原因是交叉熵?fù)p失函數(shù)期望輸入的是原始 logits,而不是已經(jīng)歸一化后的概率。提前做 Softmax 可能導(dǎo)致數(shù)值穩(wěn)定性和訓(xùn)練效果問題。
在推理階段,如果需要查看每個(gè)類別的概率,可以再使用:
pred = torch.argmax(probs, dim=1)如果只是取預(yù)測類別,也可以直接對 logits 取最大值:
pred = torch.argmax(logits, dim=1)因?yàn)?Softmax 不會(huì)改變各類別分?jǐn)?shù)的相對大小順序。
十、函數(shù)式損失函數(shù)
torch.nn.functional 中也提供了多種損失函數(shù),例如:
? F.cross_entropy
? F.binary_cross_entropy
? F.binary_cross_entropy_with_logits
? F.mse_loss
? F.l1_loss
1、分類任務(wù)中的 F.cross_entropy
多分類任務(wù)中,常用 F.cross_entropy:
這里要特別注意:
? logits 的形狀通常是 (batch_size, num_classes)
? target 的形狀通常是 (batch_size,)
? target 中保存的是類別編號(hào)
? 不需要提前對 logits 做 Softmax
F.cross_entropy 和 nn.CrossEntropyLoss() 的核心計(jì)算含義一致。區(qū)別在于,前者是函數(shù)式接口,直接調(diào)用;后者是模塊式接口,需要先創(chuàng)建損失函數(shù)對象。
模塊式寫法:
loss = criterion(logits, target)函數(shù)式寫法:
loss = F.cross_entropy(logits, target)如果損失函數(shù)沒有需要長期保存的配置,函數(shù)式寫法會(huì)更簡潔。
2、回歸任務(wù)中的 F.mse_loss
回歸任務(wù)中,可以使用 F.mse_loss:
這里預(yù)測值和目標(biāo)值都應(yīng)是浮點(diǎn)張量,并且形狀應(yīng)盡量保持一致。
如果 pred 的形狀是 (8, 1),而 target 的形狀是 (8,),可能觸發(fā)廣播,使損失計(jì)算的含義不符合預(yù)期。因此,回歸任務(wù)中要特別注意形狀。
3、二分類與多標(biāo)簽任務(wù)
二分類或多標(biāo)簽任務(wù)中,常用:
loss = F.binary_cross_entropy_with_logits(logits, target)它通常比先手動(dòng)做 Sigmoid 再計(jì)算二元交叉熵更穩(wěn)妥。
常見寫法:
這里的 target 應(yīng)是浮點(diǎn)張量,取值通常為 0 或 1。
需要注意:
? 多分類任務(wù)常用 F.cross_entropy
? 二分類或多標(biāo)簽任務(wù)常用 F.binary_cross_entropy_with_logits
? 回歸任務(wù)常用 F.mse_loss 或 F.l1_loss
損失函數(shù)要與任務(wù)類型和模型輸出形式匹配,否則即使代碼能運(yùn)行,模型也可能學(xué)不到正確目標(biāo)。
十一、函數(shù)式歸一化與距離計(jì)算
torch.nn.functional 還提供了一些用于特征處理的函數(shù),例如:
? F.normalize
? F.cosine_similarity
? F.pairwise_distance
F.normalize 可以把向量按指定維度歸一化,常用于表示學(xué)習(xí)、檢索、對比學(xué)習(xí)等場景。
示例:
這里的 dim=1 表示對每個(gè)樣本的特征維度做歸一化。
歸一化后的向量長度通常接近 1,這使得向量之間的方向關(guān)系更加突出。
例如,在計(jì)算余弦相似度時(shí),可以寫成:
這類函數(shù)不是傳統(tǒng)意義上的“網(wǎng)絡(luò)層”,但在現(xiàn)代深度學(xué)習(xí)任務(wù)中非常常用,尤其適合特征表示、相似度計(jì)算和對比學(xué)習(xí)。
十二、函數(shù)式接口在自定義模型中的典型用法
函數(shù)式接口常用于 forward() 中,與模塊式層配合使用。
例如,一個(gè)簡單的卷積網(wǎng)絡(luò)可以寫成:
這個(gè)例子體現(xiàn)了常見原則:
? 有參數(shù)的層使用模塊式接口,如 nn.Conv2d、nn.Linear
? 無參數(shù)操作使用函數(shù)式接口,如 F.relu、F.max_pool2d
? 數(shù)據(jù)流動(dòng)邏輯寫在 forward() 中
這種寫法既保持了模型參數(shù)管理的規(guī)范性,也讓前向計(jì)算過程更加簡潔。
十三、函數(shù)式接口與模型狀態(tài)
函數(shù)式接口本身通常不保存狀態(tài)。這一點(diǎn)非常重要。
例如:
x = F.relu(x)這只是對 x 執(zhí)行一次 ReLU 操作,不會(huì)在模型中注冊一個(gè)名為 ReLU 的子模塊。
這意味著:
? 它不會(huì)出現(xiàn)在 model.children() 中
? 它不會(huì)出現(xiàn)在 model.state_dict() 中
? 它沒有自己的參數(shù)
? 它不需要優(yōu)化器更新
對于 ReLU、Softmax、池化等無參數(shù)操作,這通常沒有問題。
但對于需要參數(shù)或狀態(tài)的結(jié)構(gòu),就要謹(jǐn)慎。例如,BatchNorm 不僅涉及當(dāng)前輸入,還涉及訓(xùn)練階段和推理階段的統(tǒng)計(jì)量。普通模型中通常推薦寫成:
self.bn = nn.BatchNorm2d(16)而不是優(yōu)先使用函數(shù)式接口手動(dòng)處理。
類似地,卷積和線性層雖然有函數(shù)式接口 F.conv2d、F.linear,但如果權(quán)重需要作為模型參數(shù)訓(xùn)練,通常應(yīng)使用 nn.Conv2d、nn.Linear 來自動(dòng)管理參數(shù)。
十四、什么時(shí)候使用函數(shù)式接口
函數(shù)式接口適合以下場景。
1、操作沒有可學(xué)習(xí)參數(shù)
例如:
x = F.max_pool2d(x, 2)這類操作不需要保存權(quán)重,也不需要注冊到模型參數(shù)中,使用函數(shù)式接口很自然。
2、只想在 forward() 中表達(dá)一次計(jì)算
例如:
x = F.gelu(x)這種寫法簡潔直接,適合表達(dá)數(shù)據(jù)流中的一次變換。
3、需要手動(dòng)控制權(quán)重
例如:
y = F.linear(x, weight, bias)這適合自定義層、特殊實(shí)驗(yàn)、元學(xué)習(xí)或需要?jiǎng)討B(tài)生成權(quán)重的模型。
4、損失函數(shù)配置簡單
例如:
loss = F.cross_entropy(logits, target)如果沒有復(fù)雜配置,函數(shù)式損失函數(shù)可以讓代碼更簡潔。
5、需要更靈活的底層控制
一些研究型模型需要精細(xì)控制每一步計(jì)算,此時(shí)函數(shù)式接口比模塊式接口更靈活。
十五、什么時(shí)候更適合使用模塊式接口
雖然函數(shù)式接口很方便,但并不是所有操作都適合優(yōu)先使用它。
以下場景通常更適合使用模塊式接口:
1、操作包含可學(xué)習(xí)參數(shù)
例如:
self.conv = nn.Conv2d(3, 16, 3)如果使用 F.linear 或 F.conv2d,就必須自己管理 weight 和 bias,這對初學(xué)者并不友好,也容易導(dǎo)致參數(shù)沒有正確注冊。
2、操作包含訓(xùn)練狀態(tài)
例如:
self.bn = nn.BatchNorm2d(16)這些結(jié)構(gòu)在訓(xùn)練和推理階段行為可能不同。模塊式接口能更自然地響應(yīng) model.train() 和 model.eval()。
3、希望模型結(jié)構(gòu)更清晰
模塊式接口會(huì)顯示在模型結(jié)構(gòu)中。例如:
print(model)如果使用很多函數(shù)式操作,有些操作不會(huì)在模型結(jié)構(gòu)打印結(jié)果中顯示出來。這可能影響模型結(jié)構(gòu)閱讀和調(diào)試。
4、希望保存和加載狀態(tài)更簡單
模塊式接口能夠與 state_dict() 更好地配合。凡是需要保存參數(shù)或統(tǒng)計(jì)狀態(tài)的部分,都應(yīng)優(yōu)先考慮模塊式寫法。
十六、一個(gè)完整示例:函數(shù)式接口參與訓(xùn)練
下面用一個(gè)完整示例,把 torch.nn.functional 放入神經(jīng)網(wǎng)絡(luò)訓(xùn)練流程中。
![]()
圖 3:函數(shù)式接口參與前向傳播與損失計(jì)算
這個(gè)示例中:
? nn.Linear 負(fù)責(zé)保存可學(xué)習(xí)參數(shù)
? F.relu 負(fù)責(zé)執(zhí)行無參數(shù)激活操作
? F.cross_entropy 負(fù)責(zé)計(jì)算分類損失
? optimizer.step() 負(fù)責(zé)更新參數(shù)
這說明,函數(shù)式接口通常不是單獨(dú)使用的,而是與 nn.Module、自動(dòng)求導(dǎo)和優(yōu)化器共同構(gòu)成訓(xùn)練流程。
十七、使用函數(shù)式接口時(shí)應(yīng)注意的問題
1、不要把有參數(shù)的層隨意改成函數(shù)式接口
例如,nn.Linear 會(huì)自動(dòng)管理權(quán)重和偏置,而 F.linear 需要手動(dòng)傳入 weight 和 bias。
普通模型中推薦寫成:
self.fc = nn.Linear(4, 3)而不是:
y = F.linear(x, weight, bias)除非你明確知道如何把這些權(quán)重注冊為可學(xué)習(xí)參數(shù)。
2、使用 F.dropout 時(shí)要傳入 training=self.training
推薦寫法:
x = F.dropout(x, p=0.5, training=self.training)如果遺漏 training=self.training,Dropout 的行為可能與 model.train()、model.eval() 不一致,尤其容易影響驗(yàn)證和推理結(jié)果。
3、使用 Softmax 時(shí)要明確 dim
錯(cuò)誤或不清晰的寫法:
probs = F.softmax(logits)推薦寫法:
probs = F.softmax(logits, dim=1)對于分類任務(wù),dim 通常表示類別所在的維度。維度寫錯(cuò)會(huì)導(dǎo)致概率歸一化方向錯(cuò)誤。
4、訓(xùn)練多分類模型時(shí)不要在 CrossEntropy 前手動(dòng) Softmax
推薦寫法:
loss = F.cross_entropy(logits, target)不推薦寫成:
loss = F.cross_entropy(probs, target)因?yàn)榻徊骒負(fù)p失函數(shù)期望輸入 logits,而不是概率。
5、注意函數(shù)式接口不會(huì)注冊子模塊
例如:
x = F.relu(x)這不會(huì)讓 ReLU 出現(xiàn)在 model.children() 中,也不會(huì)保存到 state_dict() 中。對于無參數(shù)操作,這通常沒有問題;但對于需要參數(shù)或狀態(tài)的結(jié)構(gòu),就要謹(jǐn)慎。
6、注意輸入形狀
不同函數(shù)式接口對輸入形狀有不同要求:
? F.linear 通常關(guān)注輸入最后一維
? F.conv2d 通常要求輸入為 (N, C, H, W)
? F.max_pool2d 通常作用于圖像特征圖
? F.cross_entropy 常用 (N, C) 的 logits 和 (N,) 的類別標(biāo)簽
調(diào)試時(shí)可以多打印:
print(target.shape)很多函數(shù)式接口錯(cuò)誤,本質(zhì)上都是輸入維度不匹配。
7、不要混淆 torch.nn.functional 與 torch.func
torch.nn.functional 主要提供神經(jīng)網(wǎng)絡(luò)中的函數(shù)式操作,例如激活、卷積、池化、損失函數(shù)等。
torch.func 則更偏向函數(shù)變換、函數(shù)式調(diào)用、向量化和高階自動(dòng)微分等高級用途。
初學(xué) PyTorch 神經(jīng)網(wǎng)絡(luò)時(shí),通常先掌握 torch.nn.functional;深入研究函數(shù)變換、元學(xué)習(xí)或高級自動(dòng)微分時(shí),再進(jìn)一步學(xué)習(xí) torch.func。
小結(jié)
torch.nn.functional 提供直接作用于張量的神經(jīng)網(wǎng)絡(luò)函數(shù),是模塊式接口的重要補(bǔ)充。學(xué)習(xí)這一模塊,重點(diǎn)是理解哪些操作適合函數(shù)式寫法,哪些應(yīng)交給 nn.Module 管理。通常,無參數(shù)操作可用 F.relu、F.max_pool2d 等函數(shù)式接口;有參數(shù)或狀態(tài)的層更適合使用 nn.Linear、nn.Conv2d、nn.BatchNorm2d 等模塊式接口。
“點(diǎn)贊有美意,贊賞是鼓勵(lì)”
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺(tái)“網(wǎng)易號(hào)”用戶上傳并發(fā)布,本平臺(tái)僅提供信息存儲(chǔ)服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.