主页 > IT业界  > 

中药细粒度图像分类

中药细粒度图像分类

在细粒度图像分类(FGVC)领域,Bilinear CNN(BCNN)模型因其能够捕捉图像中的局部特征交互而受到广泛关注。该模型通过双线性池化操作将两个不同CNN提取的特征进行外积运算,从而获得更加丰富的特征表示,这对于区分外观相似但属于不同子类别的物体尤其有效。然而,BCNN通常计算成本较高,限制了其在移动设备或资源受限环境下的应用。

为了实现轻量化并保持高精度的细粒度分类,可以考虑将MobileNetV2引入到BCNN框架中。MobileNetV2以其深度可分离卷积和倒残差结构著称,能够在减少计算复杂度的同时保证较高的分类性能。此外,MobileNetV2中的线性瓶颈和逐点卷积有助于更有效地处理稀疏数据,进一步提升网络的表达能力。

在此基础上,添加Inception模块是一个值得探索的方向。Inception模块通过并行使用多种尺寸的卷积核,能够同时捕捉不同尺度的特征信息,这对于中药这种形态各异、纹理复杂的对象来说尤为重要。结合Inception模块的多尺度特征提取能力和MobileNetV2的高效架构,可以在不显著增加计算负担的前提下增强模型对细节特征的敏感度。

弱监督学习则允许我们仅依赖图像级别的标签来进行训练,无需精确的边界框或部分注释,这大大降低了标注成本,并使得大规模数据集的应用成为可能。特别是在中药分类这样一个需要大量专业知识才能准确标注的领域,弱监督方法能够显著降低专家标注的工作量。

好的我来讲代码部分

github /HaoMood/bilinear-cnn.git   基于这个大佬的代码改进

github /hackerjackL/xilidu.git   这是我的代码仓库

pip install torch torchvision pillow tqdm #理论上应该是这些 torch >=2.0 #1.0版本不行,有些函数和方法用不了 python 3.8

models.py 

import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torchvision.models import MobileNet_V2_Weights class InceptionModule(nn.Module): def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): super(InceptionModule, self).__init__() self.branch1 = nn.Conv2d(in_channels, ch1x1, kernel_size=1) self.branch2 = nn.Sequential( nn.Conv2d(in_channels, ch3x3red, kernel_size=1), nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) ) self.branch3 = nn.Sequential( nn.Conv2d(in_channels, ch5x5red, kernel_size=1), nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, padding=2) ) self.branch4 = nn.Sequential( nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(in_channels, pool_proj, kernel_size=1) ) def forward(self, x): branch1 = F.relu(self.branch1(x)) branch2 = F.relu(self.branch2(x)) branch3 = F.relu(self.branch3(x)) branch4 = F.relu(self.branch4(x)) return torch.cat([branch1, branch2, branch3, branch4], 1) class MobileNetV2Classifier(nn.Module): """Bilinear CNN Model using MobileNetV2""" def __init__(self, num_classes): super(MobileNetV2Classifier, self).__init__() model_urls = { 'mobilenet_v2': ' download.pytorch.org/models/mobilenet_v2-b0353104.pth', } mobilenet_v2 = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1) self.features = mobilenet_v2.features # Freeze the features layers for param in self.features.parameters(): param.requires_grad = False # Add a new classifier on top of the features self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(mobilenet_v2.last_channel, num_classes) ) def forward(self, x): x = self.features(x) x = x.mean([2, 3]) # Global average pooling x = self.classifier(x) return x InceptionModule

InceptionModule 类实现了一个经典的 Inception 模块,它可以在 GoogLeNet 等模型中找到。这个模块允许网络在多个尺度上并行处理信息。

构造函数 (__init__ 方法):

接受输入通道数以及各个分支的通道数配置。分支1: 使用一个 1x1 卷积来减少维度。分支2: 先通过一个 1x1 卷积进行降维,然后使用一个 3x3 卷积(带有 padding 来保持尺寸)。分支3: 类似于分支2,但使用的是 5x5 卷积。分支4: 首先进行 3x3 最大池化,然后通过一个 1x1 卷积调整通道数。

前向传播 (forward 方法):

对每个分支应用 ReLU 激活函数,并将它们的结果沿通道维度拼接起来。 MobileNetV2Classifier

MobileNetV2Classifier 类基于 MobileNet V2 模型,用于图像分类任务。它利用了预训练的 MobileNet V2 特征提取器,并在其基础上添加了一个新的分类头。

构造函数 (__init__ 方法):

加载预训练的 MobileNet V2 模型(使用 weights=MobileNet_V2_Weights.IMAGENET1K_V1 参数指定加载 ImageNet 上预训练的权重)。冻结特征提取层的参数(即设置 requires_grad=False),以便只训练新添加的分类层。添加了一个由 Dropout 层和全连接层组成的分类头。全连接层的输入大小是 MobileNet V2 的最后一个特征通道数(mobilenet_v2.last_channel),输出大小是类别数。

前向传播 (forward 方法):

输入数据首先通过 MobileNet V2 的特征提取层。然后对特征图进行全局平均池化(即将每个特征图缩减为单个数值),这一步通常用于将二维特征转换为一维特征向量。最终通过分类器得到最终的分类结果。

logger.py

这里我写了一些前作者没有的 例如显存 打印网络层 预计时间等 不过打印显存信息有误 请还是nvidia --smi 查看吧

import time import torch from tqdm import tqdm import os import csv def print_model_structure(model): print("Model structure:") print("----------------------------------------------------------------") for name, module in model.named_modules(): if name == '': continue indent = name.count('.') print(' ' * (4 * indent) + name + ':', str(module).split('\n')[0]) print("----------------------------------------------------------------") print("Model summary:") def log_training_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used, writer): info = { "GPU Memory Used": f"{gpu_mem_used:.2f}MB", "Epoch": f"{epoch}", "Train Loss": f"{train_loss:.4f}", "Train Accuracy": f"{train_acc:.2f}%", "Validation Accuracy": f"{val_acc:.2f}%", "Epoch Time": f"{epoch_time:.1f}s", "Remaining Time": f"{remaining_time:.1f}s" } print("\t".join(info.values())) writer.writerow(info) class TrainingLogger: def __init__(self, epochs, result_file): self.pbar = tqdm(total=epochs, desc="Training", unit="epoch") self.result_file = result_file self.fieldnames = ["GPU Memory Used", "Epoch", "Train Loss", "Train Accuracy", "Validation Accuracy", "Epoch Time", "Remaining Time"] with open(self.result_file, mode='w', newline='') as file: self.writer = csv.DictWriter(file, fieldnames=self.fieldnames) self.writer.writeheader() def update_progress(self): self.pbar.update(1) def close_progress(self): self.pbar.close() def log_epoch_info(self, epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used): with open(self.result_file, mode='a', newline='') as file: writer = csv.DictWriter(file, fieldnames=self.fieldnames) log_training_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used, writer)

config.py

这里存放了一些定义的超参 运行代码 例如

python train.py --data_dir /root/images --batch_size 32 --epochs 100

import argparse def get_config(): parser = argparse.ArgumentParser(description="Bilinear CNN Training") parser.add_argument("--data_dir", type=str, default="./images", help="Root directory of dataset (contains train/val folders)") parser.add_argument("--model_dir", type=str, default="./models", help="Directory to save trained models") parser.add_argument("--batch_size", type=int, default=32, help="Input batch size for training") parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train") parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight decay") parser.add_argument("--workers", type=int, default=4, help="Number of data loading workers") parser.add_argument("--optimizer", type=str, default="adam", choices=["sgd", "adam"], help="Optimizer to use (default: sgd)") parser.add_argument("--scheduler", type=str, default="reduce_on_plateau", choices=["reduce_on_plateau", "cosine_annealing"], help="Learning rate scheduler to use (default: reduce_on_plateau)") parser.add_argument("--patience", type=int, default=3, help="Patience for ReduceLROnPlateau scheduler (default: 3)") return parser.parse_args()

下面是一个trainer.py 这个是串联我们其他py的核心组件

1. 初始化方法 (__init__)

Trainer 类的初始化方法首先设置了一些基本参数,如设备类型(CPU 或 GPU)、训练目录等。它还定义了数据预处理的方式,并加载了训练和验证数据集。这里使用了 torchvision.transforms 来进行数据增强,包括随机水平翻转、随机裁剪等,以提高模型的泛化能力。

数据加载 使用 torchvision.datasets.ImageFolder 来加载数据集,该函数假设数据集按类别组织在不同的子文件夹中。数据预处理步骤包括调整大小、数据增强、转换为张量以及归一化处理。 模型创建 创建了一个 MobileNetV2Classifier 实例,如果存在多个 GPU,则使用 nn.DataParallel 来并行化模型训练。打印模型结构,方便调试和理解模型架构。 损失函数和优化器 定义了交叉熵损失函数 nn.CrossEntropyLoss(),适用于多分类问题。根据传入的参数选择合适的优化器(SGD 或 Adam),并且仅对分类头部分的参数进行优化。 学习率调度器 支持两种调度策略:ReduceLROnPlateau 和 CosineAnnealingLR,它们分别根据验证准确率的变化或按照余弦退火方式调整学习率。 2. 训练过程 (train 方法)

train 方法是整个训练流程的核心,它通过循环执行多次迭代(每个 epoch)来进行模型训练。每一轮迭代都包含以下几个步骤:

训练阶段:模型处于训练模式,对每个批次的数据进行前向传播计算损失,然后通过反向传播更新模型权重。验证阶段:切换到评估模式,不进行梯度计算,仅评估模型性能(准确率)。学习率调整:根据验证结果调整学习率。保存最佳模型:如果当前 epoch 的验证准确率优于历史最高值,则保存当前模型状态。

此外,还记录了每次迭代的训练损失、准确率及验证准确率,并通过 TrainingLogger 对象将其写入 CSV 文件以便后续分析。

3. 验证过程 (validate 方法)

validate 方法用于评估模型在验证集上的表现,它与训练阶段的主要区别在于不进行参数更新,而是单纯地计算模型预测的准确性。

4. 结果可视化 (plot_results 方法)

训练完成后,plot_results 方法会生成两张图表,一张展示训练过程中损失值的变化趋势,另一张则显示验证集上准确率的变化情况。这些图表有助于直观地了解模型的学习进程和性能改进情况。

5. 辅助函数

find_next_train_folder 函数用于自动查找下一个可用的训练目录名称,确保每次运行时都有独立的存储空间存放实验结果。

import os import time import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR from models import MobileNetV2Classifier from logger import print_model_structure, TrainingLogger import matplotlib.pyplot as plt class Trainer: def __init__(self, args): self.args = args self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 创建新的训练目录 self.train_dir = find_next_train_folder(self.args.model_dir) os.makedirs(self.train_dir, exist_ok=True) # 数据预处理 self.train_transform = transforms.Compose([ transforms.Resize(256), transforms.RandomHorizontalFlip(), transforms.RandomCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) data_dir = args.data_dir # 修改这个基础路径 train_dir = os.path.join(data_dir, "train") val_dir = os.path.join(data_dir, "val") # 检查数据集是否存在 if not os.path.exists(train_dir) or not os.path.exists(val_dir): raise ValueError(f"数据集不存在,请检查{data_dir}文件夹是否正确。" f"\nExpected directories: {train_dir} and {val_dir}") self.train_dataset = torchvision.datasets.ImageFolder( root=train_dir, transform=self.train_transform ) self.val_dataset = torchvision.datasets.ImageFolder( root=val_dir, transform=self.val_transform ) # 打印类别数量 self.num_classes = len(self.train_dataset.classes) print(f"Number of classes: {self.num_classes}") if self.num_classes <= 1: raise ValueError(f"数据集中必须包含至少两个类别,当前只有 {self.num_classes} 个类别。") # 创建模型 self.model = MobileNetV2Classifier(self.num_classes).to(self.device) if torch.cuda.device_count() > 1: self.model = nn.DataParallel(self.model) # 打印模型结构 print_model_structure(self.model) # 定义损失函数和优化器 self.criterion = nn.CrossEntropyLoss() if args.optimizer == "sgd": self.optimizer = optim.SGD( self.model.module.classifier.parameters() if isinstance(self.model, nn.DataParallel) else self.model.classifier.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay ) elif args.optimizer == "adam": self.optimizer = optim.Adam( self.model.module.classifier.parameters() if isinstance(self.model, nn.DataParallel) else self.model.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay ) else: raise ValueError(f"Unsupported optimizer: {args.optimizer}") # 定义学习率调度器 if args.scheduler == "reduce_on_plateau": self.scheduler = ReduceLROnPlateau( self.optimizer, mode='max', factor=0.1, patience=args.patience, verbose=True ) elif args.scheduler == "cosine_annealing": self.scheduler = CosineAnnealingLR( self.optimizer, T_max=args.epochs, eta_min=0, last_epoch=-1, verbose=True ) else: raise ValueError(f"Unsupported scheduler: {args.scheduler}") # 数据加载器 self.train_loader = DataLoader( self.train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True ) self.val_loader = DataLoader( self.val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True ) # 初始化训练日志记录器 result_csv_path = os.path.join(self.train_dir, "result.csv") self.logger = TrainingLogger(self.args.epochs, result_csv_path) # 记录每个epoch的损失和准确率 self.train_losses = [] self.val_accuracies = [] def train(self): best_acc = 0.0 print(f"Starting training with {self.num_classes} classes...") print(f"GPU Rem\tEpoch\tTrain Loss\tTrain Acc\tVal Acc\tTime\tRemaining") for epoch in range(1, self.args.epochs + 1): # 从1开始计数 start_time = time.time() # 训练阶段 self.model.train() train_loss = 0.0 correct = 0 total = 0 for inputs, labels in self.train_loader: inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() train_loss /= len(self.train_loader) train_acc = 100.0 * correct / total # 验证阶段 val_acc = self.validate() # 学习率调整 if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(val_acc) else: self.scheduler.step() # 保存最佳模型 if val_acc > best_acc: best_acc = val_acc model_path = os.path.join( self.train_dir, f"best_model_{val_acc:.2f}.pth" ) torch.save(self.model.state_dict(), model_path) # 获取GPU内存使用情况 gpu_mem_used = torch.cuda.memory_allocated(self.device) / 1e6 if torch.cuda.is_available() else 0 # 打印统计信息 epoch_time = time.time() - start_time remaining_time = (self.args.epochs - epoch) * epoch_time self.logger.log_epoch_info(epoch, train_loss, train_acc, val_acc, epoch_time, remaining_time, gpu_mem_used) # 更新tqdm进度条 self.logger.update_progress() # 记录损失和准确率 self.train_losses.append(train_loss) self.val_accuracies.append(val_acc) self.logger.close_progress() print(f"Best validation accuracy: {best_acc:.2f}%") # 绘制损失和准确率图 self.plot_results() def validate(self): self.model.eval() correct = 0 total = 0 with torch.no_grad(): for inputs, labels in self.val_loader: inputs = inputs.to(self.device) labels = labels.to(self.device) outputs = self.model(inputs) _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() acc = 100.0 * correct / total self.model.train() return acc def plot_results(self): epochs = list(range(1, self.args.epochs + 1)) plt.figure(figsize=(12, 6)) # 绘制训练损失 plt.subplot(1, 2, 1) plt.plot(epochs, self.train_losses, label='Training Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Training Loss over Epochs') plt.legend() # 绘制验证准确率 plt.subplot(1, 2, 2) plt.plot(epochs, self.val_accuracies, label='Validation Accuracy', color='orange') plt.xlabel('Epochs') plt.ylabel('Accuracy (%)') plt.title('Validation Accuracy over Epochs') plt.legend() # 保存图片 plot_path = os.path.join(self.train_dir, "training_results.png") plt.savefig(plot_path) plt.show() def find_next_train_folder(base_dir): i = 1 while True: folder_name = os.path.join(base_dir, f"train{i}") if not os.path.exists(folder_name): return folder_name i += 1

 最后就是运行脚本 主函数 train.py

from config import get_config from trainer import Trainer def main(): args = get_config() trainer = Trainer(args) trainer.train() if __name__ == "__main__": main()

好这就是我们的代码部分

请注意我们的数据集结构 

然后无需标注 只需划分train val即可

以上就是内容部分 如有问题请评论区或私信指正谢谢 !! 本科小白一枚 感谢观看! 

标签:

中药细粒度图像分类由讯客互联IT业界栏目发布,感谢您对讯客互联的认可,以及对我们原创作品以及文章的青睐,非常欢迎各位朋友分享到个人网站或者朋友圈,但转载请说明文章出处“中药细粒度图像分类