基于pyqt的深度学习的森林火灾智能检测系统

一、课题背景与意义

1.1 课题背景

森林火灾是全球范围内影响生态环境、公共安全以及经济发展的重大灾害之一。每年,森林火灾在多个国家和地区造成了极大的人员伤亡、财产损失以及生态破坏。以美国加利福尼亚州、澳大利亚和中国的森林火灾为例,这些地区经历了极其严重的森林火灾,火灾发生后往往带来的是资源浪费、环境污染和生态环境破坏等一系列后果。森林火灾的发生往往受到气候条件、森林管理不善、突发性火源等多方面因素的影响,因此尽早准确地检测火灾的发生,对于减少损失、提高应急响应能力具有至关重要的意义。

传统的火灾监测手段主要依赖人工巡逻和遥感卫星数据的采集,但这些方法往往存在响应滞后和人力成本高的问题。近年来,随着人工智能、计算机视觉及深度学习技术的不断进步,基于图像识别的火灾检测系统逐渐崭露头角,能够实现实时监测、自动识别,并提供有效的火灾预警,已成为火灾监测和防控的一个新趋势。

图像识别技术,尤其是基于卷积神经网络(CNN)等深度学习算法,已经在许多领域展现出了强大的能力。在火灾监测领域,借助深度学习的图像分类技术,可以从森林区域的图像数据中快速提取火灾的相关特征,从而实现自动化的火灾检测,极大提高了监测的效率和准确度。

1.2 课题意义

本课题旨在设计并实现一个基于深度学习的森林火灾智能检测系统,该系统通过图像识别技术,结合卷积神经网络(CNN)等深度学习模型,能够对上传的森林图像进行火灾检测和分类。系统通过图像分析判断当前图像中是否存在火灾,并根据分类结果返回火灾与非火灾的预测概率,帮助应急响应部门及时获取火灾信息并采取有效措施。

本课题的意义主要体现在以下几个方面:

  1. 提高火灾识别的准确性与效率:传统的火灾检测方法依赖人工巡逻和遥感卫星数据,存在响应时间滞后和数据更新不及时的问题。而基于深度学习的图像识别技术能够快速、准确地识别森林图像中的火灾信息,实时性和准确性得到显著提升。
  2. 减少人工干预,提高自动化水平:该系统通过自动化图像识别,减少了人工巡查的成本和人力资源浪费,能够在较短的时间内完成大面积区域的火灾检测。
  3. 帮助生态环境保护与应急响应:实时、准确的火灾检测能够为林业部门提供决策支持,及时发现火灾热点区域,快速响应火灾发生,减轻生态环境损害,减少人员伤亡和财产损失。
  4. 推动人工智能技术在自然灾害防控中的应用:本课题基于深度学习的森林火灾检测系统,展示了人工智能技术在自然灾害防控中的潜力,具有较高的研究价值和应用前景。

二、国内外研究现状

2.1 国内研究现状

在国内,森林火灾监测和预警的研究起步较早。传统的火灾检测方法主要依靠人工巡逻、卫星遥感等方式。随着遥感技术的发展,利用卫星图像进行火灾检测成为一种常见的手段。然而,传统遥感技术受限于数据获取的时效性和精度,且无法实现实时检测。为了提高火灾检测的准确性和效率,近年来,越来越多的国内研究开始关注深度学习和计算机视觉技术的应用。

例如,华南农业大学的研究团队提出了一种基于卷积神经网络(CNN)和无人机图像的森林火灾检测方法,通过无人机获取森林火灾的高分辨率图像,并使用深度学习模型进行分析,取得了较好的火灾识别效果。此外,北京大学和中科院的部分研究团队也在火灾检测的领域做出了贡献,提出了一些基于图像识别的火灾检测方法。

然而,现有的研究还存在一些问题。首先,现有的火灾图像数据集相对较小,模型训练难度较大,且缺乏多样化的图像样本。其次,火灾识别的准确性和实时性仍然存在一定的提升空间,尤其是在不同光照条件和复杂背景下的火灾图像识别任务。

2.2 国外研究现状

在国外,火灾检测技术的研究和应用相对较早且成熟,尤其是美国、欧洲和日本等发达国家,已形成了一定规模的森林火灾监测体系。美国的NASA、NOAA等机构,利用遥感卫星和航空监测手段,通过大范围数据采集和分析,实现了大范围森林火灾的监测。然而,这些方法通常存在图像分辨率较低、无法实时监控等问题。

随着深度学习技术的飞速发展,越来越多的研究开始将深度学习模型应用于火灾检测。例如,加州大学的研究团队开发了一种基于卷积神经网络(CNN)的火灾检测系统,能够通过高分辨率的卫星图像对火灾进行自动检测。该系统通过训练卷积神经网络模型,能够识别图像中的火灾区域,检测精度和速度上都有了较大提升。

另外,国外还开展了多项关于火灾数据集的建设工作。许多公共数据集(如Firesense、FireNet等)提供了大量标注的火灾图像数据,为火灾检测的研究提供了基础支持。通过这些数据集,研究人员能够训练和优化深度学习模型,提高火灾检测的准确性和泛化能力。

2.3 发展趋势与不足

随着深度学习技术的不断发展,火灾图像识别技术也呈现出良好的应用前景。基于卷积神经网络(CNN)的火灾检测方法,尤其是在处理复杂环境和图像质量差的情况时,具有较高的鲁棒性。近年来,国内外的研究逐渐从传统的图像处理方法转向基于深度学习的自动化检测。

然而,当前的研究还存在一些不足:

  1. 数据集的稀缺性:虽然已经有一些公开的火灾数据集,但这些数据集普遍较小且多为简单场景,缺乏复杂环境下的火灾图像,影响了模型的泛化能力。
  2. 火灾检测的实时性和精度:虽然基于深度学习的火灾检测方法在精度上取得了显著进展,但在实际应用中,模型的推理速度和实时性仍然是一个挑战,尤其是在大规模区域和高分辨率图像的处理上。
  3. 环境变化的适应性:深度学习模型对于环境变化和光照变化的适应性仍然存在一定的挑战,模型训练时需要考虑更多样化的图像样本和复杂的背景信息。

本课题将在现有研究基础上,使用深度学习方法,针对森林火灾图像进行更高效、准确的检测。

三、研究目标与任务

3.1 研究目标

本课题的主要目标是设计并实现一个基于深度学习的森林火灾智能检测系统,能够通过图像识别技术对森林火灾进行准确分类,并为用户提供实时火灾预警。具体目标如下:

  1. 火灾图像数据集构建:收集和整理森林火灾图像数据集,进行数据预处理与增强,确保训练数据的多样性和有效性。
  2. 深度学习模型设计与训练:设计基于卷积神经网络(CNN)的深度学习模型,利用迁移学习技术,提升模型的训练效率和准确性。
  3. 火灾检测与分类功能:实现火灾图像的自动分类,能够识别火灾和非火灾场景,并根据预测结果输出相关信息。
  4. 系统实现与优化:开发一个用户友好的图形界面,通过PyQt5实现火灾图像的上传、分类结果展示等功能,优化系统性能,保证实时性和准确性。

3.2 研究任务

  1. 数据集收集与预处理:构建火灾图像数据集,进行图像的标注、裁剪、旋转、亮度调整等增强操作,提升训练数据的多样性。
  2. 深度学习模型训练与调优:基于ResNet50或EfficientNet等预训练模型,进行迁移学习,优化模型性能。
  3. 图形用户界面开发:使用PyQt5框架开发用户界面,支持用户上传图像、查看分类结果及置信度。
  4. 系统集成与性能优化:将深度学习模型与图形界面系统进行集成,优化图像处理与分类速度,确保系统在实际环境下的稳定运行。

四、研究内容与技术路线

4.1 研究内容

  1. 数据集的构建与预处理
  • 收集火灾图像,确保数据集的多样性。
  • 进行数据增强,提高训练样本的丰富性。
  1. 深度学习模型的设计与训练
  • 使用卷积神经网络(CNN)进行图像分类任务。
  • 采用迁移学习,使用预训练模型(如ResNet50)进行微调,提升训练效率。
  1. 系统开发与实现
  • 使用PyQt5开发图形界面,提供火灾图像上传、结果展示等功能。
  1. 性能优化与测试
  • 优化模型推理速度,确保系统实时性。
  • 进行系统测试,确保在不同环境下的稳定性和准确性。

4.2 技术路线

  1. 数据收集与预处理:
  • 收集公开数据集或自采集火灾图像,进行标注和增强。
  1. 深度学习模型训练:
  • 基于CNN模型进行火灾图像的训练和分类。
  1. 用户界面设计与实现:
  • 使用PyQt5实现系统的图形用户界面。
  1. 系统集成与优化:
  • 将深度学习模型与PyQt5界面进行整合,确保系统高效运行。

五、预期成果与创新点

5.1 预期成果

  1. 设计并实现森林火灾智能检测系统:基于深度学习的火灾图像分类模型,能够准确检测火灾场景。
  2. 用户友好的界面:开发支持火灾图像上传和结果展示的图形用户界面。
  3. 提升火灾检测的准确性与效率:通过深度学习技术,提高火灾检测的准确性,并优化系统性能,确保实时性。

5.2 创新点

  1. 基于深度学习的火灾图像识别:通过深度学习技术进行火灾图像分类,提升火灾检测的准确性。
  2. 迁移学习与数据增强:通过迁移学习和数据增强方法,提升训练数据的多样性,减少数据不足的问题。
  3. 用户交互界面与系统集成:设计并实现图形用户界面,使得普通用户也能方便地进行火灾图像识别。

六、研究计划与进度安排

6.1

  • 第1周至第2周:数据集收集与预处理,进行数据增强与标准化处理。
  • 第3周至第5周:设计并训练深度学习模型,进行模型调试与测试。
  • 第6周至第7周:开发图形用户界面,完成上传、显示结果等功能。
  • 第8周至第9周:系统集成与性能优化,确保系统稳定运行。
  • 第10周:完成毕业论文撰写,并整理研究成果。

七、参考文献

  1. 李明,张强,《基于卷积神经网络的火灾图像识别》,《计算机应用》2019年第3期,pp. 45-51。
  2. 刘辉,陈杰,《深度学习在灾害检测中的应用研究综述》,《人工智能》2020年第4期,pp. 102-110。
  3. 张超,王涛,《PyQt5开发桌面应用实战》,机械工业出版社,2018年。

核心设计部分(仅供学习和参考)

以下是一个基于 PyQt5深度学习森林火灾智能检测系统 的核心程序实现。该程序结合了深度学习模型和图形用户界面(GUI)来进行火灾图像检测。具体包括以下几个模块:

  1. 用户登录与注册功能
  2. 图像上传功能
  3. 基于深度学习的火灾检测功能
  4. 结果展示功能

以下是整个系统的部分详细代码实现。

1. 项目结构

/forest_fire_detection_system
    ├── main.py                # 主程序
    ├── auth_manager.py        # 用户认证模块
    ├── model_loader.py        # 深度学习模型模块
    ├── upload_widget.py       # 图像上传模块
    ├── login_dialog.py        # 登录界面
    ├── register_dialog.py     # 注册界面
    ├── fire_model.pth         # 训练好的火灾检测模型
    ├── class_names.txt        # 火灾与非火灾分类标签
    └── users.db               # 用户数据库

2. 主程序 (main.py)

import sys
from PyQt5.QtWidgets import QApplication, QMainWindow
from auth_manager import AuthManager
from model_loader import FireDetector
from upload_widget import ImageUploadWidget

class FireDetectionApp(QMainWindow):
    def __init__(self):
        super().__init__()
        self.auth = AuthManager()
        self.model = FireDetector()
        self.init_ui()
        self.check_login()

    def init_ui(self):
        self.setWindowTitle('森林火灾智能检测系统')
        self.resize(800, 600)
        
        self.upload_widget = ImageUploadWidget(self.model)
        self.setCentralWidget(self.upload_widget)

    def check_login(self):
        if not self.auth.is_logged_in:
            from login_dialog import LoginDialog
            dlg = LoginDialog(self.auth)
            if not dlg.exec_():
                sys.exit()

if __name__ == '__main__':
    app = QApplication(sys.argv)
    window = FireDetectionApp()
    window.show()
    sys.exit(app.exec_())

3. 用户认证模块 (auth_manager.py)

import sqlite3
from hashlib import sha256

class AuthManager:
    def __init__(self):
        self.conn = sqlite3.connect('users.db')
        self._create_table()
        self.current_user = None

    def _create_table(self):
        self.conn.execute('''CREATE TABLE IF NOT EXISTS users
             (id INTEGER PRIMARY KEY AUTOINCREMENT,
              username TEXT UNIQUE NOT NULL,
              password TEXT NOT NULL)''')

    def register(self, username, password):
        try:
            hashed = sha256(password.encode()).hexdigest()
            self.conn.execute("INSERT INTO users (username, password) VALUES (?, ?)",
                            (username, hashed))
            self.conn.commit()
            return True
        except sqlite3.IntegrityError:
            return False

    def login(self, username, password):
        hashed = sha256(password.encode()).hexdigest()
        cursor = self.conn.execute("SELECT * FROM users WHERE username=? AND password=?",
                                 (username, hashed))
        if cursor.fetchone():
            self.current_user = username
            return True
        return False

    @property
    def is_logged_in(self):
        return self.current_user is not None

4. 深度学习模型模块 (model_loader.py)

import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torchvision.models as models

class FireDetector:
    def __init__(self, model_path='fire_model.pth'):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = torch.load(model_path, map_location=self.device)
        self.model.eval()

        self.transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])
        ])
        
        with open('class_names.txt', 'r') as f:
            self.classes = [line.strip() for line in f]

    def predict(self, image_path):
        img = Image.open(image_path).convert('RGB')
        img_t = self.transform(img).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(img_t)
            _, preds = torch.max(outputs, 1)
            probs = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
            
        return {
            'class': self.classes[preds[0]],
            'confidence': float(probs[preds[0]]),
            'all_probs': {c: float(p) for c, p in zip(self.classes, probs)}
        }

5. 图像上传与分类模块 (upload_widget.py)

from PyQt5.QtWidgets import QWidget, QVBoxLayout, QLabel, QPushButton, QFileDialog, QScrollArea
from PyQt5.QtGui import QPixmap

class ImageUploadWidget(QWidget):
    def __init__(self, classifier):
        super().__init__()
        self.classifier = classifier
        self.init_ui()

    def init_ui(self):
        self.layout = QVBoxLayout()
        
        self.upload_btn = QPushButton('上传森林火灾图片')
        self.upload_btn.clicked.connect(self.open_file)
        
        self.image_label = QLabel()
        self.image_label.setFixedSize(400, 300)
        
        self.result_area = QScrollArea()
        self.result_widget = QWidget()
        self.result_layout = QVBoxLayout(self.result_widget)
        
        self.layout.addWidget(self.upload_btn)
        self.layout.addWidget(self.image_label)
        self.layout.addWidget(self.result_area)
        self.setLayout(self.layout)

    def open_file(self):
        fname, _ = QFileDialog.getOpenFileName(self, '选择森林火灾图片', 
                                              '', 'Images (*.png *.jpg *.jpeg)')
        if fname:
            pixmap = QPixmap(fname)
            self.image_label.setPixmap(pixmap.scaled(
                self.image_label.size(), 
                aspectRatioMode=True))
            
            result = self.classifier.predict(fname)
            self.display_result(result)

    def display_result(self, result):
        # 清空之前的结果
        while self.result_layout.count():
            child = self.result_layout.takeAt(0)
            if child.widget():
                child.widget().deleteLater()
        
        # 显示新结果
        self.result_layout.addWidget(QLabel(f"预测种类: {result['class']}"))
        self.result_layout.addWidget(QLabel(f"置信度: {result['confidence']:.2f}%"))
        
        if result['confidence'] < 70:
            warning = QLabel("⚠️ 低置信度结果,请谨慎参考!")
            warning.setStyleSheet("color: red; font-weight: bold;")
            self.result_layout.addWidget(warning)
        
        self.result_area.setWidget(self.result_widget)

6. 登录对话框 (login_dialog.py)

from PyQt5.QtWidgets import QDialog, QVBoxLayout, QLabel, QLineEdit, QPushButton, QMessageBox

class LoginDialog(QDialog):
    def __init__(self, auth_manager):
        super().__init__()
        self.auth = auth_manager
        self.init_ui()

    def init_ui(self):
        self.setWindowTitle('用户登录')
        layout = QVBoxLayout()
        
        self.username = QLineEdit()
        self.username.setPlaceholderText('用户名')
        
        self.password = QLineEdit()
        self.password.setPlaceholderText('密码')
        self.password.setEchoMode(QLineEdit.Password)
        
        login_btn = QPushButton('登录')
        login_btn.clicked.connect(self.do_login)
        
        register_btn = QPushButton('注册新用户')
        register_btn.clicked.connect(self.show_register)
        
        layout.addWidget(QLabel('森林火灾智能检测系统登录'))
        layout.addWidget(self.username)
        layout.addWidget(self.password)
        layout.addWidget(login_btn)
        layout.addWidget(register_btn)
        
        self.setLayout(layout)

    def do_login(self):
        if self.auth.login(self.username.text(), self.password.text()):
            self.accept()
        else:
            QMessageBox.warning(self, '错误', '用户名或密码错误!')

    def show_register(self):
        from register_dialog import RegisterDialog
        dlg = RegisterDialog(self.auth)
        dlg.exec_()

7. 注册对话框 (register_dialog.py)

from PyQt5.QtWidgets import QDialog, QVBoxLayout, QLabel, QLineEdit, QPushButton, QMessageBox

class RegisterDialog(QDialog):
    def __init__(self, auth_manager):
        super().__init__()
        self.auth = auth_manager
        self.init_ui()

    def init_ui(self):
        self.setWindowTitle('用户注册')
        layout = QVBoxLayout()
        
        self.username = QLineEdit()
        self.username.setPlaceholderText('用户名')
        
        self.password = QLineEdit()
        self.password.setPlaceholderText('密码')
        self.password.setEchoMode(QLineEdit.Password)
        
        register_btn = QPushButton('注册')
        register_btn.clicked.connect(self.do_register)
        
        layout.addWidget(QLabel('请输入注册信息'))
        layout.addWidget(self.username)
        layout.addWidget(self.password)
        layout.addWidget(register_btn)
        
        self.setLayout(layout)

    def do_register(self):
        if self.auth.register(self.username.text(), self.password.text()):
            QMessageBox.information(self, '成功', '注册成功!')
            self.accept()
        else:
            QMessageBox.warning(self, '错误', '用户名已存在!')

8. 模型训练

import torch
import torchvision.models as models
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# 数据集加载
train_data = datasets.ImageFolder(root='fire_train', transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# 加载预训练模型
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # 2类:火灾与非火灾

# 使用交叉熵损失与Adam优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
model.train()
for epoch in range(10):  # 训练10个epoch
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/10], Loss: {loss.item():.4f}")

# 保存模型
torch.save(model.state_dict(), 'fire_model.pth')

9. 运行与部署

  1. 依赖安装
pip install pyqt5 torch torchvision pillow
  1. 运行系统
python main.py

算法核心:

此代码包含以下几部分:

  1. 数据预处理与增强
  2. 模型设计与训练
  3. 推理过程
  4. 性能评估

我们将使用 PyTorch 来构建深度学习模型,使用 ResNet50 作为基础模型,并进行迁移学习,以适应森林火灾图像分类任务。

1. 数据预处理与增强

为了提高模型的鲁棒性,我们需要进行数据增强。数据增强的目的是生成更多的训练样本,以帮助模型更好地泛化。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据增强与预处理
transform = transforms.Compose([
    transforms.Resize(256),                # 将图片统一大小
    transforms.CenterCrop(224),            # 中心裁剪,去除边缘无关区域
    transforms.RandomHorizontalFlip(),     # 随机水平翻转,数据增强
    transforms.RandomRotation(30),         # 随机旋转,数据增强
    transforms.ToTensor(),                 # 转化为Tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])  # 标准化(ImageNet标准)
])

# 加载训练数据集和验证数据集
train_data = datasets.ImageFolder(root='fire_train', transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

val_data = datasets.ImageFolder(root='fire_val', transform=transform)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

2. 模型设计与训练

使用 ResNet50 预训练模型,并根据火灾识别任务对模型进行微调。我们将使用迁移学习来利用 ResNet50 的预训练权重,然后训练新的分类层。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# 定义分类模型
class FireDetector(nn.Module):
    def __init__(self, num_classes=2):
        super(FireDetector, self).__init__()
        
        # 使用ResNet50预训练模型
        self.resnet = models.resnet50(pretrained=True)
        
        # 修改输出层,适应我们的分类任务
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)  # 假设有2类:火灾与非火灾
    
    def forward(self, x):
        return self.resnet(x)

# 初始化模型
model = FireDetector(num_classes=2)  # 2类:火灾与非火灾

# 使用GPU进行训练(如果可用)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# 交叉熵损失与Adam优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
def train_model(model, train_loader, val_loader, num_epochs=10):
    best_acc = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()  # 清零梯度
            outputs = model(inputs)  # 前向传播
            loss = criterion(outputs, labels)  # 计算损失
            loss.backward()  # 反向传播
            optimizer.step()  # 更新权重

            running_loss += loss.item()

            # 计算准确率
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels).item()
            total_predictions += labels.size(0)
        
        # 每个epoch的训练损失和准确率
        train_loss = running_loss / len(train_loader)
        train_acc = correct_predictions / total_predictions

        # 在验证集上评估
        model.eval()  # 设置为评估模式
        correct_predictions = 0
        total_predictions = 0
        with torch.no_grad():  # 不需要计算梯度
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                correct_predictions += torch.sum(preds == labels).item()
                total_predictions += labels.size(0)
        
        val_acc = correct_predictions / total_predictions

        # 输出每个epoch的训练信息
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, "
              f"Validation Accuracy: {val_acc:.4f}")
        
        # 保存模型,若验证集准确率有所提升
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_fire_model.pth')

# 训练模型
train_model(model, train_loader, val_loader, num_epochs=10)

3. 推理过程

训练完成后,我们可以使用训练好的模型进行火灾图像的预测。该过程将加载图像,进行预处理,并使用模型进行推理。

from PIL import Image
import torch
from torchvision import transforms

# 模型推理
def predict_image(image_path, model):
    # 加载并预处理图像
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # 转为batch维度,并移到GPU(如果可用)

    # 模型推理
    model.eval()
    with torch.no_grad():
        outputs = model(image)
        _, preds = torch.max(outputs, 1)  # 获取预测类别
        probs = torch.nn.functional.softmax(outputs, dim=1)[0] * 100  # 计算每类的概率

    # 获取分类标签
    class_names = ['Non-Fire', 'Fire']  # 火灾与非火灾的分类标签
    predicted_class = class_names[preds.item()]
    confidence = probs[preds.item()].item()

    return predicted_class, confidence

# 预测新图像
image_path = 'path_to_fire_image.jpg'
predicted_class, confidence = predict_image(image_path, model)
print(f"Predicted Class: {predicted_class}, Confidence: {confidence:.2f}%")

4. 性能评估

为了评估模型的性能,我们可以使用准确率、混淆矩阵等指标来评估模型的表现。这里我们计算并可视化混淆矩阵。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

def evaluate_model(model, val_loader):
    model.eval()
    all_preds = []
    all_labels = []

    # 计算所有预测与标签
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    # Flatten the lists of predictions and labels
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    # 计算混淆矩阵
    cm = confusion_matrix(all_labels, all_preds)

    # 绘制混淆矩阵热图
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=train_data.classes, yticklabels=train_data.classes)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.show()

# 评估模型
evaluate_model(model, val_loader)

总结

  • 数据预处理与增强:我们使用了 transforms 进行图像的标准化、随机旋转、随机翻转等增强操作,以提高模型的泛化能力。
  • 模型设计与训练:基于 ResNet50 模型进行迁移学习,并微调了输出层来适应火灾与非火灾的二分类任务。
  • 推理与可用性:使用训练好的模型进行图像预测,输出火灾与非火灾的分类标签及置信度。
  • 评估与优化:通过计算混淆矩阵和准确率等指标评估模型的性能,并提供热图进行可视化。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

源码空间站TH

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值