项目要点
工业缺陷类别的分类:
-
# CR 裂纹: crackle # In 夹杂 inclusion # SC 划痕 scratch # PS 压入氧化皮 press in oxide scale # RS 麻点 # PA 斑点
一 数据预处理
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2 as cv2
# 创建数据集,集成pytorch的dataset类, 必须实现__len__ and __getitem__
# CR 裂纹: crackle
# In 夹杂 inclusion
# SC 划痕 scratch
# PS 压入氧化皮 press in oxide scale
# RS 麻点
# PA 斑点
defect_labels = ['In', 'Sc', 'Cr', 'PS', 'RS', 'Pa']
class SurfaceDefectDataset(Dataset):
def __init__(self, root_dir):
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
transforms.Resize((200,200))
])
img_files = os.listdir(root_dir)
self.defect_types = []
self.images = []
for file_name in img_files:
# 以下划线分隔文件名
defect_class = file_name.split('_')[0]
defect_index = defect_labels.index(defect_class)
self.images.append(os.path.join(root_dir, file_name))
self.defect_types.append(defect_index)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image_path = self.images

该文介绍了一个基于PyTorch的工业缺陷图像分类项目。首先,进行了数据预处理,包括图像读取、转换和resize,然后定义了数据集类。接着,导入ResNet模型并调整为6类缺陷分类。最后,展示了模型训练过程和测试代码,用于预测新的缺陷图像类别。

被折叠的 条评论
为什么被折叠?



