lesson4. 构建神经网络

约 2469 字大约 8 分钟...

lesson4. 构建神经网络

import os 
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
/home/lsj/.conda/envs/pt12/lib/python3.8/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at  /opt/conda/conda-bld/pytorch_1656352465323/work/c10/cuda/CUDAFunctions.cpp:109.)
  return torch._C._cuda_getDeviceCount() > 0

获取训练设备

  • 查看有无可用GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
cpu

定义神经网络类

  • super(NeuralNetwork,self) :查找NeuralNetwork的父类,对self实施父类的方法
  • nn.Flatten(x,[start=1,end=-1]) :对输入张量进行指定维数降维,此处将(1,28,28)降成(1,28*28)
  • nn.Sequential() :序列容器,将神经网络模块按顺序添加到容器中
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork,self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28,512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU(),
            nn.Linear(512,10),
        )
        
    def forward(self,x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
##将模型移入GPU并打印其网络结构
model = NeuralNetwork().to(device)
print(model)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
##输入数据到模型模块进行推理,不要直接调用model.forward()!!!
X = torch.rand(1,28,28,device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(dim=1)
print(logits)
print(pred_probab)
print(y_pred)
tensor([[ 0.1835,  0.0703,  0.0762, -0.0534, -0.0084,  0.0368,  0.1318, -0.0841,
         -0.0383, -0.0130]], grad_fn=<AddmmBackward0>)
tensor([[0.1162, 0.1037, 0.1044, 0.0917, 0.0959, 0.1003, 0.1103, 0.0889, 0.0931,
         0.0955]], grad_fn=<SoftmaxBackward0>)
tensor([0])

模型层解构

input_image = torch.rand(3,28,28)
print(input_image.size())
torch.Size([3, 28, 28])

nn.Flatten

flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())
torch.Size([3, 784])

nn.Linear

layer1 = nn.Linear(in_features=28*28,out_features=20)
hidden1 = layer1(flat_image)
print(hidden1)
tensor([[-0.3914,  0.9743,  0.3574,  0.0032, -0.0105, -0.3178, -0.6143,  0.0487,
          0.0233, -0.4386,  0.3226,  0.0912,  0.0098,  0.0723, -0.1843, -0.5586,
         -0.0618, -0.0330,  0.6477,  0.4035],
        [-0.5303,  0.4330,  0.5043,  0.3772, -0.3653, -0.2800, -0.3662,  0.0570,
          0.3869,  0.0945, -0.2175, -0.0924, -0.1414, -0.1828,  0.0621, -0.3528,
         -0.2910, -0.0231,  0.1191,  0.2671],
        [-0.1596,  0.5198,  0.3571,  0.0806, -0.2248, -0.2083, -0.3483,  0.0522,
         -0.0583, -0.0232,  0.0035, -0.3093,  0.0038,  0.0386,  0.2241, -0.2543,
         -0.2830,  0.0570,  0.2809,  0.0586]], grad_fn=<AddmmBackward0>)

nn.ReLu

print("Before ReLU: "+str(hidden1))
hidden1 = nn.ReLU()(hidden1)
print("After ReLU: "+str(hidden1))
Before ReLU: tensor([[-0.3914,  0.9743,  0.3574,  0.0032, -0.0105, -0.3178, -0.6143,  0.0487,
          0.0233, -0.4386,  0.3226,  0.0912,  0.0098,  0.0723, -0.1843, -0.5586,
         -0.0618, -0.0330,  0.6477,  0.4035],
        [-0.5303,  0.4330,  0.5043,  0.3772, -0.3653, -0.2800, -0.3662,  0.0570,
          0.3869,  0.0945, -0.2175, -0.0924, -0.1414, -0.1828,  0.0621, -0.3528,
         -0.2910, -0.0231,  0.1191,  0.2671],
        [-0.1596,  0.5198,  0.3571,  0.0806, -0.2248, -0.2083, -0.3483,  0.0522,
         -0.0583, -0.0232,  0.0035, -0.3093,  0.0038,  0.0386,  0.2241, -0.2543,
         -0.2830,  0.0570,  0.2809,  0.0586]], grad_fn=<AddmmBackward0>)
After ReLU: tensor([[0.0000, 0.9743, 0.3574, 0.0032, 0.0000, 0.0000, 0.0000, 0.0487, 0.0233,
         0.0000, 0.3226, 0.0912, 0.0098, 0.0723, 0.0000, 0.0000, 0.0000, 0.0000,
         0.6477, 0.4035],
        [0.0000, 0.4330, 0.5043, 0.3772, 0.0000, 0.0000, 0.0000, 0.0570, 0.3869,
         0.0945, 0.0000, 0.0000, 0.0000, 0.0000, 0.0621, 0.0000, 0.0000, 0.0000,
         0.1191, 0.2671],
        [0.0000, 0.5198, 0.3571, 0.0806, 0.0000, 0.0000, 0.0000, 0.0522, 0.0000,
         0.0000, 0.0035, 0.0000, 0.0038, 0.0386, 0.2241, 0.0000, 0.0000, 0.0570,
         0.2809, 0.0586]], grad_fn=<ReluBackward0>)

nn.Sequential

seq_modules = nn.Sequential(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Linear(20,10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)
print(logits)
tensor([[ 0.1165, -0.2980, -0.1271,  0.1471,  0.1197, -0.0370, -0.1217, -0.0424,
          0.1851,  0.1187],
        [ 0.2034, -0.2883, -0.2599,  0.1343,  0.0700, -0.1013, -0.1442, -0.0667,
          0.2362,  0.1714],
        [ 0.1659, -0.2946, -0.1774,  0.1805,  0.1837, -0.1381, -0.2138, -0.0489,
          0.1290,  0.1409]], grad_fn=<AddmmBackward0>)

nn.Softmax

softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)
print(pred_probab)
tensor([[0.1105, 0.0730, 0.0866, 0.1139, 0.1108, 0.0948, 0.0871, 0.0943, 0.1183,
         0.1107],
        [0.1211, 0.0741, 0.0762, 0.1130, 0.1060, 0.0893, 0.0855, 0.0924, 0.1251,
         0.1173],
        [0.1171, 0.0739, 0.0831, 0.1188, 0.1192, 0.0864, 0.0801, 0.0945, 0.1128,
         0.1142]], grad_fn=<SoftmaxBackward0>)

模型参数

  • nn.Moudle会自动跟踪保存模型参数,使用parameters()或named_parameters()获取
print(model)
for name,param in model.named_parameters():
    print(name)
    print(param)
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
linear_relu_stack.0.weight
Parameter containing:
tensor([[ 0.0348,  0.0040,  0.0076,  ..., -0.0357,  0.0206,  0.0259],
        [-0.0055,  0.0342, -0.0239,  ...,  0.0177, -0.0028,  0.0001],
        [-0.0087,  0.0276, -0.0036,  ...,  0.0066, -0.0192,  0.0107],
        ...,
        [ 0.0304,  0.0203, -0.0344,  ..., -0.0248, -0.0153, -0.0235],
        [-0.0091,  0.0115,  0.0318,  ...,  0.0144, -0.0122,  0.0103],
        [ 0.0242,  0.0351,  0.0152,  ...,  0.0321,  0.0250,  0.0157]],
       requires_grad=True)
linear_relu_stack.0.bias
Parameter containing:
tensor([ 0.0086,  0.0261, -0.0205, -0.0040, -0.0326,  0.0316,  0.0330,  0.0151,
         0.0222,  0.0155,  0.0138,  0.0226, -0.0340,  0.0284,  0.0142,  0.0227,
        -0.0184,  0.0282,  0.0177, -0.0084, -0.0119,  0.0096,  0.0266, -0.0295,
         0.0015, -0.0128,  0.0264,  0.0321,  0.0178, -0.0252,  0.0236,  0.0214,
        -0.0174,  0.0289,  0.0141, -0.0312,  0.0098,  0.0023,  0.0334,  0.0266,
        -0.0066, -0.0073, -0.0353, -0.0099,  0.0220,  0.0309, -0.0066,  0.0306,
        -0.0093, -0.0084, -0.0095, -0.0249,  0.0078, -0.0064,  0.0199,  0.0271,
        -0.0092,  0.0084, -0.0321,  0.0004,  0.0080, -0.0073,  0.0245, -0.0002,
        -0.0249, -0.0292, -0.0079,  0.0330, -0.0068,  0.0158,  0.0122, -0.0097,
         0.0253,  0.0310, -0.0032,  0.0106,  0.0018, -0.0012,  0.0039, -0.0119,
        -0.0156, -0.0053,  0.0002,  0.0014,  0.0014,  0.0159, -0.0092,  0.0293,
         0.0052, -0.0214,  0.0215,  0.0116,  0.0255,  0.0205, -0.0012, -0.0256,
         0.0120, -0.0325,  0.0134, -0.0281, -0.0290,  0.0121, -0.0133, -0.0083,
         0.0181,  0.0198,  0.0058,  0.0216, -0.0294, -0.0038,  0.0191,  0.0347,
        -0.0143,  0.0352,  0.0141, -0.0328, -0.0030,  0.0131,  0.0042,  0.0300,
         0.0344,  0.0171,  0.0259,  0.0226, -0.0004,  0.0050, -0.0064, -0.0164,
         0.0203, -0.0159, -0.0110,  0.0344,  0.0011, -0.0209,  0.0086,  0.0233,
        -0.0299,  0.0152, -0.0084, -0.0026, -0.0344, -0.0111,  0.0086,  0.0349,
         0.0332, -0.0176,  0.0083, -0.0140,  0.0187,  0.0048,  0.0086,  0.0030,
         0.0277, -0.0129,  0.0217,  0.0349,  0.0280,  0.0318, -0.0106, -0.0288,
         0.0162, -0.0272, -0.0072, -0.0309, -0.0355, -0.0083, -0.0280,  0.0040,
        -0.0351,  0.0036,  0.0146,  0.0040, -0.0322,  0.0223, -0.0252, -0.0342,
        -0.0173,  0.0159, -0.0172,  0.0283,  0.0034, -0.0343, -0.0092,  0.0318,
        -0.0109, -0.0017,  0.0117, -0.0121,  0.0198,  0.0045,  0.0195,  0.0203,
         0.0027,  0.0307,  0.0342,  0.0058,  0.0137, -0.0250,  0.0323, -0.0203,
        -0.0175, -0.0336,  0.0225,  0.0174,  0.0133, -0.0023,  0.0227,  0.0159,
         0.0103, -0.0025, -0.0069,  0.0192, -0.0229, -0.0128,  0.0086, -0.0022,
         0.0079,  0.0272,  0.0226,  0.0276,  0.0306, -0.0331, -0.0246, -0.0043,
        -0.0060, -0.0044, -0.0346,  0.0211, -0.0071, -0.0253,  0.0276, -0.0204,
        -0.0324,  0.0193, -0.0099, -0.0121,  0.0290, -0.0217,  0.0343, -0.0314,
        -0.0292, -0.0337,  0.0151,  0.0295,  0.0176, -0.0045,  0.0142,  0.0304,
         0.0330, -0.0073,  0.0148,  0.0149,  0.0192, -0.0222,  0.0071,  0.0125,
         0.0131,  0.0269,  0.0344, -0.0205,  0.0218, -0.0177, -0.0269,  0.0308,
        -0.0307, -0.0215,  0.0141, -0.0032, -0.0103, -0.0041, -0.0056,  0.0134,
         0.0209, -0.0150,  0.0095,  0.0140, -0.0091, -0.0120,  0.0293,  0.0271,
        -0.0248,  0.0034, -0.0329, -0.0193,  0.0074,  0.0277,  0.0201,  0.0309,
         0.0164, -0.0086,  0.0351,  0.0066,  0.0134, -0.0169,  0.0097, -0.0147,
        -0.0202,  0.0163, -0.0352, -0.0045,  0.0349,  0.0263, -0.0148, -0.0227,
        -0.0271,  0.0343,  0.0116, -0.0238, -0.0317,  0.0028, -0.0039,  0.0135,
        -0.0292, -0.0170,  0.0183,  0.0149, -0.0118, -0.0347,  0.0133,  0.0243,
        -0.0031, -0.0055, -0.0007,  0.0086,  0.0182,  0.0312,  0.0135,  0.0247,
        -0.0009, -0.0114,  0.0334,  0.0033,  0.0345,  0.0009,  0.0325,  0.0345,
         0.0130, -0.0173, -0.0304,  0.0315, -0.0152,  0.0342,  0.0344,  0.0159,
        -0.0345, -0.0127, -0.0041,  0.0154,  0.0021, -0.0109, -0.0194, -0.0281,
        -0.0313,  0.0304,  0.0296, -0.0010,  0.0145, -0.0013,  0.0225, -0.0129,
        -0.0117,  0.0243,  0.0114,  0.0268,  0.0355,  0.0287,  0.0215, -0.0161,
        -0.0352, -0.0282, -0.0211, -0.0301, -0.0174,  0.0089, -0.0218, -0.0023,
        -0.0317,  0.0042,  0.0058, -0.0156, -0.0101,  0.0149,  0.0078,  0.0137,
        -0.0260, -0.0297,  0.0091,  0.0093, -0.0114,  0.0023, -0.0234, -0.0002,
        -0.0168,  0.0292, -0.0079, -0.0051,  0.0270, -0.0315, -0.0071,  0.0253,
         0.0168,  0.0220,  0.0239, -0.0155,  0.0092,  0.0175, -0.0040, -0.0141,
        -0.0194,  0.0099,  0.0291, -0.0104, -0.0010, -0.0028,  0.0270, -0.0121,
        -0.0240, -0.0177,  0.0315, -0.0061,  0.0183,  0.0273, -0.0118, -0.0030,
         0.0263, -0.0175, -0.0066, -0.0259, -0.0101, -0.0285,  0.0177, -0.0302,
         0.0235, -0.0129, -0.0354, -0.0338, -0.0323,  0.0244,  0.0228, -0.0277,
        -0.0251, -0.0111,  0.0082, -0.0015,  0.0052,  0.0273, -0.0055, -0.0343,
        -0.0202, -0.0139,  0.0105,  0.0304, -0.0068,  0.0223, -0.0314, -0.0344,
         0.0260, -0.0021,  0.0234, -0.0201, -0.0235, -0.0280,  0.0195,  0.0007,
        -0.0124,  0.0133, -0.0023, -0.0111, -0.0275,  0.0120, -0.0128, -0.0184,
        -0.0173, -0.0179, -0.0357, -0.0295,  0.0036, -0.0305,  0.0249,  0.0217,
         0.0213,  0.0177, -0.0226,  0.0214, -0.0117,  0.0243,  0.0051, -0.0346,
        -0.0152, -0.0278,  0.0193, -0.0311, -0.0318,  0.0307,  0.0086,  0.0304,
        -0.0007,  0.0357,  0.0219, -0.0269, -0.0110, -0.0060, -0.0256,  0.0340,
         0.0111, -0.0119, -0.0019,  0.0254,  0.0167, -0.0046,  0.0224,  0.0296],
       requires_grad=True)
linear_relu_stack.2.weight
Parameter containing:
tensor([[ 0.0345,  0.0181,  0.0369,  ..., -0.0131,  0.0136,  0.0314],
        [ 0.0394,  0.0078, -0.0396,  ...,  0.0083, -0.0370,  0.0369],
        [ 0.0008,  0.0354, -0.0103,  ...,  0.0071,  0.0435,  0.0437],
        ...,
        [-0.0400, -0.0052, -0.0206,  ..., -0.0104, -0.0068, -0.0242],
        [ 0.0065, -0.0042,  0.0153,  ...,  0.0032, -0.0207, -0.0188],
        [-0.0385,  0.0161, -0.0351,  ..., -0.0256,  0.0053, -0.0024]],
       requires_grad=True)
linear_relu_stack.2.bias
Parameter containing:
tensor([ 4.1813e-02, -1.6619e-02,  3.9113e-02, -1.5093e-02,  1.5017e-02,
         5.2896e-03, -7.0315e-03,  9.5963e-03, -3.3275e-02, -1.9160e-02,
        -3.4745e-02,  1.0509e-02,  1.0498e-04, -7.1548e-03, -1.5839e-02,
        -3.1533e-02, -1.3287e-02,  2.3534e-02, -1.6398e-02,  4.8202e-03,
        -1.0436e-02,  3.4014e-02,  2.8655e-02,  2.8397e-02,  2.7178e-02,
         3.9391e-02,  3.0124e-02,  3.3509e-03,  6.1992e-03,  2.6582e-02,
        -3.9359e-02, -3.0841e-02,  2.8772e-02,  1.2272e-02,  3.5646e-02,
         2.2063e-02, -1.4506e-02,  8.3983e-03,  3.6239e-02, -3.6312e-02,
         4.0445e-02,  3.2031e-02,  5.8938e-03,  1.1676e-02,  1.2338e-02,
         4.0429e-02,  3.2177e-02, -1.9051e-02, -3.5229e-02, -1.8315e-02,
        -3.6294e-02,  2.0468e-02, -4.3678e-02,  1.2375e-02, -2.1652e-02,
        -2.6055e-03,  1.9871e-02, -2.9539e-02, -4.4110e-03, -4.0434e-02,
         3.8815e-02,  2.5248e-02,  3.2590e-02,  4.4631e-04, -3.0369e-02,
         1.7029e-02, -5.3398e-03,  1.9067e-02, -4.1852e-02,  9.3174e-03,
         3.2190e-02,  6.5696e-03,  3.1825e-03, -6.6836e-04,  1.0921e-02,
         3.2889e-02,  1.3549e-02,  1.1673e-03,  4.2575e-02,  2.5849e-02,
        -3.9895e-02, -3.9511e-02, -2.1672e-02,  3.7632e-02,  1.7327e-02,
         2.3956e-02, -2.5438e-02,  3.1431e-02, -2.6375e-03,  2.0853e-02,
        -2.4985e-02, -1.1729e-02, -1.8595e-02,  1.9006e-02,  1.5368e-03,
        -3.5385e-02, -3.6201e-02,  9.3275e-03,  1.2355e-02,  3.5495e-02,
        -3.3091e-02,  6.2980e-03,  4.5804e-03, -4.3357e-02, -9.3052e-03,
        -1.4889e-02,  3.6015e-02, -6.1881e-03,  5.8576e-03,  3.3089e-02,
         1.5759e-02, -1.9032e-02, -1.1096e-02, -6.8615e-03,  1.7547e-03,
         2.7654e-02,  7.6464e-03, -4.0611e-02, -1.9589e-02,  1.9037e-02,
        -7.7626e-03, -3.3604e-02, -2.5285e-02,  4.0471e-02,  3.6573e-02,
        -3.9312e-02, -3.3128e-02,  1.4771e-02,  1.0701e-02, -3.1122e-02,
         3.6833e-02, -3.6208e-02,  7.8927e-03, -2.9675e-02,  3.9354e-02,
        -3.0588e-02, -3.5297e-02,  3.1088e-02,  1.7613e-02,  3.1319e-02,
         2.7442e-02,  3.8756e-02,  4.4021e-02,  4.3242e-02, -5.1761e-03,
         3.4909e-02,  3.3177e-02, -2.4528e-02,  3.8147e-02, -1.9509e-02,
        -2.1462e-02,  1.5008e-02, -3.2534e-02, -3.9613e-02, -3.7725e-02,
         3.1532e-02,  1.9861e-02,  3.8157e-02,  1.7813e-02,  6.0684e-03,
         3.6414e-03,  1.7636e-02,  5.8332e-03,  3.4099e-02, -3.4436e-02,
         2.4158e-02, -2.9897e-02,  3.6654e-02,  7.4221e-03,  1.5306e-02,
        -8.5132e-03,  8.1645e-03, -2.7132e-02, -1.4036e-02,  2.8793e-02,
         4.2096e-02, -1.4138e-02,  3.3185e-02, -3.6140e-02,  2.7398e-02,
        -1.5582e-02,  3.5993e-02,  3.0235e-02, -1.9122e-02, -3.2258e-02,
         5.2566e-03, -1.7969e-02,  3.2155e-02, -4.3666e-02,  2.1930e-02,
         1.4098e-02,  4.9657e-03, -3.7629e-02,  4.2928e-02,  3.9507e-02,
        -1.5557e-02, -3.8715e-02, -7.4666e-04,  6.8257e-03,  3.9410e-02,
         2.7932e-02, -7.3785e-03,  3.5149e-02, -2.1111e-03,  4.2002e-02,
        -6.6258e-03, -1.2529e-02,  2.5985e-02, -5.3836e-03,  3.4099e-02,
        -1.9472e-02,  1.4900e-02,  5.3838e-03,  1.7148e-02,  3.6593e-02,
        -1.3598e-02,  2.1629e-02,  2.9592e-02, -1.4871e-02,  1.7056e-02,
         2.5576e-02,  2.2679e-02,  9.2657e-03,  2.7061e-02, -1.4918e-04,
        -7.0879e-04,  2.1378e-02,  3.2623e-02, -3.2693e-02, -6.9890e-03,
         1.2475e-02,  2.1180e-02, -2.5963e-02,  1.1538e-02, -3.1687e-02,
        -3.3825e-02,  6.3065e-03, -2.2391e-02, -1.6993e-02,  2.9761e-02,
        -1.7584e-02,  2.6158e-02, -3.8398e-02,  9.5393e-03,  1.0308e-02,
        -2.7005e-02,  2.7423e-02, -3.0228e-02, -7.5275e-04,  2.9244e-03,
        -3.2164e-02,  3.6587e-02,  3.0417e-02, -2.9701e-02,  3.8880e-02,
        -3.0179e-02,  8.1338e-04,  1.9973e-02, -2.1000e-02, -1.2114e-02,
         2.6584e-02,  7.1286e-03,  8.1980e-03,  3.8927e-02, -1.0494e-02,
        -3.5321e-02,  3.1413e-03, -7.3663e-03, -1.4615e-02, -2.9388e-02,
        -1.0254e-02,  3.6683e-02,  1.9666e-02,  1.0081e-02, -3.3764e-02,
        -1.3077e-02, -1.1296e-02,  1.9023e-02,  5.0457e-03,  3.8632e-02,
        -3.8144e-02,  4.1357e-02,  1.7847e-02,  3.6878e-02,  3.9748e-02,
        -6.0793e-03,  2.1098e-02, -3.4776e-03,  3.3519e-02,  1.9769e-02,
        -3.3734e-03, -2.9008e-02, -4.2866e-02, -2.6344e-02,  1.6083e-02,
        -4.1019e-02, -3.9586e-02,  3.0554e-02,  1.0268e-02,  2.8977e-02,
         3.8883e-02, -1.5359e-02, -2.8558e-02,  4.0887e-03,  4.0116e-02,
         3.4093e-02,  3.2030e-03, -2.7915e-02,  1.0666e-02,  3.0899e-02,
        -2.1109e-02, -6.3490e-03, -4.2149e-02, -3.4760e-02, -2.5595e-02,
        -3.3401e-02, -5.2975e-03, -3.4152e-02,  3.5972e-02, -4.2104e-02,
        -2.7873e-02, -1.3702e-02, -5.8416e-03,  2.8854e-02,  2.5872e-02,
        -2.6067e-02, -2.7059e-02,  7.1851e-03,  2.8099e-02, -1.0737e-02,
         1.9886e-02,  2.8195e-02,  3.1350e-02, -2.6669e-02,  2.2479e-02,
        -4.3147e-03, -3.5953e-02, -1.9973e-04, -2.4630e-03,  2.7178e-02,
        -2.7818e-02,  2.6631e-02,  8.9729e-03, -2.2624e-02, -2.4536e-02,
         3.8296e-02,  2.4300e-02,  3.1020e-02,  2.5661e-02,  3.2956e-02,
         2.6426e-02, -3.3200e-02, -2.5431e-02,  3.7043e-02, -2.5536e-02,
        -2.3622e-02,  2.5614e-02,  2.9049e-02,  5.2677e-03,  3.0301e-02,
        -1.4990e-02,  1.3833e-02, -2.7951e-02,  1.0994e-02, -1.7039e-02,
        -3.8425e-02, -3.3476e-02, -3.6594e-02,  3.7877e-02, -2.3660e-02,
        -2.7774e-02, -2.5421e-02,  3.1451e-02,  1.8529e-02, -1.8345e-02,
         3.9190e-02, -3.0978e-02,  5.2639e-03,  2.9981e-02,  9.4218e-03,
        -4.0102e-02, -1.6685e-02, -2.9653e-02,  9.4134e-04, -3.6672e-02,
         2.8482e-02, -1.8619e-02,  2.3092e-02,  1.6687e-02,  3.6474e-03,
         2.7813e-02, -1.0635e-02,  6.3452e-03,  2.8750e-02, -2.9257e-02,
         4.9299e-03, -9.1479e-03,  3.1707e-02, -6.7157e-03,  2.3438e-04,
         1.7786e-02,  1.8846e-02,  7.4835e-03,  2.5446e-02,  1.7958e-02,
        -2.3121e-02,  3.3062e-02,  3.7613e-02, -4.3311e-02, -1.0183e-02,
         2.8330e-02, -4.0246e-02, -3.4520e-02,  3.7923e-02, -1.8100e-02,
         1.2737e-02, -1.2114e-02,  6.5649e-03, -5.9071e-03,  7.7365e-03,
        -7.4515e-03, -2.2001e-02, -2.3253e-02, -2.9996e-02,  1.1063e-03,
         1.4057e-03,  1.3215e-03, -3.7861e-02, -1.5023e-02, -7.1092e-03,
         1.2387e-02, -4.4001e-02,  3.0360e-02, -2.8778e-02,  3.1841e-02,
        -1.2114e-02,  4.1058e-02, -1.2814e-02,  3.9077e-02,  6.9941e-03,
         1.7139e-02,  3.9288e-02,  2.1338e-02,  2.7192e-02,  4.3974e-02,
         1.9053e-02,  9.1687e-04, -1.3872e-02,  2.6999e-02,  2.7631e-02,
        -2.2441e-02, -2.1365e-02, -3.3507e-02,  5.4755e-03,  2.4524e-02,
        -9.1414e-03,  3.3348e-02, -3.6404e-02, -1.5845e-02, -1.7853e-02,
        -2.6763e-02,  2.4977e-02, -2.2549e-02, -2.6055e-03, -1.4956e-02,
        -2.4629e-02, -2.9500e-02, -3.6520e-02, -3.1318e-02,  2.5301e-02,
        -4.1560e-02,  4.0596e-02, -3.5743e-02, -9.9402e-03, -3.8203e-02,
         1.8469e-02, -7.2222e-03,  8.7008e-03, -1.6842e-02,  1.3508e-03,
         2.8061e-02, -3.6841e-02, -4.1263e-02,  1.8805e-02, -4.1805e-02,
        -3.8119e-04, -1.6620e-02, -2.8509e-02, -4.1276e-02, -4.1390e-02,
        -1.6600e-02,  1.1227e-02, -9.8479e-03,  8.0011e-03, -3.1407e-02,
        -3.2109e-02,  3.0424e-02, -3.1924e-02, -3.0520e-05, -1.3426e-02,
        -3.4665e-02, -3.7141e-02, -8.8735e-03,  2.1064e-02,  1.1333e-02,
         3.4191e-02, -8.8482e-03,  1.2196e-02, -2.9521e-02, -7.7659e-03,
         2.9205e-02, -4.3007e-02], requires_grad=True)
linear_relu_stack.4.weight
Parameter containing:
tensor([[-0.0176, -0.0254,  0.0367,  ..., -0.0318, -0.0283,  0.0168],
        [-0.0305,  0.0373,  0.0211,  ..., -0.0173,  0.0382,  0.0340],
        [-0.0214, -0.0138,  0.0270,  ...,  0.0156, -0.0321, -0.0142],
        ...,
        [-0.0366,  0.0046, -0.0345,  ..., -0.0114, -0.0277, -0.0087],
        [-0.0342, -0.0080, -0.0343,  ...,  0.0110, -0.0043,  0.0092],
        [-0.0192,  0.0405, -0.0111,  ..., -0.0192, -0.0165,  0.0208]],
       requires_grad=True)
linear_relu_stack.4.bias
Parameter containing:
tensor([ 0.0355,  0.0054,  0.0438, -0.0149, -0.0191,  0.0248,  0.0221, -0.0388,
         0.0360,  0.0128], requires_grad=True)
上次编辑于:
贡献者: lisenjie757
已到达文章底部,欢迎留言、表情互动~
  • 赞一个
    0
    赞一个
  • 支持下
    0
    支持下
  • 有点酷
    0
    有点酷
  • 啥玩意
    0
    啥玩意
  • 看不懂
    0
    看不懂
评论
  • 按正序
  • 按倒序
  • 按热度
Powered by Waline v2.14.9