Rinne's Blog

Back

数据增强#

使用 torchvision.transforms 能方便的进行数据增强,能减少过拟合。 代码示例:

transforms_train = torchvision.transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5), # 水平翻转
    transforms.RandomVerticalFlip(p=0.5),   # 垂直翻转
    transforms.RandomRotation(30),  # 旋转角度
    transforms.RandomPerspective(distortion_scale=0.2, p=0.3),  # 透视变换
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),  # 颜色调整
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1),ratio=(3/4, 4/3)), # 裁剪、缩放
    transforms.ToTensor()
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transforms_test = torchvision.transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
python

提高 GPU/显存 占用率#

  1. 当显存占用低时,增大 batch_size 有提高占用率。
  2. 当数据处理速度较模型训练速度慢时,修改 DataLoader 参数以解决问题,如:
class torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, drop_last=False
)
python

timm 库#

timm (PyTorch Image Models) 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。

torchinfo 库#

torchinfo 库的核心功能为 summary 函数,能快速获取模型的详细结构和统计信息,如模型的层次结构、输入/输出维度、参数数量、多加操作等关键信息。代码示例如下:

res_model = timm.create_model('resnest50d', pretrained=True)
res_model.fc = nn.Linear(2048, 176)
summary(res_model, input_size = (model_batch_size, 3, 224, 224))
python
得到输出为(点击展开)
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   --                        --
├─Sequential: 1-1                        [32, 64, 112, 112]        --
│    └─Conv2d: 2-1                       [32, 32, 112, 112]        864
│    └─BatchNorm2d: 2-2                  [32, 32, 112, 112]        64
│    └─ReLU: 2-3                         [32, 32, 112, 112]        --
│    └─Conv2d: 2-4                       [32, 32, 112, 112]        9,216
│    └─BatchNorm2d: 2-5                  [32, 32, 112, 112]        64
│    └─ReLU: 2-6                         [32, 32, 112, 112]        --
│    └─Conv2d: 2-7                       [32, 64, 112, 112]        18,432
├─BatchNorm2d: 1-2                       [32, 64, 112, 112]        128
├─ReLU: 1-3                              [32, 64, 112, 112]        --
├─MaxPool2d: 1-4                         [32, 64, 56, 56]          --
├─Sequential: 1-5                        [32, 256, 56, 56]         --
│    └─ResNestBottleneck: 2-8            [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-1                  [32, 64, 56, 56]          4,096
│    │    └─BatchNorm2d: 3-2             [32, 64, 56, 56]          128
│    │    └─ReLU: 3-3                    [32, 64, 56, 56]          --
│    │    └─SplitAttnConv2d: 3-4         [32, 64, 56, 56]          43,488
│    │    └─Conv2d: 3-5                  [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-6             [32, 256, 56, 56]         512
│    │    └─Sequential: 3-7              [32, 256, 56, 56]         16,896
│    │    └─ReLU: 3-8                    [32, 256, 56, 56]         --
│    └─ResNestBottleneck: 2-9            [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-9                  [32, 64, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-10            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-11                   [32, 64, 56, 56]          --
│    │    └─SplitAttnConv2d: 3-12        [32, 64, 56, 56]          43,488
│    │    └─Conv2d: 3-13                 [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-14            [32, 256, 56, 56]         512
│    │    └─ReLU: 3-15                   [32, 256, 56, 56]         --
│    └─ResNestBottleneck: 2-10           [32, 256, 56, 56]         --
│    │    └─Conv2d: 3-16                 [32, 64, 56, 56]          16,384
│    │    └─BatchNorm2d: 3-17            [32, 64, 56, 56]          128
│    │    └─ReLU: 3-18                   [32, 64, 56, 56]          --
│    │    └─SplitAttnConv2d: 3-19        [32, 64, 56, 56]          43,488
│    │    └─Conv2d: 3-20                 [32, 256, 56, 56]         16,384
│    │    └─BatchNorm2d: 3-21            [32, 256, 56, 56]         512
│    │    └─ReLU: 3-22                   [32, 256, 56, 56]         --
├─Sequential: 1-6                        [32, 512, 28, 28]         --
│    └─ResNestBottleneck: 2-11           [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-23                 [32, 128, 56, 56]         32,768
│    │    └─BatchNorm2d: 3-24            [32, 128, 56, 56]         256
│    │    └─ReLU: 3-25                   [32, 128, 56, 56]         --
│    │    └─SplitAttnConv2d: 3-26        [32, 128, 56, 56]         172,992
│    │    └─AvgPool2d: 3-27              [32, 128, 28, 28]         --
│    │    └─Conv2d: 3-28                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-29            [32, 512, 28, 28]         1,024
│    │    └─Sequential: 3-30             [32, 512, 28, 28]         132,096
│    │    └─ReLU: 3-31                   [32, 512, 28, 28]         --
│    └─ResNestBottleneck: 2-12           [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-32                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-33            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-34                   [32, 128, 28, 28]         --
│    │    └─SplitAttnConv2d: 3-35        [32, 128, 28, 28]         172,992
│    │    └─Conv2d: 3-36                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-37            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-38                   [32, 512, 28, 28]         --
│    └─ResNestBottleneck: 2-13           [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-39                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-40            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-41                   [32, 128, 28, 28]         --
│    │    └─SplitAttnConv2d: 3-42        [32, 128, 28, 28]         172,992
│    │    └─Conv2d: 3-43                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-44            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-45                   [32, 512, 28, 28]         --
│    └─ResNestBottleneck: 2-14           [32, 512, 28, 28]         --
│    │    └─Conv2d: 3-46                 [32, 128, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-47            [32, 128, 28, 28]         256
│    │    └─ReLU: 3-48                   [32, 128, 28, 28]         --
│    │    └─SplitAttnConv2d: 3-49        [32, 128, 28, 28]         172,992
│    │    └─Conv2d: 3-50                 [32, 512, 28, 28]         65,536
│    │    └─BatchNorm2d: 3-51            [32, 512, 28, 28]         1,024
│    │    └─ReLU: 3-52                   [32, 512, 28, 28]         --
├─Sequential: 1-7                        [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-15           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-53                 [32, 256, 28, 28]         131,072
│    │    └─BatchNorm2d: 3-54            [32, 256, 28, 28]         512
│    │    └─ReLU: 3-55                   [32, 256, 28, 28]         --
│    │    └─SplitAttnConv2d: 3-56        [32, 256, 28, 28]         690,048
│    │    └─AvgPool2d: 3-57              [32, 256, 14, 14]         --
│    │    └─Conv2d: 3-58                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-59            [32, 1024, 14, 14]        2,048
│    │    └─Sequential: 3-60             [32, 1024, 14, 14]        526,336
│    │    └─ReLU: 3-61                   [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-16           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-62                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-63            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-64                   [32, 256, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-65        [32, 256, 14, 14]         690,048
│    │    └─Conv2d: 3-66                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-67            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-68                   [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-17           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-69                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-70            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-71                   [32, 256, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-72        [32, 256, 14, 14]         690,048
│    │    └─Conv2d: 3-73                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-74            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-75                   [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-18           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-76                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-77            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-78                   [32, 256, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-79        [32, 256, 14, 14]         690,048
│    │    └─Conv2d: 3-80                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-81            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-82                   [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-19           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-83                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-84            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-85                   [32, 256, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-86        [32, 256, 14, 14]         690,048
│    │    └─Conv2d: 3-87                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-88            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-89                   [32, 1024, 14, 14]        --
│    └─ResNestBottleneck: 2-20           [32, 1024, 14, 14]        --
│    │    └─Conv2d: 3-90                 [32, 256, 14, 14]         262,144
│    │    └─BatchNorm2d: 3-91            [32, 256, 14, 14]         512
│    │    └─ReLU: 3-92                   [32, 256, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-93        [32, 256, 14, 14]         690,048
│    │    └─Conv2d: 3-94                 [32, 1024, 14, 14]        262,144
│    │    └─BatchNorm2d: 3-95            [32, 1024, 14, 14]        2,048
│    │    └─ReLU: 3-96                   [32, 1024, 14, 14]        --
├─Sequential: 1-8                        [32, 2048, 7, 7]          --
│    └─ResNestBottleneck: 2-21           [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-97                 [32, 512, 14, 14]         524,288
│    │    └─BatchNorm2d: 3-98            [32, 512, 14, 14]         1,024
│    │    └─ReLU: 3-99                   [32, 512, 14, 14]         --
│    │    └─SplitAttnConv2d: 3-100       [32, 512, 14, 14]         2,756,352
│    │    └─AvgPool2d: 3-101             [32, 512, 7, 7]           --
│    │    └─Conv2d: 3-102                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-103           [32, 2048, 7, 7]          4,096
│    │    └─Sequential: 3-104            [32, 2048, 7, 7]          2,101,248
│    │    └─ReLU: 3-105                  [32, 2048, 7, 7]          --
│    └─ResNestBottleneck: 2-22           [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-106                [32, 512, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-107           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-108                  [32, 512, 7, 7]           --
│    │    └─SplitAttnConv2d: 3-109       [32, 512, 7, 7]           2,756,352
│    │    └─Conv2d: 3-110                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-111           [32, 2048, 7, 7]          4,096
│    │    └─ReLU: 3-112                  [32, 2048, 7, 7]          --
│    └─ResNestBottleneck: 2-23           [32, 2048, 7, 7]          --
│    │    └─Conv2d: 3-113                [32, 512, 7, 7]           1,048,576
│    │    └─BatchNorm2d: 3-114           [32, 512, 7, 7]           1,024
│    │    └─ReLU: 3-115                  [32, 512, 7, 7]           --
│    │    └─SplitAttnConv2d: 3-116       [32, 512, 7, 7]           2,756,352
│    │    └─Conv2d: 3-117                [32, 2048, 7, 7]          1,048,576
│    │    └─BatchNorm2d: 3-118           [32, 2048, 7, 7]          4,096
│    │    └─ReLU: 3-119                  [32, 2048, 7, 7]          --
├─SelectAdaptivePool2d: 1-9              [32, 2048]                --
│    └─AdaptiveAvgPool2d: 2-24           [32, 2048, 1, 1]          --
├─Linear: 1-10                           [32, 176]                 360,624
==========================================================================================
Total params: 25,794,864
Trainable params: 25,794,864
Non-trainable params: 0
Total mult-adds (G): 171.83
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 7350.32
Params size (MB): 103.18
Estimated Total Size (MB): 7472.76
==========================================================================================
log

参考资料#

  1. 叶子分类比赛笔记
  2. PyTorch Image Models
©
Kaggle项目中训练模型时的技巧总结
https://astro-pure.js.org/blog/learning/deep-learning/model-train-tips
Author Rinne
Published at 2025年9月15日
Comment seems to stuck. Try to refresh?✨