まずは線形回帰からやってみる。
Contents
Pytorchインストール
pipでpytorch他をインストールする。
pip install torch torchvision torchinfo torchviz
必要なライブラリをインポート
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchviz import make_dot
from IPython.display import display
データセット作成
に適当にノイズを乗せてデータにする。
x = np.random.rand(100,1) * 2
y = 5 * x + 3 + np.random.randn(100,1)
plt.scatter(x, y, s=10)
plt.xlabel('x')
plt.ylabel('y')
plt.show()

Tensorに変換。
inputs = torch.tensor(x).float()
labels = torch.tensor(y).float()
モデル定義
class Net(nn.Module):
def __init__(self, n_input, n_output):
super().__init__()
self.l1 = nn.Linear(n_input, n_output)
def forward(self, x):
x1 = self.l1(x)
return x1
損失関数と最適化関数定義
net = Net(1, 1) # モデルインスタンス
criterion = nn.MSELoss() # 損失関数: 平均2乗誤差
lr = 0.01 # 学習率
optimizer = optim.SGD(net.parameters(), lr=lr) # 最適化関数: 勾配降下法
計算グラフ
とりま計算グラフを出してみる。
outputs = net(inputs) # とりま一回計算
loss = criterion(outputs, labels) # とりま一回損失計算
g = make_dot(loss, params=dict(net.named_parameters()))
display(g)

pngで保存したい時は下記のようにする。
make_dot(loss, params=dict(net.named_parameters())).render("figure_name", format="png")
ちなみに、予めgraphvizをインストールしておかないといけない(Mac)。
brew install graphviz
ループを回す
total_epoch = 1000 # epoch数
loss_log = np.zeros((0,2)) # 損失記録用
for epoch in range(total_epoch):
optimizer.zero_grad() # 勾配値初期化
outputs = net(inputs) # 予測計算
loss = criterion(outputs, labels) # 損失計算
loss.backward() # 勾配計算(誤差逆伝播)
optimizer.step() # パラメータ更新
# 10回ごとに途中経過を記録する
if (epoch % 10 == 0):
loss_log = np.vstack((loss_log, np.array([epoch, loss.item()])))
print(f'Epoch {epoch} loss: {loss.item():.5f}')
結果
パラメータ
print(net.state_dict())
OrderedDict([(‘l1.weight’, tensor([[4.3680]])), (‘l1.bias’, tensor([3.6119]))])
傾き:4.4、切片:3.6くらい。
を元にノイズを乗せてデータを作ったので、こんなもんでしょう。
プロット
グラフにプロットしてみる。
line_x = np.array((x.min(), x.max())).reshape(-1,1)
line_xt = torch.tensor(xse).float()
with torch.no_grad():
line_yt = net(line_xt)
plt.scatter(x, y, s=10)
plt.xlabel('x')
plt.ylabel('y')
plt.plot(line_xt, line_yt, c='m')
plt.show()

損失グラフ
plt.plot(loss_log[:,0], loss_log[:,1])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

いい感じに収束していってる。
