您好,登錄后才能下訂單哦!
這篇文章主要介紹“pytorch實踐線性模型3d源碼分析”的相關知識,小編通過實際案例向大家展示操作過程,操作方法簡單快捷,實用性強,希望這篇“pytorch實踐線性模型3d源碼分析”文章能幫助大家解決問題。
y = wx +b
通過meshgrid 得到兩個二維矩陣
關鍵理解:
plot_surface需要的xyz是二維np數組
這里提前準備meshgrid來生產x和y需要的參數
下圖的W和I即plot_surface需要xy
Z即我們需要的權重損失
計算方式要和W,I. I的每行中內容是一樣的就是y=wx+b的b是一樣的
fig = plt.figure() ax = fig.add_axes(Axes3D(fig)) ax.plot_surface(W, I, Z=MSE_data)
總的實驗代碼
import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.mplot3d import Axes3D class LinearModel: @staticmethod def forward(w, x): return w * x @staticmethod def forward_with_intercept(w, x, b): return w * x + b @staticmethod def get_loss(w, x, y_origin, exp=2, b=None): if b: y = LinearModel.forward_with_intercept(w, x, b) else: y = LinearModel.forward(w, x) return pow(y_origin - y, exp) def test_2d(): x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0] weight_data = [] MSE_data = [] # 設定實驗的權重范圍 for w in np.arange(0.0, 4.1, 0.1): weight_data.append(w) loss_total = 0 # 計算每個權重在數據集上的MSE平均平方方差 for x_val, y_val in zip(x_data, y_data): loss_total += LinearModel.get_loss(w, x_val, y_val) MSE_data.append(loss_total / len(x_data)) # 繪圖 plt.xlabel("weight") plt.ylabel("MSE") plt.plot(weight_data, MSE_data) plt.show() def test_3d(): x_data = [1.0, 2.0, 3.0] y_data = [5.0, 8.0, 11.0] weight_data = np.arange(0.0, 4.1, 0.1) intercept_data = np.arange(0.0, 4.1, 0.1) W, I = np.meshgrid(weight_data, intercept_data) MSE_data = [] # 設定實驗的權重范圍 循環要先寫截距的 meshgrid 的返回第二個是相當于41*41 同一行值相同 ,要在第二層循環去遍歷權重 for intercept in intercept_data: MSE_data_tmp = [] for w in weight_data: loss_total = 0 # 計算每個權重在數據集上的MSE平均平方方差 for x_val, y_val in zip(x_data, y_data): loss_total += LinearModel.get_loss(w, x_val, y_val, b=intercept) MSE_data_tmp.append(loss_total / len(x_data)) MSE_data.append(MSE_data_tmp) MSE_data = np.array(MSE_data) fig = plt.figure() ax = fig.add_axes(Axes3D(fig)) ax.plot_surface(W, I, Z=MSE_data) plt.xlabel("weight") plt.ylabel("intercept") plt.show() if __name__ == '__main__': test_2d() test_3d()
關于“pytorch實踐線性模型3d源碼分析”的內容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業相關的知識,可以關注億速云行業資訊頻道,小編每天都會為大家更新不同的知識點。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。