本文介绍: ”’数据预处理/数据增强(基于albumentations库)”’# 训练时增强# 最长边限制为imgSize# 参数:随机色调、饱和度、值变化# 随机明亮对比度# 高斯噪声A.OneOf([# 使用随机大小的内核将运动模糊应用于输入图像# 中值滤波# 使用随机大小的内核模糊输入图像], p=0.2),# 较短的边做padding],# 验证时增强# 最长边限制为imgSize# 较短的边做padding],
import模块
import numpy as np
import torch
from functools import partial
from PIL import Image
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import random
import albumentations as A
from pycocotools.coco import COCO
import os
import cv2
import matplotlib.pyplot as plt
基于albumentations
库自定义数据预处理/数据增强
class Transform():
'''数据预处理/数据增强(基于albumentations库)
'''
def __init__(self, imgSize):
maxSize = max(imgSize[0], imgSize[1])
# 训练时增强
self.trainTF = A.Compose([
A.BBoxSafeRandomCrop(p=0.5),
# 最长边限制为imgSize
A.LongestMaxSize(max_size=maxSize),
A.HorizontalFlip(p=0.5),
# 参数:随机色调、饱和度、值变化
A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, always_apply=False, p=0.5),
# 随机明亮对比度
A.RandomBrightnessContrast(p=0.2),
# 高斯噪声
A.GaussNoise(var_limit=(0.05, 0.09), p=0.4),
A.OneOf([
# 使用随机大小的内核将运动模糊应用于输入图像
A.MotionBlur(p=0.2),
# 中值滤波
A.MedianBlur(blur_limit=3, p=0.1),
# 使用随机大小的内核模糊输入图像
A.Blur(blur_limit=3, p=0.1),
], p=0.2),
# 较短的边做padding
A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=cv2.BORDER_CONSTANT, value=[0,0,0]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
)
# 验证时增强
self.validTF = A.Compose([
# 最长边限制为imgSize
A.LongestMaxSize(max_size=maxSize),
# 较短的边做padding
A.PadIfNeeded(imgSize[0], imgSize[1], border_mode=0, mask_value=[0,0,0]),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
],
bbox_params=A.BboxParams(format='coco', min_area=0, min_visibility=0.1, label_fields=['category_ids']),
)
自定义数据集读取类COCODataset
实现
class COCODataset(Dataset):
def __init__(self, annPath, imgDir, inputShape=[800, 600], trainMode=True):
'''__init__() 为默认构造函数,传入数据集类别(训练或测试),以及数据集路径
Args:
:param annPath: COCO annotation 文件路径
:param imgDir: 图像的根目录
:param inputShape: 网络要求输入的图像尺寸
:param trainMode: 训练集/测试集
Returns:
FRCNNDataset
'''
self.mode = trainMode
self.tf = Transform(imgSize=inputShape)
self.imgDir = imgDir
self.annPath = annPath
self.DataNums = len(os.listdir(imgDir))
# 为实例注释初始化COCO的API
self.coco=COCO(annPath)
# 获取数据集中所有图像对应的imgId
self.imgIds = list(self.coco.imgs.keys())
def __len__(self):
'''重载data.Dataset父类方法, 返回数据集大小
'''
return len(self.imgIds)
def __getitem__(self, index):
'''重载data.Dataset父类方法, 获取数据集中数据内容
这里通过pycocotools来读取图像和标签
'''
# 通过imgId获取图像信息imgInfo: 例:{'id': 12465, 'license': 1, 'height': 375, 'width': 500, 'file_name': '2011_003115.jpg'}
imgId = self.imgIds[index]
imgInfo = self.coco.loadImgs(imgId)[0]
# 载入图像 (通过imgInfo获取图像名,得到图像路径)
image = Image.open(os.path.join(self.imgDir, imgInfo['file_name']))
image = np.array(image.convert('RGB'))
# 得到图像里包含的BBox的所有id
imgAnnIds = self.coco.getAnnIds(imgIds=imgId)
# 通过BBox的id找到对应的BBox信息
anns = self.coco.loadAnns(imgAnnIds)
# 获取BBox的坐标和类别
labels, boxes = [], []
for ann in anns:
labelName = ann['category_id']
labels.append(labelName)
boxes.append(ann['bbox'])
labels = np.array(labels)
boxes = np.array(boxes)
# 训练/验证时的数据增强各不相同
if(self.mode):
# albumentation的图像维度得是[W,H,C]
transformed = self.tf.trainTF(image=image, bboxes=boxes, category_ids=labels)
else:
transformed = self.tf.validTF(image=image, bboxes=boxes, category_ids=labels)
# 这里的box是coco格式(xywh)
image, box, label = transformed['image'], transformed['bboxes'], transformed['category_ids']
return image.transpose(2,0,1), np.array(box), np.array(label)
其他
# DataLoader中collate_fn参数使用
# 由于检测数据集每张图像上的目标数量不一
# 因此需要自定义的如何组织一个batch里输出的内容
def frcnn_dataset_collate(batch):
images = []
bboxes = []
labels = []
for img, box, label in batch:
images.append(img)
bboxes.append(box)
labels.append(label)
images = torch.from_numpy(np.array(images))
return images, bboxes, labels
# 设置Dataloader的种子
# DataLoader中worker_init_fn参数使
# 为每个 worker 设置了一个基于初始种子和 worker ID 的独特的随机种子, 这样每个 worker 将产生不同的随机数序列,从而有助于数据加载过程的随机性和多样性
def worker_init_fn(worker_id, seed):
worker_seed = worker_id + seed
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
# 固定全局随机数种子
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
batch数据集可视化
def visBatch(dataLoader:DataLoader):
'''可视化训练集一个batch
Args:
dataLoader: torch的data.DataLoader
Retuens:
None
'''
catName = {1:'person', 2:'bicycle', 3:'car', 4:'motorcycle', 5:'airplane', 6:'bus',
7:'train', 8:'truck', 9:'boat', 10:'traffic light', 11:'fire hydrant',
13:'stop sign', 14:'parking meter', 15:'bench', 16:'bird', 17:'cat', 18:'dog',
19:'horse', 20:'sheep', 21:'cow', 22:'elephant', 23:'bear', 24:'zebra', 25:'giraffe',
27:'backpack', 28:'umbrella', 31:'handbag', 32:'tie', 33:'suitcase', 34:'frisbee',
35:'skis', 36:'snowboard', 37:'sports ball', 38:'kite', 39:'baseball bat',
40:'baseball glove', 41:'skateboard', 42:'surfboard', 43:'tennis racket',
44:'bottle', 46:'wine glass', 47:'cup', 48:'fork', 49:'knife', 50:'spoon', 51:'bowl',
52:'banana', 53:'apple', 54:'sandwich', 55:'orange', 56:'broccoli', 57:'carrot',
58:'hot dog', 59:'pizza', 60:'donut', 61:'cake', 62:'chair', 63:'couch',
64:'potted plant', 65:'bed', 67:'dining table', 70:'toilet', 72:'tv', 73:'laptop',
74:'mouse', 75:'remote', 76:'keyboard', 77:'cell phone', 78:'microwave',
79:'oven', 80:'toaster', 81:'sink', 82:'refrigerator', 84:'book', 85:'clock',
86:'vase', 87:'scissors', 88:'teddy bear', 89:'hair drier', 90:'toothbrush'}
for step, batch in enumerate(dataLoader):
images, boxes, labels = batch[0], batch[1], batch[2]
# 只可视化一个batch的图像:
if step > 0: break
# 图像均值
mean = np.array([0.485, 0.456, 0.406])
# 标准差
std = np.array([[0.229, 0.224, 0.225]])
plt.figure(figsize = (8,8))
for idx, imgBoxLabel in enumerate(zip(images, boxes, labels)):
img, box, label = imgBoxLabel
ax = plt.subplot(4,4,idx+1)
img = img.numpy().transpose((1,2,0))
# 由于在数据预处理时我们对数据进行了标准归一化,可视化的时候需要将其还原
img = img * std + mean
for instBox, instLabel in zip(box, label):
x, y, w, h = round(instBox[0]),round(instBox[1]), round(instBox[2]), round(instBox[3])
# 显示框
ax.add_patch(plt.Rectangle((x, y), w, h, color='blue', fill=False, linewidth=2))
# 显示类别
ax.text(x, y, catName[instLabel], bbox={'facecolor':'white', 'alpha':0.5})
plt.imshow(img)
# 在图像上方展示对应的标签
# 取消坐标轴
plt.axis("off")
# 微调行间距
plt.subplots_adjust(left=0.05, bottom=0.05, right=0.95, top=0.95, wspace=0.05, hspace=0.05)
plt.show()
example
# for test only:
if __name__ == "__main__":
# 固定随机种子
seed = 23
seed_everything(seed)
# BatcchSize
BS = 16
# 图像尺寸
imgSize = [800, 800]
trainAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_train2017.json"
testAnnPath = "E:/datasets/Universal/COCO2017/annotations/instances_val2017.json"
imgDir = "E:/datasets/Universal/COCO2017/train2017"
# 自定义数据集读取类
trainDataset = COCODataset(trainAnnPath, imgDir, imgSize, trainMode=True)
trainDataLoader = DataLoader(trainDataset, shuffle=True, batch_size = BS, num_workers=2, pin_memory=True,
collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
# validDataset = COCODataset(testAnnPath, imgDir, imgSize, trainMode=False)
# validDataLoader = DataLoader(validDataset, shuffle=True, batch_size = BS, num_workers = 1, pin_memory=True,
# collate_fn=frcnn_dataset_collate, worker_init_fn=partial(worker_init_fn, seed=seed))
print(f'训练集大小 : {trainDataset.__len__()}')
visBatch(trainDataLoader)
for step, batch in enumerate(trainDataLoader):
images, boxes, labels = batch[0], batch[1], batch[2]
# torch.Size([bs, 3, 800, 800])
print(f'images.shape : {images.shape}')
# 列表形式,因为每个框里的实例数量不一,所以每个列表里的box数量不一
print(f'len(boxes) : {len(boxes)}')
# 列表形式,因为每个框里的实例数量不一,所以每个列表里的label数量不一
print(f'len(labels) : {len(labels)}')
break
输出
images.shape : torch.Size([16, 3, 800, 800])
len(boxes) : 16
len(labels) : 16
原文地址:https://blog.csdn.net/SESESssss/article/details/135723489
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.7code.cn/show_61443.html
如若内容造成侵权/违法违规/事实不符,请联系代码007邮箱:suwngjj01@126.com进行投诉反馈,一经查实,立即删除!
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。