数据增强#
使用 torchvision.transforms 能方便的进行数据增强,能减少过拟合。
代码示例:
transforms_train = torchvision.transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 水平翻转
transforms.RandomVerticalFlip(p=0.5), # 垂直翻转
transforms.RandomRotation(30), # 旋转角度
transforms.RandomPerspective(distortion_scale=0.2, p=0.3), # 透视变换
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), # 颜色调整
transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1),ratio=(3/4, 4/3)), # 裁剪、缩放
transforms.ToTensor()
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transforms_test = torchvision.transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])python提高 GPU/显存 占用率#
- 当显存占用低时,增大
batch_size有提高占用率。 - 当数据处理速度较模型训练速度慢时,修改
DataLoader参数以解决问题,如:
class torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, drop_last=False
)pythontimm 库#
timm (PyTorch Image Models) 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。
torchinfo 库#
torchinfo 库的核心功能为 summary 函数,能快速获取模型的详细结构和统计信息,如模型的层次结构、输入/输出维度、参数数量、多加操作等关键信息。代码示例如下:
res_model = timm.create_model('resnest50d', pretrained=True)
res_model.fc = nn.Linear(2048, 176)
summary(res_model, input_size = (model_batch_size, 3, 224, 224))python得到输出为(点击展开)
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
ResNet -- --
├─Sequential: 1-1 [32, 64, 112, 112] --
│ └─Conv2d: 2-1 [32, 32, 112, 112] 864
│ └─BatchNorm2d: 2-2 [32, 32, 112, 112] 64
│ └─ReLU: 2-3 [32, 32, 112, 112] --
│ └─Conv2d: 2-4 [32, 32, 112, 112] 9,216
│ └─BatchNorm2d: 2-5 [32, 32, 112, 112] 64
│ └─ReLU: 2-6 [32, 32, 112, 112] --
│ └─Conv2d: 2-7 [32, 64, 112, 112] 18,432
├─BatchNorm2d: 1-2 [32, 64, 112, 112] 128
├─ReLU: 1-3 [32, 64, 112, 112] --
├─MaxPool2d: 1-4 [32, 64, 56, 56] --
├─Sequential: 1-5 [32, 256, 56, 56] --
│ └─ResNestBottleneck: 2-8 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-1 [32, 64, 56, 56] 4,096
│ │ └─BatchNorm2d: 3-2 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-3 [32, 64, 56, 56] --
│ │ └─SplitAttnConv2d: 3-4 [32, 64, 56, 56] 43,488
│ │ └─Conv2d: 3-5 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-6 [32, 256, 56, 56] 512
│ │ └─Sequential: 3-7 [32, 256, 56, 56] 16,896
│ │ └─ReLU: 3-8 [32, 256, 56, 56] --
│ └─ResNestBottleneck: 2-9 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-9 [32, 64, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-10 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-11 [32, 64, 56, 56] --
│ │ └─SplitAttnConv2d: 3-12 [32, 64, 56, 56] 43,488
│ │ └─Conv2d: 3-13 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-14 [32, 256, 56, 56] 512
│ │ └─ReLU: 3-15 [32, 256, 56, 56] --
│ └─ResNestBottleneck: 2-10 [32, 256, 56, 56] --
│ │ └─Conv2d: 3-16 [32, 64, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-17 [32, 64, 56, 56] 128
│ │ └─ReLU: 3-18 [32, 64, 56, 56] --
│ │ └─SplitAttnConv2d: 3-19 [32, 64, 56, 56] 43,488
│ │ └─Conv2d: 3-20 [32, 256, 56, 56] 16,384
│ │ └─BatchNorm2d: 3-21 [32, 256, 56, 56] 512
│ │ └─ReLU: 3-22 [32, 256, 56, 56] --
├─Sequential: 1-6 [32, 512, 28, 28] --
│ └─ResNestBottleneck: 2-11 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-23 [32, 128, 56, 56] 32,768
│ │ └─BatchNorm2d: 3-24 [32, 128, 56, 56] 256
│ │ └─ReLU: 3-25 [32, 128, 56, 56] --
│ │ └─SplitAttnConv2d: 3-26 [32, 128, 56, 56] 172,992
│ │ └─AvgPool2d: 3-27 [32, 128, 28, 28] --
│ │ └─Conv2d: 3-28 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-29 [32, 512, 28, 28] 1,024
│ │ └─Sequential: 3-30 [32, 512, 28, 28] 132,096
│ │ └─ReLU: 3-31 [32, 512, 28, 28] --
│ └─ResNestBottleneck: 2-12 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-32 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-33 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-34 [32, 128, 28, 28] --
│ │ └─SplitAttnConv2d: 3-35 [32, 128, 28, 28] 172,992
│ │ └─Conv2d: 3-36 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-37 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-38 [32, 512, 28, 28] --
│ └─ResNestBottleneck: 2-13 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-39 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-40 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-41 [32, 128, 28, 28] --
│ │ └─SplitAttnConv2d: 3-42 [32, 128, 28, 28] 172,992
│ │ └─Conv2d: 3-43 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-44 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-45 [32, 512, 28, 28] --
│ └─ResNestBottleneck: 2-14 [32, 512, 28, 28] --
│ │ └─Conv2d: 3-46 [32, 128, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-47 [32, 128, 28, 28] 256
│ │ └─ReLU: 3-48 [32, 128, 28, 28] --
│ │ └─SplitAttnConv2d: 3-49 [32, 128, 28, 28] 172,992
│ │ └─Conv2d: 3-50 [32, 512, 28, 28] 65,536
│ │ └─BatchNorm2d: 3-51 [32, 512, 28, 28] 1,024
│ │ └─ReLU: 3-52 [32, 512, 28, 28] --
├─Sequential: 1-7 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-15 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-53 [32, 256, 28, 28] 131,072
│ │ └─BatchNorm2d: 3-54 [32, 256, 28, 28] 512
│ │ └─ReLU: 3-55 [32, 256, 28, 28] --
│ │ └─SplitAttnConv2d: 3-56 [32, 256, 28, 28] 690,048
│ │ └─AvgPool2d: 3-57 [32, 256, 14, 14] --
│ │ └─Conv2d: 3-58 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-59 [32, 1024, 14, 14] 2,048
│ │ └─Sequential: 3-60 [32, 1024, 14, 14] 526,336
│ │ └─ReLU: 3-61 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-16 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-62 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-63 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-64 [32, 256, 14, 14] --
│ │ └─SplitAttnConv2d: 3-65 [32, 256, 14, 14] 690,048
│ │ └─Conv2d: 3-66 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-67 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-68 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-17 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-69 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-70 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-71 [32, 256, 14, 14] --
│ │ └─SplitAttnConv2d: 3-72 [32, 256, 14, 14] 690,048
│ │ └─Conv2d: 3-73 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-74 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-75 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-18 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-76 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-77 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-78 [32, 256, 14, 14] --
│ │ └─SplitAttnConv2d: 3-79 [32, 256, 14, 14] 690,048
│ │ └─Conv2d: 3-80 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-81 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-82 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-19 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-83 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-84 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-85 [32, 256, 14, 14] --
│ │ └─SplitAttnConv2d: 3-86 [32, 256, 14, 14] 690,048
│ │ └─Conv2d: 3-87 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-88 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-89 [32, 1024, 14, 14] --
│ └─ResNestBottleneck: 2-20 [32, 1024, 14, 14] --
│ │ └─Conv2d: 3-90 [32, 256, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-91 [32, 256, 14, 14] 512
│ │ └─ReLU: 3-92 [32, 256, 14, 14] --
│ │ └─SplitAttnConv2d: 3-93 [32, 256, 14, 14] 690,048
│ │ └─Conv2d: 3-94 [32, 1024, 14, 14] 262,144
│ │ └─BatchNorm2d: 3-95 [32, 1024, 14, 14] 2,048
│ │ └─ReLU: 3-96 [32, 1024, 14, 14] --
├─Sequential: 1-8 [32, 2048, 7, 7] --
│ └─ResNestBottleneck: 2-21 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-97 [32, 512, 14, 14] 524,288
│ │ └─BatchNorm2d: 3-98 [32, 512, 14, 14] 1,024
│ │ └─ReLU: 3-99 [32, 512, 14, 14] --
│ │ └─SplitAttnConv2d: 3-100 [32, 512, 14, 14] 2,756,352
│ │ └─AvgPool2d: 3-101 [32, 512, 7, 7] --
│ │ └─Conv2d: 3-102 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-103 [32, 2048, 7, 7] 4,096
│ │ └─Sequential: 3-104 [32, 2048, 7, 7] 2,101,248
│ │ └─ReLU: 3-105 [32, 2048, 7, 7] --
│ └─ResNestBottleneck: 2-22 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-106 [32, 512, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-107 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-108 [32, 512, 7, 7] --
│ │ └─SplitAttnConv2d: 3-109 [32, 512, 7, 7] 2,756,352
│ │ └─Conv2d: 3-110 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-111 [32, 2048, 7, 7] 4,096
│ │ └─ReLU: 3-112 [32, 2048, 7, 7] --
│ └─ResNestBottleneck: 2-23 [32, 2048, 7, 7] --
│ │ └─Conv2d: 3-113 [32, 512, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-114 [32, 512, 7, 7] 1,024
│ │ └─ReLU: 3-115 [32, 512, 7, 7] --
│ │ └─SplitAttnConv2d: 3-116 [32, 512, 7, 7] 2,756,352
│ │ └─Conv2d: 3-117 [32, 2048, 7, 7] 1,048,576
│ │ └─BatchNorm2d: 3-118 [32, 2048, 7, 7] 4,096
│ │ └─ReLU: 3-119 [32, 2048, 7, 7] --
├─SelectAdaptivePool2d: 1-9 [32, 2048] --
│ └─AdaptiveAvgPool2d: 2-24 [32, 2048, 1, 1] --
├─Linear: 1-10 [32, 176] 360,624
==========================================================================================
Total params: 25,794,864
Trainable params: 25,794,864
Non-trainable params: 0
Total mult-adds (G): 171.83
==========================================================================================
Input size (MB): 19.27
Forward/backward pass size (MB): 7350.32
Params size (MB): 103.18
Estimated Total Size (MB): 7472.76
==========================================================================================log