24- 工业缺陷检测 (Pytorch系列) (项目二十四)

该文介绍了一个基于PyTorch的工业缺陷图像分类项目。首先,进行了数据预处理,包括图像读取、转换和resize,然后定义了数据集类。接着,导入ResNet模型并调整为6类缺陷分类。最后,展示了模型训练过程和测试代码,用于预测新的缺陷图像类别。

项目要点

工业缺陷类别的分类:

  • # CR 裂纹: crackle
    # In 夹杂 inclusion
    # SC 划痕 scratch
    # PS 压入氧化皮 press in oxide scale
    # RS 麻点
    # PA 斑点
    


一 数据预处理

import os

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2 as cv2

# 创建数据集,集成pytorch的dataset类, 必须实现__len__ and __getitem__

# CR 裂纹: crackle
# In 夹杂 inclusion
# SC 划痕 scratch
# PS 压入氧化皮 press in oxide scale
# RS 麻点
# PA 斑点
defect_labels = ['In', 'Sc', 'Cr', 'PS', 'RS', 'Pa']

class SurfaceDefectDataset(Dataset):
    def __init__(self, root_dir):
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
            transforms.Resize((200,200))
        ])
        img_files = os.listdir(root_dir)
        self.defect_types = []
        self.images = []

        for file_name in img_files:
            # 以下划线分隔文件名
            defect_class = file_name.split('_')[0]
            defect_index = defect_labels.index(defect_class)
            self.images.append(os.path.join(root_dir, file_name))
            self.defect_types.append(defect_index)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值