Advanced Intelligent Computing Technology and Applications: 19th International Conference, ICIC 2023, Zhengzhou, China, August 10–13, 2023, ... II (Lecture Notes in Computer Science, 14087) [1st ed. 2023] 9819947413, 9789819947416

This three-volume set of LNCS 14086, LNCS 14087 and LNCS 14088 constitutes - in conjunction with the double-volume set L

112 64 87MB

English Pages 834 [827] Year 2023

Report DMCA / Copyright

DOWNLOAD PDF FILE

Table of contents :
Preface
Organization
Contents – Part II
Neural Networks
Adversarial Ensemble Training by Jointly Learning Label Dependencies and Member Models
1 Introduction
2 Preliminary
2.1 Notations
2.2 Definitions
2.3 Adversarial Attacks
2.4 On Techniques to Generate Soft Labels
3 CLDL Assisted Ensemble Training
3.1 Diversity Promoting Loss Design
3.2 CLDL Based Ensemble Model Training
4 Experiments
4.1 Datasets and Competitor Methods
4.2 Optimizer Used for Training
4.3 White-Box Attacks
4.4 Black-Box Attacks
4.5 Experimental Results for Black-box Adversarial Attacks
5 Conclusion
References
PFGE: Parsimonious Fast Geometric Ensembling of DNNs
1 Introduction
2 Related Works
3 The Proposed PFGE Algorithm
4 Experiments
4.1 Experimental Setting
4.2 CIFAR Datasets
4.3 IMAGENET
4.4 Performance of Separate Models in The Ensemble
4.5 On Training Efficiency and Test-time Cost
4.6 Mode Connectivity Test
5 Conclusions
References
Research on Indoor Positioning Algorithm Based on Multimodal and Attention Mechanism
1 Introduction
2 Related Work
3 Multimodal Indoor Location Network
3.1 Embedding
3.2 Location Activate Unit
3.3 SoftReLU Activation Function
4 Experiment
4.1 Data Collection
4.2 Evaluation Indicators
4.3 Experimental Comparison Scheme
5 Conclusion
References
Undetectable Attack to Deep Neural Networks Without Using Model Parameters
1 Introduction
2 Preliminaries
2.1 Convolutional Neural Networks
2.2 Image Perturbation
2.3 Undirected Attack
2.4 Directed Attack
3 Algorithm
3.1 Beetle Antennae Search (BAS)
3.2 Optimization Algorithm
4 Experimental Studies
4.1 Evaluation Methodology
4.2 Results
5 Conclusions
References
A Traffic Flow Prediction Framework Based on Clustering and Heterogeneous Graph Neural Networks
1 Introduction
2 Related Work
3 Methodology
3.1 Problem Definition
3.2 Framework Overview
3.3 Clustering Module
3.4 Graph Transformer Module
3.5 Spatio-Temporal Feature Learning Module
4 Experiments and Results
4.1 Dataset
4.2 Experiment Setting
4.3 Analysis of Experimental Results
5 Conclusion
References
Effective Audio Classification Network Based on Paired Inverse Pyramid Structure and Dense MLP Block
1 Introduction
2 Methods
2.1 Paired Inverse Pyramid Structure
2.2 Positional Modelling for Time Domain and Depth Domain
2.3 Temporal MLP Based on Feed-Forward Structure
2.4 Depth Domain Block and Linear Skip Connection
2.5 Dense MLP
3 Experiments
3.1 Overview on the UrbanSound8K Dataset
3.2 Overview on the GTZAN Dataset
3.3 Training setup and Preprocessing for the Datasets
3.4 UrbanSound8K Experiments Results
3.5 GTZAN Experiments Results
3.6 Abalation Study
4 Discussion
5 Conclusion
References
Dynamic Attention Filter Capsule Network for Medical Images Segmentation
1 Introduction
2 Related Work
2.1 Capsule Network
2.2 Medical Image Segmentation
3 Proposed Method
3.1 DAF-CapsNet
3.2 DAF-CapsUNet Architecture
3.3 Object Function
4 Experiments
4.1 Datasets and Training Details
4.2 Comparison with Other CapsNets
4.3 Ablation Study on DAF
4.4 Complexity Analysis of CapsNet
4.5 Robustness to Adversarial Examples
4.6 Experiment Results on Synapse Dataset
4.7 Experiment Results on ACDC Dataset
5 Conclusion
References
Cross-Scale Dynamic Alignment Network for Reference-Based Super-Resolution
1 Introduction
2 Related Work
2.1 Single Image Super-Resolution
2.2 Reference-Based Super-Resolution
3 Proposed Method
3.1 Network Architecture
3.2 Improved Texture Matching Module
3.3 Cross-Scale Dynamic Correction Module
4 Experiments
4.1 Experimental Setup
4.2 Evaluation
4.3 Ablation Study
5 Conclusion
References
Solving Large-Scale Open Shop Scheduling Problem via Link Prediction Based on Graph Convolution Network
1 Introduction
2 Related Work
3 Preliminaries
3.1 Open Shop Scheduling Problem
3.2 Link Prediction and Disjunctive Graph
4 The Proposed Method
4.1 Overall Process of Solving OSSP Based on GCN-LP
4.2 Node Feature Design
4.3 GCN-Based Open Shop Scheduling Model
4.4 Open Shop Scheduling Algorithm Based on GCN-LP
5 Experiment
5.1 Experimental Settings
5.2 Compared with Meta-heuristic Algorithms
5.3 Compared with Other GNN Models
5.4 Analysis of the Solution Efficiency
6 Conclusions
References
Solving Class Imbalance Problem in Target Detection with a Squared Cross Entropy Based Method
1 Introduction
2 Focal Loss
3 Squared Cross Entropy
4 Experiments
4.1 Comparison Between CE and SCE Losses
4.2 SCE in Target Detection
5 Conclusion
References
Modeling Working Memory Using Convolutional Neural Networks for Knowledge Tracing
1 Introduction
2 Related Works
2.1 Knowledge Tracing
2.2 Working Memory Model
2.3 Convolutional Neural Networks
3 Model
3.1 Problem Definition
3.2 Exercise and Response Representation
3.3 Simulating Working Memory Model
3.4 Long-Term Memory Attention Network
3.5 Knowledge State and Prediction
4 Experiments
4.1 Experimental Settings
4.2 Overall Performance Prediction
4.3 Ablation Study
4.4 Effectiveness of Modeling WM with CNNs
4.5 Comparison in Terms of Training Details
4.6 Visualization of Knowledge States
5 Conclusions
References
A Multi-granularity Decision Fusion Method Based on Category Hierarchy
1 Introduction
2 Related Work
3 Method
3.1 Category Hierarchy
3.2 Decision Fusion
4 Experiment
4.1 Datasets
4.2 Category Hierarchy
4.3 Network Architecture
5 Results
5.1 Fusion Effects
5.2 Compared with the Original Network
5.3 Comparison with Other Methods
6 Conclusion
References
Make Active Attention More Active: Using Lipschitz Regularity to Improve Long Sequence Time-Series Forecasting
1 Introduction
2 Related Work
3 Methodology
3.1 How Informer Works?
3.2 Borrowing Ideas from GNNs
3.3 Implementation Details and Theoretical Analysis
4 Experiments
4.1 Dataset
4.2 Experimental Setup
4.3 Performance Comparison and Analysis
4.4 Does Our Method Work Only on Sparse Attention?
5 Conclusion
References
Attributed Multi-relational Graph Embedding Based on GCN
1 Introduction
2 Related Work
3 Problem Statement
4 The Proposed Model
4.1 Intra-relation Aggregation
4.2 Inter-relation Aggregation
4.3 Training
5 Experiments
5.1 Dataset
5.2 Baselines
5.3 Performance Analysis
5.4 Parameters Effects
5.5 Ablation Study
6 Conclusion
References
CharCaps: Character-Level Text Classification Using Capsule Networks
1 Introduction
2 Related Work
3 Proposed Method
3.1 Formalization
3.2 Feature Extraction Layer
3.3 Capsule Layer
3.4 Selection Mechanism of Special Characters
4 Experiment
4.1 Datasets and Baselines
4.2 Experimental Results
4.3 Ablation Experiments
4.4 Different Network Layers
5 Conclusions
References
Multi-student Collaborative Self-supervised Distillation
1 Introduction
2 Related Work
3 Rethinking Knowledge Distillation
3.1 Motivation and Overview
3.2 Loss Function
3.3 Multi-student Self-supervised Distillation
3.4 Multi-student Adaptive Inference
4 Experiment
4.1 Datasets and Metrics
4.2 Distillation with Isomerism and Heterogeneous
4.3 Comparison of Other Methods
4.4 Single Student and Multi-student Distillation Comparison
4.5 Distillation in Text Detection
4.6 Ablation
5 Discussion and Conclusion
References
Speech Emotion Recognition Using Global-Aware Cross-Modal Feature Fusion Network
1 Introduction
2 Proposed Methodology
2.1 Problem Statement
2.2 Feature Encoder
2.3 Residual Cross-Modal Fusion Attention Module
2.4 Global-Aware Fusion
2.5 CTC Layer
3 Experimental Evaluation
3.1 Dataset
3.2 Experimental Setup
3.3 Ablation Studies
3.4 Comparative Analysis
4 Conclusion
References
MT-1DCG: A Novel Model for Multivariate Time Series Classification
1 Introduction
2 Deep Learning Methods
3 Data Preprocessing and Augmentation
3.1 Dataset
3.2 Preprocessing
3.3 Data Augmentation
4 Proposed Models
4.1 A-BiGRU
4.2 ST-1DCG
4.3 MT-1DCG
5 Experiments
5.1 Comparison of Proposed Models
5.2 Comparison Between MT-1DCG and Other External Models
6 Conclusion
References
Surrogate Modeling for Soliton Wave of Nonlinear Partial Differential Equations via the Improved Physics-Informed Deep Learning
1 Introduction
2 Method
2.1 Physics-Informed Neural Networks
2.2 The Gradient-Enhanced Physics-Informed Neural Networks with Self-adaptive Loss Function
3 Experiments and Results
3.1 The Single Soliton Solution of CDGSK Equation
3.2 The Multi-soliton Solution of CDGSK Equation
4 Conclusion
References
TAHAR: A Transferable Attention-Based Adversarial Network for Human Activity Recognition with RFID
1 Introduction
2 Related Work
2.1 RFID-Based Activity Recognition
2.2 Self-attention Method
2.3 Domain Adversarial Adaptation
3 Detailed Design of TAHAR
3.1 Signal Preprocessing
3.2 Feature Extractor
3.3 Self-attention Module and Batch Spectral Penalization
3.4 Activity Predictor
3.5 Domain Discriminator
3.6 Optimization Objective
4 Experiments
4.1 General Performance
4.2 Ablation Experiments
4.3 Comparative Experiments
4.4 Feature Visualization
5 Conclusion
References
Pattern Recognition
Improved Blind Image Denoising with DnCNN
1 Introduction
2 Proposed Method
3 Experiments
4 Conclusions
References
Seizure Prediction Based on Hybrid Deep Learning Model Using Scalp Electroencephalogram
1 Introduction
2 Related Work
3 Proposed Methods
3.1 Method Overview
3.2 CHB-MIT EEG Dataset
3.3 Preprocessing
3.4 Deep Learning Models
3.5 Postprocessing
4 Result
5 Discussion
6 Conclusion
References
Data Augmentation for Environmental Sound Classification Using Diffusion Probabilistic Model with Top-K Selection Discriminator
1 Introduction
2 Method
2.1 DPMs
2.2 DPM-Solver++ and DPM-Solver
2.3 Data Augmentation
2.4 Top-k Selection Pretrained Discriminator
2.5 DL Models for ESC
3 Experiments
3.1 Experiments Pipeline and Dataset
3.2 Hyperparameters Setting
3.3 Experiments Results
4 Conclusion
References
Improved DetNet Algorithm Based on GRU for Massive MIMO Systems
1 Introduction
2 System Model
2.1 Real-valued Models for MIMO Systems
2.2 DetNet
3 Improved DetNet Model
3.1 GRU-DetNet
3.2 Hybrid-DetNet
4 Simulation Results Analysis
4.1 Simulation Settings
4.2 Simulation Results
5 Conclusions
References
Epileptic Seizure Detection Based on Feature Extraction and CNN-BiGRU Network with Attention Mechanism
1 Introduction
2 EEG Database
3 Methods
3.1 Pre-processing
3.2 Feature Extraction
3.3 Classification Model (CNN-BiGRU with Attention Mechanism Model)
3.4 Post-processing
4 Results
5 Discussion
6 Conclusion
References
Information Potential Based Rolling Bearing Defect Classification and Diagnosis
1 Introduction
2 Information Potential for Bearing Fault Detection
2.1 The Definition of Information Potential
2.2 Feature Extraction and Classification for Rolling Bearing
2.3 Main Steps of Information Potential Based Bearing Fault Classification
3 Experimental Verification of Bearing Fault Detection
4 Conclusions
References
Electrocardiogram Signal Noise Reduction Application Employing Different Adaptive Filtering Algorithms
1 Introduction
2 Noise Cancellation
2.1 Least-Mean-Square (LMS)
2.2 Normalized Least-Mean-Square (NLMS)
2.3 Recursive-Least-Squares (RLS)
2.4 Wiener Filter
3 Experimental Results
3.1 Input Signal Corrupted with Noise
3.2 Audio Signal Corrupted with Echo
3.3 Effect of Input Parameters
3.4 Electrocardiogram (ECG) Signal Noise Reduction
4 Conclusion
References
Improving the Accuracy of Deep Learning Modelling Based on Statistical Calculation of Mathematical Equations
1 Introduction
2 Related Works
2.1 Fuzzy Neural Network
2.2 Deep Convolution Neural Network
3 Statistical Calculation of Mathematical Equations
4 Metrics for Model Evaluation
5 Results
6 Conclusion
References
Deep Learning for Cardiotocography Analysis: Challenges and Promising Advances
1 Introduction
2 Taxonomy
2.1 Methodology
2.2 Machine Learning and Deep Learning
3 Deep Learning Driven Research
3.1 Artificial Neural Network
3.2 Convolutional Neural Network
3.3 Recurrent Neural Network
3.4 Hybrid Network
3.5 Others
4 Results
5 Recommendation for Future Research
5.1 Explainable AI
5.2 Transfer Learning
5.3 Generative Deep Learning
6 Conclusion
References
Exploiting Active-IRS by Maximizing Throughput in Wireless Powered Communication Networks
1 Introduction
2 System Model and Problem Formulation
2.1 Energy Transmission
2.2 Information Transmission
2.3 Optimization Problem
3 Proposed Solution
3.1 Find d* by Fixing 0, and u
3.2 Find u* by Fixing 0, and d
4 Performance Evaluation
5 Conclusion
References
Palmprint Recognition Utilizing Modified LNMF Method
1 Introduction
2 The LNMF Algorithm
3 The Modified LNMF Algorithm
3.1 The Cost Function
3.2 The Updating Rules
4 RBPNN Classifier
5 Experimental Results and Analysis
5.1 Test Data Preprocessing
5.2 Learning of Feature Bases
5.3 Recognition Results
6 Conclusions
References
PECA-Net: Pyramidal Attention Convolution Residual Network for Architectural Heritage Images Classification
1 Introduction
2 Related Work
2.1 Heritage Buildings Classification
2.2 Multi-scale Convolution
2.3 Attention Mechanism
3 PECA-Net
3.1 The Overall Framework of PECA-Net
3.2 Pyramidal Convolution
3.3 Dual-Pooling ECA Attention Mechanism Module
4 Experiments and Results
4.1 Datasets
4.2 Data Augmentation
4.3 Experimental Setup and Evaluation Metrics
4.4 Impact of Insertion Position on DP-ECA Module Performance
4.5 Comparison with Other Models
4.6 Ablation Experiment
5 Conclusion
References
Convolutional Self-attention Guided Graph Neural Network for Few-Shot Action Recognition
1 Introduction
2 Related Work
3 Method
3.1 Problem Formulation
3.2 Feature Embedding
3.3 Sequence Distance Learning Module
3.4 Training and Inference
4 Experiment
4.1 Datasets
4.2 Implementation Details
4.3 Experimental Results
4.4 Ablation Study
5 Conclusion
References
DisGait: A Prior Work of Gait Recognition Concerning Disguised Appearance and Pose
1 Introduction
2 Related Work
3 The DisGait Dataset
3.1 Data Collection Protocol
3.2 Data Statistics
3.3 Gait Sequence Extraction
4 Baselines on DisGait
4.1 SOTA Methods
4.2 Evaluation Protocol
5 Experiments
5.1 Experiment 1: Baseline Results
5.2 Experiment 2: Different Frames
5.3 Experiment 3: Different Body Parts
6 Conclusion
References
Image Processing
An Experimental Study on MRI Denoising with Existing Image Denoising Methods
1 Introduction
2 An Experimental Study
3 Conclusions and Future Research
4 Data Availability Statement
References
Multi-scale and Self-mutual Feature Distillation
1 Introduction
2 The Proposed Approach
2.1 Self-mutual Feature Distillation
2.2 Multi-scale Feature Distillation
2.3 Overall Optimisation Objective
3 Experiment
3.1 CIFAR100 Classification
3.2 ImageNet Classification
3.3 Semantic Segmentation
3.4 Ablation Study
4 Conclusion
References
A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation and Guidance in Off-Road Environments
1 Introduction
2 Related Work
2.1 Semantic Segmentation
2.2 Multi-task Learning
2.3 Learning Task Inter-dependencies
3 Method
3.1 The Proposed VPrs-Net
4 Experiments
4.1 Datasets
4.2 Implementation Details
4.3 Metrics
4.4 Loss
4.5 The Results of Road Segmentation
4.6 The Results of VP Detection.
4.7 Ablation Study
5 Conclusions
References
DCNet: Glass-Like Object Detection via Detail-Guided and Cross-Level Fusion
1 Introduction
2 Method
2.1 Label Decoupling
2.2 Multi-scale Detail Interaction Module
2.3 Body-Induced Cross-Level Fusion Module
2.4 Attention-Induced Aggregation Module
2.5 Loss Function
3 Experiments
3.1 Datasets and Evaluation Metrics
3.2 Comparisons with State-of-the-Art Methods
3.3 Ablation Studies
4 Conclusion
References
A Method for Detecting and Correcting Specular Highlights in Capsule Endoscope Images Based on Independent Cluster Distribution
1 Introduction
2 Related Work
2.1 Highlight Detection of Capsule Endoscope Images
2.2 Highlight Correction of Capsule Endoscope Images
3 Proposed Method
3.1 Method Overview
3.2 Threshold Adaptive Highlight Detection for Color Channel Differences
3.3 Highlight Correction for Region Clustering
4 Experiments
4.1 Highlight Detection
4.2 Highlight Detection
4.3 Highlight Detection
4.4 Highlight Correction
5 Conclusion
References
Siamese Adaptive Template Update Network for Visual Tracking
1 Introduction
2 Related Works
3 Proposed Methods
3.1 Siamese Backbone Network
3.2 Adaptive Template Update Module
3.3 Enhanced Multi-attention Module
3.4 Training Details and Structure
4 Experiments
4.1 Implementation Details
4.2 Results on GOT-10K
4.3 Results on Other Datasets
4.4 Ablation Experiment
5 Conclusion
References
Collaborative Encoder for Accurate Inversion of Real Face Image
1 Introduction
2 Method
2.1 Motivations
2.2 Encoder Architecture
2.3 Losses
3 Experiment and Results
3.1 Dataset
3.2 Metrics
3.3 Quantitative Evaluation
3.4 Qualitative Evaluation
4 Application
5 Conclusion
References
Text-Guided Generative Adversarial Network for Image Emotion Transfer
1 Introduction
2 Related Works
2.1 Image Emotion Transfer
2.2 Text-Guided Image Generation and Manipulation
3 Text-Drived Emotional Generative Adversarial Network
3.1 Disentangled Representations of NHC and ELF
3.2 Loss Functions
3.3 Implementation Details
4 Experiments
4.1 Image Dataset
4.2 Qualitative Results
4.3 Quantitative Evaluation
5 Conclusion
References
What Constitute an Effective Edge Detection Algorithm?
1 Introduction
1.1 Canny Edge Detection
1.2 Kovalevsky Edge Detection
2 Machine Learning Techniques
2.1 Edge Detection Method Based on Multi-scale Feature Fusion
3 Conclusion
References
SporeDet: A Real-Time Detection of Wheat Scab Spores
1 Introduction
2 Related Work
2.1 Spore Detection
2.2 Object Detection
3 Proposed Method
3.1 The Base Framework
3.2 Task-Decomposition Channel Attention Head
3.3 Feature Reconstruction Loss
3.4 Overall Optimization
4 Experiments
4.1 Dataset and Evaluation Metrics
4.2 Implementation Details
4.3 Comparisons with Existing Methods
4.4 Ablation Studies
5 Conclusions and Future Work
References
Micro-expression Recognition Based on Dual-Branch Swin Transformer Network
1 Introduction
2 Related Works
3 Proposed Method
3.1 Motion Feature Extraction Using Optical Flow
3.2 Spatial Feature Representation Based on Swin Transformer
4 Experiments and Evaluation
4.1 Databases
4.2 Experimental Results and Analysis
5 Conclusion
References
CC-DBNet: A Scene Text Detector Combining Collaborative Learning and Cascaded Feature Fusion
1 Introduction
2 Methodology
2.1 Overall Architecture
2.2 Intra-instance Collaborative Learning
2.3 Cascaded Feature Fusion Module
2.4 Differentiable Binarization
2.5 Label Generation
2.6 Optimization
3 Experiments
3.1 Datasets
3.2 Implementation Details
3.3 Evaluation Metrics
3.4 Ablation Study
3.5 Comparisons with Previous Methods
4 Conclusion
References
A Radar Video Compression and Display Method Based on FPGA
1 Background
2 System Composition
3 Compression in a Single Frame
3.1 Brightness Data Conversion
3.2 Process of Point Selection
4 Compression in Interframe
5 Display of Radar Video
6 A Morphological Method for Eliminating Isolated Points
7 Conclusion
References
SA-GAN: Chinese Character Style Transfer Based on Skeleton and Attention Model
1 Introduction
2 Related Work
2.1 Image-To-Image Translation
2.2 Font Generation
3 Method
3.1 Attention Module
3.2 SCSK Module
3.3 Loss Function
4 Experiment
4.1 Comparison with State-of-Art Methods
5 Ablation Experiment
6 Conclusion
References
IL-YOLOv5: A Ship Detection Method Based on Incremental Learning
1 Introduction
2 Related Works
2.1 Object Detection
2.2 Incremental Learning Method
3 Methodology
3.1 Image Feature Information Enhancement
3.2 IL-YOLOv5 Based on Incremental Learning
4 Experiments
4.1 Implementation Details
4.2 Introduction to the Dataset
4.3 Ablation Studies
4.4 Comparison to Other Model
4.5 Results on Ship
5 Conclusions
References
Efficient and Precise Detection of Surface Defects on PCBs: A YOLO Based Approach
1 Introduction
2 Related Works
2.1 The Public PCB Dataset
2.2 Single Stage Object Detection Networks
2.3 Attention Mechanisms in Convolutional Neural Networks
2.4 C3 Module that Contains Coordinate Attention Mechanism
2.5 WIoUv1 Bounding Box Regression Loss Function for the Optimization of PCB Fault Detection Models
3 Experiments
3.1 Data Enhancement
3.2 Experiments on Original Models
3.3 Comparison Between Different Attention Mechanisms
3.4 Visualization and Discussion
4 Conclusion
References
Corneal Ulcer Automatic Classification Network Based on Improved Mobile ViT
1 Introduction
2 Related Work
3 Method
3.1 MV2-SE Block
3.2 A New Mobile ViT Block
3.3 Improved Mobile ViT Architecture
4 Experiment
4.1 Dataset
4.2 Experiment Result
5 Conclusion
References
Multiple Classification Network of Concrete Defects Based on Improved EfficientNetV2
1 Introduction
2 Related Work
2.1 EfficientNetV2
2.2 Concrete Structural Defect Classification
3 Dataset and Methods
3.1 Dataset
3.2 Methods
3.3 Multi-class Multi-target Evaluation
4 Results and Discussion
4.1 Experimental Equipment and Data Set Splitting
4.2 Model Building
4.3 Training Process
4.4 Results
5 Conclusion
References
A Lightweight Hyperspectral Image Super-Resolution Method Based on Multiple Attention Mechanisms
1 Introduction
2 Methodology
2.1 Network Architecture
2.2 Large Kernel Pixel Attention Network
2.3 Deep Transformer Self-attention Residual Module
2.4 Contextual Incremental Fusion Module
2.5 Loss Function
3 Experimental Result
3.1 Dataset and Experimental Details
3.2 Evaluating Indicator
3.3 Results and Analysis
4 Conclusion
References
Graph Disentangled Representation Based Semi-supervised Single Image Dehazing Network
1 Introduction
2 Proposed GDSDN
2.1 Main Backbone
2.2 Encoder Network
2.3 Graph-Disentangled Representation Network
2.4 Reconstruction Network
2.5 Object Function
3 Experiments
3.1 Experimental Setup
3.2 Main Results
4 Conclusion
References
Surface Target Saliency Detection in Complex Environments
1 Introduction
2 Related Work
3 Proposed Method
3.1 General Architecture
3.2 Perceptual Field Enhancement Module
3.3 Adjacent Feature Fusion Module
3.4 Supervision
4 Experimental Results
4.1 Implementation Details
4.2 Datasets and Evaluation Metrics
4.3 Comparison Experiments
4.4 Ablation Studies
5 Conclusion
References
Traffic Sign Recognition Based on Improved VGG-16 Model
1 Introduction
2 Traffic Sign Recognition Method
2.1 Weighted-Hybrid Loss Function
2.2 Hybrid Dilated Convolution and VGG
3 Data Set and Model Parameter Settings
3.1 Data Set Preparation
3.2 Network Parameter Setting
4 Experimental Results
4.1 Training Process
4.2 Recognition Results
5 Conclusion
References
A Weakly Supervised Semantic Segmentation Method on Lung Adenocarcinoma Histopathology Images
1 Introduction
2 Methods
2.1 Feature Extraction
2.2 Enlargement of Class Activation Maps
2.3 Class Re-Activation Mapping
2.4 Implement Details
2.5 Datasets
3 Results
3.1 Ablation Experiments
3.2 Comparison Experiments
4 Discussion
5 Conclusions
References
Improved Lane Line Detection Algorithms Based on Incomplete Line Fitting
1 Introduction
2 Algorithm Design
3 Image Preprocessing
3.1 Camera Calibration
3.2 Inverse Perspective Transformation
4 Lane Lines Detection
4.1 Edge Detection Based on Sobel Operator
4.2 Locate the Left and Right Base Points of the Lane Lines
4.3 Sliding Window Search
4.4 Quadratic Polynomial Fitting
5 Experiment
5.1 Actual Lane Detection Results
5.2 Analysis and Comparison with Other Lane Detection Methods
6 Conclusion
References
A Deep Transfer Fusion Model for Recognition of Acute Lymphoblastic Leukemia with Few Samples
1 Introduction
1.1 Related Work
2 Proposed Network Framework
3 Experiments and Results
3.1 Dataset Description
3.2 Evaluation Metrics
3.3 Model Evaluation and Comparison Result
3.4 Generalization Performance
3.5 Image Feature Visualization
4 Conclusion
References
IntrNet: Weakly Supervised Segmentation of Thyroid Nodules Based on Intra-image and Inter-image Semantic Information
1 Introduce
2 Methods
2.1 EAM
2.2 SCSM-FB
2.3 Loss
3 Experiments
3.1 Comparison Experiments
3.2 Ablation Study
4 Conclusion
References
Computational Intelligence and Its Application
Novel Ensemble Method Based on Improved k-nearest Neighbor and Gaussian Naive Bayes for Intrusion Detection System
1 Introduction
2 Methodology
2.1 Feature Reduction
2.2 Ensemble Method
3 Experiment and Discussion
3.1 Benchmark Datasets
3.2 Dataset Preprocessing
3.3 Experimental Procedure
3.4 Results and Discussion
4 Conclusions
References
A Hybrid Queueing Search and Gradient-Based Algorithm for Optimal Experimental Design
1 Introduction
2 Backgrounds
3 Queueing Search Algorithm and Its Improvement
3.1 Queueing Search Algorithm
3.2 Multiplicative Algorithm
3.3 Proposed Algorithm
4 Numerical Examples Settings and Analyses
4.1 Numerical Examples Settings
4.2 Results and Analyses
5 Conclusion
References
A Review of Client Selection Mechanisms in Heterogeneous Federated Learning
1 Introduction
2 Background
2.1 Research on Statistical Heterogeneity
2.2 Research on System Heterogeneity
3 Client selection methods in Federated Learning
3.1 Client Selection Based on Trust Level
3.2 Client Selection Based on Time Threshold
3.3 Client Selection Based On Reinforcement Learning
3.4 Client Selection Based on Probability
3.5 Discussion
4 Future Research Directions
5 Conclusion
References
ARFG: Attach-Free RFID Finger-Tracking with Few Samples Based on GAN
1 Introduction
2 Related Work
2.1 Attach-Based Recognition
2.2 Attach-Free Recognition
2.3 Few-Shot Learning in RFID
3 Preliminaries
3.1 Definition of RFID Signal
4 System Design
4.1 Problem Definition
4.2 System Overview
5 Performance Evaluation
5.1 Experimental Setup
5.2 Experimental Results
5.3 Ablation Study
6 Conclusion
References
GM(1,1) Model Based on Parallel Quantum Whale Algorithm and Its Application
1 Introduction
2 GM(1,1) Prediction Model
2.1 Data Testing and Processing
2.2 Testing and Processing of Data
2.3 Improved GM(1,1) Model
3 Whale Optimization Algorithm and Quantum Whale Optimization Algorithm
3.1 Whale Optimization Algorithm
3.2 Parallel Quantum Whale Algorithm
4 QWOA-EFGM(1,1) Model Construction
5 Prediction and Analysis
5.1 Inspection and Processing of Data
5.2 Analysis of Experimental Results
6 Conclusion
References
3D Path Planning Based on Improved Teaching and Learning Optimization Algorithm
1 Introduction
2 Establishment and Analysis of Models
2.1 Environmental Model
2.2 Cubic B-spline Curve Path
2.3 Fitness Function
3 Teaching and Learning Algorithms
3.1 Teaching Stage
3.2 Learning Stage
4 Strategies for Improving Teaching and Learning Algorithms
4.1 Group Teaching
4.2 Autonomous Learning
4.3 Steps for Improving Algorithm Implementation
5 Experimental Results and Analysis
5.1 Simulation Environment and Experimental Data
5.2 Simulation Results and Analysis
6 Conclusion
References
Author Index
Recommend Papers

Advanced Intelligent Computing Technology and Applications: 19th International Conference, ICIC 2023, Zhengzhou, China, August 10–13, 2023, ... II (Lecture Notes in Computer Science, 14087) [1st ed. 2023]
 9819947413, 9789819947416

  • 0 0 0
  • Like this paper and download? You can publish your own PDF file online for free in a few minutes! Sign Up
File loading please wait...
Citation preview

LNCS 14087

De-Shuang Huang · Prashan Premaratne · Baohua Jin · Boyang Qu · Kang-Hyun Jo · Abir Hussain (Eds.)

Advanced Intelligent Computing Technology and Applications 19th International Conference, ICIC 2023 Zhengzhou, China, August 10–13, 2023 Proceedings, Part II

Lecture Notes in Computer Science Founding Editors Gerhard Goos Juris Hartmanis

Editorial Board Members Elisa Bertino, Purdue University, West Lafayette, IN, USA Wen Gao, Peking University, Beijing, China Bernhard Steffen , TU Dortmund University, Dortmund, Germany Moti Yung , Columbia University, New York, NY, USA

14087

The series Lecture Notes in Computer Science (LNCS), including its subseries Lecture Notes in Artificial Intelligence (LNAI) and Lecture Notes in Bioinformatics (LNBI), has established itself as a medium for the publication of new developments in computer science and information technology research, teaching, and education. LNCS enjoys close cooperation with the computer science R & D community, the series counts many renowned academics among its volume editors and paper authors, and collaborates with prestigious societies. Its mission is to serve this international community by providing an invaluable service, mainly focused on the publication of conference and workshop proceedings and postproceedings. LNCS commenced publication in 1973.

De-Shuang Huang · Prashan Premaratne · Baohua Jin · Boyang Qu · Kang-Hyun Jo · Abir Hussain Editors

Advanced Intelligent Computing Technology and Applications 19th International Conference, ICIC 2023 Zhengzhou, China, August 10–13, 2023 Proceedings, Part II

Editors De-Shuang Huang Department of Computer Science Eastern Institute of Technology Zhejiang, China Baohua Jin Zhengzhou University of Light Industry Zhengzhou, China Kang-Hyun Jo University of Ulsan Ulsan, Korea (Republic of)

Prashan Premaratne University of Wollongong North Wollongong, NSW, Australia Boyang Qu Zhong Yuan University of Technology Zhengzhou, China Abir Hussain Department of Computer Science Liverpool John Moores University Liverpool, UK

ISSN 0302-9743 ISSN 1611-3349 (electronic) Lecture Notes in Computer Science ISBN 978-981-99-4741-6 ISBN 978-981-99-4742-3 (eBook) https://doi.org/10.1007/978-981-99-4742-3 © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 This work is subject to copyright. All rights are reserved by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors, and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd. The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore

Preface

The International Conference on Intelligent Computing (ICIC) was started to provide an annual forum dedicated to emerging and challenging topics in artificial intelligence, machine learning, pattern recognition, bioinformatics, and computational biology. It aims to bring together researchers and practitioners from both academia and industry to share ideas, problems, and solutions related to the multifaceted aspects of intelligent computing. ICIC 2023, held in Zhengzhou, China, August 10–13, 2023, constituted the 19th International Conference on Intelligent Computing. It built upon the success of ICIC 2022 (Xi’an, China), ICIC 2021 (Shenzhen, China), ICIC 2020 (Bari, Italy), ICIC 2019 (Nanchang, China), ICIC 2018 (Wuhan, China), ICIC 2017 (Liverpool, UK), ICIC 2016 (Lanzhou, China), ICIC 2015 (Fuzhou, China), ICIC 2014 (Taiyuan, China), ICIC 2013 (Nanning, China), ICIC 2012 (Huangshan, China), ICIC 2011 (Zhengzhou, China), ICIC 2010 (Changsha, China), ICIC 2009 (Ulsan, South Korea), ICIC 2008 (Shanghai, China), ICIC 2007 (Qingdao, China), ICIC 2006 (Kunming, China), and ICIC 2005 (Hefei, China). This year, the conference concentrated mainly on theories and methodologies as well as emerging applications of intelligent computing. Its aim was to unify the picture of contemporary intelligent computing techniques as an integral concept that highlights the trends in advanced computational intelligence and bridges theoretical research with applications. Therefore, the theme for this conference was “Advanced Intelligent Computing Technology and Applications”. Papers that focused on this theme were solicited, addressing theories, methodologies, and applications in science and technology. ICIC 2023 received 828 submissions from 12 countries and regions. All papers went through a rigorous peer-review procedure and each paper received at least three review reports. Based on the review reports, the Program Committee finally selected 337 high-quality papers for presentation at ICIC 2023, and inclusion in five volumes of proceedings published by Springer: three volumes of Lecture Notes in Computer Science (LNCS), and two volumes of Lecture Notes in Artificial Intelligence (LNAI). This volume of LNCS_14087 includes 66 papers. The organizers of ICIC 2023, including Eastern Institute of Technology, China Zhongyuan University of Technology, China, and Zhengzhou University of Light Industry, China, made an enormous effort to ensure the success of the conference. We hereby would like to thank the members of the Program Committee and the referees for their collective effort in reviewing and soliciting the papers. In particular, we would like to thank all the authors for contributing their papers. Without the high-quality submissions from the authors, the success of the conference would not have been possible. Finally,

vi

Preface

we are especially grateful to the International Neural Network Society, and the National Science Foundation of China for their sponsorship. June 2023

De-Shuang Huang Prashan Premaratne Boyang Qu Baohua Jin Kang-Hyun Jo Abir Hussain

Organization

General Co-chairs De-Shuang Huang Shizhong Wei

Eastern Institute of Technology, China Zhengzhou University of Light Industry, China

Program Committee Co-chairs Prashan Premaratne Baohua Jin Kang-Hyun Jo Abir Hussain

University of Wollongong, Australia Zhengzhou University of Light Industry, China University of Ulsan, Republic of Korea Liverpool John Moores University, UK

Organizing Committee Co-chair Hui Jing

Zhengzhou University of Light Industry, China

Organizing Committee Members Fubao Zhu Qiuwen Zhang Haodong Zhu Wei Huang Hongwei Tao Weiwei Zhang

Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China

Award Committee Co-chairs Michal Choras Hong-Hee Lee

Bydgoszcz University of Science and Technology, Poland University of Ulsan, Republic of Korea

viii

Organization

Tutorial Co-chairs Yoshinori Kuno Phalguni Gupta

Saitama University, Japan Indian Institute of Technology Kanpur, India

Publication Co-chairs Valeriya Gribova M. Michael Gromiha Boyang Qu

Far Eastern Branch of Russian Academy of Sciences, Russia Indian Institute of Technology Madras, India Zhengzhou University, China

Special Session Co-chairs Jair Cervantes Canales Chenxi Huang Dhiya Al-Jumeily

Autonomous University of Mexico State, Mexico Xiamen University, China Liverpool John Moores University, UK

Special Issue Co-chairs Kyungsook Han Laurent Heutte

Inha University, Republic of Korea Université de Rouen Normandie, France

International Liaison Co-chair Prashan Premaratne

University of Wollongong, Australia

Workshop Co-chairs Yu-Dong Zhang Hee-Jun Kang

University of Leicester, UK University of Ulsan, Republic of Korea

Organization

ix

Publicity Co-chairs Chun-Hou Zheng Dhiya Al-Jumeily Jair Cervantes Canales

Anhui University, China Liverpool John Moores University, UK Autonomous University of Mexico State, Mexico

Exhibition Contact Co-chair Fubao Zhu

Zhengzhou University of Light Industry, China

Program Committee Members Abir Hussain Antonio Brunetti Antonino Staiano Bin Liu Bin Qian Bin Yang Bing Wang Binhua Tang Bingqiang Liu Bo Li Changqing Shen Chao Song Chenxi Huang Chin-Chih Chang Chunhou Zheng Chunmei Liu Chunquan Li Dahjing Jwo Dakshina Ranjan Kisku Dan Feng Daowen Qiu Dharmalingam Muthusamy Dhiya Al-Jumeily Dong Wang

Liverpool John Moores University, UK Polytechnic University of Bari, Italy Università di Napoli Parthenope, Italy Beijing Institute of Technology, China Kunming University of Science and Technology, China Zaozhuang University, China Anhui University of Technology, China Hohai University, China Shandong University, China Wuhan University of Science and Technology, China Soochow University, China Harbin Medical University, China Xiamen University, China Chung Hua University, Taiwan Anhui University, China Howard University, USA University of South China, China National Taiwan Ocean University, Taiwan National Institute of Technology Durgapur, India Huazhong University of Science and Technology, China Sun Yat-sen University, China Bharathiar University, India Liverpool John Moores University, UK University of Jinan, China

x

Organization

Dunwei Gong Eros Gian Pasero Evi Sjukur Fa Zhang Fengfeng Zhou Fei Guo Gaoxiang Ouyang Giovanni Dimauro Guoliang Li Han Zhang Haibin Liu Hao Lin Haodi Feng Hongjie Wu Hongmin Cai Jair Cervantes Jixiang Du Jing Hu Jiawei Luo Jian Huang Jian Wang Jiangning Song Jinwen Ma Jingyan Wang Jinxing Liu Joaquin Torres-Sospedra Juan Liu Jun Zhang Junfeng Xia Jungang Lou Kachun Wong Kanghyun Jo Khalid Aamir Kyungsook Han L. Gong Laurent Heutte

China University of Mining and Technology, China Politecnico di Torino, Italy Monash University, Australia Beijing Institute of Technology, China Jilin University, China Central South University, China Beijing Normal University, China University of Bari, Italy Huazhong Agricultural University, China Nankai University, China Beijing University of Technology, China University of Electronic Science and Technology of China, China Shandong University, China Suzhou University of Science and Technology, China South China University of Technology, China Autonomous University of Mexico State, Mexico Huaqiao University, China Wuhan University of Science and Technology, China Hunan University, China University of Electronic Science and Technology of China, China China University of Petroleum, China Monash University, Australia Peking University, China Abu Dhabi Department of Community Development, UAE Qufu Normal University, China Universidade do Minho, Portugal Wuhan University, China Anhui University, China Anhui University, China Huzhou University, China City University of Hong Kong, China University of Ulsan, Republic of Korea University of Sargodha, Pakistan Inha University, Republic of Korea Nanjing University of Posts and Telecommunications, China Université de Rouen Normandie, France

Organization

Le Zhang Lejun Gong Liang Gao Lida Zhu Marzio Pennisi Michal Choras Michael Gromiha Ming Li Minzhu Xie Mohd Helmy Abd Wahab Nicola Altini Peng Chen Pengjiang Qian Phalguni Gupta Prashan Premaratne Pufeng Du Qi Zhao Qingfeng Chen Qinghua Jiang Quan Zou Rui Wang Saiful Islam Seeja K. R. Shanfeng Zhu Shikui Tu Shitong Wang Shixiong Zhang Sungshin Kim Surya Prakash Tatsuya Akutsu Tao Zeng Tieshan Li Valeriya Gribova

Vincenzo Randazzo

xi

Sichuan University, China Nanjing University of Posts and Telecommunications, China Huazhong Univ. of Sci. & Tech., China Huazhong Agriculture University, China University of Eastern Piedmont, Italy Bydgoszcz University of Science and Technology, Poland Indian Institute of Technology Madras, India Nanjing University, China Hunan Normal University, China Universiti Tun Hussein Onn Malaysia, Malaysia Polytechnic University of Bari, Italy Anhui University, China Jiangnan University, China GLA University, India University of Wollongong, Australia Tianjin University, China University of Science and Technology Liaoning, China Guangxi University, China Harbin Institute of Technology, China University of Electronic Science and Technology of China, China National University of Defense Technology, China Aligarh Muslim University, India Indira Gandhi Delhi Technical University for Women, India Fudan University, China Shanghai Jiao Tong University, China Jiangnan University, China Xidian University, China Pusan National University, Republic of Korea IIT Indore, India Kyoto University, Japan Guangzhou Laboratory, China University of Electronic Science and Technology of China, China Institute of Automation and Control Processes, Far Eastern Branch of Russian Academy of Sciences, Russia Politecnico di Torino, Italy

xii

Organization

Waqas Haider Wen Zhang Wenbin Liu Wensheng Chen Wei Chen Wei Peng Weichiang Hong Weidong Chen Weiwei Kong Weixiang Liu Xiaodi Li Xiaoli Lin Xiaofeng Wang Xiao-Hua Yu Xiaoke Ma Xiaolei Zhu Xiangtao Li Xin Zhang Xinguo Lu Xingwei Wang Xinzheng Xu Xiwei Liu Xiyuan Chen Xuequn Shang Xuesong Wang Yansen Su Yi Xiong Yu Xue Yizhang Jiang Yonggang Lu Yongquan Zhou Yudong Zhang Yunhai Wang Yupei Zhang Yushan Qiu

Kohsar University Murree, Pakistan Huazhong Agricultural University, China Guangzhou University, China Shenzhen University, China Chengdu University of Traditional Chinese Medicine, China Kunming University of Science and Technology, China Asia Eastern University of Science and Technology, Taiwan Shanghai Jiao Tong University, China Xi’an University of Posts and Telecommunications, China Shenzhen University, China Shandong Normal University, China Wuhan University of Science and Technology, China Hefei University, China California Polytechnic State University, USA Xidian University, China Anhui Agricultural University, China Jilin University, China Jiangnan University, China Hunan University, China Northeastern University, China China University of Mining and Technology, China Tongji University, China Southeast Univ., China Northwestern Polytechnical University, China China University of Mining and Technology, China Anhui University, China Shanghai Jiao Tong University, China Huazhong University of Science and Technology, China Jiangnan University, China Lanzhou University, China Guangxi University for Nationalities, China University of Leicester, UK Shandong University, China Northwestern Polytechnical University, China Shenzhen University, China

Organization

Yunxia Liu Zhanli Sun Zhenran Jiang Zhengtao Yu Zhenyu Xuan Zhihong Guan Zhihua Cui Zhiping Liu Zhiqiang Geng Zhongqiu Zhao Zhuhong You

xiii

Zhengzhou Normal University, China Anhui University, China East China Normal University, China Kunming University of Science and Technology, China University of Texas at Dallas, USA Huazhong University of Science and Technology, China Taiyuan University of Science and Technology, China Shandong University, China Beijing University of Chemical Technology, China Hefei University of Technology, China Northwestern Polytechnical University, China

Contents – Part II

Neural Networks Adversarial Ensemble Training by Jointly Learning Label Dependencies and Member Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Lele Wang and Bin Liu PFGE: Parsimonious Fast Geometric Ensembling of DNNs . . . . . . . . . . . . . . . . . . Hao Guo, Jiyong Jin, and Bin Liu Research on Indoor Positioning Algorithm Based on Multimodal and Attention Mechanism . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Chenxi Shi, Lvqing Yang, Lanliang Lin, Yongrong Wu, Shuangyuan Yang, Sien Chen, and Bo Yu

3

21

33

Undetectable Attack to Deep Neural Networks Without Using Model Parameters . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Chen Yang, Yinyan Zhang, and Ameer Hamza Khan

46

A Traffic Flow Prediction Framework Based on Clustering and Heterogeneous Graph Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Lei Luo, Shiyuan Han, Zhongtao Li, Jun Yang, and Xixin Yang

58

Effective Audio Classification Network Based on Paired Inverse Pyramid Structure and Dense MLP Block . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Yunhao Chen, Yunjie Zhu, Zihui Yan, Zhen Ren, Yifan Huang, Jianlu Shen, and Lifang Chen

70

Dynamic Attention Filter Capsule Network for Medical Images Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Ran Chen, Kai Hu, and Zhong-Qiu Zhao

85

Cross-Scale Dynamic Alignment Network for Reference-Based Super-Resolution . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Kai Hu, Ran Chen, and Zhong-Qiu Zhao

98

Solving Large-Scale Open Shop Scheduling Problem via Link Prediction Based on Graph Convolution Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 109 Lanjun Wan, Haoxin Zhao, Xueyan Cui, Changyun Li, and Xiaojun Deng

xvi

Contents – Part II

Solving Class Imbalance Problem in Target Detection with a Squared Cross Entropy Based Method . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 124 Guanyu Chen, Quanyu Wang, Qi Li, Jun Hu, and Jingyi Liu Modeling Working Memory Using Convolutional Neural Networks for Knowledge Tracing . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 137 Huali Yang, Bin Chen, Junjie Hu, Tao Huang, Jing Geng, and Linxia Tang A Multi-granularity Decision Fusion Method Based on Category Hierarchy . . . . 149 Jian-Xun Mi, Ke-Yang Huang, and Nuo Li Make Active Attention More Active: Using Lipschitz Regularity to Improve Long Sequence Time-Series Forecasting . . . . . . . . . . . . . . . . . . . . . . . . 162 Xiangxu Meng, Wei Li, Wenqi Zheng, Zheng Zhao, Guangsheng Feng, and Huiqiang Wang Attributed Multi-relational Graph Embedding Based on GCN . . . . . . . . . . . . . . . . 174 Zhuo Xie, Mengqi Wu, Guoping Zhao, Lijuan Zhou, Zhaohui Gong, and Zhihong Zhang CharCaps: Character-Level Text Classification Using Capsule Networks . . . . . . . 187 Yujia Wu, Xin Guo, and Kangning Zhan Multi-student Collaborative Self-supervised Distillation . . . . . . . . . . . . . . . . . . . . . 199 Yinan Yang, Li Chen, Shaohui Wu, and Zhuang Sun Speech Emotion Recognition Using Global-Aware Cross-Modal Feature Fusion Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 211 Feng Li and Jiusong Luo MT-1DCG: A Novel Model for Multivariate Time Series Classification . . . . . . . . 222 Yu Lu, Huanwen Liang, Zichang Yu, and Xianghua Fu Surrogate Modeling for Soliton Wave of Nonlinear Partial Differential Equations via the Improved Physics-Informed Deep Learning . . . . . . . . . . . . . . . . 235 Yanan Guo, Xiaoqun Cao, Kecheng Peng, Wenlong Tian, and Mengge Zhou TAHAR: A Transferable Attention-Based Adversarial Network for Human Activity Recognition with RFID . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 247 Dinghao Chen, Lvqing Yang, Hua Cao, Qingkai Wang, Wensheng Dong, and Bo Yu

Contents – Part II

xvii

Pattern Recognition Improved Blind Image Denoising with DnCNN . . . . . . . . . . . . . . . . . . . . . . . . . . . . 263 Guang Yi Chen, Wenfang Xie, and Adam Krzyzak Seizure Prediction Based on Hybrid Deep Learning Model Using Scalp Electroencephalogram . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 272 Kuiting Yan, Junliang Shang, Juan Wang, Jie Xu, and Shasha Yuan Data Augmentation for Environmental Sound Classification Using Diffusion Probabilistic Model with Top-K Selection Discriminator . . . . . . . . . . . 283 Yunhao Chen, Zihui Yan, Yunjie Zhu, Zhen Ren, Jianlu Shen, and Yifan Huang Improved DetNet Algorithm Based on GRU for Massive MIMO Systems . . . . . . 296 Hanqing Ding, Bingwei Li, and Jin Xu Epileptic Seizure Detection Based on Feature Extraction and CNN-BiGRU Network with Attention Mechanism . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 308 Jie Xu, Juan Wang, Jin-Xing Liu, Junliang Shang, Lingyun Dai, Kuiting Yan, and Shasha Yuan Information Potential Based Rolling Bearing Defect Classification and Diagnosis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 320 Hui Li, Ruijuan Wang, and Yonghui Xie Electrocardiogram Signal Noise Reduction Application Employing Different Adaptive Filtering Algorithms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 330 Amine Essa, Abdullah Zaidan, Suhaib Ziad, Mohamed Elmeligy, Sam Ansari, Haya Alaskar, Soliman Mahmoud, Ayad Turky, Wasiq Khan, Dhiya Al-Jumeily OBE, and Abir Hussain Improving the Accuracy of Deep Learning Modelling Based on Statistical Calculation of Mathematical Equations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 343 Feng Li and Yujun Hu Deep Learning for Cardiotocography Analysis: Challenges and Promising Advances . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 354 Cang Chen, Weifang Xie, Zhiqi Cai, and Yu Lu Exploiting Active-IRS by Maximizing Throughput in Wireless Powered Communication Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 367 Iqra Hameed and Insoo Koo

xviii

Contents – Part II

Palmprint Recognition Utilizing Modified LNMF Method . . . . . . . . . . . . . . . . . . . 377 Li Shang, Yuze Zhang, and Bo Huang PECA-Net: Pyramidal Attention Convolution Residual Network for Architectural Heritage Images Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . 389 Shijie Li, Yifei Yang, and Mingyang Zhong Convolutional Self-attention Guided Graph Neural Network for Few-Shot Action Recognition . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 401 Fei Pan, Jie Guo, and Yanwen Guo DisGait: A Prior Work of Gait Recognition Concerning Disguised Appearance and Pose . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 413 Shouwang Huang, Ruiqi Fan, and Shichao Wu Image Processing An Experimental Study on MRI Denoising with Existing Image Denoising Methods . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 429 Guang Yi Chen, Wenfang Xie, and Adam Krzyzak Multi-scale and Self-mutual Feature Distillation . . . . . . . . . . . . . . . . . . . . . . . . . . . 438 Nianzu Qiao, Jia Sun, and Lu Dong A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation and Guidance in Off-Road Environments . . . . . . . . . . . . . . . . . . . . . 449 Yu Liu, Xue Fan, Shiyuan Han, and Weiwei Yu DCNet: Glass-Like Object Detection via Detail-Guided and Cross-Level Fusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 461 Jianhao Zhang, Gang Yang, and Chang Liu A Method for Detecting and Correcting Specular Highlights in Capsule Endoscope Images Based on Independent Cluster Distribution . . . . . . . . . . . . . . . 473 Jiarui Ma and Yangqing Hou Siamese Adaptive Template Update Network for Visual Tracking . . . . . . . . . . . . . 485 Jia Wen, Kejun Ren, Yang Xiang, and Dandan Tang Collaborative Encoder for Accurate Inversion of Real Face Image . . . . . . . . . . . . 498 YaTe Liu, ChunHou Zheng, Jun Zhang, Bing Wang, and Peng Chen Text-Guided Generative Adversarial Network for Image Emotion Transfer . . . . . 506 Siqi Zhu, Chunmei Qing, and Xiangmin Xu

Contents – Part II

xix

What Constitute an Effective Edge Detection Algorithm? . . . . . . . . . . . . . . . . . . . 523 Prashan Premaratne and Peter Vial SporeDet: A Real-Time Detection of Wheat Scab Spores . . . . . . . . . . . . . . . . . . . . 531 Jin Yuan, Zhangjin Huang, Dongyan Zhang, Xue Yang, and Chunyan Gu Micro-expression Recognition Based on Dual-Branch Swin Transformer Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 544 Zhihua Xie and Chuwei Zhao CC-DBNet: A Scene Text Detector Combining Collaborative Learning and Cascaded Feature Fusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 555 Wenheng Jiang, Yuehui Chen, Yi Cao, and Yaou Zhao A Radar Video Compression and Display Method Based on FPGA . . . . . . . . . . . 567 Daiwei Xie, Jin-Wu Wang, and Zhenmin Dai SA-GAN: Chinese Character Style Transfer Based on Skeleton and Attention Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 576 Jian Shu, Yuehui Chen, Yi Cao, and Yaou Zhao IL-YOLOv5: A Ship Detection Method Based on Incremental Learning . . . . . . . 588 Wenzheng Liu and Yaojie Chen Efficient and Precise Detection of Surface Defects on PCBs: A YOLO Based Approach . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 601 Lejun Pan, Wenyan Wang, Kun Lu, Jun Zhang, Peng Chen, Jiawei Ni, Chenlin Zhu, and Bing Wang Corneal Ulcer Automatic Classification Network Based on Improved Mobile ViT . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 614 Chenlin Zhu, Wenyan Wang, Kun Lu, Jun Zhang, Peng Chen, Lejun Pan, Jiawei Ni, and Bing Wang Multiple Classification Network of Concrete Defects Based on Improved EfficientNetV2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 626 Jiawei Ni, Bing Wang, Kun Lu, Jun Zhang, Peng Chen, Lejun Pan, Chenlin Zhu, Bing Wang, and Wenyan Wang A Lightweight Hyperspectral Image Super-Resolution Method Based on Multiple Attention Mechanisms . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 639 Lijing Bu, Dong Dai, Zhengpeng Zhang, Xinyu Xie, and Mingjun Deng

xx

Contents – Part II

Graph Disentangled Representation Based Semi-supervised Single Image Dehazing Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 652 Tongyao Jia, Jiafeng Li, and Li Zhuo Surface Target Saliency Detection in Complex Environments . . . . . . . . . . . . . . . . 664 Benxin Yang and Yaojie Chen Traffic Sign Recognition Based on Improved VGG-16 Model . . . . . . . . . . . . . . . . 676 Tang Shuyuan, Li Jintao, and Liu Chang A Weakly Supervised Semantic Segmentation Method on Lung Adenocarcinoma Histopathology Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 688 Xiaobin Lan, Jiaming Mei, Ruohan Lin, Jiahao Chen, and Yanju Zhang Improved Lane Line Detection Algorithms Based on Incomplete Line Fitting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 699 QingYu Ren, BingRui Zhao, TingShuo Jiang, and WeiZe Gao A Deep Transfer Fusion Model for Recognition of Acute Lymphoblastic Leukemia with Few Samples . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 710 Zhihua Du, Xin Xia, Min Fang, Li Yu, and Jianqiang Li IntrNet: Weakly Supervised Segmentation of Thyroid Nodules Based on Intra-image and Inter-image Semantic Information . . . . . . . . . . . . . . . . . . . . . . 722 Jie Gao, Shaoqi Yan, Xuzhou Fu, Zhiqiang Liu, Ruiguo Yu, and Mei Yu Computational Intelligence and Its Application Novel Ensemble Method Based on Improved k-nearest Neighbor and Gaussian Naive Bayes for Intrusion Detection System . . . . . . . . . . . . . . . . . . . 737 Lina Ge, Hao Zhang, and Haiao Li A Hybrid Queueing Search and Gradient-Based Algorithm for Optimal Experimental Design . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 749 Yue Zhang, Yi Zhai, Zhenyang Xia, and Xinlong Wang A Review of Client Selection Mechanisms in Heterogeneous Federated Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 761 Xiao Wang, Lina Ge, and Guifeng Zhang ARFG: Attach-Free RFID Finger-Tracking with Few Samples Based on GAN . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 773 Sijie Li, Lvqing Yang, Sien Chen, Jianwen Ding, Wensheng Dong, Bo Yu, Qingkai Wang, and Menghao Wang

Contents – Part II

xxi

GM(1,1) Model Based on Parallel Quantum Whale Algorithm and Its Application . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 785 Huajuan Huang, Shixian Huang, Xiuxi Wei, and Yongquan Zhou 3D Path Planning Based on Improved Teaching and Learning Optimization Algorithm . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 798 Xiuxi Wei, Haixuan He, Huajuan Huang, and Yongquan Zhou Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 811

Neural Networks

Adversarial Ensemble Training by Jointly Learning Label Dependencies and Member Models Lele Wang1 and Bin Liu2(B) 1 Research Center for Data Mining and Knowledge Discovery, Zhejiang Lab,

Hangzhou 311121, China [email protected] 2 Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab, Hangzhou 311121, China [email protected]

Abstract. Training an ensemble of diverse sub-models has been empirically demonstrated as an effective strategy for improving the adversarial robustness of deep neural networks. However, current ensemble training methods for image recognition typically encode image labels using one-hot vectors, which overlook dependency relationships between the labels. In this paper, we propose a novel adversarial ensemble training approach that jointly learns the label dependencies and member models. Our approach adaptively exploits the learned label dependencies to promote diversity among the member models. We evaluate our approach on widely used datasets including MNIST, FashionMNIST, and CIFAR-10, and show that it achieves superior robustness against black-box attacks compared to state-of-the-art methods. Our code is available at https://github.com/ZJLABAMMI/LSD. Keywords: deep learning · model ensemble · adversarial Attack · label dependency

1 Introduction Deep neural networks (DNNs), also known as deep learning, have achieved remarkable success across many tasks in computer vision [1–4], speech recognition [5, 6], and natural language processing [7, 8]. However, numerous works have shown that modern DNNs are vulnerable to adversarial attacks [9–14]. Even slight perturbations to input images, which are imperceptible to humans, can fool a high-performing DNN into making incorrect predictions. Additionally, adversarial attacks designed for one model may deceive other models, resulting in wrong predictions - this issue is known as adversarial transferability [15–18]. These adversarial vulnerability issues pose significant challenges for real-life L. Wang—Work done when he was with Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 3–20, 2023. https://doi.org/10.1007/978-981-99-4742-3_1

4

L. Wang and B. Liu

applications of DNNs, particularly for safety-critical problems such as self-driving cars [19, 20]. As a result, there has been increasing attention on promoting robustness against adversarial attacks for DNNs. Ensembling models has been shown to be a highly effective approach for improving the adversarial robustness of deep learning systems. The basic idea behind model ensembling is illustrated through a Venn diagram in Fig. 1 [21]. The rectangle represents the space spanned by all possible orthogonal perturbations to an input instance, while the circle represents the subspace spanned by perturbations that are adversarial to the associated model. The shaded region represents the perturbation subspace that is adversarial to the model ensemble. In the case of a single model as shown in Fig. 1(a), any perturbation within the circle results in misclassification of the input image. However, for cases that employ an ensemble of two or more models (Fig. 1(b) and Fig. 1(c)), successful adversarial attacks require perturbations only within the shaded region, meaning that the attack must fool all individual models in the ensemble. Therefore, promoting diversity among individual models is an intuitive strategy to improve the adversarial robustness of a model ensemble as the less overlap there is among their corresponding adversarial subspaces, the greater the diversity of the individual models. The amount of overlap determines the dimensionality of the adversarial subspace [21]. Throughout this paper, we use the terms individual model, sub-model, and member model interchangeably.

Fig. 1. A conceptual illustration of the idea of using model ensembling to promote adversarial robustness: (a) single model; (b) an ensemble of two models; (c) an ensemble of two more diversified models. The shaded region denotes a subspace, adversarial attacks within which can fool the (ensemble) model to make a wrong prediction. This figure is drawn referring to Fig. 1 of [21].

One challenge in applying model ensembling is how to promote diversity among member models while maintaining their prediction accuracy, particularly for nonadversarial inputs. The question becomes: how can we balance the trade-off between diversity and prediction quality during ensemble training? To address this issue, several advanced methods have been proposed in the literature [21–25]. For example, Pang et al. propose a diversity regularized cost function that encourages diversity in the nonmaximal class predictions given by the last softmax layers of the member models [22]. Kariyappa & Qureshi select an ensemble of gradient misaligned models by minimizing their pairwise gradient similarities [21]. Yang et al. find that merely encouraging misalignment between pairwise gradients is insufficient to reduce adversarial transferability [25], and thus propose promoting both gradient misalignment and model smoothness. Sen et al. propose training an ensemble of models with different numerical precisions,

Adversarial Ensemble Training by Jointly Learning Label Dependencies

5

where models with higher numerical precisions aim for prediction accuracy, while those with lower numerical precisions target adversarial robustness [23]. As previously mentioned, to apply state-of-the-art (SOTA) ensemble training methods, one must select a diversity metric to measure the diversity of member models. This metric may be the difference in non-maximal predictions given by the last softmax layers of the member models [22], or the difference in gradient vectors associated with the member models [21]. It is worth noting that all these training methods use one-hot vectors to encode image classes, meaning each image in the training set is assigned a hard label. However, such hard labels cannot reflect any possible dependency relationships among image classes. Given an image, it’s likely that dependency relationships exist between its associated classes. For example, in a handwritten digit image dataset, the number “0” may look more similar to “9” than to “4”; “3” may look more similar to “8” than to “7”, and so on. Conditional on a noisy input image, e.g., one whose ground truth label is “0”, there should exist a dependency relationship between labels “0” and “9”. Using hard labels omits such conditional dependency relationships. Motivated by the above observation, we propose a novel ensemble training approach called Conditional Label Dependency Learning (CLDL) assisted ensemble training. Our approach jointly learns the conditional dependencies among image classes and member models. Compared to existing methods, our approach selects a different diversity metric that considers the difference in pairwise gradient vectors and predicted soft labels given by member models. The learned soft labels encode dependency relationships among the original hard labels. We find that our approach is more robust against black-box attacks compared to several state-of-the-art methods. In summary, the main contributions of this work are: – We propose a novel diversity metric for adversarial ensemble training that incorporates information from both the gradient vectors associated with member models and predicted soft labels given by member models. – We adapt a label confusion learning (LCL) model developed in [26] to generate soft labels of images in the context of adversarial ensemble training, originally used for enhancing text classification. – We propose a CLDL-assisted ensemble training algorithm and demonstrate that it complements existing ensemble training methods. In particular, we show that our algorithm is more robust against black-box attacks compared to several state-of-theart methods. The remainder of this paper is organized as follows. In Sect. 2, we present some preliminary information. In Sect. 3, we describe our CLDL-based ensemble training method in detail. In Sect. 4, we present the experimental results. We conclude the paper in Sect. 5.

6

L. Wang and B. Liu

2 Preliminary In this section, we present the preliminary knowledge that is related to our work. 2.1 Notations Here we consider a DNN based image recognition task, which involves C classes. Following [25], let X denote the input space of the DNN model, Y = {1, 2, ..., C} the class space. A DNN model is trained to yield a mapping function F : X → Y. Let x denote  a clean image and x an adversarial counterpart of x. Let  be a pre-set attack radius that defines the maximal magnitude of an adversarial perturbation. That says, for any adversarial perturbation, its Lp norm is required to be less than . Let lF (x, y) denote the cost function used for training the model F(x, θ ), where θ denotes the parameter of the model. 2.2 Definitions Definition 1. Adversarial Attack [25]. Given an input x ∈ X with true label y ∈ Y, F(x) = y. (1) An untargeted attack crafts AU (x) = x + δ to maximize lF (x + δ, y) where δp ≤ . (2) A targeted attack with target label yt ∈ Y crafts AT (x) = x + δ to minimize lF (x + δ, yt ) where δp ≤  and  is a pre-defined attack radius that limits the power of the attacker. Definition 2. Alignment of Loss Gradients [21, 25]. The alignment of loss gradients between two differentiable loss functions lF and lG is defined as: CS(∇x lF , ∇x lG ) =

∇x lF (x, y) · ∇x lG (x, y) ∇x lF (x, y)2 · ∇x lF (x, y)2

(1)

which is the cosine similarity between the gradients of the two loss functions for an input x drawn from X with any label y ∈ Y. If the cosine similarity of two gradients is −1, we say that they are completely misaligned.

2.3 Adversarial Attacks Adversarial attacks aim to create human-imperceptible adversarial inputs that can fool a high-performing DNN into making incorrect predictions. These attacks are typically divided into two basic classes: white-box attacks, which assume the adversary has full knowledge of the model’s structures and parameters, and black-box attacks, which assume the adversary has no access or knowledge of any information regarding the model. Here we briefly introduce four typical white-box attacks involved in our experiments while referring readers to review papers [27–29] and references therein for more information on adversarial attacks. Fast Gradient Sign Method (FGSM). FGSM is a typical white-box attacking method to find adversarial examples. It performs a one-step update along the direction of the

Adversarial Ensemble Training by Jointly Learning Label Dependencies

7 

gradient of the adversarial loss. Specifically, it generates an adversarial example x by solving a maximization problem as follows [9]: 

x = x + ε · sign(∇x l(x, y))

(2)

where ε denotes the magnitude of the perturbation, x the original benign image sample, y the target label of x, and ∇x l(x, y) the gradient of the loss l(x, y) with respect to x. Basic Iterative Method (BIM). BIM is an extension of FGSM, which performs FGSM iteratively to generate an adversarial example as follows [30]: 



xi = clipx, (xi−1 +

 · sign(gi−1 )) r

(3)



where x0  x, r is the number of iterations, clipx, (A) is a clipping function that projects    A in a -neighbourhood of x, and gi  ∇x l xi , y . Projected Gradient Descent (PGD). PGD [12] is almost the same as BIM, the only  difference being that PGD initialize x0 as a random point in the -neighbourhood of x. Momentum Iterative Method (MIM). MIM is an extension of BIM. It updates the gradient gi with the momentum μ as follows [31]: 



xi = clipx, (xi−1 + where gi = μgi−1 +

∇x (xi−1 ,y) ∇x (xi−1 ,y)1 ,

 · sign(gi )) r

(4)

and ·1 denotes the L1 norm.

2.4 On Techniques to Generate Soft Labels Label smoothing is perhaps the most commonly used technique to generate soft labels [4, 32–34]. Although it is simple, it has been demonstrated as an effective approach to improve the accuracy of deep learning predictions. For example, Szegedy et al. propose generating soft labels by averaging the one-hot vector and a uniform distribution over labels [4]. Fu et al. study the effect of label smoothing on adversarial training and find that adversarial training with the aid of label smoothing can enhance model robustness against gradient-based attacks [35]. Wang et al. propose an adaptive label smoothing approach capable of adaptively estimating a target label distribution [36]. Guo et al. propose a label confusion model (LCM) to improve text classification performance [26], which we adapt here for CLDL and generating soft labels for adversarial ensemble prediction.

3 CLDL Assisted Ensemble Training Here we describe our CLDL based ensemble training approach in detail. We present a pseudo-code implementation of our approach in Algorithm 1 and a conceptual illustration in Fig. 2. For ease of presentation, we will use an example of an ensemble consisting of N = 2 member models.

8

L. Wang and B. Liu

Our model is mainly composed of two parts: an ensemble of N sub-models ({Fi (x, θi )}i∈[N ] , where [N ] = {1, 2, . . . , N }) and a label confusion model (LCM) adapted from [26]. Each sub-model in the ensemble consists of an input convolutional neural network (CNN) encoder followed by a fully connected classifier, which can be any main stream DNN based image classifier. As shown in Fig. 2, an image instance (x) is fed into the input-encoder, which generates an image representation vi , where i is the sub-model index. Then vi is fed into the fully connected classifier to predict the label distribution of this image. The above operations can be formulated as follows: vi = Fiencoder (x) pi = softmax(Wvi + b)

(5)

where Fiencoder (.) is the output of input-encoder of Fi which transforms x to vi , W and b are weights and the bias of the fully connect layer that transforms vi to the the predicted label distribution (PLD) pi . The LCM consists of two parts: a label encoder and a simulated label distribution (SLD) computation block. The label encoder is a deep neural network used to generate the label representation matrix [26]. The SLD computation block comprises a similarity layer and an SLD computation layer. The similarity layer takes the label representation matrix and the current instance’s representation as inputs, computes the dot product similarities between the image instance and each image class label, then feeds the similarity values into a softmax activation layer that outputs the label confusion vector (LCV). The LCV captures the conditional dependencies among the image classes through the computed similarities between the instance representation and the label representations. The LCV is instance-dependent, meaning it considers the dependency relationships among all image class labels conditional on a specific image instance. In the following SLD computation layer, the one-hot vector formed hard label yi is added to the LCV with a controlling parameter γ , which is then normalized by a soft-max function that generates the SLD. The controlling parameter γ decides how much of the one-hot vector will be changed by the LCV. The above operations can be formulated as follows: (l)

(l)

(l)

Vec(l) = f L (l) = f L ([l1 , l2 , . . . , lC ]) = [Vec1 , Vec2 , . . . , VecC ] ci = softmax(vi Vec(l) W + b) si = softmax(γ yi + ci )

(6)

where f L is the label encoder function to transfer labels l = [l1 , l2 , . . . , lC ] to the label representation matrix Vec(l) , C the number of image classes, f L is implemented by an embedding lookup layer followed by a DNN, ci the LCV and si the SLD. The SLD is then viewed as a soft label that replaces the original hard label for model training.

Adversarial Ensemble Training by Jointly Learning Label Dependencies

Algorithm 1 CLDL assisted ensemble training for an ensemble of

9

sub-models

Note that the SLD si and the predicted label vector pi are both probability measures. We use the Kullback-Leibler (KL) divergence [37] to measure their difference:    c log si c (7) li (x) = KL(si , pi ) = C s i c=1 pi c

10

L. Wang and B. Liu

The LCM is trained by minimizing the above KL divergence, whose value depends on the semantic representation of the image instance vi and the soft label si given by the LCM.

Fig. 2. The proposed CLDL assisted ensemble training method. Here we take an ensemble model that consists of two member models as an example for ease of illustration. Given an image instance x, the label confusion model (LCM) in the middle, which is adapted from [26], is used to generate a soft label si for the i th sub-model. Two types of diversity regularizers that are based on the label distribution given the soft labels and the gradient are combined to generate the finally used ensemble diversity regularizer. See the text in Sect. 3 for more details. The LCM module is drawn referring to Fig. 1 of [26]. Note that the model of [26] is used for text classification, while here it is adapted to an ensemble model for image classification under adversarial attacks.

Adversarial Ensemble Training by Jointly Learning Label Dependencies

11

3.1 Diversity Promoting Loss Design Here we present our design of the diversity promoting loss used in our approach. Soft Label Diversity: For an input (x, y) in the training dataset, we define the soft label diversity based on the non-maximal value of the SLD of each sub-model. Specifically, \y let si be a (C − 1)×1 vector constructed by excluding the maximal value from the SLD corresponding to model Fi (x, θ ). Then we use the Jensen-Shannon divergence (JSD) [38] to measure the difference between a pair of, say the i th and the j th member models, in terms of their predicted soft labels, as follows \y

\y

\y

\y

1 \y si + sj \y si + sj ) + KL(sj , )) (8) = (KL(si , 2 2 2   \y \y From Eqns. (6) and (8), we see that JSD si sj monotonically increases with   \y \y JSD ci cj . A large JSD indicates a misalignment between the SLDs of the two involved sub-models given the image instance x. Given x and an ensemble of N models, we define a loss item as follows     N N 2 \y \y (9) exp JSD si sj lld (x) = log i=1 j=i+1 N (N − 1) \y \y JSD(si sj )

which will be included in the final loss function Eq. (12) used for training the model ensemble. It plays a vital role in promoting the member models’ diversity concerning their predicted soft labels given any input instance. It is worth noting that we only consider the non-maximal values of the SLDs in Eq. (8) following [22]. By doing so, promoting the diversity among the sub-models does not affect the ensemble’s prediction accuracy for benign inputs. However, it can lower the transferability of attacks among the sub-models. Gradient Diversity: Following [21, 25], we consider the sub-models’ diversity in terms of gradients associated with them. Given an image instance x and an ensemble of N models, we define the gradient diversity loss item as follows lgd (x) =

   (∇x l )T (∇x l )  N N N N 2 2 i j   CS(∇x li , ∇x lj ) =   i=1 j=i+1 i=1 j=i+1  ∇x li 2 ∇x lj   N (N − 1) N (N − 1) 2

(10)

which will also be included in the final loss function Eq. (12). The Combined Diversity Promoting Loss: Combining the above soft label and gradient diversity loss items, we propose our CLDL based ensemble diversity loss function. For a pair of member models (F and G), given an input instance x, this diversity promoting loss function is lF ,G ,x = −αlld (x) + βlgd (x)

(11)

where α, β ≥ 0 are hyper-parameters that balance the effects of the soft label based and the gradient based diversity loss items.

12

L. Wang and B. Liu

3.2 CLDL Based Ensemble Model Training We train our ensemble model by minimizing the training loss function defined as follows

N 1  Loss(x) = lE (x) + lF ,G ,x = li − αlld (x) + βlgd (x) (12) N i=1

 where lE (x) = N1 N i=1 li refers to the average of the KL-divergence losses of the member models. See Fig. 2 for the definition of the KL loss of the member models. By minimizing the above loss function, we simultaneously learn the soft labels given by each sub-model, promote the diversity among the sub-models in terms of their predicted soft labels and their gradients, and minimize the KL-divergence loss of each sub-model.

4 Experiments 4.1 Datasets and Competitor Methods We conducted our experiments on the widely-used image datasets MNIST [39], FashionMNIST (F-MNIST) [40], and CIFAR-10 [41]. For each dataset, we used its training set for ensemble training. We set the hyper-parameters of our algorithm based on 1,000 test images randomly sampled from the testing set and used the remaining data in the testing set for performance evaluation. We compared the performance of our algorithm with competitor methods, including a baseline method that trains the model ensemble without the use of any defense mechanism and four popularly used ensemble training methods: the adaptive diversity promoting (ADP) algorithm, the gradient alignment loss (GAL) method, the diversifying vulnerabilities for enhanced robust generation of ensembles (DVERGE) method, and the transferability reduced smooth (TRS) method. We used ResNet-20 as the basic model structure of the sub-models and averaged the output probabilities given by the softmax layer of the member models to yield the final prediction. 4.2 Optimizer Used for Training We used Adam [42] as the optimizer for ensemble training with an initial learning rate of 10−3 , and a weight decaying parameter of 10−4 . For our CLDL-based approach, we trained the ensemble for 200 epochs, multiplied the learning rate by 0.1 twice at the 100th and 150th epochs, respectively. We set the batch size to 128 and used normalization, random cropping, and flipping-based data augmentations for dataset CIFAR-10. We considered two ensemble sizes, 3 and 5, in our experiments. To make a fair comparison, we trained ADP, GAL, DVERGE, and TRS under a similar training setup described above. We used the AdverTorch [43] library for simulating adversarial attacks.

Adversarial Ensemble Training by Jointly Learning Label Dependencies

13

4.3 White-Box Attacks We considered four basic white-box adversarial attacks, namely FGSM, BIM, MIM, and PGD for simulating black-box attacks used in our experiments. For each attack type, we considered four different perturbation scales () ranging from 0.01 to 0.04. We set the number of attack iterations at 10 and set the step size to be  /5 for BIM, MIM, and PGD. Each experiment was run five times independently, and the results were averaged for performance comparison. We simulated the white-box attacks by treating the whole ensemble, other than one of the individual sub-models, as the target model to be attacked. 4.4 Black-Box Attacks We considered black-box attacks, in which the attacker has no knowledge about the target model, including its architecture and parameters. The attacker designs adversarial examples based on several surrogate models. We simulated black-box attacks with our ensemble model as the target by creating white-box adversarial attacks based on a surrogate ensemble model that has the same architecture as the true target ensemble and is trained on the same dataset using the same training routine. We trained the surrogate ensemble model consisting of 3- or 5-member sub-models by minimizing a standard ensemble cross-entropy loss function. For each type of attack mentioned above, we evaluated the robustness of the involved training methods under black-box attacks with four different perturbation scales () 0.01, 0.02, 0.03, and 0.04. We set the number of attack iterations at 10, and the step size at /5 for BIM, MIM, and PGD-based attacks. Following [44], we generated adversarial examples using the cross-entropy loss and the CW loss [11]. 4.5 Experimental Results for Black-box Adversarial Attacks In our experiments, we used classification accuracy as the performance metric, which is the ratio of the number of correctly predicted adversarial instances to the total number of adversarial instances. We conducted random re-trainings of the model in our experiments, and the reported values are averages of multiple (>3) independent tests. Our code is open-sourced to support reproducibility of these results. CIFAR-10. Here we present our experimental results on CIFAR-10 in Tables 1 and 2. Note that, in all tables shown below,CLDLa,b,c denotes our CLDL based algorithm with hyper-parameters, namely γ in Eq. (6), α and β in Eq. (12), set to be a, b, and c, respectively.  refers to the perturbation scale of the attack. As is shown, our CLDL based algorithm performs best for almost all attacks considered, compared to the other methods, especially when the perturbation scale is large.

14

L. Wang and B. Liu

We also investigate the effects of the soft label diversity based loss and the gradient diversity based one on the performance of our algorithm. See the result in Table 3. As is shown, CLDL4,2,4 gives the best results. By comparing the result of CLDL4,2,4 to that of CLDL4,0.5,4 , we find a performance gain given by the soft label diversity based loss. By comparing the result of CLDL4,2,4 to that of CLDL4,2,0 , we verify the contribution of the gradient diversity-based loss. MNIST. In Tables 4 and 5, we show the classification accuracy (%) results for an ensemble of 3 and 5 LeNet-5 member models [45] on the MNIST dataset. We find that again our algorithm outperforms its competitors significantly. F-MNIST. In Tables 6 and 7, we present results associated with the F-MNIST dataset. As is shown, among all methods involved, our CLDL algorithm ranks number 1 for 10 times. TRS and GAL have 5 times and one time to rank number 1, respectively. Table 1. Classification accuracy (%) on the CIFAR-10 dataset for four types of black-box attacks. The ensemble consists of 3 ReNets-20 member models.  refers to the perturbation scale for the attack. CIFAR-10



ADP

GAL

DVERGE

TRS

CLDL4,2,4

CLDL4,0.5,4

BIM

0.01 0.02 0.03 0.04

45.43 9.74 2.01 0.47

92.61 83.59 73.02 62.32

89.66 75.41 59.244 2.27

87.86 71.17 53.74 37.87

92.77 85.21 75.76 66.47

92.77 83.66 72.97 62.38

FGSM

0.01 0.02 0.03 0.04

67.03 40.58 26.41 18.25

93.47 84.62 75.93 66.15

91.20 78.68 65.37 51.40

90.37 77.59 64.12 51.04

93.49 86.21 77.94 69.93

93.47 85.48 76.04 66.95

MIM

0.01 0.02 0.03 0.04

41.65 7.37 1.13 0.30

91.11 76.92 59.75 43.48

87.52 65.85 40.35 19.77

85.39 61.80 38.01 20.24

91.29 78.98 63.84 48.27

91.08 77.01 59.89 43.06

PGD

0.01 0.02 0.03 0.04

46.20 9.23 1.55 0.35

92.41 83.46 74.67 65.46

89.36 76.08 61.28 44.72

88.06 72.07 55.12 39.23

92.18 84.81 77.34 70.64

92.25 83.68 75.04 67.09

Adversarial Ensemble Training by Jointly Learning Label Dependencies

15

Table 2. Classification accuracy (%) on the CIFAR-10 dataset for four types of black-box attacks. The ensemble consists of 5 ReNets-20 member models. CIFAR-10



ADP

GAL

DVERGE

TRS

CLDL4,2,4

CLDL4,0.5,4

BIM

0.01 0.02 0.03 0.04

45.28 9.50 1.79 0.44

92.22 81.69 69.68 57.21

93.32 85.27 75.72 64.59

91.17 80.15 67.93 55.95

93.73 85.71 75.93 65.39

92.89 84.58 75.56 65.8

FGSM

0.01 0.02 0.03 0.04

68.49 43.02 27.76 19.36

93.33 84.13 74.15 64.32

94.22 86.79 77.56 67.10

92.65 83.30 72.24 61.08

94.47 87.06 78.79 69.93

93.54 85.92 78.3 69.68

MIM

0.01 0.02 0.03 0.04

41.77 7.42 1.16 0.25

90.50 73.95 55.09 37.87

91.92 79.69 62.87 44.19

89.53 73.33 54.63 38.76

92.17 79.33 63.24 47.16

91.4 78.67 62.93 46.62

PGD

0.01 0.02 0.03 0.04

46.50 9.43 1.54 0.29

92.22 81.85 71.89 61.60

93.34 85.60 77.59 68.53

91.29 80.24 68.25 55.45

92.92 85.92 78.07 70.05

92.02 84.45 76.87 69.78

Table 3. Classification accuracy (%) given by an ensemble model consisting of 3 ReNets-20 member models trained with our CLDL based algorithm with different hyper-parameter settings against black-box attacks on the CIFAR-10 dataset. CIFAR-10



CLDL4,0,0

CLDL4,1,0

CLDL4,2,0

CLDL4,4,0

CLDL4,1,2

CLDL4,2,2

CLDL4,0.5,4

CLDL4,2,4

BIM

0.01 0.02 0.03

89.97 74.01 57.1

89.66 74.18 56.62

89.82 74.2 56.78

89.78 74.42 56.92

92.76 83.43 73.02

92.62 82.97 72.64

92.77 83.66 72.97

92.77 85.21 75.76

0.04

40.11

39.63

40.63

40.53

61.8

61.69

62.38

66.47

FGSM

0.01 0.02

91.53 78.64

91.07 78.63

91.01 78.78

91.12 78.51

93.5 84.79

93.2 84.62

93.47 85.48

93.49 86.21

0.03 0.04

64.83 52.36

64.37 51.87

65.22 53.11

64.37 52.16

76.05 66.55

76.24 67.42

76.04 66.95

77.94 69.93

0.01

87.64

87.18

87.39

87.26

91.22

90.73

91.08

91.29

0.02 0.03 0.04

63.63 39.03 20.84

63.24 38.62 20.29

63.79 39.29 21.5

63.86 38.82 21.42

76.16 58.9 41.79

76.3 59.25 43.07

77.01 59.89 43.06

78.98 63.84 48.27

0.01 0.02 0.03 0.04

90.24 75.54 60.63 46.97

89.74 75.66 60.03 46.2

89.76 75.48 60.55 46.68

89.75 75.69 60.13 46.41

92.26 83.25 74.53 66.03

92.09 82.81 74.45 65.95

92.25 83.68 75.04 67.09

92.18 84.81 77.34 70.64

MIM

PGD

16

L. Wang and B. Liu

Table 4. Classification accuracy (%) given by an ensemble model consisting of 3 LeNet-5 member models trained with our CLDL based algorithm against black-box attacks on the MNIST dataset. MNIST



ADP

GAL

DVERGE

TRS

CLDL3,4,4

CLDL3,2,1

BIM

0.1 0.15 0.2 0.25

90.18 60.38 23.23 5.32

87.34 55.61 28.5 11.06

90.24 61.21 21.17 2.53

92.5 76.63 46.16 17.42

94.48 85.64 65.11 32.81

94.22 81.42 51.59 22.95

FGSM

0.1 0.15 0.2 0.25

93.29 79.58 52.99 27.38

90.77 69.94 47.35 30.13

93.51 80.75 55.98 27.78

94.55 86.21 70.64 48.1

95.43 89.82 78.39 57.51

95.5 88.56 72.73 46.77

MIM

0.1 0.15 0.2 0.25

90.21 63.05 24.69 5.58

85.82 53.72 27.49 10.64

90.52 63.67 23.88 3.16

92.31 76.81 46.83 16.58

94.25 85.31 65.1 30.44

94.07 81.56 51.33 21.34

PGD

0.1 0.15 0.2 0.25

89.66 56.83 19.42 3.06

84.87 47.69 21.69 6.44

89.82 57.86 17.32 1.05

91.91 73.01 39.19 11.71

93.75 83.5 58.94 24.81

93.84 78.92 46.91 16.65

Table 5. Robust accuracy (%) of an ensemble of 5 LeNet-5 models against black-box attacks on the MNIST dataset MNIST



ADP

GAL

DVERGE

TRS

CLDL3,4,4

CLDL3,2,1

BIM

0.1 0.15 0.2 0.25

88.43 53.68 18.98 2.21

90.26 66.19 34.1 12.02

89.49 60.83 23.04 4.01

94.01 82.77 55.12 22.2

95.01 87.18 68.76 40.58

93.98 78.76 51.46 25.3

FGSM

0.1 0.15 0.2 0.25

92.23 75.07 46.26 22.41

93.03 79.74 58.21 35.67

92.88 80.23 54.71 29.5

95.41 88.85 76.11 54.6

95.9 89.86 76.95 52.6

95.62 88.43 71.07 44.93

MIM

0.1 0.15 0.2 0.25

89.06 56.8 21.1 3.02

89.62 66.79 35.18 12.37

89.84 63.58 26.19 5.33

93.96 82.61 56.63 22.55

94.68 85.96 65.32 34.31

93.96 78.68 50.11 22.36

PGD

0.1 0.15 0.2 0.25

87.83 50.51 15.33 0.97

88.69 60.26 27.47 6.64

89.11 57.6 19.45 2.19

93.53 81.19 50.15 17.06

94.4 85.08 63.35 31.88

93.49 74.77 45.58 18.38

Adversarial Ensemble Training by Jointly Learning Label Dependencies

17

Table 6. Classification accuracy (%) of an ensemble of 3 LeNet-5 models against black-box attacks on the F-MNIST dataset. F-MNIST



ADP

GAL

DVERGE

TRS

CLDL3,2,4

BIM

0.08 0.1 0.15 0.2

38.38 28.39 10.55 1.78

54.43 44.37 22.63 7.10

39.42 28.16 10.58 2.69

54.71 44.61 25.06 10.95

54.19 44.99 27.84 13.98

FGSM

0.08 0.1 0.15 0.2

48.25 40.09 23.24 9.82

62.77 54.22 36.09 15.85

52.39 42.80 25.41 13.33

62.83 53.64 37.63 23.45

61.99 53.87 38.49 25.67

MIM

0.08 0.1 0.15 0.2

38.24 28.41 9.17 1.01

52.83 42.44 19.61 3.44

39.58 28.19 9.98 2.33

53.46 43.10 23.01 7.84

52.88 43.41 25.20 9.75

PGD

0.08 0.1 0.15 0.2

37.74 27.75 8.69 1.08

52.34 41.71 18.78 3.52

39.17 28.15 9.67 2.24

52.86 42.38 22.44 8.25

51.65 42.23 24.75 10.91

Table 7. Classification accuracy (%) of an ensemble of 5 LeNet-5 models against black-box attacks on the F-MNIST dataset. F-MNIST



ADP

GAL

DVERGE

TRS

CLDL3,2,4

BIM

0.08 0.1 0.15 0.2

38.01 28.57 10.83 2.01

52.62 42.47 23.04 8.86

41.36 29.67 11.15 2.63

60.30 51.10 33.67 20.48

61.37 51.72 35.46 24.58

FGSM

0.08 0.1 0.15 0.2

48.36 39.94 23.00 9.98

62.53 54.14 36.65 19.49

54.13 44.80 27.28 14.44

65.65 58.09 43.66 31.91

67.90 59.78 45.69 32.71

MIM

0.08 0.1 0.15 0.2

38.22 28.65 9.75 1.26

52.16 42.23 20.97 4.94

41.69 30.20 11.27 2.12

58.93 49.41 31.79 17.11

59.18 49.48 33.15 18.71

PGD

0.08 0.1 0.15 0.2

37.76 27.86 9.39 1.39

51.56 41.03 20.45 5.99

41.11 29.50 10.14 1.97

59.06 49.18 31.27 18.00

58.28 47.99 32.97 20.38

18

L. Wang and B. Liu

5 Conclusion In this paper, we proposed a novel adversarial ensemble training approach that leverages conditional label dependency learning. In contrast to existing methods that encode image classes with one-hot vectors, our algorithm can learn and exploit the conditional relationships between labels during member model training. Experimental results demonstrate that our approach is more robust against black-box adversarial attacks than state-of-the-art methods. Acknowledgment. This work was supported by Research Initiation Project (No.2021KB0PI01) and Exploratory Research Project (No.2022RC0AN02) of Zhejiang Lab.

References 1. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 2. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. In: Advances in Neural Information Processing Systems, vol. 25 (2012) 3. Russakovsky, O., et al.: Imagenet large scale visual recognition challenge. Int. J. Comput. Vision 115(3), 211–252 (2015) 4. Szegedy, C., Vanhoucke, V., Ioffffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 5. Graves, A., Jaitly, N.: Towards end-to-end speech recognition with recurrent neural networks. In: International Conference on Machine Learning, pp. 1764–1772. PMLR (2014) 6. Hannun, A., et al.: Deep speech: Scaling up end-to-end speech recognition, arXiv preprintarXiv:1412.5567 (2014) 7. Sutskever, I., Vinyals, O., Le, Q.V.: Sequence to sequence learning with neural networks. In: Advances in Neural Information Processing Systems, vol. 27 (2014) 8. Young, T., Hazarika, D., Poria, S., Cambria, E.: Recent trends in deep learning based natural language processing. IEEE Comput. Intell. Mag. 13(3), 55–75 (2018) 9. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. In: International Conference on Learning Representations (2015) 10. Papernot, N., McDaniel, P., Jha, S., Fredrikson, M., Celik, Z.B., Swami, A.: The limitations of deep learning in adversarial settings. In: IEEE European Symposium on Security and Privacy, pp. 372–387. IEEE (2016) 11. Carlini, N., Wagner, D.: Towards evaluating the robustness of neural networks. In: IEEE Symposium on Security and Privacy, pp. 39–57. IEEE (2017) 12. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks, arXiv preprintarXiv:1706.06083 (2017) 13. Xiao, C., Li, B., Zhu, J., He, W., Liu, M., Song, D.: Generating adversarial examples with adversarial networks, arXiv preprintarXiv:1801.02610 (2018) 14. Xiao, C., Zhu, J., Li, B., He, W., Liu, M., Song, D.: Spatially transformed adversarial examples, arXiv preprint arXiv:1801.02612 (2018) 15. Papernot, N., McDaniel, P., Goodfellow, I.: Transferability in machine learning: from phenomena to black-box attacks using adversarial samples, arXiv preprintarXiv:1605.07277 (2016)

Adversarial Ensemble Training by Jointly Learning Label Dependencies

19

16. Liu, Y., Chen, X., Liu, C., Song, D.: Delving into transferable adversarial examples and black-box attacks, arXiv preprintarXiv:1611.02770 (2016) 17. Inkawhich, N., Liang, K.J., Carin, L., Chen, Y.: Transferable perturbations of deep feature distributions, arXiv preprintarXiv:2004.12519 (2020) 18. Ilyas, A., Santurkar, S., Tsipras, D., Engstrom, L., Tran, B., Madry, A.: Adversarial examples are not bugs, they are features. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 19. Maqueda, A.I., Loquercio, A., Gallego, G., García, N., Scaramuzza, D.: Event-based vision meets deep learning on steering prediction for self-driving cars. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5419–5427 (2018) 20. Bojarski, M., et al.: End to end learning for self-driving cars, arXiv preprintarXiv:1604.07316 (2016) 21. Kariyappa, S., Qureshi, M.K.: Improving adversarial robustness of ensembles with diversity training, arXiv preprintarXiv:1901.09981 (2019) 22. Pang, T., Xu, K., Du, C., Chen, N., Zhu, J.: Improving adversarial robustness via promoting ensemble diversity. In: International Conference on Machine Learning, pp. 4970–4979. PMLR (2019) 23. Sen, S., Ravindran, B., Raghunathan, A.: EMPIR: ensembles of mixed precision deep networks for increased robustness against adversarial attacks. In: International Conference on Learning Representation, pp. 1–12 (2020) 24. Zhang, S., Liu, M., Yan, J.: The diversified ensemble neural network. In: Advances in Neural Information Processing Systems, vol. 33, pp. 16 001–16 011 (2020) 25. Yang, Z., et al.: TRS: transferability reduced ensemble via promoting gradient diversity and model smoothness. In: Advances in Neural Information Processing Systems, vol. 34 (2021) 26. Guo, B., Han, S., Han, X., Huang, H., Lu, T.: Label confusion learning to enhance text classification models. In: Proceedings of the AAAI Conference on Artificial Intelligence (2020) 27. Ren, K., Zheng, T., Qin, Z., Liu, X.: Adversarial attacks and defenses in deep learning. Engineering 6(3), 346–360 (2020) 28. Yuan, X., He, P., Zhu, Q., Li, X.: Adversarial examples: attacks and defenses for deep learning. IEEE Trans. Neural Netw. Learn. Syst. 30(9), 2805–2824 (2019) 29. Xu, H., et al.: Adversarial attacks and defenses in images, graphs and text: A review. Int. J. Autom. Comput. 17(2), 151–178 (2020) 30. Kurakin, A., Goodfellow, I., Bengio, S., et al.: Adversarial examples in the physical world (2016) 31. Dong, Y., et al.: Boosting adversarial attacks with momentum. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 9185–9193 (2018) 32. Akata, Z., Perronnin, F., Harchaoui, Z., Schmid, C.: Label-embedding for image classification. IEEE Trans. Pattern Anal. Mach. Intell. 38(7), 1425–1438 (2016) 33. Muller, R., Kornblith, S., Hinton, G.E.: When does label smoothing help, arXiv: Learning (2019) 34. Zhang, X., Zhang, Q.-W., Yan, Z., Liu, R., Cao, Y.: Enhancing label correlation feedback in multi-label text classification via multi-task learning, arXiv preprintarXiv:2106.03103 (2021) 35. Fu, C., Chen, H., Ruan, N., Jia, W.: Label smoothing and adversarial robustness, arXiv preprintarXiv:2009.08233 (2020) 36. Wang, Y., Zheng, Y., Jiang, Y., Huang, M.: Diversifying dialog generation via adaptive label smoothing, arXiv preprintarXiv:2105.14556 (2021) 37. Kullback, S., Leibler, R.A.: On information and sufficiency. Ann. Math. Stat. 22(1), 79–86 (1951) 38. Menéndez, M., Pardo, J.A., Pardo, L., Pardo, M.C.: The Jensen-Shannon divergence. J. Franklin Inst. 334(2), 307–318 (1997)

20

L. Wang and B. Liu

39. Deng, L.: The MNIST database of handwritten digit images for machine learning research. IEEE Signal Process. Mag. 29(6), 141–142 (2012) 40. Xiao, H., Rasul, K., Vollgraf, R.: Fashion-MNIST: a novel image dataset for benchmarking machine learning algorithms, arXiv preprintarXiv:1708.07747 (2017) 41. Krizhevsky, A., Hinton, G.: Learning multiple layers of features from tiny images. Technical report, Department of Computer Science, University of Toronto (2009) 42. Kingma, D.P., Ba, J.: Adam: a method for stochastic optimization, arXiv preprintarXiv:1412. 6980 (2014) 43. Ding, G.W., Wang, L., Jin, X.: AdverTorch v0. 1: an adversarial robustness toolbox based on pytorch, arXiv preprintarXiv:1902.07623 (2019) 44. Yang, H., et al.: DVERGE: diversifying vulnerabilities for enhanced robust generation of ensembles. Adv. Neural. Inf. Process. Syst. 33, 5505–5515 (2020) 45. LeCun, Y., et al.: Backpropagation applied to handwritten zip code recognition. Neural Comput. 1(4), 541–551 (1989)

PFGE: Parsimonious Fast Geometric Ensembling of DNNs Hao Guo, Jiyong Jin, and Bin Liu(B) Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab, Hangzhou 311121, China {guoh,jinjy,liubin}@zhejianglab.com

Abstract. Ensemble methods are commonly used to enhance the generalization performance of machine learning models. However, they present a challenge in deep learning systems due to the high computational overhead required to train an ensemble of deep neural networks (DNNs). Recent advancements such as fast geometric ensembling (FGE) and snapshot ensembles have addressed this issue by training model ensembles in the same time as a single model. Nonetheless, these techniques still require additional memory for test-time inference compared to single-model-based methods. In this paper, we propose a new method called parsimonious FGE (PFGE), which employs a lightweight ensemble of higherperforming DNNs generated through successive stochastic weight averaging procedures. Our experimental results on CIFAR-{10,100} and ImageNet datasets across various modern DNN architectures demonstrate that PFGE achieves 5x memory efficiency compared to previous methods, without compromising on generalization performance. For those interested, our code is available at https://git hub.com/ZJLAB-AMMI/PFGE. Keywords: deep learning · ensemble method · generalization · geometric ensembling

1 Introduction Ensemble methods are a popular way to enhance the generalization performance of machine learning models [2, 4, 5, 30]. However, their application with modern deep neural networks (DNNs) poses challenges. With millions or even billions of parameters, directly ensembling k DNNs results in k-folded computational overhead in terms of both training time and memory requirements for test-time inference. Recently, fast geometric ensembling (FGE) and snapshot ensemble (SNE) methods have been proposed to overcome the hurdle of training time by enabling the training of DNN ensembles in the same time as a single model [8, 14]. Nevertheless, these techniques still require higher memory overhead for test-time inference compared to single-model-based approaches. To reduce the test-time cost of ensembles, some researchers have proposed model compression and knowledge distillation methods [1, 13], which aim to train one single model that embodies the “knowledge” of the ensembles. However, they do not account for the computational overhead of ensemble training. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 21–32, 2023. https://doi.org/10.1007/978-981-99-4742-3_2

22

H. Guo et al.

This paper addresses the important challenge of reducing both training time cost and test-time memory budget for deep neural network (DNN) ensembling. This issue is especially critical for DNN applications with limited memory resources, such as edge computing on devices with restricted memory space [20, 26, 29, 31]. We introduce a novel algorithm called PFGE that achieves 5x memory efficiency compared to prior ensemble methods for test-time inference, without sacrificing generalization performance or training efficiency. In other words, PFGE enables the reduction of both the trainingtime and test-time computational budgets for DNN ensembling while maintaining high generalization performance. The remainder of this paper is structured as follows. Section 2 presents related works. Section 3 introduces the proposed PFGE algorithm. Section 4 presents experimental results. Finally, we conclude the paper in Sect. 5.

2 Related Works Ensemble methods have traditionally posed significant computational challenges for learning with modern DNNs, due to the high overhead of both ensemble training and testtime inference. Nonetheless, researchers have made notable strides in adapting ensemble methods to DNNs, including FGE [8], SNE [14], SWAG [22], Monte-Carlo dropout [7], and deep ensembles [6, 19]. Of these, FGE, SNE, and SWAG are most relevant to this work, as they all employ a cyclical learning rate and enable the training of DNN ensembles in the same time as for one DNN model. Both FGE and SNE construct DNN ensembles by sampling network weights from an SGD trajectory that corresponds to a learning rate [25]. Running SGD with a cyclical learning rate is equivalent in principle to performing SGD sampling with periodic warm restarts [21]. Researchers have shown that the cyclical learning rate provides an efficient approach for collecting high-quality DNN weights that define the models for ensembling, as demonstrated by the authors of [8, 14]. Compared to SNE, FGE has a unique feature - a geometric explanation for its ensemble generation approach. Specifically, FGE is based on a geometric insight into the DNN loss landscape, which suggests that simple curves exist connecting local optima, over which both training and test accuracy remain approximately constant. FGE leverages this insight to efficiently discover these high-accuracy pathways between local optima. Building upon the success of FGE, researchers proposed SWA, which averages high-performing network weights obtained by FGE for test-time inference [16]. The underlying geometric insight of SWA is that averaging weights sampled from an SGD trajectory corresponding to a cyclical or constant learning rate can lead to wider optima in the DNN loss landscape, which results in better generalization [17]. SWAG, on the other hand, uses the SWA solution as the center of a Gaussian distribution to approximate the posterior of the network weights [22]. The resulting model ensembles are generated by sampling from this Gaussian.

PFGE: Parsimonious Fast Geometric Ensembling of DNNs

23

3 The Proposed PFGE Algorithm The design of PFGE is inspired by the observation that running a single stochastic weight averaging (SWA) procedure can lead to a higher-performing local optimum [16], and that running a series of SWA procedures successively may reveal a set of higher-performing weights than those obtained with SGD [10]. While FGE employs an ensemble of models found by SGD, we conjecture that using an ensemble of higher-performing models discovered by SWA, PFGE could achieve comparable generalization performance with far fewer models. As detailed in Sect. 4, our experimental results indeed support this conjecture. We provide the pseudo-codes for implementing SWA, FGE, and PFGE in Algorithms 1, 2, and 3, respectively. Each algorithm iteratively performs stochastic gradient-based weight updating, starting from a local optimum w0 found by preceding SGD. The iterative weight updating operation uses a cyclical learning rate to allow the weight trajectory to escape from the current optimum, discover new local optima, and converge to them. A graphical representation of the cyclical learning rate is shown in Fig. 1, where α1 and α2 define the minimum and maximum learning rate values, respectively, c represents the cycle length, n corresponds to the total number of allowable iterations that depends on the training time budget, and P denotes the period for collecting member models for PFGE.

The iterative weight updating operation employs a cyclical learning rate for letting the weight trajectory escape from the current optimum, discover new local optima, and converge to them. Figure 1 provides a graphical illustration of the cyclical learning rate, where α1 and α2 define the lower and upper bounds of the learning rate values, respectively, c represents the cycle length, n corresponds to the total number of allowable

24

H. Guo et al.

iterations that determine the training time budget, and P denotes the period for storing member models used in PFGE.

PFGE: Parsimonious Fast Geometric Ensembling of DNNs

25

Fig. 1. A conceptual diagram of the cyclical learning rate used by SWA, FGE and PFGE. The circles mask the time instances for recording the local optima discovered along the SGD trajectory. The real relationship between c, P, and n is that P is an integer multiple of c, and n is an integer multiple of P. Here only one example case is plotted, in which P = 2c and n = 2P. See the text for detailed explanations for all parameters involved.

As depicted in Algorithm 1, SWA maintains a running average of network weights collected at every c iterations, and ultimately outputs a single model with the weight wSWA used for test-time inference. In practice, wSWA is derived as the average of n /c weights traversed along the SGD trajectory. PFGE differentiates from FGE in the approach used to generate ensemble models. Unlike FGE, which employs purely SGD to generate its ensemble models (see operations 4 and 6 in Algorithm 2), PFGE uses multiple SWA operations performed successively (see operations 10 and 13 in Algorithm 3) to generate the models. In the successively performed SWA procedures, the output of an SWA procedure is used as the initialization for its subsequent SWA operation (see operation 14 in Algorithm 3). The code to implement PFGE is available at https://github.com/ZJLAB-AMMI/PFGE.

4 Experiments We compare PFGE against prior state-of-the-art methods, including FGE [8], SWA [16], and SWAG [22], on widely-used image datasets such as CIFAR-100 [18] and ImageNet ILSVRC-2012 [3, 23], to evaluate its generalization performance for image classification. A more comprehensive experimental analysis of PFGE, including an examination of its performance in uncertainty calibration and mode connectivity, can be found in the extended version of this paper available at arXiv [9]. 4.1 Experimental Setting As illustrated in Algorithms 1–3, SWA, FGE, and PFGE are all initialized with a local optimum w0 and a learning rate schedule. For all architectures and datasets considered, we initialize all algorithms under comparison with the same w0 and identical learning rate settings. Following [8], we employ a triangle learning rate schedule, as depicted in Fig. 1. Specifically, we set c to iteration numbers corresponding to 2 or 4 epochs (following [8]), P to 10 epochs, and n to 40 or 20 epochs. For α1 , and α2 , we adopt the same values as used in [8]. We fix the mini-batch size for model training at 128. For CIFAR-{10,100}, we obtain w0 by running a standard SGD with momentum regulated by the same type of decaying learning rate schedule as employed in [16], until convergence to minimize an L2-regularized cross-entropy loss. We adopt the hyperparameters of SGD, such as the weight decay parameter and momentum factor, in the same

26

H. Guo et al.

manner as in [16]. For ImageNet, we use pre-trained models ResNet-50 [11], ResNet152 [11], and DenseNet-161 [15] available in PyTorch to initialize w0 , and set n to be the iteration number corresponding to 40 epochs. In our experiments, PFGE consistently employs four model components and utilizes the average of their softmax outputs for test-time prediction. Conversely, FGE and SWAG use the entire ensemble, consisting of 20 models, for test-time inference. To ensure that all algorithms under comparison have equal overhead in test-time inference, we compare PFGE against FGE* and SWAG*, which are lightweight versions of FGE and SWAG, respectively. The sole distinction between FGE* (resp. SWAG*) and FGE (resp. SWAG) is that the former generates test-time predictions based on the last four model components added to the ensemble set S, while the latter utilizes all model components in S. Our performance metric of interest is test accuracy. 4.2 CIFAR Datasets We evaluate the performance of PFGE and competitor methods on various network architectures, including VGG16 [24], Pre-activation ResNet164 (PreResNet-164) [12], WideResNet-28-10 [28], using datasets CIFAR-{10,100}. We independently execute each algorithm at least three times and report the average test accuracy results along with their corresponding standard errors in Tables 1 and 2. Our experimental findings indicate that on CIFAR-10, the VGG16 architecture achieves the highest test accuracy (93.41%) using PFGE. Table 1. Test accuracy on CIFAR-10. We compare PFGE with FGE* , SWA, and SWAG* . Best results for each architecture are bolded. Results for FGE and SWAG, which use 5x memory resources for test-time inference compared with PFGE, are also listed for reference. Algorithm

Test Accuracy (%) VGG16

PreResNet

WideResNet

PFGE

93.41 ± 0.08

95.70 ± 0.05

96.37 ± 0.03

FGE*

93.03 ± 0.18

95.52 ± 0.08

96.14 ± 0.07

SWA

93.33 ± 0.02

95.78 ± 0.07

96.47 ± 0.04

SWAG*

93.24 ± 0.06

95.45 ± 0.14

96.36 ± 0.04

FGE

93.40 ± 0.08

95.57 ± 0.05

96.27 ± 0.02

SWAG

93.37 ± 0.07

95.61 ± 0.11

96.45 ± 0.07

Regarding PreResNet-164 and WideResNet-28-10 architectures, our findings demonstrate that PFGE outperforms FGE* and SWAG*, but performs slightly worse than SWA in terms of test accuracy; SWA achieves the highest accuracy (95.78% and 96.47%). On CIFAR-100, we observe that PFGE delivers the best performance in terms of test accuracy for all network architectures. Furthermore, as evident from Tables 1 and 2, even when compared with FGE and SWAG, which utilize the full ensemble of model

PFGE: Parsimonious Fast Geometric Ensembling of DNNs

27

components for test-time inference, PFGE attains comparable or superior performance in terms of test accuracy while only requiring a 20% memory overhead for test-time inference. Table 2. Test accuracy on CIFAR-100. We compare PFGE with FGE* , SWA, and SWAG* . Best results for each architecture are bolded. Results for FGE and SWAG, which use 5x memory resources for test-time inference compared with PFGE, are also listed for reference. Algorithm

Test Accuracy (%) VGG16

PreResNet

WideResNet

PFGE

74.17 ± 0.04

80.06 ± 0.13

81.96 ± 0.01

FGE*

73.49 ± 0.24

79.76 ± 0.06

81.09 ± 0.25

SWA

73.83 ± 0.20

79.97 ± 0.06

81.92 ± 0.02

SWAG*

73.77 ± 0.18

79.24 ± 0.04

81.55 ± 0.06

FGE

74.34 ± 0.05

80.17 ± 0.09

81.62 ± 0.16

SWAG

74.15 ± 0.17

80.00 ± 0.03

81.83 ± 0.12

4.3 IMAGENET We experiment with network architectures ResNet-50 [11], ResNet-152 [11], and DenseNet-161 [15] on ImageNet ILSVRC-2012 [3, 23]. As in our previous experiments, we execute each algorithm three times independently. The results are summarized in Table 3, which reveals that PFGE surpasses FGE* and SWAG* in terms of test accuracy and achieves comparable performance to FGE and SWAG. Table 3. Test accuracy on Imagenet. We compare PFGE with FGE* , SWA, and SWAG* . Best results for each architecture are bolded. Results for FGE and SWAG, which use 5x memory resources for test-time inference compared with PFGE, are also listed for reference. Algorithm

Test Accuracy (%) ResNet-50

ResNet-152

DenseNet-161

PFGE

77.06 ± 0.19

79.07 ± 0.04

78.72 ± 0.08

FGE*

76.85 ± 0.07

78.73 ± 0.02

78.53 ± 0.08

SWA

76.70 ± 0.38

78.82 ± 0.02

78.41 ± 0.29

SWAG*

76.19 ± 0.29

78.72 ± 0.04

77.19 ± 0.83

FGE

77.17 ± 0.08

79.13 ± 0.06

79.91 ± 0.06

SWAG

76.70 ± 0.35

79.10 ± 0.06

77.94 ± 0.62

28

H. Guo et al.

4.4 Performance of Separate Models in The Ensemble We evaluate the performance of each individual model in PFGE and FGE. Each separate model can be viewed as a “snapshot” of the SGD trajectory.

Fig. 2. Ensemble performance of PFGE and FGE as a function of the training iteration index i. We see that PFGE only uses 4 model components, while achieves a performance on par with or even better than that of FGE which uses 20 model components. Crosses represent the performance of separate “snapshot” models, and diamonds show the performance of the ensembles composed of all models available at the given iteration. Left column: CIFAR-10. Right column: CIFAR-100. Top row: VGG16. Middle Row: PreResNet-164. Bottom row: WideResNet-28-10.

PFGE: Parsimonious Fast Geometric Ensembling of DNNs

29

To evaluate the generalization performance of these “snapshot” models, we collect them at various training checkpoints and measure their test accuracy. The results are presented in Fig. 2, from which we observe that: • the individual models of PFGE outperform those of FGE in terms of test accuracy for all network architectures and both datasets considered; • for both PFGE and FGE, the test accuracy of the model ensemble increases with an increase in the number of member models utilized. Our experimental findings suggest that due to the superior quality of separate model components, PFGE outperforms FGE and SWAG when they are equipped with an equal number of model components (Fig. 3).

Fig. 3. Ensemble performance of PFGE and FGE on Imagenet as a function of the training iteration index i. We see that PFGE only uses 4 model components, while achieves a performance on par with that of FGE which uses 20 model components. Top Left: ResNet-50. Top Right: ResNet-152. Bottom: DenseNet-161. Crosses represent the performance of separate “snapshot” models, and diamonds show the performance of the ensembles composed of all models available at the given iteration.

4.5 On Training Efficiency and Test-time Cost As highlighted in Sect. 4.1, all methods employ the same learning rate setting, execute an equal number of SGD iterations, and possess identical training efficiency (in terms of training time). Their test-time cost, primarily in terms of memory overhead, is proportional to the number of member models utilized, represented here by K. For PFGE, we

30

H. Guo et al.

    have K = Pn = 4. In contrast, for FGE [8] and SWAG [22], we have K = nc = 20 according to our experimental setting in Sect. 4.1. This implies that using PFGE neces4 = 20% memory overhead for test-time inference, compared to FGE [8] sitates only a 20 and SWAG [22]. Concerning test-time efficiency, SWA [16] is the preferred option since it always produces a single model (K = 1) for use. However, in comparison to ensemble methods, SWA faces issues with uncertainty calibration [9, 22], a crucial aspect closely linked to detecting out-of-distribution samples [9, 22]. 4.6 Mode Connectivity Test In a related study [27], researchers illustrate that ensembles of trained models tend to converge to locally smooth regions of the loss landscape, resulting in the highest test accuracy. This indicates that the mode connectivity of model components closely influences the ensemble’s generalization performance. To determine the mode connectivity of PFGE and FGE model components, we conducted an experiment and discovered that the mode connectivity of PFGE’s model components is superior to that of FGE. Further details are available in our extended paper at [9].

5 Conclusions Ensemble methods are widely applied to enhance the generalization performance of machine learning models. However, their use with modern DNNs presents challenges due to the extensive computational overhead required to train the DNN ensemble. Recent advancements such as FGE and SWAG have addressed this issue by training model ensembles in a timeframe equivalent to that required for a single model. Nevertheless, these techniques still require extra memory for test-time inference when compared to single-model-based approaches. This paper introduces a new method called parsimonious FGE (PFGE), which is based on FGE but utilizes a lightweight ensemble of higher-performing DNNs produced via successive SWA procedures. Our experimental findings on CIFAR-10, CIFAR-100, and ImageNet datasets across various state-of-theart DNN architectures demonstrate that PFGE can achieve up to five times memory efficiency compared to prior art methods such as FGE and SWAG without sacrificing generalization performance. Acknowledgment. This work was supported by Research Initiation Project (No. 2021KB0PI01) and Exploratory Research Project (No. 2022RC0AN02) of Zhejiang Lab.

References 1. Bucilua, C., Caruana, R., Niculescu-Mizil, A.: Model compression. In: Proceedings of the 12th ACM SIGKDD, pp. 535–541 (2006) 2. Caruana, R., Niculescu-Mizil, A., Crew, G., Ksikes, A.: Ensemble selection from libraries of models. In: ICML, p. 18 (2004) 3. Deng, J., Dong, W., Socher, R., Li, L., Li, K., Fei-Fei, L.: Imagenet: a large-scale hierarchical image database. In: CVPR, pp. 248–255. IEEE (2009)

PFGE: Parsimonious Fast Geometric Ensembling of DNNs

31

4. Dietterich, T.G.: Ensemble methods in machine learning. In: Kittler, J., Roli, F. (eds.) MCS 2000. LNCS, vol. 1857, pp. 1–15. Springer, Heidelberg (2000). https://doi.org/10.1007/3540-45014-9_1 5. Džeroski, S., Ženko, B.: Is combining classifiers with stacking better than selecting the best one? Mach. Learn. 54(3), 255–273 (2004) 6. Fort, S., Hu, H., Lakshminarayanan, B.: Deep ensembles: a loss landscape perspective. arXiv preprint arXiv:1912.02757 (2019) 7. Gal, Y., Ghahramani, Z.: Dropout as a Bayesian approximation: representing model uncertainty in deep learning. In: ICML, pp. 1050–1059. PMLR (2016) 8. Garipov, T., Izmailov, P., Podoprikhin, D., Vetrov, D., Wilson, A.G.: Loss surfaces, mode connectivity, and fast ensembling of DNNs. In: Advances in Neural Information Processing Systems, pp. 8803–8812 (2018) 9. Guo, H., Jin, J., Liu, B.: PFGE: parsimonious fast geometric ensembling of DNNs. arXiv preprint arXiv:2202.06658 (2022) 10. Guo, H., Jin, J., Liu, B.: Stochastic weight averaging revisited. Appl. Sci. 13(5), 1–17 (2023) 11. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, pp. 770–778 (2016) 12. He, K., Zhang, X., Ren, S., Sun, J.: Identity mappings in deep residual networks. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 630–645. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46493-0_38 13. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531 (2015) 14. Huang, G., Li, Y., Pleiss, G., Liu, Z., Hopcroft, J.E., Weinberger, K.Q.: Snapshot ensembles: train 1, get M for free. In: International Conference on Learning Representations (2017) 15. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: CVPR, pp. 4700–4708 (2017) 16. Izmailov, P., Podoprikhin, D., Garipov, T., Vetrov, D., Wilson, A.G.: Averaging weights leads to wider optima and better generalization. In: Proceedings of Conference on Uncertainty in Artificial Intelligence (UAI), pp. 1–10 (2018) 17. Keskar, N.S., Mudigere, D., Nocedal, J., Smelyanskiy, M., Tang, P.T.P.: On large batch training for deep learning: Generalization gap and sharp minima. In: International Conference on Learning Representations (2017) 18. Krizhevsky, A., Hinton, G.: Learning multiple layers of features from tiny images. Technical report, University of Toronto (2009) 19. Lakshminarayanan, B., Pritzel, A., Blundell, C.: Simple and scalable predictive uncertainty estimation using deep ensembles. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 20. Li, P., Seferoglu, H., Dasari, V.R., Koyuncu, E.: Model-distributed DNN training for memoryconstrained edge computing devices. In: IEEE International Symposium on Local and Metropolitan Area Networks (LANMAN), pp. 1–6. IEEE (2021) 21. Loshchilov, I., Hutter, F.: SGDR: stochastic gradient descent with warm restarts. In: International Conference on Learning Representations (2017) 22. Maddox, W.J., Izmailov, P., Garipov, T., Vetrov, D.P., Wilson, A.G.: A simple baseline for Bayesian uncertainty in deep learning. In: Advances in Neural Information Processing Systems 32, pp. 13153–13164 (2019) 23. Russakovsky, O., et al.: Imagenet large scale visual recognition challenge. Int. J. Comput. Vision 115(3), 211–252 (2015) 24. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. In: ICLR (2015) 25. Smith, L.N.: Cyclical learning rates for training neural networks. In: IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 464–472. IEEE (2017)

32

H. Guo et al.

26. Varghese, B., Wang, N., Barbhuiya, S., Kilpatrick, P., Nikolopoulos, D.S.: Challenges and opportunities in edge computing. In: IEEE International Conference on Smart Cloud (SmartCloud), pp. 20–26. IEEE (2016) 27. Yang, Y., et al.: Taxonomizing local versus global structure in neural network loss landscapes. In: Advances in Neural Information Processing Systems, vol. 34 (2021) 28. Zagoruyko, S., Komodakis, N.: Wide residual networks. In: Proceedings of the British Machine Vision Conference (BMVC) (2016) 29. Zhou, A., et al.: Brief industry paper: optimizing memory efficiency of graph neural networks on edge computing platforms. In: IEEE 27th Real-Time and Embedded Technology and Applications Symposium (RTAS), pp. 445–448. IEEE (2021) 30. Zhou, Z.: Ensemble Methods: Foundations and Algorithms. CRC Press, Boca Raton (2012) 31. Zhang, P., Liu, B.: Commonsense Knowledge Assisted Deep Learning with Application to Size-Related Fine-Grained Object Detection. arXiv preprint arXiv:2303.09026 (2023)

Research on Indoor Positioning Algorithm Based on Multimodal and Attention Mechanism Chenxi Shi1 , Lvqing Yang1(B) , Lanliang Lin1 , Yongrong Wu1 , Shuangyuan Yang1 , Sien Chen2,3 , and Bo Yu1,4 1 School of Informatics, Xiamen University, Xiamen 361005, China

[email protected]

2 School of Navigation, Jimei University, Xiamen 361021, China 3 Tech Vally (Xiamen) Information Technology Co., Ltd., Xiamen 361005, China 4 Zijin Zhixin (Xiamen) Technology Co., Ltd., Xiamen 361005, China

Abstract. RFID (Radio Frequency Identification) technology is an automatic identification technology that has received widespread attention from indoor positioning researchers due to its high stability and low power consumption. We proposes a multimodal data indoor positioning algorithm model based on RFID and WiFi, named Multimodal Indoor Location Network (MMILN). The common deep learning paradigms, Embedding and Pooling are used to process and pretrain different modalities of data, in order to obtain more data features. At the same time, in order to overcome the limitations of the received signal strength, WiFi hotspot names are introduced as another modality of data to compensate for the instability of a single signal. In order to better utilize the different modalities of data, we designs a location activation unit based on the idea of attention mechanism to calculate the weighted sum of the collected signals. In addition, we designs an adaptive activation function, SoftReLU, to better assist model training and prediction, given the special characteristics of the indoor positioning task and data. Experimental results show that after introducing the position activation unit and the adaptive activation function, the mean absolute error of indoor positioning decreases to 0.178 m, which is a 40.86% reduction compared to the baseline model, significantly improving the accuracy of indoor positioning. Keywords: Indoor positioning · Radio Frequency Identification · multimodal · attention mechanism

1 Introduction With the widespread use of wireless devices, indoor positioning is becoming a key technology in many Internet of Things applications. Compared with the limitations of global satellite navigation systems such as China’s Beidou Navigation System [1], the US GPS System [2], and the Russian GLONASS Navigation System [3] in indoor positioning, the field of indoor positioning has received increasing attention from academia and industry in recent years. Indoor positioning typically focuses on inferring the user’s location © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 33–45, 2023. https://doi.org/10.1007/978-981-99-4742-3_3

34

C. Shi et al.

within limited indoor environments such as shopping malls, hospitals, and office buildings. These environments are typically small and contain multiple rooms, facilities, and different levels of signal obstruction. As users move around indoors, they receive various signals, including WiFi, RFID, infrared, visible light, Bluetooth, ultrasound, and so on [4]. These signals are often subject to interference and attenuation to varying degrees, making it a very challenging task to achieve indoor positioning using electromagnetic signals. Nowadays, RFID has been recognized as an emerging technology for “ubiquitous positioning” [5], especially in indoor positioning, where RFID is currently one of the popular methods. The traditional RFID indoor positioning system is an intelligent technological system based on computers, which integrates RFID data acquisition, RFID data processing and transmission, GIS spatial analysis and querying, and other technologies. Its general idea is to use parameters such as Received Signal Strength Indicator (RSSI) and phase of the signals received by RFID readers, and then use positioning algorithms to calculate the distance and direction, in order to infer the location of the mobile user [6]. In traditional positioning algorithms, including hyperbolic positioning [7] or Angle of Arrival (AOA) based on received signals [8], and later, positioning RFID tags through simulating virtual antenna arrays [9, 10], the concept of Synthetic Aperture Radar (SAR) or Inverse Synthetic Aperture Radar (ISAR) is utilized, which can achieve positioning accuracy of centimeter or even millimeter level in some cases. However, in real RFID application scenarios, the indoor building facilities, especially metal objects, inevitably cause multipath effects, which can destroy the signal measurement and bring huge challenges to the positioning accuracy of RFID positioning systems. In recent years, with the rise of artificial intelligence, indoor positioning algorithms using machine learning and deep learning have received increasing attention from researchers. Since the propagation of RFID signals may be severely affected by complex indoor environments, each location in the target area will have its unique signal pattern. The more complex the environment, the greater the signal discrimination. Therefore, using machine learning methods to solve indoor positioning problems is a practical and feasible solution. However, current RFID indoor positioning algorithms still have the following deficiencies: • Multiple RFID readers or tags deployed throughout the indoor environment are not only greatly influenced by location and environment, but also inevitably produce multipath effects. Collisions that occur during signal transmission lead to signal fluctuations or errors, resulting in loss of positioning accuracy. • Whether it is the traditional geometric indoor positioning method or the emerging machine learning method, both have limited discriminative ability for sensitive data. At the same time, deficiencies such as slow computational speed and weak prediction capabilities on high-dimensional data in complex scenarios cannot be ignored. • Although existing RFID indoor positioning algorithms can achieve centimeter-level accuracy in some scenarios, this is subject to great limitations in the environment. In more realistic scenarios, the indoor environment contains various electromagnetic signals. How to utilize this data for better positioning is a research focus in this field.

Research on Indoor Positioning Algorithm Based on Multimodal

35

To address the aforementioned limitations, we proposes using multimodal data to compensate for the information loss caused by signal collisions during the transmission of a single RFID signal. Embedding and Pooling paradigms are employed to address the sparsity issue of high-dimensional data. Customized activation functions and regularization techniques are also utilized to facilitate model training and prediction. Moreover, a location activation unit algorithm is proposed to perform weighted operations on different modalities of data to better exploit scene features according to the task at hand. The aim of this paper is to leverage deep learning in indoor positioning methods for more effective utilization of various data in the scene, improve the accuracy of indoor positioning, and apply it to a variety of RFID-based identification scenarios.

2 Related Work RFID, as one of the supporting technologies of the Internet of Things, is a non-contact identification technology that uses wireless radio frequency signals to read and transmit information stored in electronic tags. RFID technology has the characteristics of nonline-of-sight transmission and fast identification speed, making it very suitable for indoor positioning scenarios. In this section, we will introduce representative traditional RFID indoor positioning methods and emerging indoor positioning methods based on machine learning and deep learning. Due to the complex indoor environment and high precision requirements, indoor positioning systems are very difficult to use in practical applications, including signal strength collection, database construction, and accurate prediction of positioning algorithms. Currently, the positioning modes for indoor positioning services include Received Signal Strength Indicator (RSSI), Angle of Arrival (AOA), Time of Arrival (TOA), and Time Difference of Arrival (TDOA). Based on these features, researchers have proposed many classical geometric positioning methods. The LANDMARC algorithm proposed by Liu et al. [11] is a classic indoor positioning algorithm based on active RFID. The core idea is the centroid weighting algorithm based on RSSI. The algorithm calculates the average value of the positions of K nearest neighbor tags using the Euclidean distance as the calculation method, which eliminates the common interference of environmental factors on the signal propagation of these tags in nearby locations and improves the accuracy of object positioning. The SpotON system proposed by Hightower et al. [12] uses three or more readers as base stations to record the RSSI of each reader, and calculates the tag position through triangulation. Some works [13, 14] estimate the position of RFID tags by using the phase distribution of received signals, and use directional antennas or antenna arrays to judge the possible source direction of active tag signals based on the receiving angle. Later fine-grained research [9, 15] applied Synthetic Aperture Radar (SAR) to positioning, and by simulating multiple virtual antennas through the relative motion between RFID antennas and tags, more spatial information of RFID tags could be extracted.

36

C. Shi et al.

In recent years, with the advent of the AI wave, indoor positioning, as a new technology application in the field of machine learning and deep learning, has achieved many good results. In 2018, Kim et al. [16] proposed a method combining stacked autoencoders and artificial neural networks to estimate the specific location of users on the floor. They used stacked autoencoders to reduce the dimensionality and filter noise of the input received signal strength, and then used artificial neural networks for classification, achieving an accuracy of 97.2% on the classification task. However, using grid partitioning for the classification task makes the model prone to overfitting, reducing the model’s generalization ability. In terms of time-series data, Bai et al. [17] used three sparse autoencoders combined with a recurrent neural network for indoor positioning, reducing the MSE of mobile user positioning to 1.6 m. However, the recurrent neural network used in the paper has high computational complexity and suffers from the problem of vanishing gradients for long-tail data, which still needs to be improved. In terms of multimodal data, Gan et al. [18] fused WiFi, geomagnetic, and Bluetooth signals using deep neural networks and smoothed the data using Kalman filtering, achieving a RMSE of less than 1 m. However, the paper did not clearly demonstrate the role of different types of data in the positioning process, and how multimodal data can play their respective roles in the fusion process is a key point that needs further exploration.

3 Multimodal Indoor Location Network 3.1 Embedding The indoor positioning task of multimodal data is a multi-group classification task, where one group of data represents one feature category, for example, [RFID RSSI = {−32, − 96, ……, −64}, WiFi RSSI = {−28, −54, ……, −100}]. These signals are transformed into high-dimensional sparse binary features through encoding. The i-th feature group of the encoding vector is denoted as ti ∈ RKi , where Ki is the dimension of the i-th feature group, meaning that feature group i contains Ki different IDs. ti [j] is the j-th  i  element of ti , ti [j]∈{0,1}, and K j=1 ti j = k, where k = 1 means one-hot encoding, and k > 1 means mulit-hot encoding. Thus, a multimodal data can be represented as  K = K, x = [t1T , t2T , . . . . . . , t TM ]T , where M is the number of feature groups, and M i i=1 where K is the dimension of the entire feature space. Using this method, the example with two feature groups shown in Fig. 1.

Fig. 1. RFID and WiFi feature binary example diagrams.

In the WiFi data, different hotspots are represented by the same RSSI, which is too simplistic. In order to highlight the characteristics of multimodal data, we introduces the embedding representation of WiFi hotspots on the basis of embedding of WiFi RSSI. Similarly, embedding is performed on 28 WiFi hotspots separately, and the WiFi hotspot

Research on Indoor Positioning Algorithm Based on Multimodal

37

of the i-th data is defined as ai ∈ RK , where K is the of the WiFi  dimension    hotspot vector, j = k, when k = and K = 28 here. ai [j] is the j-th hotspot of ai , ai j ∈ {0, 1}, K a i j=1 1, it means that the point has only one unique WiFi hotspot, while when k > 1, it means that the point can receive multiple WiFi hotspots. After the embedding representation, for the WiFi hotspot ai of the i-th data, let P i = [p1i , . . . , pij , . . . , piK ] ∈ RH ×K , where  pji ∈ RH is an H-dimensional embedding vector, and if ai j = 1, j ∈ {i1 , i2 , . . . . . . , ik }, the embedding representation of ai is a list: {ei1 , ei2 , . . . . . . , eik } = {pii1 , pii2 , . . . , piik }. After converting the vectors into high-dimensional binary format, they are transformed into low-dimensional dense feature representations using an embedding method. For the i-th feature group ti , let W i = [w1i , . . . , wij , . . . , wiKi ] ∈ RD×Ki denote the i-th embedding lookup table, where wji ∈ RD is a D-dimensional embedding vector, following the lookup table mechanism: • If ti is a one-hot vector and the j-th element ti [j] = 1, then the embedding representation of ti is a single vector: ei = wji . • If ti is a multi-hot vector and ti [j] = 1, j ∈ {i1 , i2 , . . . . . . , ik }, the embedding representation of t_i is an embedding list: {ei1 , ei2 , . . . . . . , eik }= {wii1 , wii2 , . . . . . . , wiik }. The results before and after embedding are shown in the Fig. 2. It can be seen that after embedding, the original irregular RFID and WiFi data can be distinguished. “WiFi Join” in the figure refers to the result obtained by concatenating WiFi RSSI and WiFi hotspots.

Fig. 2. Comparison results of RFID and WiFi data before and after embedding.

3.2 Location Activate Unit In limited dimensions, how to express the importance of different modal data in a vector elegantly? Inspired by the key-value attention mechanism, we proposes a novel attention mechanism calculation method, named Location Activate Unit (LAU), for the multimodal indoor positioning task. The first goal is to enable the model to adaptively calculate the correlation degree of different features when learning different modal data for a unified task. The second goal is to better fuse data from different modalities by LAU calculation, mapping low-dimensional information to a high-dimensional space. In the following, we will introduce the algorithm principle of LAU used in this paper.

38

C. Shi et al.

In LAU, the RFID RSSI and WiFi data after the Embedding layers are respectively denoted as R_Vec and W_Vec, where W_Vec contains WiFi RSSI and WiFi hotspot information. The algorithm flow of LAU is as follows: R_Vec is transformed by linear transformation to obtain Query (referred to as Q), and W_Vec is transformed by linear transformation to obtain Key (referred to as K), and the length of K is recorded at the same time, denoted as Key_length. The purpose of this design is to use WiFi data to help the model learn RFID signals and their corresponding locations that are more conducive to the task. Therefore, in the process of calculating the attention distribution, the method of dot product model is first used to calculate the proportion of Q and K separately, to obtain a new query vector Querynew , as shown in formulas (1), where D is the dimension of the vector after embedding, and qi , ki are the components of Q and K at the i-th index position. The new query vector Qnew is linearly subtracted and element-wise multiplied (Hadamard product) with the original K separately, so it is necessary to ensure that the dimensions of Q and K are consistent during this step, to obtain two sets of weight expressions, QK 1 and QK 2 , as shown in formula (2) and formula (3). D×D = Qnew

QK 1 =

D

qT ki i=1 i

D i=1

|qi − ki |

QK 2 = Qnew ∗ K

(1) (2) (3)

Finally, the scoring function s(K, Qnew ) concatenates the input with QK 1 and QK 2 , as shown in formula (4), to obtain the final attention weight matrix: concat(Qnew , K, QK 1 , QK 2 ) (4) √ Kdim √ The Kdim represents the dimension of K, and dividing by Kdim can scale the data variance back to the original distribution, which helps maintain the stability of gradients during training. In the resulting attention weight matrix obtained through the scoring function s, the larger the value, the higher the relevance, highlighting the more important information. Unlike traditional attention methods, the LAU algorithm proposed in this paper abandons the use of the softmax function for normalization because softmax lacks fine-grained resolution for scene-sensitive features. For example, in some specific scenarios, slight fluctuations in WiFi signals may be treated as noise by softmax. Therefore, as a relatively independent structure, after obtaining the attention weight matrix, the LAU algorithm will be scaled to a specified dimension through the SoftReLU activation function and fully connected layer, and its output will serve as the activation weight matrix (Fig. 3). s(K, Qnew ) =

Research on Indoor Positioning Algorithm Based on Multimodal

39

Fig. 3. Network Architecture. The left part is our proposed MMILN model, where the input is composed of RFID RSSI, WiFi RSSI, and WiFi hotspot information. The right part illustrates the structure of LAU, where the data after embedding is processed by LAU to calculate location weights.

3.3 SoftReLU Activation Function The PReLU has a strong rectification effect for zero input values, which is not very friendly to RSSI type data. To mitigate the sharp changes at the boundaries of different intervals in the activation function, we proposes a new adaptive activation function: SoftReLU. SoftReLU introduces an indicator function p(v) to control the variation of the function image, and the calculation formula for p(v) is shown in formula (6). During training, E[s] and Var[s] are the mean and variance of each batch of input data, while in test, they are the mean and variance of the entire data set.  is an empirical constant to prevent the denominator from being zero, and  is set to 10−8 . By recalculating the mean and variance of each batch of data, the model can better learn the distribution of different modalities of data, thereby accelerating model training and convergence. The calculation formula for SoftReLU is shown in formula (5). It is worth mentioning that when p(v) = 1(v > 0), i.e., E[s] = 0 and Var[s] = 0, SoftReLU can also degenerate into PReLU. ϕ(v) = p(v) × v + (1 − p(v)) × αv p(v) =

1 1+e

− √v−E[v] Var[v]+

(5) (6)

40

C. Shi et al.

4 Experiment 4.1 Data Collection The RFID reader used in this article is the Impinj R700 ultra-high frequency (UHF), equipped with the UHF2599 reader antenna, as shown in Fig. 4. The UHF2599 reader antenna is responsible for transmitting the electromagnetic waves of the reader and receiving the signals returned by the tag, achieving high efficiency and performance in all frequency bands. The RFID tag, as a lightweight chip terminal for storing information, can respond to requests initiated by the reader at the same frequency. We use the Smartrac dogbone U8 copper paper tag, with a working frequency of 860 MHz to 960 MHz, which meets general indoor positioning requirements and can be easily used for personnel and object positioning. We also developed a WiFi signal visualization software based on the Android system, which obtains real-time WiFi RSSI and WiFi name data by calling the signal receiving device built into the phone. We used the antenna built into the Samsung S8 smartphone as the WiFi signal receiver to collect WiFi data, meeting the experimental requirements of this article.

Fig. 4. Impinj R700 RFID Reader (left) and UHF2599 Reader Antenna (right)

The experimental scenario in this paper is an indoor environment with dimensions of 5 m in both length and width. To solve the collision problem in RFID with multiple readers and tags, we use a single RFID reader and tag. The indoor RFID tag’s returned signal is collected by moving the reader from different angles, simulating the process of indoor object movement. Meanwhile, different WiFi signals and their strength changes are simultaneously recorded during the movement process. The collected data is processed using Python. The hardware device deployment and data collection process are shown in Fig. 5. The localization target is randomly placed at a certain position (x, y) inside the room, and the RFID reader antenna starts to move from coordinates (5, 0) and (0, 5) along a straight line towards coordinate (5, 5) to uniformly collect 50 RFID RSSI data returned from the localization target during the movement. At the same time, the WiFi signals include different access points from different locations. Based on the estimation of indoor WiFi signal strength, 28 WiFi hotspots are selected as WiFi RSSI and WiFi name data. This way, RFID RSSI data with a length of 52 is obtained for each data: [rfid _rssi1 , rfid _rssi2 , . . . . . . , rfid _rssi50 x, y], and WiFi data with a length of 30: [wifi1 , wifi2 , . . . . . . , wifi28 , x, y], where wifii represents the WiFi RSSI and name of the i-th WiFi hotspot, and (x, y) is the coordinate point during the data collection

Research on Indoor Positioning Algorithm Based on Multimodal

41

process. The RFID data and WiFi data are aggregated according to the coordinate point (x, y) of the localization target, and finally, each data is a list with a length of 80: [rfid _rssi1 , rfid _rssi2 , . . . . . . , wifi1 , wifi2 , . . . . . . , wifi28 , x, y].

Fig. 5. Data collection schematic, the RFID reader antenna collects the signal values returned by the tag along a predetermined trajectory from two different directions, while simultaneously recording the collected WiFi data.

4.2 Evaluation Indicators Indoor localization, due to its particularity, can be treated as either a classification task using the grid partition method, or as a regression task by predicting coordinates. In the grid partition method, only the target grid needs to be predicted. Although this approach is simple and can achieve good accuracy when the grid partition granularity is small, as the grid granularity becomes smaller, the number of target categories increases, which makes it difficult for the model to learn and predict. On the other hand, in the regression method, the target coordinates can be predicted continuously, and the error can be minimized. This method is also more conducive to model learning and prediction. Therefore, in order to improve the accuracy of indoor localization, this paper considers indoor localization as a regression task. The following is a brief introduction to the regression task evaluation metrics used in this paper. (1) Mean Absolute Error (MAE) is the main accuracy metric in this paper, and the lower the value, the smaller the error. MAE(xi , yi ) =

 1 n  f (xi − yi | i=1 n

(2) Root Mean Square Error (RMSE), the lower its value, the smaller the error.  1 n RMSE(xi , yi ) = (f (xi ) − yi )2 i=1 n

(7)

(8)

42

C. Shi et al.

4.3 Experimental Comparison Scheme This section will analyze the contributions of different methods of our proposed MMILN model during the training process through ablation experiments. As shown in Fig. 6, under the same experimental environment, the change in the loss function value after 250 epochs of training for single modal, multimodal, embedding layers, and SoftReLU activation functions on the training set is shown. It can be seen that although the Embedding layer can extract more features to help the model learn better, the convergence speed is relatively slow due to other factors. After adding SoftReLU, not only has the different layers of the model been re adapted, but also the respective data distribution has been calculated for data of different modals. Therefore, it can be seen that the loss function value of the purple curve has decreased more rapidly and steadily. Through the above comparison, it is easy to see whether various methods have played a role in the MMILN model: based on the blue curve, the role of the Embedding layer is verified by comparing the yellow curve; Based on the yellow curve, the effect of multimodal data is verified by comparing the green curve; The effect of the adaptive activation function SoftReLU is verified by comparing the red curve with the purple curve based on the yellow curve and the green curve, respectively (Table 1).

Fig. 6. MMILN compares the performance of loss functions on training sets for different data and methods.

Table 2 shows the localization results of the ablation experiment based on the MMILN model. The MAE was used as the evaluation index in the experiment, and number of parameters and average prediction time of the model were statistically analyzed, which is consistent with the real-time task scene of indoor positioning. It can be seen that due to the differences in data and models used, there are some slight differences between different models. The difference in the number of parameters and average prediction time consumption between different models is mainly reflected in the embedding layer, which originally caused this phenomenon because multimodal data of RFID and WiFi compared to a single RFID RSSI data, During the embedding process, an additional embedding layer is required to encode WiFi data. At the same time, from the results, we can see the effectiveness of the adaptive activation function SoftReLU proposed in this paper. Whether it is single mode RFID data or combined RFID + WiFi multimodal

Research on Indoor Positioning Algorithm Based on Multimodal

43

Table 1. Comparison of positioning results of MMILN Model ablation experiments. The first column in the table represents changes made to the MMILN model, with “-” indicating missing parts Compared to MMILN (±) Data

MAE(m) Number of parameters Prediction time

- Embedding - SoftReLU

RFID

1.549

- SoftReLU

RFID

0.743

579432

75 ms

- SoftReLU

RFID, WiFi 0.565

639274

103 ms

None

RFID

0.596

579432

75 ms

None

RFID, WiFi 0.301

639274

102 ms

524288

71 ms

data, the SoftReLU activation function can greatly improve the prediction accuracy at a very low cost, which shows that the model in the era of deep learning is crucial to the learning of data distribution, especially for data of different modes, Learning their respective data distribution is very helpful for predicting the performance of models. Table 2. Comparison of experimental results before and after adding LAU Model

MAE

RMSE

Prediction time

LSTM

1.985

2.149

34 ms

LSTM-LAU

1.912

1.993

106 ms

CNN

0.776

0.784

27 ms

CNN-LAU

0.694

0.702

89 ms

MMILN

0.301

0.304

102 ms

MMILN-LAU

0.172

0.180

153 ms

To visualize the performance of LAU on indoor positioning, the experiment randomly selected five groups of data from the test set with coordinates (3, 3.75), (2.53, 1.42), (0.89, 2.37), (4.89, 2.57), (1.69, 4.78) to predict these coordinates using algorithms with LAU and algorithms without LAU as input to the model. To prove that LAU really improves multimodal data, LSTM and CNN are also included as comparative experiments to verify their generalization ability. The results are shown in Fig. 7, it can be seen that the prediction results of each model after adding LAU are closer to the original red coordinate points, indicating that the prediction accuracy has been improved, and the MMILN model is significantly better than the other two models. The experimental results are shown in Table 2. The results show that the LAU is not robust on CNN and LSTM models. Although the positioning accuracy has been improved slightly in general, the effect is not ideal in some scenarios.

44

C. Shi et al.

Fig. 7. Indoor positioning results. The red circle is the original coordinate point, the orange prism is the LSTM model and the LSTM-LAU model, the blue triangle is the CNN model and the CNNLAU model, the green quadrilateral is the MMILN model and the MMILN-LAU model. The left part is the result without LAU, and the right part is the result with LAU.

5 Conclusion In this paper we presents a multimodal indoor positioning method based on RFID and WiFi data. On this basis, a LAU algorithm is proposed based on the attention mechanism of key-value pairs to improve the positioning accuracy. For the problems in the training process, a SoftReLU activation function is introduced to help the model training. At the same time, the effectiveness of the proposed method is verified by ablation and comparison experiments. We hope that our method can serve as a solid baseline on indoor positioning and also motivate further research. Acknowledgment. This paper is supported by the 2021 Fujian Foreign Cooperation Project (No. 2021I0001): Research on Human Behavior Recognition Based on RFID and Deep Learning; 2021 Project of Xiamen University (No. 20213160A0474): Zijin International Digital Operation Platform Research and Consulting; State Key Laboratory of Process Automation in Mining & Metallurgy, Beijing Key Laboratory of Process Automation in Mining & Metallurgy (No. BGRIMM-KZSKL- 2022-14): Research and application of mine operator positioning based on RFID and deep learning; National Key R&D Program of China-Sub-project of Major Natural Disaster Monitoring, Early Warning and Prevention (No. 2020YFC1522604): Research on key technologies of comprehensive information application platform for cultural relic safety based on big data technology.

References 1. Niu, J., Yang, L., Song, X.: A working and verification framework of marine BDS high precision positioning system. In: 2020 IEEE 9th Data Driven Control and Learning Systems Conference (DDCLS), pp. 314–318. IEEE (2020) 2. Sujatmik, B., Fitria, N., Fathania, D., et al.: Comparison of accuracy between smartphone-GPS and professional-Gps for mapping tuberculosis patients in Bandung city (a preliminary study).

Research on Indoor Positioning Algorithm Based on Multimodal

3.

4.

5. 6. 7.

8.

9. 10. 11.

12. 13. 14.

15. 16.

17. 18.

45

In: 2017 5th International Conference on Instrumentation, Communications, Information Technology, and Biomedical Engineering (ICICI-BME), pp. 299–303. IEEE (2017) Kong, S., Peng, J., Liu, W., et al.: GNSS system time offset real-time monitoring with GLONASS ICBs estimated. In: 2018 IEEE 3rd International Conference on Image, Vision and Computing (ICIVC), pp. 837–840. IEEE (2018) Yassin, A., Nasser, Y., Awad, M., Al-Dubai, A., Liu, R., Yuen, C., et al.: Recent advances in indoor localization: a survey on theoretical approaches and applications. IEEE Commun. Surveys Tutor. 19(2), 1327–1346 (2016) Bai, Y.B., Wu, S., Wu, H.R., et al.: Overview of RFID-based indoor positioning technology. GSR (2012) Yang, B., Guo, L., Guo, R., et al.: A novel trilateration algorithm for RSSI-based indoor localization. IEEE Sens. J. 20(14), 8164–8172 (2020) Liu, T., Yang, L., Lin, Q., Guo, Y., Liu, Y.: Anchor-free backscatter positioning for RFID tags with high accuracy. In: 2014 Proceedings of IEEE International Conference on Computer Communications, pp. 379–387. IEEE (2014) Wang, J., Adib, F., Knepper, R., Katabi, D., Rus, D.: RF-compass: robot object manipulation using RFIDs. In: International Conference on Mobile Computing & Networking, pp. 3–14 (2013) Miesen, R., Kirsch, F., Vossiek, M.: Holographic localization of passive UHF RFID transponders. In: IEEE International Conference on RFID, pp. 32–37 (2011) Wang, J., Vasisht, D., Katabi, D.: RF-IDraw: virtual touch screen in the air using RF signals. In: ACM Conference on SIGCOMM (2014) Liu, X., Wen, M., Qin, G., et al.: LANDMARC with improved k-nearest algorithm for RFID location system. In: 2016 2nd IEEE International Conference on Computer and Communications (ICCC), pp. 2569–2572. IEEE (2016) Hightower, J., Vakili, C., Borriello, G., et al.: Design and calibration of the spoton ad-hoc location sensing system (2001, unpublished) Nikitin, P.V., Martinez, R., Ramamurthy, S., Leland, H.: Phase based spatial identification of UHF RFID tags. In: IEEE International Conference on RFID, pp. 102–109 (2010) Wang, J., Katabi, D.: Dude, where’s my card?: RFID positioning that works with multipath and non-line of sight. In ACM SIGCOMM Computer Communication Review, vol. 43, pp. 51–62. ACM (2013) Shangguan, L., Jamieson, K.: The design and implementation of a mobile RFID tag sorting robot. In: International Conference on Mobile Systems (2016) Kim, K.S., et al.: Large-scale location-aware services in access: hierarchical building/floor classification and location estimation using Wi-Fi fingerprinting based on deep neural networks. Fiber Integr. Opt. 37(5), 277–289 (2018) Bai, J., Sun, Y., Meng, W., Li, C.: Wi-Fi fingerprint-based indoor mobile user localization using deep learning. Wirel. Commun. Mob. Comput. 2021(1), 1–12 (2021) Xingli, G., Yaning, L., Ruihui, Z.: Indoor positioning technology based on deep neural networks. In: 2018 Ubiquitous Positioning, Indoor Navigation and Location-based Services (UPINLBS), pp. 1–6. IEEE (2018)

Undetectable Attack to Deep Neural Networks Without Using Model Parameters Chen Yang1 , Yinyan Zhang1(B) , and Ameer Hamza Khan2 1 The College of Cyber Security, Jinan University, Guangzhou 510632, China

[email protected], [email protected]

2 Department of Land Surveying and Geo-Informatics, The Hong Kong Polytechnic University,

Kowloon 999077, Hong Kong, China [email protected]

Abstract. Deep neural networks (DNNs) have been widely deployed in a diverse array of tasks, such as image classification. However, recent research has revealed that intentionally adding some perturbations to the input samples of a DNN can cause the model to misclassify the samples. The adversarial samples have the capability of fooling highly proficient convolutional neural network classifiers in deep learning. The presence of such vulnerable ability in these neural networks may have severe implications on the security of targeted applications. In this work, we show that attacks on CNNs can be successfully implemented even without knowing model parameters of the target network. We use the beetle antennae search algorithm to realize the attack such that human eyes cannot detect the attack. Compared to other adversarial attack algorithms, the resulting adversarial samples from our algorithm are not significantly different from the pre-attack images, which makes the attack undetectable. In this study, the CIFAR-10 dataset was utilized to show the efficacy and advantages of the algorithm on LeNet-5 and ResNet architectures. Our findings indicate that the proposed algorithm produces images with no significant difference from the original images while the attack success rate is high. Keywords: Deep neural networks · Nature-inspired algorithm · Undetectable attack

1 Introduction Deep neural networks (DNNs) were broadly investigated, with a particular focus on convolutional neural networks (CNNs), which are now frequently used for image classification [1]. Despite their widespread application, recent research has revealed security vulnerabilities in DNNs. They are vulnerable to attacks by adding some perturbations to the original inputs [2, 3]. In practical applications of DNNs, their vulnerability to adversarial samples may lead to many security problems, such as misleading autonomous driving systems, where an attacker can fool the autonomous driving system to disobey traffic rules [4]. This paper focuses on the intriguing phenomenon in DNNs when employed for © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 46–57, 2023. https://doi.org/10.1007/978-981-99-4742-3_4

Undetectable Attack to Deep Neural Networks

47

image classification tasks. Specifically, the intentional addition of perturbations to the input image may result in misclassification by the classifier with high confidence, which is called an adversarial attack, and a corresponding algorithm is proposed to generate such samples, called “adversarial images”. There are broadly two types of adversarial attacks. Attacks for which an attacker needs all the information and parameters of the target neural network is called the white-box attacks. The case that the attack is deployed without information about the parameters and structure of the target network is referred to as black-box attacks, which only uses the adversarial samples from the inputs and outputs of the model [5]. The most common way to generate adversarial images is to add some well-selected perturbations to a correctly classified natural image, which can cause the classifier to misclassify the image. Szegedy et al. [2] first observed an intriguing phenomenon present in DNNs when used for image classification tasks, which revealed DNNs’ vulnerability to adversarial attacks. Perturbation can be implemented by several gradient-based algorithms using back-propagation to obtain gradient information. People have paid great attention to this problem. For that, various methods have been proposed with the aim of attacking image classification models. Goodfellow et al. [6] showed such an attack by adding a perturbation to the image, which can be generated in the direction of gradient ascent. Papernot et al. [7] used a Jacobian matrix to build “adversarial saliency map” to obtain pixel positions that have the greatest contribution to the misclassification neural network. In addition, Moosavi-Dezfooli et al. [8] proposed DeepFool by assuming the linearity of DNN decision boundaries to obtain perturbation. Carlini et al. [9] proposed C&W attacks, which are based on L0 , L1 , and L∞ separately, and each of them can successfully break the defensive distillation. Su et al. [10] proposed differential evolution to reduce the time to find the best pixel, which deceives the model by perturbing only a single pixel. However, the above strategies either require estimating the gradient of the target function or generate adversarial images that can be readily detected by human eyes (see Fig. 1), and there are notable differences between the original and post-attack images. The attacked image shows pixel points that do not belong to that image, with significant differences in color. To solve this problem, we adopt the beetle antenna search (BAS) algorithm [12] to develop a black-box attack algorithm that can effectively fool a welldesigned DNN. We formulate the adversarial attack mathematically using optimization theory with pixel positions and brightness values as decision variables. The optimizer minimizes the neural network’s confidence in the true class, resulting in the classifier outputting false classification results. Next, we propose an optimization algorithm to search for pixel perturbations with the objective of fooling CNN. In each iteration, the number of modified pixels is not limited, and the pixel modification intensity is limited, so that the image produced after the attack is not significantly different from the original image, and it cannot be distinguished whether the image has been attacked or not. The vulnerability of DNNs can lead to security issues, and more researchers are actively addressing this challenge. In addition to being simple and practical, the proposed attack highlights the aspect of the lack of robustness in deep learning. The following shows our main contributions.

48

C. Yang et al.

Fig. 1. The images are displayed in two rows, the first row represents the original images, while the second row illustrates the perturbed images using the method proposed in [12] to launch an attack on the ResNet model. The misclassification is as follows: (a) an airplane misclassified as a bird, (b) a ship misclassified as an airplane, (c) an automobile misclassified as an airplane, and (d) a dog misclassified as a deer.

1. The proposed method is simple and imperceptible, not only fooling the neural network but also keeping it undetected by human eyes. 2. We show that without information about the neural network architecture and internal parameters, it is easy to generate adversarial perturbations, so the algorithm can be applied to different classifiers to assess their robustness to adversarial perturbations. 3. The algorithm’s effectiveness is demonstrated through pixel modifications, and it outperforms the algorithm proposed in [11].

2 Preliminaries Here, CNNs are introduced, after which we formulate the optimization problem mathematically, and discuss directed and undirected attacks. 2.1 Convolutional Neural Networks The first popular CNN model called AlexNet was proposed by Krizhevsky et al. in 2012 [13]. Inspired by AlexNet, researchers have proposed more CNN models, such as VGGNet [14], GooleLeNet [15], ResNet [16], and DenseNet [17]. All CNNs aim to automatically learn important features from training data. For a simple CNN, the image is multi-layer convolved and sent to the downsampling layers, fully connected layers, and finally output softmax layer that provides recognition results. In an image classification task, CNN can be regarded as a non-linear function fcnn with the input being the image X clean and the output being the probability of belonging to each class p = fcnn (X clean ),

(1)

where p is an array of possible results. Each element of the vector corresponds to an output class. Assume that the vector C represents the collection of all classes: C = [C1 , C2 , . . . , CN ]T ∈ RN ,

(2)

Undetectable Attack to Deep Neural Networks

49

where C1 , C2 , . . . , CN are class labels, N is the overall quantity of output classes. The neural network outputs a probability of being assigned to each class, regarding the class with the greatest probability being the neural network’s prediction. The prediction result is given as follows: i∗ = arg max p[i] i   C ∗ = C i∗

(3)

where C ∗ stands for prediction class label and argmaxi gives the index of the p element with the highest value. The prediction class C ∗ is a function of X clean , as can be seen from (1) and (3). Thus, it can be simplified as C ∗ = F(X clean ),

(4)

To make the formula more concise in subsequent sections, we establish a novel function as follow: P(p, Ci ) = p[i],

(5)

where the possibility of the corresponding class serves as the output, and the two inputs are the class name Ci and the probability vector p. 2.2 Image Perturbation Assume that the original image X clean has m×n pixels. The perturbation function P must be established. P takes as input three parameters: rows r ∈ [1, 2, . . . , m]k×1 , columns c ∈ [1, 2, . . . , n]k×1 , and RBG values RGB ∈ [0, 255]k×3 respectively, where $k$ is the number of pixels allowed to be modified. The perturbed image Xper is returned by the perturbation function P, which is obtained by modifying the RGB values of the pixels in the corresponding rows and columns. The following is the definition of the perturbation function P P : X clean (r[i], c[i]) := RGB[i], where i ∈ {1, 2, . . . , k},

(6)

where := is the assignment operator. Thus, the perturbed image X per is defined as follow: X per = P(X clean , r, c, RGB).

(7)

2.3 Undirected Attack For an undirected attack, if the original image X clean belongs to the class Cclean , we only need to minimize the probability that the perturbed image X per belongs to the real class Cclean . The attack is successful when the probability of another class is higher than the real class. Based on Eqs. (1) and (5), the minimization function is as follow:     X ∗per = arg min P fcnn X per , Cclean , (8) X per

50

C. Yang et al.

where X ∗per is the image with minimum true-class confidence after adding perturbations. If the modification strength is too strong, i.e., the difference between the original RGB values and the post-attack RGB values is large, the adversarial attack can be easily detected. If we expect the attack to be imperceptible to human eyes, we should limit the modification range of RGB values. Based on (7), we formulate the undirected attack as follows: r ∗ , c∗ , RGB∗ = arg min P(fcnn (P(X clean , r, c, RGB)), Cclean ) r,c,RGB

subject to 0 ≤ r[i, 1] ≤ m, ∀i ∈ {1, 2, . . . , k} 0 ≤ c[i, 1] ≤n, ∀i ∈ {1, 2,  . . . , k} ∗ i, j ≤ 255, ∀i ∈ {1, 2, . . . , k}, j ∈ {1, 2, 3} 0 ≤ RGB i, j , RGB      RGB∗ i, j − RGB i, j  ≤ δ, ∀i ∈ {1, 2, . . . , k}, j ∈ {1, 2, 3}

(9)

in which r, c and RGB represents the rows, columns, and RGB values of the original image, respectively; r ∗ , c∗ , and RGB∗ represents the rows, columns, and RGB values of the perturbed image, respectively. δ represents the range in which RGB values can be varied. Here, δ controls the allowable maximum difference between the original and post-attack images, and has an effect on the success rate of the adversarial attack. 2.4 Directed Attack In the directed attack, the class of perturbed image needs to be artificially specified, called Ctarget . Thus, the objective of the directed attack differs from an undirected attack in a way that the neural network classifier outputs a target class Ctarget = Cclean . Based on Eqs. (1) and (5), the maximization function is     (10) X ∗per = arg min P fcnn X per , Ctarget . X per

Similarly, (7) can be utilized to reformulate the optimization problem:   r ∗ , c∗ , RGB∗ = arg max P fcnn (P(X clean , r, c, RGB)), Ctarget r,c,RGB

subject to 0 ≤ r[i, 1] ≤ m, ∀i ∈ {1, 2, . . . , k} 0 ≤ c[i, 1] ≤n, ∀i ∈ {1, 2,  . . . , k} ∗ i, j ≤ 255, ∀i ∈ {1, 2, . . . , k}, j ∈ {1, 2, 3} 0 ≤ RGB i, j , RGB      RGB∗ i, j − RGB i, j  ≤ δ, ∀i ∈ {1, 2, . . . , k}, j ∈ {1, 2, 3}

(11)

To reduce the complexity of the algorithm, we transform the above maximization problem into the following minimization problem:   r ∗ , c∗ , RGB∗ = arg min 1 − P fcnn (P(X clean , r, c, RGB)), Ctarget , (12) r,c,RGB

where the rows, columns, and RGB values that make the model produce the target class are denoted as r ∗ , c∗ , and RGB∗ , respectively.

Undetectable Attack to Deep Neural Networks

51

3 Algorithm We offer a solution to the previously mentioned optimization problem in this section. 3.1 Beetle Antennae Search (BAS) Here, we introduce the BAS algorithm [12]. Two directions xl and xr are assigned for the beetle antennae search. At moment t − 1, the beetle is currently at position xt−1 . If the fitness of the function f (x) represents the strength of the smell at location x, the intensity of smell at the current position is f (xt−1 ). The maximum smell intensity should satisfy the following constraint max f (x) x

subject to: xmin ≤ x ≤ xmax .

(13)

The position of two antenna endpoints is described as  xl = xt−1 + d b,  xr = xt−1 − d b,

(14)

where d > 0 ∈ R represents the corresponding beetle’s searching area which is the length of an antennae and b is a random direction vector. However, the random vector b may breach the constraint (13), and we define a constraint to prevent the violation: ψ = {x|xmin ≤ x ≤ xmax }, Pψ (x) = max{xmin , min{x, xmax }},

(15)

where ψ is a constrained set, and Pψ (.) is a projection function, which makes the calculation simple and efficient. Then, we project the position of the antenna endpoints onto the constraint set ψ: ψ

xl = Pψ (xl ),

ψ

xr = Pψ (xr ),

(16)

where the ψ superscript indicates that the vector is projected   onto  the set ψ. The smell intensities of the beetle antenna is indicated by f ψ xl and f ψ xr . The beetle’s position at the subsequent moment is defined by the following equation:   ψ   xt = xt−1 + ξ bsign(f (17) xl − f ψ xr ), where ξ represents the step size of beetle. The beetle reevaluates the strength of the smell when it reaches the new position. If the smell intensity has improved, it will stay there; if not, it will go back to the previous location with the highest smell intensity xk , i.e.  xt , if f (xt ) ≥ f (xt−1 ) xt = . (18) xk , if f (xt ) < f (xt−1 )

52

C. Yang et al.

When the beetle reaches the new location xt , we regenerate the random vector b and continue the procedure described above until the attack is successful. The above algorithm is for maximization problems, and we can transform Eq. (17) into the following form for minimization problems:    ψ   (19) xt = xt−1 − ξ bsign f xl − f ψ xr .

3.2 Optimization Algorithm The optimization algorithm can be briefly summarized in the following steps: 1. 2. 3. 4.

The beetle starts from a randomly generated location x0 . Randomly generates a vector b related to the beetle’s present location x0 . Calculate the antenna endpoints’ position xl and xr using (14). Following Eqs. (17) and (18), compute the updated position. If the attack proves to be successful, stop. However, if unsuccessful, return to step 2.

To implement directed and non-directed attacks using the optimization algorithm, we define a matrix X. According to the definition in Sect. 2.2, k is the number of modifiable pixels, so the dimension of the matrix X is k × 5: X = [r c RGB]

(20)

Then, the cost function of the undirected attack is       f X = P fcnn P(X clean , X[:, 1], X[:, 2], X[:, {3, 4, 5}] ), Cclean .

(21)

For directed attacks, the cost function is       f X = 1 − P fcnn P(X clean , X[:, 1], X[:, 2], X[:, {3, 4, 5}] ), Ctarget .

(22)

In order to successfully launch the attacks with minimal modification strength, the following constraint needs to be satisfied     X[:, 3, 4, 5] − X clean X[:, 1] X[:, 2]  ≤ δ. (23) Since r and c are in (20) with table rows and columns, X clean [.][.] represents the RGB value of the pixel corresponding to the row and column values. Our attack algorithm is shown in Algorithm 1.

Undetectable Attack to Deep Neural Networks

53

4 Experimental Studies In this section, the proposed attack algorithms are evaluated using experiments. 4.1 Evaluation Methodology Our study evaluates the efficacy of our attacking algorithm on deep convolutional neural network architectures, as applied to the CIFAR-10 image classification dataset. Two distinct neural network architectures, i.e., LeNet-5 [18] and ResNet [16] are tested. The CIFAR-10 dataset comprises RGB images of 10 distinct classes, with image dimensions of 32 × 32. In TensorFlow, we built two CNN designs, LeNet-5 and ResNet with cross-entropy being the loss function. Table 1 shows the accuracy of the LeNet-5 and ResNet models. The attack algorithm implemented in our study utilizes specific parameter settings. To allow the modification of all pixels of the image, the value of k = 1024 is assigned. Through iterative trial and error, the length of an antennae d = 0.5, the step size of beetle ξ = 0.5 and the RGB modifiable range of δ = 10 are determined. Additionally, the maximum number of iterations maxiter = 400 is imposed. 4.2 Results In order to enhance the precision of experimental results, we only attacked the images from the test dataset that had been accurately classified by the trained model. The outcomes of the undirected and directed attacks on LeNet-5 and ResNet are separately presented in our study. Only 7581 images from LeNet-5 and 9520 images from ResNet were found to be suitable for the adversarial attack according to Table 1.

54

C. Yang et al.

Fig. 2. Samples of single pixel successful untargeted attacks. The original images are displayed in the top row, while the perturbed images are presented in the bottom row. (a) Samples for LeNet-5. (b) Samples for ResNet

The adversarial images are compared to the original images in Fig. 2 for the undirected attack and in Fig. 3 for the directed attack. Figure 2(a) and Fig. 2(b) demonstrate the results for LeNet-5 and ResNet, respectively, showing the pre-attack class, the confidence of the CNN, the post-attack false class, the new confidence, and the number of iterations required for a successful attack. Figure 3(a) and Fig. 3(b) present the target attack images for LeNet-5 and ResNet, respectively. In addition to the information described above in Fig. 2, we add target class to Fig. 2 to form Fig. 3. Each image’s target class is selected at random but must be taken to ensure that it differs from the image’s actual class. To demonstrate the statistical results, we randomly selected 1000 images from the CIFAR-10 test set for the undirected attack on ResNet and applied the same attack using the method proposed in [11]. Based on the observation in Fig. 1, it becomes apparent that certain pixels do not belong to the perturbed image generated by the method in [11], which will serve as the criterion for evaluating whether the attack can be detected by the human eyes. Table 2 shows the comparison results of the above two attack methods, which include the number of attack samples, the percentage of successful attacks, the average number of iterations, and the percentage of successful attacks detectable by human eyes. We can find that the performance of the proposed method in this paper is significantly better than the proposed method in [11]. The images of successful attacks in the proposed attack method have no obvious evidence of having been attacked.

Undetectable Attack to Deep Neural Networks

55

Fig. 3. Samples of single pixel successful untargeted attacks. The original images are displayed in the top row, while the perturbed images are presented in the bottom row. (a) Samples for LeNet-5. (b) Samples for ResNet.

Table 1. Accuracy of the CNN Architectures on the CIFAR-10 Dataset After Training. Architectures

Training Samples

Training Accuracy

Testing Sample

Testing Accuracy

LeNet-5

50000

74.88%

10000

75.81%

ResNet

10000

99.99%

10000

95.20%

Table 2. Comparison Results. Method

Attack Samples

Successful Attack

Average Iteration

Detection Percentage

Method in this paper

1000

47.60%

207.79

0.00%

Method in [11]

1000

41.70%

254.19

96.05%

56

C. Yang et al.

5 Conclusions In this work, we have proposed an attack algorithm for CNNs. The proposed algorithm can be used without knowing any parameters of the network. The experimental results in this paper demonstrate that the attack algorithm successfully makes the classifier misclassify the image with negligible perceptible dissimilarity between the adversarial and original image, rendering the attack undetectable to human eyes. Furthermore, we believe that our attack method has broader applicability and can be used to attack other DNNs, providing a useful tool to measure the robustness of DNNs. Acknowledgement. This work is supported in part by the National Natural Science Foundation of China under Grant 62206109, the Guangdong Basic and Applied Basic Research Foundation under Grant 2022A1515010976, and the Science and Technology Program of Guangzhou under Grant 202201010457.

References 1. Goodfellow, I., Bengio, Y., Courville, A.: Deep Learning. MIT Press, Cambridge (2016) 2. Szegedy, C., et al.: Intriguing properties of neural networks. arXiv preprint arXiv:1312.6199 (2013) 3. Papernot, N., McDaniel, P., Goodfellow, I., Jha, S., Celik, Z.B., Swami, A.: Practical blackbox attacks against machine learning. In: Proceedings of the 2017 ACM on Asia Conference on Computer and Communications Security, pp. 506–519 (2017) 4. Tang, K., Shen, J., Chen, Q.A.: Fooling perception via location: a case of region-of-interest attacks on traffic light detection in autonomous driving. In: NDSS Workshop on Automotive and Autonomous Vehicle Security (AutoSec) (2021) 5. Zhang, C., Benz, P., Karjauv, A., Cho, J.W., Zhang, K., Kweon, I.S.:Investigating top-k white-box and transferable black-box attack. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 15085–15094 (2022) 6. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572 (2014) 7. Papernot, N., McDaniel, P., Jha, S., Fredrikson, M., Celik, Z.B., Swami, A.:The limitations of deep learning in adversarial settings. In: 2016 IEEE European Symposium on Security and Privacy (EuroS&P), pp. 372–387. IEEE (2016) 8. Moosavi-Dezfooli, S.M., Fawzi, A., Frossard, P.: Deepfool: a simple andaccurate method to fool deep neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2574–2582 (2016) 9. Carlini, N., Wagner, D.: Towards evaluating the robustness of neural networks. In: 2017 IEEE Symposium on Security and Privacy, pp. 39–57. IEEE (2017) 10. Su, J., Vargas, D.V., Sakurai, K.: One pixel attack for fooling deep neural networks. IEEE Trans. Evol. Comput. 23(5), 828–841 (2019) 11. Khan, A.H., Cao, X., Xu, B., Li, S.: Beetle antennae search: using biomimetic foraging behaviour of beetles to fool a well-trained neuro-intelligent system. Biomimetics 7(3), 84 (2022) 12. Wang, J., Chen, H.: Bsas: beetle swarm antennae search algorithm for optimization problems. arXiv preprint arXiv:1807.10470 (2018) 13. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. Commun. ACM 60(6), 84–90 (2017)

Undetectable Attack to Deep Neural Networks

57

14. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 15. Szegedy, C., et al.: Going deeper with convolutions. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1–9 (2015) 16. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 17. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4700–4708 (2017) 18. Wang, Y., Tan, Y.A., Zhang, W., Zhao, Y., Kuang, X.: An adversarial attack on dnn-based black-box object detectors. J. Netw. Comput. Appl. 161, 102634 (2020) 19. LeCun, Y., Bottou, L., Bengio, Y., Haffner, P.: Gradient-based learning applied to document recognition. Proc. IEEE 86(11), 2278–2324 (1998) 20. Cao, H., Si, C., Sun, Q., Liu, Y., Li, S., Gope, P.: Abcattack: a gradient-free optimization black-box attack for fooling deep image classifiers. Entropy 24(3), 412 (2022) 21. Giulivi, L., et al.: Adversarial scratches: deployable attacks to CNN classifiers. Pattern Recogn. 133, 108985 (2023) 22. Ali, Y.M.B.: Adversarial attacks on deep learning networks in image classification based on smell bees optimization algorithm. Futur. Gener. Comput. Syst. 140, 185–195 (2023) 23. Cai, Z., et al.: Context-aware transfer attacks for object detection. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, pp. 149–157 (2022) 24. Puttagunta, M.K., Ravi, S., Nelson Kennedy Babu, C.: Adversarial examples: attacks and defences on medical deep learning systems. Multimedia Tools Appl. 1–37 (2023) 25. Ye, J., Wang, Y., Zhang, X., Xu, L., Ni, R.: Adversarial attack algorithm for object detection based on improved differential evolution. In: 6th International Workshop on Advanced Algorithms and Control Engineering (IWAACE 2022), vol. 12350, pp. 669–678. SPIE (2022)

A Traffic Flow Prediction Framework Based on Clustering and Heterogeneous Graph Neural Networks Lei Luo1 , Shiyuan Han1(B) , Zhongtao Li1 , Jun Yang2 , and Xixin Yang3 1 Shandong Provincial Key Laboratory of Network Based Intelligent Computing, University of

Jinan, Jinan 250022, China [email protected] 2 School of Automotive Engineering, Shandong Jiaotong University, Jinan 250023, China 3 College of Computer Science and Technology, Qingdao University, Qingdao 266071, China

Abstract. Traffic flow forecasting is crucial for traffic management, but the complex spatio-temporal correlation and heterogeneity among traffic nodes make this problem challenging. While many deep spatio-temporal models have been proposed and applied to traffic flow prediction, they mostly focus on capturing the spatio-temporal correlation among traffic nodes, ignoring the influence of the functional characteristics of the area to which the nodes belong. Therefore, there is a need to propose a method to help models capture such influence. This paper presents a novel framework that enhances existing deep spatio-temporal models by combining clustering with heterogeneous graph neural networks. Our framework’s clustering module measures the similarity between nodes in the traffic pattern using the Dynamic Time Warping and the Wasserstein distance and then applies spectral clustering to divide the nodes into different clusters based on traffic pattern. Our graph transformer module can adaptively construct a new graph for nodes in the same cluster, and the spatio-temporal feature learning module captures the spatio-temporal correlation among nodes based on the new graph. Extensive experiments on two real datasets demonstrate that our proposed framework can effectively improve the performance of some representative deep spatio-temporal models. Keywords: Intelligent transportation · Traffic flow prediction · Heterogeneous graph neural network · Traffic flow · Clustering

1 Introduction Today, the widespread placement of traffic sensors allows researchers to obtain a large amount of traffic data and utilize data-driven methods to solve traffic flow prediction problems [1]. Since traditional convolutional neural networks are not suitable for processing graph-structured data, many traffic flow prediction models based on graph neural networks (GNNs) have been proposed [2]. However, most current models focus on modeling the spatio-temporal correlation between traffic nodes, while ignoring the influence of the heterogeneity of traffic data on prediction results. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 58–69, 2023. https://doi.org/10.1007/978-981-99-4742-3_5

A Traffic Flow Prediction Framework

59

The heterogeneity of traffic flow data is manifested in different ways, including the varying functional characteristics of the areas where traffic nodes are located. As a result, the traffic patterns observed by these nodes also differ, which is specifically reflected in the changing trends and numerical distributions of traffic data [3]. Figure 1 illustrates this concept using three areas in Lixia District of Jinan: A, B, and C, corresponding to residential, park, and commercial areas. Figure 1(b) shows the different trends of traffic flow in these areas. For instance, on weekdays, because people usually need to go to work during the day and go home at night, the traffic flow into area C will gradually increase in the morning and peak soon after, while traffic flow into area A will reach its peak in the evening. In addition, traffic in area B is much less than that in areas A and C due to fewer people going to the park on weekdays. Figure 1(c) shows the distribution of different traffic flow values in these three typical areas. This also indicates from another perspective that some nodes that are far apart in space may still exhibit similar traffic patterns because of the similar functional characteristics [4]. Therefore, models ignoring this heterogeneity may fail to capture the correlation between nodes with similar traffic patterns, resulting in limited predictive ability. To achieve more accurate forecasting results, it is necessary to propose a general framework to help existing traffic flow forecasting models address the problem of heterogeneity modeling. However, this task presents the following two challenges: (1) How to design a forecasting method that considers the correlation between nodes with similar traffic patterns while capturing the complex spatio-temporal correlation in traffic flow data. (2) How to make the proposed method transplantable and applicable to various existing prediction models that handle spatio-temporal data. The main contribution of this paper is the introduction of a traffic flow prediction framework that uses clustering and heterogeneous graph neural networks to capture the shared traffic evolution rules among nodes with similar traffic patterns. We further demonstrate through experiments on real datasets that our framework can be easily applied to a range of existing traffic flow prediction models based on GNNs and improve their performance to some extent.

Fig. 1. The example of differences in traffic patterns

2 Related Work Some traffic nodes may exhibit similar patterns due to their similar functional characteristics, even if they are far apart. Therefore, clustering can be used to divide nodes with similar traffic patterns into clusters to further capture the correlation between them. Many

60

L. Luo et al.

researchers have applied clustering to tasks such as analyzing traffic status. However, traditional clustering such as K-means can only process data with a convex spherical sample space, and they fall into local optima when the sample space is non-convex [5]. In recent years, the spectral clustering has received widespread attention from researchers [6]. Essentially, the spectral clustering is based on matrix spectral analysis theory to obtain new features, and then uses these features to cluster. Compared with other clustering algorithms, spectral clustering is simple, less prone to falling into local optimal solutions, and capable of dealing with non-convex distributions. In many real-life scenarios, the data consists of multiple types that cannot be simply represented by homogeneous graphs with only one node type and edge type. Instead, they require heterogeneous graphs containing multiple edges and nodes. Unfortunately, most GNNs are designed for homogeneous graphs. Therefore, a common method to deal with heterogeneous graphs involves designing meta-paths manually. A meta-path refers to a path connecting different types of nodes. The heterogeneous graph is then converted into a homogeneous graph defined by the meta-path, which is subsequently processed by a general GNN. The Heterogeneous Graph Attention Network (HAN), proposed by Wang et al., converts heterogeneous graphs into homogeneous graphs while using a hierarchical attention mechanism for aggregation operations to obtain nodes feature representation [7]. However, the selection of meta-paths is generally done manually by domain experts, and the chosen meta-paths can significantly affect the final model’s performance. The Graph Transformer Networks (GTN) proposed by Yun et al. can adaptively generate meta-paths according to current data and tasks using matrix multiplication and learnable parameters [8]. Moreover, GTN can be combined with different types of GNNs.

3 Methodology 3.1 Problem Definition Definition 1 (Road network G). A weighted directed graph G = (V , E) is given to describe the topology of the traffic road network. Each traffic sensor in the road network is regarded as a node, and the road segments connecting adjacent sensors are regarded as edges. V = {v1 , v2 , ..., vN } represents the node set, where N is the number of nodes. E represents the edge set, and the adjacency matrix A ∈ RN ×N is used to represent the adjacency relationship between nodes. In this paper, we use the reciprocal of the length of the road as the weight of the edge. If there is no adjacency relationship between two nodes, the corresponding adjacency matrix element is 0. Definition 2 (Feature matrix X). We consider the data collected by traffic sensors on the road network as the attribute of nodes, denoted as X = (X(1) , X(2) , ..., X(τ ) ) ∈ RN ×τ , where τ represents the length of the historical time series. Here, X(t) ∈ RN ×1 represents the data collected by all sensors at time t, and Xn ∈ R1×τ represents the historical data of length τ collected by the n − th sensor. The attributes can include flow, speed, and other traffic-related information. Based on the above definitions, The traffic flow forecasting problem can be regarded as learning the mapping function f under the premise of knowing the topological structure of the road network G and the feature matrix X. For the reference time t, this function

A Traffic Flow Prediction Framework

61

predicts the traffic data of the T moments after the t moment from the observation data of the previous n moments, as shown in formula: [X(t+1) , ..., X(t+T −1) , X(t+T ) ] = f (G; [X(t−n+1) , ..., X(t−1) , X(t) ]).

(1)

3.2 Framework Overview The proposed traffic flow prediction framework based on clustering and heterogeneous graph neural network consists of three modules: clustering module, graph transformer module, and spatio-temporal feature learning module, as shown in Fig. 2. A brief overview of the framework is presented below: 1. Clustering module. This module takes the feature matrix X as input, and measures the similarity between traffic nodes in the traffic pattern, which is used as a basis to divide the nodes into different clusters. The output of the clustering module is the cluster division result of each node. 2. Graph transformer module. Based on the cluster division results, this module assigns type attributes to each node of the road network G and transforms it into a heterogeneous graph. The module then employs the GTN to adaptively obtain a set of meta-paths and their corresponding homogeneous graphs. The output of this module is used in the spatio-temporal feature learning module. 3. Spatio-temporal feature learning module. This module takes the original road network G, feature matrix X, and the meta-paths graph obtained from the previous module as input. It aims to capture the correlation of each node in both time and space dimensions, and further strengthen it through the potential connection between similar nodes. The module then outputs the predicted traffic data.

Fig. 2. The overall structure of the proposed framework

3.3 Clustering Module The differences in traffic patterns among traffic nodes are primarily manifested in the varying trends and numerical distributions of traffic data. To measure the difference in the data change trend of each node, this paper employs the Dynamic Time Warping (DTW), which can determine the similarity between time series by assessing the change trend [9]. Additionally, Wasserstein distance can measure the difference between two probability distributions while considering the geometric characteristics of the distribution, so we

62

L. Luo et al.

choose Wasserstein distance to measure the difference in the numerical distribution of each node in the data [10]. The standard DTW employs dynamic programming to calculate the similarity between time series, with a time complexity of O(N 2 ). However, when dealing with long time series, the execution time of the standard DTW becomes prohibitively high. To address this issue, we adopt the fast-DTW, which restricts the search length to H . The time complexity of the fast-DTW is reduced to O(H × n), allowing for efficient processing of large-scale spatio-temporal data.

Before using the Wasserstein distance to measure the similarity of the traffic nodes in the numerical distribution, the feature matrix X needs to be preprocessed. First, we need to find the maximum and minimum values of all elements in X and get the difference between them. Then, we determine the number of groups Q, to obtain the dividing points of each group. Next, for the n − th sensor, we count the frequency at which the value collected by the sensor appears in each group, denoted as (f ) (q) (q) (q) (Q) (1) (2) Xn = (fn , fn , ..., fn , ...., fn ), fn ∈ [0, 1] [0, 1], where fn represents the frequency at which the value of the data collected by the n − th sensor appears in the q − th  (q) group, and satisfies q fn = 1. In this way, the traffic data collected by the n − th sensor can be transformed into a probability distribution Pn . In addition, we need to set the conversion cost between probability mass. For example, the conversion cost between (j) the probability mass fn(i) of the n − th node group i and the probability mass fm of the m − th node group j can be set as the absolute value of the difference between them,

A Traffic Flow Prediction Framework

then the Wasserstein distance between traffic nodes n and m is:      (y)  dist2 (Xn , Xm ) = inf γ (x, y)fn(x) − fm dxdy γ ∈[Pn ,Pm ] x y   (y) s.t. γ (x, y)dy = fn(x) , γ (x, y)dx = fm

63

(2)

In the formula, [Pn , Pm ] represents the set of joint probability distributions of Pn and Pm , and γ is a probability distribution within this set, which satisfies that the marginal distribution of γ is exactly Pn and Pm . To avoid the influence of dimension inconsistency on subsequent calculations, all dist1 need to be normalized after calculation, and the same is true for dist2 . Subsequently, dist1 and dist2 can be combined to construct a similarity matrix S ∈ RN ×N to measure the similarity between nodes in the traffic pattern. The elements of the similarity matrix are given by: sn,m = dist1 (Xn , Xm ) + dist2 (Xn , Xm )

(3)

After obtaining the similarity matrix, the clustering module will apply spectral clustering using the commonly used graph cutting method Ncut and the widely adopted clustering method K-means.

3.4 Graph Transformer Module After the clustering module divides the nodes into different clusters, the traffic nodes obtain type attributes, and the original road network G is transformed from a homogeneous graph to a heterogeneous graph. Thus, the graph transformer module can use the GTN belonging to the heterogeneous graph neural networks to process it. The purpose of this is to obtain homogeneous graphs under the same type of traffic nodes, allowing the next module to capture the correlations between the same type of nodes.

64

L. Luo et al.

Fig. 3. Graph Transformer Module

Let T (v) and T (e) denote sets of node and edge types respectively. In this paper, the node type is the type represented by each cluster obtained by the clustering module, that is, T (v) = {t1(v) , t2(v) , ..., tK(v)(v) }. The edge type is defined according to the node types (v)

at both ends of the edge. Since this paper discusses directed graphs, ti (v) tj

(v)

→ tj

and

(v) → ti represent two types of edges respectively. If the number of edge types is K (e) , (e) (e) (e) then T (e) = {t1 , t2 , ..., tK (e) }. The road network G transformed into a heterogeneous (e) N ×N graph can be represented by a set of adjacency matrices {At (e) }K k=1 , where Atk(e) ∈ R k

is an adjacency matrix that only contains a certain type of edge. Specifically, when (e) there is an edge of type tk from the i − th node to the j − th node, the element in row i and column j of At (e) is non-zero. In a heterogeneous graph, if the edge type (e)

(e)

(e)

k

sequence (t1 , t2 , ..., tl ) is given, the adjacency matrix AP representing the metapath P can be obtained by multiplying the adjacency matrix of the corresponding edge types. Specifically, the following formula can be used: Ap = At (e) At (e) ...At (e) 1

2

(4)

l

Compared with many heterogeneous graph neural networks that require manual meta-paths design, the GTN can automatically learn meta-paths for given data and tasks. The adjacency matrix APˆ of an arbitrary length element path Pˆ can be adaptively obtained by using the following formula: ⎞⎛ ⎞ ⎛ ⎞ ⎛    ⎟⎜ ⎟ ⎜ ⎟ ⎜ (5) APˆ = ⎝ α (1) At (e) ⎠⎝ α (2) At (e) ⎠...⎝ α (l) At (e) ⎠ (e)

t1 ∈T (e)

1

(e)

t2 ∈T (e)

2

(e)

tl ∈T (e)

l

In this formula, APˆ represents the adjacency matrix corresponding to the meta-path obtained by self-adaptation, T (e) represents the set of edge types, α (l) is the weight of the l − th edge type that is learnable. Additionally, this module can set multiple channels to consider the diversity of meta-path types, as shown in Fig. 3. 3.5 Spatio-Temporal Feature Learning Module The spatio-temporal feature learning module takes as input the feature matrix X, and a ˆ ∈ R(B+1)×N ×N composed of the original adjacency matrix of multi-channel tensor A

A Traffic Flow Prediction Framework

65

the road network G and the adjacency matrix of the meta-path graph, where B represents the number of meta-path graphs generated by the graph transformer module. The spatio-temporal feature learning module serves two main purposes. Firstly, it extracts the basic spatio-temporal correlation between nodes from the original traffic data. Secondly, it supplements and strengthens the basic spatio-temporal correlation by extracting the correlation between similar nodes from the meta-path graph. It is worth noting that many researchers have proposed spatio-temporal prediction models that can handle graph-structured data, such as T-GCN, DCRNN, ASTGCN, and MOAGE, among others. Therefore, we can leverage the relevant components of existing models to learn spatio-temporal features and process the adjacency matrix of each channel, such that the features of nodes of the same type can be aggregated. Finally, we can use a fully connected layer to fuse the information of each channel.

4 Experiments and Results 4.1 Dataset 1. SZ-taxi. This dataset contains taxi driving speeds on 156 main roads in Luohu District of Shenzhen from January 1 to January 31, 2015. The data includes two main parts: an adjacency matrix where each row represents a road and the value represents the connectivity between roads, and a feature matrix that records changes in traffic speed on each road. Each row of the feature matrix represents a road, and each column represents the average traffic speed on that road during different time periods. The dataset is sampled at 15-min intervals. 2. Los-loop. This dataset was collected in real-time by highway loop detectors in Los Angeles County. It includes data for 207 sensors and their detected traffic speed from March 1 to March 7, 2012, aggregated every 5 min. The dataset includes an adjacency matrix calculated using the distance between sensors in the traffic road network, as well as a feature matrix. In the experiments, we normalized the input data to the interval [0,1]. Furthermore, we split the data into a training set (80%) and a test set (20%), and predicted the traffic data for 15, 30, and 60 min in the future. 4.2 Experiment Setting This paper uses three evaluation metrics commonly used in regression tasks: Root Mean Square Error (RMSE), Mean Absolute Error (MAE) and Accuracy to evaluate difference between ground truth and predicted value. Besides, we compare the traffic flow prediction framework proposed in this paper with the following models: (1) Historical average model (HA), which regards the evolution of traffic data as a daily cycle change process and takes the average value of all historical data at a certain time step every day as the prediction result of this time step [11]. (2) Autoregressive Integrated Moving Average model (ARIMA), is a widely used time series forecasting model combining moving average and autoregressive [12]. (3) Gated Recurrent Unit (GRU), is an effective RNN structure using gating mechanism

66

L. Luo et al.

to memorize more long-term information for time series modeling [13]. (4) Diffusion Convolutional Recurrent Neural Networks (DCRNN), which uses bidirectional random walk-based diffusion convolution and an encoder-decoder structure to handle spatial and temporal correlations, respectively [14]. (5) Temporal Graph Convolutional Network (TGCN), which combines GCN and GRU, using GCN to capture spatial correlation and using GRU to learn the dynamics of traffic data changing over time [15]. (6) Attention Based Spatio Temporal Graph Convolutional Network (ASTGCN), which introduces attention mechanisms in both spatial and temporal dimensions [16]. (7) Model Combining Outlook Attention and Graph Embedding (MOAGE), which uses the outlook attention mechanism to simulate the spatio-temporal dependencies, and uses the node2vec to learn the node representation of the road network [17]. In this experiment, we directly use T-GCN, DCRNN, ASTGCN, and MOAGE related components as the spatio-temporal feature learning modules of the framework proposed in this paper. The aim was to verify the strengthening and auxiliary functions of the proposed framework for such deep spatio-temporal prediction models. We trained and tested all experiments in this paper using the Pytorch framework on a Linux server with an Intel(R) Xeon(R) Gold 6226R CPU @ 2.90GHz and an NVIDIA GeForce GTX3090 GPU. 4.3 Analysis of Experimental Results In the experiment, we used the framework proposed in this paper to enhance the deep spatio-temporal models among the baseline models, and named the enhanced model “base model+”. The comparison results between the enhanced model and the baseline model are presented in Table 1 and Table 2. It is evident from Table 1 and 2 that non-deep models like HA and ARIMA show noticeable limitations when handling complex and nonlinear traffic data due to their linear and stationary assumptions. Deep learning-based models are usually better than non-deep models, with models considering both temporal and spatial correlations performing better than deep models like GRU, which only consider temporal correlation. Among them, models that consider both temporal and spatial correlation are better than deep models such as GRU that only consider temporal correlation. In the comparison between the enhanced models and their base model, the enhanced models exhibit an average improvement of 5.7% RMSE, 5.7% MAE, and 2.0% accuracy in the SZ-taxi dataset, as well as 13.1% RMSE, 11.5% MAE, and 1.5% accuracy improvement in the Los-loop dataset. The reason why the framework proposed can significantly improve the spatio-temporal prediction model based on the GNN is as follows: the SZ-taxi dataset and the Los-loop dataset have many traffic nodes, and these nodes cannot only present one single traffic pattern. However, the base model cannot effectively distinguish the traffic patterns of nodes, and thus cannot make good use of the common evolution rules and characteristics between similar nodes. On the contrary, the framework proposed can measure the similarity of nodes in traffic patterns, and then cluster them, and use the heterogeneous graph neural network to aggregate the features between similar nodes, and finally help the basic model capture more correlations to improve accuracy. In addition, since the graph transformer network itself can be combined with different GNNs,

A Traffic Flow Prediction Framework

67

the framework proposed is theoretically easy to transplant to any spatio-temporal data prediction model based on GNNs. Table 1. The performance comparison of different model on SZ-taxi

Table 2. The performance comparison of different model on Los-loop

Besides, the selection of hyperparameters plays a significant role in the performance. The number of clusters can greatly affect the enhancement effect of the framework. To determine the best number of clusters, we experimented with different numbers of clusters ranging from 2 to 7 and compared their results. For example, when using the T-GCN+ to predict data in the next 15 min, we selected the number of clusters and controlled the remaining hyperparameters unchanged. The results, shown in Fig. 4, reveal that the best results for the SZ-loop dataset were achieved with 5 clusters, while for the Los-loop it was 6. And the reason for that is probably because the Los-loop has more nodes, leading to more types of traffic patterns.

68

L. Luo et al.

Furthermore, Fig. 4 shows that as the number of clusters increases, the prediction error first decreases and then increases. This phenomenon can be attributed to two factors. First, increasing the number of clusters leads to an increase in model complexity. Second, an increase in the number of clusters may result in a decrease in the number of nodes in the same cluster, making it more difficult for the model to capture potential correlation that exist between nodes in the same cluster. Overall, the experimental results demonstrate the importance of carefully selecting hyperparameters to achieve optimal performance.

Fig. 4. RMSE value under different cluster numbers

5 Conclusion In this paper, we propose a framework for traffic flow prediction based on clustering and heterogeneous graph neural networks. The similarity in traffic patterns between traffic nodes can help spatio-temporal prediction models improve prediction accuracy, but it is challenging for most models to make good use of this similarity. The framework proposed can measure the similarity of traffic patterns between nodes using DTW and Wasserstein distance in terms of flow trend and numerical distribution. It then clusters the traffic nodes based on this similarity and uses the heterogeneous graph neural network to aggregate the characteristics of similar nodes. This is a novel approach that introduces the GTN into traffic flow prediction. The advantage of using the GTN is that it can aggregate the characteristics of similar nodes by constructing meta-paths from the original spatial adjacency. Moreover, it can automatically generate meta-path graphs based on data without requiring domain knowledge or manual design. The results show that the framework proposed can be combined with some representative spatio-temporal prediction models based on GNNs and can strengthen their prediction effect to a certain extent. In future work, we plan to make the framework more lightweight and enhance its portability. Specifically, we aim to make the framework avoid a significant increase in the number of parameters and inference time when combined with more spatio-temporal prediction models based on GNNs. Acknowledgements. This project is supported by Natural Science Foundation of Shandong Province for Key Project (No. ZR2020KF006), National Natural Science Foundation of China (No. 62273164) and A Project of Shandong Province Higher Educational Science and Technology Program (No. J16LB06, No. J17KA055).

A Traffic Flow Prediction Framework

69

References 1. Luo, Q.: Research on intelligent transportation system technologies and applications. In: 2008 Workshop on Power Electronics and Intelligent Transportation System, Piscataway, pp. 529–531. IEEE (2008) 2. Jiang, W., Luo, J.: Graph neural network for traffic forecasting: a survey. Expert Syst. Appl. 117921 (2022) 3. Pan, Z., Wang, Z., Wang, W., Yu, Y., Zhang, J., Zheng, Y.: Matrix factorization for spatiotemporal neural networks with applications to urban flow prediction. In: Proceedings of the 28th ACM International Conference on Information and Knowledge Management, pp. 2683– 2691. ACM, New York (2019) 4. Li, M., Zhu, Z.: Spatial-temporal fusion graph neural networks for traffic flow forecasting. In: Proceedings of the AAAI Conference on Artificial Intelligence, Menlo Park, vol. 35, pp. 4189–4196. AAAI (2021) 5. Cai, X., Dai, G., Yang, L.: Survey on spectral clustering algorithms. Comput. Sci. 35(7), 14–18 (2008) 6. Ng, A., Jordan, M., Weiss, Y.: On spectral clustering: analysis and an algorithm. In: Advances in Neural Information Processing Systems, vol. 14 (2001) 7. Wang, X., et al.: Heterogeneous graph attention network. In: The World Wide Web Conference, pp. 2022–2032. ACM, New York (2019) 8. Yun, S., Jeong, M., Kim, R., Kang, J., Kim, H.J.: Graph transformer networks. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 9. Berndt, D.J., Clifford, J.: Using dynamic time warping to find patterns in time series. In: KDD Workshop, Menlo Park, vol. 15, pp. 359–370. AAAI (1994) 10. Panaretos, V.M., Zemel, Y.: Statistical aspects of Wasserstein distances. Annu. Rev. Stat. Appl. 6, 405–431 (2019) 11. Liu, J., Guan, W.: A summary of traffic flow forecasting methods. J. Highway Transp. Res. Dev. 21(3), 82–85 (2004) 12. Box, G.E.P., Pierce, D.A.: Distribution of residual autocorrelations in autoregressiveintegrated moving average time series models. J. Am. Stat. Assoc. 65(332), 1509–1526 (1970) 13. Cho, K., van Merriënboer, B., Bahdanau, D., Bengio, Y.: On the properties of neural machine translation: encoder–decoder approaches. In: Syntax, Semantics and Structure in Statistical Translation, vol. 103 (2014) 14. Li, Y., Yu, R., Shahabi, C., Liu, Y.: Diffusion convolutional recurrent neural network: datadriven traffic forecasting. In: International Conference on Learning Representations (2018) 15. Zhao, L., et al.: T-GCN: a temporal graph convolutional network for traffic prediction. IEEE Trans. Intell. Transp. Syst. 21(9), 3848–3858 (2019) 16. Guo, S., Lin, Y., Feng, N., Song, C., Wan, H.: Attention based spatial-temporal graph convolutional networks for traffic flow forecasting. In: Proceedings of the AAAI Conference on Artificial Intelligence, Menlo Park, vol. 33, pp. 922–929. AAAI (2019) 17. Zhang, J., Liu, Y., Gui, Y., Ruan, C.: An improved model combining outlook attention and graph embedding for traffic forecasting. Symmetry 15(2), 312 (2023)

Effective Audio Classification Network Based on Paired Inverse Pyramid Structure and Dense MLP Block Yunhao Chen1(B)

, Yunjie Zhu2 , Zihui Yan1 , Zhen Ren1 , Yifan Huang1 , Jianlu Shen1 , and Lifang Chen1

1 Jiangnan University, Wuxi 214000, China

[email protected] 2 University of Leeds, Leeds LS2 9JT, UK

Abstract. Recently, massive architectures based on Convolutional Neural Network (CNN) and self-attention mechanisms have become necessary for audio classification. While these techniques are state-of-the-art, these works’ effectiveness can only be guaranteed with huge computational costs and parameters, large amounts of data augmentation, transfer from large datasets and some other tricks. By utilizing the lightweight nature of audio, we propose an efficient network structure called Paired Inverse Pyramid Structure (PIP) and a network called Paired Inverse Pyramid Structure MLP Network (PIPMN) to overcome these problems. The PIPMN reaches 95.5% of Environmental Sound Classification (ESC) accuracy on the UrbanSound8K dataset and 93.2% of Music Genre Classification (MGC) on the GTAZN dataset, with only 1 million parameters. Both of the results are achieved without data augmentation or transfer learning. The PIPMN can achieve similar or even exceeds other state-of-the-art models with much less parameters under this setting. The Code is available on the https://github.com/JNAIC/PIPMN. Keywords: Audio Classification · Multi-stage Structure · Skip Connection · Multi-layer Perceptron (MLP)

1 Introduction Audio classification aims to categorize sounds into several predefined groups such as Environmental Sound Classification (ESC) [1] or Music Genre Classification (MGC) tasks [2]. The demand for accurate audio classification systems has grown in recent years, with applications in various fields like hearing aids [3], urban planning [4], and biology [5]. For instance, MGC can be utilized in music recommendation systems or emotional tests [6, 7]. Meanwhile, deep learning techniques have revolutionized the field of audio classification, showcasing outstanding accuracy. Specifically, Convolutional Neural Networks Y. Chen, Y. Zhu and Z. Yan—contributed equally to the paper.

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 70–84, 2023. https://doi.org/10.1007/978-981-99-4742-3_6

Effective Audio Classification Network

71

(CNN) and self-attention mechanisms have been highly effective. These models typically process 2D spectrograms extracted from audio as if they were images [8–10, 37, 41, 44]. Additionally, techniques such as transfer learning [11], knowledge distillation [12], and cross-modal cooperative learning [13, 37, 38] have been employed to enhance model robustness and accuracy. Moreover, despite improvements in the ESC task by previous works [14–16], these networks primarily focus on the spatial information of the time and frequency domain, often neglecting adequate extraction of depth domain information, resulting in lower accuracy. On a more optimistic note, we have observed that the humble Multi-Layer Perceptron (MLP), a simple and computationally efficient module, has demonstrated competitiveness in image classification tasks [17] on ImageNet [18]. Such performance lends credibility to the MLP’s ability to process complex information. Inspired by this, we decided to build our network around the MLP. Yet, we recognized that the multi-stage structure [19] commonly used for CNN does not suit MLP-based networks due to the propensity for overfitting. To tackle this, we introduce a long-range skip connection with layer scale [20] on the bottleneck as a countermeasure against overfitting. In this paper, we aim to establish a novel approach for depth and time domain information extraction in audio classification, emphasizing minimal parameter use. Our motivation stems from three primary aspects. Firstly, we aim to process audio in a lightweight manner rather than treating it as an image. Secondly, we seek to replace CNN and transformer with MLP to reduce computational cost and parameters. This has led to the development of the Dense MLP (DM) block for audio classification. Lastly, to minimize network overfitting, we propose the Paired Inverse Pyramid Structure (PIP). To summarize, the main contributions of this paper are as follows: 1. We propose Temporal MLP, Depth Domain Block and Dense MLP block to extract audio’s spatial and temporal domain information more effectively. 2. We propose a Paired Inverse Pyramid Structure to reduce overfitting for multi-stage MLP networks in audio classification. 3. The entire PIPMN achieves outstanding accuracy on both the UrbanSound8K dataset for the ESC task and the GTAZN dataset for the MGC task without data augmentation and transfer learning.

2 Methods 2.1 Paired Inverse Pyramid Structure The overall structure of the model is depicted in Fig. 1. One of the key features of this structure is the incorporation of a long-range skip connection with layer scale on the bottleneck, which is instrumental in reducing overfitting. The rationale behind this is that as the network becomes deeper in the conventional multi-stage structure, it represents a more complex function due to the increased number of parameters. This complexity enhances the propensity of the network to overfit. However, by utilizing training dynamics [20] offered by layer scale, the network can represent a less complex function from the skip connection, making it less susceptible to overfitting.

72

Y. Chen et al.

Fig. 1. Paired Inverse Pyramid Structure.

Moreover, previous studies on deep networks with stochastic depth [21] demonstrate that the learning of a layer depends on the information obtained from both the preceding layer and the non-adjacent layer. Therefore, the long-range skip connection can preserve the information extracted by a layer and transmit it to another layer that requires it, without losing it in multiple intermediate layers. The whole structure can be described as follows:   (1) n = κ1 , κ2 , . . . , κn−1 , κn , κn−1 , . . . , κ 1 where n is the number of pairs of layers, κi is the hyperparameter which is the expansion rate of the depth domain’s dimension. An example of this structure is demonstrated in Fig. 1 where n = 3, κ1 = 2, κ2 = 4, κ3 = 8 and 3 = [2, 4, 8, 4, 2]. The (Batch, time_length,  [i] × in_dim) means the size of the sensor after processing by the layer. The Batch is the batch size of the input tensor, the time_length is the hyperparameter set for Adaptive Pooling and in_dim is the initial depth domain’s dimension. 2.2 Positional Modelling for Time Domain and Depth Domain The spectrogram fed into the positional modelling can be represented as χ ∈ RB×D×L where B is the batch size, L is the temporal domain, and D is the Depth domain. And the positional modelling for the depth domain on the time domain can be described as follows: ϕ(χ) = χ ∗ γ31

(2)

Effective Audio Classification Network

73

where ϕ(χ) is the positional modelling, the γ31 is the convolution layer with a kernel size of 3 and stride of 1. To fit the positional modelling to the network, we used the 1-padding in the layer. To reduce parameters, the group of input equals D. 2.3 Temporal MLP Based on Feed-Forward Structure The structure is illustrated in Fig. 2. The feed-forward structure includes one skip connection [22], one MLP block and a positional modelling block. The MLP block includes one LayerNorm [23] layer, two Linear layers and a GELU [24] activation function.

Fig. 2. The temporal MLP based on feed-forward structure.

The operation of the MLP block is akin to the structure of a transformer [25]. In the MLP block, the length of the time domain is first expanded and then condensed back to its original length. This mechanism aids in the optimal exploitation of information within the time domain. The degree of expansion is regulated by a hyperparameter denoted as α. The MLP block can be described using the following equation:    T T GELU WαL×L (3) φ(χ) = WL×αL (η(χ)) + b1 + b2 where χ ∈ RB×D×L , η(χ) is the layer normalization for χ, Wa×b is the learnable parameters vector of sizea × b, b1 , b2 is the corresponding bias, GELU is the activation function and L = time_length. The whole feed-forward structure can be described as follows: T γ(χ) = W3αL×L ([ϕ(χ), χ, φ(χ)])

(4)

74

Y. Chen et al.

where χ ∈ RB×D×L , φ(χ) is the operation in the MLP, [.,.,.] is the concatenation operation between tensors, and W3αL×L transforms the size of the concatenation tensor into the original one. This feed-forward structure aims to extract the information related to the temporal domain and the relations between the temporal and depth domain. 2.4 Depth Domain Block and Linear Skip Connection Due to the complicated structure of PIP, the depth domain block of each dense MLP block is relatively simple to reduce overfitting, which will not compromise its overall accuracy and will be less likely to overfit. The structure is depicted in Fig. 3 and can be described as follows:   T (5) δ(χ ) = GELU WDTin ×Dout η(χ ) + ωD (χ ) in ×Dout where wDin ×Dout is the learnable parameters vector of size Din × Dout and χ ∈ RB×L×D (its size permutation is different from the tensor χ mentioned above). The special skip connection for this layer deserves attention. Because the output tensor size of this layer differs from the input tensor size, we need to apply a simple Linear layer to the input to match the skip connection with the structure. The skip connection is effective because it provides a simple way to transform and map the original data to the processed data. Without this skip connection, the information propagated by multi-layers will be lost due to multiple transformations.

Fig. 3. The depth domain block of the dense MLP block.

Effective Audio Classification Network

75

Fig. 4. The Structure of Dense MLP. The proposed structure for the Dense MLP.

Moreover, compared with the inherited complex structure of a deep neural network, a simple Linear layer is simple enough to serve as the skip connection in terms of transformation. Therefore, a Linear layer can act as the skip connection implemented with an identity shortcut in the ResNet. We call this module the Linear Skip Connection. 2.5 Dense MLP The whole structure is depicted in Fig. 4. And the processing procedure for the structure can be described as follows: θ(χ) = 2 δ(1 λ(γ(λ(χ))) + χ)

(6)

where λ(χ) is the operation of permutation between the depth domain and temporal domain, the operation transforms the χ ∈ RB×L×D into χ ∈ RB×D×L or transforms the χ ∈ RB×D×L into χ ∈ RB×L×D , 1 , 2 means learnable scalers as the layer scale value[20], which can reduce the overfitting according to the experiments. For example, when n = 2, 1 = [2, 4, 2], time_length = 10 and in_dim = 100, the Din and Dout for the first Dense MLP is in_dim and 2 × in_dim, for the second Dense MLP is 2 × in_dim and 4 × in_dim, for the third Dense MLP is 4 × in_dim and 2 × in_dim.

76

Y. Chen et al.

3 Experiments 3.1 Overview on the UrbanSound8K Dataset The UrbanSound8K dataset, created by Salamon et al. in 2014, is an extensive compendium of urban noise recordings. It encompasses a collection of 8732 sound samples, all of which have been meticulously classified into ten distinct categories drawn from a broad urban sound taxonomy. These categories include common urban sounds such as car horns, sirens, and children playing, amongst others.

Fig. 5. Different types of acoustic features are extracted from the audio samples and concatenated into a single tensor.

The samples, gathered from various field recordings, have a maximum duration of 4 s and are sampled at varying rates, ranging from 16 kHz to 48 kHz. The dataset is equipped with additional metadata that provides insightful details like the geographical location where the sample was recorded and the type of device employed for the recording.

Effective Audio Classification Network

77

This dataset has evolved into an essential benchmark for testing and validating the performance of audio classification models, particularly those focused on urban sounds. 3.2 Overview on the GTZAN Dataset The GTZAN dataset, crafted by Tzanetakis and Cook, is a comprehensive dataset dedicated to the domain of music genre classification. Comprising 1000 audio tracks, each 30 s long, the dataset presents a balanced representation of ten diverse genres, including blues, classical, country, disco, hip-hop, jazz. All the tracks in the GTZAN dataset adhere to the WAV format, sampled at a frequency of 22.05 kHz and a 16-bit resolution. This extensive and varied compilation of music has been extensively utilized as a benchmark for evaluating the performance of music genre classification models. Besides, it also finds applications in other audio classification tasks, thus proving to be a valuable asset for researchers in the field of music information retrieval (MIR) and audio signal processing. 3.3 Training setup and Preprocessing for the Datasets In our study, we evaluated our approach using three datasets: UrbanSound8K for the ESC, GTZAN for the MGC. We employed similar feature extraction techniques across all datasets, including NGCC, MFCC, GFCC, LFCC, and BFCC, with a feature size of (399,20). For GTZAN, we divided each sample into seven segments, resulting in foursecond audio samples for both GTZAN and UrbanSound8K. Our training batch size was 128, yielding an input size of (128,399,100) for all datasets. We used the AdamW optimizer with a learning rate of 0.001, weight decay of 0.05, and default parameters. The loss function was Cross-entropy-loss with label-smoothing of 0.1. Training lasted for 3500 epochs, halting when the training accuracy reached 100% and the loss ceased decreasing. No data augmentation, transfer learning, EMA, pretraining, or other techniques were employed to assess network efficiency. The datasets were randomly split into three parts, with 10% for testing, 10% for validation. The hyperparameters of PIPMN are set as follows: time_length = 5, in_dim = 100, n = 2, ϕ2 = [4, 8, 4] and α = 3. The comparison models are implemented by virtual of timm and their hyperparameters are the default parameters in timm. 3.4 UrbanSound8K Experiments Results Table 1 compares our model’s effectiveness with other high-performing models on the UrbanSound8K dataset. The table shows that our proposed model achieves higher results with fewer parameters. Consequently, our model can retrieve helpful information without too much overfitting.

78

Y. Chen et al.

Table 1. PREVIOUS MODELS VS. PROPOSED MODEL ON URBANSOUND8K. (The unit measurement for this table is %. MaP stands for Macro Precision, MaF1 stands for Macro F1, LSTM here stands for “Replace DenseMLP with LSTM”, and MiF1 stands for Micro F1.) Framework

Accuracy

MaP

MaF1

MiF1

Params

VIT-Small [30]

93.7

93.5

93.1

91.3

21.5M

CoaT-lite-mini [31]

94.4

94.3

94.6

92.6

10.5M

ConViT-Small [32]

95.0

94.5

94.5

93.9

27.1M

MobileViT III [33]

95.1

95.3

95.2

95.0

21.5M

CovNeXT [19]

94.7

95.4

95.1

93.6

87.6M

Mlp-Mixer [17]

94.5

94.6

94.9

94.3

17.5M

LSTM

91.6

92.2

91.6

90.6

11.0M

CovNeXTV2 [42]

94.9

95.7

95.4

94.6

49.5M

AST [43]

94.7

95.5

95.1

95.3

85.5M

Proposed model

95.5

95.8

95.9

95.1

1.4M

Figure 6 and 7 separately present the loss and accuracy curve diagram. The curve diagrams of accuracy and loss of the training set and validation set demonstrate the model is easy to train because the curve is smooth without sudden changes.

Fig. 6. The loss curve diagram of the training set and validation set.

Fig. 7. The accuracy curve diagram of the training set and validation set.

Effective Audio Classification Network

79

3.5 GTZAN Experiments Results To further explain the efficiency of our model, we evaluate our model on the GTZAN dataset for the MGC task. As we can see from Table 2, though our model is not state-ofart or the most lightweight model, our model maintains a balance between the parameters and the accuracy. What is more, these models are specialized for the MGC task. Table 2. PREVIOUS MODELS VS. THE PROPOSED MODEL ON GTAZN (LSTM here stands for “Replace DenseMLP with LSTM”.) Framework

Accuracy

MaP

MaF1

MiF1

Params

VGGish [34]

92.2

93.2

92.8

93.0

72.1M

VGGish + CoTrans-b2 [34]

95

95.6

93.2

94.7

72.6M

Improved-BBNN [35, 36]

91

91.2

91.4

91.5

0.18M

Proposed model

93.3

93.0

93.4

93.7

1.4M

3.6 Abalation Study Table 3 presents an ablation study on the UrbanSound8K dataset, exploring the impact of different components on classification results. Two experiments were conducted, one replacing the input with 50 MFCCs and the other with a 100-Mel-Spectrogram. The results suggest that altering the input type does not significantly affect the model’s performance, indicating its robustness to input variations. Table 3. ABLATION STUDY ON URBANSOUND8K DATASET (MaP stands for Macro Precision, MaF1 stands for Macro F1, MiF1 stands for Micro F1, DM stands for DenseMLP, PIPS stands for Paired Inverse Pyramid Structure, and OMS stands for Original Multi-Stage.) Framework

Accuracy

Map

MaF1

MiF1

Params

Without ➀

95.1

95.2

95.2

94.6

1.4M

Without ➁

95.2

95.6

95.5

95.0

1.4M

Without ➂

93.6

93.7

94.0

93.2

1.4M

MFCC

95.2

95.4

95.5

95.1

0.3M

Mel-Spectrogram

92.1

92.5

92.7

91.9

1.4M

Replace PIPS with OMS

93.9

93.7

94.1

94.2

2.4M

Replace DM with LSTM

91.6

92.2

91.6

90.6

11.0M

Proposed model

95.5

95.8

95.9

95.1

1.4M

➀ is Long Range Skip Connection, ➁ is Positional Modelling, ➂ is Linear Skip Connection. The structures represented by these numbers are already indicated in the figures.

80

Y. Chen et al.

Fig. 8. Multi-Stage Structure as Comparison Structure to the PIP Structure

Additionally, the “Replace DM with LSTM” experiment replaced the Dense MLP (DM) block with an LSTM unit, which is effective for learning time-series data but computationally demanding with more parameters. However, this substitution resulted in decreased performance metrics and increased complexity (from 1.4M to 11.0M parameters). In summary, replacing the Dense MLP with LSTM led to inferior performance and higher complexity. Furthermore, to assess the effectiveness of the Paired Inverse Pyramid Structure (PIPS), we replaced the PIPS in our model with a conventional multi-stage structure (Fig. 8). It is apparent that the original multi-stage structure contains a larger number of parameters compared to the PIPS. Table 3 demonstrates the better performance of our PIPS when compared with the original multi-stage structure. Additionally, we conducted an evaluation of the distinct modules in the Dense Multi-Layer Perceptron (MLP) by sequentially removing elements ➀, ➁ and ➂. The outcomes of these experimental modifications are displayed in Table 3. The results affirm the effectiveness and integral role of these three components in PIPMN.

Effective Audio Classification Network

81

4 Discussion The experiments confirmed that PIPMN outperforms high-performing classification models for audio spectrogram processing when data augmentation and transfer learning are unavailable. It demonstrates superior generalization abilities. The PIP Structure, with a long-range skip connection and layer scale on the bottleneck, reduces overfitting and improves performance. The skip connection allows the network to fit a less complex function, reducing overfitting risk. The layer scale controls training dynamics and information flow between layers, enhancing learning ability. Our experiments on different datasets validate our claims, proving the effectiveness and robustness of our structure. The Linear Skip Connection is effective, using a simple Linear layer to transform input and match output tensor size. It ensures smooth information flow and preserves learning ability, enhancing performance. Compared to complex deep neural network structures, a Linear layer is sufficient as a skip connection. Our experiments confirm the efficiency and reliability of the Linear Skip Connection, which can be extended to other networks for improved accuracy and simplicity. The PIPMN’s MLP-based structure reduces overall model complexity compared to CNNs or Transformers. Preprocessed cepstral coefficient data allows for simpler MLP structures with fewer layers, achieving good performance with lower computational cost and memory requirements. Mel-Spectrogram inputs lead to decreased accuracy in PIPMN, emphasizing the suitability of MLP-based structures for preprocessed audio data. CNNs excel in spatial data with strong spatial patterns like images but may not be optimal for audio transformed into cepstral coefficients. MLP-based structures are more appropriate and efficient in this work, as the absence of strong spatial patterns in preprocessed audio reduces the benefits of CNNs. Moreover, Current SOTA models treat input data as 3D data, introducing complexity, computational time, and overfitting risks when applied directly to audio spectrogram, which is inherently two-dimensional. Our experiments support this finding.

5 Conclusion In this paper, we propose a new network called PIPMN. The proposal of this network is based on the lightweight nature of audio. The results have shown that treating the audio spectrogram without adding a new dimension by virtual of the proposed Dense MLP block and Paired Inverse Pyramid Structure can achieve similar or even higher performance with much fewer parameters than those high-performing classification models. However, the PIPMN only focuses on ESC and MGC. For future work, we will extend our model for more challenging tasks such as sound detection.

82

Y. Chen et al.

References 1. Salamon, J., Jacoby, C., Bello, J.P.: A dataset and taxonomy for urban sound research. In: Proceedings ACM International Conference on Multimedia, pp. 1041–1044 (2014) 2. Tzanetakis, G., Cook, P.: Musical genre classifcation of audio signals. IEEE Trans. Speech Audio Process. 10(5), 293–302 (2002). https://doi.org/10.1109/TSA.2002.800560 3. Alexandre, E., et al.: Feature selection for sound classification in hearing aids through restricted search driven by genetic algorithms. IEEE Trans. Audio Speech Lang. Process. 15(8), 2249–2256 (2007). https://doi.org/10.1109/TASL.2007.905139 4. Barchiesi, D., Giannoulis, D.D., Stowell, D., Plumbley, M.D.: Acoustic scene classification: classifying environments from the sounds they produce. IEEE Signal Process. Mag. 32(3), 16–34 (2015). https://doi.org/10.1109/MSP.2014.2326181 5. González-Hernández, F.R., et al.: Marine mammal sound classification based on a parallel recognition model and octave analysis. Appl. Acoust. 119, 17–28 (2017). https://doi.org/10. 1016/J.APACOUST.2016.11.016 6. Lampropoulos, A.S., Lampropoulou, P.S., Tsihrintzis, G.A.: A cascade-hybrid music recommender system for mobile services based on musical genre classification and personality diagnosis. Multimedia Tools Appl. 59, 241–258 (2012) 7. Silverman, M.J.: Music-based affect regulation and unhealthy music use explain coping strategies in adults with mental health conditions. Community Ment. Health J. 56(5), 939–946 (2020). https://doi.org/10.1007/s10597-020-00560-4 8. Salamon, J., Bello, J.P.: Deep convolutional neural networks and data augmentation for environmental sound classification. IEEE Signal Process. Lett. 24(3), 279–283 (2017) 9. Huang, J., et al.: Acoustic scene classification using deep learning-based ensemble averaging. In: Proceedings of Detection Classification Acoustic Scenes Events Workshop (2019) 10. Tak, R.N., Agrawal, D.M., Patil, H.A.: Novel phase encoded mel filterbank energies for environmental sound classification. In: Shankar, B.U., Ghosh, K., Mandal, D.P., Ray, S.S., Zhang, D., Pal, S.K. (eds.) PReMI 2017. LNCS, vol. 10597, pp. 317–325. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-69900-4_40 11. Kumar, A., Khadkevich, M., Fügen, C.: Knowledge transfer from weakly labeled audio using convolutional neural network for sound events and scenes. In: Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing, pp. 326–330 (2018) 12. Kumar, A., Ithapu, V.: A sequential self teaching approach for improving generalization in sound event recognition. In: Proceedings of 37th International Conference on Machine Learning, pp. 5447–5457 (2020) 13. Aytar, Y., Vondrick, C., Torralba, A.: SoundNet: learning sound representations from unlabeled video. In: Proceedings of 30th International Conference on Neural Information Processing Systems, pp. 892–900 (2016) 14. Zhang, L., Shi, Z., Han, J.: Pyramidal temporal pooling with discriminative mapping for audio classification. IEEE/ACM Trans. Audio Speech Lang. Process. 28, 770–784 (2020) 15. Zhang, L., Han, J., Shi, Z.: Learning temporal relations from semantic neighbors for acoustic scene classification. IEEE Signal Process. Lett. 27, 950–954 (2020) 16. Zhang, L., Han, J., Shi, Z.: ATReSN-Net: capturing attentive temporal relations in semantic neighborhood for acoustic scene classification. In: Proceedings of Annual Conference of the International Speech Communication Association, pp. 1181–1185 (2020) 17. Ilya, T., et al.: MLP-mixer: an all-MLP architecture for vision. In: Neural Information Processing Systems, pp. 24261–24272 (2021) 18. Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., Fei-Fei, L.: ImageNet: a large-scale hierarchical lmage database. In: CVPR 2009 (2009)

Effective Audio Classification Network

83

19. Liu, Z., Mao, H., Wu, C.Y., Feichtenhofer, C., Darrell, T., Xie, S.: A ConvNet for the 2020s. In: Computer Vision and Pattern Recognition, pp. 11966–11976 (2022) 20. Touvron, H., Cord, M., Sablayrolles, A., Synnaeve, G., Jégou, H.: Going deeper with image transformers. In: International Conference on Computer Vision, pp. 32–42 (2021) 21. Huang, G., Sun, Yu., Liu, Z., Sedra, D., Weinberger, K.Q.: Deep networks with stochastic depth. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 646–661. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46493-0_39 22. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Computer Vision and Pattern Recognition, pp. 770–778 (2016) 23. Lei, J.B., Ryan, K., Geoffrey, E.H., Jimmy, L.B., Jamie, R.K., et al.: Layer normalization. Computing Research Repository, abs/1607.06450 (2016) 24. Hendrycks, D., Gimpel, K.: Gaussian Error Linear Units (GELUs). arXiv.org (2022). https:// arxiv.org/abs/1606.08415. Accessed 15 Sept 2022 25. Vaswani, A., et al.: Attention is all you need. In: Proceedings of the 31st International Conference on Neural Information Processing Systems (NIPS 2017), pp. 6000–6010. Curran Associates Inc., Red Hook (2017) 26. Zouhir, Y., Ouni, K.: Feature extraction method for improving speech recognition in noisy environments. J. Comput. Sci. 12, 56–61 (2016). https://doi.org/10.3844/jcssp.2016.56.61 27. Valero, X., Alias, F.: Gammatone cepstral coefficients: biologically inspired features for nonspeech audio classification. IEEE Trans. Multimedia 14(6), 1684–1689 (2012). https://doi. org/10.1109/TMM.2012.2199972 28. Zhou, X., et al.: Linear versus mel frequency cepstral coefficients for speaker recognition. In: 2011 IEEE Workshop on Automatic Speech RecognitionUnderstanding, pp. 559–564 (2011). https://doi.org/10.1109/ASRU.2011.6163888 29. Kumar, C., et al.: Analysis of MFCC and BFCC in a speaker identification system. In: 2018 International Conference on Computing, Mathematics and Engineering Technologies (2018) 30. Alexey, D., Lucas, B., et al.: An image is worth 16x16 words: transformers for image recognition at scale. In: International Conference on Learning Representations (2021) 31. Xu, W., Xu, Y., Chang, T., Tu, Z.: Co-scale conv-attentional image transformers. In: International Conference on Computer Vision, pp. 9961–9970 (2021) 32. Stéphane, D., Hugo, T., et al.: Convit: improving vision transformers with soft convolutional inductive biases. In: International Conference on Machine Learning, vol. 139 pp. 2286–2296 (2021) 33. Touvron, H., Cord, M., Jégou, H.: DeiT III: revenge of the ViT. In: Avidan, S., Brostow, G., Cissé, M., Farinella, G.M., Hassner, T. (eds.) ECCV 2022, pp. 516–533. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-20053-3_30 34. Hedegaard, L., Bakhtiarnia, A., Iosifidis, A.: Continual Transformers: Redundancy-Free Attention for Online Inference, arXiv.org (2022). https://arxiv.org/abs/2201.06268 35. Liu, C., Feng, L., Liu, G., Wang, H., Liu, S.: Bottom-up Broadcast Neural Network for Music Genre Classification, arXiv.org (2022). https://arxiv.org/abs/1901.08928 36. Heakl, A., Abdelgawad, A., Parque, V.: A study on broadcast networks for music genre classification. In: IEEE International Joint Conference on Neural Network, pp. 1–8 (2022) 37. Bahmei, B., et al.: CNN-RNN and data augmentation using deep convolutional generative adversarial network for environmental sound classification. IEEE Signal Process. Lett. 29, 682–686 (2022) 38. Song, H., Deng, S., Han, J.: Exploring inter-node relations in CNNs for environmental sound classification. IEEE Signal Process. Lett. 29, 154–158 (2022) 39. Chen, Y., Zhu, Y., Yan, Z., Chen, L.: Effective Audio Classification Network Based on Paired Inverse Pyramid Structure and Dense MLP Block (2022) 40. Wightman, R.: PyTorch Image Models (2019). https://github.com/rwightman/pytorch-imagemodels

84

Y. Chen et al.

41. Fonseca, E., et al.: Audio tagging with noisy labels and minimal supervision.In: Proceedings of DCASE2019 Workshop, NYC, US (2019) 42. Woo, S., et al.: ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders. arXiv abs/2301.00808 (2023) 43. Gong, Y., Chung, Y.-A., Glass, J.R.: AST: Audio spectrogram transformer. In: Interspeech (2021) 44. Chen, Y., et al.: Data Augmentation for Environmental Sound Classification Using Diffusion Probabilistic Model with Top-K Selection Discriminator. arXiv:2023.15161 (2023)

Dynamic Attention Filter Capsule Network for Medical Images Segmentation Ran Chen1 , Kai Hu1 , and Zhong-Qiu Zhao1,2,3,4(B) 1 College of Computer and Information, Hefei University of Technology, Hefei, China

[email protected]

2 Intelligent Manufacturing Institute of HFUT, Hefei, China 3 Intelligent Interconnected Systems Laboratory of Anhui Province, Hefei University of

Technology, Hefei, China 4 Guangxi Academy of Sciences, Guangxi, China

Abstract. In the existing works, Capsule Networks (CapsNets) have been proven to be promising alternatives to Convolutional Neural Networks (CNNs). However, CapsNets perform poorly on complex datasets with RGB backgrounds and cannot handle images with large input sizes. We propose a Dynamic Attention Filter (DAF) method to improve the performance of CapsNets. DAF is a filter unit between low-level capsule layers and voting layers, which can effectively filter the invalid background capsules and improve the classification performance of CapsNet. Besides, we propose DAF-CapsUNet for inputting medical images of large sizes, combining the advantages of the U-shaped encoder-decoder structure and DAF-CapsNet. Specifically, it contains three fundamental operations: an encoder is responsible for extracting shallow feature information, a capsule module is responsible for capturing the detailed feature information lost due to the pooling layer of the CNNs, and a decoder is responsible for fusing the feature information extracted from the two stages. Extensive experiments demonstrate that DAF can improve the performance of CapsNets on complex datasets and reduce the number of parameters, GPU memory cost, and running time of CapsNets. Benefiting from DAF-CapsNet, our model can achieve more useful information for precise localization. The medical experiments show that DAF-CapsUNet achieves SOTA performance compared to other segmentation models. Keywords: CapsNets · DAF · encoder-decoder · DAF-CapsUNet

1 Introduction In recent years, CNNs have achieved much success in computer vision tasks. However, despite the success, getting more detailed features is usually difficult because of pooling layers in CNNs. For example, the maximum pooling layer only learns representative local features but loses the spatial relationships between features. Hence, CNNs must be trained with a large amount of data to accurately identify images from different viewpoints. But, annotations are often expensive and require much expertise and time. And the annotation of medical images requires not only specialized knowledge and © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 85–97, 2023. https://doi.org/10.1007/978-981-99-4742-3_7

86

R. Chen et al.

experienced doctors but also the segmentation of 3D images into 2D images and the annotation one by one.

Fig. 1. Comparison of Original CapsNet and DAF-CapsNet.

CapsNets [1] was developed to overcome the drawbacks of CNNs mentioned above. CapsNets use capsules instead of individual neurons to represent image instances. A capsule is a group of neurons to represent an object or parts of an object. And a capsule is a vector including much more information, such as texture, direction, rotation, and so on. In addition, routing-by-agreement of CapsNets can find similarities between these capsules and allow for modeling part-to-whole relationships. CapsNets perform well in image classification [1–4] and have shown outstanding performance in various segmentation tasks [5–10]. However, the classification performance of the CapsNet is poor when dealing with complex datasets with RGB background. And CapsNet also cannot handle images with large input sizes. As illustrated in Fig. 1, original routing regards the background information as the valid capsules to participate in the routing process and finally outputs the wrong result. These redundant capsules interfere with the final recognition result and increase the computation cost and time complexity. In order to improve the classification performance of CapsNets on complex datasets, we propose a filtering method that reduces the redundant capsules between capsule layers, named Dynamic Attention Filter (DAF). From Fig. 1, DAF improves the classification performance of CapsNets by removing the invalid background voting capsules. Previous work [6] addresses this issue by implementing a mean-voting procedure in capsule layers. However, the method loses some useful information because it uses a mean capsule representing all capsules in the receptive field. We use two dynamic attention units to remove the redundant capsules (See Fig. 2). In order to enable CapsNets to process medical images with large input sizes, we combine the U-shaped encoder-decoder structure [11] with DAF-CapsNet and propose a capsule U-shaped network, namely DAF-CapsUNet. As shown in Fig. 3, it contains three critical operations: an encoder to extract low-dimensional feature information from the input image and reduce the high-resolution image to a suitable size for capsule module extraction, a capsule module to extract more detailed features including texture, direction, etc., and a decoder to upsample the outputs of capsule layers, combined with the high-resolution CNN features by skip-connections. In summary, DAF removes some

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

87

redundant capsules, allowing CapsNet to process images with larger input sizes. And the encoder also reduces the size of the input images. Therefore, DAF-CapsUNet can handle medical images with a larger resolution. This paper makes the following contributions: • We propose a DAF method which can remove the invalid voting capsules in capsule layers and reduce the time complexity and GPU memory cost. • We present a DAF-CapsUNet model for medical image segmentation. It can use the capsule module to improve further feature representation, contributing to more precise segmentation results. • We have demonstrated the effectiveness of DAF. Besides, extensive experiments also show our model’s classification and segmentation performance.

2 Related Work 2.1 Capsule Network CapsNets were proposed to break through the weakness of CNN and RNN [12], which can learn part-whole relationships via transformation matrices. Correspondingly, a dynamic routing procedure [1] was proposed to capture the part-whole relationships. The probability that the entity exists is the capsule length [1] or an activation probability [2, 3]. Compared with a traditional neuron, a capsule contains not only different properties of the same entity but also has the spatial relationship between low-level capsules (e.g., eyes, mouth) and high-level capsules (e.g., face). In the routing procedure, the low-level capsule (part) makes predictions for the pose of each high-level capsule (whole) by a voting procedure. The pose of low-level capsules multiplies the transformation matrices to generate the voting matrices. The routing process ensures that most of the voting capsules are passed to the next layer through loop iterations. Subsequently, the EM algorithm is then performed between two adjacent capsule layers to update the voting matrices [2]. The voting matrices are the weights of the parts corresponding to the whole. Self-Routing [13] replaces the routing-by-agreement with the self-routing mechanism. And it utilizes a like Mixture-of-Experts (MoE) methods that enhance the robustness of adversarial examples. IDAR [14] proposes a novel routing algorithm that differs from the routing-by-agreement scheme. It directly makes the routing probability depend on the agreement between the parent’s pose (from the previous iteration step) and the child’s vote for the parent’s pose (in the current iteration step). 2.2 Medical Image Segmentation CNNs have become the most effective method in the field of medical image segmentation due to their powerful image-learning ability. As Fully Convolutional Network (FCN) [15] has shown remarkable performance on semantic segmentation, U-Net [11] bases on FCN that proposes an encoder-decoder architecture with skip-connections on medical segmentation task. Many models are inspired by the encoder-decoder architecture, such as V-Net [16], M-Net [17], H-DenseUNet [18], SegCaps [5], nnU-Net [19], MPUnet [20], MSU-Net [21], HPN [22], TransUNet [23] and Swin-Unet [24]. The U-Net style models can learn more contextual information by taking the whole image as their input.

88

R. Chen et al.

Additionally, the skip-connections can fuse the low-level and high-level features to obtain global. Another medical image segmentation method, for example, Hd-Net [25], Phiseg [26], Ce-Net [27], and SS-Net [28], often introduces an additional feature map pyramid network (FPN) [29] at the end of the backbone network to deal with segmentation tasks. In general, these models use a classic pretrained model, such as resnet-101 [30] and VGG [31], as the backbone and then use skip connections and different FPNs to completer the segmentation tasks. Besides, most models utilize data augmentation in the training procedure because medical image datasets are usually small.

Fig. 2. The overall architecture of DAF, CL and CL+1 are the number of capsules. The dark and light blue rectangles represent valid and redundant capsules, respectively.

3 Proposed Method 3.1 DAF-CapsNet To reduce the computation cost in CapsNets, we propose a filter method, namely DAF (See Fig. 2). We denote the input to the capsule layer L as ul b,C L ,d ,L ,L , where b, ( w h) C L , d , Lw , Lh , and l index the batch size, number of capsules, capsule dimensions axis, spatial width axis, spatial height axis, and the l-th capsule layer, respectively. First, we perform a convolutional transform on each capsule channel to generate various poses for low-level capsules. There are redundant capsules in these capsule poses. The redundant capsules are also called the background capsules. As shown in Fig. 1, the redundant capsules are the background features in the input image, such as older adults, windows, and cabinets. The redundant capsules in layer L participate in the voting process and interfere with the recognition result. Second, we use two attention units to filter the transformed capsules. During the training process, the two attention units set the weight of the background capsules to 0 to prevent the background capsules from interfering with the routing process. The process of filtering is as follows:    l (1) = Norm Att pl , Ap pl+1    l al+1 = Norm Att al , Aa

(2)

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

89

where pl and al are the pose matrix and activation of all capsules, respectively. And Att uses a weight matrix to remove invalid voting capsules dynamically, Ap and Aa are inputindependent learnable weight matrices. Norm normalizes the poses of capsules in the same capsule channel. Since different capsules represent different object parts, and the learned feature distributions are also different, different processing methods are adopted for different capsule channels. The filtered capsules are the voting capsules. Finally, we utilize the routing process to activate the capsules in the next layer L + 1. Taking the human visual system as an example, people move from one part to the whole when they observe the target object. Similarly, the routing process constructs hierarchical spatial relationships between voting and high-level capsules. The hierarchical spatial relationship is also from the part to the whole. The process of routing is as follows:       l l , Wb (3) , al+1 ul+1 = σ Linear GMM Cat pl+1 l l , GMM is Gaussian Mixed Model, it ensures where Cat is to concatenate pl+1 and al+1 that the majority of the voting capsules are in agreement and activates the next layer of the same high-level capsule. Linear is the fully connected layer, and σ is the Softmax function. Wb is a weight matrix which is initialized to 0, and continuously updated during the training process to filter out valid voting capsules. ul+1 is the output capsule, and its activation value represents the probability that each capsule exists, the probability value of belonging to the category. In the original capsule layers, each capsule in layer L in the receptive field (KX ×KY ) receives feedback from the corresponding capsule in layer L + 1. So, each capsule in layer L casts CL+1 votes, where CL+1 is the number of capsule types in layer L + 1, resulting in KX × KY × CL+1 × CL votes for routing procedure of the capsules at layer L + 1. We can reduce the votes by K times after using DAF and result in a total of K × CL+1 × CL votes, much fewer than the original. Therefore, DAF reduces the number of parameters, running time, and GPU memory cost of CapsNet by reducing the number of voting capsules.

3.2 DAF-CapsUNet Architecture The overall architecture of our proposed DAF-CapsUNet is shown in Fig. 3. It is composed of encoder, capsule module, and decoder. Most common methods train a CNN as an encoder to obtain feature information. Unlike existing models, our model introduces CapsNets to improve the encoder for more delicate high-dimensional features information. Inspired by TransUNet [23], we add a capsule module at the tail of the encoder. The process of feature extraction can be divided into coarse-grained and fine-grained stages. First, the encoder extracts coarse-grained features, and the capsule module extracts fine-grained features via DAF-CapsNet. The fine-grained features include texture, rotation, and flip of the object. Then, the decoder combines coarse-grained and fine-grained features through upsampling and skip connections. Finally, the fused features become segmented feature maps after passing through the segmentation head.

90

R. Chen et al.

Fig. 3. The overview of DAF-CapsUNet architecture.

3.3 Object Function Class imbalance often occurs in medical images, resulting in training dominated by classes with more pixels. It is difficult for smaller objects to learn their features. Hence, we use cross entropy loss and the generalized Dice loss (GDL) [32] to compute the loss of our model. Then, we get the following loss: Lce = −

1 N m (yij log yˆij ) i=1 j=1 N

(4)

where N is the number of pixels, m is the number of categories in the medical image dataset. The role of Lce is to examine each pixel and compare the class prediction yˆij with our target value yij . yij and yˆij represent the pixel points in the feature map. wj =  N

1

i=1 yij

2

(5)

wj is the weight of the class j. Subsequently, we get:  2 N yij yij 1 m =1− wj N i=1  j=1 m i=1 yij + yij 

Ldice



(6)

where m is the number of classes, yij is the ground-truth of class i at the j th pixel and the yˆij is predictive value corresponding to class i. Ldice measures the similarity between the predicted value yˆij and the ground-truth yij . yˆij and yij represent two sample collections. The final segmentation loss is a weighted sum of Lce and Ldice : Lse = λLce + (1 − λ)Ldice

(7)

where λ is initialized to 0. Then, the network gradually learns to assign higher weight to the Lce .

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

91

4 Experiments 4.1 Datasets and Training Details We evaluate our method on four datasets, including CIFAR10, SVHN, SmallNORB, and Fashion-MNIST (F-mnist) datasets. We augment these datasets using rand affine transformation, which comes from torchvision transforms package. We split the training set in each dataset into a training set (90%) and a validation set (10%). Moreover, we evaluate our proposed model on Synapse multi-organ CT (Synapse) and Automated cardiac diagnosis challenge (ACDC). The Synapse contains 30 abdominal CT scans. We divide the dataset into 18 training and 12 testing samples [23, 24]. The ACDC was collected from 100 patients using two MRI scanners. We divided the dataset into 70 training samples, 10 validation samples, and 20 testing samples. Our best model is trained for 300 epochs. First, the initial learning rate is set to 6e-3 and a decay rate of 2e-7. Then, we add the running time (R) to measure the computation cost. For the parameters (P), GPU memory cost (GPU), and running time, to be fair, we use the same 2080ti to measure uniformly. During the training time, the model uses an SGD optimizer with a momentum of 0.9 and weight decay 1e-4. All the experiments are implemented with the run on a single GeForce RTX 2080Ti with 11GB memory. In the medical experiment, we use the evaluation metrics of the average Dice-Similarity coefficient (DSC) and the average Hausdorff Distance (HD) on Synapse dataset. For the ACDC dataset, we only use HD to evaluate our model. Table 1. Performance comparison with other CapsNets. Model

P(M)

R(S)

DCNet [33]

11.8

361.5

3.2

96.8

SR-Routing [13] VB-Routing [34]

GPU(G) 3.29 11.0

CIFAR-10

SVHN

SmallNORB

F-mnist

82.63

95.58

94.43

94.64

92.14

96.88

-

97.03

0.1

309.7

3.80

87.60

95.82

98.46

94.80

DeepCaps [35]

13.43

620.7

3.72

92.74

97.56

-

94.23

EM (baseline) [2]

14.36

26.8

7.46

87.47

86.74

90.30

89.76

ME [6]

0.19

5.4

2.16

88.10

84.24

82.04

88.74

DAF-CapsNet

0.23

10.5

3.03

95.27

96.01

92.67

94.74

4.2 Comparison with Other CapsNets We first compare our method with other CapsNets on four datasets, such as CIFAR-10, SVHN, SmallNORB, and Fashion-MNIST (F-mnist) datasets. From Table 1, compared with the baseline EM model, DAF-CapsNet achieves the best accuracy on all datasets with fewer number of parameters and GPU memory cost. Although our running time and GPU memory cost are higher than Mean-Routing (ME), our model achieves better

92

R. Chen et al.

classification performance. The experimental result implies that ME loses some useful features, and the redundant capsules affect classification performance. Compared to the remaining CapsNets, DAF-CapsNet achieves competitive classification performance with fewer number of parameters, GPU memory cost, and the running time. The experimental results in Table 1 demonstrate that DAF can improve the classification performance of CapsNets while reducing the number of parameters, GPU memory cost, and running time of CapsNets.

Fig. 4. Comparison with different CapsNets. The blue lines and gray rectangles mean DAF is added.

4.3 Ablation Study on DAF To demonstrate the effectiveness of DAF, we apply DAF to other CapsNets, such as EM, DR-2 [1, 13], EM-2 [2, 13], SR-2 [13], and IDAR [14]. We evaluate the running time, classification performance, and GPU memory cost of different CapsNets. As shown in Fig. 4, the experimental result shows that DAF can reduce the running time and GPU memory cost of other CapsNets. Moreover, DAF improves the classification performance of the baseline EM models on CIFAR-10 and SVHN datasets (See Fig. 4). Compared with the remaining CapsNets, DAF only improves the performance of the other CapsNets on the CIFAR-10 dataset, with a slight improvement on the SVHN dataset. Experiments show that DAF can improve model performance by removing redundant capsules. 4.4 Complexity Analysis of CapsNet The routing-by-agreement of CapsNets contains a large number of matrix operations. Matrix operations will increase GPU memory cost and time complexity. For example, the baseline EM routing model contains matrix multiplication. Between the adjacent capsule layers, EM routing multiplies the pose (p) and the weight parameter (w) to generate the voting matrix (v). EM routing is a fully connected structure. So, its time complexity is O(bsCL K 2 CL+1 d ). SR-Routing adopts MOE method, and its time complexity is O(bsCL K 2 d ) (s = L + 1h L + 1w ). Our model reduces the number of capsules using DAF. So, the time complexity of DAF-CapsNet is O(bsCL KCL+1 d ). The more complex the matrix operations, the longer

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

93

the running time, so we measure the time complexity of the model by the running time. Compared with the baseline EM model, DAF-CapsNet uses less the running time and GPU memory cost (See Table 1). Although the time complexity of our model is the same as SR-Routing, the running time, number of parameters, and GPU memory cost of our model are much fewer than SR-Routing. Figure 4 also demonstrate that DAF can reduce the running time and GPU memory cost of other CapsNets.

Fig. 5. Success rates (%) with FGSM attacks against different routing methods of CapsNets.

4.5 Robustness to Adversarial Examples Researchers hope to enhance the performance of neural networks in adversarial examples due to neural network models cannot handle adversarial examples effectively. An adversarial example is to add interference information to the input to trick the neural network into making the wrong classification result. Interference information means random noise. Compared to CNNs, CapsNets are more resistant to these attacks in the experiments of EM Routing. Thus, we use FGSM [36] to evaluate the robustness of our model. As shown in Fig. 5, DAF achieves the best performance in both untargeted and targeted attacks. The result proves that redundant capsules may affect the robustness of EM Routing. And it shows that ME loses some finer features. 4.6 Experiment Results on Synapse Dataset We compare our model with previous models on Synapse. From Table 2, our method achieves the best performance with an accuracy of 88.43%. Figure 6 shows the segmentation visualization results of different models. Here, we compare our method with five models: nnFormer [39], Swin-Unet, TransUNet, SegCaps, and EM-CapsUNet. The first row shows that our model predicts more true positives. The second row shows that our method can segment organs well in a single organ image. The third row shows that our model controls over-segmenting organs more effectively. Unlike SegCaps, we regard the capsule as a module to extract features, while SegCaps uses the capsule to replace convolution. From Table 2, SegCaps has poor performance in segmenting multiple organs. 4.7 Experiment Results on ACDC Dataset We use another ACDC dataset to evaluate segmentation performance. The results are summarized in Table 3. Our model achieves the best performance with an accuracy of 94.23%, which shows that our method has good generalization ability.

94

R. Chen et al. Table 2. Segmentation performance comparison with different networks on the Synapse.

Model

DSC

HD

Aorta

Gallbladder

Kidney(L)

Kidney(R)

Liver

Pancreas

Spleen

Stomach

V-Net [16] DARR [37]

68.81

-

75.34

51.87

80.75

80.75

87.84

40.05

80.56

56.98

69.97

-

74.74

53.77

73.24

73.24

94.08

54.18

89.90

45.96

R50 U-Net [11]

74.66

36.87

87.74

63.66

80.60

78.19

93.74

56.90

85.87

74.16

U-Net [11]

76.85

39.70

89.07

69.72

77.77

68.60

93.43

53.98

86.67

75.58

R50 Att-UNet [38]

75.57

36.97

55.92

63.91

79.20

72.71

93.56

49.37

87.19

74.95

Att-UNet [38]

77.77

36.02

89.56

68.88

77.98

71.11

93.57

58.04

87.30

75.75

TransUNet [23]

77.48

31.69

87.83

63.13

81.87

77.02

94.08

55.86

85.08

75.62

Swin-Unet [24]

79.13

21.55

85.47

66.53

83.28

79.61

94.29

56.58

90.66

76.60

nnFormer [39]

86.57

10.53

92.04

70.17

86.57

86.25

96.84

83.35

90.51

86.63

SegCaps [5]

22.90

32.12

26.35

31.52

33.50

61.94

4.02

9.72

4.00

EM-CapsUNet [2]

62.59

80.15

81.74

49.31

68.14

64.15

89.20

28.58

74.45

45.18

DAF-CapsUNet

88.43

10.5

95.87

70.09

88.07

88.23

97.84

85.27

94.74

89.43

138.2

Fig. 6. Comparison with different approaches by visualization.

Table 3. Segmentation accuracy of different methods on the ACDC dataset. Model

DSC

RV

Myo

LV

R50 U-Net [38]

87.55

87.10

80.63

94.92

R50 Att-UNet [11]

86.75

87.58

79.20

93.47

TransUNet [23]

89.71

88.86

84.53

95.73

Swin-Unet [24]

90.00

88.55

85.62

95.83

nnFormer [39]

92.06

90.94

89.58

95.65

DAF-CapsUNet

94.23

92.47

91.87

98.62

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

95

5 Conclusion In this paper, we propose DAF improve the performance of CapsNets. Moreover, DAF reduces the number of parameters and GPU memory cost of CapsNets. Extensive experimental results show that DAF can increase classification performance and reduce the time complexity and the computation cost of CapsNets. We also combine DAF-CapsNet and the U-shaped encoder-decoder structure and propose DAF-CapsUNet. We evaluate our DAF-CapsUNet on Synapse and ACDC datasets, which shows that our model outperforms other state-of-the-art methods. Compared to CNNs, CapsNets can achieve competitive performance with fewer parameters. In the future, we should optimize CapsNets with fewer number of parameters and computation costs. And we plan to make CapsNets scale up to a much larger dataset, such as ImageNet. In addition, we should also apply CapsNets to other vision tasks. Acknowledgment. This work was supported in part by the National Natural Science Foundation of China under Grants 61976079, in part by Guangxi Key Research and Development Program under Grant AB22035022, and in part by Anhui Key Research and Development Program under Grant 202004a05020039.

References 1. Sabour, S., Frosst, N., Hinton, G.E.: Dynamic routing between capsules. In: Advances in NeurIPS (2017) 2. Hinton, G.E., Sabour, S., Frosst, N.: Matrix capsules with EM routing. In: ICLR (2018) 3. Kosiorek, A., Sabour, S., Teh, Y.W., Hinton, G.E.: Stacked capsule autoencoders. In: Advances in NeurIPS (2019) 4. Mazzia, V., Salvetti, F., Chiaberge, M.: Efficient-capsnet: capsule network with self-attention routing. Sci. Rep. 1–13 (2021) 5. LaLonde, R., Bagci, U.: Capsules for object segmentation. arXiv preprint arXiv:1804.04241 (2018) 6. Duarte, K., Rawat, Y., Shah, M.: Videocapsulenet: a simplified network for action detection. In: Advances in NeurIPS (2018) 7. Mobiny, A., Yuan, P., Cicalese, P.A., Van Nguyen, H.: DECAPS: detail-oriented capsule networks. In: Martel, A.L. (ed.) MICCAI 2020. LNCS, vol. 12261, pp. 148–158. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59710-8_15 8. McIntosh, B., Duarte, K., Rawat, Y.S., Shah, M.: Visual-textual capsule routing for text-based video segmentation. In: CVPR, pp. 9942–9951 (2020) 9. Duarte, K., Rawat, Y.S., Shah, M.: Capsulevos: semi-supervised video object segmentation using capsule routing. In: ICCV, pp. 8480–8489 (2019) 10. Afshar, P., Naderkhani, F., Oikonomou, A., Rafiee, M.J., Mohammadi, A., Plataniotis, K.N.: Mixcaps: a capsule network-based mixture of experts for lung nodule malignancy prediction. Pattern Recognit. 116, 107942 (2021) 11. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-31924574-4_28

96

R. Chen et al.

12. Zaremba, W., Sutskever, I., Vinyals, O.: Recurrent neural network regularization. arXiv preprint arXiv:1409.2329 (2014) 13. Hahn, T., Pyeon, M., Kim, G.: Self-routing capsule networks. In: Advances in NeurIPS, vol. 32 (2019) 14. Tsai, Y.H.H., Srivastava, N., Goh, H., Salakhutdinov, R.: Capsules with inverted dot-product attention routing. In: ICLR (2020) 15. Long, J., Shelhamer, E., Darrell, T.: Fully convolutional networks for semantic segmentation. In: CVPR, pp. 3431–3440 (2015) 16. Milletari, F., Navab, N., Ahmadi, S.A.: V-net: fully convolutional neural networks for volumetric medical image segmentation. In: 3DV, pp. 565–571. IEEE (2016) 17. Mehta, R., Sivaswamy, J.: M-net: a convolutional neural network for deep brain structure segmentation. In: ISBI, pp. 437–440. IEEE (2017) 18. Li, X., Chen, H., Qi, X., Dou, Q., Fu, C.W., Heng, P.A.: H-DenseUNet: hybrid densely connected UNet for liver and tumor segmentation from CT volumes. IEEE Trans. Med. Imaging 37, 2663–2674 (2018) 19. Isensee, F., Petersen, J., Klein, A., Zimmerer, D., et al.: nnu-net: Selfadapting framework for u-net-based medical image segmentation. arXiv preprint arXiv:1809.10486 (2018) 20. Perslev, M., Dam, E.B., Pai, A., Igel, C.: One network to segment them all: a general, lightweight system for accurate 3D medical image segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 30–38. Springer, Cham (2019). https://doi.org/10. 1007/978-3-030-32245-8_4 21. Wang, T., et al.: MSU-Net: multiscale statistical U-Net for real-time 3D cardiac MRI video segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 614–622. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8_68 22. Zhou, Y., et al.: Hyper-pairing network for multi-phase pancreatic ductal adenocarcinoma segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 155–163. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8_18 23. Chen, J., Lu, Y., Yu, Q., Luo, X., Adeli, E., et al.: Transunet: transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306 (2021) 24. Cao, H., Wang, Y., Chen, J., Jiang, D., Zhang, X., Tian, Q., et al.: Swin-unet: Unet-like pure transformer for medical image segmentation. arXiv preprint arXiv:2105.05537 (2021) 25. Jia, H., Song, Y., Huang, H., Cai, W., Xia, Y.: HD-Net: hybrid discriminative network for prostate segmentation in MR images. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 110–118. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8_13 26. Baumgartner, C.F., et al.: PHiSeg: capturing uncertainty in medical image segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 119–127. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8_14 27. Gu, Z., et al.: Ce-net: context encoder network for 2d medical image segmentation. IEEE Trans. Med. Imaging 38, 2281–2292 (2019) 28. Huo, Y., Xu, Z., Bao, S., et al.: Splenomegaly segmentation on multi-modal MRI using deep convolutional networks. IEEE Trans. Med. Imaging 38, 1185–1196 (2018) 29. Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature pyramid networks for object detection. In: CVPR, pp. 2117–2125 (2017) 30. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, pp. 770–778 (2016) 31. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. In: ICLR 2015 (2015) 32. Sudre, C.H., Li, W., Vercauteren, T., Ourselin, S., Jorge Cardoso, M.: Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In: Cardoso, M.J. (ed.) DLMIA/ML-CDS -2017. LNCS, vol. 10553, pp. 240–248. Springer, Cham (2017). https:// doi.org/10.1007/978-3-319-67558-9_28

Dynamic Attention Filter Capsule Network for Medical Images Segmentation

97

33. Phaye, S.S.R., Sikka, A., Dhall, A., Bathula, D.: Dense and diverse capsule networks: Making the capsules learn better. arXiv preprint arXiv:1805.04001 (2018) 34. Ribeiro, F.D.S., Leontidis, G., Kollias, S.: Capsule routing via variational bayes. In: AAAI, vol. 34, pp. 3749–3756 (2020) 35. Rajasegaran, J., Jayasundara, V., Jayasekara, S., Jayasekara, H., Seneviratne, S., Rodrigo, R.: Deepcaps: going deeper with capsule networks. In: CVPR, pp. 10725–10733 (2019) 36. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. In: ICLR 2015 (2015) 37. Fu, S., et al.: Domain adaptive relational reasoning for 3D multi-organ segmentation. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 656–666. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59710-8_64 38. Oktay, O., Schlemper, J., Folgoc, L.L., Lee, M., Heinrich, M., Misawa, K., et al.: Attention u-net: learning where to look for the pancreas. arXiv preprint arXiv:1804.03999 (2018) 39. Zhou, H.Y., Guo, J., Zhang, Y., Yu, L., Wang, L., Yu, Y.: nnformer: interleaved transformer for volumetric segmentation. arXiv preprint arXiv:2109.03201 (2021)

Cross-Scale Dynamic Alignment Network for Reference-Based Super-Resolution Kai Hu1,2 , Ran Chen1,2 , and Zhong-Qiu Zhao1,2(B) 1 College of Computer and Information, Hefei University of Technology, Hefei, Anhui, China

[email protected] 2 Intelligent Manufacturing Institute of HFUT, Hefei, China

Abstract. Image super-resolution aims to recover high-resolution (HR) images from corresponding low-resolution (LR) images, but it is prone to lose significant details in reconstruction progress. Reference-based image super-resolution can produce realistic textures using an external reference (Ref) image, thus reconstructing pleasant images. Despite the remarkable advancement, there are two critical challenges in reference-based image super-resolution. One is that it is difficult to match the correspondence between LR and Ref images when they are significantly different. The other is how the details of the Ref image are accurately transferred to the LR image. In order to solve these issues, we propose improved feature extraction and matching method to find the matching relationship corresponding to the LR and Ref images more accurately, propose cross-scale dynamic correction module to use multiple scale related textures to compensate for more information. Extensive experimental results over multiple datasets demonstrate that our method is better than the baseline model on both quantitative and qualitative evaluations. Keywords: Reference-based Image Super-Resolution · Cross-Scale · Transformer

1 Introduction Image super-resolution (SR) is an active research topic in computer vision and image processing. It could improve the perceptual quality of images, which also helps improve other computer visual tasks [1, 2, 6]. In recent years, a variety of classic SR methods have been proposed. SISR is a highly ill-posed problem since a single low-resolution (LR) image always corresponds to multiple high-resolution (HR) images. To mitigate this problem, Reference-based super-resolution (RefSR) reconstructs realistic HR image from the LR image with the guidance of an additional HR reference (Ref) image. The existing RefSR method [8, 14, 16, 30] uses internal information or external highfrequency information to enhance the texture. Although it alleviates the highly ill-posed problem, there are two new challenges: how to match correspondence and transfer textures. Existing RefSR works match the correspondence by estimating the pixel or patch similarity of texture features between LR and Ref images, then transferring Ref textures © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 98–108, 2023. https://doi.org/10.1007/978-981-99-4742-3_8

Cross-Scale Dynamic Alignment Network

99

into LR image to assist in image reconstruction. Most works calculate the correspondence between these two patches by normalized inner product [14, 20, 22] although it can accurately calculate most similar features when their appearances vary due to scale and rotation transformations, correspondence computed purely by inner product are inaccurate, leading to an unsatisfactory texture transfer. In addition, although some works [8, 14, 30] noticed the effect of the resolution gap, the problem of unsatisfactory texture transfer due to resolution differences has not been well resolved. To address the aforementioned challenges, we proposed a cross-scale dynamic alignment network (CDAN). The excellent performance of the attention mechanism [3, 26, 27] in the SR task inspires us to use the learned texture extractors, in which parameters will be updated during end-to-end training. As for the resolution gap between LR and Ref images, we proposed a new related feature matching module, named the improved texture match module. More specifically, the features of LR and Ref images are used as K and V of the Transformer and then obtain the correspondence position map and correspondence relevance map. Inspired by Shim’s work [16], we proposed a progressive convolutional module that adjusts the correspondence position map for further alignment, to obtain a more accurate correspondence. The design of the model structure allows us to more accurately match and align and realize the pleasant visual effect compared with the baseline model. The main contributions of this article are as follows: (1) We propose improved feature extraction and matching method, which can stably calculate optimal matching results regardless of whether there is resolution gap, thus providing more important high-frequency information for reconstruction. (2) We propose a progressive image recovery module, which makes use of information from the Ref feature, LR feature, and SR results from the previous upsample layer, then uses dynamic convolution correct result. (3) We evaluated the proposed methods on the benchmark datasets, and our methods obtained excellent SR results in subjective and objective measurement. Experiments have shown that the matching methods and progressive network we proposed have greatly improved the RefSR performance.

2 Related Work 2.1 Single Image Super-Resolution In recent years, the SISR method has clearly outperformed traditional methods in learning-based methods. Based on deep learning methods, the SISR problem is regarded as a dense image regression task. It learns an end-to-end image mapping function between LR and HR images. SRCNN [4] trains a three-layer full-convolutional network with Mean Square Error (MSE) between the SR image and the HR image. It proves that deep learning achieves the most advanced performance on the SR task. FSRCNN [4] Further accelerates SR process by replacing the interpolated LR image with the original LR image and using the deconvolution module to amplify the feature map at the last layer. Soon after, VDSR [10] and DRCN [11] with deeper networks were proposed to be committed to residual learning. ResNet [7] introduces the residual block and has been improved in EDSR [13]. Later, with the help of residual blocks, the SR network became

100

K. Hu et al.

deeper or wider [5, 10, 11, 13, 25, 27, 31–34]. In recent years, Transformer [22–24] has performed well on super-resolution tasks. However, the above method utilizes the MSE or Mean Absolute Error (MAE), which usually produce excessively smooth result from the perspective of human perceptions. Therefore, the subsequent work tends to improve perceptual quality and design a new loss function. Johnson, Alahi, and Fei-Fei use VGG [17] to introduce perception loss [9] and show a visually pleasing result. SRGAN [12] forced SR results to approach the distribution of natural images by using generative adversarial networks (GANs) as loss functions. In addition, some SR methods [12, 19] adopt GANs to optimize the perception quality to further improve the super-resolution output. 2.2 Reference-Based Super-Resolution Different from SISR, RefSR has a reference HR image, which aims to super-resolve the image by transferring high-frequency details of the Ref image. More detailed texture auxiliary information can be extracted from the reference image. CrossNet [30] and propose an end-to-end neural network with light flow distortion features. SRNTT [29] adopts the method of patch match in multi-scale features to exchange relevant image texture features. Similarly, SSEN [16] uses dynamic convolution to align the relevant texture features. In order to improve performance, TTSR [22] put forward hard attention and soft attention to improve the transfer and synthesis of high-frequency details. E2ENT2 [21] transfers texture features by using SR tasks. In order to improve the matching efficiency, MASA [14] adopts a hierarchical enumerated correspondence matching scheme to map the distribution of the Ref feature to the distribution of the LR feature. Recently, a powerful RefSR method c2-matching [8] proposed a comparison corresponding network to learn the corresponding relationship. It uses teacher-student correlation distillation to improve the LR-HR matching and finally uses residual features to synthesize HR images. However, this method requires training multiple network models.

3 Proposed Method 3.1 Network Architecture In reference-based super-resolution (RefSR), transferring irrelevant textures to the output could result in poor SR performance, therefore, how to accurately match the correspondence between the Ref and LR images is a challenging problem as it affects the quality of super-resolved results. In order to solve this problem, we propose an improved feature extractor to extract the robust features of LR and Ref images. LR indicates a lowresolution image that has been up-sampled by x4, whereas Ref indicates that the Ref image has been down-sampled by bicubic interpolation with a scaling factor of 4 and then upsampled, which is domain-consistent with LR to alleviate Inaccurate matches due to resolution gaps. As the three elements of transformer Q, K, and (Fig. 1) V, their process of texture extraction can express as: Qi = ITE(LR ↑)

(1)

Cross-Scale Dynamic Alignment Network

Backbone

101

F

LR

Output image CD CN

Vi

0.1 ...

...

... 0.5 ...

Ref

...

... 0.7

S

ITE

K

LR

ITM

4

...

...

...

5

...

...

...

1

H

Q

Ref

Fig. 1. The architecture of our CDAN network.

K = ITE(Ref ↓↑)

(2)

V = ITE(Ref )

(3)

Among them, ITE represents the corresponding Q, K, and V features extracted at different scales. In the texture extraction module, the output of different layers in the output network is used as multi-scale features. Since the information of each scale feature is complementary. The scale of this module uses × 1, × 2↓ and × 4↓, to alleviate the inaccurate matching problem caused by the different resolutions. 3.2 Improved Texture Matching Module Texture matching is to calculate the similarity between LR and Ref images, and prepare for subsequent texture transfers. In order to calculate correspondence, image features such as Q and K are expanded into qi and kj , most of the previous methods use inner product of qi and kj estimating the similarity, which only focuses on the similarity between pixel values and ignores the differences in brightness values between the input image and the reference image, results in matching features with similar pixel values instead of truly similar features. On this basis, we add pixel gradient similarity to force consider structural similarity in texture matching additionally. For each patch qi in Q and kj in K, we calculate the relevance ri,j as follow (Fig. 2): ri,j = αI (qi , kj ) + βG(qi , kj )

(4)

Among them, α + β = 1; the index map hi and the confidence map si are obtained after relevance calculation. hi = arg max ri,j j

(5)

102

K. Hu et al.

Fout

F

Deform Conv V1 Deform Conv

V2

Deform Conv

V4

Fig. 2. The architecture of our CDCN module. At each layer our model consists of deformable convolutional module and residual connect.

si = max ri,j j

(6)

We use the position map hi as the position reference of the subsequent dynamic correction network, and the relevance map as soft attention to promoting feature fusion of feature T with input image extraction. In conclusion, the improved texture extraction and matching module can effectively match the relevant HR texture features from Ref images to LR features. 3.3 Cross-Scale Dynamic Correction Module We propose a cross-scale dynamic correction module (CDCN) that fully utilizes detailed information at each layer. In this module, the progressive structure can alleviate the resolution gap between LR and Ref images, and improved DCN is used for alignment correction at each scale. Using this multi-level network, more accurate high-frequency detail features are transferred at each layer to obtain the robustness of scale feature transfer. The hi and si are obtained through the calculation of the afore-mentioned module. For each pixel p in the input image feature Li , we aggregate the corresponding position hi,p in the Ref image. We use a modified DCN to transfer the relevant texture around the hi,p position, the process can be expressed as: Y (p) =

K 

wj X (p + pj + pj )mj

(7)

j

where X is the input of each stage, Y is the output. Where pj ∈ {(−1,1), (−1,0),…, (1,1)}, wj is the weight of the convolution kernel, pj is j-th dynamic offset. In addition, due to the limitation of the receptive field of the improved DCN network, we use a progressive network to alleviate this problem, perform further alignment in layers, and use the aligned features as the input of the next level for multi-scale correction to solve the problem of resolution gap.

Cross-Scale Dynamic Alignment Network

103

4 Experiments 4.1 Experimental Setup Datasets: we train and test our network on CUFED5 dataset. The CUFED5 dataset contains 11871 pairs of training image, each pair consists of a 160 × 160 reference image and an input image. The test set consists of 126 input images and 4 reference images of different similarity levels corresponding to each image. For comparison, all models are trained and tested on the CUFED5 dataset. To evaluate the generalization ability of the network model, we tested it on Sun80, Urban100 and Manga109. Sun80 contains 80 natural images, each paired with multiple reference images. Since there is no corresponding reference image for Urban100 and they are all architectural images with strong self-similarity, the LR image is used as the Ref image. The Manga109 dataset, which also lacks reference images, consists of lines, curves, and flat colored regions. Therefore, we randomly sample HR images in this dataset as reference images. Implementation Details: All our models are implemented on Pytorch 1.11.0 and use an NVIDIA RTX2080 GPU for training and testing. For the training data, we perform random horizontal and vertical flips, and then randomly rotate 90°, 180° and 270°. During the training process, the LR image is generated by four-fold downsampling of the HR image with a size of 160 × 160, and the Ref image is output at different levels through the pre-trained VGG19 to obtain multi-scale features, then input it into the model. The batch size set to 9 and use Lrec , Ladv and Lper to train our model. Their weight are 1, 1e-3, and 1e-2 respectively. The Adam optimizer is used for training and the learning rate is set to 1e-4. First, use Lrec to warm up the network for 3 epochs, and then use the total loss function to train another 100 epochs.

4.2 Evaluation Quantitative Evaluation. To demonstrate the effectiveness of our model approach, we evaluate our model on the CUFED5 test set, Sun80, Urban100, and Manga109 datasets, and compare with SISR and RefSR methods respectively. Among them, SISR methods include SRCNN [4], MDSR [13], RDN [29], RCAN [27], SRGAN [12], ENet [15], ESRGAN [19], RSRGAN [18], and RefSR methods include CrossNet [30], SRNTT [28], TTSR [22]. Qualitative results are evaluated as shown in Fig. 4, and our method performs well on the evaluation metrics PSNR and SSIM. In order to obtain better visual quality and improve perceptual quality, we sacrifice the evaluation index and use perceptual loss for training, but the values of PSNR and SSIM will be reduced. Therefore, we trained another version of our model optimized only for reconstruction loss, named CDCN-rec, which is for a fair comparison of PSNR and SSIM our model also achieved better visual quality. CDCN can transfer more accurate HR textures from reference images to generate good results, as shown in the Table 1. Qualitative Evaluation. The results of the qualitative evaluation are shown in Fig. 3, and our method has the best visual quality containing many realistic details close to respective HR ground-truths. Specifically, as shown in the first, fourth and sixth examples, CDAN recovers the architecture details successfully compared with other methods.

104

K. Hu et al.

Table 1. PSNR/SSIM comparison among different SR methods on four different datasets. Methods are grouped by SISR methods (top) and RefSR methods (down). Method

CUFED5

Sun80

Urban100

Manga109

SRCNN

25.33/.745

28.26/.781

24.41/.738

27.12/.850

MDSR

25.93/.777

28.52/.792

25.51/.783

28.93/.891

RDN

25.95/.769

28.63/.806

25.38/.768

29.24/.894

RCAN

26.06/.769

29.86/.810

25.42/.768

23.38/.895

SRGAN

24.40/.702

26.76/.725

24.07/.729

25.12/.802

ENet

24.24/.695

26.24/.702

23.63/.711

25.25/.802

ESRGAN

21.90/.633

24.18/.651

20.91/.620

23.53/.797

RSRGAN

22.31/.635

25.60/.667

21.47/.624

25.04/.803

CrossNet

25.48/.764

28.52/.763

25.11/.764

23.36/.741

SRNTT

25.61/.764

28.54/.793

25.50/.783

27.54/.862

SRNTT-rec

26.24/.784

28.54/.793

25.50/.783

28.96/.885

TTSR

25.53/.765

28.59/.774

24.62/.747

27.70/.886

TTSR-rec

27.09/.804

30.02/.814

25.87/.784

30.09/.907

CDAN

25.66/.772

29.03/.794

25.61/.789

27.91/.885

CDAN-rec

27.57/.817

30.01/.821

25.94/.790

30.04/.909

Furthermore, as shown in the third line, CDAN achieves significant improvements in restoring the word textures. Even if the reference image is not that globally relevant to the input image, our CDAN can still extract finer textures from local regions and transfer effective textures into the predicted SR result, as shown in the fifth example in Fig. 3.

4.3 Ablation Study In this section, we evaluate the effectiveness of proposed modules in our method, including improved feature extraction and matching method, cross-scale dynamic correction module. Our improved texture extraction and matching module (ITEM) mainly consists of two parts: a multi-scale texture extractor, an improved texture matching module, and we reimplement TTSR [22] as our baseline model. On the baseline model, we gradually add ITEM and CDCN module. Models without ITEM modules use pretrained VGG19 for feature extraction. As we can see, when ITEM is added, the PSNR performance can be improved from 27.09 to 27.21, which verifies the effectiveness of ITEM module for detail texture extraction and the accuracy of texture matching. When adding a crossscale dynamic correction network, the relevant texture features are better utilized. This increased the network performance to 27.57 as Table 2.

Cross-Scale Dynamic Alignment Network

105

Fig. 3. Visual comparison among different SR methods on CUFED5 (top three examples), and Urban100 (the bottom example whose reference image is the LR input), Manga109 (the fifth examples), Sun80 (the sixth examples).

Table 2. Ablation study on CUFED5 to study the effectiveness of the proposed ITEM and CDCN Method

ITEM

CDCN

Base

PSNR/SSIM 27.09/.804

Base + ITEM



Base + ITEM + CDCN



27.21/.807 ✓

27.57/.817

106

K. Hu et al.

5 Conclusion In this paper, we propose an improved feature extraction and matching method for matching corresponding textures from Ref image and LR image. This module adds comparison of gradient information to pay more attention to structural information, which facilitate texture transfer. By extracting joint features and inputting them into our proposed cross-scale dynamic correction network, we can learn more powerful feature representations and then reconstruct images with more detailed information. Extensive experiments demonstrate that that our method is better than the baseline model on both quantitative and qualitative evaluations. In the future, we will further extend the proposed improved texture extraction and matching network to general image tasks. Acknowledgement. This work was supported in part by the National Natural Science Foundation of China under Grants 61976079, in part by Guangxi Key Research and Development Program under Grant AB22035022, and in part by Anhui Key Research and Development Program under Grant 202004a05020039.

References 1. Bai, Y., Zhang, Y., Ding, M., Ghanem, B.: SOD-MTGAN: small object detection via multitask generative adversarial network. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11217, pp. 210–226. Springer, Cham (2018). https://doi.org/ 10.1007/978-3-030-01261-8_13 2. Dai, D., Wang, Y., Chen, Y., Van Gool, L.: Is image super-resolution helpful for other vision tasks? In: 2016 IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 1–9. IEEE (2016) 3. Dai, T., Cai, J., Zhang, Y., Xia, S.T., Zhang, L.: Second-order attention network for single image super-resolution. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11065–11074 (2019) 4. Dong, C., Loy, C.C., He, K., Tang, X.: Image super-resolution using deep convolutional networks. IEEE Trans. Pattern Anal. Mach. Intell. 38(2), 295–307 (2015) 5. Ghifary, M., Kleijn, W.B., Zhang, M., Balduzzi, D., Li, W.: Deep reconstruction-classification networks for unsupervised domain adaptation. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 597–613. Springer, Cham (2016). https://doi.org/ 10.1007/978-3-319-46493-0_36 6. Haris, M., Shakhnarovich, G., Ukita, N.: Task-driven super resolution: object detection in lowresolution images. In: Mantoro, T., Lee, M., Ayu, M.A., Wong, K.W., Hidayanto, A.N. (eds.) ICONIP 2021. CCIS, vol. 1516, pp. 387–395. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-92307-5_45 7. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 8. Jiang, Y., Chan, K.C., Wang, X., Loy, C.C., Liu, Z.: Robust reference-based super-resolution via c2-matching. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2103–2112 (2021) 9. Johnson, J., Alahi, A., Fei-Fei, L.: Perceptual losses for real-time style transfer and superresolution. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9906, pp. 694–711. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46475-6_43

Cross-Scale Dynamic Alignment Network

107

10. Kim, J., Lee, J.K., Lee, K.M.: Accurate image super-resolution using very deep convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1646–1654 (2016) 11. Kim, J., Lee, J.K., Lee, K.M.: Deeply-recursive convolutional network for image superresolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1637–1645 (2016) 12. Ledig, C., et al.: Photo-realistic single image super-resolution using a generative adversarial network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4681–4690 (2017) 13. Lim, B., Son, S., Kim, H., Nah, S., Mu Lee, K.: Enhanced deep residual networks for single image super-resolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, pp. 136–144 (2017) 14. Lu, L., Li, W., Tao, X., Lu, J., Jia, J.: MASA-SR: matching acceleration and spatial adaptation for reference-based image super-resolution. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6368–6377 (2021) 15. Sajjadi, M.S., Scholkopf, B., Hirsch, M.: Enhancenet: single image super-resolution through automated texture synthesis. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 4491–4500 (2017) 16. Shim, G., Park, J., Kweon, I.S.: Robust reference-based super-resolution with similarityaware deformable convolution. In: Proceedings of the IEEE/CVF conference on Computer Vision and Pattern Recognition, pp. 8425–8434 (2020) 17. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 18. Wang, X., Xie, L., Dong, C., Shan, Y.: Real-esrgan: training real-world blind super-resolution with pure synthetic data. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1905–1914 (2021) 19. Wang, X., et al.: ESRGAN: enhanced super-resolution generative adversarial networks. In: Leal-Taixé, L., Roth, S. (eds.) ECCV 2018. LNCS, vol. 11133, pp. 63–79. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-11021-5_5 20. Xia, B., Tian, Y., Hang, Y., Yang, W., Liao, Q., Zhou, J.: Coarse-to-fine embedded patchmatch and multi-scale dynamic aggregation for reference-based superresolution. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, pp. 2768–2776 (2022) 21. Xie, Y., Xiao, J., Sun, M., Yao, C., Huang, K.: Feature representation matters: end-to-end learning for reference-based image super-resolution. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12349, pp. 230–245. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58548-8_14 22. Yang, F., Yang, H., Fu, J., Lu, H., Guo, B.: Learning texture transformer network for image super-resolution. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5791–5800 (2020) 23. Yu, R., et al.: Cascade transformers for end-to-end person search. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7267–7276 (2022) 24. Zamir, S.W., Arora, A., Khan, S., Hayat, M., Khan, F.S., Yang, M.H.: Restormer: efficient transformer for high-resolution image restoration. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5728–5739 (2022) 25. Zhang, K., Liang, J., Van Gool, L., Timofte, R.: Designing a practical degradation model for deep blind image super-resolution. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4791–4800 (2021) 26. Zhang, Y., Li, K., Li, K., Fu, Y.: MR image super-resolution with squeeze and excitation reasoning attention network. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13425–13434 (2021)

108

K. Hu et al.

27. Zhang, Y., Li, K., Li, K., Wang, L., Zhong, B., Fu, Y.: Image super-resolution using very deep residual channel attention networks. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 294–310. Springer, Cham (2018). https://doi.org/ 10.1007/978-3-030-01234-2_18 28. Zhang, Z., Wang, Z., Lin, Z., Qi, H.: Image super-resolution by neural texture transfer. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7982–7991 (2019) 29. Zhang, Y., Tian, Y., Kong, Y., Zhong, B., Fu, Y.: Residual dense network for image super-resolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2472–2481 (2018) 30. Zheng, H., Ji, M., Wang, H., Liu, Y., Fang, L.: CrossNet: an end-to-end reference-based super resolution network using cross-scale warping. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11210, pp. 87–104. Springer, Cham (2018). https:// doi.org/10.1007/978-3-030-01231-1_6 31. Li, J., Zhao, Z.Q.: Training super-resolution network with difficulty-based adaptive sampling. In: 2022 IEEE International Conference on Multimedia and Expo (ICME), pp. 1–6. IEEE (2022) 32. Shen, H., Zhao, Z.Q.: Mid-weight image super-resolution with bypass connectionattention network. In: ECAI 2020, pp. 2760–2767. IOS Press (2020) 33. Shen, H., Zhao, Z.Q., Liao, W., Tian, W., Huang, D.S.: Joint operation and attention block search for lightweight image restoration. Pattern Recogn. 132, 108909 (2022) 34. Shen, H., Zhao, Z.Q., Zhang, W.: Adaptive dynamic filtering network for image denoising. arXiv preprint arXiv:2211.12051 (2022)

Solving Large-Scale Open Shop Scheduling Problem via Link Prediction Based on Graph Convolution Network Lanjun Wan(B)

, Haoxin Zhao, Xueyan Cui, Changyun Li, and Xiaojun Deng

School of Computer Science, Hunan University of Technology, Zhuzhou 412007, China {wanlanjun,dengxiaojun}@hut.edu.cn

Abstract. The open shop scheduling problem (OSSP) is one of the classical production scheduling problems, which usually has complex constraints and huge solution space. Given that the traditional meta-heuristic algorithms are difficult to solve the large-scale OSSP efficiently, a method to solve the large-scale OSSP via graph convolution network-based link prediction (GCN-LP) is proposed. Firstly, the state of OSSP is represented using a disjunctive graph, and the features of the operation nodes are designed. Secondly, a GCN-based open shop scheduling model is designed by embedding the operation node features in OSSP. Finally, an open shop scheduling algorithm based on link prediction is designed by combining with the GCN-based open shop scheduling model, which improves the efficiency and quality of solving the large-scale OSSP. Experimental results show that the solution quality of the proposed GCN-LP method is comparable to the metaheuristic algorithms in the OSSP benchmark instances, but the solution quality and solution efficiency of the GCN-LP method are significantly better than the meta-heuristic algorithms in the large-scale OSSP random instances. Compared with the other graph neural network (GNN) models, the proposed GCN-based link prediction method can obtain better and more stable scheduling results when solving the large-scale OSSP random instances. Keywords: Open shop scheduling problem · Graph convolution network · Link prediction · Disjunctive graph · Graph neural network

1 Introduction OSSP is one of the classical production scheduling problems [1]. Given machines and jobs, each job must be processed on each machine, meaning each job has operations and each operation has a fixed processing time. In OSSP, the processing sequence of each job operation is arbitrary, but different processing sequences can have an impact on the overall completion time of each job. Therefore, the objective of solving OSSP is to find an optimal production scheduling scheme which minimizes the maximum completion time (makespan) [2]. Approximation algorithms include heuristic algorithms [3, 4] and meta-heuristic algorithms. Heuristic algorithms have the advantages of being simple, and computationally efficient. In recent years, meta-heuristic algorithms, including genetic algorithms © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 109–123, 2023. https://doi.org/10.1007/978-981-99-4742-3_9

110

L. Wan et al.

(GA) [5, 6], ant colony optimization (ACO) [7], and particle swarm optimization (PSO) [8], are mainly used to solve the OSSP. Meta-heuristic algorithms have a number of advantages in solving OSSP. The meta-heuristic algorithms have some flexibility in that it can be adapted and modified for specific problems to suit different problem scenarios. However, meta-heuristic algorithms require multiple trials to find the optimal solution, resulting in significant computational overhead when solving large-scale the OSSP. In addition, the design and implementation of meta-heuristic algorithms are also relatively difficult, requiring certain professional knowledge and experience. As the manufacturing industry grows, the OSSP becomes larger and more complex, leading to increased uncertainty. Traditional methods struggle with large search spaces, causing lower solution quality and efficiency. In order to improve the solution quality and efficiency of large-scale OSSP, a method for solving large-scale OSSP by link prediction based on graph convolution network is proposed. The main contributions of this article are summarized as follows. • The disjunctive graphs are used to represent the relationship between operations and machines information in OSSP, so that OSSP can be described more clearly from the perspective of graphs, and the features of operation nodes are designed, so that the nodes feature information can be efficiently used to solve OSSP. • In order to fully utilize the relationship information between operation nodes in the graph structure, a GCN-based open shop scheduling model is designed by embedding the operation node features in OSSP. In this model, the encoder can obtain the node embedding of each operation node through multi-layer convolution process, and the decoder reconstructs the OSSP information from the node embedding to determine the possibility of edge existence. • In order to improve the efficiency and quality of solving the large-scale OSSP, the open shop scheduling algorithm based on GCN-LP is designed to efficiently guide the generation of scheduling results and optimize the scheduling results through link prediction, thereby minimizing the makespan of all jobs. • A large number of experiments are carried out using OSSP benchmark instances and random instances of different scales to verify the effectiveness of the proposed method. The experimental results show that this method effectively improves the solution quality and solution speed of large-scale OSSP instances. The rest of this paper is structured as follows. Section 2 introduces the related work. Section 3 gives the theoretical basis. Section 4 gives the proposed method for solving the large-scale OSSP based on GCN-LP. Section 5 gives the experimental results and analysis. Section 6 gives the conclusions and future work.

2 Related Work This section mainly introduces the research status of meta-heuristic algorithms [9, 10] for solving OSSP and GNN for solving shop scheduling. GA [11] is an optimization algorithm inspired by the theory of biological evolution in nature, which gradually evolves better populations through survival of the fittest. Liaw [5] proposed a hybrid GA combining tabu search with basic genetic algorithm which

Solving Large-Scale Open Shop Scheduling Problem

111

can obtain optimal solutions on OSSP benchmark instances. Rahmani et al. [6] proposed an improved GA for solving OSSP, considering the impact of crossover and mutation operators on solving OSSP. ACO [12] is a probabilistic algorithm used for finding optimal paths, which has the characteristics of distributed computing, positive feedback of information, and heuristic search. Blum [7] proposed a hybrid ACO algorithm which combines the construction mechanism of ACO with beam search, which enhances the stability of ACO and improves the quality and efficiency of solving OSSP. PSO [13] is a population-based optimization algorithm, where each particle represents an individual and the particle swarm is composed of multiple particles. The relationship between particles and the swarm in PSO is similar to the relationship between chromosomes and the population in GA. Sha and Hsu [8] proposed a PSO algorithm for solving OSSP, which improves adaptivity by modifying the representations of particle position, particle movement, and particle velocity, and introduces a new decoding operator for better computing results. However, as problem size and complexity increase, these algorithms face various limitations, such as high time complexity and local optima trapping. Due to their powerful data understanding and cognitive capabilities, GNN can solve problems such as high time complexity and local optimum capture encountered when solving large-scale problems [14]. Therefore, some researchers have also started to use GNN to solve shop scheduling problems [15]. Hameed et al. [16] proposed a GNNbased distributed reinforcement learning framework to solve the job shop scheduling problem. Li et al. [17] firstly introduced the discount memory into graph attention network to solve OSSP, then converted the OSSP into a sequence-to-sequence problem by constructing an incremental graph solution, and finally trained the model using reinforcement learning. Zhang et al. [18] encoded the state disjunctive graph using GNN and combined reinforcement learning to learn high-quality priority scheduling rules in job shop scheduling problem. Park et al. [19] also combined GNN and reinforcement learning to solve job shop scheduling problem, which uses richer node features and considers different relationships between nodes in the GNN message-passing process.

3 Preliminaries 3.1 Open Shop Scheduling Problem OSSP is a typical NP-hard problem defined as a shop with m machines and n jobs, each containing m operations, each of which must be processed on a predetermined machine. The related symbols are defined as shown in Table 1. The OSSP constraint is that a machine can only process one job at the same time, and the operation of a job can only be processed on one machine at the same time, and the machine is allowed to be idle. The specific constraints are as follows. Equation (1) represents that when Oij is processing on J k at a certain moment Wijk = 1. Constraint (2) requires that each operation of a job can only be processed on one machine at the same time. The preemptive scheduling problem is not considered in this paper, so constraint (3) requires that once an operation starts, and it must be executed without interruption until the end of the process and the job leaves the machine.

112

L. Wan et al. Table 1. Definition of related symbols.

Symbol

Definition description

m

the number of machines

n

the number of jobs

Mi

the i-th machine, where 1 ≤ i ≤ m

Ji

the i-th job, where 1 ≤ i ≤ n

Oij

the j-th operation of job Ji , where 1 ≤ i ≤ n , 1 ≤ j ≤ m

Pij

the processing time of operation Oij

Pijk

the processing time of the operator Oij on the machine k

Sij

the start processing time of operation Oij

Eij

the end processing time of operation Oij

Cm

the completion time of machine Mm

Cmax

the makespan of the scheduling instance

The makespan represented by Eq. (5) is the maximum value of all job completion times. The objective function of this paper is given by Eq. (6), which aims to minimize the makespan by scheduling the jobs without specifying their processing order [20].  1, if Oij is processed on Mk; (1) Wijk = 0, otherwise. m 

Wijk = 1, ∀i, j.

(2)

k=1

Sij + Wijk × Pij = Eij , ∀i, j, k.  Yijk =

(3)

1, if Ji is processed on Mk before Jj ; 0, otherwise.

(4)

Cmax = max{C1 , C2 , ..., Cm }.

(5)

Minimize Cmax .

(6)

3.2 Link Prediction and Disjunctive Graph Link prediction can be seen as a typical binary classification problem. Suppose x and y are nodes in the graph and L(x,y) is the label of the node pair instance (x, y). In link prediction, each pair of non-connected nodes corresponds to an instance, including the class label and the features describing the pair. The core problem of link prediction is to predict the connections between nodes by measuring the similarity between them [21].

Solving Large-Scale Open Shop Scheduling Problem

113

In the implementation process, it is necessary to consider how to construct the network graph, how to extract features and how to train the model, as well as to use suitable evaluation metrics to assess the performance of the algorithm. The initial state of an OSSP instance can be described as a disjunctive graph Gs (V , C ∪ D), with the set of nodes V containing the specifics of each operation in the scheduling instance, such as the machine on which the operation is processed and the processing time of the operation. The set of connected edges C contains connected edges, each of which represents the successive constraints between two operations on the same job. In OSSP, each job must start from the start node Start, and all scheduling is considered complete after processing all operations of each job. The disjunction edge set D contains disjunction edges, and each disjunction edge represents the machine sharing constraints between two nodes, that is, two operations. When the same machine can handle more than two operations, the corresponding two operation nodes are connected by a disjunction edge. The solution of each OSSP can be expressed as a directed graph Ge .

4 The Proposed Method 4.1 Overall Process of Solving OSSP Based on GCN-LP This paper proposes a method for solving the large-scale OSSP based on GCN-LP, which consists of three phases: the data preprocessing phase, the training, validation and testing phase of the model, and the phase of solving new OSSP instances using GCN-LP. Figure 1 shows the overall flowchart of solving OSSP based on GCN-LP.

Fig. 1. Overall flowchart of solving OSSP based on GCN-LP.

During the training phase, the OSSP instances are converted into graphical structures, and the node features are designed. Then, the training set is divided into test, validation,

114

L. Wan et al.

and training sets, and multi-layer convolution is performed on the edge and operation node features to obtain the node embeddings. In the testing phase, the model’s accuracy is evaluated, and its parameters are updated. The GCN-based open shop scheduling model is used to solve a new OSSP instance with node features, and the GCN-LP-based algorithm generates a Gantt chart and parse chart of the scheduling results. 4.2 Node Feature Design Figure 2 shows the diagram of the operation node features design. The detailed design of operation node features is as follows. Operation sorting: Sort all operations of the scheduling instance, where the first operation of the first job is 0 and the last operation of the last job is (m × n − 1). If the operation is Oij, the alignment number is (i − 1) × m + ( j − 1).

Fig. 2. Diagram of the operation node features design.

Machine corresponding to the operation: After obtaining the sequence number corresponding to each operation, the information of the machine processing each operation node is processed in one-hot encoding based on the data in the scheduling instance. The dimension of the feature matrix is (m × n − 1) × (m − 1), and the first m columns of each operation node have and only one column with the number 1, i.e., the operation is processed on this machine, and the rest of the columns are set to 0. The processing time normalization can be calculated by XNorm =

X − Xmin , Xmax − Xmin

(7)

the last column of the feature matrix is normalized for the processing time, where Xmin represents the minimum value of the operational processing time and Xmax represents the maximum value of the operational processing time. Thus, the effect of scaling between different features is eliminated, allowing them to be compared on the same scale. 4.3 GCN-Based Open Shop Scheduling Model 4.3.1 Network Structure Design Figure 3 presents the schematic diagram of the network structure of the proposed GCNbased open shop scheduling model. The model designed in this paper scales linearly with

Solving Large-Scale Open Shop Scheduling Problem

115

the number of edges in the graph and learns hidden layer representations that encode local graph structure and node features. The forward propagation of the multilayer GCN with hierarchical propagation rules is performed by ˜ − 2 H (l) W (l) ), ˜ − 2 A˜ D H (l+1) = σ (D 1

1

(8)

the self-loop adjacency matrix of the graph is calculated by A˜ = A + IN ,

(9)

the degree matrix of A˜ is obtained by ˜ ii = D

 j

A˜ ij ,

(10)

where I N is the unit matrix, W (l) is the trainable weight matrix, σ (·) is the activation function, H (l+1) is the activation matrix, and H (0) = X is the input matrix of the first layer. Graph Conv. Layer 2 (17 Channels)

Graph Conv. Layer 1 (19 Channels)

Input

Output

Sigmoid

σ ( x) =

Graph Conv. Layer 7 (3 Channels)

1 1+ e −x

Fig. 3. Network structure of the GCN-based open shop scheduling model.

The GCN model is iteratively updated mainly based on the neighbor information of each node in the graph, using a 7-layer graph convolution network and a 1-layer sigmoid function. According to the feature information of each operation node itself and the neighbor information, node embedding is performed for each operation node in the large-scale OSSP to obtain the feature information of each operation node after multilayer embedding. The multi-layer GCN forward propagation model can be represented by ˜ −2 , ˜ − 2 A˜ D Aˆ = D 1

1

(11)

and (0) ˆ A( ˆ A( ˆ A( ˆ A( ˆ A(sigmoid ˆ ˆ Z = f (x, A) = A( (AXW ))W (1) )W (2) )W (3) )W (4) )W (5) )W (6) , (12)

where x is the characteristic of each operation node, A is the adjacency matrix, and A˜ is the normalization matrix to matrix A.

116

L. Wan et al.

4.3.2 Model Training Algorithm 1 is the training process for the GCN-based open shop scheduling model. The specific steps of the GCN-based open shop scheduling model are as follows. Step 1: Input the node feature matrix X, schedule the dataset D and convert it into graph data Gs . Step 2: Divide G into a training set Strain , a validation set Sval and a test set Stest . A part of the existing edges is used as the training set and another part of the existing edges is used as the test set. Step 3: Add negative sample Yn to the training and test sets to make the model better learn the features of the positive samples Yp . The negative samples are generated by the negative sampling technique. Step 4: The trained model is evaluated for accuracy AUC values using the test set and validation set, and the parameters are continuously updated to improve the accuracy of the model. Step 5: Output an open shop scheduling model based on GCN. The loss function is calculated by Loss = −[y ∗ log(sigmoid (x)) + (1 − y) ∗ log(1 − sigmoid (x))],

(13)

where x and y represent the predicted value of the model and the actual label respectively.

Algorithm 1: Training of open shop scheduling model based on GCN Input: The scheduling instance data D, the node feature matrix X, the number of epochs E Output: The parameters of the model, the accuracy of the model on the verification set and test set Uval, Utest 1: Gs ← Disjunctive(D); 2: Sval , Strain , Stest ← RandomLinkSplit( Gs ); 3: for epoch = 1 to E do 4: Yp ← 1, Yn ← 0; 5: 6: 7: 8: 9: 10:

X ← GCN( Gs , X );

Yptr ); Lg ← decode( ZT , Yptr );

ZT ← encode( XT ,

Calculate the loss function by LT ( Lg , La ) by Eq. (13); Update parameters; Zv , Zt ← encode( XV , Xt , Ypte , Ypva );

Ll ← decode( Zv , Zt , Yptr ); 11: 12: Lp ← sigmoid( Ll ); 13: Calculate the AUC value and add elements to Uval and Utest; 14: end for

Solving Large-Scale Open Shop Scheduling Problem

117

4.3.3 Node Encoding and Decoding Due to the fact that the initial information of OSSP can be described as a graph, and the scheduling result can be described as a directed acyclic graph, each operation node only represents the main features of each operation, and there is no graph structure for the state information of each individual operation node. In order to obtain the updated node embeddings, negative link edges with an equal number of positive link edges are randomly added to the original graph, which transforms the model task into a binary classification task for the original positive link edges and the newly added negative link edges. Then, the feature information of the target node and the connected edge information are input to the GCN, and through multi-layer graph convolution operation, the updated feature information of the target node is output. The updated node feature information is called node embedding. The decoder predicts the links between operation nodes by using their embeddings to reconstruct the graph. The encoder is optimized to minimize the reconstruction loss. Different decoders can be used to reconstruct either the graph structure, attribute values or both. This study uses dot product similarity to predict the connection degree between node pairs based on their embedding similarity score. 4.4 Open Shop Scheduling Algorithm Based on GCN-LP An open shop scheduling algorithm based on GCN-LP is designed to efficiently solve large-scale OSSP. Algorithm 2 describes how to solve the open shop scheduling problem based on GCN-LP, which mainly includes the following steps. Step 1: Traverse from the first row of X to find the operations that are processed on the same machine to obtain the feature matrix Filter(X ) after filtering. Assign the initial value −1 to the end time Tte of all operations. Step 2: Construct the first operation node of each machine. Step 3: Find the maximum end time Tme among other operation nodes whose end time has been determined for the job to which the currently selected operation node belongs. Obtain the end time Tce of the previous operation for the current machine. Step 4: Twait is calculated by  0, if Tme 0

If pk = pi , that’s to say, the predicted result is inconsistent with the actual result, we can compare the size of norm(pi 2 ) and pi as following.

=

  norm pi2 pi = cpi p2 j=1 j 1 c p pi + j=1,j=i ( pj ) i < 1 pk pi + p i < pi 1+1

(10)

1 for multi-relational graph G, A = A(1) , A(2) , . . . , A(|R|) is a set of adjacency matrices, where A(r) ∈ {0, 1}N ×N is an adjacency matrix of the graph G r . Definition 2. (Attributed Multi-relational Graph Embedding). Given an attributed multi-relational graph G, the aim of attributed multi-relational graph embedding is to learn a d -dimensional vector representation zi ∈ Z ∈ RN ×d for each node vi ∈ V .

4 The Proposed Model This section presents the unsupervised method for embedding an attributed multirelational graph. We first describe how to independently model each relation, then explain how to jointly integrate them to obtain the node embedding.

Attributed Multi-relational Graph Embedding Based on GCN

177

4.1 Intra-relation Aggregation For relation r ∈ R, we introduce a nodal relation-specific encoder which can obtain the relation-specific nodal embedding. The encoder is a single-layer GCN, as follows,   1 (r) 1   −  − (1) F r = gr X , Ar |W (r) = σ Dr 2 A Dr 2 XW (r) (r) where A = A(r) +  I n , Dii = j Aij , W is a trainable weight matrix of the relationspecific decoder gr , and σ is the ReLU nonlinearity. Unlike conventional GCNs, we control the weight of the self-connections by introducing a weight  . Larger  indicates that the node itself plays a more important role in generating its embedding, which diminishes the importance of its neighboring nodes. To address redundant edges in the graph, we use a difference operation to emphasize the attribute differences between nodes. Comparing whether nodes are similar by calculating the differences in attributes between nodes, the differences between nodes of the same class are smaller than those of different classes. As follows,

(2) δir = AGG fir fjr , vj ∈ Nr (i) 





where δir is the result of difference of node i and  represents the operation of difference. Taking node v1 in Fig. 2 as an example, the  of the attribute vector x1 of this node and the attribute vectors x2 , x4 and x5 of its neighboring nodes is shown in Eq. (3), where N(v1 ) denotes the set of first-order neighboring nodes of node v1 . x1 xj = (x1 − x2 ) + (x1 − x3 ) + (x1 − x4 ), vj ∈ N (v1 )

(3)

Lastly, f and δ are aggregated to get node embedding hri on specific relation r. As follows, hri = λ1 · fir + λ2 · δir where λ1 , λ2 are two trade-off parameters which keep the balance of δ and f .

Fig. 2. Schematic diagram of differential operation

(4)

178

Z. Xie et al.

4.2 Inter-relation Aggregation Considering the complementarity of information in each relation in the multi-relationship graph, this subsection designs an inter-relation aggregation module based on the attention mechanism as shown in Fig. 3. The embedding vectors of the specific relations obtained in the intra-relation aggregation module are used as inputs, and the final embedding vectors obtained by the aggregation are used as outputs. To determine the importance of each relation, we transform relation-specific embeddings using a nonlinear transformation. The importance of these embeddings is measured by their similarity to a relation-level attention vector q. Additionally, we calculate the average importance of all relation-specific node embeddings, reflecting the significance of each relation. The importance of relation r, denoted as sr :   T sr = |V1 | q · tanh W · hrv + b (5) v∈V

where W is the weight matrix, b is the bias vector, q is the attention vector. The importance of each relation is normalized through the SoftMax to obtain the weight of each relation as follows, βr =

exp(sr ) |R| i=1 exp(si )

(6)

Obviously, the larger βr is, the more important r is. With the learned weight β as the coefficient, we can aggregate the node embedding of specific relationship to get the final node embedding h. It can be written as follows, h=

|R|

r=1 βr

· hr

Fig. 3. Schematic diagram of inter-relation aggregation module

(7)

Attributed Multi-relational Graph Embedding Based on GCN

179

4.3 Training According to [16], the co-occurrence information in the graph can be extracted by the random walk method. Let vcon and vcen are the starting node of the random walk method and one of the other nodes of the resulting sequence. The probability p of meeting vcon given vcen can be modeled by the expected number of their co-occurrence, as shown in (8). p(vcon |vcen ) =

  exp hTcon hcen T z∈Z exp(h hcen )



(8)

To enhance reconstruction, the parameters W , α11 , . . . , αr1 , . . . , α1k ..., αrk can be learned by minimizing the following objective function that can be found the minimum value. L=−

|R|



r=1 (vcon ,vcen )∈set(ζr )

C(vcon , vcen ) · logp(vcon |vcen )

(9)

where ζr represents all the tuples generated by the random walk in the relation r, set(ζ ) represents the set of non-repeating tuples in ζ and C(vcon , vcen ) represents the number of occurrences of (vcon , vcen ). The pseudo-code of the AMGGCN algorithm is given in Algorithm 1. Algorithm 1 Attributed Multi-relational Graph embedding based on GCN (AMGGCN) Input: Graph ; Number of iterations T; 1. for t=1 to T do: 2. for do: for relation by Eq.(1); 3. Get the neighborhood aggregation vector for relation by Eq.(2); 4. Get the difference vector 5. Get the node embedding vector for relation by Eq.(4); 6. end for 7. Update the weight of relation by Eq.(5) and Eq.(6); 8. Update the node representations by Eq.(7); 9. Calculate and update the loss by Eq.(9); 10. end for Output: The node embedding matrix .

5 Experiments In this section, experiments were conducted to demonstrate the efficacy of model for multi-relational graph embedding with three public graph datasets and a constructed futures dataset.

180

Z. Xie et al.

5.1 Dataset This subsection describes several public datasets and a private dataset used in the experiments. The statistics of these datasets are recorded in Table 1. The detailed description of the dataset and the method of constructing the attribute multi-relationship graph are shown below. DBLP. This dataset contains a total of 14328 papers (P), 4057 authors (A), 20 conferences (C) and 8789 terms (T). The papers are used as the nodes of the graph, which contains three types of relations are PAP, PCP and PATAP respectively. The task is to classify the papers into four categories according to the paper types are data mining, machine learning and information retrieval respectively. IMDB. This dataset contains a total of 4780 movies (M), 5841 actors (A) and 2269 directors (D). The movies are used as the nodes of the graph, which contains two types of relations are MAM and MDM respectively. The task is to classify the movies into three categories according to the movie genres are action, comedy and drama respectively. ACM. This dataset contains a total of 3025 papers (P), 5835 as authors (A) and 56 subjects (S). The papers are used as the nodes of the graph, which contains two types of relations are PAP and PSP respectively. The task is to classify the papers into three categories according to the paper types are database, wireless communication and data mining respectively. FG. This dataset is derived from the transaction data and terminal information data of accounts (A) that have traded on a futures contract on the Zhengzhou Commodity Exchange from September 1, 2020 to September 15, 2020. The accounts are used as nodes in the attributed multi-relational graph, which contains four types of relations: shared Windows-type devices, shared Linux-type devices, shared iOS-type devices, and shared Android-type devices. The node labels used in the evaluation of the experimental results are derived from manual annotation by experts. The task is to mine accounts in collusive trading groups. Collusive accounts refer to groups of users who trade in collusion and malicious speculation for personal gain, who may have similar trading behavior, or who may share equipment.

5.2 Baselines We choose some state-of-the-art methods as competitors. 1) Embedding methods for single-relational graph. LINE [5]: Preserving both first-order and second-order similarity. DeepWalk [6]: It learns node embeddings by random walks and skip-gram. GCN [11]: A neighborhood aggregation network that captures the local topology of the graph and the attribute information of the nodes during the embedding process. GAT [12]: Aggregates the attribute information of neighborhood nodes by trainable attention weights. DGI [7]: Maximizes the mutual information between the node embedding vector and the global representation vector.

Attributed Multi-relational Graph Embedding Based on GCN

181

Table 1. Dataset Information Statistics Datasets

Nodes

Attributes

Relation Types

Edges

DBLP

4,057

334

PAP

6,772,278

PCP

5,000,495

PATAP

11,113

MAM

98,010

MDM

21,018 29,281

IMDB

4,780

1232

ACM

3,025

1830

PAP PSP

2,210,761

FG

8,028

44

AWA

10,914

ALA

6,017

AIA

1,893

ANA

52,430

2) Embedding methods for multi-relational graph. DMGI [14]: Extends DGI to multi-relational graphs and incorporates consensus regularization to aggregate node embeddings from different relationships. HDMI [15]: Joint supervised signal containing high-order mutual information and introduces a high-order depth information model (HDI) to optimize the signal. HAN [17]: A graph embedding algorithm for heterogeneous graphs that combines GCNs and GATs, treating the multi-relational graph as a heterogeneous graph. 3) K-Means [18] and Spectral [19]: K-Means relies solely on node attributes and ignores the topology, while Spectral focuses solely on the graph’s topology and disregards node attributes. In the experiments, the node embedding dimension is set to 64 for fair comparison. For AMGGCN, a grid search is applied to tune the hyperparameters, with  set to 3, and λ1 and λ2 set to 0.4 and 0.6, respectively. And two tasks, classification and clustering, are employed to evaluate the experimental results. For classification, micro-F1 and macro-F1 metrics are used, where higher values indicate better classification performance. For clustering, evaluation metrics include Acc, ARI, and NMI, where higher values indicate better clustering outcomes. 5.3 Performance Analysis Experimental results of node classification and node clustering on the multi-relational graphs are shown in Tables 2 and 3 respectively. On the task of node classification, AMGGCN conducts the classification by calculating the cross-entropy loss. Since the variance of graph-structured data can be quite high, we repeat the process for 10 times and report the average results in Table 2. AMGGCN has the best performance on the classification task compared to other baselines. First, the experimental results of shallow models (e.g., LINE, DeepWalk) are

182

Z. Xie et al. Table 2. Overall performance on the supervised task: node classification (%)

Models

DBLP

IMDB

ACM

MaF1

MiF1

MaF1

MiF1

MaF1

MiF1

LINE

18.55

15.20

23.83

23.90

27.53

26.89

DeepWalk

17.29

12.88

24.61

23.93

35.43

30.04

GCN

80.79

81.71

41.13

36.69

76.81

76.77

GAT

80.87

81.96

29.78

34.17

76.23

76.01

DGI

64.33

63.91

28.26

23.98

91.92

92.09

DMGI

72.58

69.49

36.33

33.98

80.57

79.75

HDMI

80.12

81.26

42.12

41.08

90.12

90.19

HAN

81.17

82.07

39.78

34.17

88.15

87.99

AMGGCN

85.11

84.33

45.73

49.78

92.26

93.53

weaker compared to the graph neural network-based models. Second, the experimental results of models using only topology (e.g., LINE, DeepWalk) are weaker than those combining topology with node attributes. This indicates that on the one hand deep neural network-based models have advantages over shallow models in graph embedding, and on the other hand node attributes can significantly improve the results of graph embedding. Comparing AMGGCN with GCN, DMGI and HAN, the experimental results are significantly improved, which indicates that AMGGCN can effectively reduce the effect of topological redundancy on the neighborhood aggregation algorithm. On the node clustering task, K-Means is applied to obtain the results and the number of clusters is set to the number of classes. Since the performance of K-Means is affected by the initial center of mass, we repeat the process for 10 times and report the average results in Table 3. In general, embedding methods based on graph topology, such as LINE and DeepWalk, perform poorly, indicating that topology alone cannot be relied on to learn satisfactory node embeddings. GCN, GAT and DGI integrate the topological structure and attribute information in the graph, however, the performance is still worse than other multi-relational-based models. The reason for this situation is that these models are designed for single-relational graphs and cannot well handle the data of multirelational graphs. Compared with other multi-relational graph-based models, such as DMGI, HDMI, and HAN, AMGGCN achieves optimal performance for most results. Additionally, to evaluate the effectiveness of the AMGGCN model in practical application scenarios, we applied it to the FG futures trading dataset to detect hidden collusion account groups in futures trading. This dataset contains a large amount of noise (i.e., some accounts do not belong to any collusion group), and although the data was filtered based on existing trading rules before the experiment, there are still some noise data that cannot be filtered by rules alone. At the same time, due to the imbalance in the number of samples in different categories, it is difficult to obtain ideal experimental results in

Attributed Multi-relational Graph Embedding Based on GCN

183

Table 3. Overall performance on the unsupervised tasks: node clustering (%) Models

DBLP

IMDB ARI

Acc

ACM

Acc

NMI

NMI

K-Means

24.82

11.63

2.49

30.48

1.45

Spectral

24.17

0.18

0.00

33.86

LINE

24.83

54.76

59.88

DeepWalk

25.44

1.76

1.33

GCN

23.17

62.34

64.21

GAT

22.62

61.50

DGI

21.15

60.38

DMGI

27.24

HDMI HAN AMGGCN

ARI

Acc

NMI

ARI

3.05

30.09

38.43

38.11

0.10

0.01

33.49

0.13

0.00

32.75

0.10

-0.90

33.66

39.40

34.33

33.97

1.45

2.15

31.40

41.61

35.10

20.58

5.45

4.40

39.73

51.40

53.01

67.26

21.40

8.45

7.46

37.65

57.29

60.43

63.27

32.20

0.31

-0.23

34.53

48.97

50.15

70.34

68.61

35.83

5.69

4.49

32.96

67.18

73.05

24.58

68.70

73.54

30.31

9.87

11.72

39.21

64.86

66.98

27.97

69.12

70.76

29.88

9.62

10.01

40.17

61.56

64.39

30.07

71.71

72.79

33.65

10.67

10.12

57.60

69.51

75.09

the classification task. Therefore, only the clustering task was evaluated in this experiment, and the DBSCAN [20] was used instead of the traditional K-Means to cluster the embedded vectors to obtain the results. In addition, due to the special nature of the financial services industry, the identification of collusion accounts should be as cautious as possible to avoid unnecessary disturbances to users. Therefore, the Acc is used as the evaluation metric in this experiment to ensure the highest possible accuracy of the results. The embedding space dimension is set to 32 in this experiment, and the other parameters are consistent with those used in the experiments on public datasets discussed earlier. As seen from the experimental results shown in Fig. 4, the AMGGCN model shows remarkable advantages in the Acc metric compared with other models, generally achieving 2% to 20% improvements compared to other baselines. This suggests the effectiveness of the AMGGCN model in real-world scenarios. 5.4 Parameters Effects In the AMGGCN model, there are two weight parameters, λ1 and λ2 , which represent the proportions of GCN neighborhood aggregation and differential operation results in the embedding vectors, respectively. Taking the clustering on the ACM as an example, we verify the impact of λ1 and λ2 on the experimental results. As shown in Fig. 5 left, we can see that AMGGCN performs best when the value of λ1 is around 0.4, since a smaller value of λ1 would lead to inadequate weighting of the GCN neighborhood aggregation results in the embedding vector, and thereby result in poor performance. Similarly, the parameter λ2 has the optimal value of around 0.6. When its value increases, the performance of the model declines significantly, possibly because the weighting of the

184

Z. Xie et al.

Fig. 4. The results of AMGGCN on the FG

differential operation result is too large, which is expected to reduce the similarity of embedding vectors among nodes in the same category.

Fig. 5. Results on ACM (left) and results on DBLP (right).

Additionally, we also verified the impact of the number of iterations on the experimental results. As illustrated by Fig. 5 right, taking the node classification experiment as an example, the model reaches the peak of the MaF1 result on both the DBLP and IMBD datasets after 800–900 iterations, while the performance improvement after exceeding 1000 iterations was not significant. Based on these findings, we set the number of iterations rounds for the model to 900 in the experiment. 5.5 Ablation Study In this section, experiments were conducted on clustering tasks (evaluated by the NMI metric) and classification tasks (evaluated by the MaF1 metric), respectively. To evaluate the effectiveness of the inter-relation aggregation module, we compared it with the average pooling operation (AMGGCN-a). As shown in Fig. 6, in all metrics of all datasets, the experimental results of AMGGCN outperform those of AMGGCN-a, which verifies the effectiveness of the inter-relation aggregation module in aggregating relations between nodes.

Attributed Multi-relational Graph Embedding Based on GCN

185

To evaluate the effectiveness of the differential operation, we compared the model with and without the differential operation. As shown in Fig. 6, in all metrics of all datasets, the experimental results of AMGGCN outperform those of AMGGCN-d, indicating that the differential operation has a positive impact on the performance of the model and can effectively reduce the impact of topological structure redundancy.

Fig. 6. The results of AMGGCN ablation study

6 Conclusion In this paper, we propose a GCN-based method (AMGGCN) to learn node embeddings on attributed multi-relational graphs. The AMGGCN model firstly amplifies the difference between neighboring nodes by introducing a difference operation, which alleviates the problem of introducing noise in the convolution due to the redundancy of the topology. Furthermore, the importance of each relation in generating the final node embedding is obtained through the inter-relation aggregation module based on the attention mechanism, and the information on different relations is weighted to aggregate, which facilitates the acquisition of good low-dimensional embedding vectors. Finally, experiments on three public datasets and one futures dataset demonstrate that the AMGGCN model has significant improvement over other baselines. Acknowledgement. This paper was supported by the National Natural Science Foundation of China (Grant No. 62006211).

References 1. Su, X., Xue, S., Liu, F.Z., et al.: A comprehensive survey on community detection with deep learning. In IEEE Transactions on Neural Networks and Learning Systems (2022) 2. Maaten, L., Geoffrey, H.: Visualizing data using t-SNE. J. Mach. Learn. Res. 9, 2579–2605 (2008)

186

Z. Xie et al.

3. Muhammad, F.N., Xian, Y.Q., Akatal, Z.: Learning graph embeddings for compositional zeroshot learning. In: 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 953–962 (2021) 4. Wang, T., Yang, G., Wu, J., et al.: Dual Contrastive Attributed Graph Clustering Network. arXiv preprint arXiv:2206.07897 (2022) 5. Tang, J., Qu, M., Wang, M., et al.: Line: Large-scale information network embedding. Proceedings of the 24th international conference on world wide web,1067–1077 (2015) 6. Perozzi, B., Al-Rfou, R., Skiena, S.: Deepwalk: Online learning of social representations. In: Proceedings of the 20th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 701–710 (2014) 7. Velickovic, P., Fedus, W., Hamilton, W.L., et al.: Deep graph infomax. Int. Conf. Learn. Represent. 2(3), 4 (2019) 8. Gao, H., Huang, H.: Deep attributed network embedding. In: Proceedings of the 27th International Joint Conference on Artificial Intelligence, pp. 3364–3370 (2018) 9. Grover, A., Leskovec, J.: Node2vec: scalable feature learning for networks. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 855–864 (2016) 10. Yang, C., Liu, Z., Zhao, D., et al.: Network representation learning with rich text information. In: Twenty-Fourth International Joint Conference on Artificial Intelligence (2015) 11. Kipf, T.N., Welling, M.:Semi-supervised classification with graph convolutional networks. In: 5th International Conference on Learning Representations (2017) 12. Veliˇckovi´c, P., Cucurull, G., Casanova, A., et al.: Graph attention networks. In: 6th International Conference on Learning Representations (2018) 13. Ma, Y., Wang, S., Aggarwal, C.C., Yin, D., Tang, J.: Multi-dimensional graph convolutional networks. In: BergerWolf, T., Chawla, N. (eds.) Proceedings of the 2019 SIAM international conference on data mining, pp. 657–665. Society for Industrial and Applied Mathematics, Philadelphia, PA (2019). https://doi.org/10.1137/1.9781611975673.74 14. Park, C., Kim, D., Han, J., Yu, H.: Unsupervised attributed multiplex network embedding. Proc. AAAI Conf. Artif. Intell. 34(04), 5371–5378 (2020). https://doi.org/10.1609/aaai.v34 i04.5985 15. Jing, B., Park, C., Tong, H.: HDMI: high-order deep multiplex infomax. Proc. Web Conf. 2021, 2414–2424 (2021) 16. Ma, Y., Ren, Z.C., Jiang, Z.H., et al.: Multi-dimensional network embedding with hierarchical structure. In: Proceedings of the Eleventh ACM International Conference on Web Search and Data Mining, pp. 387–395 (2018) 17. Wang, X., Ji, H., Shi, C., et al.: Heterogeneous graph attention network. In: The World Wide Web Conference, pp. 2022–2032 (2019) 18. Hartigan, J.A., Wong, M.A.: Algorithm AS 136: A k-means clustering algorithm. Appl. Statist. 28(1), 100 (1979). https://doi.org/10.2307/2346830 19. Ng, A., Jordan, M., Weiss, Y.: On spectral clustering: analysis and an algorithm. In: Advances in Neural Information Processing Systems, p. 14 (2001) 20. Ester, M., Kriegel, H.P., Sander, J., et al.: A density-based algorithm for discovering clusters in large spatial databases with noise. In: Proceedings of the Second International Conference on Knowledge Discovery and Data Mining, pp. 226–231 (1996)

CharCaps: Character-Level Text Classification Using Capsule Networks Yujia Wu(B) , Xin Guo(B) , and Kangning Zhan School of Information Science and Technology, Sanda University, Shanghai 201209, China [email protected], [email protected]

Abstract. Text classification is a hot topic in the field of natural language processing and has achieved great success. Existing character-level text classification methods mainly use convolutional neural networks to extract character-level local features, making them ineffective in modeling the hierarchical spatial relationship information on the character-level features, reducing the classification performance. This paper proposes a new character-level text classification framework based on the capsule network called CharCaps to solve the above problem. The proposed CharCaps framework first extracts character-level text features using seven convolutional layers and then reconstructs them based on the capsule vector representation to obtain the hierarchical spatial relationship information between character-level features effectively and achieve a significant classification without pre-trained models. Experimental results on five challenging benchmark datasets demonstrate that our proposed method outperforms state-of-the-art character-level text classification models, especially convolutional neural network-based models. Keywords: Neural Networks · Natural Language Processing · Text Classification · Capsule Networks · Character-level

1 Introduction In the last decade, deep neural networks have achieved great success in several fields [1, 2]. Many deep neural networks have been used to solve text classification problems [3, 4]. Kim proposed a text classification model based on convolutional neural networks (CNNs) and achieved good experimental results [5]. Another deep learning text classification algorithm uses recurrent neural networks (RNNs) [6]. Lai et al. [7] proposed a more complex method based on recurrent convolutional neural networks (RCNNs), which combines the advantages of CNNs and RNNs. The above research works generally use Word2Vec [8] or Glove [9] as a pre-trained model, but after 2019, the BERT pretraining model proposed by Google [10] has gradually come to the mainstream. Mekala et al. [11] used BERT to create a contextualized corpus, generate pseudo-labels, and achieve weakly supervised text classification tasks. Croce et al. [12] used BERT to propose a semi-supervised framework for text classification based on GANs. Qin et al. [13] proposed a feature projection text classification model based on BERT. Chen et al. [14] visually displayed different combinations © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 187–198, 2023. https://doi.org/10.1007/978-981-99-4742-3_15

188

Y. Wu et al.

of words and phrases in a hierarchical structure when detecting feature influences to improve the model interpretability. The above-mentioned deep learning-based text classification methods mainly use words as the smallest unit of language and are input to the neural network models, which rely on pre-trained models to supplement external knowledge. The characterlevel text classification model does not rely on the use of a pre-trained model. Zhang et al. [15] proposed CNNs to process character-level text classification tasks. Kim et al. [16] combined convolutional neural networks and long short-term memory (LSTM) to extract character-level text features for use in multiple languages for text classification tasks. Liu et al. [17] combined convolutional and gated recurrent neural networks to complete character-level text classification tasks. Londt et al. [18] proposed a characterlevel text classification model using an evolutionary deep learning technique to improve the model performance. Although the existing character-level text classification methods have achieved competitive experimental results. However, convolution alone cannot provide the hierarchical spatial relationship information of the original character text features, which may affect the classification. The hierarchical spatial relationship information may contain character order or positional information that is beneficial for classification. Table 1. Examples for sentiment classification. Ti denotes the i-th training sample; Sj denotes the j-th test sample. [class] is the label. Bold characters indicate the category label of a sample. Partition

No

Label

Sample

Training

T1

[positive]

That book really good!

T2

[negative]

Is this book really worth reading?

Test

S1

[positive]

This is a fine book, really!

S2

[negative]

This is a fine book, really?

Table 1 illustrates the difficulty of classifying text when the hierarchical spatial relationship information between characters is absent. The test sample Sj has the same six words but different special characters “?” and “!”. Because it is difficult to classify by learning word features without a richer context, the classifier is likely to classify S1 and S2 as the same class, even though they belong to different classes. S1 and S2 are the positive and negative samples, respectively. Therefore, we obtain features that are fused by learning different combinations of special characters and words. The spatial combination “really” and “!” may indicate a positive sample, while “really” and “?” indicates a negative sample. This hierarchical spatial relationship information can further help the model to classify test samples S1 and S2 correctly. Therefore, we propose a CharCaps model, which is a character-level text classification model framework based on Capsule Networks (CapsNet) [19, 20]. CapsNet has been used to solve word-level but not character-level text classification tasks [21]. The extracted character-level text features are reconstructed into a capsule vector representation instead of neurons as the basic unit of the neural network. We use seven

CharCaps: Character-Level Text Classification

189

convolutional layers to extract character-level text features and obtain the hierarchical spatial relationship information of the local features; we also use routing algorithms to realize information transfer from the capsule layers without the need for a pre-trained model. The experimental results show improvement in the performance of characterlevel text classification models. The main contributions of this study are summarized as follows. (1) To the best of our knowledge, this study is the first to use character-level capsule networks to process text classification tasks. (2) Different from existing character-level text classification algorithms, we were able to select the optimal characters efficiently using the proposed special character selection mechanism. (3) Experimental results show that the overall performance of the proposed model outperforms the state-of-the-art character-level text classification models on the five types of large datasets.

2 Related Work Text classification can be divided into two main approaches based on the smallest unit of the processing object language, i.e., word-level and character-level text classification models [5–7]. The former usually requires pre-trained models such as Word2Vec [8], Glove [9], and BERT [10] to supplement external knowledge. Pre-trained Language Models have been used in text classification tasks with limited samples [22]. Croce et al. [12] proposed a semi-supervised generative adversarial network text classification method that fine-tunes the BERT architecture using unlabeled data in a generative adversarial setting. Wang et al. [23] proposed a unified fine-tuning framework that explicitly captures cue semantics from non-target NLP datasets to enhance few-shot text classification in BERT-style models. Shnarch et al. [24] proposed adding an intermediate unsupervised classification task between the pre-training and fine-tuning phases to improve model performance. The latter approach uses characters as the smallest unit of language in neural network models [15–18]. However, these methods do not consider the hierarchical spatial relationship information between text features. In contrast, this study introduces capsule networks to model the hierarchical spatial relationship information of local text features, thus enhancing character-level text classification. Capsule Networks were first proposed by the Google Brain team led by Geoffrey Hinton in 2017 to overcome the limitations of CNNs in obtaining the hierarchical spatial relationship information of local features [19]. Yao et al. [25] proposed a novel capsule network routing algorithm using a reverse attention mechanism, enabling the capsule network to learn the local features of text by utilizing the hierarchical spatial relationship information of local features to improve classification accuracy. Gong et al. [26] demonstrated that the information aggregation mechanism in capsule networks has notable advantages over traditional methods such as maximum pooling. Wang et al. [27] proposed a method that combines RNNs and capsule networks for sentiment classification. Yang et al. [28] studied capsule networks with dynamic routing for text classification. Zhao et al. [29] proposed a scalable and reliable capsule network with dynamic routing that can be used for tasks such as multi-label text classification and intelligent question

190

Y. Wu et al.

answering. Chen et al. [30] proposed a transfer capsule network model for transferring document-level knowledge and solving aspect-level sentiment classification. Capsule neurons are essentially vectors or matrices that can hold more hierarchical spatial relationship information. However, capsule networks have not yet been used to solve character-level text classification tasks.

3 Proposed Method 3.1 Formalization Given an original character-level document D, the proposed character-level capsule network model aims to classify documents according to the categories set by users automatically. D is composed of n sentences, Td (1 ≤ d ≤ n), with L characters, that is, D = {T1 , T2 , · · · , Tn }. Each character, denoted by iL and iL ∈ RM , is represented by an M − dimensional character vector. We use W ∈ RH ×M to denote the convolutional kernel. The width of W is equal to the width of the character iL . As each character is indivisible, the convolution kernel height is denoted as H .

Fig. 1. The CharCaps model comprises three modules: character-level feature extraction layer, hierarchical spatial relationship information extraction layer, and routing mechanism. The model takes characters as input, and the features of words, root words, affixes, word fragments, and special characters are extracted by the five convolutional layers. Two additional convolutional layers extract higher-level text features, and the features of adjacent units are reconstructed to obtain a capsule vector representation. The routing mechanism then learns more discriminative and high-level semantic features to enhance the character-level text classification model.

The proposed CharCaps model, which takes character-level text as input and outputs the corresponding text category, is depicted in Fig. 1. The model comprises five convolutional layers that extract local features from the character-level text. These features, which are similar to character combinations, represent high-level text local features. Two convolutional layers then learn the hierarchical spatial relationship information of these local features. Based on the extracted high-level character-level text features, the capsule vector is reconstructed. Finally, the length of the DigitCaps determines the classification result.

CharCaps: Character-Level Text Classification

191

3.2 Feature Extraction Layer In this character-level text classification, L was generally set to 1024 characters as it is considered sufficient to capture most of the texts of interest. If the sample length is less than 1024, it is filled with zeros and truncated if the length exceeds 1024. A total of 128 characters of the highest weight are selected using Eq. (8) as the input of each sample. Any text sentence Td ∈ RL×M is considered as an M-dimensional embedding with a length of 1024 and not more than 128 characters. Each sample is converted into a matrix Tdt ∈ R1024×M which denotes the input of the CharCaps model. As shown in Fig. 1, five convolutional layers are set to extract character-level text features. For each convolution kernel W ∈ RH ×M , a convolution operation is performed on the text sequence Td = {i1 , i2 , · · · , iL } to obtain a local text feature yn , as follows: yn = f (W · T + b)

(1)

where b is the bias term and f (·) is the rectified linear unit (ReLU) activation function. Multiple convolution kernels of width M and height H (H ∈ {2, 3, 4, 5}) are used to learn multiple features to form a set of feature maps Y , as follows: Y = [y1 , y2 , · · · , yn ]

(2)

The proposed CharCaps model employs an end-to-end training method that extracts fixed character combinations from multiple convolutional layers, allowing the feature extraction layer to extract information from a combination of characters, including special characters, that is useful for classification. The five convolutional layers in CharCaps can extract words, root words, affixes, word fragments, and special characters. To combat model overfitting, a dropout operation is set between the convolutional layers. The optimization process during model optimization using the gradient obtained by backpropagation is controlled by the learning rate. The feature extraction layer is capable of identifying spelling errors and extracting word fragments and special characters. CharCaps performs better in large datasets as it extracts more adequate features and information, including complete words, word fragments, root words, affixes, and special characters. Additionally, character-level text features reduce information loss and are more resilient to word abbreviations or spelling errors. 3.3 Capsule Layer The CharCaps model employs five convolutional layers to extract local features from character-level text, which are then reconstructed as capsule vectors. This approach does not rely on pre-trained models and can process multiple languages and special characters, thereby enhancing the performance of character-level text classification models. In the PrimaryCaps layer, the 256 local features are reconstructed into 16 capsule vector representations, each with 16 dimensions. The corresponding DigitCaps layers are set based on the number of output categories. We use a dynamic routing algorithm, as demonstrated by Gong et al. [26], to train the model without fully connected and pooling layers, reducing information loss and network parameters. The dynamic routing algorithm facilitates communication between the PrimaryCaps and DigitCaps layers and

192

Y. Wu et al.

adjusts the network parameters through iterations of the routing algorithm. The goal of the dynamic routing algorithm is to find a set of mapping relationships between the l PrimaryCaps layer and (l + 1) DigitCaps layer, resulting in the reconstructed capsule vector. The routing algorithm communicates with the two capsule layers, i.e., the PrimaryCaps layer and DigitCaps layer, and replaces the fully connected layer. The length of the DigitCaps layer (L2 norm distance) represents the probability of category existence. Initially, the coefficient Aij is the weighted sum of the capsules ui of the previous layer. The initial weight is equal to k1 , indicating that the initial weights of Aij are equal. The routing algorithm then determines the most suitable weight coefficient using Eq. (3), where Aij is obtained from variable Bij . The initial value of variable Bij is set to zero. exp(Bij ) Aij =  k exp(Bik )

(3)

We then obtain an intermediate variable sj as follows: sj =



Aij · Wij · ui

(4)

i

We then calculate the capsule vj in the DigitCaps layer as follows:  2 sj  sj vj =  2  2 1 + sj  sj 

(5)

Finally, we obtain a new coefficient Bij using Eq. (6). At this point, the iterative routing process was completed. Bij = Bij + W ij · ui · vi

(6)

where Wij is a fixed shared weight matrix, calculated by optimizing the loss functions. In most cases, a character-level capsule network with the best performance can be obtained after approximately three or four routing iterations. The output length of the DigitCaps layer indicates the probability of a certain category. The coefficient Aij is updated by the routing algorithm, but the other parameters in the entire network and shared weight matrix Wij need to be updated according to the loss function in Eq. (7). Lossk = Ek · max(0, m+ − Vc )2 + λ(1 − Ek ) · max(0, Vc  − m− )2

(7)

where m+ = 0.9 and m− = 0.1 are the upper and lower bounds, respectively. The probability of category c is determined by the length Vc  of the input sample, and · refers to the L2 norm distance. Ek = 1 or 0 indicate that a category exists or not, respectively. For loss functions that lack categories, the weight of λ will be lowered to reduce the length of the activity vectors of all DigitCaps. We use λ=0.5, and the total loss is the sum of the losses of all DigitCaps.

CharCaps: Character-Level Text Classification

193

3.4 Selection Mechanism of Special Characters The CharCaps model takes an input of 128 characters, including 26 letters, 10 numbers, and 92 special characters. The letters do not distinguish between uppercase and lowercase letters. While the 26 letters and 10 numbers are the most common and important characters, the more than 1000 special characters are mostly infrequent and do not aid in classification. Therefore, we developed a character selection mechanism to sort special characters based on their weights. wk = −m ∗ log(

M ) Q+1

(8)

The selection mechanism comprises three steps. First, we calculated the frequency of each character. Second, we calculated the inverse document frequency of the characters. Finally, we multiplied the frequency and inverse document frequency to obtain the weights of these special characters, retaining those with high weights. The weight wk was calculated using Eq. (8), where k denotes the weight of the k-th character, m denotes the number of times a special character appears in a document, M denotes the total number of document categories in the corpus, and Q denotes the number of documents containing the special character. We added + 1 to the denominator to avoid division by zero. The 92 special characters with the highest weights were selected as the retained special characters, along with the 26 English letters and 10 numbers, resulting in a total of 128 input characters for the model.

4 Experiment 4.1 Datasets and Baselines To assess the effectiveness of the CharCaps model, we adopted accuracy as the evaluation metric, which is commonly used in character-level text classification studies. To this end, we conducted experiments on several large datasets, including AG News [15], Yelp R. [15], DBpedia [18], Yahoo A. [21], and Amazon R. [15], which have been widely used in previous studies. Furthermore, we compared the performance of our proposed algorithm with several state-of-the-art benchmark models, including CharCNN [15], Word-based CNN [15], Word-based LSTM [15], Word-based Capsule [21], Hierarchical [33], MixText [34], Hierarchical NA [35], and GP-Dense [18]. 4.2 Experimental Results To ensure a fair comparison between competing models, we conducted a series of experiments using a traditional pre-trained model and character-level classification models without pre-trained models. We present the accuracies of the proposed CharCaps model and comparison models in Table 3, where red indicates the best performing model on the different datasets, and bold underlined values indicate the second-best results. The first six rows show all word-level models using a pre-trained model, where external knowledge can be introduced to improve the model performance. Our proposed CharCaps model achieved the best results on the AG News dataset and the second-best

194

Y. Wu et al.

performance with other datasets. However, the Amazon Review dataset had the lowest performance among the different datasets, with word-based LSTM achieving the best results. This could be attributed to the large dataset size, which is beneficial for LSTM to extract text features of contextual information. Table 2. Results from demonstration methods. The first six rows of the table depict word-level models that use pre-trained models, indicating that external knowledge can be introduced to improve model performance. Our proposed CharCaps model, although not using a pre-trained model, achieves the best performance on four datasets (AG News, Yahoo Answers, DBpedia, and Amazon Review) and is close to the best results on the fourth dataset (Yelp Review). Model

Pre-train

AG News

Yahoo A

Dbpedia

Yelp R

Amazon R

Word-based CNN

YES

90.08

68.03

98.58

59.84

55.6

Word-based LSTM

YES

86.06

70.84

98.55

58.17

59.43

Word-based Capsule

YES

90.2

68.22

96.9

58.39

54.11

Hierarchical NA

YES

/

/

93.72

/

/

Hierarchical

YES

/

/

95.3

/

/

MixText

YES

91.5

74.1

99.2

/

/

CharCNN

NO

87.18

70.45

98.27

60.38

58.69

GP-Dense

NO

89.58

/

/

61.05

/

CharCaps(ours)

NO

93.1

71.5

98.9

60.9

58.9

Our proposed CharCaps model outperforms most of the word-based capsule models and other character-level text classification models without a pre-trained model in terms of accuracy. This demonstrates the effectiveness of our proposed method in achieving optimal results for most datasets. By using capsule vectors instead of neurons to obtain hierarchical spatial relationship information, our model can extract more information from raw character text, leading to improved classification accuracy. 4.3 Ablation Experiments

Table 3. Results of ablation experiments on five datasets Model

AG News

Yahoo A

Dbpedia

Yelp R

Amazon R

CharCaps(64)

84.2

64.3

90.3

56.2

54.3

CharCaps(full)

92.3

71.1

98.3

59.7

57.3

CharCaps(128)

93.1

71.5

98.9

60.9

58.9

CharCaps: Character-Level Text Classification

195

There are more than 1000 common special characters, but some of them occur very infrequently. To address this issue, we developed a character selection mechanism that ranks special characters according to their weights. The weights are calculated using Eq. (8). We conducted experiments on five datasets to evaluate the effectiveness of our character selection mechanism. The results are shown in Table 3. Table 3 shows that the best performance was achieved by using 128 characters with the character selection mechanism. This was a 0.92% improvement over using all characters, and a 6.8% improvement over using 64 characters. Therefore, we selected 128 as the optimal character size for our character selection mechanism. Our results suggest that the proposed character selection mechanism can be integrated into character-level text classification models to achieve high-accuracy text classification for challenging scenarios. 4.4 Different Network Layers

Fig. 2. Proposed model performance with different number of network layers. In most datasets, five convolutional layers give the best performance.

The proposed CharCaps model differs from other word-level models that use pretrained models. Unlike these models, character-level models use different characters as the smallest unit of language in neural network models. Our model employs up to five convolutional layers as character text feature extraction layers, and these features are used as input to the capsule network. Additionally, we set up two convolutional layers to extract higher-level text features, which are then reconstructed to obtain a capsule vector. We conducted experiments using one to seven convolutional layers to analyze their impact on the extraction of character-level text features. The results of experiments conducted using five datasets show that five convolutional layers can better extract character-level text features, as shown in Fig. 2. However, on the DBpedia dataset, the best result is achieved with six layers, although five layers show an accuracy of 98.1%, which is close to the best result. A number of network layers that are too large or too small will not yield optimal results. Too many convolutional layers can make the features too complex,

196

Y. Wu et al.

leading to overfitting and a longer training time. Conversely, too few convolutional layers fail to extract good character-level text features. The proposed CharCaps model does not use pooling layers, as opposed to convolutional neural networks, and uses ReLu as the activation function. Unlike word-level models that extract features resembling words or word fragments, root words, or affixes, our model extracts features using different characters as the smallest unit of language in neural network models. The extracted features are then reconstructed to obtain capsule vectors with hierarchical spatial relationship information.

5 Conclusions This work proposes a character-level text classification model based on capsule networks to learn hierarchical spatial relationships. The model employs multiple convolutional layers to learn character-level text features, reducing information loss, and reconstructs the local features of adjacent units to obtain capsule vector representations. A routing mechanism obtains information about the hierarchical spatial relationship of local features. Experimental results demonstrate the potential of capsule networks for character-level text classification. The proposed CharCaps model outperforms the current convolutional neural network-based approach on large datasets due to the information extracted from character-level text features being sufficient for large datasets. Capsule vectors model more effective hierarchical spatial relationships of text features, and the proposed mechanism of selecting special characters further improves text classification accuracy. Acknowledgments. This work was Sponsored by Natural Science Foundation of Shanghai (No. 22ZR1445000) and Research Foundation of Shanghai Sanda University (No. 2020BSZX005, No. 2021BSZX006).

References 1. Wan, J., Li, J., Lai, Z., Du, B., Zhang, L.: Robust face alignment by cascaded regression and de-occlusion. Neural Netw. 123, 261–272 (2020) 2. Wan, J., et al.: Robust facial landmark detection by cross-order cross-semantic deep network. Neural Netw. 136, 233–243 (2021) 3. Wu, Y., Li, J., Song, C., Chang, J.: Words in pairs neural networks for text classiffcation. Chin. J. Electron. 29, 491–500 (2020) 4. Sergio, G.C., Lee, M.: Stacked debert: all attention in incomplete data for text classiffcation. Neural Netw. 136, 87–96 (2021) 5. Kim, Y.: Convolutional Neural Networks for Sentence Classification. In: Conference on Empirical Methods in Natural Language Processing, pp. 1746–1751. ACL, Doha, Qatar (2014) 6. Liu, P., Qiu, X., Huang, X.: Recurrent Neural Network for Text Classification with Multi-Task Learning. In: 25th International Joint Conference on Artificial Intelligence, pp. 2873–2879. IJCAI/AAAI Press, New York, NY, USA (2016) 7. Lai, S., Xu, L., Liu, K., Zhao, J.: Recurrent Convolutional Neural Networks for Text Classification. In: 29th AAAI Conference on Artificial Intelligence, pp. 2267–2273. AAAI Press, Austin, Texas, USA (2015)

CharCaps: Character-Level Text Classification

197

8. Le, Q.V., Mikolov, T.: Distributed representations of sentences and documents. In: 31th International Conference on Machine Learning, pp. 1188–1196. JMLR, Beijing, China (2014) 9. Pennington, J., Socher, R., Manning, C.D.: Contribution title. In: Conference on Empirical Methods in Natural Language Processing, pp. 1532–1543. ACL, Doha, Qatar (2014) 10. Devlin, J., Chang, M.-W., Lee, K., Toutanova, K.: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In: Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 4171– 4186. ACL, Minneapolis, MN, USA (2019) 11. Mekala, D., Shang, J.: Contextualized Weak Supervision for Text Classification. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 323–333. ACL, Online (2020) 12. Croce, D., Castellucci, G., Basili, R.: GAN-BERT: generative adversarial learning for robust text classification with a bunch of labeled examples. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 2114–2119. ACL, Online (2020) 13. Qin, Q., Hu, W., Liu, B.: Feature projection for improved text classification. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 8161–8171. ACL, Online (2020) 14. Chen, H., Zheng, G., Ji, Y.: Generating hierarchical explanations on text classification via feature interaction detection. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 5578–5593. ACL, Online (2020) 15. Zhang, X., Zhao, J., LeCun, Y.: Character-level Convolutional Networks for Text Classification. In: 28th Annual Conference on Neural Information Processing Systems, pp. 649–657. Montreal, Quebec, Canada (2015) 16. Kim, Y., Jernite, Y., Sontag, D., Rush, A.M.: Character-aware neural language models. In: 9th International Proceedings on Proceedings, pp. 2741–2749. AAAI Press, Phoenix, Arizona, USA (2016) 17. Liu, B., Zhou, Y., Sun, W.: Character-level text classification via convolutional neural network and gated recurrent unit. Int. J. Mach. Learn. Cybern. 11(8), 1939–1949 (2020). https://doi. org/10.1007/s13042-020-01084-9 18. Londt, T., Gao, X., Andreae, P.: Evolving character-level densenet architectures using genetic programming. In: Castillo, P.A., JiménezLaredo, J.L. (eds.) Applications of Evolutionary Computation. LNCS, vol. 12694, pp. 665–680. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-72699-7_42 19. Sabour, S., Frosst, N., Hinton, G.E.: Dynamic routing between capsules. In: 30th Annual Conference on Neural Information Processing, pp. 3856–3866. Long Beach, CA, USA (2017) 20. Wu, Y., Li, J., Chen, V., Chang, J., Ding, Z., Wang, Z.: Text classification using triplet capsule networks. in: international joint conference on neural networks, pp. 1–7. IEEE, Glasgow, United Kingdom (2020) 21. Wu, Y., Li, J., Wu, J., Chang, J.: Siamese capsule networks with global and local features for text classification. Neurocomputing 390, 88–98 (2020) 22. Hong, S.K., Jang, T.: LEA: meta knowledge-driven self-attentive document embedding for few-shot text classification. In: Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, pp. 99–106. ACL, Seattle, WA, United States (2022) 23. Wang, J., et al.: Towards Unified Prompt Tuning for Few-shot Text Classification. In: Findings of the Association for Computational Linguistics: EMNLP 2022, pp. 524–536. Publisher, Abu Dhabi, United Arab Emirates (2022) 24. Shnarch, E., et al.: Cluster & tune: boost cold start performance in text classification. In: 60th Annual Meeting of the Association for Computational Linguistics, pp. 7639–7653. ACL, Dublin, Ireland (2022)

198

Y. Wu et al.

25. Tsai, Y.H., Srivastava, N., Goh, H., Salakhutdinov, R.: Capsules with inverted dot-product attention routing. In: 8th International Conference on Learning Representations, ICLR, Addis Ababa, Ethiopia (2020) 26. Gong, J., Qiu, X., Wang, S., Huang, X.: Information aggregation via dynamic routing for sequence encoding. In: 27th International Conference on Computational Linguistics, pp. 2742–2752. COLING, Santa Fe, New Mexico, USA (2018) 27. Wang, Y., Sun, A., Han, J., Liu, Y., Zhu, X.: Sentiment analysis by capsules. In: Conference on World Wide Web, pp. 1165–1174. ACM, Lyon, France (2018) 28. Yang, M., Zhao, W., Chen, L., Qu, Q., Zhao, Z., Shen, Y.: Investigating the transferring capability of capsule networks for text classification. Neural Netw. 118, 247–261 (2019) 29. Zhao, W., Peng, H., Eger, S., Cambria, E., Yang, M.: Towards scalable and reliable capsule networks for challenging NLP applications. In: 57th Conference of the Association for Computational Linguistic, pp. 1549–1559. ACL, Florence, Italy (2019) 30. Chen, Z., Qian, T.: Transfer capsule network for aspect level sentiment classification. In: 57th Conference of the Association for Computational Linguistic, pp. 547–556. ACL, Florence, Italy (2019) 31. Lehmann, J., et al.: DBpedia - a large-scale, multilingual knowledge base extracted from Wikipedia. Semant. Web 6(2), 167–195 (2015) 32. McAuley, J., Leskovec, J.: Hidden factors and hidden topics: understanding rating dimensions with review text. In: 7th ACM Conference on Recommender Systems, pp. 165–172. ACM, Hong Kong, China (2013) 33. Rojas, K.R., Bustamante, G., Cabezudo, M.A.S., Oncevay, A.: Efficient Strategies for Hierarchical Text Classification: External Knowledge and Auxiliary Tasks. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 2252–2257. ACL, Online (2020) 34. Chen, J., Yang, Z., Yang, D.: MixText: linguistically-informed interpolation of hidden space for semi-supervised text classification. In: 58th Annual Meeting of the Association for Computational Linguistics, pp. 2147–2157. ACL, Online (2020) 35. Sinha, K., Dong, Y., Cheung, J.C.K., Ruths, D.: A hierarchical neural attention-based text classifier. In: Conference on Empirical Methods in Natural Language Processing, pp. 817–823. ACL, Brussels, Belgium (2018)

Multi-student Collaborative Self-supervised Distillation Yinan Yang1,2 , Li Chen1,2(B) , Shaohui Wu1,2 , and Zhuang Sun1,2 1 School of Computer Science and Technology, Wuhan University of Science and Technology,

Wuhan 430065, Hubei, China [email protected] 2 Hubei Province Key Laboratory of Intelligent Information Processing and Real-Time Industrial System, Wuhan, China

Abstract. Knowledge distillation is a widely used technique for transferring knowledge from a teacher network to a student network. We leverage the spatial distribution properties of distinct model output values and the similarities of homologous targets, explore the decay and update mechanisms of multi-domain weights, and propose a multi-student collaborative self-supervised distillation approach to achieve a self-improvement strategy for multi-domain collaborative models. The distillation process comprises two parts: Distillation learning from teachers to students, where distillation loss is computed for the logical values divided into targets and non-targets. And mutual learning between students, where KLDivergence (Kullback Leibler Divergence) is employed as the mutual learning loss to explore additional convergence space. We conduct target classification and text detection experiments on CIFAR100 and ICDAR2015. The results demonstrate that our method can effectively reduce the number of parameters and computation required for the model after distillation while maintaining only a slight decrease in model accuracy. Moreover, multi-student inference at the same scale requires fewer resources than single-student inference for higher accuracy. In the distillation comparison of classification models using ResNet and VGG architectures, our method achieved an average improvement of 0.4%, while multi-student collaborative inference led to an average improvement of 1.4%. In the text detection distillation experiments, our method outperformed the DKD distillation methods with a 5.12% improvement in F1 scores. Keywords: Knowledge Distillation · Self-supervised Learning · Collaborative Learning

1 Introduction Deep learning has garnered significant attention and has undergone rapid development in recent years owing to its remarkable robustness to changes in target diversity. However, as deep learning models achieve superior performance, the computation and storage resource requirements increase, which hinders their application to low-resource devices such as the Internet of Things and mobile Internet. As a result, researchers © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 199–210, 2023. https://doi.org/10.1007/978-981-99-4742-3_16

200

Y. Yang et al.

have initiated investigations into efficient deep learning models that can deliver highperformance results while meeting the power consumption and real-time requirements of low-resource devices, without substantially compromising model performance. Knowledge distillation is considered an effective method to address this issue. Nevertheless, the effectiveness of knowledge distillation at present is constrained by the following limitations: (1) Information loss: In knowledge distillation, there is a certain amount of information loss when transferring knowledge from large models to small models. Although this information is not very important, in some tasks it may have a significant impact on performance. (2) Hard labeling limitation: In knowledge distillation, it is common to use hard labels of large models to train small models. This can make the output of small models very rigid and unable to express the complexity and diversity of large models. (3) Dataset Labeling Constraints: The knowledge distillation process often requires labeled datasets for supervised training of models in addition to the loss of distillation of large models to small models. This can make the model distillation limited with the labeling of the dataset.

2 Related Work Knowledge Distillation (KD) was firstly proposed by Hinton [1]. KD defines a learning manner where a bigger teacher network is employed to guide the training of a smaller student network. Knowledge distillation can be divided into two main approaches, namely distillation from logical values of the network output [5–8] and distillation from intermediate features of the network [9–19]. KD is a knowledge distillation method first proposed by Hinton et al. In this approach, a teacher network with larger parameters is pre-trained, and then a lighter student network is trained, and the training process includes the loss of the real labels and the output of the student network, as well as the distillation loss with the output of the teacher model, and the student network is able to approach the performance of the teacher model by continuous iteration. Deep Mutual Learning (DML) [2] is a method based on mutual learning of multiple neural networks. The main idea of DML is to improve the performance of each neural network by building multiple neural networks to train simultaneously, and each network not only learns from the real labels during the training process, but also learns from its peer networks to analyze each other’s experience continuously. The performance of each neural network is improved. Decoupled Knowledge Distillation (DKD) [3] is a knowledge distillation method after decoupling KL-Divergence. Unlike most distillation of intermediate features, this paper distills the logical values of the network output. It also decouples the classical KLDivergence into target distillation loss and non-target distillation loss, and then performs a flexible combination of the two losses to form the final distillation loss. Eventually, better results can be achieved while maintaining high efficiency. We leverage the spatial distribution characteristics of distinct model output values and the similarities of homologous targets. We explore the decay and update mechanism of multi-domain weights and propose multi-student collaborative self-supervised

Multi-student Collaborative Self-supervised Distillation

201

distillation approach to achieve a self-boosting strategy for multi-domain collaborative models. This approach enhances the accuracy of the distilled model and is better suited for diverse application scenarios. Our method approach involves the transfer of knowledge from a teacher network to a group of untrained student models, where learning is also taking place among the students. During the distillation process, each student receives training on two loss functions: the first is the distillation loss generated by the teacher’s model, and the second is the mutual learning loss between the student and other students. The combination of the two loss functions for self-supervised training does not require supervised loss of the labeled dataset and can break the labeling limitation of the dataset to achieve better results.

3 Rethinking Knowledge Distillation In this work, the multi students distillation mechanism is explored in depth. First, the structure of multi students collaborative self-supervised distillation is introduced. Then the advantages and disadvantages of various loss functions as distillation loss or mutual learning loss are studied. Finally the process of multi students self-supervision and the role of collaborative learning in the inference stage are given. 3.1 Motivation and Overview In a real human classroom learning scenario, there are two main ways for students to receive new knowledge. One is for teachers to impart knowledge to students for learning, and the other promotes communication and learning among students. They can be associated with the possibility of introducing a mechanism of multi-student mutual learning in the knowledge distillation framework to improve the performance of the distillation model.

Fig. 1. Illustration of multi-student learning using an example

202

Y. Yang et al.

In the knowledge distillation task of image classification, multi-student models are added for experiments. As shown in the Fig. 1, the multi students can not output the same prediction values. The heterogeneity of model initialization weights and architecture results in different students converge divergent paths during the training process. Leveraging this property can enable students to explore a wider convergence domain, thus prompting the concept of multi-student distillation. We propose a multi-student collaborative self-supervised structure, as shown in Fig. 2. Students learn to explore the convergence space collaboratively with each other and the teacher model guides their learning direction. The learning process contains two kinds of losses: (1) Mutual Loss among students. (2) Distillation Loss between teachers and students.

Fig. 2. Multi-Student Collaborative Self-supervised Distillation schematic

3.2 Loss Function KL-Divergence is used as the mutual learning loss function between students and students. The distillation loss function between teachers and students is DKD [3], which separates KL-Divergence into target and non-target loss components. And then the two losses are combined as the distillation loss function. For image classification, C is the numbe of classes. The logical value of the network output of i-th class is zi , where pi is the probability of j-th class. pi =

exp(zi ) C 

(1)

exp(zj )

j=1

For student models S1 , S2 , …., SK , the formula for KL-Divergence and the mutual learning loss of the student network can be expressed as: Mutual Loss =

K  i=1,j=i

KL(Si ||Sj )

(2)

Multi-student Collaborative Self-supervised Distillation

KL(S1 ||S2 ) =

C 

piS1 log(

i=1

piS1 PiS2

)

203

(3)

In the student model Si and teacher model T , KL(bT ||bS1 ) denotes the target category pS1 ) denotes the non-target category loss for loss for teachers and students, and KL( pT || teachers and students, the DKD formula and distillation loss can be expressed as: Distillation Loss =

K 

DKD(T ||Si )

(4)

i=1

DKD(T ||S1 ) = α × KL(bT ||bS1 ) + β × KL( pT || p S1 )

(5)

The loss of the final training process are as follows: Loss =

K 

[KL(Si ||Sj ) + DKD(T ||Si )]

(6)

i=1,j=i

To explore the effects of KL-Divergence and DKD as model distillation loss and mutual learning loss, ResNet 8 × 4 [20] is selected as the training model and controlled experiments are conducted on the CIFAR-100 dataset with crossover loss fixed with labels during training and then combined with mutual learning loss or distillation loss with Top-1 accuracy as an indicator(see Table 1). Table 1. KL-Divergence and DKD loss comparison results Student

Distillation Loss

Mutual Loss

Top-1

KL

-

64.79

DKD

-

75.94

-

KL

74.88

-

DKD

74.51

ResNet32 × 4 as the teacher ResNet 8 × 4

After the above experiments, it is demonstrated that DKD works better when it is used as distillation loss between models. KL-Divergence works better when it is used as mutual learning loss function between models. In this paper, theoretical analysis is conducted from the spatial distribution characteristics of model features, and the logical values of model outputs are dimensioned down using the T-SNE method, as shown in Fig. 3. It is found that there are similarities and differences between the logical value feature spaces of student and student, and the non-target logical values have a more discrete feature space distribution containing more information, while the DKD loss function weighted by non-target losses can better utilize non-target information for learning, so that the student can better iterate to the convergence domain.

204

Y. Yang et al.

Fig. 3. Feature of teacher and student models in target and non-target space, with different colors and shapes in the figure labeled corresponding to different categories of the dataset.

3.3 Multi-student Self-supervised Distillation The Multi-Student Distillation process is a self-supervised distillation without the supervised loss of dataset labels. Self-supervised distillation is achieved using distillation learning and mutual learning: (1) distillation learning, which uses DKD to calculate losses on target and non-target logical values to update the model. As shown in Fig. 4, the target features in the learning process iteratively move the student model closer toward the convergence domain, and the non-target features push the student model outward from the local convergence domain; (2) mutual learning, based on the similarity of homologous targets, the mutual learning process uses KL-Divergence for the two models to coalesce with each other and explore more convergence domains. The distillation process can effectively utilize the similarity between students and the difference between teacher and student characteristics to realize the multi-domain collaborative model self-raising strategy and self-supervised distillation. 3.4 Multi-student Adaptive Inference Inspired by ensemble learning, for multiple models obtained by multi-student distillation, the logical values of multiple model outputs are fed into the adapter for inference (see Fig. 5), The adapter is obtained by training with Bagging. The multi-student inference approach is more effective than single-model inference of the same scale. As can be observed from the experimental results in Table 3, the results of multi-student inference improve by an average of 3.14% compared to single-model.

Multi-student Collaborative Self-supervised Distillation

205

Fig. 4. Self-supervised Distillation schematic. The arrow represent the corresponding loss functions in the learning process.

Fig. 5. Multi-student Adaptive Inference

4 Experiment 4.1 Datasets and Metrics We focus on two fields of image classification and text detection, and the dataset includes: CIFAR100 [21]: image classification dataset containing 100 categories, with 50,000 images in the training set and 10,000 images in the test set. ICDAR2015 [22]: a competition dataset held at the International Conference on Document Analysis and Recognition (ICDAR), mainly involving text in street scenes, and contains 1000 training sets and 500 test sets. Image classification takes Top-1 accuracy as metrics, and text detection uses IOU50 as a positive example criterion, using Precision (P), Recall (R), and F1-Score (F1) as evaluation metrics.

206

Y. Yang et al.

4.2 Distillation with Isomerism and Heterogeneous To investigate the effect of distillation by students using isomerism and heterogeneous. Establishing comparative experiments with Various model architectures. As shown in Table 2, the greater the difference in the structure of the student models, the higher the Top-1 of the individual student. While taking the isomorphic students, the individual model results were not the best, but the Top-1 was the highest when combining the logical values of the two model outputs. Table 2. Comparison of different model architectures in CIFAR100. Mix represents the result of model ensemble learning. Teacher

Student

Top-1

ResNet32 × 4

ResNet8 × 4

76.60

ResNet8 × 4

76.58

Mix

78.15

ResNet8 × 4

76.51

ResNet32

73.52

Mix

77.48

ResNet8 × 4

76.63

WRN-16–2

75.62

Mix

77.99

ResNet8 × 4

76.82

WRN-40–1

74.65

Mix

78.02

4.3 Comparison of Other Methods To verify the effectiveness of the method in this paper, comparison experiments were conducted on the CIFAR100 dataset with FitNet [13], RKD [14], CRD [16], OFD [9], ReviewKD [4], KD [1], and DKD [3]. ResNet, WRN, and VGG networks, where the teacher model had been pre-trained and the two students were taken as homogeneous of the networks, using Kaiming to initialize the models [26]. As shown in Table 3, It can be seen that in the single student model comparison, this method has the best accuracy in the distillation process of the model with ResNet and VGG as the architecture, with an average improvement of 0.4 percentage points. In the experiments where WRN-16–2 was the student model, it differed from the best method by 0.48 percentage points; and after collaborative inference of the student models, the highest accuracy was achieved in the ResNet, WRN, and VGG network architecture control experiments with an average of 1.4 percentage point improvement.

Multi-student Collaborative Self-supervised Distillation

207

Table 3. Comparative results of different methods on CIFAR100 Distill manner

Teacher ResNet56 ResNet110 ResNet32 × WRN40–2 WRN40–2 VGG13 4 Student ResNet20 ResNet32

features FitNet

logits

ResNet8 × 4 WRN16–2 WRN40–1 VGG8

69.21

71.06

73.50

73.58

72.24

71.02

RKD

69.61

71.82

71.90

73.35

72.22

71.48

CRD

71.16

73.48

75.51

75.48

74.14

73.94

OFD

70.98

73.23

74.95

75.24

74.33

73.95

Review KD

71.98

73.89

75.63

76.12

75.09

74.84

KD

70.66

73.08

73.33

74.92

73.54

72.98

DKD

71.97

74.11

76.32

76.24

74.81

74.68

OURS

72.21

74.77

76.60

75.76

75.49

75.05

OURS MIX

73.68

75.91

78.15

77.10

76.69

76.40

4.4 Single Student and Multi-student Distillation Comparison Comparative experiments were conducted on CIFAR100 with ResNet and WRN as architectures, and the selected models are shown in the Table 4, with single and dual models with close amounts of control parameters. The results are shown in Fig. 6 Among the models with close parameter amounts, the results of collaborative inference of the dual model are on average 3.14 percentage points higher than the Top-1 of the single model. Table 4. Number of parameters on the networks Single Student

ResNet20

ResNet32

ResNet8 × 4

WRN16_1

WRN40_1

WRN16_2

parameters

0.56M

0.94M

2.46M

0.6M

1.14M

1.8M

Multi Students

ResNet44

Res56

Res152

WRN34_1

WRN50_1

WRN40_2

parameters

0.67M

1.02M

2.56M

0.81M

1.26M

2.26M

4.5 Distillation in Text Detection In addition to target classification, to explore the usefulness of the method in other domains. Comparative experiments were conducted in the text detection domain with the dataset of ICDAR2015,using DBNet [27] as the training model, ResNet18 as the

208

Y. Yang et al.

ResNet

WRN

Fig. 6. The comparison curve of single model and dual model on CIFAR100, stu_one indicates the single-student result and stu_mix indicates the multi-student collaboration result, the higher the curve to the left the better the effect.

backbone network for the teacher model and MobileNetV3 [28] as the backbone network for the student model, and the results are shown in Table 5. Table 5. Compare backbone in ICDAR2015 Method

Teacher Backbone

Student Backbone

P

R

F1

DKD

ResNet18

MobileNetV3

50.13

26.85

34.97

OURS

ResNet18

MobileNetV3

53.23

32.16

40.09

As shown in Table 5, this method has a 3.1 percentage point improvement in the P index of multi-student distillation, a 5.31 percentage point improvement in the R index, and a 5.12 percentage point improvement in the combined F1 index in the field of text detection distillation compared to the DKD method. 4.6 Ablation In the ablation experiment results, we can see that it is difficult for two student models to converge if they just learn from each other if the teacher model is missing as a guide; and the models cannot explore the best convergence domain if there is a lack of mutual learning between student models. The combination of student mutual learning and teacher distillation improves 1.57 percentage points in accuracy compared with single teacher distillation, which fully proves the effectiveness of this method (Table 6).

5 Discussion and Conclusion Collaborative multi-student distillation has many advantages over single teacher-student model distillation. First, the teacher performs distillation process to students with decoupled KL-Divergence as distillation loss, which can make use of more informative nontarget features for iterative updates. Finally, by merging the knowledge of student models

Multi-student Collaborative Self-supervised Distillation

209

Table 6. Results of ablation Student

Distillation Loss

ResNet32 × 4 as the teacher √ ResNet8 × 4 × √

Mutual Loss

Top-1

× √

75.03



5.15 76.60

in the inference stage, the accuracy of the model can be improved without increasing the computational cost, and the collaborative inference does not need to increase the model complexity because it only involves the integration of multiple small student models rather than adding more layers or neurons. In this research area of knowledge distillation, how to determine the best distilled knowledge objects as well as the best student-teacher structure and how to measure the proximity between students and teachers are challenges that need to be addressed as knowledge distillation research moves toward application. Furthermore, collaborative multi students self-supervised distillation is not the only way to reduce model parameters. How multi-student distillation can stand out from the competition in the task of reducing model parameters or combine other approaches to form more effective solutions for reducing model parameters remains a question well worth investigating. Acknowledgement. This work was supported by National Natural Science Founda-tion of China (62271359).

References 1. Chen, P., Liu, S., Zhao, H., Jia, J.: Distilling knowledge via knowledge review. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5008–5017 (2021) 2. Cho, J.H., Hariharan, B.: On the efficacy of knowledge distillation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 4794–4802 (2019) 3. Furlanello, T., Lipton, Z., Tschannen, M., Itti, L., Anandkumar, A.:Born again neural networks. In: International Conference on Machine Learning. pp. 1607–1616. PMLR (2018) 4. He, K., Zhang, X., Ren, S., Sun, J.: Delving deep into rectifiers: Surpassing human-level performance on imagenet classification. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1026–1034 (2015) 5. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for imagerecognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 6. Heo, B., Kim, J., Yun, S., Park, H., Kwak, N., Choi, J.Y.: A comprehen-sive overhaul of feature distillation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1921–1930 (2019) 7. Heo, B., Lee, M., Yun, S., Choi, J.Y.: Knowledge transfer via distillation of activation boundaries formed by hidden neurons. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 33, pp. 3779–3787(2019)

210

Y. Yang et al.

8. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531 (2015) 9. Howard, A., et al.: Searching for mobilenetv3.In: Proceedings of the IEEE/CVF international conference on computer vision. pp. 1314–1324 (2019) 10. Huang, Z., Wang, N.: Like what you like: Knowledge distill via neuron selectivity transfer. arXiv preprint arXiv:1707.01219 (2017) 11. Karatzas, D., et al.: ICDAR 2015 competition on robust reading. In: 2015 13th international conference on document analysis and recognition (ICDAR), pp. 1156–1160. IEEE (2015) 12. Kim, J., Park, S., Kwak, N.: Paraphrasing complex network: network compression via factor transfer. In: Advances in Neural Information Processing Systems, vol. 31 (2018) 13. Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny images (2009) 14. Liao, M., Wan, Z., Yao, C., Chen, K., Bai, X.: Real-time scene text detection with differentiable binarization. Proc. AAAI Conf. Artif. Intell. 34(07), 11474–11481 (2020). https://doi.org/10. 1609/aaai.v34i07.6812 15. Loshchilov, I., Hutter, F.: Decoupled weight decay regularization. arXiv preprint arXiv:1711. 05101 (2017) 16. Mirzadeh, S.I., Farajtabar, M., Li, A., Levine, N., Matsukawa, A., Ghasemzadeh, H.: Improved knowledge distillation via teacher assistant. Proc. AAAI Conf. Artif. Intell. 34(04), 5191–5198 (2020). https://doi.org/10.1609/aaai.v34i04.5963 17. Park, W., Kim, D., Lu, Y., Cho, M.: Relational knowledge distillation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3967–3976 (2019) 18. Peng, B., et al.: Correlation congruence for knowledge distillation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 5007–5016 (2019) 19. Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: Hints for thin deep nets. arXiv preprint arXiv:1412.6550 (2014) 20. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 21. Tian, Y., Krishnan, D., Isola, P.: Contrastive representation distillation. arXiv preprint arXiv: 1910.10699 (2019) 22. Tung, F., Mori, G.: Similarity-preserving knowledge distillation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1365–1374 (2019) 23. Yang, C., Xie, L., Su, C., Yuille, A.L.: Snapshot distillation: teacher-student optimization in one generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2859–2868 (2019) 24. Yim, J., Joo, D., Bae, J., Kim, J.: A gift from knowledge distillation: Fast optimization, network minimization and transfer learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4133–4141 (2017) 25. Zagoruyko, S., Komodakis, N.: Paying more attention to attention: improving the performance of convolutional neural networks via attention transfer. arXiv preprint arXiv:1612.03928 (2016) 26. Zagoruyko, S., Komodakis, N.: Wide residual networks. arXiv preprint arXiv:1605.07146 (2016) 27. Zhang, Y., Xiang, T., Hospedales, T.M., Lu, H.: Deep mutual learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4320–4328 (2018) 28. Zhao, B., Cui, Q., Song, R., Qiu, Y., Liang, J.: Decoupled knowledge distillation. In: Proceedings of the IEEE/CVF Conference on computer vision and pattern recognition. pp. 11953–11962 (2022)

Speech Emotion Recognition Using Global-Aware Cross-Modal Feature Fusion Network Feng Li(B) and Jiusong Luo Department of Computer Science and Technology, Anhui University of Finance and Economics, Anhui, China [email protected]

Abstract. Speech emotion recognition (SER) facilitates better interpersonal communication. Emotion is normally present in conversation in many forms, such as speech and text. However, existing emotion recognition systems use only features of a single modality for emotion recognition, ignoring the interaction of multimodal information. Therefore, in our study, we propose a global-aware cross-modal feature fusion network for recognizing emotions in conversations. We introduce a residual cross-modal fusion attention module (ResCMFA) and a global-aware block to fuse information from multiple modalities and capture global information. More specifically, we first use transfer learning to extract wav2vec 2.0 features and text features that are fused by the ResCMFA module. Then, multimodal features are fed into the global-aware block to capture the most important emotional information on a global scale. Finally, extensive experiments on the IEMOCAP dataset have shown that our proposed algorithm has significant advantages over state-of-the-art methods. Keywords: speech emotion recognition · global-aware · attention · wav2vec 2.0

1 Introduction Speech, as the first attribute of language, plays a decisive supporting role in language. It includes not only the linguistic information but also the emotional message intended by the speaker [1]. There is a tremendous difference in how the same text is expressed with different emotions. As a result of the relevance of emotion in communication, speech emotion recognition (SER) has gained greater interest [2–4]. As more people interact, such as Alexa, Siri, Google Assistant and Cortana, they must infer the user’s emotions and respond appropriately to enhance the user experience [5]. However, humans express their emotions not only through voice but also through writing, body gestures, and facial expressions [6, 7]. Understanding the emotions communicated in an utterance thus necessitates a thorough understanding of numerous modalities. Speech emotion recognition models human perception through the utilization of audio inputs to infer emotion categories. So far, one of the most researched modalities in emotion recognition (ER) is speech. Researchers presented machine models based © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 211–221, 2023. https://doi.org/10.1007/978-981-99-4742-3_17

212

F. Li and J. Luo

on Support Vector Machines (SVM) [8, 9], Gaussian mixture models [9], and Hidden Markov models [10] in the early stages. Engineering features such as Mel Frequency Cepstrum Coefficient (MFCC), energy, pitch, and others were used to generate machine models [11–13]. Traditional models perform well or poorly depending on the variety of features implemented. Additionally, established methods of emotion classification had proven to be pretty effective. Researchers are still looking for new traits or algorithms to fit different emotional expressions. Many surprising deep neural network results have appeared in SER as deep learning techniques have advanced [14, 15]. The earlier manufactured features are replaced by advanced features retrieved from the original waveform [16–18]. Han et al. [16] pioneered deep learning in SER by extracting high-level characteristics from raw audio using deep neural networks (DNNs). To collect the most significant emotional information across many scales, Zhu et al. [19] used a global-aware fusion block. Transfer learning, on the other hand, is widely used in SER. Many techniques make use of pre-trained self-supervised learning functions to address downstream speech-processing tasks including telephone recognition and automatic speech recognition (ASR) [20]. Audio and text information are derived from various modalities. Despite the fact that there is some association between them, fusing them is a difficult undertaking. As a result, integrating data from other modalities is the key issue [21, 22]. Many researchers have attempted to fuse information from various modalities utilizing different attention mechanisms, such as self-attention and cross-modal attention, as a growing number of attention mechanisms [23–25]. Various objects are targeted by self-attention processes and cross-modal attention mechanisms. For unimodal emotion recognition, self-attention mechanisms are commonly used, whereas cross-modal attention mechanisms are used for multimodal emotion recognition. Therefore, in this paper, we present a global-aware cross-modal feature fusion network in Fig. 1 that extracts features from various modalities by incorporating a crossmodal attention mechanism and a global-aware fusion block. First, we employ several pre-trained models as encoders for various modalities. The audio encoder in this example is the wav2vec 2.0 model, while the text encoder is the roberta-base structure. Second, we create the residual cross-modal feature fusion module (ResCMFA) using CMA to combine wav2vec 2.0 and text features. In order to build emotionally appropriate representations, this method bridges distinct modal information to each other’s feature potential space. Third, we present a global-aware fusion module for dealing with the important emotional information of multimodal fusion features. Finally, to reduce text contextual bias, we apply ASR as an auxiliary task. This method better accounts for the natural monotonic alignment of audio and text properties.

2 Proposed Methodology 2.1 Problem Statement The dataset D has k utterances ui that match the labels li . Each utterance is made up of a speech segment ai and a text transcript ti , where ui ∈ (ai , ti ). ti is either an ASR transcript or a human-annotated transcript. The proposed network takes ui as input and

Speech Emotion Recognition Using Global-Aware

213

assigns the correct emotion to any given discourse. U,L = { { ui = ai , ti , li }|i ∈ [1, k] }

(1)

Transcription

Characters

Predict label

Emotion label

FC

FC

Global-Aware Block

ResCMFA Module Pretrained wav2vec 2.0

Roberta-base

Raw waveform

Text sentences

Fig. 1. The architecture of the proposed method.

2.2 Feature Encoder Speech Representation. As the audio encoder, we used a pre-trained wav2vec 2.0 model to encode the speech in order to generate a high-level contextual representation reflecting the time-domain information. The wav2vec 2.0 model is built on a transformer structure that fits a set of ASR modeling units shorter than the phonemes to describe speech audio sequences. Furthermore, when comparing the two versions of the wav2vec 2.0 model, we selected the wav2vec2-base architecture with a dimension size of 768 as the feature extractor. Each utterance is encoded and to feed into our proposed model. Here, we input the audio data ai of the ith utterance into the pre-trained wav2vec 2.0 model to get the contextual embedding representation eai ∈ Rj∗DW . The DW indicates the size of the

214

F. Li and J. Luo

audio feature embedding. Thus, the eia can be expressed as follows eai = Fwav2vec 2.0 (ai )|i ∈ [1, k], ∀eai ∈ Rj∗DW

(2)

where Fwav2vec2.0 indicates the pre-trained wav2vec 2.0 model’s function as an audio feature processor. The wav2vec 2.0 framework, j is determined by the size of the raw audio and the CNN feature extraction layer. With a stride of 20ms and a hop size of 25ms, this CNN layer recovers frames from raw audio. Contextualized Word Representation. Rich text information facilitates the recognition of emotions. Therefore, in order to obtain a rich contextual token representation, we input the text data ti to the Roberta-base model for encoding. Before text feature extraction, we tokenize the input text, adding additional start and end tokens. In addition, we fine-tuned the labeling on the tokenized text data and the corresponding discourse. The extracted contextual embedding eit can be expressed as follows eti = FRoberta−base (ti )|i ∈ [1, k], ∀eti ∈ Rm∗DT

(3)

where FRoberta−base denotes the text feature extraction module, m depends on the number of tokens in the text. And DT is the dimension of text feature embedding. 2.3 Residual Cross-Modal Fusion Attention Module The proposed module is made up of two parallel fusion blocks that target separate modalities, identified in Fig. 2 as A-T ResCMFA block and T-A ResCMFA block. It accomplishes multimodal information interaction by using a multi-head attention mechanism. Audio information (ai ) and text information (ti ) are used to generate the relevant audio features and text features by the corresponding pre-trained feature extractors. The Query, Key, and Value of the multi-head attention mechanism distinguishes the two ResCMFA blocks. For the communication of audio and text information, the A-T ResCMFA block transmits audio features (eia ) as Query and text information (eit ) as Key and Value to the multi-head attention mechanism. In addition, the T-A ResCMFA block utilizes text features (eit ) as Query and audio features (eia ) as Key and Values. To start with, the audio and text features communicate via a multi-headed attention method. The communicating feature is then sent via a feedforward layer. Finally, it is linked to the block’s original audio or text features with a residual structure. i = 1 (eia , eit ) Ffusion 1

(4)

where 1 is the indicated first A-T ResCMFA’s or T-A ResCMFA’s learning function. Furthermore, the results of the first ResCMFA block, along with the original audio or text features, is passed into the second ResCMFA block. Multiple ResCMFA blocks are placed together in this manner to generate the matching multimodal fusion feature i and FTi −Am . FA−T m i i = m (...(2 (Ffusion , eit ))) A − T : FA−T m 1

(5)

i T − A : FTi −Am = m (...(2 (Ffusion , eia ))) 1

(6)

Speech Emotion Recognition Using Global-Aware

215

To improve the integration of multimodal information, we will constantly preserve the Key and Value values of each ResCMFA block as the module’s original audio and text features. The final multimodal fusion feature is produced by combining the outputs of the last two ResCMFA blocks. multi− mod al i = Concate(FA−T , FTi −Am ) Ffusion m

(7)

ResCMFA Module

Concatenation A-T ResCMFA Block-N

T-A ResCMFA Block-N

.. .

.. .

A-T ResCMFA Block-3

T-A ResCMFA Block-3

A-T ResCMFA Block-2

T-A ResCMFA Block-2

A-T ResCMFA Block-1

T-A ResCMFA Block-1

Q

Q A

T

K,V

T

A

K,V

Fig. 2. The architecture of the ResCMFA module.

2.4 Global-Aware Fusion The global-aware block’s architecture is made up of two fully connected layers, a convolutional layer, two normalization layers, a GeLU activation function, and a multiplication operation. The output is split into 2Df following the GeLU activation function projection. The multiplication operation enhances feature mixing across dimensions. Finally,

216

F. Li and J. Luo

the output of the global-aware block is integrated for classification. The corresponding equations are described as follows multimodal Fglobal−aware = global−aware (Ffusion )

(8)

yi = FC(Fglobal−aware ), yi ∈ RC

(9)

where global−aware is the function of multimodal fusion features through the globalaware block. C is the number of emotional categories. 2.5 CTC Layer In this study, the gradient is effectively back-propagated using CTC loss. As a result, we compute CTC loss using waw2vec 2.0 features eia and text transcription data ti , the equation is described as follows yia = FC(eai )

(10)

where yia ∈ Rj∗V , V = 32 is the size of our vocabulary list, with 26 letters of the alphabet and a few punctuation marks LCTC = CTC(yia , ti )

(11)

In addition, we need to compute the CrossEntropy loss using the output features yi of the global-aware block and the true emotion label li . LCrossEntropy = CrossEntropy(yi , li )

(12)

Finally, we introduce a hyperparameter α that combines the two loss functions into a single loss. α can effectively control the relative importance of CTC losses. L = LCrossEntropy + αLCTC , α ∈ (0, 1)

(13)

3 Experimental Evaluation 3.1 Dataset In this experiment, we trained and evaluated the proposed method on the IEMOCAP dataset, which is the benchmark for ER. In this study, we choose five emotions for ER: happy, angry, neutral, sad, and excited to evaluate our model [21]. Due to the similarity between happy and excited, we labeled the excited samples as happy. All experiments are implemented in the above emotional categories. We divided the dataset into a training set (80%) and a testing set (20%) according to the five-fold cross-validation principle.

Speech Emotion Recognition Using Global-Aware

217

3.2 Experimental Setup To investigate the positive effects of multimodality, we create two unimodal baselines employing both speech and text modalities. The contextualized text encoder was Roberta-base, and the classification was performed with a single linear layer and the softmax activation function. The speech baseline employed the same configuration as the text baseline, except the speech encoder was replaced with a pre-trained wav2vec 2.0 model. Table 1 presents the essential hyperparameter settings in our experiments. We employed weighted accuracy (WA) and unweighted accuracy (UA) as evaluation metrics to overcome the problem of unequal distribution of sample data. WA is the overall accuracy, and UA is the accuracy for each emotion category. The calculation is shown in Eq. 14. k UA = WA =

ni 1 Ni

k k

1 ni k 1 Ni

(14)

The Ni means the number of utterances in ith class, the ni means the number of correctly recognized utterances in ith class and k means the number of classes. Table 1. The hyper-parameters setting Parameters

Values

Batchsize

2

Accumulated gradient

4

Epoch

100

α

0;0.001;0.01;0.1;1

Optimizer

Adam

Learning rata

1 × e−5

Loss function

CrossEntropy, CTC

Evaluation metrics

WA and UA

3.3 Ablation Studies To better grasp the role of the various components in the proposed model. On the IEMOCAP dataset, we conducted many ablation studies. The WA and UA evaluation metrics are used in this experiment. We train our proposed network with solely audio or text features as input, without employing any fusion modality. Table 2 shows that combining two features combines the benefits of both features and considerably enhances the emotion recognition rate when compared to a single feature. Better results are achieved

218

F. Li and J. Luo

with audio features alone than with text features alone. The reason for this is that we fine-tuned the wav2vec 2.0 pre-training model, but not the roberta-base pre-training model. Furthermore, we study the effect of the global-aware block on our proposed model. Table 3 shows that inserting global-aware blocks enhances WA and UA by 1.09% and 1.07%, respectively. As a result, it can show that the global-aware block may gather more relevant emotional information to boost our model’s performance. Table 2. Comparison of results in different modalities Models

WA

UA

Only Roberta-base (baseline)

69.89%

69.27%

Only wav2vec 2.0 (baseline)

78.66%

79.76%

Roberta-base + wav2vec 2.0

82.01%

82.80%

Table 3. Comparison of results in the global-aware block Models

WA

UA

W/O Global –Aware Block

80.92%

81.73%

Ours

82.01%

82.80%

In addition, we built up ablation experiments for the ResCMFA. Because there are two sorts of ResCMFA blocks that are positioned in parallel. As a result, we test the influence of different ResCMFA block layer numbers in our proposed model. Table 4 represents the best model performance with four layers of ResCMFA blocks (m = 4). When m = 5, nevertheless, the model’s accuracy declines. We believe that m = 4 is the best option. Table 4. Comparison of results in different number of ResCMFA modules m

WA

UA

1

79.29%

79.92%

2

79.56%

80.79%

3

81.10%

82.16%

4

82.01%

82.80%

5

80.47%

80.95%

Finally, it has been established that the hyperparameter can influence the intensity of CTC loss. As a result, we attempt changing from 0 to 1 to get a different acceleration.

Speech Emotion Recognition Using Global-Aware

219

Table 5 depicts the influence of various values of on our best model. We can see that the favorable impact of CTC loss is greatest. Nevertheless, the added auxiliary task decreases the recognition rate of the model when α = 1. Table 5. Comparison of results in different values of α α

WA

UA

0

81.10%

81.56%

0.001

81.65%

81.88%

0.01

80.47%

81.19%

0.1

82.01%

82.80%

1

76.22%

77.05%

3.4 Comparative Analysis As shown in Table 6, we compare multimodal emotion recognition models in WA and UA using the same modality data. It is noteworthy that we selected the same audio features and text features that were used in [29]. Nevertheless, by incorporating the ResCMFA module and the global-aware block, our proposed model outperforms the MMER model in WA by 4.37%. This illustrates that the ResCMFA module and the global-aware block may integrate and receive emotional information from many modalities with greater efficiency. It is clear that our model achieves state-of-the-art results from experiments on the WA and UA. The comparison endorses the correctness of our proposed approach even more. Table 6. Quantitative comparison with multimodal methods on IEMOCAP dataset Method

WA

UA

Year

Xu et al. [21]

70.40%

69.50%

2019

Liu et al. [26]

72.40%

70.10%

2020

Makiuchi et al. [27]

73.50%

73.00%

2021

Cai et al. [28]

78.15%

Srivastava et al. [29]

77.64%

Morais et al. [30]

77.36%

77.76%

2022

Ours

82.01%

82.80%

2023

-

2021 2022

220

F. Li and J. Luo

4 Conclusion In this paper, a new approach is proposed for speech emotion recognition. After the ResCMFA module, the global-aware block is inserted to extract emotion-rich features worldwide. In addition, we present ASR as an auxiliary job for calculating CTC loss. Experiments on the IEMOCAP dataset show that the ResCMFA module, global-aware block, and ASR to calculate CTC loss all increase the model’s performance. In the future, we will investigate more efficient optimization strategies to build the accurate SER structures.

References 1. Sreeshakthy M, Preethi J.: Classification of human emotion from deap eeg signal using hybrid improved neural networks with cuckoo search. BRAIN. Broad Res. Artif. Intelligence and Neuroscience. 6(3–4):60–73 (2016) 2. Fan, W., Xu, X., Cai, B., Xing, X.: ISNet: Individual standardization network for speech emotion recognition. IEEE/ACM Trans. Audio, Speech Lang. Process. 30, 1803–1814 (2022) 3. Schuller, B.W.: Speech emotion recognition: Two decades in a nutshell, benchmarks, and ongoing trends. Commun. ACM 61(5), 90–99 (2018) 4. Zhang, H., Gou, R., Shang, J., Shen, F., Wu, Y., Dai, G.: Pre-trained deep convolution neural network model with attention for speech emotion recognition. Front. Physiol. 12, 643202 (2021). https://doi.org/10.3389/fphys.2021.643202 5. Dissanayake, V., Tang, V., Elvitigala, D.S., et al.: Troi: towards understanding users perspectives to mobile automatic emotion recognition system in their natural setting. Proc. ACM Human Comput. Interact. 6, 1–22 (2022) 6. Zhang, M., Chen, Y., Lin, Y., Ding, H., Zhang, Y.: Multichannel perception of emotion in speech, voice, facial expression, and gesture in individuals with autism: a scoping review. J. Speech Lang. Hear. Res. 65(4), 1435–1449 (2022) 7. Ko, B.: A brief review of facial emotion recognition based on visual information. Sensors 18(2), 401 (2018). https://doi.org/10.3390/s18020401 8. Jain, M., Narayan, S., Balaji, P., et al.: Speech emotion recognition using support vector machine. arXiv preprint arXiv:2002.07590 (2020) 9. Kandali, A.B., Routray, A., Basu, T.K.: Emotion recognition from Assamese speeches using MFCC features and GMM classifier. In: TENCON 2008–2008 IEEE region 10 conference, pp. 1–5. IEEE, Hyderabad (2008) 10. Nwe, T.L., Foo, S.W., De Silva, L.C.: Speech emotion recognition using hidden Markov models. Speech Commun. 41(4), 603–623 (2003) 11. Kwon, O.W., Chan, K., Hao, J., Lee, T.W.: Emotion recognition by speech signals. In: Eighth European Conference on Speech Communication and Technology, pp. 1562–1567. IEEE, Penang (2017) 12. Ververidis, D., Kotropoulos, C.: Emotional speech recognition: resources, features, and methods. Speech Commun. 48(9), 1162–1181 (2006) 13. Kishore, K.K., Satish, P.K.: Emotion recognition in speech using MFCC and wavelet features. In: 2013 3rd IEEE International Advance Computing Conference (IACC), pp. 842–847. IEEE, Ghaziabad (2013) 14. Chan W, Jaitly N, Le Q, Vinyals O.: Listen, attend and spell: A neural network for large vocabulary conversational speech recognition. In: 2016 IEEE international conference on acoustics, speech and signal processing (ICASSP), pp. 4960–4964. IEEE, Shanghai (2016)

Speech Emotion Recognition Using Global-Aware

221

15. Ramet, G., Garner, P.N., Baeriswyl, M., Lazaridis, A.: Context-aware attention mechanism for speech emotion recognition. In: 2018 IEEE Spoken Language Technology Workshop (SLT), pp. 126–131. IEEE, Athens (2018) 16. Han, K., Yu, D., Tashev, I.: Speech emotion recognition using deep neural network and extreme learning machine. Interspeech, pp. 223–227. Singapore (2014) 17. Gao, M., Dong, J., Zhou, D., Zhang, Q., Yang, D.: End-to-end speech emotion recognition based on one-dimensional convolutional neural network. In: Proceedings of the 2019 3rd International Conference on Innovation in Artificial Intelligence, pp. 78–82. ACM, New York (2019) 18. Yang, Z., Hirschberg, J.: Predicting arousal and valence from waveforms and spectrograms using deep neural networks. In: Interspeech, pp. 3092–3096. Hyderabad (2018) 19. Zhu, W., Li, X.: Speech emotion recognition with global-aware fusion on multi-scale feature representation. In: ICASSP 2022–2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 6437–6441. IEEE, Singapore (2022) 20. Yang, S.W., et al.: Superb: speech processing universal performance benchmark. Interspeech, pp. 1194–1198. Bron (2021) 21. Xu, H., Zhang, H., Han, K., Wang, Y., Peng, Y., Li, X.: Learning Alignment for Multimodal Emotion Recognition from Speech. Interspeech, pp. 3569–3573. Graz (2019) 22. Cambria, E., Hazarika, D., Poria, S., Hussain, A., Subramanyam, R.B.V.: Benchmarking multimodal sentiment analysis. In: Gelbukh, A. (ed.) Computational Linguistics and Intelligent Text Processing. LNCS, vol. 10762, pp. 166–179. Springer, Cham (2018). https://doi.org/10. 1007/978-3-319-77116-8_13 23. Cao, Q., Hou, M., Chen, B., Zhang, Z., Lu, G.: Hierarchical network based on the fusion of static and dynamic features for speech emotion recognition. In: ICASSP 2021–2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 6334– 6338. IEEE, Toronto (2021) 24. Wu, W., Zhang, C., Woodland, PC.: Emotion recognition by fusing time synchronous and time asynchronous representations. In: ICASSP 2021–2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 6269–6273. IEEE, Toronto (2021) 25. Sun, L., Liu, B., Tao, J., Lian, Z.: Multimodal cross-and self-attention network for speech emotion recognition. In: ICASSP 2021–2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 4275–4279. IEEE, Toronto (2021) 26. Liu, P., Li, K., Meng, H.: Group gated fusion on attention-based bidirectional alignment for multimodal emotion recognition. Interspeech, pp. 379–383. Shanghai (2020) 27. Makiuchi, M.R., Uto, K., Shinoda K.: Multimodal emotion recognition with high-level speech and text features. 2021 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU), pp. 350–357. IEEE, Colombia (2021) 28. Cai, X., Yuan, J., Zheng, R., Huang, L, Church, K.: Speech Emotion Recognition with MultiTask Learning. Interspeech, pp. 4508–4512. Bron (2021) 29. Srivastava, H., Ghosh, S., Umesh, S.: MMER: multimodal multi-task learning for emotion recognition in spoken utterances. arXiv preprint arXiv:2203.16794 (2022) 30. Morais, E., Hoory, R., Zhu, W., Gat, I., Damasceno, M., Aronowitz, H.: Speech emotion recognition using self-supervised features. In: ICASSP 2022–2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 6922–6926. IEEE, Singapore (2022)

MT-1DCG: A Novel Model for Multivariate Time Series Classification Yu Lu1(B) , Huanwen Liang1,2 , Zichang Yu1,2 , and Xianghua Fu1 1 College of Big Data and Internet, Shenzhen Technology University, Shenzhen, China

[email protected] 2 College of Applied Sciences, Shenzhen University, Shenzhen, China

Abstract. Cardiotocography (CTG) is a critical component of prenatal fetal monitoring, offering essential multivariate time series data that enables healthcare professionals to assess fetal growth and implement timely interventions for abnormal conditions, ensuring optimal fetal well-being. However, conventional CTG interpretation are susceptible to individual clinical experience and inconsistencies in assessment guidelines. To address these limitations, this study investigates artificial intelligence algorithms for developing an objective fetal assessment method based on multivariate time-series signals of fetal heart rate and uterine contractions. We preprocess data from an open-source fetal heart rate and contraction database, addressing missing values and noise reduction, and enhance the dataset for reliable experimentation. We also propose multivariate time-series signal models, including MT-1DCG, A-BiGRU, and ST-1DCG. The performance of the MT-1DCG model is validated through multiple experiments, demonstrating superior results compared to A-BiGRU and ST-1DCG models. Standard evaluation metrics, including accuracy, sensitivity, specificity, and ROC, are employed to assess model performance. The proposed MT-1DCG model yields an accuracy of 95.15%, sensitivity of 96.20%, and specificity of 94.09% in the test set. These findings indicate that our method effectively evaluates fetal health status and can support obstetricians in clinical decision-making. Keywords: Cardiotocography · Multivariate Time Series · Deep Learning

1 Introduction Multivariate time series analysis has become an essential tool for studying relationships between two time-dependent variables, providing valuable insights into underlying patterns and interactions. In fetal health monitoring, multivariate time series analysis of fetal heart rate (FHR) and uterine contraction (UC) signals from Cardiotocography (CTG) plays a pivotal role in understanding and predicting fetal well-being. This paper examines multivariate time series analysis in the context of FHR and UC signals, emphasizing the importance of capturing their complex interactions and dependencies for accurate identification of potential complications and timely medical interventions. CTG, a widely used technique for monitoring FHR and UC during pregnancy and labor, provides crucial information about fetal well-being and detects potential complications such as fetal hypoxia, acidemia, or distress. However, CTG signal interpretation © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 222–234, 2023. https://doi.org/10.1007/978-981-99-4742-3_18

MT-1DCG: A Novel Model for Multivariate Time Series Classification

223

is subjective, complex, and error-prone, necessitating the development of automatic and intelligent methods for CTG image classification to assist clinicians in making accurate and timely decisions. Deep learning techniques have demonstrated remarkable performance in various image classification tasks, including face recognition, object detection, and medical diagnosis. These models can learn high-level features from raw data without relying on hand-crafted features or domain knowledge and can efficiently and robustly handle large-scale, complex datasets. As a result, deep learning techniques have been applied to CTG image classification with promising outcomes. This study aims to develop a generalized algorithm for the automatic classification of multivariate timing signals of CTG. Data processing involved using an open-source dataset to remove outliers, applying the Hermite interpolation algorithm to fill missing values, and utilizing a window averaging method for data denoising. Due to patient privacy concerns and the scarcity of publicly available abnormal signals compared to normal data, a sliding window shifting algorithm was employed to expand the dataset and achieve category balance. Three temporal models—MT-1DCG, A-BiGRU, and ST1DCG—were proposed based on fetal heart morphological characteristics, with multiple reproducible experiments conducted to validate their performance. Experimental comparisons with other models were also performed. The proposed multivariate temporal model, MT-1DCG, combines a one-dimensional convolutional neural network (1D-CNN) with a gated recurrent unit (GRU) network, eliminating the need for manual feature extraction and enabling automatic classification.

Fig. 1. Flow chart of the approach.

The flow chart of our research is shown in Fig. 1. The process include: (1) Data pre-processing: we utilize open-source data with complete time-series signals for FHR and UC, pre-processing the data to address numerous abnormal values due to external factors during acquisition. (2) Model construction and training: we propose a multitemporal input model based on multi-temporal bioelectrical signals, considering the relative temporal relationship between FHR and UC as per fetal assessment rules, and

224

Y. Lu et al.

build and train the model. (3) Performance evaluation: we evaluate the model’s performance against our test set using accuracy, specificity, sensitivity, and other metrics, comparing it with alternative models.

2 Deep Learning Methods Deep learning techniques, a subset of machine learning methods, employ multiple layers of nonlinear transformations to extract high-level features from raw data. These models can automatically learn hierarchical representations without relying on hand-crafted features or domain knowledge. Furthermore, deep learning models efficiently and robustly handle large-scale and complex datasets. The three primary deep learning models used for CTG image classification are convolutional neural networks (CNNs), recurrent neural networks (RNNs), and multimodal approaches. CNNs are extensively used in CTG image classification because they capture spatial patterns and structures of CTG signals and learn high-level features from raw CTG images. For instance, Li et al. [11] proposed a CNN model to classify CTG images into normal, suspicious, or abnormal categories, using a sliding window approach to segment CTG images into smaller patches. Zhao et al. [20] proposed an 8-layer deep CNN framework for predicting fetal acidemia using CTG images, with continuous wavelet transform (CWT) input 2D images. They also proposed a system using recurrence plot (RP) and CNN to diagnose fetal hypoxia [8]. Baghel et al. [17] proposed a one-dimensional CNNbased method for diagnosing fetal acidosis using fetal heart rate signals. Ogasawara et al. [2] introduced a deep neural network model (CTG-net) to detect impaired fetal state. Fergus et al. [15] suggested a method based on one-dimensional CNN, LSTM, and MLP to predict fetal health using FHR and UC signals. Liang and Li [12] developed a CNN model using a weighted voting mechanism to classify FHR into normal or pathological states. Fasihi et al. [7] proposed a one-dimensional CNN to assess fetal state using FHR and UC signals. RNNs have also been used for CTG image classification because they exploit temporal information and context of CTG signals, learning high-level features from raw CTG signals. For instance, Liu et al. [13] proposed a method based on attention mechanism CNN and bidirectional long short-term memory network (BiLSTM) to focus on essential input features and obtain transform coefficient features of FHR signals using DWT, reducing overfitting. The two features are then fused for fetal acidosis classification. Bondet et al. [3] suggested a method based on LSTM and GRU to enhance data by random transformations, using fetal heart rate signals and other sensor signals to detect maternal heart rate and false signals, improving cardiotocogram quality. Xiao et al. [18] introduced a deep learning model combining CNN and RNN, specifically a multiscale CNN and BiLSTM. The CNN extracts time-frequency features of CTG signals, while the RNN extracts temporal features, which are then fused through a feature fusion layer to form a deep feature fusion network (DFFN). Multimodal deep learning models have also been used for CTG image classification because they can combine spatial and temporal information from CTG images and signals, learning high-level features from both modalities. For example, Fei et al. [8] proposed a method to intelligently interpret and classify CTG signals using a MBiGRU.

MT-1DCG: A Novel Model for Multivariate Time Series Classification

225

The method handles both numerical and image data of CTG signals. Spairani et al. [17] proposed a method to classify fetal heart rate (FHR) signals using a mixed data type approach that leverages MLP and CNN. This method handles both numerical and image data of FHR signals. Asfaw et al. [1] proposed a method that can handle both numerical and image data of FHR signals, as well as numerical data of contraction signals.

3 Data Preprocessing and Augmentation 3.1 Dataset In this study, we utilize the open-source CTU-CHB database, which contains 552 highquality CTG data from the Czech Technical University (CTU) in Prague and the University Hospital (UHB) in Brno, out of a total of 9164 CTG data. A sample CTG image after data preprocessing is shown in Fig. 2. The dataset includes not only FHR and UC signals, but also other parameters for assessing fetal health, such as cord blood gas pH analysis and Apgar scores, which are widely accepted assessment criteria. Fetal hypoxia in utero may result in acidemia, and there is a correlation between fetal heart monitoring parameters and the fetus’s acid-base balance and hypoxia. Research indicates that cord blood gas pH analysis accurately identifies intrauterine hypoxia in fetuses [19]. Consequently, this study employs pH as the gold standard for the assessment algorithm.

Fig. 2. Multivariate time series of CTG data.

We determined the pH threshold by consulting several international studies [6, 9, 10] that employed a pH value of 7.15 as the judgment threshold. Using this pH threshold, we classified the data into two categories: normal (negative) for pH values greater than or equal to 7.15 and pathological (positive) for values below 7.15. The resulting dataset consisted of 447 normal cases and 105 pathological cases, with a ratio of approximately 4:1 between the two categories.

226

Y. Lu et al.

3.2 Preprocessing Acquiring clinical fetal heart rate signals typically involves placing a Doppler ultrasound (US) probe on a pregnant woman’s abdomen and measuring the fetal electrocardiogram (FECG) signal through electrodes attached to the fetal scalp [14]. The collected data may be noisy due to interference from factors like maternal movement or instrument error, and the quality of the obtained fetal heart rate data can affect assessment performance. Consequently, data preprocessing is crucial for this study. We performed the following primary preprocessing steps for the CTG images: (1) Identify data points with a fetal heart rate value of 0. If the detected interval is longer than 15 s, the interval is removed. (2) Identify data points with a fetal heart rate value of 0. If the detected interval is less than 15 s, interpolation is performed within that interval. (3) Remove anomalies with fetal heart rate values below 50 bpm or above 200 bpm and interpolate using Hermite interpolation. This method fits a discrete set of data points to obtain a continuous function by employing Hermite polynomials for polynomial interpolation. 3.3 Data Augmentation The limited size of the dataset (552 data points) and the imbalanced category distribution may result in overfitting of the model, which could negatively impact the algorithm’s overall reliability. To address this issue, we expand the data using the approach proposed by Cui [5]. For a time series of length n, the expansion can be defined using Eq. (1): T = {t1 , t2 , t3 , · · · , tn }

(1)

A segment within that time series is defined as Eq. (2): Si:j = {si , si+1 , si+2 , · · · , sj−1 , sj }, 1 40 and σn ≤ 80, then we choose the second network model. If σn > 80, then we choose the third network model. The only difference between our proposed method and the standard DnCNN is the training of the networks. Our proposed method trains three noise ranges ([0,40], [40,80], and [80,120]) whereas the standard DnCNN only train one noise range [0, 55]. Unlike non-blind denoising, we need more epochs for training blind DnCNN models with a range of noise levels, which needs longer training time. Instead of TensorFlow or PyTorch in Python, we use Matlab deep learning toolbox for training our models. The Matlab code for training our new models is given as follows: % Train DnCNN SIGMA=[1 40 80 120]; for i=1:length(SIGMA)-1 SigR=[SIGMA(i) SIGMA(i+1)]/255; setDir = fullfile(toolboxdir('images'),'imdata'); imds = imageDatastore(setDir,'FileExtensions',{'.jpg'}); dnds = denoisingImageDatastore(imds,... 'PatchesPerImage',512,... 'PatchSize',50,... 'GaussianNoiseLevel',SigR,... 'ChannelFormat','grayscale'); layers = dnCNNLayers; options = trainingOptions('adam', ... 'MaxEpochs',500,... 'InitialLearnRate',1e-4, ... 'Verbose',false, ... 'Plots','training-progress'); net = trainNetwork(dnds,layers,options); save(['net' num2str(SIGMA(i+1)) '.mat'],'net'); end

266

G. Y. Chen et al.

In this paper, the standard DnCNN is given by Matlab, where we can load the pretrained model by: net = denoisingNetwork(‘DnCNN’); We can perform denoising by using the following Matlab command: denoisedImage = denoiseImage(noisyImage,net); The training for each of our three noise ranges takes about 16.5 h on a GPU available at Concordia University. Experimental results in the next section demonstrate that our proposed method outperforms the standard DnCNN for almost all testing cases and all noise levels.

3 Experiments We test our proposed method in this paper for six grayscale images: Barbara, Boat, Fingerprint, House, Lena, and Peppers. Figure 2 depicts these six noise-free images, which are frequently tested in image processing. The noisy image B is generated from the noise-free image A as follows: B = A + σ n Z, where σn is the noise standard deviation and Z obtains normal distribution N(0, 1). The peak signal to noise ratio (PSNR) is used to measure the denoising capability of both DnCNN and our proposed method. The PSNR is defined as PSNR(A.B) = 10log10 ( 

M × N × 2552 ) 2 i,j (B(i, j) − A(i, j))

where M × N is the number of pixels in the image, and A and B are the noise-free and noisy/denoised images. The PSNR is one of the frequently adopted metrics for measuring the visual quality of an image denoising algorithm. Table 1 shows the PSNR of both the DnCNN and the proposed method for different noise levels (σn = 10, 30, 50, 70, 90, 110, 130, 150, and 170) and six different images. The best results are heighted in bold font. Our proposed method outperforms the DnCNN for almost all testing cases. There are only two special cases where the standard DnCNN is better than our proposed method. For example, the DnCNN is better than our new method at σn = 50 for image Barbara and at σn = 30 for image Boat.

Improved Blind Image Denoising with DnCNN

267

Figures 3, 4, 5, 6,7 and 8 shows the noisy image, the denoised image with DnCNN and the denoised image with the proposed method for images Barbara, Boat, Fingerprint, House, Lena, and Peppers, respectively. The proposed method is almost always better than the standard DnCNN in terms of both PSNR and visual quality. This demonstrates that our proposed method is very effective in image denoising when the images are corrupted by AWGN.

Fig. 2. The images used in the experiments for noise reduction.

Fig. 3. The noisy image Barbara, the denoised image with DnCNN and the denoised image with the proposed method.

268

G. Y. Chen et al.

Fig. 4. The noisy image Boat, the denoised image with DnCNN and the denoised image with the proposed method.

Fig. 5. The noisy image Fingerprint, the denoised image with DnCNN and the denoised image with the proposed method.

More experiments will be conducted to test the effectiveness of our proposed method in this paper. For example, we can try our method on many testing images and then report the mean and standard deviation of the resulting PSNR.

Improved Blind Image Denoising with DnCNN

269

Fig. 6. The noisy image House, the denoised image with DnCNN and the denoised image with the proposed method.

Fig. 7. The noisy image Lena, the denoised image with DnCNN and the denoised image with the proposed method.

270

G. Y. Chen et al.

Fig. 8. The noisy image Peppers, the denoised image with DnCNN and the denoised image with the proposed method.

Table 1. The PSNR of both the DnCNN and the proposed method for different noise levels and different images. The proposed method is better than the standard DnCNN almost always. The best results are heighted in bold font. Images

Methods

σn 10

Barbara

DnCNN

Boat

DnCNN

30

50

70

90

110

130

150

170

31.14 25.99 23.87 22.24 20.63 19.17 17.98 17.04 16.29

Proposed 32.09 26.34 23.69 22.50 21.30 20.18 19.19 18.37 17.70 33.24 28.50 25.97 23.74 21.70 20.05 18.77 17.79 17.03

Proposed 33.73 28.46 26.30 24.49 22.95 21.66 20.59 19.74 19.06 Fingerprint DnCNN

30.66 25.67 22.99 21.05 19.45 18.17 17.19 16.42 15.81

Proposed 32.07 26.03 23.79 21.80 20.04 18.68 17.65 16.90 16.36 House

DnCNN

Lena

DnCNN

35.19 30.77 27.87 24.99 22.44 20.48 19.04 17.98 17.16

Proposed 35.84 30.97 29.09 26.77 24.87 23.04 21.58 20.47 19.61 34.98 30.30 27.57 24.94 22.50 20.57 19.12 18.03 17.20

Proposed 35.62 30.31 28.55 26.29 24.32 22.65 21.32 20.28 19.47 Peppers

DnCNN

33.94 28.48 25.57 23.18 21.09 19.41 18.14 17.17 16.41

Proposed 34.44 28.65 25.92 23.85 21.88 20.52 19.40 18.50 17.79

Improved Blind Image Denoising with DnCNN

271

4 Conclusions Digital images are often damaged by noise during acquisition, compression, and transmission. As a result, this will cause distortion and loss of information. Due to noise, the following image processing tasks such as image classification, video processing, and object tracking, will be influenced. Therefore, image denoising plays a vital role in today’s image processing systems. In this paper, we have proposed an improved DnCNN for image denoising by training three network models with the noise range [0, 40], [40, 80], and [80, 120], respectively. These three models can be trained in parallel by taking advantages of the latest GPUs. We estimate the noise level in each noisy image by means of wavelet shrinkage. Based on the estimated noise level, we select the appropriate network model for image denoising. Experimental results demonstrate that our proposed method outperforms the standard DnCNN for almost all noise levels and all testing images. We would like to design novel CNN structures for image denoising, which should be better than the existing DnCNN. We would like to develop new patch-based denoising methods for reducing noise form both gray scale and color images. We would also like to improve our previously published algorithms for image denoising [4–7].

References 1. Zhang, K., Zuo, W., Chen, Y., Meng, D., Zhang, L.: Beyond a gaussian denoiser: residual learning of deep CNN for image denoising. IEEE Trans. Image Process. 26(7), 3142–3155 (2017) 2. Donoho, D.L., Johnstone, I.M.: Ideal spatial adaptation by wavelet shrinkage. Biometrika 81(3), 425–455 (1994) 3. Kingsbury, N.G.: Complex wavelets for shift invariant analysis and filtering of signals. J. Appl. Comput. Harmon. Anal. 10(3), 234–253 (2001) 4. Cho, D., Bui, T.D., Chen, G.Y.: Image denoising based on wavelet shrinkage using neighbour and level dependency. Int. J. Wavelets Multiresolut. Inf. Process. 7(3), 299–311 (2009) 5. Chen, G.Y., Kegl, B.: Image denoising with complex ridgelets. Pattern Recogn. 40(2), 578–585 (2007) 6. Chen, G.Y., Bui, T.D., Krzyzak, A.: Image denoising using neighbouring wavelet coefficients. Integr. Comput.-Aid. Eng. 12(1), 99–107 (2005) 7. Chen, G.Y., Bui, T.D., Krzyzak, A.: Image denoising with neighbour dependency and customized wavelet and threshold. Pattern Recogn. 38(1), 115–124 (2005)

Seizure Prediction Based on Hybrid Deep Learning Model Using Scalp Electroencephalogram Kuiting Yan, Junliang Shang, Juan Wang, Jie Xu, and Shasha Yuan(B) School of Computer Science, Qufu Normal University, Rizhao 276826, China [email protected]

Abstract. Epilepsy is a neurological disorder that affects the brain and causes recurring seizures. Scalp electroencephalography (EEG)-based seizure prediction is essential to improve the daily life of patients. To achieve more accurate and reliable predictions of seizures, this study introduces a hybrid model that merges the Dense Convolutional Network (DenseNet) and Bidirectional LSTM (BiLSTM). The densely connected structure of DenseNet can learn richer feature information in the initial layers, while BiLSTM can consider the correlation of the time series and better capture the dynamic changing features of the signal. The raw EEG data is first converted into a time-frequency matrix by short-time Fourier transform (STFT) and then the STFT converted images are fed into the DenseNet-BiLSTM hybrid model to carry out end-to-end feature extraction and classification. Using Leave-One-Out Cross-Validation (LOOCV), our model achieved an average accuracy of 92.45%, an average sensitivity of 92.66%, an F1-Score of 0.923, an average false prediction rate (FPR) of 0.066 per hour, and an Area Under Curve (AUC) score was 0.936 on the CHB-MIT EEG dataset. Our model exhibits superior performance when compared to state-of-the-art methods, especially lower false prediction rate, which has great potential for clinical application. Keywords: Scalp EEG · Seizure prediction · STFT · DenseNet · BiLSTM · Hybrid model

1 Introduction Epilepsy is a non-contagious, long-term brain disease. The symptoms of epileptic seizures vary according to the location and transmission of abnormal incoming neurons in the brain, which can manifest as conscious, mental, motor, sensory or autonomic nervous disorder. The world faces a growing number of people with epilepsy every year, and the incidence of epilepsy is markedly greater in developing nations compared to developed nations [1]. If epilepsy is not treated promptly and correctly, it can recur over the course of the disease. The EEG of individuals with epilepsy can typically be categorized into four phases: the seizure phase, which occurs during the seizure event; the pre-seizure phase, which takes place before the seizure; the inter-seizure phase, which refers to the interval between © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 272–282, 2023. https://doi.org/10.1007/978-981-99-4742-3_22

Seizure Prediction Based on Hybrid Deep Learning

273

seizures; and the post-seizure phase, which occurs after the seizure has ended. When seizures are not treated in a timely manner they can be extremely damaging to the patient’s life, and predicting seizures can be an effective way to avoid the danger. Usually, doctors with clinical experience can get a general idea of a patient’s condition from the patient’s EEG images, but they may be overwhelmed by a large amount of data and overlook small fluctuations in the EEG signal while using machines to predict seizures can handle large amounts of data and thus avoid the problems that may occur manually. Analyzing EEG patterns during these phases can provide valuable insights into the pathophysiology of epilepsy and lead to the development of better seizure prediction methods. This study presents a method that utilizes a DenseNet-BiLSTM model for predicting seizures. By combining these two deep learning techniques, the primary aim of the approach is to enhance the accuracy and reliability in seizure prediction. This approach comprises two main steps. To begin, the unprocessed EEG signal is transformed into a time-frequency matrix using the short-time Fourier transform. The converted data is subsequently employed as input for the DenseNet model. By converting the EEG data to image format, the model can better capture the temporal and spectral characteristics of the signal, resulting in improved seizure prediction accuracy. In the second stage, the converted EEG data is trained to analyze the difference between preictal and interictal periods to predict seizures. This hybrid model could take full advantage of DenseNet’s strengths in feature extraction and BiLSTM’s capabilities in handling time-series data. By employing this hybrid model, key features in EEG signals can be captured more effectively, thereby improving the accuracy of seizure prediction.

2 Related Work In recent years, new advances in seizure prediction research have been made. Machine learning techniques were widely used in early seizure studies. Marzieh et al. by extracting four statistical features, mean, variance, skewness, and kurtosis, K-Nearest Neighbors (KNN) and Support Vector Machine (SVM) classifiers were applied to the Bonn dataset to obtain accuracies of up to 99.5% and 100%, respectively [2]. Usman et al. utilized empirical mode decomposition (EMD) for the preprocessing, extracted spectral and statistical moment features combined into a feature vector, and applied KNN, plain Bayesian, and support vector machine classifiers on the CHB-MIT dataset to obtain the sensitivity of 92.23% and specificity of 93.38% [3]. James et al. computed feature spectra of spatial delay correlation and covariance matrices from 15-s EEG data blocks at multiple delay scales, which were input to the SVM classifier to validate the algorithm effects on the Freiburg dataset [4]. As deep learning continues to evolve, CNNs have become increasingly ubiquitous in a wide range of applications. Liang et al. constructed the 18-layer long-term recurrent convolutional network (LRCN) to automatically identify and localize epileptogenic regions in scalp EEG, when tested on the CHB-MIT dataset, obtaining the accuracy of 99%, sensitivity of 84%, and specificity of 99% [5]. Daoud et al. developed a seizure prediction method that employs CNNs within a deep learning framework [6]. The CNNs are utilized for extracting important spatial features from various scalp locations, introducing a semi-supervised method to improve the optimal solution, and suggesting a channel

274

K. Yan et al.

automatic picking algorithm to pick the most significant EEG channels. Khan et al. employed continuous wavelet transform to process EEG data and input it into CNN [7]. They further evaluated the efficacy of the features extracted by the CNN in capturing the transition from the interictal to the preictal using the Kullback-Leibler (KL) divergence. Zhang et al. developed a shallow CNN to differentiate between preictal and interictal, achieving a sensitivity of 92.2% and FPR of 0.12/h [8].

3 Proposed Methods 3.1 Method Overview Firstly, the EEG signal is processed through short-time Fourier transform to obtain the data that is suitable for use as model input (two-dimensional time-frequency matrix). Next, DenseNet processes the 2D image input data to the composite model and inputs the resulting feature maps to BiLSTM. The softmax function is then employed to calculate the probabilities of interictal and preictal states. The proposed model is depicted in Fig. 1.

Fig. 1. Summary of the proposed model.

3.2 CHB-MIT EEG Dataset The EEG recordings utilized in this research were sourced from the publicly CHB-MIT scalp EEG dataset. The EEG dataset comprises 23 instances of intractable epilepsy that were gathered from Boston Children’s Hospital. Notably, both chb01 and chb21 were acquired from the same female participant. The recordings were segmented into 24 cases and sampled using the international 10–20 system for EEG electrode placement and nomenclature at a rate of 256 Hz with a 16-bit resolution. The EEG dataset contained a total of approximately 983 h of recording duration and 198 seizures were marked in the EEG signals [9]. 3.3 Preprocessing There are several methods commonly used for time-frequency analysis of EEG data, including wavelet transform, STFT, and Stockwell transform [10]. STFT facilitates the analysis of signals in the time-frequency domain by providing both temporal and frequency information. This is particularly significant for the examination of non-stationary

Seizure Prediction Based on Hybrid Deep Learning

275

signals, such as EEG signals, given that their frequency components vary over time. In comparison to alternative pre-processing techniques, STFT boasts the benefits of straightforward implementation, adjustable window size, and low computational complexity when analyzing EEG signals. Although other methods may possess advantages in certain areas, STFT has consistently proven effective for EEG signal analysis in numerous instances. In this research, we employed the STFT to convert the EEG data into a 2D matrix, where the frequency and time domains serve as the respective axes. The signal was divided into 4-s windows by channel, and the DC component of 0 Hz and frequency ranges of 57–63 Hz and 117–123 Hz were excluded from the analysis. 3.4 Deep Learning Models DenseNet Traditional CNNs suffer from some issues, such as difficulty in leveraging deep-level features, vulnerability to gradient vanishing, and overfitting. To address these problems, researchers have proposed a series of methods for improving CNNs. Among these methods, DenseNet [11] is a classical convolutional neural network that has more layers and fewer parameters than ResNet [12]. This architecture promotes the reuse of features, simplifies the training of network, and effectively prevents the gradient vanishing and model degradation problems. A feed-forward connection links each layer to the others in DenseNet. In contrast to the L connections of a conventional CNN with L layers, a DenseNet has L(L + 1)/2 connections. Specifically, in DenseNet, each convolutional layer takes as input the output from the preceding layer as well as the outputs from all prior layers, forming a densely connected structure. This architecture is illustrated in Fig. 2.

Fig. 2. The connected structure of DenseNet. The square symbolizes the feature map derived from multiple channels, while c represents the channel-wise connection operation.

BiLSTM LSTM is a kind of RNN specifically designed for solving the gradient vanishing and

276

K. Yan et al.

gradient explosion problems that arise when training long sequences, which outperform conventional RNNs in processing longer sequences. The LSTM architecture consists of several components, including the input gate it , memory cell Ct , temporary cell state C˜ t , hidden state ht , forget gate ft , and output gate Ot at time t. In a traditional feedforward LSTM network, information flows only in the forward direction. However, the BiLSTM module handles the input sequence by processing it separately in both the forward and backward directions, generating two sets of outputs. The outputs from both directions are concatenated to produce the final output [13]. This is achieved through the use of two separate LSTM networks, one for each direction. Figure 3 illustrates the BiLSTM architecture.

Fig. 3. The architecture diagram of BiLSTM.

Hybrid Model The proposed hybrid DenseNet-BiLSTM method is illustrated in Fig. 4. After obtaining the transformed EEG time-frequency image matrices by short-time Fourier transform, they were fed into DenseNet network producing feature maps. Then, the feature maps obtained from the DenseNet were input into the BiLSTM, which utilizes the sequence information of features further extracting effective information to achieve predictions. Finally, the preictal and interictal periods are classified using the softmax function.

Fig. 4. Architecture of the DenseNet-BiLSTM hybrid model.

Seizure Prediction Based on Hybrid Deep Learning

277

The model’s input layer has an image size of 7 × 114 × channels. The image data is then inputted into a convolution layer that contains 64 filters, with each filter having a size of 7 × 7 and the stride is set to 2. The dense block in the model has a growth rate of 32 and applies 1-pixel zero-padding, resulting in a feature map size that remains the same. The transition layers is applied to dense blocks to reduce the number and spatial dimensionality of feature maps and to perform average pooling. After extracting features, a one-dimensional feature map is obtained by applying global average pooling. The feature map is then reshaped and fed into the BiLSTM for classification. 3.5 Postprocessing During the interictal period, there may be some sporadic false positives that can result in false alarms. To overcome this challenge, we incorporated a postprocessing step using the k-of-n technique on the classification results [14]. This method reduces the false alarm rate by requiring a certain number of consecutive positive predictions before triggering an alarm. During this experiment, we fixed the parameters n and k to be 10 and 8, respectively, for the k-of-n post-processing step, which means that at least 8 out of 10 consecutive EEG signal segments needed to be predicted as positive before the alarm was triggered.

4 Result This deep learning model was developed in a workstation environment equipped with an i9-10900K CPU and 64 GB of RAM. Training of the model was performed using a GeForce RTX 2060 GPU. Anaconda 3 was used as the development environment (Python 3.7), and the TensorFlow version 1.15 was employed. We chose the Adam optimizer for our model training with the learning rate set as 0.001. ReLU was utilized as the activation function and the growth rate was set to 32 while the reduction rate was 0.5. In this study, a SOP of 30 min and a SPH of 5 min were adopted following the definitions proposed by Maiwald et al. [15]. SOP refers to the period of time during which a person with epilepsy is at risk of experiencing a seizure. SPH denotes the time interval between the warning indicating the start of SOP and the actual seizure. If the model correctly predicts a seizure within the SOP stage, it is considered a true positive. On the other hand, if the prediction system detects a seizure during the SOP, but no actual seizure occurs, it is considered a false alarm. The performance of this model is evaluated by measuring several metrics, including accuracy, F1-score, FPR, Sensitivity, and AUC. Accuracy is a measure of the percentage of data that is correctly classified in the total dataset, while F1-score is a weighted average of both precision and sensitivity. FPR represents the proportion of incorrectly judged interictal periods as preictal state. AUC is calculated as the area below the Receiver Operating Characteristic (ROC) curve, and its physical meaning is the probability that any positive case and any negative case will be ranked before the negative case. Cross-validation is a widely employed method for assessing the performance of statistical methods by testing their generalizability to independent data sets. In this study, we adopted the LOOCV method in every case. Specifically, if a case had N seizures,

278

K. Yan et al.

we used (N−1) seizures to train the model, while other seizures were used for validation. Additionally, the interictal segment was randomly partitioned into N segments, with (N−1) segments used to train while the others were used for validation. The accuracy of classification was evaluated using 24 cases, achieving an average accuracy of 92.45%. With the exception of CHB05, CHB14, CHB16, CHB21, CHB22, and CHB24, all other cases had a classification accuracy above 90%. The following Table 1 shows the performance data for patients CHB01-CHB24. Table 1. The method performance metrics (sop = 30 min, sph = 5 min) Case

Accuracy (%)

Sensitivity (%)

FPR(/h)

F1-score

AUC

CHB01

99.99

99.96

0

0.99

0.99

CHB02

98.29

98.21

0.009

0.99

0.98

CHB03

92.81

93.3

0.066

0.93

0.92

CHB04

94.32

92.82

0.065

0.96

0.96

CHB05

79.7

81.71

0.211

0.81

0.80

CHB06

92.55

92.31

0.061

0.95

0.91

CHB07

93.51

94.11

0.043

0.96

0.94

CHB08

95.54

94.32

0.002

0.98

0.96

CHB09

94.66

95.21

0.031

0.96

0.94

CHB10

90.2

92.61

0.046

0.90

0.89

CHB11

98.82

98.36

0.008

0.98

0.97

CHB12

93.01

92.56

0.011

0.97

0.92

CHB13

90.32

91.67

0.126

0.95

0.89

CHB14

89.52

89.03

0.151

0.90

0.89

CHB15

91.22

91.93

0.06

0.92

0.90

CHB16

85.27

88.09

0.14

0.89

0.88

CHB17

93.77

93.81

0.007

0.95

0.94

CHB18

90.83

91.05

0.097

0.90

0.90

CHB19

99.96

99.99

0.006

0.99

0.99

CHB20

95.08

96.28

0.034

0.96

0.97

CHB21

89.67

89.96

0.139

0.90

0.89

CHB22

86.25

85.38

0.156

0.81

0.85

CHB23

97.43

97.66

0.027

0.99

0.97

CHB24

86.11

85.58

0.084

0.89

0.89

Average

92.45

92.66

0.066

0.923

0.936

Seizure Prediction Based on Hybrid Deep Learning

279

5 Discussion Our proposed seizure prediction method employs a combination of CNN and BiLSTM networks. DenseNet, which deviates from traditional CNNs, preserves critical information without relearning redundant feature maps, resulting in a more parsimonious model. Additionally, the dense connections within DenseNet facilitate direct access to the gradient of the loss function and the initial input for each layer. Simplifies training and the dense connection’s regularization effect helps reduce overfitting in tasks with small training sets. Given the temporal nature of EEG signals, BiLSTM networks are adept at capturing temporal information and establishing temporal dependencies. This assists in harnessing dynamic alterations within these signals, enabling the comprehensive utilization of bidirectional information in EEG signals to enhance the extraction of crucial features. Consequently, BiLSTM networks can effectively recognize features and patterns across extended time scales when processing EEG signals. Overfitting occurs when the model is excessively intricate, causing it fitting the training data too closely, but performing poorly on new data. To prevent overfitting in this experiment, an early stopping method is used. During the training process, data is usually split into two subsets, a training set and a validation set. At the conclusion of each epoch, the model’s performance is assessed on the validation set to monitor its progress and the best validation accuracy is recorded. If the validation accuracy begins to decrease after some number of epochs, the training process is halted and the model that has the best validation accuracy is selected for the final model. Early stopping is a balance between training duration and generalization bias, and stops the training when the model starts to overfit the training data. The imbalanced data across different categories is a frequent occurrence in EEG classification tasks, and to address this issue, we employed the oversampling technical in our study. By generating additional preictal segments while the training phase of the model, the preictal to interictal samples ratio could be balanced in the dataset. In the feature extraction based on DenseNet, we also adopted for global average pooling to produce feature maps and output them in one-dimensional vectors, instead of utilizing fully connected layers with an excessive number of parameters. This approach helped us to mitigate possible overfitting issues that can arise due to the neural network. The cross-entropy loss function is a way to measure the predicted value of a neural network against the actual value. It works by transforming the model into an entropy value, which is then used to compare the discrepancies among the models. To guarantee convergence of the model parameters during training, we optimized the cross-entropy loss function. Figure 5 (a) shows the convergence of the loss on both the training and validation sets. To assess the validity of the proposed method, we partitioned the EEG samples into training and validation sets and employed the LOOCV method for estimation. As shown in Fig. 5 (b), the accuracy stabilized after 30 to 40 epoch iterations. The performance of this approach compared with other proposed methods is shown in Table 2. Shahbazi et al. [16] used STFT to preprocess the EEG signal then input the preprocessed EEG signal to a CNN-LSTM model at 0.13/h FRP and 98.21% sensitivity. Although their sensitivity is higher than ours, but we used the full set of patients from the dataset with much larger sample size than theirs. Ryu et al. [17] used discrete wavelet

280

K. Yan et al.

Fig. 5. The loss curve and accuracy curve in the training phase of a randomly selected patient. (a) loss curve; (b) accuracy curve.

transform for preprocessing and fed the transformed data into a DenseNet-LSTM network, which was experimented on the CHB-MIT dataset and obtained 93.28% prediction accuracy, 0.063/h FPR and 0.923 F1-scores. In contrast to our approach, Ryu et al. use 2 Dense Blocks and 2 Transition Layers, while our model uses 4 Dense Blocks and 3 Transition Layers, which makes the model richer and more effective in feature extraction. Moreover, the BiLSTM employed in our model is able to capture the bidirectional information in the input sequence, which helps the model to better understand and utilize the timing information in the EEG signal. Although our performance is comparable, the use of SOP and SPH concepts helps the model to be more accurate in predicting seizures, which were not used in the study by Ryu et al., and this may affect its prediction performance. Zhang et al. [8] used wavelet packet decomposition and CSP feature extractor to extract features, combined a shallow CNN for discriminating preictal and interictal states. Their model obtained sensitivity of 92.2%, FPR of 0.12/h, and accuracy of 90%. Our proposed method outperforms their method for all performance metrics. Recently, Gao et al. [18] introduced an end-to-end approach to predict seizures through a spatiotemporal multiscale CNN with dilated convolution. This method was assessed with sensitivity of 93.3% and FPR of 0.007 per hour using 16 cases of the same dataset. It is noteworthy that our method used the full dataset of patients without any selection and had similar performance. LEE et al. [19] developed a model that combined ResNet and LSTM, and converted the data into spectrogram images by STFT as the input of the model, and validated the performance using LOOCV of the CHB-MIT dataset to obtain 91.90% accuracy, 89.64% sensitivity, and FPR per 0.058/h. Our method exhibits superior accuracy and sensitivity compared to the aforementioned method, while its performance in terms of FPR is comparable.

Seizure Prediction Based on Hybrid Deep Learning

281

Table 2. Performance comparison on the CHB-MIT dataset for different methods. Authors

Year

Number Method of patients

Acc (%)

Sen (%)

FPR (/h)

AUC

F1-score

Khan et al. 2017 13 [7]

CWT + CNN



87.8

0.147 –



Truong et al. [14]

STFT + CNN



81.2

0.16





Mohamad 2018 14 et al. [16]

STFT + CNN-LSTM



98.2

0.13





Zhang et al. [8]

2019 23

wavelet packet + CNN

90.0

92.0

0.12

0.90

0.91

Ryu et al. [17]

2021 24

DWT + DenseNet-LSTM

93.28 92.92 0.063 –

0.923

Gao et al. [18]

2022 16

Raw EEG + Dilated CNN



0.007 –



LEE et al. [19]

2022 24

STFT + ResNet + LSTM

91.90 89.64 0.058 –



2018 13

This work 2023 24

93.3

STFT + 92.45 92.66 0.066 0.936 0.923 DenseNet-BiLSTM

6 Conclusion A novel approach for seizure prediction was introduced utilizing deep learning techniques on scalp EEG data. The DenseNet-BiLSTM hybrid model is trained using EEG signals preprocessed via the STFT, enabling the automatic extraction of time-frequency features by DenseNet and subsequent classification by the BiLSTM network. This method circumvents the need for intricate manual feature extraction and offers substantial advantages in computational cost and time, thereby enhancing efficiency in processing large-scale data. The reliability and suitability of our proposed model for epilepsy prediction are substantiated through validation on the CHB-MIT dataset and comparison with prior studies. The exceptional performance signifies potential practical applications. To achieve higher performance, future validation on a broader range of datasets is necessary. Acknowledgment. This work was supported by the Program for Youth Innovative Research Team in the University of Shandong Province in China (No. 2022KJ179), and jointly supported by the National Natural Science Foundation of China (No. 61972226, No. 62172253).

References 1. Preux, P.-M., Druet-Cabanac, M.: Epidemiology and aetiology of epilepsy in sub-saharan Africa. Lancet Neurol. 4(1), 21–31 (2005)

282

K. Yan et al.

2. Savadkoohi, M., Oladunni, T., Thompson, L.: A machine learning approach to epileptic seizure prediction using electroencephalogram (EEG) signal. Biocybern. Biomed. Eng. 40(3), 1328–1341 (2020) 3. Usman, S.M., Usman, M., Fong, S.: Epileptic seizures prediction using machine learning methods. Computat. Math. Meth. Med. (2017) 4. Williamson, J.R., Bliss, D.W., Browne, D.W., Narayanan, J.T.: Seizure prediction using EEG spatiotemporal correlation structure. Epilepsy Behav. 25(2), 230–238 (2012) 5. Liang, W., Pei, H., Cai, Q., Wang, Y.: Scalp EEG epileptogenic zone recognition and localization based on long-term recurrent convolutional network. Neurocomputing 396, 569–576 (2020) 6. Daoud, H., Bayoumi, M.A.: Efficient epileptic seizure prediction based on deep learning. IEEE Trans. Biomed. Circuits Syst. 13(5), 804–813 (2019) 7. Khan, H., Marcuse, L., Fields, M., Swann, K., Yener, B.: Focal onset seizure prediction using convolutional networks. IEEE Trans. Biomed. Eng. 65(9), 2109–2118 (2017) 8. Zhang, Y., Guo, Y., Yang, P., Chen, W., Lo, B.: Epilepsy seizure prediction on EEG using common spatial pattern and convolutional neural network. IEEE J. Biomed. Health Inform. 24(2), 465–474 (2019) 9. Shoeb, A.H.: Application of machine learning to epileptic seizure onset detection and treatment. PhD thesis, Massachusetts Institute of Technology (2009) 10. Stockwell, R.G., Mansinha, L., Lowe, R.: Localization of the complex spectrum: the s transform. IEEE Trans. Signal Process. 44(4), 998–1001 (1996) 11. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4700–4708 (2017) 12. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 13. Siami-Namini, S., Neda, T., Namin, A.S.: The performance of LSTM and BiLSTM in forecasting time series. In: IEEE International Conference on Big Data (Big Data), pp. 3285–3292 (2019) 14. Truong, N.D., et al.: Convolutional neural networks for seizure prediction using intracranial and scalp electroencephalogram. Neural Netw. 105, 104–111 (2018) 15. Maiwald, T., Winterhalder, M., Aschenbrenner-Scheibe, R., Voss, H.U., Schulze-Bonhage, A., Timmer, J.: Comparison of three nonlinear seizure prediction methods by means of the seizure prediction characteristic. Physica D 194(3–4), 357–368 (2004) 16. Shahbazi, M., Aghajan, H.: A generalizable model for seizure prediction based on deep learning using CNN-LSTM architecture. In: 2018 IEEE Global Conference on Signal and Information Processing (GlobalSIP), pp. 469– 473 (2018) 17. Ryu, S., Joe, I.: A hybrid densenet-lstm model for epileptic seizure prediction. Appl. Sci. 11(16), 7661 (2021) 18. Gao, Y., et al.: Pediatric seizure prediction in scalp EEG using a multi-scale neural network with dilated convolutions. IEEE J. Transl. Eng. Health Med. 10, 1–9 (2022) 19. Lee, D., et al.: A resnet-lstm hybrid model for predicting epileptic seizures using a pretrained model with supervised contrastive learning. Res. Square (2022)

Data Augmentation for Environmental Sound Classification Using Diffusion Probabilistic Model with Top-K Selection Discriminator Yunhao Chen1(B)

, Zihui Yan1 , Yunjie Zhu2 , Zhen Ren1 , Jianlu Shen1 , and Yifan Huang1

1 Jiangnan University, Wuxi 214000, China

[email protected] 2 University of Leeds, Leeds LS2 9JT, UK

Abstract. Despite consistent advancement in powerful deep learning techniques in recent years, large amounts of training data are still necessary for the models to avoid overfitting. Synthetic datasets using generative adversarial networks (GAN) have recently been generated to overcome this problem. Nevertheless, despite advancements, GAN-based methods are usually hard to train or fail to generate high-quality data samples. In this paper, we propose an environmental sound classification (ESC) augmentation technique based on the diffusion probabilistic model (DPM) with DPM-Solver ++ for fast sampling. In addition, to ensure the quality of the generated spectrograms, we propose a top-k selection technique to filter out the low-quality synthetic data samples. According to the experiment results, the synthetic data samples have similar features to the original dataset and can significantly increase the classification accuracy of different state-of-the-art models compared with traditional data augmentation techniques. The public code is available on https://github.com/JNAIC/DPMs-for-Audio-Data-Augmentation. Keywords: Diffusion Probabilistic Models · Data Augmentation · Environmental Sound Classification

1 Introduction Deep learning models, such as CNNs, transformers, and CNN-RNN networks, have significantly improved sound classification [2–5]. However, these models demand large amounts of data to achieve efficient performance due to their extensive parameters. Limited training data poses a major challenge for deep learning methods, and data annotation further compounds the laborious and expensive nature of developing supervised models. To address it, data augmentation is proposed as a solution to overcome the scarcity of training samples. Traditional data augmentation techniques, like flip, shift, and masking [2, 6], are limited by their simplicity and linearity. These methods cannot effectively improve the performance of deep learning models such as CNNs or transformers. To enhance accuracy, either augmentation techniques of similar complexity to deep learning models © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 283–295, 2023. https://doi.org/10.1007/978-981-99-4742-3_23

284

Y. Chen et al.

or generative models representing the probability distribution of real data samples are needed. Researchers have turned to generative adversarial networks (GANs) [2, 4, 6] to synthesize new data samples that closely resemble real ones, thereby improving deep learning performance. However, despite considerable amounts of GANs’ applications in data augmentation, they are subject to unstable training processes, model collapse issues and failure to represent a broad enough data distribution. [7, 8] Consequently, GANs are challenging to be scaled rapidly to new applications. DPMs [9, 10] have gained popularity over GANs, particularly in tasks such as image synthesis [10], limited data medical image synthesis [11], and topology optimization [12]. We leverage DPMs for high-quality data augmentation. However, the conventional sampling procedures of popular diffusion models, like denoising diffusion implicit models (DDIM) [13], require 50 to 100 steps to generate high-quality data samples, which is time-consuming. In contrast, DPM-Solver++ [14] achieves similar results in just 10 to 20 steps. Therefore, we adopt DPM-Solver++ for efficient sampling in our strategy. DPMs can generate diverse data samples from complex distributions by reversing a Markov chain of Gaussian diffusion processes. However, the quality of the data samples generated by these models is not always satisfactory, as they may contain artefacts, blurriness or inconsistency with the target distribution. Consequently, we propose a Top-k Selection method to filter out inappropriate data samples based on a pretrained model to address this problem. Our method can improve the quality of data generation without modifying the DPMs. To summarize, the main contributions of this paper are as follows: 1. We present the first study on applying DPM to generate high-quality data samples for ESC tasks based on a popular sound dataset, UrbanSound8K [15]. 2. We propose a post-processing approach called top-k selection based on a pre-trained discriminator. This approach automatically eliminates samples with low quality and insufficient representation after training. 3. We evaluate seven SOTA deep learning (DL) models for ESC on UrbanSoun8K with real and synthetic data generated. We train the models from scratch and show significant accuracy improvement with synthetic data.

2 Method 2.1 DPMs DPMs [16, 17] are a class of generative models that convert Gaussian noise into samples from a learned data distribution via an iterative denoising process. Non-equilibrium thermodynamics serves as the basis for diffusion models. To gradually introduce random noise to the data, [16, 17] establish a Markov chain of diffusion steps. Then they figure out how to reverse the diffusion process to create the desired data samples from the noise using deep learning methods. Forward Diffusion Process. A forward diffusion process is defined as a process that adds Gaussian noise to a data sample x0 sampled from a real data distribution q(x) over T steps, resulting in a sequence of noisy samples x1 , . . . , xT . The amount of noise added

Data Augmentation for Environmental Sound Classification

285

  at each step is determined by a variance schedule βt ∈ (0, 1)Tt=1 . T     q(xt |xt−1 ) = N xt ; 1 − βt xt−1 , βt I q(x1:T |x0 ) = q(xt |xt−1 )

(1)

t=1

The data sample x0 gradually loses its distinctive features as t increases. When T approaches infinity, xT converges to an isotropic Gaussian distribution. An advantage of this process is that we can obtain samples at any arbitrary time step using a closed-form expression with the reparameterization trick.   xt = αt x0 + 1 − αt  We can come to the following equation:    q(xt |x0 ) = N xt ; αt x0 , (1 − αt )I where αt = 1 − βt , αt =

t

i=1 αi , t−1 , t−2 , · · ·

(2)

∼ N (0, I ).

Reverse Diffusion Process. Reversing this process and sampling from q(x t-1 |x t ) would enable us to reconstruct the true sample from a Gaussian noise input, xT ∼ N (0, I ). If βt is sufficiently small, q(x t-1 |x t ) will also be Gaussian. However, estimating q(x t-1 |x t ) is not feasible because it requires using the entire dataset. Therefore, we need to learn a model pθ that approximates these conditional probabilities for running the reverse diffusion process. pθ (xt−1 |xt ) = N (xt−1 ; μθ (xt , t), θ (xt , t))

(3)

where θ is a learnable parameter vector in the Gaussian distribution’s mean function μθ (xt , t) and standard deviation function θ (xt , t), the data samples generated by this distribution can be represented as: pθ (x0:T ) = p(xT )

T t=1

pθ (xt−1 |xt )

(4)

By applying the learned parameters θ to the mean function μθ (xt , t) and the standard deviation function θ (xt , t). In brief, the forward diffusion process adds noise to the data sample, while the reverse diffusion process removes the noise and creates new data samples. Training Objective of DPMs. Similar to Variational Autoencoder [25], the variational lower bound can be used to optimize the negative log-likelihood as follows:

286

Y. Chen et al.

We can come to the following equations:

q(x1:T |x0 ) LVLB = Eq(x0:T ) log ≥ −Eq(x0) log pθ (x0 ) pθ (x0:T )

(5)

To make each term in the equation analytically computable, we can reformulate the objective as a combination of several terms involving KL divergence and entropy. The objective can be rewritten as follows:

Since x0 follows a fixed data distribution and xT is a Gaussian noise, LT is a constant term. We can also interpret L0 as the entropy of the multivariate Gaussian distribution, because p(x 0 |x 1 ) is a Gaussian distribution with mean μθ (x1 , 1) and covariance matrix θ . The loss term Lt , t ∈ [1, 2, 3, . . . , T − 1] can be parameterized as: Lt = Ex0 ,t [

  βt2  −  ( α x + 1 − α t t , t)2 ] + C t θ t 0 2αt (1 − αt )Σ 2θ

(6)

According to [17], the diffusion model can be trained more effectively by using a simplified objective that does not include the weighting term:    Lsimple (θ) = Ex0 ,t t − θ α t x0 + 1 − α t t , t 2 + C (7) Classifier-Free Guidance for Conditional Generation. Conditional diffusion can be performed by combining scores from both a conditional and an unconditional DPM. The unconditional DPM pθ (x) is parameterized by a score estimator θ (xt , t), while the is parameterized byθ (xt , t, y). A single neural network can conditional model learn these two models simultaneously. The implicit classifier’s gradient can be expressed with conditional and unconditional score estimators. The classifier-guided modified score, which incorporates this gradient, does not depend on a separate classifier. ∇xt log p(y|xt ) = − √

1 (∈θ (xt , t, y)− ∈θ (xt , t)) 1 − α¯ t

(8)

Data Augmentation for Environmental Sound Classification

287

 θ (xt , t, y) = (w + 1)θ (xt , t, y) − wθ (xt , t)

(9)

In this study, we present the first study on applying the classifier-free guidance. DPM to generate high-quality data samples for ESC tasks.

2.2 DPM-Solver++ and DPM-Solver One of the main challenges of working with DPMs is the high computational cost and time required to generate data samples from the complex posterior distributions. To overcome this limitation, we adopt DPM-Solver++ as our sampling method. DPMSolver is a high-order solver that can generate high-quality samples in around 10 steps by solving the diffusion ODE with a data prediction model. DPM-Solver++ is an improved version of DPM-Solver that can handle guided sampling by using thresholding methods to keep the solution matching the training data distribution. DPM-Solver++ and DPM-Solver significantly enhance the efficiency of training-free samplers in the “few-step sampling” regime, where sampling can be accomplished within approximately 10 sequential function evaluations. These methods address the challenge of sampling from DPMs by solving the corresponding diffusion ordinary differential equations (ODEs). The diffusion ODEs exhibit a semi-linear structure consisting of a linear function dependent on the data variable and a nonlinear function parameterized by neural networks. By analytically computing the linear portion of the solutions, DPM-Solver++ and DPM-Solver avoid discretization errors that can arise from the corresponding discretization process. Additionally, the solutions can be simplified using change-of-variable techniques, enabling efficient computation through numerical methods for exponential integrators. The customized solver for diffusion ODEs is shown below: xti−1 →ti =

λti

k−1 (n) (λ − λti−1 )n αti dλ x˜ ti−1 − αti ˆθ (ˆxλti−1 , λti−1 ) ∫ e−λ n=0 αti−1 n! λti−1

(10)

where xs is an initial value at times > 0, xt is the solution at time t ∈ [0, s] and . As λ(t) = λt is a strictly decreasing function of t, it has an inverse function tλ (·) satisfying t = tλ (λ(t)). The DPM-Solver++ further changes the sub. The scripts of x and θ fromt to λ and denote   O hik+1 is omitted in the above equation because it is a high-order error. DPM-Solver performs guidance-free sampling, while DPM-Solver++ utilizes guidance to enhance sample quality and diversity. Incorporating additional information like text or images, DPM-Solver++ improves the sampling process of DPMs for superior sample outcomes.

288

Y. Chen et al.

DPM-Solver++ corporates thresholding to adapt to guided sampling. Thresholding restricts the solution of the diffusion ODE within the training data distribution, enhancing sample quality and diversity. Two types of thresholding are employed: dynamic and static. Dynamic thresholding adjusts the threshold based on noise level and guidance scale, while static thresholding uses a fixed value. By combining both approaches, DPMSolver++ achieves a balance between stability and efficiency. 2.3 Data Augmentation Data augmentation is a potent tactic to broaden the current data range and enable model training without requiring new data collection. In this research, standard and intelligent data augmentation methodologies are also taken into consideration. Two distinct audio data deformations are applied in conventional data augmentation. First, certain background noises were added to the data samples, including crowd, street, and restaurant sounds (the background noises were taken from publicly available recordings made available on the “freesound.org” website [18]). Second, the records are subjected to pitch shifting [2]}. The audio samples’ pitch is adjusted by a half-octave (up and down) to produce various sounds. The audio stream is subjected to each contortion before being transformed into the input representation. Thirdly, the audio sample augmentation implements the time stretch [22]. For intelligent data augmentation, we use the U-net structure in the [20] for DPMs and DPM-Solver++ for the sampling schedule. 2.4 Top-k Selection Pretrained Discriminator One of the challenges of using DPMs for data augmentation is that the quality of the generated samples may vary depending on the amount of available data and computational resources. To address this issue, we propose to use a pretrained discriminator network to filter out the low-quality samples and retain only the ones that are realistic and diverse enough to augment the training data. The discriminator network is an Xception [27] trained on the entire dataset to classify the images into their corresponding labels. The filtering criterion is based on the top-k accuracy of the discriminator, i.e., we accept a generated sample if its accurate label is among the top-k predictions of the discriminator. Otherwise, we reject it. This way, we ensure that the generated samples are visually plausible and semantically consistent with their labels. The number of accepted samples can be expressed as follows: G=

N i=1

I (fk (xi , ci ) = ci )

(11)

where ci is the label of the i-th generated sample, xi is the i-th generated sample, fk is the discriminator network with top-k prediction, and N is the number of generation epochs.

Data Augmentation for Environmental Sound Classification

289

2.5 DL Models for ESC One of the main objectives of this study is to assess the quality of the synthetic data samples generated by DPMs for environmental classification tasks. To this end, we propose to use the synthetic data samples to augment the original training data, and then train different deep learning (DL) classifiers on the augmented data. We hypothesize that the augmented data can improve the performance of models. To test this hypothesis, we select seven SOTA DL models from different architectures and paradigms, namely ResNet-50[19], Xception [27], ConViT-tiny [28], mobilevitv2–50 [29], mobilevitv2– 150 [29], ConvNext-tiny [30] and Deit III [31]. These models are implemented using the timm library [23], and their hyperparameters are set to default. We evaluate the performance of these models on the UrbanSound8K dataset.

3 Experiments 3.1 Experiments Pipeline and Dataset In the proposed experimental settings, the procedure (Fig. 1) involves several steps. Audio is transformed into a mel-spectrogram as the foundation. Training samples are augmented, resized to 128 × 128, and added to the original dataset. Simultaneously, a discriminator model is trained on the augmented dataset. Using the augmented dataset, a DPM is trained to generate data samples. A trained discriminator filters out unqualified samples, resulting in 8730 qualified synthetic data samples. Two comparison experiments are conducted: one using the original data samples with traditional augmentation techniques and the other incorporating both original and synthetic data samples into the training process. UrbanSound8K is an audio dataset that contains 8732 labelled sound excerpts(≤ 4s) of urban sounds from 10 classes. The dataset can be used for various tasks such as urban sound classification, sound event detection, acoustic scene analysis, etc.

Fig. 1. An overview of the proposed pipeline for generating synthetic data samples. The pipeline consists of four main components: data augmentation, diffusion probabilistic modelling, discriminator filtering, and comparison experiments.

290

Y. Chen et al.

3.2 Hyperparameters Setting Data Preprocessing. This section investigates the performance of the DPM using the UrbanSound8K dataset to augment the data samples and optimize the classification model for environmental sound recognition. The data samples have a length of up to 5 s. The MFCC feature extraction is utilized, and the mel-spectrograms are generated. The original images have a dimension of 768 × 384. To minimize the number of training parameters in the DPM training process, the original images are resized to 128 × 128, where each image has 128 frames and 128 bands. Hyperparameters Setting for Data Augmentation. We apply stochastic data augmentation techniques using the hyperparameter p to enhance model training stability. Each data sample undergoes at least one and up to two random transformations (see Table 1 for details). To augment a data sample with city ambience noise, we randomly select a segment of equal length and superimpose it by adding their amplitudes. Pitch shifting alters a sound’s pitch by changing its frequency. The pitch shift factor quantifies the degree of pitch shifting, with a value of 2 indicating doubling the frequency. Up and down pitch shifting are performed separately. Time stretch modifies the speed or duration of an audio signal without affecting pitch. The minimum and maximum rate parameters control the range of speed change factors. Table 1. Table captions should be placed above the tables. Heading level

p

Setting

Noise from City Ambience

0.6

Ambience’s weight is 0.6

Up Pitch Shifting

0.8

The pitch shift factor is 2

Down Pitch Shifting

0.8

The pitch shift factor is 2

Time stretch

0.7

Min: 0.8 Max: 1.25

DPM Training Setting. In this paper, we use the U-net [21] structure from [20] to estimate pθ . U-net incorporates skip connections and four upsampling operations to capture fine details and enable multi-scale prediction. It suffers less information loss during sampling compared to other methods. The Unet architecture is configured with parameters: dim = 64, dim_mults = (1, 2, 4, 8), resnet_block_groups = 8, and learned_sinusoidal_dim = 16.

Data Augmentation for Environmental Sound Classification

291

We train the U-Net network using AdamW [24], which is a variant of Adam incorporating weight decay regularization with a learning rate of 0.0001 and weight decay of 0.05. The mean squared error (MSE) loss function measures the discrepancy between predicted and ground truth images. The network is trained for 3500 epochs on the augmented dataset. To demonstrate effectiveness, we avoid using techniques like model transfer, EMA, pretraining, or other tricks. DPM Sampling Setting. Two versions of DPM-Solver ++ are used: DPM-Solver++ (2S) as a second-order single-step solver, and DPM-Solver++ (2M) as a multistep second-order solver. The latter addresses instability in high-order solvers by reducing the effective step size. In this study, we employ the 2M version as the sampling schedule, implemented via diffusers [32]. Initial and final values of β for inference are 0.0001 and 0.02, respectively. β regulates the balance between data likelihood and the prior distribution over latent variables, which capture the underlying structure of the data. A linear method is employed for the β schedule, progressively increasing β from the initial value to the final value over a predetermined number of iterations. The solver type for the second-order solver is the midpoint method, approximating the solution of the differential equation by using the midpoint of an interval as an estimate. The number of inference steps is 20. Classification Models’ Training Setting. We trained our model for 500 epochs with a batch size of 30. We used the AdamW optimizer with the same hyperparameters as the DPMs training optimizer. We used the cross-entropy loss function with label smoothing of 0.1 to prevent overfitting and improve generalization. We kept the other hyperparameters as default in the timm. The hyperparameter k of top-k selection is set as one in the following experiments. The experiments are performed on a computer with a 13th Gen lntel R CoreTM i9-13900KF CPU and a GeForce RTX 4090 GPU. Discriminator Training Setting. This paper utilizes the Xception model from timm [23] as the backbone, known for its high accuracy in image recognition. The model is configured with 3 input channels, a dropout rate of 0, and average pooling for global pooling. The training hyperparameters are consistent with classification models.

292

Y. Chen et al.

3.3 Experiments Results Results for Different SOTA Models. Figure 2 illustrates random visual examples of the generated spectrograms using DPMs. DPMs have a high capability to produce spectrograms that have similar structures. Testing performance and generalization of DL models are evaluated using the reliable 10-fold method, with the final accuracy computed as the mean of the 10 test results.

Fig. 2. Real (right panel) and generated (left panel) audio samples intelligent augmentation.

Table 2. Evaluation of DL models on ESC trained with and without the samples generated with the proposed data augmentation. The best values are in bold. Models

Parameters

Real + Traditional Top1 Accuracy(%)

Real + Synthetic Top1 Accuracy(%)

ResNet-50

23528522

73.8%

80.1%

Xception

20827442

72.5%

80.2%

ConViT-tiny

5494098

65.1%

68.6%

Mobilevitv2–50

1116163

59.2%

62.0%

Mobilevitv2–150

9833443

68.3%

74.6%

ConvNext-tiny

27827818

60.7%

65.9%

Deit III

85722634

67.4%

72.2%

Table 2 presents the testing performance of 7 DL models on UrbanSound8K. We compare performance with both synthetic augmentation + real dataset and traditional data augmentation + real dataset. Incorporating synthetic images into the expanded datasets significantly improves all DL models. Training becomes more stable (Fig. 3) with higher accuracy and lower losses. For instance, with synthetic augmented images, Inception-v3

Data Augmentation for Environmental Sound Classification

293

and ResNet-50 achieve classification accuracies of 80.1% and 80.2%, respectively, marking around 6.3% and 7.6% improvements over the baseline models without synthetic augmentation.

Fig. 3. Training loss curves of different DL models on the UrbanSound8K dataset with and without synthetic data augmentation.

Table 3. The table shows effects of k on accuracy of DL models trained on synthetic + real dataset. When k ≥ 5, the generation of data samples remains unaffected in training setup. Method

Different Values of Hyperparameter k 1

2

3

4

≥5

ResNet-50

80.1%

80.2%

77.3%

76.7%

77.4%

Xception

80.2%

78.7%

74.8%

74.5%

74.0%

ConViT-tiny

68.6%

68.7%

64.5%

64.5%

64.6%

Mobilevitv2–50

62.0%

62.8%

63.0%

65.5%

69.0%

Mobilevitv2–150

74.6%

73.5%

73.5%

76.4%

72.6%

ConvNext-tiny

65.9%

62.3%

62.0%

61.0%

62.8%

Deit III

72.2%

72.8%

71.3%

69.2%

67.8%

Influence of Hyperparameter++ for Top-k Selection. To assess how the hyperparameter k affects the outcome of synthetic data generation, we perform experiments with different values of k from 1 to 10 and compare the results. The results are shown in Table 3. The results show that the top-k selection strategy enhances most DL models’ performance. The top-k strategy reduces noise and ambiguity in augmented data. As shown in Table 3, the top-k strategy significantly improves accuracy for six ones.

294

Y. Chen et al.

4 Conclusion This paper presents a novel application of diffusion models for generating high-quality synthetic images from sound recordings. It is the first study to explore diffusion models for data augmentation in ESC, accompanied by a new selection method based on topk confidence scores to filter out low-quality synthetic images and retain informative ones per sound class. Extensive experiments on the UrbanSound8K dataset validate the effectiveness of diffusion models in generating realistic and diverse synthetic images. These images significantly improve classification accuracy, reduce losses, and enhance data balance among classes through the top-k selection method.

References 1. Ho, J., et al.: Denoising diffusion probabilistic models. Adv. Neural. Inf. Process. Syst. 33, 6840–6851 (2020) 2. Salamon, J., Bello, J.P.: Deep convolutional neural networks and data augmentation for environmental sound classification. IEEE Sig. Process. Lett. 24, 279–283 (2016) 3. Gong, Y., et al.: AST: Audio Spectrogram Transformer. ArXiv abs/2104.01778 (2021) 4. Bahmei, B., et al.: CNN-RNN and data augmentation using deep convolutional generative adversarial network for environmental sound classification. IEEE Sign. Process. Lett. 29, 682–686 (2022) 5. Hershey, S., et al.: CNN architectures for large-scale audio classification. In: IEEE International Conference on Acoustics, Speech and Signal Processing, pp. 131–135 (2016) 6. Zhu, X., et al.: Emotion classification with data augmentation using generative adversarial networks. In: Pacific-Asia Conference on Knowledge Discovery and Data Mining, pp. 349– 360 (2018) 7. Arjovsky, M., et al.: Wasserstein GAN. ArXiv abs/1701.07875 (2017) 8. Zhao, H., et al.: Bias and generalization in deep generative models: an empirical study. Neural Inf. Process. Syst. 13 (2018) 9. Ho, J., et al.: Denoising Diffusion Probabilistic Models. ArXiv abs/2006.11239 (2020) 10. Dhariwal, P., Nichol, A.: Diffusion Models Beat GANs on Image Synthesis. ArXiv abs/2105.05233 (2021) 11. Müller-Franzes, G., et al.: Diffusion Probabilistic Models beat GANs on Medical Images. ArXiv abs/2212.07501 (2022) 12. Maz’e, F., Ahmed, F.: Diffusion Models Beat GANs on Topology Optimization (2022) 13. Song, J., et al.: Denoising Diffusion Implicit Models. ArXiv abs/2010.02502 (2020) 14. Cheng, L., et al.: DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models. ArXiv abs/2211.01095 (2022) 15. Cordts, M., et al.: The cityscapes dataset for semantic urban scene understanding. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition, pp. 3213–3223 (2016) 16. Dickstein, S., Narain, J., et al.: Deep Unsupervised Learning using Nonequilibrium Thermodynamics. ArXiv abs/1503.03585 (2015) 17. Saharia, C., et al.: Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding. ArXiv abs/2205.11487 (2022) 18. Font, F., et al.: Freesound technical demo. In: Proceedings of the 21st ACM International Conference on Multimedia, pp. 411–412 (2013) 19. He, K., et al.: Deep residual learning for image recognition. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2015)

Data Augmentation for Environmental Sound Classification

295

20. lucidrains.2023.denoising-diffusion-pytorch (2023). https://github.com/lucidrains/denois ing-diffusion-pytorch 21. Ronneberger, O., et al.: U-Net: Convolutional Networks for Biomedical Image Segmentation. ArXiv abs/1505.04597 (2015) 22. Iwana, B.K., Uchida, S.: An empirical survey of data augmentation for time series classification with neural networks. Plos One 16 (2020) 23. rw2019timm, Ross Wightman, PyTorch Image Models (2019). https://github.com/rwightman/ pytorch-image-models 24. Loshchilov, I., Hutter, F.: Decoupled weight decay regularization. In: International Conference on Learning Representations (2017) 25. Kingma, D.P., Welling, M.: Auto-Encoding Variational Bayes. ArXiv. /abs/1312.6114 (2013). Accessed 22 March 2023 26. Ho, J.: Classifier-Free Diffusion Guidance. ArXiv abs/2207.12598 (2022) 27. Chollet, F.: Xception: deep learning with depthwise separable convolutions. In: 2017 IEEE Conference on Computer Vision and Pattern Recognition, pp. 1251-1258 (2016) 28. d’Ascoli, S., et al.: ConViT: improving vision transformers with soft convolutional inductive biases. J. Statist. Mech. Theory Experiment 2022 (2021) 29. Mehta, S., Rastegari, M.: MobileViT: Light-weight, General purpose, and Mobile-friendly Vision Transformer. ArXiv abs/2110.02178 (2021) 30. Liu, Z., et al.: A ConvNet for the 2020s. In: 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11966–11976 1800–1807 (2022) 31. Touvron, H., et al.: DeiT III: Revenge of the ViT. ArXiv abs/2204.07118 (2022) 32. .von-platen-etal-2022-diffusers, Patrick von Platen et al. 2022, Diffusers: State-of-the-art diffusion models. https://github.com/huggingface/diffusers 33. Chen, Y., et al.: Effective audio classification network based on paired inverse pyramid structure and dense MLP Block. ArXiv abs/2211.02940 (2022)

Improved DetNet Algorithm Based on GRU for Massive MIMO Systems Hanqing Ding , Bingwei Li(B) , and Jin Xu Zhengzhou University of Light Industry, Zhengzhou 450000, China [email protected]

Abstract. Massive Multiple-Input Multiple-Output (MIMO) technology is widely used to achieve high system capacity and spectral efficiency in wireless communication systems. The application of deep learning algorithms in massive MIMO signal detection has attracted great attention with the development of artificial intelligence in recent years. DetNet is a deep detection network for massive MIMO systems. By introducing the reset gate and update gate mechanism of the gate recurrent unit, this paper proposes an improved massive MIMO detection algorithm GRU-DetNet. Then, a hybrid neural network (Hybrid-DetNet) model with a parallel structure is further proposed. Simulation results show that the proposed GRU-DetNet and the Hybrid-DetNet achieve about 0.5 dB and 1 dB performance gain respectively over the existing DetNet scheme. The running time is reduced by about 50% under the same conditions. In addition, the proposed method has good universality since it can handle various modulation modes by training only a single network. Keywords: Massive MIMO detection · GRU-DetNet · Hybrid-DetNet · BER performance

1 Introduction Massive multiple-input multiple-output (MIMO) can provide more flexible spatial multiplexing and higher diversity gain, thus achieving higher system capacity and spectral efficiency [1]. However, when the number of antennas and the order of modulation constellation increase, it is difficult to balance the bit error rate (BER) performance and computational complexity for traditional MIMO detection algorithms [2]. In recent years, with the development of artificial intelligence, the application of deep learning algorithms in the field of communication has attracted much attention [3, 4]. In [5], a partial learning-based detection algorithm was proposed for a massive MIMO system. By optimizing trainable parameters, a model-driven massive MIMO detection algorithm was proposed in [6]. In order to further improve the model-driven detection algorithm, Peiyan Ao et al. also incorporated a semi-supervised learning approach [7]. Meanwhile, to improve the conventional MIMO detection algorithm, Tan et al. considered a deep neural network-assisted message-passing detector DNN-MPD [8]. And Zhou et al. proposed an RCNet model for symbol detection of MIMO-OFDM with © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 296–307, 2023. https://doi.org/10.1007/978-981-99-4742-3_24

Improved DetNet Algorithm Based on GRU

297

online learning [9]. Many researchers have considered adding elements such as modulation identification [10], channel estimation [11], channel feedback [12] and spatial reuse into the model [13]. Through the comparative study of hybrid neural networks with parallel structures and single networks in [14], the former has better detection performance than the latter. In [15], a MIMO detection based on the projection gradient descent algorithm, called DetNet, was proposed to iteratively extend deep learning. Although the DetNet neural network achieves good performance in massive MIMO detection, it can still be improved in terms of performance and complexity. The contributions of this paper are as follows: By improving the DetNet model, the GRU-DetNet and Hybrid-DetNet models are proposed to cope with massive MIMO detection. A comparative study under different channel models and modulation constellations is carried out to reveal that the proposed model and algorithm achieve better detection performance with lower complexity. In addition, the parallel structured Hybrid-DetNet model is proposed to significantly reduce the training and testing time.

2 System Model 2.1 Real-valued Models for MIMO Systems The complex-valued baseband model of the MIMO system can be expressed as y = H x+n.

(1)

where y ∈ CN ×1 is the received signal vector, x ∈ SK×1 is the training signal vector, H ∈ CN ×K is the channel coefficient matrix, n ∈ CN ×1 is a white noise vector in complex baseband form. In order to simplify the calculation of the detection process of a large-scale MIMO system, the complex model is equivalent to the following real model, y = H x+n .  where y =



(y) ,n = J(y)





(n) ,x = J(ω)



(2) 

(x) ,H = J(x)



 (H)−J(H) , y ∈ R2N J(H) (H)

is the received vector, H ∈ R2N ×2K is the channel matrix, x ∈ S and n ∈ R2N are equivalent real-valued baseband signal vectors and noise vectors. In addition, the various modulation methods involve the re-parameterization of the constellation map parameters with a unique thermal code. One-hot encoding of QPSK 2K

s1 = −1 ↔u1 = [1, 0] s2 = 1 ↔u2 = [0, 1]

(3)

where x = [s1 , s2 ]T denotes the transmitted signal vector with each element belongs to the modulation constellation, ui is the bit tube corresponding to the signal si , 1 ≤i≤ 2.

298

H. Ding et al.

One-hot encoding of 16-QAM s1 s2 s3 s4

= −3↔u1 =[1, 0, 0, 0] = −1↔u2 = [0, 1, 0, 0] = 1↔u3 = [0, 0, 1, 0] = 3↔u4 = [0, 0, 0, 1]

(4)

The mapping relationship between the constellation points and the source bits can be represented by the function si = foh (xoh ) i = 1, ..., |S|, which is defined as, x = foh (xoh ) =

|S| 

si [xoh ]i .

(5)

i=1

In this paper, the output of the current time step reset gate will be element-wise multiplied by the hidden state of the previous time step. If the value of elements in the reset gate is close to 0, the hidden state of the previous time step will be discarded. If it is close to 1, the hidden state of the previous time step is retained. Take 16QAM for example, if the target we want to receive is [1, 0, 0, 0] and the output of the hidden state of the previous step is [0, 1, 0, 0], [0, 0, 0, 1], then the element value of the reset gate will be 0. In another word, the hidden state of the previous step will be discarded. Otherwise, it is reserved. 2.2 DetNet The DetNet is a deep neural network designed for massive MIMO detection applications. It is created through iterative expansion of the projected gradient descent algorithm, and is obtained by matrix multiplication of the channel matrix, forwarding channel matrix, and received signal vector, as shown in Fig. 1. δ1,k T

H y

×

Wx,h

b x ,h

×

+

T

HH xk

× δ 2,k

vk

+

Con

ρ

Wh, y

bh, y

×

+

×

+

Wh,v

b h ,v

X Oh ,k +1

f oh

Fig. 1. DetNet single-layer flow chart

X k +1 Vk +1

×

Mulplicaon

+

Addion

f oh

One-Hot Funcon

Con

ρ

Concatenaon

Acvaon Funcon

Improved DetNet Algorithm Based on GRU

299

From the above process as shown in Fig. 1, the following relationship is obtained,   ∂||y−Hx||2 |x=ˆxk , xˆ k+1 =  xˆ k − δk ∂x (6)   T T = xˆ k −δk H y+δk H Hxk ⎛ ⎡ T ⎤ ⎞ H y ⎜ ⎢ ⎥ ⎟ ⎜ ⎢ xˆ k ⎥ ⎟ ⎜ ⎢ ⎥ + b1k ⎟ (7) zk = ρ ⎜W1k ⎢ T ⎥ ⎟, ⎝ ⎣ H Hˆxk−1 ⎦ ⎠ vk xˆ oh,k = W2k zk + b2k ,

(8)

xˆ k+1 = foh (ˆxoh,k ),

(9)

vˆ k+1 = W3k zk + b3k ,

(10)

xˆ 0 = 0, vˆ 0 = 0.

(11)

where xˆ k is the estimated value at the k-th iteration,[·] is the nonlinear projection operator,zk is the output of the hidden layer, xˆ oh,k is the output of the output layer, xˆ k+1 is the input data for the next iteration, vˆ k is a nonlinear factor, and δk is the learned gradient step size, which is optimized through training.

3 Improved DetNet Model 3.1 GRU-DetNet In this section, an improved neural network model for DetNet is proposed. As shown in Fig. 2, gated circulation unit is introduced to control the information flow [16]. The reset gate discards historical information that is not relevant to the prediction, and the update gate allows the hidden state to be updated by the candidate hidden state of the current time step information. Thus, the structure of the hidden layer in the model is modified, and a new neural network model GRU-DetNet is obtained. As a result, the system increases the memory function and provides the necessary information for the next prediction. The problem of gradient attenuation or explosion in this model can be overcome, and the risk of overfitting is reduced. The output of the input layer of the GRU-DetNet model is calculated as follows, qt = xˆ t−1 −δ1t HT y+δ2t HT Hˆxt−1 ,

(12)

where qt is the output of the input layer,ˆxt is the estimate of xt , xt−1 and xt are the inputs at the previous and current time, respectively.

300

H. Ding et al.

δ1,t T

H y

×

×

h t−1 ×

T

HH xt

× δ 2,t

+

ht

1−

Rt ρ

+

Zt

× ρ ρ

Con

×

+

Wh, y

bh, y

×

+

Wh,v

b h ,v

X Oh ,t+1

f oh

X t +1

Vt +1

vt Fig. 2. GRU-DetNet single-layer flow chart

The status updates of the reset and update gates of GRU-DetNet neural network and the output of the hidden layer are calculated as follows.    qt Wxr + ht−1 Whr + br , (13) Rt = ρ1 vt    qt Zt = ρ1 Wxz + ht−1 Whz + bz , (14) vt    qt h˜ t = ρ2 Wxh + Rt ht−1 Whh + bn , (15) vt ht = Zt ht−1 + (1 − Zt )h˜ t ,

(16)

where Rt is the output of the update gate in GRU-DetNet, Zt is the output of the reset gate in GRU-DetNet, h˜ t is the output of the candidate hidden state in GRU-DetNet, ht and ht−1 are the output of the hidden layer at the current time and the previous time.Wxr ,Wxz and Wxh are weights from the input layer to the update gate, reset gate, and candidate hidden states, respectively.Whr ,Whz and Whh are the weights of the hidden state to the update door, reset door and hidden layer, respectively. br , bz and bn are the update gate, reset gate, and hidden state bias, respectively.ρ1 and ρ2 are sigmoid and tanh functions respectively. where tanh x =

f (x) =

1 , 1 + e−x

ex − e−x sinh x = x . cosh x e + e−x

(17) (18)

The output layer of the improved GRU-DetNet is calculated as follows. xˆ oh,t = Why ht + bhy ,

(19)

Improved DetNet Algorithm Based on GRU

301

xˆ t+1 = foh (ˆxoh,t ),

(20)

vˆ t+1 = Whv ht + bhv .

(21)

where xˆ oh,t is the output of the output layer, xˆ t+1 is the input data for the next iteration.Why and Whv are the weights of the hidden layer to the input layer of the next iteration. bhy and bhv are the bias from the hidden layer to the input layer of the next iteration. The loss function l considering all layers is used [17], l(xoh ; xˆ oh (H, y; θ )) =

L 

log(l)||xoh − xˆ oh,k ||2 .

(22)

l=1

3.2 Hybrid-DetNet It is seems difficult to obtain the global optimum by training the neural network by a single model. Due to the parallelism and robustness properties of deep neural network, a hybrid neural network model Hybrid-DetNet with parallel structure of DetNet model and improved GRU-DetNet model is also proposed, as shown in Fig. 3.

DetNet

Demodulated Output

GRU-DetNet

x = f oh ( xoh )

Fig. 3. Hybrid-DetNet single-layer flow chart

In this paper, we also incorporate the residual features of ResNet [18] to weighted average the outputs of the two models in the hybrid neural network. ht0 = αht1 + (1 − α)ht2 ,

(23)

0 xˆ oh,t = Why ht0 + bhy ,

(24)

  0 0 , xˆ t+1 = foh xˆ oh,t

(25)

0 vˆ t+1 = Whv ht0 + bhv ,

(26)

302

H. Ding et al.

where ht0 is the output of the hidden layer of the Hybrid-DetNet model, ht1 and ht2 are the 0 output of the DetNet and GRU-DetNet hidden layers, respectively. xˆ oh,t is the output of 0 the output layer of the Hybrid-DetNet model, xˆ t+1 is the input data for the next iteration of the Hybrid-DetNet model, where 0 < α < 1.

4 Simulation Results Analysis 4.1 Simulation Settings In this paper, the parameters of the jamming system performance are used as the learnable parameters, and the parameters used in the system are optimized by the back propagation algorithm in the training of the network model, which can effectively improve the detection performance of the system model. And the whole experiment process is completed in the TensorFlow2.0 library based on Python3.8 language. During the investigation, the number of neurons zk and vk used for BPSK modulation are 4K and 2K, respectively. The number of neurons zk and vk used for 8PSK modulation are 12K and 4K, both modulation methods have been trained for 200,000 iterations. The number of neurons zk and vk used for 16QAM modulation are 8K and 4K. However, 16QAM is trained iteratively for three batches of 200000, 300000 and 500000, respectively. The experimental data of the three modulation modes used in the training process is 3000 samples. The batch used during testing was 1000 samples and 200 iterations were performed. 4.2 Simulation Results As shown in Fig. 4, the proposed GRU-DetNet and Hybrid-DetNet models achieve better accuracy than the DetNet model under FC channel conditions, where HybridDetNet performs better and better with increasing number of iterations. In addition, the BER is much lower than the traditional Zero Forcing (ZF) and Decision Feedback (DF) detection algorithms, while the traditional Approximate Message Passing (AMP) detection algorithm can’t converge successfully. Figure 5 and Fig. 6 show the BER performance of 60 × 30 MIMO with BPSK modulation and 25 × 15 MIMO with 16QAM modulation over VC channels, respectively. The proposed GRU-DetNet and Hybrid-DetNet have higher performance gain than the conventional ZF, DF and AMP for both modulation methods. And the proposed HybridDetNet model becomes more and more effective as the number of iterations increases. The proposed GRU-DetNet model achieves a performance gain of about 0.5 dB compared with the DetNet model, while the Hybrid-DetNet model achieves a performance gain of about 1 dB under 16QAM modulation, as shown in Fig. 6.

Improved DetNet Algorithm Based on GRU

303

Fig. 4. Comparison of detection performance of BPSK modulation 25 × 15 MIMO in FC channel.

Fig. 5. Comparison of detection performance of BPSK modulation 25 × 15 MIMO in VC channel.

Figure 7 and Fig. 8 are about the running times comparisons of the algorithms under three batches with batch sizes of 1, 10, and 100, respectively. The running time of these algorithm are given for 60 × 30 MIMO with BPSK modulation and 25 × 15 MIMO with 16-QAM modulation over VC channels. The proposed GRU-DetNet model has a running time comparable to AMP and DetNet for the same batch, while the Hybrid-DetNet model has a shorter running time than AMP, DetNet, and GRU-DetNet. Furthermore, the use of batch processing can significantly improve the performance of the model.

304

H. Ding et al.

Fig. 6. Comparison of detection performance of 16-QAM modulation 25 × 15 MIMO in VC channel.

Fig. 7. Comparison of the algorithm running time of different batches of BPSK modulation 60 × 30 MIMO in VC channel.

Figure 9 and Fig. 10 show the BER performance of 30 × 20 MIMO with QPSK modulation and 25 × 15 MIMO with 8-PSK modulation over VC channels, respectively. The proposed GRU-DetNet and Hybrid-DetNet achieve comparable performance with DetNet under both modulations. Moreover, the Hybrid-DetNet models are getting better with increasing number of iterations for both modulations. The proposed GRU-DetNet and Hybrid-DetNet have higher performance gains than the conventional ZF, DF and AMP under QPSK modulation, as shown in Fig. 9.

Improved DetNet Algorithm Based on GRU

305

Fig. 8. Comparison of Algorithms Running Time of Different Batches of 16-QAM Modulation 25 × 15MIMO in VC Channel.

Fig. 9. Comparison of detection performance of QPSK modulation 30 × 20 MIMO in VC channel.

306

H. Ding et al.

Fig. 10. Comparison of detection performance of 8-PSK modulation 25 × 15 MIMO in VC channel.

5 Conclusions In this paper, deep detection networks GRU-DetNet and Hybrid-DetNet model are proposed as low-complexity and high-performance detection algorithms for massive MIMO. According to the comparative study conducted on different channel models and constellations, the proposed model and algorithm achieve better detection performance with lower complexity. In this study, the effectiveness of the deep learning-based approach in massive MIMO signal detection is verified. It provides a new technical solution for future research on intelligent wireless communication systems.

References 1. Andrews, J.G., Buzzi, S., Choi, W., et al.: What will 5G be? IEEE J. Sel. Areas Commun. 32(6), 1065–1082 (2014) 2. Zappone, A., Di Renzo, M., Debbah, M.: Wireless networks design in the era of deep learning: model-based, AI-based, or both?. IEEE Trans. Commun. 67(10), 7331–7376 (2019) 3. Baek, M.S., Kwak, S., Jung, J.Y., Kim, H.M., Choi, D.J.: Implementation methodologies of deep learning-based signal detection for conventional MIMO transmitters. IEEE Trans. Broadcast. 65(3), 636–642 (2019) 4. Li, L., Hou, H., Meng, W.: Convolutional-neural-network-based detection algorithm for uplink multiuser massive MIMO systems. IEEE Access 8(3), 64250–64265 (2020) 5. Jia, Z., Cheng, W., Zhang, H.: A partial learning-based detection scheme for massive MIMO. IEEE Wirel. Commun. Lett. 8(4), 1137–1140 (2019) 6. He, H., Wen, C.K., Jin, S., Li, G.Y.: A model-driven deep learning network for MIMO detection. In: 2018 IEEE Global Conference on Signal and Information Processing (GlobalSIP), pp. 584–588. Anaheim, CA (2018)

Improved DetNet Algorithm Based on GRU

307

7. Ao, P., Li, R., Sun, R., Xue, J.: Semi-supervised learning for MIMO detection. In: 2022 14th International Conference on Wireless Communications and Signal Processing (WCSP), Nanjing, China, pp. 1023–1027 (2022) 8. Tan, X., Zhong, Z., Zhang, Z., You, X., Zhang, C.: Low-complexity message passing MIMO detection algorithm with deep neural network. In: 2018 IEEE Global Conference on Signal and Information Processing (GlobalSIP), pp. 559–563. Anaheim, CA, USA (2018) 9. Zhou, Z., Liu, L., Jere, S., Zhang, J., Yi, Y.: RCNet: incorporating structural information into deep RNN for online MIMO-OFDM symbol detection with limited training. IEEE Trans. Wirel. Commun. 20(6), 3524–3537 (2021) 10. Bouchenak, S., Merzougui, R., Harrou, F., Dairi, A., Sun, Y.: A Semi-supervised modulation identification in MIMO systems: a deep learning strategy. IEEE Access 10(5), 76622–76635 (2022) 11. Zhang, Y., Sun, J., Xue, J., Li, G.Y., Xu, Z.: Deep expectation-maximization for joint MIMO channel estimation and signal detection. IEEE Trans. Sign. Process. 70(16), 4483–4497 (2022) 12. Jeong, D., Kim, J.: Signal detection for MIMO SC-FDMA systems exploiting block circulant channel structure. IEEE Trans. Veh. Technol. 65(9), 7774–7779 (2016) 13. Choi, J., Joung, J.: Generalized space-time line code with receive combining for MIMO systems. IEEE Syst. J. 16(2), 1897–1908 (2022) 14. Jin, X., Kim, H.N.: Parallel deep learning detection network in the MIMO channel. IEEE Commun. Lett. 24(1), 126–130 (2020) 15. Samuel, N., Diskin, T., Wiesel, A.: Learning to detect. Sig. Process. IEEE Trans. 67(5), 2554–2564 (2019) 16. Mirza, A.H.: Online additive updates with FFT-IFFT operator on the GRU neural networks. In: 2018 26th Signal Processing and Communications Applications Conference (SIU). Lzmir: SIU, pp. 1–4 (2018) 17. Szegedy, C., et al.: Going deeper with convolutions. In: 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Boston: CVPR, pp. 1–9 (2015) 18. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). Las Vegas: CVPR, pp. 770–778 (2016)

Epileptic Seizure Detection Based on Feature Extraction and CNN-BiGRU Network with Attention Mechanism Jie Xu, Juan Wang, Jin-Xing Liu, Junliang Shang, Lingyun Dai, Kuiting Yan, and Shasha Yuan(B) School of Computer Science, Qufu Normal University, Rizhao 276826, China [email protected]

Abstract. Epilepsy is one of the most widespread neurological disorders of the brain. In this paper, an efficient seizure detection system based on the combination of traditional feature extraction and deep learning model is proposed. Firstly, the wavelet transform is applied to the EEG signals for filtering processing and the subband signals containing the main feature information are selected. Then several EEG features, including statistical, frequency and nonlinear properties of the signals, are extracted. In order to highlight the extracted feature representation of EEG signals and solve the problems of slow convergence speed of model, the extracted features are fed into the proposed CNN-BiGRU deep network model with the attention mechanism for classification. Finally, the output of classification model is further processed by the postprocessing technology to obtain the classification results. This method yielded the average sensitivity of 93.68%, accuracy of 98.35%, and false detection rate of 0.397/h for the 21 patients in the Freiburg EEG dataset. The results demonstrate the superiority of the attention mechanism based CNN-BiGRU network for seizure detection and illustrate its great potential for investigations in seizure detection. Keywords: Electroencephalography · Seizure detection · Feature extraction · Deep learning

1 Introduction Epilepsy is a sophisticated disorder, characterized by sudden and recurrent seizures, making it significantly impacting patients’ quality of life [1]. Electroencephalography (EEG) is used to record the neuroelectrical activity of the brain and represents one of the most common and reliable important analytical tools for studying brain disorders. Traditionally, seizures have been monitored visually by EEG recordings relying on the a priori knowledge and experience of experts. In view of this process is often very costly resources and time, so the realization of high-performance automatic detection of epilepsy is the main research direction of scholars [2].

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 308–319, 2023. https://doi.org/10.1007/978-981-99-4742-3_25

Epileptic Seizure Detection Based on Feature Extraction

309

With the deepening of the understanding of epilepsy, an increasing body of literature has proposed a variety of automatic epilepsy detection algorithms. Traditional techniques for automatic seizure detection mainly use machine learning methods and consist of two main parts: feature extraction and classification. Feature extraction involves manually selecting relevant features from EEG, such as time-domain, frequency-domain, timefrequency-domain, and nonlinear algorithms [3]. The extracted features are then fed into effective classifiers for EEG classification, such as Ensemble Learning, Bayesian Linear Discriminant Analysis (BLDA) [4, 5]. However, traditional machine learning algorithms often suffer from slow convergence speed and poor classification performance when applied to EEG signals classification tasks. Recently, deep learning models, particularly CNN and RNN, are beginning to be frequently utilized for seizure detection due to their potent hierarchical feature learning capabilities. Artur and Jaroslaw proposed a CNN-based deep learning framework to detect epileptic activity and tested its usability on the neonatal EEG dataset [6]. Recurrent neural networks (RNN) and their derivations, like as long short-term memory (LSTM) and gated recurrent units (GRU), are designed with the goal of efficiently discovering local and global patterns. Meanwhile, there are several studies proposing a hybrid model incorporating CNN and RNN and a few studies that employed the attention mechanism in combination with RNN for seizure detection. The final results proved superior to other comparative models without the attention mechanism [7]. In this work, we proposed a method that combines handcrafted feature extraction with an attention-based CNN-BiGRU model, focusing on the automatic classification of seizure and non-seizure signals using offline EEG signals from epileptic patients, which is able to effectively handle the complexity and variability of EEG. The main contributions of this study could be condensed as follows: • The manual features of different frequency bands of the original EEG recordings after preprocessing are extracted, including statistical, frequency features and non-linear features. • The model combines the advantages of CNN and BiGRU to further extract the intrinsic EEG features, avoiding the gradient disappearance and explosion problems. We also introduce an attention mechanism to apportion various weights to the hidden states of BiGRU, which could decrease the loss of historical information, as well as highlight the influence of important information. The rest of the paper is arranged as follows: The EEG database utilized in this study is described in Sect. 2. Section 3 elaborates the proposed seizure detection method. Section 4 presents the detailed experimental results, and Sect. 5 discusses the performance of this seizure detection approach. Conclusively, Sect. 6 gives the conclusion of this paper.

310

J. Xu et al.

2 EEG Database In this paper, the EEG dataset used to evaluate our system is derived from intracranial electroencephalogram (iEEG) signals of 21 patients provided by the Epilepsy Center of the University Hospital in Freiburg, Germany. Further elaborated information on the Freiburg dataset is available in Ref. [8]. In this study, the EEG data captured from the three focal channels with a wealth of pathological information were employed. In this experiments, epileptic seizure recordings and non-seizure recordings of the same length were stochastically selected as the training dataset. The training dataset consists of 35 seizure events from 21 patients, totaling 105.32 min of seizure data and non-seizure data of equal length. The remainder included 58 seizure events with 574 h of EEG data as the testing dataset.

3 Methods Four modules, including preprocessing and segmentation, feature extraction, CNNBiGRU with attention classification and postprocessing, make up the structure of this automatic seizure detection system, which is shown in Fig. 1. The specifics of each section are provided in the following subsections.

Fig. 1. The overall flowchart of the proposed epileptic seizure detection system.

3.1 Pre-processing Since seizures are associated with EEG activity in specific frequency bands, we need to perform preprocessing operations on EEG recordings. The raw EEG signals of each channel are initially divided into 4-s segments in this paper. Then, the discrete wavelet transform (DWT) with the daubechies-4 wavelet is employed to each segmented EEG and was chosen to divide the EEG signal into five scales. The frequency bands of the detail coefficients for these five scales are 64–128 Hz (d1), 32–64 Hz (d2), 16–32 Hz (d3), 8– 16 Hz (d4), and 4–8 Hz (d5), respectively. Three scales of d3, d4 and d5 detail coefficients were selected for signal reconstruction due to seizure frequency range below 30 Hz, and the length of all reconstructed signals was equal to that of the original EEG. Afterwards, traditional feature extraction methods were implemented on the reconstructed signals.

Epileptic Seizure Detection Based on Feature Extraction

311

3.2 Feature Extraction In this paper, the single channel feature extraction method is based on the wavelet transform decomposition to extract 11 linear statistical features, 4 frequency features and 3 nonlinear features of EEG signal, and then the signal features of all channels are stitched together horizontally. It has been experimentally demonstrated that the adopted feature calculation method greatly reduces the computational complexity and maintains good feature symbolism, which in turn characterizes the changes in neural activity of the brain during the seizure period from different perspectives. Table 1. Description of the extracted features. Features

Mathematics

Features

Max

MaxX = max(X )

RMS

Min

MinX = min(X )

Energy

Mean

MeanX = E(X )

FMF

PTP

PTP = MaxX − MinX

FFC

ARV

ARVX = Mean(abs(X )) FRMSF

Variance

VarX =

Std

StdX =  N

N

n=1 (Xt

n=1



Xt −X N

−MeanX )2

2

FRVF

Mathematics 

N

2 n=1 Xn

RMS =

N

    EX = n=1 Xn2   FMF = N1 N n=1 u(n) N

FFC =

N u(n)u(n) n=1  2 2π N n=1 u(n)

FRMSF =



 FRVF =

Permutation Entropy

PE = −

N 2 n=1 u(n) 2 2 4π u(n)

N (FFC )2 u(n) n=1 N n−1 u(n)

N

r=1 Pr log2 Pr

N

Kurtosis

Kurtosis =  4 X −MeanX E StdX

Approximate Entropy

(r) =  N −m+1 1 log Cim (r) i=1 N −m+1

Skewness

Skewness =  3 E X −MeanX StdX

Sample Entropy



A SampEn = − log B



Note: a time series signal of length N is denoted by X = X1, X2 · · · Xn

Statistical features includes maximum, minimum, mean, rectified mean value (ARV), peak-to-peak (PTP), variance, kurtosis, skewness, standard deviation (Std), root mean square (RMS), and energy, respectively [9]. Four frequency-based statistical variables are employed in the frequency domain: mean frequency (FMF ), center of gravity frequency (FFC ), root mean square frequency (FRMSF ), and standard deviation of frequency (FRVF ) [10]. Three kinds of entropy, permutation entropy, approximate entropy and sample

312

J. Xu et al.

entropy, were used to extract EEG features [11]. Lastly, the various features extracted here are listed in Table 1. 3.3 Classification Model (CNN-BiGRU with Attention Mechanism Model) With the development of deep learning techniques, many deep learning approaches have been developed for seizure detection. We employed a variant of the RNN with the bidirectional gated recurrent units (BiGRU) in this model, which helps to mitigate gradient disappearance and enhances feature propagation. In addition, the proposed model was modified by applying the attention mechanism to enhance efficient information extraction of the BiGRU network. Therefore, a hybrid model was proposed, which combines CNN, BiGRU and attention mechanism to form the new EEG classification framework for seizure detection. The model architecture consists of four modules, including convolution layer, BiGRU layer, attention module, output layer, which is illustrated in Fig. 2.

Fig. 2. The detailed steps of the proposed deep learning classification model.

Convolutional Block. The convolutional layer, which contains the activation layer and convolutional computation layer, serves as the brain of the convolutional neural network (CNN). The layer of convolutional computing consists of a convolutional kernel matrix operation and a corresponding positional data matrix operation. Where the number of convolutional filters is 32, the size of the convolution kernel is 3 × 1. Bidirectional Gated Recurrent Unit. Unidirectional GRUs always propagate from front to back output in one direction, and can only extracted the current state forwarding information while no subsequent sequence information can be utilized. BiGRU allows the network to obtain information from  both forwardand backward propagation to fully extract data-related features. ht = Bi HC ,t−1 , HC,t , , (t ∈ [1, i]) indicates the output of BiGRU at t, which is simultaneously determined by the status of both GRUs together. Attention Mechanism Block. The attention mechanism not only avoids the waste of feature information caused by the direct connection of the BiGRU network output layer to the full connection layer, but also reinforces the ability of the model to capture important local information. Additionally, it assigns corresponding weights to feature vectors,

Epileptic Seizure Detection Based on Feature Extraction

313

highlighting key features and enhancing the feature expression of EEG signals, thereby improving the model’s judgment accuracy. Figure 3 depicts the structure of attention mechanism. More details can be found in Ref. [12]. Output Layer. The feature vectors obtained from the attention layer are input into the fully connected layer. Finally, the probability of which category each EEG segment belongs to is determined by Softmax activation function.

Fig. 3. The structure of attention mechanism.

In addition, the categorical cross-entropy acts as the loss function and the Adam optimization technique serves to train the deep network model. The parameters are set as follows: the number of iterations is 200, the batch size is 128 and the learning rate is 0.001.

3.4 Post-processing Since the outputs of the deep learning classification model are probabilistic results rather than exact category labels, a series of post-processing operations are required to obtain the category labels of the EEG samples. The post-processing techniques majorly contained smoothing, threshold judgment and adaptive collar techniques. Specifically, the proposed model generates two probability values for each test sample, i.e., the seizure and nonseizure states. The two probability values are subtracted, and then the linear moving average filter (MAF) is utilized to limit the fluctuation of the probability difference, making it possible to eliminate some short-term misclassification spots [13]. The outputs of the smoothing filter are compared with a threshold to produce a binary label of seizure or non-seizure for the test sample. The threshold value is unique to each patient which is dependent on the minimum error classification criteria of the training dataset. The determination of the smoothing length is also based on this approach.

314

J. Xu et al.

Moreover, the signal changes at the beginning and end stages of epileptic seizures are not obvious, and the smoothing operation makes the difference more blurred. To tackle this issue, we employ the adaptive collar technique, which is based on the principle of extending the duration of the initial detection results output by the deep network model.

4 Results In this paper, two evaluation metrics, segment-based and event-based, were employed to assess the validity and feasibility of the proposed automated seizure detection method. Sensitivity, specificity and recognition accuracy were utilized as the criteria in the segment-based assessment. In addition, event-sensitivity and false detection rate (FDR) are adopted in the event-based assessment, which focuses on the number of epileptic seizures rather than the number of EEG segments. The average number of false seizure events detected per hour during non-seizure periods is referred to the FDR. Table 2. The proposed method evaluation results on the Freiburg dataset. Patient

Sensitivity(%)

Specificity(%)

Accuracy(%)

Number of seizures experts marked

Number of seizures detections

FDR/h

pat001

50

99.46

99.45

2

1

0.2258

pat002

100

99.97

99.97

2

2

0.0645

pat003

100

98.30

98.31

4

4

0.2903

pat004

100

97.82

97.83

4

4

0.3548

pat005

100

99.53

99.44

4

4

0.6452

pat006

100

98.16

98.16

2

2

0.9032

pat007

100

98.44

98.44

2

2

0.3226

pat008

73.30

95.55

95.51

1

1

0.7097

pat009

96.69

95.61

95.61

4

4

0.5161

pat010

100

98.04

98.04

1

1

0.9355

pat011

100

99.96

99.66

3

3

0.0968

pat012

100

99.93

99.93

3

3

0.0968

pat013

100

98.75

98.75

1

1

0.3226

pat014

100

98.53

98.54

3

3

0.1935

pat015

88.06

99.62

99.59

3

3

0.0968

pat016

100

98.77

98.77

4

4

0.4194 (continued)

Epileptic Seizure Detection Based on Feature Extraction

315

Table 2. (continued) Patient

Sensitivity(%)

Specificity(%)

Accuracy(%)

Number of seizures experts marked

Number of seizures detections

FDR/h

pat017

100

99.23

99.24

4

4

0.1935

pat018

100

94.27

94.27

2

2

0.6129

pat019

70

99.96

99.95

2

1

0.0645

pat020

89.29

95.68

95.66

4

3

0.871

pat021

100

99.86

99.87

3

3

0.0968

Mean

93.68

98.35

98.33

58

55

0.397

Our experiments were implemented in MATLAB 2016 and Keras 2.9.0 with Python 3.9.10. The long-term intracranial EEG data of 21 patients were analyzed to estimate the proposed seizure detection system. Here, we elaborated the experimental results based on the above evaluation criteria. In the segment-based assessment, Table 2 displays that the average sensitivity, specificity, and identification accuracy for 21 patients were 93.68%, 98.35%, and 98.33%, respectively. Among all cases, 9 cases had specificity and accuracy greater than 99%, and 15 cases had 100% sensitivity. The impacts of epileptic-like activity resulted in low specificity and accuracy in patients 1, 8, 15, 19, and 20. Meanwhile, the table also shows the event-based assessment, in which the average FDR was 0.397/h, and 55 of the 58 seizures in the test data were actually recognized. Furthermore, the FDR was less than 0.2/h in 7 patients, and the high FDR mostly occurred on the testing dataset of cases 6, 9, and 10.

5 Discussion In this work, we propose an attention mechanism-based CNN-BiGRU deep network model for seizure detection, which demonstrates superior performance on long-term EEG dataset. To explore the contribution of traditional feature extraction generated features and deep learning learned features to classification performance, we conducted t-SNE analysis on each of them. We randomly selected a patient’s handmade features and some feature data learned through deep learning models, and the results are shown in Fig. 4. Note that each point in the scatter plot represents a sample feature. We found that the two types of features can be roughly distinguished through traditional feature extraction methods, but the differences between the two types of features are more pronounced after deep learning model learning.

316

J. Xu et al.

Fig. 4. The scatter plots of randomly selected features handcrafted and deep learning learned features. Where blue represents seizure class, yellow represents non seizure class.

Furthermore, the size of the epoch is determined in this study based on the accuracy and loss rates of the classification model training. Figure 5 shows examples of training and validation accuracy and training and validation loss for this deep network model. Since the model performs best when the epoch reaches 200 and the two metrics tend to smooth out, the number of epochs was set as 200 in this work.

Fig. 5. The accuracy and loss curves of training and validation sets for patient 2. (a) The train and valid accuracy. (b) The train and valid loss.

To further verify the detection and classification ability of the model, we compared the average detection results for all patients by eliminating CNN, CNN and Attention respectively. Figure 6 illustrates some detection result evaluation metrics of the BiGRU model, the attention-based BiGRU model and the suggested model, there are the following findings: (1) The attention mechanism was introduced to the BiGRU model, which raised the sensitivity, specificity, and accuracy of the combined model by 4.04%, 0.97%, and 1.03%, respectively. This indicates that the addition of attention mechanism can strengthen important information, making the classification model acquire prospective features more efficiently and heightening the detection accuracy of the model. (2) On the basis of the BiGRU model with the attention mechanism, the CNN-BiGRU with the attention mechanism model in this study yielded a 10.32%, 5.64% and 5.61% increase in three indicators respectively. The comparison of detection results reveals that the

Epileptic Seizure Detection Based on Feature Extraction

317

addition of CNN plays an information enhancing role in the input manual feature data. Therefore, the proposed deep model combining CNN, Attention and BiGRU has better capability and is more sensitive for the detection of EEG signals than the other two models.

Fig. 6. Classification results of three various deep learning models on the Freiburg long-trem dataset.

A comparison of several available epilepsy detection approaches is furnished in Table 3, all of which were evaluated on the Freiburg database. Ma et al. completed seizure detection by calculating tensor distance as EEG feature and combining it with BLDA [4]. In comparison, our system examined 58 seizures and achieved higher specificity, accuracy, and FDR. Mu et al. performed EEG data reduction and feature extraction with using the graph-regularized non-negative matrix decomposition (GNMF) [5]. It is apparent that our approach outperformed theirs. In the study of Malekzadeh et al., they proposed to combine various manual features with deep learning features and then used deep learning method based on CNN-RNN to accomplish epilepsy detection [14]. The work delivered very satisfactory performance with an accuracy rating of 99.13%. Though it surpassed our method in accuracy, our algorithm is not only simpler than theirs, but also produces significant results with a smaller training data. Moreover, the usage of training and testing dataset was not explained in their study. In addition, the seizure detection methods in Ref. [15] and [16] both used LSTM. CNN and LSTM neural networks were combined as seizure detectors to distinguish between interictal, ictal, and pre-ictal segments by Hussain et al. In contrast, our algorithm utilizes GRUs instead of LSTMs, which have a simpler structure with less computational cost. In summary, our proposed attention-based CNN-BiGRU deep network detection technique not only has equivalent detection capability as the other deep neural networkbased methodologies, but also has a simpler structure with less computational cost. Our method yields comparable or superior results to the existing methods while requiring less training data and analyzing each patient in the Freiburg dataset.

318

J. Xu et al. Table 3. Properties of different methodologies compared on the Freiburg database.

Authors

Method

Number of Patients

Sensitivity (%)

Specificity (%)

Accuracy (%)

FDR/h

Ma et al.[4]

Tensor distance and BLDA

21

95.12

97.60

97.60

0.76

Mu et al.[5]

GNMF and BLDA 21

93.20

98.16

98.16

0.50

Malekzadeh et al.[14]

CNN and RNN



98.96

98.96

99.13



Hussain et al.[15]

1D-convolutional and LSTM

21

99.46

98.45

99.27



Jaafar et al.[16]

LSTM

11





97.75



Our work

Handcrafted features and CNN-BiGRU

21

93.68

98.35

98.33

0.397

6 Conclusion This research presents a novel seizure detection method combining artificial features and CNN-BiGRU deep network model with attention mechanism. The classification model based on CNN-BiGRU capture more critical information by adding attention mechanism and enhance EEG features, which lead to more accurate recognition of seizure and nonseizure signals. The results prove the potential application of the proposed algorithm in clinical epilepsy detection. In future work, we would like to conduct research on largescale databases containing more EEG electrodes to achieve higher performance, such as the CHB-MIT database. In addition, the spatial topological relationship between various channels of EEG is a problem worth studying. Acknowledgment. This work was supported by the Program for Youth Innovative Research Team in the University of Shandong Province in China (No. 2022KJ179), and jointly supported by the National Natural Science Foundation of China (No. 61972226, No. 62172253).

References 1. Cl, A., Yc, A., Zc, B., Yl, B., Zw, B: Automatic epilepsy detection based on generalized convolutional prototype learning. Measurement 184, 109954 (2021) 2. Gao, B., Zhou, J., Yang, Y., Chi, J., Yuan, Q.: Generative adversarial network and convolutional neural network-based EEG imbalanced classification model for seizure detection. Biocybernetics Biomed. Eng. 42(1), 1–15 (2022) 3. Tuncer, E., Bolat, E.D.: Channel based epilepsy seizure type detection from electroencephalography (EEG) signals with machine learning techniques. Biocybernetics Biomed. Eng. 42(2), 575–595 (2022)

Epileptic Seizure Detection Based on Feature Extraction

319

4. Ma, D., et al.: The automatic detection of seizure based on tensor distance and Bayesian linear discriminant analysis. Int. J. Neural Syst. 31(05), 2150006 (2021) 5. Mu, J., et al.: Automatic detection for epileptic seizure using graph-regularized nonnegative matrix factorization and Bayesian linear discriminate analysis. Biocybernetics Biomed. Eng. 41(4), 1258–1271 (2021) 6. Gramacki, A., Gramacki, J.: A deep learning framework for epileptic seizure detection based on neonatal EEG signals. Sci. Rep. 12(1), 13010 (2022) 7. Choi, W., Kim, M.-J., Yum, M.-S., Jeong, D.-H.: Deep convolutional gated recurrent unit combined with attention mechanism to classify preictal from interictal EEG with minimized number of channels. J. Personal. Med. 12(5), 763 (2022) 8. Malekzadeh, A., Zare, A., Yaghoobi, M., Alizadehsani, R.: Automatic diagnosis of epileptic seizures in EEG signals using fractal dimension features and convolutional autoencoder method. Big Data Cogn. Comput. 5(4), 78 (2021) 9. Yu, Z., et al.: Epileptic seizure prediction using deep neural networks via transfer learning and multi-feature fusion. Int. J. Neural Syst. 32(07), 2250032 (2022) 10. Wu, Q., Dey, N., Shi, F., Crespo, R.G., Sherratt, R.S.: Emotion classification on eye-tracking and electroencephalograph fused signals employing deep gradient neural networks. Appl. Soft Comput. 110, 107752 (2021) 11. Yedurkar, D.P., Metkar, S.P., Stephan, T.: Multiresolution directed transfer function approach for segment-wise seizure classification of epileptic EEG signal. Cogn. Neurodyn. 1–15 (2022). https://doi.org/10.1007/s11571-021-09773-z 12. Niu, D., Yu, M., Sun, L., Gao, T., Wang, K.: Short-term multi-energy load forecasting for integrated energy systems based on CNN-BiGRU optimized by attention mechanism. Appl. Energy 313, 118801 (2022) 13. Yuan, S., et al.: Automatic epileptic seizure detection using graph-regularized non-negative matrix factorization and kernel-based robust probabilistic collaborative representation. IEEE Trans. Neural Syst. Rehabil. Eng. 30, 2641–2650 (2022) 14. Malekzadeh, A., Zare, A., Yaghoobi, M., Kobravi, H.-R., Alizadehsani, R.: Epileptic seizures detection in EEG signals using fusion handcrafted and deep learning features. Sensors 21(22), 7710 (2021) 15. Hussain, W., Sadiq, M.T., Siuly, S., Rehman, A.U.: Epileptic seizure detection using 1 Dconvolutional long short-term memory neural networks. Appl. Acoust. 177, 107941 (2021) 16. Jaafar, S.T., Mohammadi, M.: Epileptic seizure detection using deep learning approach. UHD J. Sci. Technol. 3(2), 41 (2019)

Information Potential Based Rolling Bearing Defect Classification and Diagnosis Hui Li(B) , Ruijuan Wang, and Yonghui Xie School of Mechanical and Electrical Engineering, Weifang Vocational College, Weifang 262737, China [email protected]

Abstract. In order to effectively extract the fault characteristics of rolling bearing vibration signals, information potential is introduced into the field of rolling bearing fault diagnosis and classification. Firstly, rolling bearing vibration signals are collected, and training and testing sets are constructed. Secondly, calculate the information potential of each data sample. Thirdly, the information potential of the calculated data samples is taken as characteristic parameters to construct standard feature spaces for different faults. Finally, the Euclidean distance between the test sample and the standard feature space is used to identify different fault types. This experimental example demonstrates how to use information potentials to identify different types of rolling bearing faults. The analysis results of the measured vibration signals of rolling bearings indicate that the information potential based rolling bearing defect classification method can effectively diagnose rolling bearing faults. Information potential can be used as an effective indicator for judging different fault types. Keywords: Bearing · Fault Diagnosis · Information Theory Learning · Information Potential · Classification

1 Introduction As a common component for connecting and transmitting power in mechanical equipment, rolling bearing plays an important role in almost any large-scale equipment. Using advanced technology for condition monitoring and fault diagnosis of bearings can achieve the transformation of bearings from post maintenance and regular maintenance to monitoring maintenance, reduce unnecessary waste of human and material resources, and contribute to improving economic benefits. When the inner ring, outer ring, or rolling element of a rolling bearing is damaged, shock pulses will be generated during operation and contact. The vibration signal contains a large amount of information related to faults, so vibration signal analysis has become one of the most important fault diagnosis methods. The signal fault diagnosis process mainly includes data acquisition, feature extraction, fault identification, and diagnosis. Among them, feature extraction is the key to fault diagnosis. Traditional feature extraction methods include time domain analysis, frequency domain analysis, and time frequency domain analysis [1]. However, © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 320–329, 2023. https://doi.org/10.1007/978-981-99-4742-3_26

Information Potential Based Rolling Bearing Defect

321

due to the complex structure of mechanical transmission devices and the characteristics of nonlinear, non-stationary, and non-Gaussian vibration signals of rolling bearings, it is difficult to accurately determine the fault type of rolling bearings using traditional fault diagnosis methods [2]. Due to the increasing complexity of mechanical equipment, coupled with the nonlinear and non-stationary nature of bearing fault vibration signals, traditional vibration signal processing methods are sometimes difficult to achieve effective feature extraction. Therefore, it is necessary to study new and more effective feature extraction methods. Machine learning methods such as artificial neural networks [3, 4], support vector machines [5], and random forests are used to achieve rolling bearing fault identification. These fault diagnoses use time-frequency transforms such as wavelet transform and Fourier transform to extract the fault features of the original vibration signals of rolling bearings, and achieve rolling bearing fault diagnosis through classifiers. However, these methods have the problems of complex operation process and ignoring the relationship between feature extraction and classification, which leads to the inability of traditional fault diagnosis to maintain high accuracy. In recent years, information theory learning (ITL) has attracted widespread attention from scholars due to its ability to effectively process non Gaussian and non-stationary signals [6–8]. The main purpose of this paper is to extend information theory learning method to the field of mechanical and electrical equipment fault diagnosis, in order to improve the reliability of mechanical and electrical equipment fault diagnosis. In this paper, a bearing fault diagnosis and classification method based on information potential is proposed, using information potential as a characteristic index to determine the type of bearing fault, and realizing the identification of different rolling bearing fault types. The purpose of using information potential to achieve fault diagnosis for rolling bearings is to provide theoretical and technical support for bearing fault diagnosis. The rest of the arrangements are outlined as follows. The basic definition of information potential and main steps of information potential based bearing fault detection are put forward in Sect. 2. Section 3 validates the advantages of information potential based bearing fault detection. Conclusions are given in Sect. 4.

2 Information Potential for Bearing Fault Detection 2.1 The Definition of Information Potential For any real random variable {xi }N i=1 , the kernel matrix can be defined as   K x (i, j) = κ xi , xj

(1)

where κσ (·) is Mercer kernel function, such as polynomial kernel function, laplacian kernel function, exponential kernel function, Gaussian kernel function and so on. The Gaussian kernel function is used to compute the kernel matrix. The Gaussian kernel function is defined as     xi − xj 2   1 (2) exp − κ xi , xj = √ 2σ 2 2πσ

322

H. Li et al.

where σ is kernel length, |·| is the absolute value operator, exp(·) is exponential function. According to Eq. (1) and Eq. (2), the kernel matrix can be calculated ⎡

⎤ κ(x1 , x1 ) κ(x1 , x2 ) κ(x1 , x3 ) · · · κ(x1 , xN ) ⎢ κ(x2 , x1 ) κ(x2 , x2 ) κ(x2 , x3 ) · · · κ(x2 , xN ) ⎥ ⎢ ⎥ ⎢ ⎥ K x = ⎢ κ(x3 , x1 ) κ(x3 , x2 ) κ(x3 , x3 ) · · · κ(x3 , xN ) ⎥ ⎢ ⎥ .. .. .. .. .. ⎣ ⎦ . . . . . κ(xN , x1 ) κ(xN , x2 ) κ(xN , x3 ) · · · κ(xN , xN ) N ×N

(3)

In accordance with the Gaussian kernel function, the kernel matrix is symmetric positive definite. According to Eq. (3), the information potential (IP) can be expressed as N N  1

 IPx = 2 κ xi , xj N

(4)

i=1 j=1

Information potential is a metric that reflects the similarity between signal amplitudes. The larger the information potential, the greater the similarity between signal amplitudes. Conversely, the smaller the information potential, the smaller the similarity between signal amplitudes. When a rolling bearing fails, the information potential of the vibration signals caused by the failure of the inner ring, outer ring, and rolling element of the bearing is different. Therefore, different types of rolling bearing failures can be distinguished based on the magnitude of the information potential. 2.2 Feature Extraction and Classification for Rolling Bearing When a rolling bearing fails, its vibration signal will inevitably change, so the calculated information potential will also be different. Different fault states of rolling bearings produce different vibration signals, which can be indicated by extracting changes in the vibration signal information potential. In the field of information theory learning, the magnitude of the information potential directly reflects the similarity between the amplitudes of vibration signals. Therefore, the information potential of the bearing fault vibration signal can be extracted as a characteristic parameter of the bearing vibration signal. Information potential can measure the similarity of unknown sample sets based on all amplitude features of known sample sets, thereby realizing the recognition and classification of unknown samples. For rolling bearing vibration signals, the difference between characteristic parameters from the same fault data is small, while the information potential difference between characteristic parameters from different fault data is large. Therefore, the absolute values of the test sample characteristic parameters and the training sample characteristic parameters can be used to achieve the classification of rolling bearing faults. Rolling Bearing Fault Feature Extraction. Calculate the information potential of the bearing vibration signal according to Eq. (4), and use the information potential as the

Information Potential Based Rolling Bearing Defect

323

characteristic parameter of the rolling bearing vibration signal. A bearing vibration signal can be extracted to obtain a characteristic parameter. Rolling Bearing Fault Classification. After calculating the information potential of each bearing vibration signal, the average value of the information potential of a certain type of bearing fault sample can be calculated, and then the Euclidean distance between the test sample and the average value can be obtained by subtracting the average value of the information potential of the training sample. Taking the minimum value of the Euclidean distance is the fault type of the test sample. Therefore, fault type of the test sample can be classified by c = min(|IPx − IPc |) c = 1, 2, 3, 4

(5)

where c is the bearing fault type number (c = 1, 2, 3, 4), IPx is the characteristic parameter of the test sample set, IPc is the mean value of the characteristic parameters of the class c training sample set.

2.3 Main Steps of Information Potential Based Bearing Fault Classification The main process of information potential based bearing fault classification is given as follows: 1) The rolling bearings in four states are randomly sampled N times at a certain sampling frequency to obtain 4N sample data. 2) To calculate the information potential of each data sample according to Eq. (4). 3) Randomly select N training sample sets under each type of state, and record the average value of the information potential of the N training samples under the c type state as IPc . 4) According to Eq. (5), rolling bearing fault types are classified to determine the status of rolling bearings and achieve recognition of rolling bearing fault types.

3 Experimental Verification of Bearing Fault Detection The data used in this paper are all from the rolling bearing fault simulation experimental platform of the Case Western Reserve University (CWRU) Electrical Engineering Laboratory in the United States [9], which includes a 2 horsepower motor, a torque sensor, and a power tester. The bearings to be tested are located at both ends of the motor. The drive end bearing model is SKF6205, and the fan end bearing model is SKF6205. The bearing failure point is processed by electrical discharge machining. The diameter of the damage point is 0.1778 mm, 0.3556 mm, and 0.5334 mm, respectively. The damage points of the outer ring of the bearing are located in the three directions of clock: 3 o’clock, 6 o’clock, and 12 o’clock. The vibration data is collected by a vibration acceleration sensor arranged on the motor housing, with a sampling frequency of 12 kHz. The power and rotational speed are measured by a torque sensor/decoder.

324

H. Li et al.

The sampling length of each sample is 2048. The vibration signals of rolling bearing inner ring fault, outer ring fault, rolling element fault, and normal condition are shown in Fig. 1. A total of 80 training samples and 80 test samples were obtained by randomly sampling 40 vibration signal samples of bearings in four different states (inner fault, outer fault, roller fault and normal condition), 20 of which were used as training samples and the other 20 as test samples. As can be seen from Fig. 1, the vibration signal of the bearing outer ring fault has the strongest impact characteristics, and the amplitude of the vibration signal in the time domain changes greatly. The impact characteristics of the bearing inner ring fault vibration signal are ranked second, and the impact characteristics of the bearing rolling element fault vibration signal are ranked third. The impact characteristics of the bearing under normal conditions are the weakest. 2

Inner fault

0

Acceleration (m/s2 )

-2 4 2 0 -2 -4

Outer fault

0.5

Roller fault

0 -0.5 0.2 0 -0.2 -0.4 0

Normal

0.02

0.04

0.06

0.08

Time (s)

0.1

0.12

0.14

0.16

Fig. 1. Rolling bearing vibration signal

Figure 2 shows the information potential of each training sample. Among them, 1–20 is the information potential of the bearing inner ring fault vibration signal, 21–40 is the information potential of the bearing outer ring fault vibration signal, 41–60 is the information potential of the bearing rolling element fault vibration signal, and 61– 80 is the information potential of the bearing normal state vibration signal. As can be seen from Fig. 2, the information potential of different fault states of rolling bearings is different. However, the information potential of vibration signals of the same bearing fault type is relatively close.

Information Potential Based Rolling Bearing Defect

325

Information Potential 4 3.5

Information Potential

3 2.5 2 1.5 1 0.5 0

Inner fault

Outer fault Roller fault Fault type of Bearing

Normal

Fig. 2. Information potential of rolling bearing vibration signal boxplot

Information potential

3.804

2.4168

1.1688 0.7871 Inner fault

Outer fault

Roller fault

Bearing fault type

Normal

Fig. 3. The mean and standard deviation of Information potential

Figure 3 displays the boxplot of the information potential of training samples. As can be seen from Fig. 3, the mean and standard deviation of information potential of different fault states of rolling bearings are different. As can be seen from Fig. 3, the mean of the information potential of the normal condition bearing is the largest, and the mean of the information potential of the bearing outer ring fault is the smallest. Therefore, the information potential of the bearing fault vibration signal can directly reflect the shock characteristics of the internal structure of the vibration signal amplitude. Therefore, the magnitude of the information potential can be used to detect the impact components

326

H. Li et al.

of the rolling bearing fault signal, which is conducive to distinguishing different fault types. Firstly, the training samples are trained, and four mean values IP1 ,IP2 ,IP3 , and IP4 are obtained using steps (2) and (3) of the rolling bearing fault classification process. They are used as standard characteristic parameters for inner ring fault, outer ring fault, rolling element fault and normal condition. Secondly, the remaining 80 test samples are tested, and the characteristic parameter IPx of each test sample is extracted using the step of bearing fault type classification (4). Among them, samples 1–20 are the bearing inner race fault test sample. Samples 21–40 are bearing outer ring fault test samples. Samples 41–60 are bearing rolling element failure test samples and samples 61–80 are bearing normal condition test samples. Finally, through step (4) of the classification process for bearing fault types, the Euclidean distance between the characteristic parameter IPx of the test sample and the standard characteristic parameter IPc of the training sample is calculated, and the minimum value thereof is taken as the bearing fault type. As can be seen from Fig. 4, the results of four types of test samples are clearly distinguished, and the training sample state corresponding to the minimum value is the state recognition type of the test sample. As can be seen from Fig. 4, the Euclidean distance |IPx − IP1 | calculated from test sample 1–20 is the smallest. Therefore, it can be determined that the test sample corresponding to sample 1–20 is the bearing inner ring fault. Similarly, the test sample corresponding to sample 21–40 is an outer ring fault, as shown in Fig. 5. The test sample corresponding to the 41–60 sampling point is a rolling element fault, as shown in Fig. 6. The test sample corresponding to the 61–80 sampling point is in a normal state, as shown in Fig. 7. This is consistent with the fault type of the test sample we started to set, that is, successfully realizing the identification of the rolling bearing fault type. 4 3.5

Absolute value

3 2.5 2 1.5 1 Inner fault Outer fault Roller fault Normal

0.5 0 0

10

20

30

40

No. of samples

50

60

70

80

Fig. 4. Euclidean distance between test sample and inner ring fault training sample

Information Potential Based Rolling Bearing Defect

327

4 3.5

Absolute value

3 2.5 2 1.5 1 Inner fault Outer fault Roller fault Normal

0.5 0 0

10

20

30

40

No. of samples

50

60

70

80

Fig. 5. Euclidean distance between test sample and outer ring fault training sample 4 3.5

Absolute value

3 2.5 2 1.5 1 Inner fault Outer fault Roller fault Normal

0.5 0 0

10

20

30

40

No. of samples

50

60

70

80

Fig. 6. Euclidean distance between test sample and roller fault training sample

Figure 8 shows the confusion matrix of test samples for rolling bearing fault vibration signal. The horizontal axis of the confusion matrix represents the actual class of the sample, the vertical axis represents the predicted class for the sample, the diagonal line represents the original number of samples that were correctly identified, and the nondiagonal line represents the number of samples that were incorrectly identified. There are 20 samples for each type of fault in the test set. The classification accuracy of the four bearing fault types is 100%. It is shown that using information potential as an effective indicator to determine different fault types can effectively diagnose rolling bearing fault types. The results of the above application examples demonstrate that information potential can effectively measure the impact characteristics of bearing fault

328

H. Li et al.

vibration signals, reflect the internal structural characteristics of different fault type data, facilitate the extraction of different fault type features, and improve the accuracy of fault diagnosis. 4

Inner fault Outer fault Roller fault Normal

3.5

Absolute value

3 2.5 2 1.5 1 0.5 0 0

10

20

30

40

No. of samples

50

60

70

80

Fig. 7. Euclidean distance between test sample and normal condition training sample

20 IF

20

0

0

0

18 16

Predicted Class

14 OF

0

20

0

0

12 10

RF

0

0

20

0

8 6 4

NC

0

0

0

20

IF

OF

RF

NC

Actual Class

Fig. 8. Confusion matrix of rolling bearing fault

2 0

Information Potential Based Rolling Bearing Defect

329

4 Conclusions The information potential of different fault states of rolling bearings is different. The information potential of vibration signals of the same bearing fault type is relatively close. Therefore, the information potential of the bearing fault vibration signal can directly reflect the smooth characteristics of the internal structure of the vibration signal amplitude. Therefore, the information potential can facilitate the extraction of different fault type features, and improve the accuracy of fault diagnosis. Acknowledgement. This research is a part of the research that is sponsored by the Science and Technology Planning Project of Tianjin (Grant No. 22YDTPJC00740).

References 1. Rai, A., Upadhyay, S.H.: A review on signal processing techniques utilized in the fault diagnosis of rolling element bearings. Tribol. Int. 96, 289–306 (2016) 2. Abboud, D., Elbadaoui, M., Smith, W.A., et al.: Advanced bearing diagnostics: a comparative study of two powerful approaches. Mech. Syst. Signal Process. 114, 604–627 (2019) 3. Zhao, Z., Li, T., Wu, J., et al.: Deep learning algorithms for rotating machinery intelligent diagnosis: an open source benchmark study. ISA Trans. 107, 224–255 (2020) 4. Zhao, Z., Zhang, Q., Yu, X., et al.: Applications of unsupervised deep transfer learning to intelligent fault diagnosis: a survey and comparative study. IEEE Trans. Instrum. Meas. 70, 3525828 (2021) 5. Pang, B., Tang, G., Tian, T.: Rolling bearing fault diagnosis based on SVDP-based Kurtogram and iterative autocorrelation of teager energy operator. IEEE Access 7, 77222–77237 (2019) 6. Santamaria, I., Pokharel, P.P., Principe, J.C.: Generalized correlation function: definition, properties, and application to blind equalization. IEEE Trans. Signal Process. 54(6), 2187–2197 (2006) 7. Liu, W., Pokharel, P.P., Principe, J.C.: Correntropy: properties and applications in non-gaussian signal processing. IEEE Trans. Signal Process. 55(11), 5286–5298 (2007) 8. Li, H., Hao, R.: Rolling bearing fault diagnosis based on sensor information fusion and generalized cyclic cross correntropy spectrum density. J. Vibrat. Shock 41(2), 200–207 (2022) 9. Smith, W.A., Randall, R.B.: Rolling element bearing diagnostics using the Case Western Reserve University data: A benchmark study. Mech. Syst. Signal Process. 64–65, 100–131 (2015)

Electrocardiogram Signal Noise Reduction Application Employing Different Adaptive Filtering Algorithms Amine Essa1 , Abdullah Zaidan1 , Suhaib Ziad1 , Mohamed Elmeligy2 , Sam Ansari1 , Haya Alaskar3 , Soliman Mahmoud1 , Ayad Turky4 , Wasiq Khan5 , Dhiya Al-Jumeily OBE6 , and Abir Hussain1,5(B) 1 Department of Electrical Engineering, University of Sharjah, Sharjah, United Arab Emirates

{aessa,azaidan,ssalah,samansari,solimanm, abir.hussain}@sharjah.ac.ae 2 Department of Electrical Engineering, American University of Sharjah, Sharjah, United Arab Emirates 3 Department of Computer Science, College of Computer Engineering and Sciences Prince Sattam Bin Abdulaziz University, Alkharj 11942, Saudi Arabia [email protected] 4 Department of Computer Science, University of Sharjah, Sharjah, United Arab Emirates [email protected] 5 School of Computer Science and Mathematics, Liverpool John Moores University, Liverpool L3 3AF, UK [email protected] 6 School of Computer Science and Mathematics, Liverpool John Moores University, Liverpool L3 5UX, UK [email protected]

Abstract. Almost all signals existing in the universe experience varying degrees of noise interference. Specifically, audio signals necessitate efficient noise cancellation for most hearing devices to comfort the user. Various filtering techniques are employed in order to apply efficient noise cancellation, empowering the system to enhance the signal-to-noise ratio. Currently, adaptive filters are preferred to other types of filters to approach higher efficiency. This study presents and examines four adaptive filter algorithms, including least-mean-square, normalized least-meansquare, recursive-least-square, and Wiener filter. The selected models are simulated, benchmarked, and contrasted in some characteristics of the performance. The presented filters are applied to four different experiments/environments to further examine their functionality. All of that is performed utilizing different step sizes to monitor two compromised result parameters: performance and execution time. Eventually, the best adaptive filter possessing the optimal parameters and step size is acquired for electrocardiogram signals enabling physicians and health professionals to deal with electrocardiogram signals efficiently, empowering them to accurately and quickly diagnose any sign of heart problems. Simulation results further designate the superiority of the presented models. Keywords: Electrocardiogram signal · Least mean square · Noise cancellation · Normalized least mean square · Recursive least square · Wiener filter © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 330–342, 2023. https://doi.org/10.1007/978-981-99-4742-3_27

Electrocardiogram Signal Noise Reduction Application Employing

331

1 Introduction Signal processing is concerned with detecting and eliminating noise and interference from input signals. The removal process is called filtering, and it is an essential aspect of signal processing [12]. For data processing, adaptive systems can be utilized, and then their parameters should be adjusted according to the current conditions [3]. These adaptive systems are exploited to filter the noise from the input to recover the original signal. There are unpredicted different types of noise that can interfere with speech signals, such as quantification interference. The interference occurs because of the coding and processing of transmitted signals or the noise coming from the transmission path itself, or it could be added to the signal as background noise. Adaptive learning involves continuously adjusting the parameters using an algorithm in order to achieve the best results [10]. Adaptive learning leverages input information alongside the parameters already defined to affect the output in different ways. To minimize the error signal $e$, the pulse response vector has to keep on changing its value. The general adaptive linear filter is shown in Fig. 1.

Fig. 1. General block diagram of adaptive linear filter.

The aim of finding the filter coefficient is to optimize the error function. Solving the function can be accomplished by using stochastic or deterministic approaches. The leastmean-square (LMS) and normalized least-mean-square (NLMS) algorithms require applying a stochastic approach [9]. On the contrary, the recursive-least-square (RLS) model involves a calculation of a large number of samples; therefore, the deterministic approach is employed [7]. The paper is organized as follows. Section 2 describes the noise cancellation methodology and the theory behind each selected algorithm. Section 3 encompasses the implementation results and discussion. Lastly, the conclusion and future work are provided in Sect. 4.

332

A. Essa et al.

2 Noise Cancellation In any communication system or even inside electronic circuits, the clean version of the transmitted signal that the user desires is not always received. There is always some unwanted noise getting added to the transmitted signal. Filters can be used to remove and decrease the destructive effect of noise on the original signals. For instance, noise cancellation devices in headphones pass the sound signals while rejecting any source of noise that may affect the purity of the sound signal. Modern technologies in the headphone industry are now using active noise cancellation (ANC). ANC works on the principle of allowing the sound signal to be conveyed while decreasing the noise by canceling it by exploiting a generated inverted signal of the noise. A 180 degree out of phase version of the noisy signal is added to the noise signal [5]. Theoretically, the best case is that the noise signals cancel each other out completely. This paper utilizes the additive white Gaussian noise (AWGN) as a noise signal to corrupt the original message. The noise signal is added to the experiments to mimic the effect of the random processes that may occur in nature. The AWGN can be found in MATLAB build-in functions and get implemented in simulations [8]. 2.1 Least-Mean-Square (LMS) LMS is one of the most frequently used algorithms nowadays. The simple math representation and the wide range of applications that can be implemented using the algorithm contribute to the importance of LMS. Meanwhile, the LMS technique is not suitable for dynamic noise. This algorithm is based on calculating the mean quadratic deviation of the output error signal. Coefficients are adjusted based on the variation of the sample to minimize the mean squared error (MSE) [4]. The output signal is expressed by: y(n) = wt (n)x(n)

(1)

where x(n) represents the input vector, w(n) denotes the filter coefficient vector. The exact mathematical equation to compute the error is as follows: e(n) = d (n) − y(n)

(2)

where d(n) indicates the desired signal and y(n) is the algorithm output signal from the filter. In addition, the recursion function of the filter can be defined as: w(n + 1) = w(n) + 2u ∗ e(n)x(n)

(3)

where (u) denotes the step size of the adaptive filter. This step size dramatically impacts the convergence rate [5].

Electrocardiogram Signal Noise Reduction Application Employing

333

Having a fixed step size parameter for each iteration is a disadvantage because this requires a prior understanding of the input signal before filter application which is rarely achievable in practice. The condition of being convergence is: 0 0, ∀k : wk  = 1. Parameter λ is the prior sparse density’s weight of H , and it is calculated by the formula of λ = 0.14σ , where parameter σ is the mean square error of an image, α and β are the sparse control scale calculated according to the sparse measurement formula defined as follows:  √ n − |xi | i sparseness(X ) = √ (4)  2 xi n−1 i

380

L. Shang et al.

where X is a non-negative vector, and n is the dimension of X . The larger the values of α and β are, the greater the sparsity of the matrices of W and H is. When the value of α is small, each column of the feature base matrix W reflects the global features of an image. On the contrary, each column of H reflects the local features of an image. And the larger the value of β is, the better the sparsity of feature coefficients is, but the less the information retention is, therefore, the values of α and β influence the sparsity of the feature vectors and the retention of feature information, and they have also important influence on the feature recognition precision. In test, it is necessary to determine the best selection value after multiple experiments. In the second item of Eq. (3), the sparse punitive   function f (·) equates the negative logarithmic density of hj , namely f (·) = − log p hj . The density model was selected as follows:   1 [0.5d (d + 1)](0.5d +1) p hj = (d + 2) √   (d +3) 2b 0.5d (d + 1) + hj b

(5)

where b, d > 0, b is the scale adjusting parameter, and d is the sparse scale adjusting parameter. And they are selected according to the following Equations:    ⎧ ⎨b = E 2 hj (6) √ ⎩ 2−k+ k(k+4) d= 2k−1 where k = b2 f (0)2 = b2 (− log p(0))2 . To illustrate the difference in sparseness intensity, Fig. 1 shows a diagram of different degrees of sparseness, and the dotted line denotes the Gauss density distribution, the solid line denotes the Laplace density distribution, and the dotted line marked “*” denotes the strongest sparse distribution described in Eq. (5). 0.8 Gaussian density Laplace density Strongly sparse density

0.7 0.6 0.5 0.4 0.3 0.2 0.1 0 -3

-2

-1

0

1

2

3

Fig. 1. The sparse distribution of different sparse density model. The dotted line is Gaussian density. The solid line is Laplace density. The dotted plus “*” is the sparse density defined as Eq. (5).

Palmprint Recognition Utilizing Modified LNMF Method

381

3.2 The Updating Rules The minimization of the cost function (see Eq. (3)) can be done by updating W and H in turn. Firstly, the initial values of matrices W and H are chosen randomly, then fixed the feature basis matrix W , the coefficient matrix H is updated, in the same way, fixed the matrix H , the matrix W is updated. Here matrices W and H are updated by using the gradient descent algorithm. The responding learning rules are derived as follows:  ⎧     T ⎪ ⎨ (Hnew )ij = (Hold )ij − μH Wij Wij (Hold )ij − Vij + λ f  hj − β hj   (7) (Wnew )ij = (Wold )ij − μW (Wold H )ij − Vij + α wi ⎪ ⎩ Wnew 2 = 1 where μH and μW are  respectively    the iterative steps in updating the matrices of H  = − log and W , and f h   p hj  is the first-order derivative of the sparse penalty   j function f hj , where p hj is defined as Eq. (5). It is noted that in iteration the matrices H and W should be guaranteed to be non-negative values. And by adjusting the iteration steps of μH and μW , the objective function value can be reduced to the global minimum.

4 RBPNN Classifier The radial basis probabilistic neural network (RBPNN) model [5] is shown in Fig. 2. It can be seen that this network consists of four layers. The first hidden layer is a nonlinear processing layer, generally consisting of the selected centers from training samples. The second hidden layer selectively sums the outputs of the first hidden layer according to the categories, where the hidden centers belong to. Namely, the connection weights between the first hidden layer and the second hidden layer are 1’s or 0’s. For pattern recognition problems, the outputs in the second hidden layer need to be normalized. The last layer for the RBPNN is just the output layer.

X1

y1

X2 yk Xm yM XN

Fig. 2. The structure of RBPNN model.

In mathematics, for input vector x, the actual output value of the ith output neuron of the RBPNN, yia , can be expressed as: yia =

M  k=1

wik hk (x)

(8)

382

L. Shang et al.

hk (x) =

nk 

φi (x − cki 2 ) · · · k = 1, 2, 3, · · · , M ,

(9)

i=1

where hk (x) is the kth output value of the second hidden layer of the RBPNN, wik is the synaptic weight between the kth neuron of the second hidden layer and the ith neuron of the output layer of the RBPNN, cki represents the ith hidden center vector for the kth pattern class of the first hidden layer, nk represents the number of hidden center vector for the kth pattern class of the first hidden layer, ·2 is Euclidean norm, and M denotes the number of the neurons of the output layer and the second hidden layer, or the pattern class number for the training samples set; φi (·) is the kernel function, which is generally Gaussian kernel function. φi (x − cki 2 ) can be written as:   x − cki 22 φi (x − cki 2 ) = exp − (10) σi2 where σi is the shape parameter for Gaussian kernel function. Generally, the training algorithms for the RBPNN include orthogonal least square algorithm (OLSA) and recursive least square algorithms (RLSA) [11, 12], etc. These two methods have the common advantages of fast convergence and good convergent accuracy. The RLSA, which requires good initial conditions, however, is to fit for those problems with the large training samples set. As the OLSA make full use of matrix computation, such as orthogonal decomposition algorithm of matrices, its training speed and convergent accuracy is faster and higher than the ones of the RLSA. Therefore, the OLSA is preferred to train the RBPNN in this paper. For N training samples corresponding to M pattern class, considering the form of matrix, Eq. (9) can be written as [10]: Y a = HW

(11)

where Y a and H are both an N × M matrix, W is a square matrix of M × M . According to literature [11], it can be known that the synaptic weight matrix W between the output layer and the second layer of the RBPNN can be solved as follows: W = R−1 Yˆ

(12)

where R is an M × M upper triangle matrix with the same rank as H , and Yˆ is an M × M matrix. Both of them can be respectively obtained as follows: ⎡ ⎤   R Yˆ ⎢ ⎥ T (13) H = Q × ⎣ · · ·⎦, Q ×Y = Y˜ 0 where Q is an N × N orthogonal matrix with orthogonal columns satisfying Q QT = QT Q = I , and Y˜ is an (N − M ) × M matrix. Equation (13) expresses the orthogonal decomposition of the output matrix H of the second hidden layer of the RBPNN.

Palmprint Recognition Utilizing Modified LNMF Method

383

5 Experimental Results and Analysis 5.1 Test Data Preprocessing In test, the Hong Kong Polytechnic University (PolyU) palmprint databaseis used to perform the task of palmprint recognition. This database contains 600 palm images with the size of 128 × 128 pixels from 100 users, with 6 images from each individual. For each person, the first three images were used as training data while the remaining ones were treated as test data. And let pixels be variables and palm images be observation, then the training set X (each column is an image) is the size of 1282 × 300. For the convenience for calculating, PCA is used to make the training data whitened and a dimension reduced from 1282 to an appropriate dimension, denoted by k. Namely, let Pk denote the matrix containing the first k principal component axes in its columns and its rows denote the pixels. Then, the principal component coefficient matrix Rk is represented by the formula of Rk = X T Pk . When setting k to be 16, the first 16 principal component axes of the image set (columns of P k ) are shown in Fig. 3. Therefore, in order to reduce the dimensionality of the input, instead of performing the feature extraction algorithm directly on the 1282 image pixels, this algorithm was performed on the first k PCA coefficients of palmprint images. These coefficients RTk comprised the columns of the input data matrix. The statistically coefficients for training images are computed as Utrain = W ∗ RTk . Where matrix W was k × k pixels, resulting in k coefficients in Utrain for each palmprint image. The representation for test images was obtained in the columns of Utest as follows: T  T ∗ Pk Utest = W ∗ RTtest = W ∗ Xtest

(14)

where each column of the weight matrix W −1 found by the modified LNMF algorithm attempts to get close to a cluster of images that look similar across pixels.

Fig. 3. First 16 PCA basis of palmprint images.

384

L. Shang et al.

5.2 Learning of Feature Bases For the training set, each palmprint image with the size of 128 × 128 pixels was randomly sampled L times with a p × p image patch, and each image patch was converted into one column, thus, thus, each image was converted into a matrix with p2 ×L pixels, so the size of training set Xtrain was p2 ×300L pixels. Otherwise, the set Xtrain must be non-negative in test. As long as a negative value occurs in Xtrain , the iteration would be terminated. Considered the different feature dimensions, the feature extraction algorithms based on NMF, LNMF and the modified LNMF were discussed here. Limited by the length of the paper, only the case where the characteristic dimension is 16, 36 64 and 121 was considered. Figure 4 showed feature basis images with different dimensions obtained by the modified LNMF algorithm. As a comparison, the feature bases obtained by NMF and LNMF were also shown in Fig. 5 and Fig. 6. It is clear to see that, the larger the feature dimension is, the better locality the feature bases no matter which algorithm it is. And under the same feature dimension, it was distinct to see that the sparsity and locality of the modified LNMF feature bases were hardly better than those of NMF and LNMF. Otherwise, in despite of the type of algorithms, it was found that the larger the feature dimension was, the slower the convergence speed was. So, both consider the calculated time and the validity of features extracted, the maximum feature basis dimension here was chosen as 64 by PCA test.

(a) 16 dimension

(b) 36 dimension

(c) 64 dimension

(d) 121 dimension

Fig. 4. The modified LNMF basis images with different feature dimension.

Palmprint Recognition Utilizing Modified LNMF Method

(a) 16 dimension

(b) 36 dimension

(c) 64 dimension

(d) 121 dimension

Fig. 5. The LNMF basis images with different feature dimension.

(a) 16 dimension

(b) 36 dimension

(c) 64 dimension

(d) 121 dimension

Fig. 6. The NMF basis images with different feature dimension.

385

386

L. Shang et al.

5.3 Recognition Results According to Subsect. 4.1, feature coefficients Utrain of the training set and Utest of the test set could be obtained by utilizing three algorithms of the modified LNMF, LNMF and NMF. Then, selecting three classifiers, namely, Euclidean distance, the probabilistic neural network (PNN), and the radial basis probabilistic neural network (RBPNN), the palmprint recognition task can be implemented. First, to determine the appropriate feature length, we perform the recognition task of PCA with different k principal components by using three classifiers. The corresponding recognition results were shown in Table 1. From Table 1, it can see that PCA with 64 principal components yield the best performance for Euclidean and RBPNN, and after this point, the recognition rate change little or drops, which is not that more principal component are better. Otherwise, when the number of principal coefficients is not greater than 64, PNN classifier is better than Euclidean distance, on the contrary, the recognition performance of PNN drops. Furthermore, considering computational complexity, the suitable feature length 64 was selected here. The recognition rates of our modified LNMF was listed in Table 2. Otherwise, we compared the our LNMF method with methods of LNMF and NMF in the case of 64 principal components, and the comparison results were also shown in Table 2. It is clearly seen that under each classifier the recognition rate based on our LNMF feature coefficients is higher than those based on LNMF and NMF. And the recognition performance of RBPNN classifier is higher than those of PNN and Euclidean distance classifiers under each algorithm. Therefore, from the above experimental results, it can be concluded that the palmprint recognition method based on the modified LNMF and RBPNN classifier can achieves higher statistical recognition rate. This method is indeed effective and efficient in practical applications. Table 1. Recognition results of PCA using different classifiers Number of principal coefficient

Euclidean distance

PNN

RBPNN

25

70.33

74.36

77.67

36

88.52

89.51

90.75

49

91.33

91.00

92.52

64

91.33

84.33

93.33

81

89.67

88.13

93.33

121

86.56

84.33

91.33

Palmprint Recognition Utilizing Modified LNMF Method

387

Table 2. Recognition rate of three types of different classifiers with different algorithms. Recognition Methods (k = 64)

Euclidean distance (%)

PNN (%)

RBPNN (%)

NMF

87.65

89.53

91.67

LNMF

91.72

92.97

94.56

Modified NMF

93.83

95.75

97.67

6 Conclusions A novel palmprint recognition method based on the modified LNMF algorithm is discussed in this paper. In the space of principal coefficients, palmprint features can be extracted successfully by the modified LNMF algorithm, which behave clearer locality and sparseness than those of LNMF and NMF. Utilizing these features, considering three classifiers, the recognition task is implemented easily. Compared with other methods, the simulation results show that the our algorithm is the best under the same classifier. Similarly, under each algorithm, the recognition property of RBPNN is the best than other two classifiers used here. In a word, the experimental results testify our palmprint recognition method proposed is indeed efficient. Acknowledgement. . This work was supported by the grants of National Science Foundation of China (No. 61972002).

References 1. Cappelli, R., Ferrara, M., Maio, D.: A fast and accurate palmprint recognition system based on minutiae. IEEE Trans. Syst. Man Cybern-Part B:Cybern. 42(3), 956–962 (2012) 2. Zhang, D., Zuo, W., Yue, F.: A coparative study of palmprint recognition algorithms. AMC Comput. Surv. Article 2 44(1), 1–37 (2012) 3. Lunke, F., Lu, G.M., Jia, W., Teng, S.H., Zhang, D.: Feature extraction methods for palmprint recognition: a survey and evaluation. IEEE Trans. Syst. Man Cybern. Syst. 49(2), 346–362 (2019) 4. Li, L., Zhang, Y.J.: A survey on algorithms of non-negative matrix factorization. Acta Electronica Sinica 36, 737–743(2008) 5. Qin, X.L., Wu, Q.G., Han, Z.Y., Yu, J.: Palmprint recognition method based on multi-layer wavelet transform with different wavelet basis functions. J. Zhengzhou Inst. Light Ind. (Nat. Sci. Edn) 26(3), 29–32 (2011) 6. Draper, B.A.: Kyungim baek, marian stewart bartlett: recognizing faces with PCA and ICA. Comput. Vis. Image Underst. 9(1–2), 115–137 (2003) 7. Tao, F., Stan, Z., Li, S.H.Y.: Local non-negative marix factorization as a visual representation. In: Processing of the 2nd International Conference on Development and Learning, Cambridge, MA, USA, pp. 178–183, 12–15 June 2002 8. Hosseini, B., Hammer, B.: Non-negative local sparse coding for subspace clustering. In: Duivesteijn, W., Siebes, A., Ukkonen, A. (eds.) Advances in Intelligent Data Analysis XVII. IDA 2018. LNCS, vol. 11191, pp. 137–150. Springer, Cham (2018). https://doi.org/10.1007/ 978-3-030-01768-2_12

388

L. Shang et al.

9. Zhao, S., Zhang, B.: Learning salient and discriminative descriptor for palmprint feature extraction and identification. IEEE Trans. Neural Netw. Learn. Syst. 31(12), 5219–5230 (2020) 10. Fei, L., Zhang, B., Teng, S., Jia, W., Wen, J., Zhang, D.: Feature extraction for 3-D palmprint recognition: a survey. IEEE Trans. Instrum. Measur. 69(3), 645–654 (2020) 11. Huang, D.S.: Radial basis probabilistic neural networks: model and application. Int. J. Pattern Recognit Artif Intell. 13(7), 1083–1101 (1999) 12. Zhao, W.B., Huang, D.S.: Application of recursive orthogonal least squares algorithm to the structure optimization of radial basis probabilistic neural networks. In Proceedings of 6th International Conference on Signal Processing (ICSP2002), Beijing, China, pp. 1211–1214, August 2002

PECA-Net: Pyramidal Attention Convolution Residual Network for Architectural Heritage Images Classification Shijie Li, Yifei Yang, and Mingyang Zhong(B) College of Artificial Intelligence, Southwest University, Chongqing, China [email protected]

Abstract. Heritage buildings have essential cultural and economic values, with constant change and destruction. Existing technologies that classify architectural heritage images can promote the recording and protection of heritage buildings. However, the background information of architectural heritage images is complex and changeable, making it difficult to extract the key feature information. In order to address the above problem, this paper proposes a pyramidal attention convolution residual network (PECA-Net) based on ResNet18, in which pyramidal convolution (PyConv) and attention mechanism have been adopted. PyConv is introduced into residual blocks instead of standard convolution, which can extract detailed information on different scales without increasing the parameter space of the model. Then, we propose a dual-pooling channel attention mechanism (DP-ECA) that effectively improves the ability to extract key information. Compared with the original ResNet18, experimental results show that the accuracy of our model is increased by 1.86%, and the parameters and FLOPs (floating point operations) are reduced by 0.48M and 0.28G respectively. Keywords: Heritage buildings · Residual network · Pyramidal convolution · Attention mechanism

1 Introduction Heritage buildings have essential historical, artistic, cultural, and scientific values [1]. However, due to many social and natural factors, the structure and quantity of heritage buildings have been seriously affected. The correct detection, classification, and segmentation of heritage buildings are important prerequisites to ensure their recording and protection. Due to the lack of professional knowledge and expert resources, it is difficult to identify a large number of heritage buildings only by manpower. Therefore, this paper focuses on how to utilize state-of-the-art technologies to build an effective building classification method. Most of the traditional classification methods are designed based on hand-crafted features and traditional machine learning methods [2–5]. For example, hand-crafted features include Scale Invariant Feature Transform (SIFT) [6], Histogram of Oriented Gradient (HOG) [7], Deformable Part Model (DPM) [8], etc. However, the hand-crafted © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 389–400, 2023. https://doi.org/10.1007/978-981-99-4742-3_32

390

S. Li et al.

features are specifically designed for given objects, which limits the applicability of the models. In recent years, neural networks such as convolutional neural networks (CNN) and deep models demonstrate strong autonomous learning ability and can quickly process a large number of image data without hand-crafted features [9]. Therefore, CNN-based methods have been widely used in the task of heritage building classification [10, 11]. However, the background information of architectural heritage images is complex and changeable, which leads to problems such as low recognition accuracy and loss of key information. In order to address the aforementioned problems, we propose a pyramidal attention convolution residual network (PECA-Net) based on ResNet18 [12], and the main contributions of this paper are listed as follows: • Pyramidal convolution (PyConv) is adopted to extract detailed information on different scales without increasing the parameter space. • A dual-pooling attention mechanism module (DP-ECA) based on Efficient Channel Attention (ECA) is proposed to effectively extract the key information. • Comprehensive evaluations have demonstrated the effectiveness and efficiency of the proposed method.

2 Related Work In this section, we present related work which addresses the heritage buildings classification task. Then we briefly introduce the multi-scale convolution and attention mechanism algorithms related to this study. 2.1 Heritage Buildings Classification Traditional heritage building classification techniques use hand-crafted features from architectural images, then classify them by machine learning. The authors in [2] use a method based on clustering and learning of local features to achieve architectural style classification of building facade towers. In addition, the authors in [3] propose another method to solve the problem of architectural style classification of facade towers, which extracts features with SIFT and then uses SVM for classification. The authors in [4] propose a feature extraction module based on image preprocessing of DPM to extract the morphological features of buildings and then applies SVM to classify them. Similarly, the authors in [5] use DPM to capture the morphological characteristics of basic architectural components and propose Multinomial Latent Logistic Regression (MLLR), which can solve the multi-class problem in latent variable models. However, the above methods rely too much on hand-crafted features, which leads to poor generalization performance and accuracy of the model. With the development of deep learning, CNN have occupied a dominant position in the field of image recognition. The authors in [10] use the advanced ResNet to classify the images of heritage buildings for the first time. The authors in [11] put forward a deep convolution neural network (DCNN), which uses sparse features at the input of the network combined with primary color pixel values to classify Mexican historical buildings.

Pyramidal Attention Convolution Residual Network

391

In addition, some researchers have developed automatic systems to classify heritage buildings. The authors in [1] set out an automatic multicategory damage detection technique, which uses CNN models based on image classification and feature extraction to detect damage to historic structures. The authors in [13] create the monumental heritage image dataset (MonuMAI) and designed the architecture element style automatic recognition system based on CNN. Compared with the traditional machine learning methods, the CNN-based methods have better applicability and performance in heritage building classification tasks. 2.2 Multi-scale Convolution Multi-scale convolution modules can extract the features of different receptive fields and enhance the model feature extraction ability. The authors in [14] propose the Atrous Spatial Pyramid Pooling (ASPP) module, which uses multiple parallel atrous convolution layers with different sampling rates to obtain multi-scale object information. The authors in [15] propose a Scale Aggregation Block, which realizes multi-scale feature extraction based on downsampling and convolution of different sampling rates. Recently, the authors in [16] propose the Pyramidal Convolution (PyConv), which can process the input at multiple filter scales to extract information at different scales. Compared with the standard convolution, PyConv does not increase the additional computational cost and parameters. Furthermore, PyConv is very flexible and extensible, providing a large space of potential network architectures for various applications. 2.3 Attention Mechanism Attention mechanism has been widely used in the field of deep learning, such as target detection, speech recognition, and image classification [17]. The authors in [18] propose the Squeeze-and-Excitation (SE) module to learn the importance of the channels and to assign weights to different channels on the learned information. Later, the authors in [19] propose the Convolutional Block Attention Module (CBAM), which introduces two analytical dimensions, channel attention, and spatial attention, to realize the sequential attention structure from channel to space. In order to reduce the computational complexity of the model, the authors in [20] propose the Efficient Channel Attention (ECA) module to effectively captures cross-channel interaction and avoid dimension reduction, which is regarded as the improved version of the SE module. Furthermore, the structure of ECA is very simple and lightweight, which can effectively improve the performance of the model without increasing the computational overhead.

3 PECA-Net In this section, we introduce our PECA-Net in detail. Section 3.1 shows the overall framework of our PECA-Net, followed by the two key structures of PECA-Net, the PyConv block, and the DP-ECA module detailed in Sects. 3.2 and 3.3, respectively.

392

S. Li et al.

3.1 The Overall Framework of PECA-Net The structure of PECA-Net is illustrated in Fig. 1. It is mainly composed of four pyramidal attention convolution (PECA) modules. Every PECA module consists of two PyConv blocks and a DP-ECA module. PyConv is introduced into residual blocks instead of 3 × 3 convolution, which can extract detailed information on different scales without increasing the parameter space of the model. Then, the DP-ECA module is added behind PyConv to improve the model’s ability to extract key feature information in a large receptive field.

Fig. 1. The structure of PECA-Net.

3.2 Pyramidal Convolution Original ResNet18 uses a small and single convolution kernel, which makes it unable to get enough context information and some other useful detailed information. In order to enhance the representation ability of the model to the feature map, we introduce pyramidal convolution (PyConv) to improve the multi-scale feature extraction ability of the model. PyConv contains four levels of different types of convolution kernels, and the structure of these convolution kernels is a pyramid [21]. The kernel size increases from the first level (the bottom of the pyramid) to the fourth level (the top of the pyramid), while the depth of the kernel decreases from the first level to the fourth level. As illustrated in Fig. 2, the kernel sizes of the four levels are 3 × 3, 5 × 5, 7 × 7, and 9 × 9, respectively. Each layer of the pyramid contains different types of filters, and the size and depth of the filters are variable. In order to use different kernel depths at each level, PyConv puts forward the idea of grouped convolution to cut the input feature map into different groups and use convolution kernel to process them independently [16]. We use two examples to explain the grouped convolution. As illustrated in Fig. 3, there are four input feature maps in every example. Figure 3(a) represents the standard convolution, which only includes a single group of input feature maps. In this case, the depth of the kernels is equal to the number of input feature maps, and each output feature map is connected to all input feature maps. Figure 3(b) illustrates the group convolution (G = 2) in which the input feature maps are split into two groups. In this case, the kernels are applied independently to each group,

Pyramidal Attention Convolution Residual Network

393

and the depth of the kernels in each group becomes 1/2 of the number of input feature maps. The depth of the kernels decreases as the number of groups increases, and the calculation cost and parameters of convolution also decrease with the decrease of the depth in the kernel.

Fig. 2. The structure of pyramidal convolution.

Fig. 3. Grouped convolution.

3.3 Dual-Pooling ECA Attention Mechanism Module Attention mechanism has demonstrated great potential in improving the performance of CNN. The function of the attention mechanism is to select features, highlight important information, and suppress unnecessary information. However, many existing methods focus on designing more sophisticated attention modules to obtain better performance, which inevitably increases the complexity of the model [20]. ECA is a lightweight channel attention mechanism module, which can make the model resist the interference of redundant information and focus on the key information of effective targets. In order to reduce the computational complexity of the model,

394

S. Li et al.

ECA changes the full connection layer to 1D convolution to learn the channel attention information on the basis of SE, which can avoid dimension reduction and effectively capture cross-channel interaction. Furthermore, ECA uses a dynamic convolution kernel to do 1D convolution, which can extract features in different ranges from different input feature maps. The definition of a dynamic convolution kernel is as follows:    log2(C) b (1) +  . k = ψ(C) =  γ γ odd where k is the convolution kernel size, C is the number of channels, | t |odd indicates the nearest odd number of t, γ is set to 2 and b is set to 1. Pooling layer is the basic component of CNN, which is mainly used for feature reduction and data compression. Commonly used pooling operations include global average pooling (GAP) and global max pooling (GMP). Original ECA module uses the GAP to compress the spatial features of the input feature map. However, GAP pays more attention to the whole feature information, which may lead to the loss of some detailed feature information in the feature compression stage. In order to make full use of feature information, we propose a dual-pooling ECA (DP-ECA) module based on ECA attention mechanism, and its structure is illustrated in Fig. 4. Different from ECA module, DP-ECA uses GAP and GMP together. GAP can better retain the global features of buildings, while GMP can better retain the local features of buildings. Therefore, the weighted dual-pooling structure based on GAP and GMP can better retain the key information of buildings, thus extracting more diverse features.

Fig. 4. The structure of dual-pooling ECA attention mechanism module.

The process of learning channel attention by DP-ECA module is shown as follows: ω = σ (1Dk (VA ) · α + 1Dk (VM ) · (1 − α)).

(2)

where σ indicates sigmoid activation function, 1Dk indicates 1D convolution with kernel size k, VA and VM indicate the feature vector passing through GAP and GMP, respectively.

Pyramidal Attention Convolution Residual Network

395

Firstly, the input feature maps are respectively passed through GAP and GMP to obtain two 1 × 1 × C compressed feature vectors. Secondly, 1D convolution with convolution kernel size k is used to capture two compressed feature vectors’ crosschannel interaction information respectively. After that, the two feature vectors are given different weights and added channel by channel. Then, the processed vector is activated by sigmoid activation function to generate the final feature vector. Finally, the feature vector with channel attention is multiplied by the original input feature map channel by channel to obtain the feature map with channel attention. In order to study the influence of different weight factor a on DP-ECA module, we set weighting factor a for GAP and weighting factor 1 – a for GMP, ResNet18 is used as the backbone network, and the experimental results on the validation set are shown in Table 1. When the a is 0 or 1, it means that only GMP or GAP is used. It can be seen that when the a is 0.5 or 0.75, the overall accuracy of the network is higher than in other cases, and the accuracy of the model reaches its highest when the a is 0.75. The experimental results show that GAP and GMP promote each other as a whole, and using two weighted pooling layers is better than using a single one. Table 1. Experimental results under different weighting factors a. Trail

Weighting factors a

Accuracy/%

1

0

95.69

2

0.25

95.88

3

0.5

96.41

4

0.75

96.47

5

1

96.04

4 Experiments and Results 4.1 Datasets The dataset named AHE_Dataset used in this paper is published in [10]. This dataset collects 10,235 architectural heritage elements images, which are classified into 10 types of heritage buildings. The training set and validation set are divided in the ratio of 8: 2. In addition, 1404 images have been compiled which form an independent set of tests. 4.2 Data Augmentation In the deep learning models, enough data is needed to avoid overfitting. Due to the high cost of preparing large-scale datasets, appropriate data augmentation plays an important role in improving the performance of the network. Therefore, we use common data augmentation means for each picture in the training process (e.g., rotation, random flip, brightness dithering, random erasing).

396

S. Li et al.

4.3 Experimental Setup and Evaluation Metrics This study uses the machine learning library PyTorch 1.11.0 + cu113 to implement the proposed algorithm and Python 3.7 for building the network model. In the process of model training, the optimizer is stochastic gradient descent (SGD), the learning rate is 0.001, the momentum is 0.9, the weight decay is 0.0001 and the batch size is 128. Cross entropy is used as the loss function in training. For the evaluation criteria of the experiment, we use the top-1 accuracy, which is commonly used in image classification tasks. In addition, the number of parameters and FLOPs (floating point operations) are used to evaluate the model complexity. 4.4 Impact of Insertion Position on DP-ECA Module Performance Finally, we try to maximize the function of the attention mechanism as much as possible, so as to better improve the classification performance of the model. As shown in Fig. 5, we introduce DP-ECA into different positions of ResNet18 to compare the classification results, and the experimental results are shown in Table 2. It can be seen that the standard insertion position (Fig. 5(a)) has higher accuracy than the other two insertion positions. After introducing the DP-ECA module, the model can focus on the useful feature information in architectural heritage images, which can improve the performance of image classification in complex contexts.

Fig. 5. Different insertion positions of DP-ECA module.

Table 2. Experimental results under different insertion positions. Trail

Design

Accuracy/%

1

DP-ECA_Std

96.47

2

DP-ECA_Pre

96.35

3

DP-ECA_Post

96.18

Pyramidal Attention Convolution Residual Network

397

4.5 Comparison with Other Models In order to verify the effectiveness of the method proposed in this paper, we compare the PECA-Net with other deep models, including AlexNet [22], VGG16 [23], Inception V3 [24], ResNet18, ResNet50 [12], RepVgg_A2 and RepVgg_B1 [25]. The experimental results are shown in Table 3. Compared with the original ResNet18, the accuracy of our model is increased by 1.86%, and the parameters and FLOPs are reduced by 0.48M and 0.28G respectively. Furthermore, compared with other deep models, our method achieves higher accuracy with relatively fewer parameters and FLOPs, which verifies the effectiveness of the model proposed in this paper. Table 3. The classification results of different models. Model name

Accuracy/%

Params/M

FLOPs/G

AlexNet

89.73

61.11

Vgg16

90.57

138.36

Inception V3

92.96

23.83

2.86

ResNet18

93.32

11.69

1.82

ResNet50

93.71

25.56

4.13

RepVgg_A2

93.64

26.82

5.72

RepVgg_B1

94.48

57.42

Ours

95.18

11.21

0.71 15.4

13.2 1.54

4.6 Ablation Experiment In order to verify the influence of various improved strategies proposed in this paper on the model, ResNet18 and ResNet50 are used as backbone networks respectively, and one or more of the three modules, PyConv, ECA module, and DP-ECA module, are added to compare the effects of different schemes on model classifying performance. The experimental results are shown in Table 4. After adding PyConv to ResNet18, the accuracy is increased by 0.42%, and the parameters and FLOPs are reduced by 0.48M and 0.28G respectively. After adding ECA and DP-ECA modules, the accuracy is increased by 0.61% and 1.05% respectively without increasing additional model complexity. Furthermore, the accuracy of the two combined modules is improved by 1.51% and 1.86% respectively. After adding PyConv to ResNet50, the accuracy is increased by 0.31%, and the parameters and FLOPs are reduced by 0.71M and 0.26G respectively. After adding ECA and DP-ECA modules, the accuracy is increased by 0.54% and 0.87% respectively. Because these two modules are very light, the parameters and FLOPs are not significantly increased. Furthermore, the accuracy of the two combined modules is improved by 1.26% and 1.59% respectively.

398

S. Li et al.

Compared with the backbone network, the overall performance of these five improved strategies has been improved to varying degrees, which verifies the effectiveness of the modules added to the model. Table 4. Ablation experiment results. BackBone

PyConv

ECA

DP-ECA

Acc/%

Params/M

FLOPs/G

ResNet18

– √





93.32

11.69

1.82

– √



93.74

11.21

1.54

– √

93.93

11.69

1.82

94.37

11.69

1.82

– √

94.83

11.21

1.54

95.18

11.21

1.54





93.71

25.56

4.13

– √



94.02

24.85

3.87

– √

94.25

25.56

4.14

94.58

25.56

4.14

– √

94.97

24.85

3.88

95.30

24.85

3.88

– – √ √ ResNet50

– √ – – √ √

– √ –

– √ –

5 Conclusion In this paper, a novel PECA-Net based on ResNet18 for architectural heritage image classification is proposed. We introduce PyConv instead of the standard convolution, which can extract detailed information on different scales without increasing additional model complexity. Benefiting from the inspiration of the ECA attention mechanism, we propose the DP-ECA module, which can better enhance the feature extraction ability of the model for key information by connecting two weighted pooling layers in parallel. Compared with the original ResNet18, experimental results show that the accuracy of PECANet is increased by 1.86%, and the parameters and FLOPs are reduced by 0.48M and 0.28G respectively. Furthermore, compared with other deep models, PECA-Net achieves higher accuracy with relatively fewer parameters and FLOPs. The method proposed in this paper can provide an effective reference for the identification and classification of heritage buildings in the future. Acknowledgement. This work is supported by the Oversea Study and Innovation Foundation of Chongqing (No. cx2021105) and Fundamental Research Funds for the Central Universities (No. SWU021001).

Pyramidal Attention Convolution Residual Network

399

References 1. Samhouri, M., Al-Arabiat, L., Al-Atrash, F.: Prediction and measurement of damage to architectural heritages facades using convolutional neural networks. Neural Comput. Appl. 34(20), 18125–18141 (2022) 2. Shalunts, G., Haxhimusa, Y., Sablatnig, R.: Classification of gothic and baroque architectural elements. In: 2012 19th International Conference on Systems, Signals and Image Processing (IWSSIP), pp. 316–319. IEEE (2012) 3. Shalunts, G.: Architectural style classification of building facade towers. In: Bebis, G., et al. (eds.) ISVC 2015. LNCS, vol. 9474, pp. 285–294. Springer, Cham (2015). https://doi.org/10. 1007/978-3-319-27857-5_26 4. Zhao, P., Miao, Q., Song, J., Qi, Y., Liu, R., Ge, D.: Architectural style classification based on feature extraction module. IEEE Access 6, 52598–52606 (2018) 5. Xu, Z., Tao, D., Zhang, Y., Wu, J., Tsoi, A.C.: Architectural style classification using multinomial latent logistic regression. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) ECCV 2014. LNCS, vol. 8689, pp. 600–615. Springer, Cham (2014). https://doi.org/10.1007/978-3319-10590-1_39 6. Lowe, D.G.: Object recognition from local scale-invariant features. In: Proceedings of the Seventh IEEE International Conference on Computer Vision, vol. 2, pp. 1150– 1157. IEEE (1999) 7. Dalal, N., Triggs, B.: Histograms of oriented gradients for human detection. In: 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR 2005), vol. 1, pp. 886–893. IEEE (2005) 8. Felzenszwalb, P., McAllester, D., Ramanan, D.: A discriminatively trained, multiscale, deformable part model. In: 2008 IEEE Conference on Computer Vision and Pattern Recognition, pp. 1–8. IEEE (2008) 9. Alzubaidi, L., et al.: Review of deep learning: concepts, CNN architectures, challenges, applications, future directions. J. Big Data 8, 1–74 (2021) 10. Llamas, J.M, Lerones, P., Medina, R., Zalama, E., Gómez-García-Bermejo, J.: Classification of architectural heritage images using deep learning techniques. Appl. Sci. 7(10), 992 (2017) 11. Obeso, A.M., Benois-Pineau, J., Acosta, A.Á.R., Vázquez, M.S.G.: Architectural style classification of mexican historical buildings using deep convolutional neural networks and sparse features. J. Electron. Imaging 26(1), 011016 (2017) 12. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 13. Lamas, A., et al.: Monumai: dataset, deep learning pipeline and citizen science based app for monumental heritage taxonomy and classification. Neurocomputing 420, 266–280 (2021) 14. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Trans. Pattern Anal. Mach. Intell. 40(4), 834–848 (2017) 15. Zhang, J., Zhang, J., Hu, G., Chen, Y., Yu, S.: Scalenet: a convolutional network to extract multi-scale and fine-grained visual features. IEEE Access 7, 147560–147570 (2019) 16. Duta, I.C., Liu, L., Zhu, F., Shao, L.: Pyramidal convolution: rethinking convolutional neural networks for visual recognition. arXiv preprint arXiv:2006.11538 (2020) 17. Niu, Z., Zhong, G., Yu, H.: A review on the attention mechanism of deep learning. Neurocomputing 452, 48–62 (2021) 18. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7132–7141 (2018)

400

S. Li et al.

19. Woo, S., Park, J., Lee, J.-Y., Kweon, I.S.: CBAM: convolutional block attention module. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 3–19. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_1 20. Wang, Q., Wu, B., Zhu, P., Li, P., Zuo, W., Hu, Q.: Eca-net: efficient channel attention for deep convolutional neural networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11534–11542 (2020) 21. Xia, Y., Xu, X., Pu, F.: Pcba-net: pyramidal convolutional block attention network for synthetic aperture radar image change detection. Remote Sens. 14(22), 5762 (2022) 22. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. Commun. ACM 60(6), 84–90 (2017) 23. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 24. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 25. Ding, X., Zhang, X., Ma, N., Han, J., Ding, G., Sun, J.: Repvgg: making VGG-STYLE convnets great again. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13733–13742 (2021)

Convolutional Self-attention Guided Graph Neural Network for Few-Shot Action Recognition Fei Pan(B) , Jie Guo, and Yanwen Guo National Key Lab for Novel Software Technology, Nanjing University, Nanjing, China [email protected], {guojie,ywguo}@nju.edu.cn

Abstract. The goal of few-shot action recognition is to recognize unseen action classes with only a few labeled videos. In this paper, we propose Convolutional Self-Attention Guided Graph Neural Network (CSA-GNN) for few-shot action recognition. First, for each video, we extract features of video frames sampled from the video and obtain a sequence of feature vectors. Then, a convolutional self-attention function is applied to the sequences to capture long-term temporal dependencies. Finally, a graph neural network is utilized to predict the distance between two sequences of feature vectors explicitly, which approximates the distance between the corresponding videos. By this means, we effectively learn the distance between the support and query videos without estimating their temporal alignment. The proposed method is evaluated on four action recognition datasets and achieves state-of-the-art results in the experiments. Keywords: Action recognition · Few-shot learning · Neural Network

1 Introduction Action recognition plays an important role in real-world scenarios with wide applications including video surveillance and content-based video retrieval. Most of the recently proposed action recognition methods [1, 2] are based on deep neural networks, which require a significant number of labeled examples. The huge overhead of collecting and labeling a large number of videos poses great challenges to these methods in real-world applications. In order to reduce such a high overhead, few-shot action recognition which aims to recognize unseen classes with a limited number of labeled videos has attracted increasing research attention recently. Few-shot action recognition is challenging due to the limited amount of labeled examples available for training new action classes. This makes it difficult to learn discriminative features that can generalize well to unseen action categories. Most state-ofthe-art few-shot action recognition approaches are based on metric learning, and can be roughly classified into two categories: global feature-based methods [3, 4] and temporal alignment-based methods [5, 6]. Global feature-based methods compress videos into fixed-length feature vectors, which may lose the temporal information of videos to a © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 401–412, 2023. https://doi.org/10.1007/978-981-99-4742-3_33

402

F. Pan et al.

certain extent, and cause a relatively large deviation in the learned distance between videos. Temporal alignment-based methods calculate the similarity between two videos by estimating their inter-frame correspondences. For instance, OTAM [5] aligns the support and query videos with DTW [7]. However, Temporal alignment-based approaches have several shortcomings. First, they may suffer from high computational complexity, which limits their applicability to long video sequences. Second, these methods can be sensitive to the choice of alignment algorithms and parameters, which may lead to suboptimal results. Finally, they may struggle to align sequences with complex temporal structures, such as videos in real-world scenarios.

Fig. 1. Overview of distance learning between the support and query videos. First, the feature vectors of a few frames sampled from each video are obtained with the feature embedding function fα . Then, the convolutional self-attention function gβ is employed to capture long-term temporal dependencies in videos. Subsequently, we construct a graph with pairwise distances between frames in the support and query videos. Finally, the graph is fed into the graph neural network hθ , which returns the distance between the two videos explicitly.

Inspired by the above insights, we propose Convolutional Self-Attention Guided Graph Neural Network (CSA-GNN) for few-shot action recognition, whose key is a sequence distance learning module. This module is comprised of a convolutional selfattention function and a graph neural network. For each video, we first sparsely sample a few frames and extract per-frame features to obtain a sequence of feature vectors. Then, the convolutional self-attention function is applied to the sequence of feature vectors. Self-attention [8] can be leveraged to model long-term temporal dependencies in sequences, which matches queries with keys at each timestamp. Different from canonical self-attention, convolutional self-attention [9] also explores the temporal information of video frames during matching. Subsequently, for two sequences of feature vectors of a given video pair, we compute the pairwise distances between elements in the sequences and generate a King’s graph with these distance scores. Finally, the graph is fed into a graph neural network to obtain the distance of the video pair. CSA-GNN can effectively learn the distance between the support and query videos, without estimating their temporal alignment. This relieves the limitation of the temporal alignment-based methods and improves the efficiency of

Convolutional Self-attention Guided Graph Neural Network

403

data utilization in the few-shot setting. The outline of the distance learning between the support and query videos is illustrated in Fig. 1. In summary, our paper makes three main contributions. – We propose Convolutional Self-Attention Guided Graph Neural Network (CSAGNN) for few-shot action recognition, which predicts the distance between sequences through a graph neural network explicitly. – We introduce convolutional self-attention into few-shot action recognition, which is combined with a graph neural network to effectively learn the distance between videos. – Extensive experiments on four challenging action recognition datasets (i.e., Kinetics, Something-Something V2 (SSv2), UCF101, and HMDB51) show that the proposed method achieves state-of-the-art performance.

2 Related Work Existing few-shot action recognition methods are mostly based on metric learning and focus on learning the distance between the support and query videos. CMN [4] introduces a multi-saliency embedding algorithm for feature representations of videos. Zhu et al. [10] propose a simple classifier-based model without estimating temporal alignment between videos, which outperforms several state-of-the-art approaches. TA2N [11] leverages a Two-Stage Action Alignment Network to address the misalignment problem in few-shot action recognition. Fu et al. [12] leverage the depth information to augment video representations and propose a depth guided adaptive instance normalization module to fuse the RGB and depth streams. He et al. [13] propose an implicit temporal alignment method called ITANet to compute the similarity between videos. TRX [14] compares the query video to sub-sequences of the videos in the support set to obtain query-specific class prototypes. PAL [15] introduces a hybrid attentive learning strategy for few-shot action recognition. Wu et al. [6] propose an innovative approach to few-shot action recognition by utilizing task-specific motion modulation and multi-level temporal fragment alignment. The proposed end-to-end Motion modulated Temporal Fragment Alignment Network (MTFAN) model incorporates a motion modulator and segment attention mechanism, enabling accurate recognition of actions in video sequences. Nguyen et al. [16] utilize appearance and temporal alignments to estimate the similarity between the support and query videos. By contrast, the proposed method introduces a sequence distance learning module which learns the distance between a pair of sequences with a graph neural network, without estimating their temporal alignment.

3 Method 3.1 Problem Formulation In the few-shot setting of video action recognition, there are three meta sets with no overlapping classes, namely the meta-training set, the meta-validation set, and the metatest set. We should train a model with the meta-training set, which can recognize novel classes with only a small number of labeled videos per class in the meta-test set. The meta-validation set is used to tune hyperparameters in the model.

404

F. Pan et al.

The model is trained with an episode-based strategy [17], which learns from a large number of episodes. For a N-way K-shot problem, each episode contains a support set S and a query set Q. There are N classes of videos in the support set, each of which has K examples. The goal is to classify the videos in the query set into one of the N classes. 3.2 Feature Embedding For videos, there is often a high degree of redundancy between adjacent video frames. Following TSN [18], we first split a given video V into T segments of equal length. Then, a video frame is randomly sampled from each segment. Through this sampling strategy, the obtained frames z1 , z2 , · · · , zT are uniformly distributed over the entire video V . Next, we employ a convolutional neural network fα parameterized by α to extract feature representations of the frames in the sequence (z1 , z2 , · · · , zT ). After feature embedding, the video V can be represented by a sequence of feature vectors: fα (V ) = (fα (z1 ), fα (z2 ), · · · , fα (zT )).

(1)

3.3 Sequence Distance Learning Module We use the matrices A = [a1 , · · · , aT ]T ∈ RT ×d and B = [b1 , · · · , bT ]T ∈ RT ×d to represent two sequences of vectors with length T . Each element in the sequences is a d -dimensional vector. In this subsection, we elaborate the sequence distance learning module, which calculates the distance between the sequences A and B. This module consists of a convolutional self-attention function and a graph neural network.

Fig. 2. (a) Canonical self-attention. (b) Convolutional self-attention. Conv,1 and Conv,k in the figure represent 1D convolution with kernel size 1 and k (k equals 3 in this example), respectively.

Convolutional Self-attention. As illustrated in Fig. 2(a), canonical self-attention in Transformer [8] matches queries with keys at each timestamp in the sequence. However, this matching mechanism fails to take into account the temporal information of elements in the sequence. In some real-world scenarios, multiple adjacent frames are necessary for determining the action at a certain timestamp in the video. Thus, we introduce convolutional self-attention [9] into few-shot action recognition, in order to leverage the temporal information of video frames during matching (see Fig. 2(b)). Specifically, given a sequence of vectors represented by the matrix H ∈ RT ×d , we transform H into a query matrix Q ∈ RT ×dq , a key matrix K ∈ RT ×dk and a value matrix V ∈ RT ×dv , respectively: Q = φq (H), K = φk (H), V = φv (H),

(2)

Convolutional Self-attention Guided Graph Neural Network

405

where the functions φq , φk and φv denote 1D convolution with the stride of 1 and the kernel size of k, k and 1, respectively. In addition, dq , dk and dv (dq = dk , dv = d ) represent the dimensions of the output features of the functions φq , φk and φv , respectively. By setting the value of k to be greater than 1, the generated query and key at each timestamp contain the contextual information in the temporal domain. After these transformations, the output matrix of convolutional self-attention that takes the matrix H as input is as follows: QKT O = H + softmax( √ )V, dk

(3)

where softmax() is the softmax function to ensure that each row of the output matrix sums to 1. In addition, we adopt residual connections in Eq. (3) to reduce model complexity and the chance of overfitting. In the following, the aforementioned convolutional self-attention is denoted by a function gβ parameterized by β. We feed the matrices A and B into the function gβ and obtain: X = gβ (A) ∈ RT ×d , Y = gβ (B) ∈ RT ×d .

(4)

Graph Neural Network. We first review the DTW algorithm [7], which computes the distance between two sequences by searching for an optimal alignment path. Given the sequences of vectors X and Y, DTW aims to find an optimal I ∗ such that:  I ∗ = arg min s((X[m] )T , (Y[n] )T ), (5) I

(m,n)∈I

where I is an alignment path with length L: I = ((m1 , n1 ), (m2 , n2 ), · · · , (mL , nL )). X[m] and Y[n] denote the mth row of the matrix X and the nth row of the matrix Y, respectively. Besides, the function s computes the distance between two vectors. DTW can be solved by dynamic programming algorithm with the following iterative equation:   (6) Dm,n = s((X[m] )T , (Y[n] )T ) + min Dm−1,n , Dm,n−1 , Dm−1,n−1 , where Dm,n is the distance corresponding to the optimal alignment path between the first m elements of the sequence X and the first n elements of the sequence Y. To make the minimum operator in Eq. (6) differentiable, the OTAM algorithm [5] approximates n  exp(−γ an ). However, it with the LogSumExp function: min(a1 , · · · , an ) ≈ − γ1 log i=1

temporal ordering should be strictly preserved in DTW: ml ≤ ml+1 ≤ ml + 1, nl ≤ nl+1 ≤ nl + 1, l ∈ {1, · · · , L − 1}. This is a stringent condition that is difficult to satisfy for videos taken in real-world scenarios. Let E ∈ RT ×T be a matrix with elements Emn = s((X[m] )T , (Y[n] )T ), m, n ∈ {0, · · · , T − 1} the distance between the sequences X and Y is:   ρDTW (X, Y) = DT ,T = s((X[m∗ ] )T , (Y[n∗ ] )T ) = Em∗ n∗ . (7) (m∗ ,n∗ )∈I ∗

(m∗ ,n∗ )∈I ∗

406

F. Pan et al.

It is observed from Eq. (7) that the distance between the sequences X and Y can be obtained through the pairwise distances between their elements (i.e., Emn , m, n = 0, · · · , T − 1). Thus, we construct a graph with the matrix E, and learn the distance between the sequences X and Y through a graph neural network explicitly. We adopt the King’s graph as the graph structure, which represents all legal moves of the king piece in chess on a chessboard (Fig. 1 contains an example of a 5 × 5 King’s graph.). More specifically, a T × T King’s graph is constructed in which the vertex at position (m, n) is annotated with the value of Emn . In addition, each vertex is connected to its 8 adjacent vertices by edges. We denote the adjacency matrix of the T × T King’s graph as AK ∈ RT ×T . The King’s graph is then fed into a Graph Neural Network hθ parameterized by θ, which is comprised of k1 GCN [19] layers, a readout layer and k2 fully connected layers. For the GCN layer, the layer-wise propagation rule is formulated as follows:   1 1 ˜D ˜ − 2 H(l) W(l) , ˜ −2 A (8) H(l+1) = ReLU D ˜ = AK + IT 2 is the adjacency matrix of the T × T King’s graph with selfwhere A connections. IT 2 and W(l) denote the identity matrix  and the weight matrix, respectively. ˜ is a diagonal matrix with elements D ˜ ii = j A˜ ij , i = 0, · · · , T − 1. Besides, we D represent the feature matrix in the l th layer as H(l) ∈ RT ×Dl . The feature matrix H(0) ∈ 2 (0) RT ×1 is initialized as follows: Hi×T +j,0 = Eij , i, j = 0, · · · , T − 1. 2

Subsequently, the feature matrix H(k1 ) ∈ RT ×Dk1 generated by the multi-layer GCN is fed into a readout function to obtain the feature representation of the graph by aggregating node features. We adopt the following diagonal sum operation for feature aggregation: 2

vs =

−1 T −1 T  

(k )

1 I(i, j)H[i×T +j] , I(i, j) =

i=0 j=0



1, i = j , 0, i = j

(9)

(k )

1 th (k1 ) . Finally, several fully where H[i×T +j] denotes the (i × T + j) row of the matrix H connected layers are applied to the vector vs , in which the output dimension of the last layer is set to 1. To summarize, the Graph Neural Network hθ predicts the distance between the sequences X and Y based on the adjacency matrix AK and the distance matrix E:

ρ(X, Y) = hθ (AK , E).

(10)

3.4 Training and Inference For a given task in the N-way K-shot problem, the support and query sets are defined NQ as: S = {(Vi , ti )}NK i=1 , Q = {(Vi , ti )}i=1 , where Vi and ti (ti ∈ {1, 2, · · · , N }) represent a video and its corresponding label, respectively. Besides, for each class c (c ∈ {1, 2, · · · , N }), we construct a set Sc whose elements are all the examples of the class c in the support set S.

Convolutional Self-attention Guided Graph Neural Network

407

For a given video Vq in the query set Q, the average of the distances between Vq and the videos in the set Sc can be calculated by: ψ(Vq , Sc ) =

1 |Sc |



ρ(gβ (fα (Vq )), gβ (fα (Vi ))),

(11)

(Vi ,ti )∈Sc

where |Sc | represents the number of elements in Sc . During training, the model is optimized by minimizing the cross-entropy loss function: L(S, Q) = −

1 NQ



log

(Vq ,tq )∈Q

exp(−ψ(Vq , Stq )) N 

.

(12)

exp(−ψ(Vq , Sc ))

c=1

During inference, given the support set S, the label of a test video V is obtained by minimizing the distance function defined in Eq. (11): t = arg min ψ(V , Sc ). c∈{1,··· ,N }

(13)

4 Experiment 4.1 Datasets In order to validate the effectiveness of the proposed method, we conduct experiments on four action recognition datasets. Kinetics [20] is a large-scale action recognition dataset with 400 classes, each of which contains at least 400 videos. Following CMN [10], 64, 12 and 24 classes are used for meta training, meta validation and meta testing, respectively. Something-Something V2 (SSv2) dataset [21] contains more than 100,000 videos with durations ranging from 2 to 6 s. Following OTAM [5], we adopt 64 classes as the meta-training set, 12 classes as the meta-validation set, and 24 classes as the meta-test set. UCF101 [22] contains 13,320 videos of 101 classes collected from YouTube. Following ARN [3], 70, 10 and 21 classes are adopted for meta training, meta validation and meta testing, respectively. HMDB51 [23] includes 6,849 videos with 51 classes. Following ARN [3], we use 31, 10 and 10 classes for meta training, meta validation and meta testing, respectively. 4.2 Implementation Details In the preprocessing stage, T frames are sparsely sampled from each video and we set T to 8 in the experiment. Following TSN [18], for each frame, we first resize it to 256×256. During training, a sub-image of size 224 × 224 is randomly sampled from the resized image. During inference, center cropping is adopted instead of random cropping. In the experiment, ResNet-50 is adopted for feature embedding and initialized with the ImageNet [24] pre-trained weights. In addition, for the convolutional self-attention function in the sequence distance learning module, we set the values of k and dk to 5 and

408

F. Pan et al.

2048, respectively. For the Graph Neural Network hθ , we set the number of GCN layers k1 to 2, and their output dimensions are set to 64 and 128, respectively. Additionally, we set the number of fully connected layers k2 to 2, and their output dimensions are set to 128 and 1, respectively. During the training stage, the model is optimized in an end-to-end manner with SGD and the learning rate is set to 0.001. Finally, the averaged classification accuracy is reported by evaluating 10,000 episodes randomly selected from the meta-test set. The experiments are performed on a system powered by four NVIDIA GeForce RTX 3090 GPUs and PyTorch framework. 4.3 Experimental Results

Table 1. Results in the 5-way 1-shot and 5-way 5-shot settings on Kinetics and SSv2. Method

Kinetics

SSv2

1-shot

5-shot

1-shot

5-shot

CMN [4]

60.5

78.9

36.2

48.8

ARN [3]

63.7

82.4

-

-

OTAM [5]

73.0

85.8

42.8

52.3

TRX [14]

63.6

85.9

42.0

64.6

TA2N [11]

72.8

85.8

47.6

61.0

MTFAN [6]

74.6

87.4

45.7

60.4

Nguyen et al. [16]

74.3

87.4

43.8

61.1

CSA-GNN (Ours)

74.8

86.6

48.0

63.1

Results on Kinetics and SSv2. In order to verify the effectiveness of CSA-GNN, we first conduct experiments in the few-shot setting on the Kinetics and SSv2 datasets. We compare the experimental results with several recently proposed methods, including CMN [4], ARN [3], OTAM [5], TRX [14], TA2N [11], MTFAN [6], and Nguyen et al. [16]. As illustrated in Table 1, our method achieves state-of-the-art results in the 5-way 1-shot and 5-way 5-shot settings. On the SSv2 dataset, where the discrimination of actions in videos relies on temporal modeling (e.g., pushing something from right to left and pushing something from left to right), our method outperforms the temporal alignment-based methods (e.g., OTAM [5] and MTFAN [6]) by a relatively large margin. This means that the proposed method has superiority over the temporal alignment-based approaches in few-shot action recognition. Results on UCF101 and HMDB51. We also compare the performance of CSA-GNN with several state-of-the-art methods on the UCF101 and HMDB51 datasets. These

Convolutional Self-attention Guided Graph Neural Network

409

Table 2. Results in the 5-way 1-shot and 5-way 5-shot settings on UCF101 and HMDB51. Method

UCF101

HMDB51

1-shot

5-shot

1-shot

5-shot

ARN [3]

66.3

83.1

45.5

60.6

TRX [14]

81.7

96.1

52.1

75.6

TA2N [11]

81.9

95.1

59.7

73.9

MTFAN [6]

84.8

95.1

59.0

74.6

Nguyen et al. [16]

84.9

95.9

59.6

76.9

CSA-GNN (Ours)

85.2

95.0

59.7

74.9

methods include ARN [3], TRX [14], TA2N [11], MTFAN [6], and Nguyen et al. [16]. The classification accuracies in the 5-way 1-shot and 5-way 5-shot settings are presented in Table 2. In the 5-way 1-shot setting, CSA-GNN outperforms TRX [14] by 3.5% and 7.6% on UCF101 and HMDB51, respectively. Besides, the classification accuracies of our method are slightly lower than TRX [14] in the 5-way 5-shot setting. This suggests that CSA-GNN achieves state-of-the-art performance on the two datasets.

4.4 Ablation Study In this subsection, ablation studies are performed to evaluate the effectiveness of the proposed method on the SSv2 dataset. Table 3. Comparison with Baseline on SSv2. Method

5-way 1-shot

5-way 5-shot

Baseline

37.6

49.8

CSA-GNN (Ours)

48.0

63.1

Baseline. In order to investigate the importance of the sequence distance learning module, we first introduce a baseline model. For a given video, after extracting per-frame features from the sampled video frames, average pooling in the temporal dimension is applied to the feature vectors to obtain feature representation of the video. For each class, the feature vectors of the videos belonging to this class in the support set are averaged, from which we obtain the class prototypes. Next, we compute the Euclidean distances between the query video and the class prototypes, and train the model end-to-end with the cross-entropy loss function based on the distance scores. As shown in Table 3, CSA-GNN outperforms baseline by 10.4% and 13.3% in the 5-way 1-shot and 5-way 5-shot settings, respectively. This verifies the effectiveness of

410

F. Pan et al.

the sequence distance learning module for obtaining better action recognition results in the few-shot setting. Table 4. Ablation results of the Model Components on SSv2. gβ



5-way 1-shot

5-way 5-shot



45.5

60.2

45.1

57.4

48.0

63.1

✓ ✓



Analysis of Model Components. We also analyze the components in the sequence distance learning module, namely the convolutional self-attention function gβ and the graph neural network hθ . The ablation results of each component in the 5-way 1-shot and 5-way 5-shot settings are summarized in Table 4. Firstly, the convolutional self-attention function is removed from the model. This is equivalent to setting the function gβ to be an identity map. It can be observed from Table 4 that when the convolutional self-attention function is removed, the accuracies of our model in the 5-way 1-shot and 5-way 5-shot settings drop by 2.5% and 2.9%, respectively. Subsequently, we verify the importance of the graph neural network for sequence distance learning. In this experiment, we keep the convolutional self-attention part and remove the graph neural network in the sequence distance learning module. For the given sequences A and B, we first apply the convolutional self-attention function and obtain the sequences X and Y with Eq. (4). Then, the distance between X and Y is computed by: ρ  (X, Y) = 1 −

T X[t] (Y[t] )T 1 , (X[t] )T (Y[t] )T T 2 2 t=1

where the cosine similarity is utilized to compute the similarity between a pair of vectors. We observe from Table 4 that when the graph neural network is removed, the accuracies of our model in the 5-way 1-shot and 5-way 5-shot settings drop by 2.9% and 5.7%, respectively.

5 Conclusion Existing state-of-the-art few-shot action recognition approaches can be roughly categorized into global feature-based methods and temporal alignment-based methods. While global feature-based methods may lose temporal information and exhibit a significant deviation in learned distances, temporal alignment-based methods can suffer from high computational complexity, sensitivity to alignment algorithms and parameters, and difficulty in aligning sequences with complex temporal structures. In this work, we proposed Convolutional Self-Attention Guided Graph Neural Network (CSA-GNN) for

Convolutional Self-attention Guided Graph Neural Network

411

few-shot action recognition. Specifically, a convolutional self-attention function is first applied to the features of a few frames sparsely sampled from each video to capture long-term temporal dependencies. Then, we leverage a graph neural network to predict the distance between two videos explicitly, without estimating their temporal alignment. Experimental results show that CSA-GNN achieves state-of-the-art performance on four challenging benchmarks, including Kinetics, SSv2, UCF101 and HMDB51. Acknowledgements. This work is supported by the National Natural Science Foundation of China under Grant 62032011, 61972194 and 61772257.

References 1. Ou, Y., Mi, L., Chen, Z.: Object-relation reasoning graph for action recognition. In: IEEE Conference on Computer Vision and Pattern Recognition, pp. 20101– 20110. IEEE (2022) 2. Xiang, W., Li, C., Wang, B., Wei, X., Hua, X.-S., Zhang, L.: Spatiotemporal self-attention modeling with temporal patch shift for action recognition. In: Avidan, S., Brostow, G., Cissé, M., Farinella, G.M., Hassner, T. (eds.) Computer Vision – ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23–27, 2022, Proceedings, Part III, pp. 627–644. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-20062-5_36 3. Zhang, H., Zhang, L., Qi, X., Li, H., Torr, P.H.S., Koniusz, P.: Few-shot action recognition with permutation-invariant attention. In: Proceedings of the European Conference on Computer Vision, Glasgow, UK (2020) 4. Zhu, L., Yang, Y.: Label independent memory for semi-supervised few-shot video classification. IEEE Trans. Pattern Anal. Mach. Intell. 44(1), 273–285 (2022) 5. Cao, K., Ji, J., Cao, Z., Chang, C., Niebles, J.C.: Few-shot video classification via temporal alignment. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, Seattle, WA, USA, pp. 10615–10624 (2020) 6. Wu, J., Zhang, T., Zhang, Z., Wu, F., Zhang, Y.: Motion-modulated temporal fragment alignment network for few-shot action recognition. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9141–9150. IEEE (2022) 7. Sakoe, H., Chiba, S.: Dynamic programming algorithm optimization for spoken word recognition. IEEE Trans. Acoust. Speech Signal Process. 26(1), 43–49 (1978) 8. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, pp. 5998–6008 (2017) 9. Li, S., et al.: Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. In: Advances in Neural Information Processing Systems, pp. 5243– 5253 (2019) 10. Zhu, Z., Wang, L., Guo, S., Wu, G.: A closer look at few-shot video classification: a new baseline and benchmark. In: Proceedings of British Machine Vision Conference, BMVA Press (2021) 11. Li, S., et al.: Ta2n: two-stage action alignment network for few-shot action recognition. In: Proceedings of the Thirty-Sixth AAAI Conference on Artificial Intelligence, AAAI Press (2022) 12. Fu, Y., Zhang, L., Wang, J., Fu, Y., Jiang, Y.: Depth guided adaptive meta-fusion network for few-shot video recognition. In: Proceedings of the 28th ACM International Conference on Multimedia, Seattle, WA, USA. pp. 1142–1151. ACM (2020) 13. Zhang, S., Zhou, J., He, X.: Learning implicit temporal alignment for few-shot video classification. In: Proceedings of the Thirtieth International Joint Conference on Artificial Intelligence, Montreal, Canada, pp. 1309–1315. ijcai.org (2021)

412

F. Pan et al.

14. Perrett, T., Masullo, A., Burghardt, T., Mirmehdi, M., Damen, D.: Temporal-relational crosstransformers for few-shot action recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 475–484. IEEE (2021) 15. Zhu, X., Toisoul, A., Pérez-Rúa, J., Zhang, L., Martínez, B., Xiang, T.: Few-shot action recognition with prototype-centered attentive learning. In: Proceedings of the British Machine Vision Conference, BMVA Press (2021) 16. Nguyen, K.D., Tran, Q.-H., Nguyen, K., Hua, B.-S., Nguyen, R.: Inductive and transductive few-shot video classification via appearance and temporal alignments. In: Avidan, S., Brostow, G., Cissé, M., Farinella, G.M., Hassner, T. (eds.) Computer Vision – ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23–27, 2022, Proceedings, Part XX, pp. 471– 487. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-20044-1_27 17. Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., Wierstra, D.: Matching networks for one shot learning. In: Advances in Neural Information Processing Systems, Barcelona, Spain, pp. 3630–3638 (2016) 18. Wang, L., Xiong, Y., Wang, Z., Qiao, Y., Lin, D., Tang, X., Gool, L.V.: Temporal segment networks: towards good practices for deep action recognition. In: Proc. of the European Conference on Computer Vision, pp. 20–36 (2016) 19. Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. In: International Conference on Learning Representations (2017) 20. Kay, W., et al.: The kinetics human action video dataset. CoRR abs/1705.06950 (2017) 21. Goyal, R., Kahou, S.E., Michalski, V., Materzynska, J., et al.: The “something something” video database for learning and evaluating visual common sense. In: Proceedings of the IEEE International Conference on Computer Vision, Venice, Italy, pp. 5843–5851 (2017) 22. Soomro, K., Zamir, A.R., Shah, M.: Ucf101: a dataset of 101 human actions classes from videos in the wild. CoRR abs/1212.0402 (2012) 23. Kuehne, H., Jhuang, H., Garrote, E., Poggio, T.A., Serre, T.: Hmdb: a large video database for human motion recognition. In: Proceedings of the IEEE International Conference on Computer Vision, Barcelona, Spain, pp. 2556–2563 (2011) 24. Deng, J., Dong, W., Socher, R., Li, L., Li, K., Li, F.: Imagenet: a large-scale hierarchical image database. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, Miami, Florida, USA, pp. 248–255 (2009)

DisGait: A Prior Work of Gait Recognition Concerning Disguised Appearance and Pose Shouwang Huang(B) , Ruiqi Fan, and Shichao Wu College of Artificial Intelligence, Nankai University, Tianjin 300071, China {2013128,2014076,wusc}@mail.nankai.edu.cn

Abstract. Gait recognition, capable of identifying humans at a distance without the cooperation of the subjects, has significant security applications. As gait recognition edges closer to practical application, addressing various challenges posed by different scenarios becomes increasingly essential. This study focuses on the issue of disguises. Specifically, it introduces a novel benchmark gait dataset (DisGait) for investigating the performance of gait recognition methods under disguised appearance and pose, which has not been explored in contemporary gait-related research. The dataset consists of 3,200 sequences from 40 subjects under various walking conditions, including normal clothing (NM), thick coats (CO), different types of shoes (SH), and uniform white lab gown (GO). The primary distinction between this dataset and others lies in varying degrees of appearance and pose disguises. We performed gait recognition with various state-of-the-art (SOTA) silhouette-based and skeleton-based approaches on the DisGait dataset. Experimental results show that all of them are unable to achieve a satisfactory score with the average accuracies falling below 40%, which substantiated that disguised appearance and pose are important for gait recognition. Moreover, we evaluated the influence of sequence frames and extensive occlusions on gait recognition, analyzed the attributes and limitations of current SOTA approaches, and discussed potential avenues for improvement. Keywords: Gait recognition · Disguise · Benchmark dataset · Security

1 Introduction Gait recognition is a biometric identification technique based on human walking characteristics. With computer vision and pattern recognition technologies, gait recognition has become a crucial biometric identification technique with many application prospects [8, 13]. In security, gait recognition is widely used for identity recognition and security monitoring [11]. Compared to traditional biometric identification technologies like fingerprint and facial recognition, gait recognition offers higher accuracy and security. Additionally, gait recognition can identify individuals without direct contact, making it more suitable for security applications. S. Huang and R. Fan–Contributed equally to this work and should be considered co-first authors. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 413–425, 2023. https://doi.org/10.1007/978-981-99-4742-3_34

414

S. Huang et al.

With the continuous development of deep learning, gait recognition has made significant progress [14]. Some methods, such as GaitSet [1] and CSTL [4] have achieved high accuracy rates in the CASIA-B [19] dataset under normal walking conditions. However, in real-world application scenarios, the accuracy of gait recognition is easily influenced by various factors. Existing datasets mainly focus on the impact of viewpoint, clothing changes, walking speed, and background. Nevertheless, few datasets consider the effect of occlusion [3], and almost none address the factors of substantial appearance disguises. Here, we define a disguise as a factor that significantly alters appearance or pose. Although many datasets pay attention to clothing changes, these alterations do not noticeably affect appearance. This paper introduces a gait dataset that considers disguise factors called DisGait. This dataset requires participants to walk under different conditions, including wearing normal clothing, thick coats, different types of shoes, and uniform white lab gown. The DisGait dataset comprises 40 participants and 3,200 sequences with walking conditions labeled. We conducted three experiments to evaluate the dataset’s performance using state-of-the-art methods and analyzed the main challenges of gait recognition in practical applications. We believe that the DisGait dataset will serve as a valuable resource in the field of gait recognition and provide more references for gait recognition in realistic application scenarios. Therefore, the main contributions of this paper can be summarized into the following three points: 1. Propose one public gait recognition dataset that considers appearance and pose disguises. Both the model-based gait data (skeleton) and the appearance-based gait data (silhouette) are provided for evaluation. 2. Present gait recognition baselines with existing outstanding gait recognition methods concerning various disguise factors. 3. Analyze the advantages and shortcomings of the current state-of-the-art methods, and discuss possible improvement directions. The organization of the rest of this paper is as follows: Sect. 2 briefly reviews the influential public datasets in gait recognition. Section 3 describes the details of DisGait construction, including the silhouette and skeleton gait extraction methods. Section 4 presents the overall pipeline of three SOTA gait recognition methods and the evaluation protocol of DisGait. Section 5 provides gait recognition results and discussions on the DisGait dataset. Finally, Sect. 6 concludes the paper.

2 Related Work The most commonly used gait recognition datasets are summarized in Table 1. The datasets in Table 1 are sorted by year of publication, where the Soton dataset [9] to the OU-ISIR MVLP [15] dataset are mentioned in ReSGait [10] or GREW [20] and are not duplicated here. The ReSGait dataset comprises 870 sequences from 172 subjects walking along unconstrained indoor curved paths. Covariates involve regular clothes, carrying items, and holding a cell phone. The GREW dataset includes 128,671 sequences from 26,345 subjects walking along unconstrained outdoor curved paths.

DisGait: A Prior Work of Gait Recognition Concerning

415

Table 1. Some popular public gait recognition datasets. Datasets are sorted in released time. NM, CO, SH, and GO are abbreviations of wearing normal clothing, thick coats, different types of shoes, and uniform white lab gown. ‘-’ is described as having a non-fixed perspective. Dataset

Subjects

Seq

Views

Walking route

Dressing style

Soton [9]

25

2,280

12

line route

NM, SH

USF [12]

20

1,870

2

curved route

NM, SH

CASIA-B [19]

124

13,640

11

line route

NM, CO

CASIA-C [16]

153

1,530

1

line route

NM

TUM-IITKGP [3]

35

840

1

line route

NM, GO

OU-ISIR LP [5]

4,007

7,842

4

line route

NM

TUM-GAID [2]

305

3,370

1

line route

NM, SH

OU-ISIR MVLP[15]

10,307

288,596

14

line route

NM

ReSGait [10]

172

870

1

curved route

CO

GREW [20]

26,345

128,671

-

curved route

NM

Ours

40

3,200

1

curved route

NM, CO, SH, GO

Covariates consist of regular clothes, carrying items, and surfaces. Among the datasets introduced in this section, almost all of them only consider normal clothes as a covariate. Normal clothes refer to the subjects wearing personal clothes, which do not cause significant changes to their appearance and body shape. None of the datasets consider using disguises, i.e., factors that can cause severe changes to the subjects’ appearance or pose. In our dataset, the gait collection conditions include wearing normal dress, wearing a thick coat, changing shoes, and wearing a uniform white lab gown. Wearing a uniform white lab gown can partially obscure the lower body but also prevents interference from personalized clothing. At the same time, changing shoes can alter one’s pose, while wearing a thick coat can significantly affect both appearance and pose. Therefore, compared to previous datasets, the most significant feature of our dataset lies in considering the appearance and pose factors of disguise, which can substantially influence walking patterns.

3 The DisGait Dataset 3.1 Data Collection Protocol The experimental scenario is shown in Fig. 1. Data collection occurs in a standard 6.6 × 6.6 m classroom with ample lighting, white walls or curtains, and a smooth tiled floor. The walking trajectory is a circle with a radius of about 3 m, centered in the room. The camera is placed in one corner of the room at a height of approximately 1.6 m, ensuring that all positions of the walking trajectory can be captured. We use the Azure Kinect sensor as our camera, with a resolution of 2560 × 1440 pixels and a frame rate of 30fps.

416

S. Huang et al.

Data collection took place from March 16 to 20, 2022, during the cold early spring when the outdoor temperature was around 0 °C, resulting in the subjects wearing thicker coats with a more pronounced impact on gait appearance than ordinary coats.

Fig. 1. The experimental scenario. On the left is a schematic diagram of the experimental setup, walking paths, and camera locations, while on the right is a record of the actual scenes.

Before collecting gait sequences, we asked subjects to bring a thick coat and another pair of shoes. After the experiment began, subjects started walking from a fixed starting point. We collected gait sequences under four different walking conditions: wearing their thick coat, taking off the coat (normal clothing), changing shoes, and wearing a uniform white lab gown. It should be noted that in the condition of wearing a uniform white lab gown, the subjects kept the same shoes they changed into without further altering them. This was done to compare with the third condition of changing shoes. Each sequence was collected for one minute, approximately two to three laps, totaling 1800 frames of images. 3.2 Data Statistics We recruited 40 subjects for data collection, providing them with compensation and obtaining their permission to collect the data. These participants included 23 males and 17 females, with a roughly balanced gender distribution. Their ages ranged from 20 to 30; all were university students without movement disorders or disabilities. Each subject walked under four different conditions, with 1,800 frames of images collected for each state. We divided every consecutive 90 frames into a sequence, resulting in 20 sequences. Therefore, the entire dataset consists of 40 × 4 × 20 = 3,200 sequences. In addition to the normal walking condition, our dataset also includes three disguise conditions, detailed as follows: • Thick coats: Participants wore their own thick coats, generally heavier due to the cold outdoor temperature during data collection. On the one hand, the substantial thickness of the overcoats significantly alters the upper body shape, and, in some cases, long coats may also partially cover the lower body. On the other hand, the weight of the overcoat shifts the body’s center of gravity, subsequently modifying the individual’s

DisGait: A Prior Work of Gait Recognition Concerning

417

walking posture. Additionally, overly tight coats may restrict the range of motion in body joints. • Changing shoes: Participants brought an additional pair of shoes that differ from their original ones. Differences in sole shoe friction, thickness, and hardness can significantly change a person’s walking posture by altering the body’s center of gravity and support points during walking. This makes the change of shoes a notable factor to consider in gait analysis. • Uniform white lab gown: The laboratory provided a uniform white lab gown, with the size selected based on the participant’s height and weight. Characterized by a closefitting upper part and a covered lower body, the white lab gown notably impacts the individual’s appearance, which is a common disguise. However, wearing a uniform white lab gown can also eliminate the individual’s external clothing personality, thereby excluding the interference of clothing, potentially allowing the algorithm to focus more on gait itself.

Fig. 2. The RGB, skeleton, and silhouette images of some participants from the DisGait dataset. Each row represents a walking condition, from top to bottom: wearing a thick coat, wearing normal clothes, changing shoes, and wearing a uniform white lab gown. Each column in the figure represents a type of data for the same participant under the four walking conditions.

3.3 Gait Sequence Extraction We use Detectron2 [18] to extract silhouette images. First, a pre-trained COCO model [8] is used for object detection and instance segmentation to obtain a mask image of the object. Then, morphological operations in the OpenCV library are used to process the mask image to remove the background and preserve the contour of the foreground object. Finally, the processed outline is converted into a binary image to obtain the silhouette image. The silhouette image size is 640 × 360. Pose data of skeleton points are extracted with the COCO-Keypoints model of the Detectron2 framework. The model uses an efficient feature pyramid network (FPN [6]) to

418

S. Huang et al.

process features at different scales and uses a multi-task loss function to simultaneously predict the pose and keypoint locations. This method can obtain a total of 17 key points. The RGB, skeleton, and silhouette images of some participants are shown in Fig. 2.

4 Baselines on DisGait 4.1 SOTA Methods Three SOTA gait recognition approaches are selected to evaluate their performance on the DisGait dataset. For silhouette data, we chose the GaitSet [1] and CSTL [4] methods. For skeletal data, we select GaitGraph [17], which is specifically designed for human joint data in gait recognition. The overall pipeline is shown in Fig. 3. The skeleton and silhouette are obtained from the raw data separately using the Detectron2 [18] framework. The skeleton information is transformed into a matrix of size 51 × batchsize, while the silhouette image undergoes preprocessing and stacking. The resulting data is then fed into their corresponding neural network models. By doing so, the gait feature vectors are extracted from each sequence.

Fig. 3. The pipeline of the DisGait dataset, including data preprocessing and feature extraction.

4.2 Evaluation Protocol In this study, we adopt a subject-independent protocol to divide the dataset and split it into training and testing sets in a ratio similar to CASIA-B [19] (73:50). The training set consists of data from the first 23 subjects, and the testing set comprises data from the remaining 17 subjects, with no overlapping subjects. The training set contains a total of 1,840 sequences from all walking conditions. In the testing set, we use the first 4 out of 20 sequences under normal walking conditions as the gallery and the remaining 16 as probes. To ensure consistency in the number of sequences in the probes for other walking conditions with the normal walking condition, we only select the last 16 sequences from the other three walking conditions as probes, leaving the first four sequences unused for rigor consideration.

DisGait: A Prior Work of Gait Recognition Concerning

419

5 Experiments In this study, we conduct three experiments on the proposed DisGait dataset and provided a detailed analysis of the experiment results. Experiments one and two utilize models trained on the CASIA-B [19] training set and tested on the entire DisGait dataset of 40 subjects. The division of the gallery and probe sets for the test set can be found in Sect. 4.2. The third experiment uses the DisGait dataset divided according to Sect. 4.2 for training and testing. The data used in the three experiments and their brief experimental objectives are shown in Table 2. Table 2. Experimental design. Each column represents the data used for training and testing and experimental objectives. Each row represents a different experiment. Train subjects

Test subjects

Goal

Exp1

-

ID 1 to 40

Baseline results and study gait covariates

Exp2

-

ID 1 to 40

Study the impact of frames

Exp3

ID 1 to 23

ID 24 to 40

Study the impact of body parts

5.1 Experiment 1: Baseline Results In Experiment 1, we test the generalization ability of existing baseline methods on DisGait and analyze the influence of different disguise variables on the test results. We use the weight models of GaitSet, CSTL, and GaitGraph trained on the CASIA-B training set and test them on the entire DisGait dataset. Before testing, we need to preprocess the image size of DisGait. The original image size of DisGait is 640 × 360. For GaitSet, we crop the image size to 128 × 128 and then resize it to 64 × 64; for the CSTL method, we crop the image to 128 × 88. Note that image cropping does not result in any loss of information. For GaitGraph, we directly use the skeleton data obtained in Sect. 3.3. The experimental results we obtained are shown in Table 3. Table 3. Rank-1 accuracy of three models under four walking conditions, average rank-1 accuracy across all conditions, and ranking of relative accuracies. Method

NM

CO

SH

GO

Avg

GaitGraph

12.66

8.59

7.81

10.31

9.84

NM > GO > CO > SH

GaitSet

62.50

17.00

27.50

41.50

37.13

NM > GO > SH > CO

CSTL

43.00

11.00

22.50

33.50

27.50

NM > GO > SH > CO

420

S. Huang et al.

The accuracy of the three disguise methods is lower than that of the non-disguised (NM) condition, indicating that all three disguise methods have varying degrees of impact on existing recognition methods. Specifically: The accuracy ranking for the skeleton-based method GaitGraph is NM > GO > CO > SH. This suggests: • Changing shoes has the lowest accuracy, indicating that it has a significant impact on a person’s pose; • Under the CO condition, despite significant appearance modifications, the skeletonbased methodology exhibits appreciable robustness, and the impact of CO on gait is less pronounced than SH, consequently yielding a higher accuracy than SH; • In the GO scenario, while participants modify their gait in relation to NM by switching shoes, the standardized clothing eliminates interference from personalized attire on skeletal extraction, thereby achieving superior accuracy relative to alternative disguise conditions. For the silhouette-based methods GaitSet and CSTL, the accuracy ranking is NM > GO > SH > CO. This suggests: • CO not only affects pose but also has a more significant impact on appearance and silhouette, resulting in the lowest accuracy in silhouette methods; • While SH considerably affects pose without altering appearance, it consequently achieves greater accuracy than CO; • In the Go scenario, even though participants’ gaits are modified due to shoe-switching, the standardized white robes mitigate interference from attire, potentially enabling the algorithm to concentrate more effectively on gait, thereby attaining superior accuracy compared to other disguise conditions. Comparing the skeleton and silhouette methods, we find that in the skeleton method, CO > SH, while in the silhouette method, CO < SH. This indicates that the skeleton method has some robustness to changes in appearance, while the silhouette method is more easily affected by changes in appearance. The average accuracy ranking of the three methods is GaitSet > CSTL > GaitGraph. On the one hand, this indicates that the overall performance of silhouette-based gait recognition methods is still higher than that of skeleton-based methods. On the other hand, while CSTL has a higher accuracy than GaitSet on the CASIA-B dataset, it has a lower accuracy on the DisGait dataset, suggesting that its generalization ability may not be as good as GaitSet. Overall, the test scores are relatively low, indicating that the existing methods’ generalization ability needs to be improved. This also demonstrates that the dataset has some significance in promoting the development of gait recognition.

DisGait: A Prior Work of Gait Recognition Concerning

421

5.2 Experiment 2: Different Frames The purpose of Experiment 2 is to investigate the impact of sequence frame count on the test results of the three SOTA methods. In practical gait recognition applications, the obtained gait sequences are usually shorter and contain fewer frames, which may significantly impact the test results. To explore this issue, we use the DisGait dataset for testing, whose sequence consists of 90 frames. The experiment tests with sequence frame numbers starting from 1 frame, increasing by six frames at a time, and testing 15 times in total. The weight models, test set, and preprocessing methods are the same as in Experiment 1. Figure 4 illustrates the experimental result. Overall, as the sequence frame numbers increase, the accuracy of each method also increases, but the growth trends differ. Specifically: GaitSet reaches a relatively high level of accuracy at around 24 frames, and the increase in accuracy in the later stages is minimal, with slight fluctuations. This is consistent with GaitSet’s characteristic of treating all frames in the gait sequence as an unordered set, not considering temporal information. In this dataset, the number of sequence frames for one gait cycle is approximately 20–25. When the sequence frame number is less than one gait cycle, increasing the number of sequence frames can significantly improve recognition accuracy. However, when the sequence frame number is more than one gait cycle, the additional sequence frames do not improve the recognition accuracy, consistent with the results reported in the original GaitSet work [1].

Fig. 4. Rank-1 accuracies with constraints of sequence length on the DisGait using GaitSet, CSTL, and GaitGraph. The accuracy values are averaged on all conditions.

The accuracy of CSTL exhibits a sustained increasing trend from frames 1 to 90. This method uses a two-branch network to learn context-sensitive temporal features, with one branch for temporal aggregation and another for salient spatial features. This method places more emphasis on temporal information, so the results show that the more frames there are, the higher the test accuracy.

422

S. Huang et al.

GaitGraph’s accuracy gradually increases in the first 50 frames, then saturates after 50. Approximately 50 frames correspond to two gait cycles, indicating that skeletonbased methods require relatively longer gait sequences to extract sufficient gait feature information. The significance of this experiment is that, for practical gait recognition applications, especially in the field of security, it provides an additional reference criterion for model evaluation, i.e., the desire for the model to achieve better results with as few frames as possible. Among the three methods in this experiment, GaitSet performs the best. 5.3 Experiment 3: Different Body Parts In most practical gait recognition scenarios, we cannot obtain complete full-body gait information due to occlusions between objects or people, which may be the main factor limiting the widespread application of gait recognition. Therefore, we designed a body part truncation experiment, using full-body, upper-body, and lower-body only gait information for training and testing, respectively. When using only upper-body gait information for silhouette data, we set the lower half of the image as the background. The same applies to the lower body. For skeleton data, based on the COCO skeleton information, we use nodes 0–11 as the upper body and nodes 12–16 as the lower body. We train on the DisGait training set (23 subjects) and test on the test set (17 subjects), obtaining the results shown in Table 4. From the average accuracy of the four probes, as shown in the top-left subfigure in Fig. 5, all three methods show a performance ranking of full-body > upper-body > lowerbody. Specifically, for GaitGraph, the performance of the model trained on the upper body and full body data is comparable. In contrast, for GaitSet and CSTL, the accuracy of the model trained on partial body gait information is significantly lower than when trained on full body gait information. This finding suggests that skeleton-based methods have better robustness to lower body occlusions, while silhouette-based methods are more sensitive to occlusions. Furthermore, the GaitGraph achieves significantly higher average accuracy in upper-body recognition than in lower-body honor. The GaitSet achieves higher average accuracy in upper-body recognition than lower-body recognition, but the advantage is relatively small. The CSTL has similar average recognition accuracy for upper and lower body gait. This result may indicate that the GaitGraph cannot yet extract rich gait information from the lower limbs. In comparison, GaitSet and CSTL have comparable abilities to extract gait features from the upper and lower limbs. This result also provides possible directions for improvement for skeleton-based gait recognition methods.

DisGait: A Prior Work of Gait Recognition Concerning

423

Table 4. All results of Experiment 3 GaitSet

NM

CO

SH

GO

Avg

Ranking NM > SH > GO > CO

full-body

96.41

84.71

89.41

88.24

89.69

upper-body

90.59

55.29

77.65

72.94

74.12

lower-body

82.35

56.47

70.59

62.35

67.94

CSTL

NM

CO

SH

GO

Avg

Ranking NM > SH > GO > CO

full-body

89.71

63.97

88.24

72.06

78.50

upper-body

74.12

37.65

63.53

55.29

57.65

lower-body

75.29

44.71

51.76

48.24

55.00

GaitGraph

NM

CO

SH

GO

Avg

Ranking

full-body

79.33

54.81

61.79

57.21

63.29

NM > SH > GO > CO

upper-body

77.04

63.94

49.53

53.85

61.09

NM > CO > GO > SH

lower-body

45.67

27.88

33.96

25.48

33.25

NM > SH > CO > GO

Comparing the accuracy of the four probes, silhouette-based methods (GaitSet and CSTL) show a ranking of NM > SH > GO > CO for full-body, upper-body, and lowerbody as shown in the top-right and bottom-left subfigures in Fig. 5. The accuracy was lowest in wearing thick coats, indicating that appearance changes easily affect silhouettebased methods. In contrast, as shown in the bottom-right subfigure in Fig. 5, skeletonbased methods trained on the upper-body information show an NM > CO > GO > SH ranking, reasonably consistent with Experiment 1. The accuracy was highest in wearing thick coats compared to the other two conditions, reflecting the robustness of skeleton-based methods to appearance changes. However, the results of lower-body training show a ranking of NM > SH > CO > GO, which is different from the upper-body results, indicating to some extent that the skeleton-based methods do not pay enough attention to the gait features of the lower body and do not fully exploit the advantages of skeleton-based methods. In summary, by comparing the performance of different methods on the truncated DisGait, it is found that the skeleton-based method has certain robustness to lower-body occlusions and appearance changes, reflecting the potential of skeleton-based methods and providing support for the widespread application of gait recognition in practice. At the same time, the experiment also shows that GaitGraph is insufficient in extracting skeleton features from the lower body, providing possible directions for improvement for skeleton-based gait recognition methods. In addition, based on considerations for practical applications, this experiment also provides a new perspective for evaluating different gait recognition models, i.e., not only focusing on the gait recognition accuracy of full-body data but also considering the robustness of the model under large-area occlusion situations.

424

S. Huang et al.

Fig. 5. The top-left subfigure offers the average rank-1 accuracy of the three methods using different body parts. The top-right subfigure shows the rank-1 accuracy of GaitSet using different body parts under different walking conditions. The bottom-left subfigure shows the related results of CSTL. The bottom-right subfigure shows the related results of GaitGraph.

6 Conclusion This paper presents a new gait dataset called DisGait, which focuses on disguises. Subjects use three types of disguises: thick coats, changing shoes, and uniform white lab gown, walking along a fixed curve. The dataset disguises appearance and pose to varying degrees, posing challenges to existing gait recognition methods. To evaluate the performance of existing gait recognition methods on our dataset, we designed three experiments to respectively investigate the generalization performance of existing models on DisGait, the impact of sequence frame numbers, and the effect of large-area body occlusion on gait recognition. The experimental results show that the rank-1 average accuracy of existing models on the DisGait dataset does not exceed 40%, reflecting the challenging nature of DisGait. Furthermore, by altering the sequence frame numbers and body part occlusion in the experiments, we compared the characteristics and shortcomings of different methods, providing further guidance for model improvement and practical applications of gait recognition. Future work will focus on enhancing skeleton-based methods based on the limitations of current approaches, with a particular emphasis on researching feature extraction for the lower limbs.

DisGait: A Prior Work of Gait Recognition Concerning

425

References 1. Chao, H., He, Y., Zhang, J., Feng, J.: Gaitset: regarding gait as a set for cross-view gait recognition (2018) 2. Hofmann, M., Geiger, J., Bachmann, S., Schuller, B., Rigoll, G.: The tum gait from audio, image and depth (gaid) database: multimodal recognition of subjects and traits. J. Vis. Commun. Image Represent. 25(1), 195–206 (2014) 3. Hofmann, M., Sural, S., Rigoll, G.: Gait recognition in the presence of occlusion: a new dataset and baseline algorithms (2011) 4. Huang, X., et al.: Context-sensitive temporal feature learning for gait recognition. In: 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 12889–12898 (2021) 5. Iwama, H., Okumura, M., Makihara, Y., Yagi, Y.: The ou-isir gait database comprising the large population dataset and performance evaluation of gait recognition. IEEE Trans. Inf. Forensics Secur. 7(5), 1511–1521 (2012) 6. Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature pyramid networks for object detection. In: 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 936–944 (2017) 7. Lin, T.Y., et al.: Microsoft coco: Common objects in context. In: European Conference on Computer Vision (2014) 8. Makihara, Y., Nixon, M.S., Yagi, Y.: Gait recognition: databases, representations, and applications, pp. 1–13. Springer, Cham (2020) 9. Matovski, D., Nixon, M., Mahmoodi, S., Carter, J.: The effect of time on gait recognition performance. IEEE Trans. Inf. Forensics Secur. 7(2), 543–552 (2012) 10. Mu, Z., et al.: Resgait: the real-scene gait dataset. In: 2021 IEEE International Joint Conference on Biometrics (IJCB), pp. 1–8 (2021) 11. Rafi, M., Raviraja, S., Wahidabanu, R.: Gait recognition: a biometric for security, October 2009 12. Sarkar, S., Phillips, P., Liu, Z., Vega, I., Grother, P., Bowyer, K.: The humanid gait challenge problem: data sets, performance, and analysis. IEEE Trans. Pattern Anal. Mach. Intell. 27(2), 162–177 (2005) 13. Sepas-Moghaddam, A., Etemad, A.: Deep gait recognition: a survey. IEEE Trans. Pattern Anal. Mach. Intell. 45(1), 264–284 (2023) 14. Shen, C., Yu, S., Wang, J., Huang, G.Q., Wang, L.: A comprehensive survey on deep gait recognition: algorithms, datasets and challenges (2022) 15. Takemura, N., Makihara, Y., Muramatsu, D., Echigo, T., Yagi, Y.: Multi-view large population gait dataset and its performance evaluation for cross-view gait recognition. IPSJ Trans. Comput. Vis. Appl. 10, 4 (2018) 16. Tan, D., Huang, K., Yu, S., Tan, T.: Efficient night gait recognition based on template matching. In: 18th International Conference on Pattern Recognition (ICPR’06). vol. 3, pp. 1000–1003 (2006) 17. Teepe, T., Khan, A., Gilg, J., Herzog, F., Hörmann, S., Rigoll, G.: Gaitgraph: graph convolutional network for skeleton-based gait recognition. In: 2021 IEEE International Conference on Image Processing (ICIP), pp. 2314–2318 (2021) 18. Wu, Y., Kirillov, A., Massa, F., Lo, W.Y., Girshick, R.: Detectron2 (2019). https://github.com/ facebookresearch/detectron2 19. Yu, S., Tan, D., Tan, T.: A framework for evaluating the effect of view angle, clothing and carrying condition on gait recognition. In: 18th International Conference on Pattern Recognition (ICPR 2006), vol. 4, pp. 441–444 (2006) 20. Zhu, Z., et al.: Gait recognition in the wild: A benchmark. In: 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 14769–14779 (2021)

Image Processing

An Experimental Study on MRI Denoising with Existing Image Denoising Methods Guang Yi Chen1(B) , Wenfang Xie2 , and Adam Krzyzak1 1 Department of Computer Science and Software Engineering, Concordia University, Montreal,

QC H3G 1M8, Canada {guang_c,krzyzak}@cse.concordia.ca 2 Department of Mechanical, Industrial and Aerospace Engineering, Concordia University, Montreal, QC H3G 1M8, Canada [email protected]

Abstract. In this paper, we perform a systematical study on existing 2D denoising methods for reducing the noise in magnetic resonance imaging (MRI). We conduct experiments on six MRI images with the following denoising methods: wiener2, wavelet-based denoising, bivariate shrinkage (BivShrink), SURELET, Non-local Means (NLM), block matching and 3D filtering (BM3D), denoising convolutional neural networks (DnCNN) and weighted nuclear norm minimization (WNNM). Based on our experiments, the BM3D and the WNNM are the best two methods for MRI image denoising. Nevertheless, the WNNM is the slowest in term of CPU computational time. As a result, it is preferable to choose the BM3D for MRI denoising. Keywords: Magnetic resonance imaging (MRI) · Denoising · Block matching and 3D filtering (BM3D) · Denoising convolutional neural networks (DnCNN)

1 Introduction Magnetic resonance imaging (MRI) of the brain plays a very important role in diagnosing such neurological diseases as Parkinson’s disease, Alzheimer’s disease (AD), brain tumors, and stroke. Studying MRI images can aid surgeons make wise decisions. Nevertheless, MRI noises reduce image quality, which negatively affects image processing and analysis works, such as registration, segmentation, classification, and visualization. To obtain reliable analysis results, removing MRI image noises is necessary before further image processing can be performed. The denoising of MRI images is a very important research topic in the literature. MRI images are often affected by random noise which limits the accuracy of any quantitative measurements from the data. Given an MRI image corrupted by noise, it is desirable to find a very good method to reduce the noise and at the same time it should be fast in term of CPU computational time. There are many denoising methods published in the literature so far. Which one if the best for MRI denoising? In this paper, we perform a comparable study for MRI denoising with eight denoising methods. Our experiments © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 429–437, 2023. https://doi.org/10.1007/978-981-99-4742-3_35

430

G. Y. Chen et al.

show that the block matching and 3D filtering (BM3D) and the weighted nuclear norm minimization (WNNM) achieve the best denoising results. Nevertheless, the BM3D method is much faster than the WNNM method. As a result, the BM3D method should be adopted for MRI denoising. The organization of the rest of this paper is as follows. In the next section, we conduct a systematical study on MRI denoising with eight image denoising methods and six MRI images. We draw the conclusion of this paper and propose future research directions afterward.

2 An Experimental Study We conduct experiments with eight image denoising methods for MRI denoising in this section. Additive Gaussian white noise (AGWN) is added to the clean MRI images and then denoising is performed. The noise standard deviation (STD) σn in our experiments is chosen as σn = {10, 20, 30, 40, 50, 60, 70, 80, 90, and 100}. The denoising methods studied in this paper are: wiener2 [1], wavelet-based denoising [2], bivariate shrinkage (BivShrink [3]), SURELET [4], Non-local Means (NLM [5]), BM3D [6], denoising convolutional neural networks (DnCNN [7]) and WNNM [8]. Tables 1, 2, 3, 4, 5 and 6 show the peak signal to noise ratio (PSNR) of different denoising methods for images Abdomen, Feet, Head, Pelvis, Thigh, and Thorax, respectively. All these images are downloaded from https://www.nlm.nih.gov/research/visible/mri.html. The best results in these tables are highlighted in bold font. The PSNR is defined as PSNR(A.B) = 10log10 ( 

M × N × 2552 2 i,j (B(i, j) − A(i, j))

)

where M × N is the number of pixels in the image, and A and B are the noise-free and noisy/denoised images. Based on our experiments, the BM3D and the WNNM are the best two methods for MRI image denoising. Nevertheless, the WNNM is the slowest in term of CPU computational time. As a result, it is preferable to choose the BM3D method for MRI denoising. Even though deep CNN methods are very popular at present, the DnCNN method does not perform very well for MRI denoising in our experiments. We admit that it is a reasonably good denoising method, but it does not outperform the BM3D method and the WNNM method for MRI denoising.

An Experimental Study on MRI Denoising

431

Table 1. The PSNR of different denoising methods for Abdomen with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.25

10.09

9.07

Wieners

29.19

27.40

25.90

24.60

23.43

22.37

21.39

20.49

19.66

18.89

Wavelets

30.11

27.12

25.27

24.10

23.38

22.83

22.39

22.03

21.74

21.53

BivShrink

32.85

29.35

27.46

26.22

25.29

24.59

24.05

23.57

23.15

22.81

SURELET

31.03

28.43

26.79

25.66

24.85

24.19

23.67

23.24

22.86

22.52

NLM

32.39

28.90

26.74

25.28

24.20

23.34

22.60

21.94

21.33

20.78

BM3D

33.21

29.97

28.09

26.61

25.70

25.00

24.44

23.94

23.51

23.13

DnCNN

32.98

29.66

27.52

25.96

24.70

23.58

22.54

21.56

20.62

19.76

WNNM

33.10

29.94

28.02

26.70

25.69

24.98

24.32

23.77

23.31

22.89

90

100 8.15

Table 2. The PSNR of different denoising methods for Feet with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

90

100

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.29

10.09

9.07

8.15

Wieners

32.91

29.57

27.20

25.36

23.87

22.62

21.53

20.57

19.71

18.92

Wavelets

33.23

29.37

27.26

25.65

24.46

23.68

23.10

22.57

22.06

21.68

BivShrink

36.12

32.12

29.94

28.49

27.40

26.49

25.72

25.05

24.48

23.98

SURELET

35.44

31.53

29.37

27.89

26.80

25.93

25.22

24.60

24.07

23.60

NLM

35.49

31.90

29.23

27.30

25.77

24.50

23.41

22.47

21.64

20.91

BM3D

38.08

34.25

32.06

30.40

29.32

28.51

27.83

27.17

26.63

26.15

DnCNN

34.47

29.46

26.40

24.14

22.35

20.85

19.56

18.44

17.46

16.59

WNNM

38.29

34.48

32.25

30.70

29.54

28.54

27.79

27.19

26.67

26.20

We also draw figures to support our findings in this paper. Figure 1 displays the six MRI images used in the experiments. Figures 2, 3, 4, 5, 6 and 7 depict the denoised MRI images with different methods for images Abdomen, Feet, Head, Pelvis, Thigh, and Thorax, respectively. The denoising results from the BM3D and the WNNM are the best among all six MRI images as demonstrated in these figures.

432

G. Y. Chen et al.

We measure the CPU computational time for all eight denoising methods for image Abdomen with 256 × 256 pixels. Our experiments are implemented in Matlab under the Linux operating system with Intel(R) Xeon(R) CPU E5–2697 v2 at 2.70 GHz and 131 GB of random-access memory (RAM). The time taken for different denoising methods is as follows: wiener2 (0.1547 s.), wavelet-based denoising (3.7856 s.), BivShrink (0.5315 s.), SURELET (0.4007 s.), Non-local Means (13.4964 s.), BM3D (0.4656 s.), DnCNN (4.6286 s.), and WNNM (107.8025 s.). The WNNM method is too slow even though its denoising results are good. Table 3. The PSNR of different denoising methods for Head with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

90

100

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.25

10.09

9.07

8.15

Wieners

30.41

28.06

26.11

24.55

23.24

22.10

21.09

20.19

19.38

18.63

Wavelets

31.03

27.06

24.96

23.54

22.55

21.87

21.28

20.75

20.32

19.96

BivShrink

34.38

30.19

27.91

26.38

25.22

24.32

23.60

23.00

22.50

22.07

SURELET

32.20

28.75

26.75

25.39

24.40

23.63

23.01

22.50

22.06

21.67

NLM

32.29

29.40

27.24

25.63

24.38

23.32

22.42

21.63

20.93

20.29

BM3D

35.49

31.28

28.97

26.78

25.73

24.96

24.38

23.90

23.50

23.16

DnCNN

32.88

28.44

25.70

23.66

22.04

20.68

19.50

18.47

17.56

16.75

WNNM

35.50

31.35

28.98

27.25

26.16

25.26

24.60

24.06

23.60

23.20

Table 4. The PSNR of different denoising methods for Pelvis with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

90

100

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.25

10.09

9.07

8.15

Wieners

31.54

29.04

27.08

25.45

24.06

22.84

21.76

20.79

19.91

19.09

Wavelets

31.45

28.29

26.59

25.45

24.60

24.01

23.50

23.12

22.81

22.47

BivShrink

34.19

30.74

28.89

27.60

26.64

25.89

25.26

24.73

24.28

23.89

SURELET

32.80

29.88

28.17

27.00

26.13

25.45

24.92

24.44

24.05

23.67

NLM

33.71

30.36

28.15

26.58

25.37

24.38

23.50

22.73

22.02

21.37

BM3D

34.72

31.50

29.73

28.30

27.37

26.61

26.00

25.51

25.08

24.70

DnCNN

34.07

30.51

28.20

26.46

25.02

23.76

22.60

21.52

20.54

19.64

WNNM

34.80

31.61

29.76

28.51

27.55

26.78

26.14

25.60

25.14

24.71

An Experimental Study on MRI Denoising

433

Table 5. The PSNR of different denoising methods for Thigh with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.25

10.09

9.07

Wieners

31.35

28.91

26.94

25.29

23.90

22.69

21.63

20.67

19.81

19.01

Wavelets

31.66

28.45

26.58

25.31

24.40

23.75

23.16

22.72

22.32

21.86

BivShrink

34.44

30.89

28.98

27.67

26.68

25.90

25.25

24.71

24.22

23.80

SURELET

33.45

30.27

28.44

27.14

26.16

25.39

24.74

24.19

23.70

23.27

NLM

34.26

30.77

28.46

26.78

25.44

24.33

23.37

22.53

21.78

21.09

BM3D

35.51

32.33

30.55

29.14

28.18

27.39

26.74

26.15

25.58

25.11

DnCNN

34.20

30.17

27.57

25.63

24.04

22.67

21.44

20.35

19.37

18.49

WNNM

35.69

32.52

30.75

29.39

28.43

27.59

26.94

26.36

25.83

25.35

90

100 8.15

Table 6. The PSNR of different denoising methods for Thorax with different noise levels. The best results are highlighted in bold font. Denoising Methods

Noise STD (σn ) 10

20

30

40

50

60

70

80

90

100

Noisy

28.15

22.13

18.61

16.11

14.18

12.59

11.25

10.09

9.07

8.15

Wieners

31.00

28.84

27.04

25.48

24.13

22.92

21.84

20.86

19.97

19.15

Wavelets

31.31

28.30

26.62

25.58

24.86

24.33

23.89

23.52

23.23

22.97

BivShrink

33.82

30.43

28.58

27.40

26.54

25.85

25.27

24.78

24.37

24.00

SURELET

32.31

29.74

28.12

26.98

26.17

25.52

24.98

24.53

24.14

23.80

NLM

33.09

29.87

27.88

26.46

25.34

24.40

23.58

22.82

22.13

21.48

BM3D

33.91

30.83

29.19

27.98

27.18

26.54

25.98

25.52

25.13

24.79

DnCNN

33.51

29.92

27.54

25.72

24.23

21.76

21.76

20.70

19.73

18.85

WNNM

33.86

30.88

29.15

28.00

27.09

26.44

25.80

25.28

24.80

24.43

434

G. Y. Chen et al.

Fig. 1. The six MRI images used in the experiments.

Fig. 2. The denoised MRI images with different methods for Abdomen.

Fig. 3. The denoised MRI images with different methods for Feet.

An Experimental Study on MRI Denoising

Fig. 4. The denoised MRI images with different methods for Head.

Fig. 5. The denoised MRI images with different methods for Pelvis.

Fig. 6. The denoised MRI images with different methods for Thigh.

435

436

G. Y. Chen et al.

Fig. 7. The denoised MRI images with different methods for Thorax.

3 Conclusions and Future Research The denoising of MRI images is an extremely important research topic. Among all image denoising methods in the literature, which one is the best for MRI denoising? As far as we know, this question has no available answer in the literature. This is the main reason why we conduct experiments for MRI denoising here. In this paper, we have performed a comparable study for the denoising of MRI images. We have considered six MRI images and eight image denoising methods. Experiments have demonstrated that the BM3D and the WNNM are the best two methods for MRI denoising. However, the BM3D method is preferred because it is much faster than the WNNM method. All other denoising methods are not as good as the BM3D method and the WNNM method for MRI denoising, Future research will be conducted in the following ways. We would propose novel methods for 3D MRI denoising instead of 2D MRI denoising. We would consider Poison noise model and spatial varying noise models for MRI denoising. MRI images can also be corrupted by Rician noise, which is image dependent and computed from both real and imaginary images. Rician noise makes image-based quantitative measurement difficult. We would investigate deep CNN for MRI denoising as well since it is very popular at present. We would also like to improve our previously published algorithms for image denoising ([9–11]) and apply them to MRI denoising.

4 Data Availability Statement Data sharing not applicable to this paper as no datasets were generated or analyzed during the current study.

Conflict of Interest. The authors declare that they have no conflict of interest.

References 1. Lim, J. S., Two-Dimensional Signal and Image Processing, Englewood Cliffs, NJ, Prentice Hall (1990)

An Experimental Study on MRI Denoising

437

2. Donoho, D.L.: De-noising by Soft-Thresholding. IEEE Trans. Inf. Theory 42(3), 613–627 (1995) 3. Sendur, L., Selesnick, I.W.: Bivariate shrinkage with local variance estimation. IEEE Signal Process. Lett. 9(12), 438–441 (2002) 4. Blu, T., Luisier, F.: The SURELET approach to image denoising. IEEE Trans. Image Process. 16(11), 2778–2786 (2007) 5. Buades, A., A non-local algorithm for image denoising, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 60–65 (2005) 6. Dabov, K., Foi, A., Katkovnik, V., Egiazarian, K.: Image denoising by sparse 3D transformdomain collaborative filtering. IEEE Trans. Image Process. 16(8), 2080–2095 (2007) 7. Zhang, K., Zuo, W., Chen, Y., Meng, D., Zhang, L.: Beyond a Gaussian Denoiser: Residual Learning of Deep CNN for Image Denoising. IEEE Trans. Image Process. 26(7), 3142–3155 (2017) 8. Gu, S., Zhang, L., Zuo, W. and Feng, X., Weighted Nuclear Norm Minimization with Application to Image Denoising, IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2862–2869 (2014) 9. Chen, G.Y., Kegl, B.: Image denoising with complex ridgelets. Pattern Recogn. 40(2), 578– 585 (2007) 10. Chen, G.Y., Bui, T.D., Krzyzak, A.: Image denoising using neighbouring wavelet coefficients. Integrated Computer-Aided Engineering 12(1), 99–107 (2005) 11. Chen, G.Y., Bui, T.D., Krzyzak, A.: Image denoising with neighbour dependency and customized wavelet and threshold. Pattern Recogn. 38(1), 115–124 (2005)

Multi-scale and Self-mutual Feature Distillation Nianzu Qiao1(B) , Jia Sun2 , and Lu Dong3 1 College of Electronic and Information Engineering, Tongji University, Shanghai 201804,

China [email protected] 2 School of Automation, Southeast University, Nanjing 210096, Jiangsu, China 3 School of Cyber Science and Engineering, Southeast University, Nanjing 210096, Jiangsu, China

Abstract. In the past few years, there have been major strides in the study of knowledge distillation, of which feature distillation is an important branch. A growing amount of scholars have conducted extensive research on this topic and have produced a variety of techniques for feature distillation. However, there is still much space for growth in this area. To handle this problem, this paper designs an innovative feature distillation strategy, which consists of self-mutual feature distillation and multi-scale feature distillation. Self-mutual feature distillation improves the performance of student networks by effectively combining the self-supervised properties of student networks with the mutually supervised properties of student-teacher networks. Furthermore, alternate feature distillation training technique is designed to maximise the advantages of both properties. Multi-scale feature distillation can deal with the limitations of single-scale feature distillation. It can encourage the student network to provide greater consideration to the multi-scale hidden information of the teacher network. Through a large number of comparative experiments, it can be found that the approach suggested in this paper generates the best outcomes for all application tasks. For example, in the CIFAR100 and ImageNet classification tasks, our approach obtains the lowest error rate. In the semantic segmentation task, our method obtains the highest mIoU value. Moreover, the ablation study further demonstrates the potential of each of the suggested components. Keywords: Feature distillation · Self-mutual · Multi-scale

1 Introduction Deep learning networks are currently enjoying great success in various areas of computer vision. While deeper networks bring huge performance gains, they also increase the complexity of the network. It is challenging to fulfill the practical applications of realtime processing. As a result, a large number of researchers have worked on network compression as a way to achieve the goal of making models smaller while maintaining their performance. There are three broad categories of proposed network compression techniques, such as network pruning, model quantization, and knowledge distillation, of © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 438–448, 2023. https://doi.org/10.1007/978-981-99-4742-3_36

Multi-scale and Self-mutual Feature Distillation

439

which knowledge distillation is currently the most popular area of research. It is based on a ‘teacher-student’ training approach, where the prior knowledge contained in a trained teacher model (a large model) is distilled into another student model (a small model). Unlike other techniques, it not only reduces the size of the network, but also does not require changes to the structure of the network. The knowledge distillation technique was proposed by Hinton et al. [1] in 2015. It involves adding a softmax function with a distillation temperature T to the output of the teacher-student network and then performing a loss calculation on the output of both. However, the technique has obvious drawbacks, for example there is no significant difference between the output of the teacher network being used to instruct the student network and ground truth being used to instruct the student network. Therefore, in order to maximise the potential information learned from the teacher’s network, researchers have suggested a variety of feature distillation methods. This class of methods allows the student network to learn the internal latent feature values of the teacher network, thus boosting the performance of the student network. Romero et al. [2] presented the FitNets feature distillation approach in 2015. The approach motivates the student network to investigate the hidden feature values of the teacher network. However, it did not achieve significant performance improvements. Zagoruyko et al. [3] suggested the attention transfer (AT) feature distillation strategy in 2017. The strategy increases the performance of the student network by making it mimic the attention map of the teacher network. Kim et al. [4] suggested a factor transfer (FT) feature distillation strategy. It implements migration and translation through two convolutional modules. Heo et al. [5] designed an activation boundaries (AB) knowledge distillation approach. It focuses on the design of activation transfer loss. Yang et al. [6] designed a feature distillation approach dedicated to the object detection task. By emphasizing both local and global knowledge, it enhances the functionality of student networks. By incorporating an activation function, a distillation location, and a distance function, Heo et al. [7] enhanced the feature distillation loss. However, the above approaches are more or less problematic, and the situation also shows that there is still more space for improvement in student networks. Training image

Redundant feature map Redundant feature map

Redundant feature map

Feature distillation Feature distillation Feature distillation Maxpool Avgpool

Conv

Maxpool Avgpool Maxpool Avgpool

Conv + down Salient feature map

Salient feature map

Conv + FC Spatial dimension element-wise add

Salient feature map

Fig. 1. Diagram of self-feature distillation.

440

N. Qiao et al.

In this work, we design an innovative and reliable feature distillation technique. Unlike existing feature distillation approaches, this technique effectively combines the self-supervised properties of the student network with the mutually supervised properties of the student-teacher network, allowing for more comprehensive use of valid information to further strengthen the performance of the student network. At the same time, a new alternate feature distillation training strategy is suggested to address the respective characteristics of the two kinds of supervision. It can maximise the advantages of self-mutual feature distillation. Finally, to address the limitations of single-scale feature distillation, a multi-scale feature distillation strategy is suggested. It allows the student network to investigate the multi-scale hidden information of the teacher network, thus further enhancing the achievement of the student network.

2 The Proposed Approach In this section, first, a self-mutual feature distillation technique is described in Sect. 2.1. Then, a multi-scale feature distillation technique is formulated in Sect. 2.2. Finally, an overall loss function is devised in Sect. 2.3. 2.1 Self-mutual Feature Distillation In this subsection, the self-supervised properties of the student network are combined with the mutually supervised properties of the student-teacher network to effectively boosts the performance of the student network. Furthermore, in order to maximise the benefits of both features, an alternative feature distillation training technique is developed. Self-Feature Distillation. Over-parameterisation of student networks can lead to the generation of redundant features [8]. To address this issue, a self-feature distillation method for student networks is devised. It makes use of spatial dimensional saliency feature map to supervise other redundant feature maps. As shown in Fig. 1, given a student network and training data in the same layer, self-feature distillation distills other redundant features by learning saliency features. Specifically, first, a single-channel saliency feature map is obtained in the spatial dimension by Avgpool and Maxpool. Then, the redundancy feature maps are aligned with the saliency feature map in channel dimension one to one. Finally, the Lself loss of the redundant feature maps and the saliency feature map is calculated, and the equation is shown as follows: Lself =

C C 1  S 1   S S 2 f Fs , Fi = Fs − FiS  C C i=1

(1)

i=1

where C represents the total number of channels in the feature map, f stands for the L2 loss function, FsS denotes the significance feature map of the student network, and FiS means the redundant feature map of the student network. Mutual-Feature Distillation. A large number of feature distillation approaches have demonstrated the well-established fact that student networks under the direction of a

Multi-scale and Self-mutual Feature Distillation

441

teacher’s network can achieve excellent performance. The problem that needs to be faced with that fact is how the student network can maximise the information inherited from the teacher’s network. The attention mechanism CBAM [9] shows that focusing on key features in the channel and spatial dimensions can effectively boost the performance of CNN networks. Since the process of mutual-feature distillation does not contain initialisation parameters that require training. Therefore, an innovative simplified attention mechanism is designed in this paper. This simplified attention mechanism not only enables the extraction of key information for feature map in channel and spatial dimensions, but also it is computed directly and does not require training of weight parameters. Therefore, based on the simplified attention mechanism and the teacher network, a mutual-feature distillation technique for student-teacher networks is suggested. As shown in Fig. 2, given a student-teacher network and training data in the same layer, mutual-feature distillation effectively distills the features of the student network by focusing on the key features of the teacher network. Specifically, first, a simplified attention mechanism is suggested. In the spatial dimension, Avgpool and Maxpool are used to compute key features and Sigmoid is used to compute attention probabilities as in Eq. 2. ⎞ ⎞ ⎛⎛ W W H  H       1 Fi,j  + (2) max Fi,j ⎠/T ⎠ M C = C · Sigmoid ⎝⎝ HW i=1 j=1

i=1 j=1

where M C is the channel attention probability vector, H , W , C stand for the height, width and number of channels of the feature map respectively, Fi,j is the pixel value at location (i, j) and T is the temperature hyperparameter of the conditioning distribution introduced by Hinton et al. Teacher network

Feature distillation

GlobalAvgpool

Simplified attention

Softmax

Feature distillation

Simplified attention

GlobalMaxpool

Feature distillation Simplified attention

Simplified attention

Avgpool

Simplified attention

Sigmoid

Maxpool

Simplified attention

Simplified attention

Training image Element-wise multiply

Student network

Fig. 2. Diagram of mutual-feature distillation.

442

N. Qiao et al.

In the channel dimension, Global-Avgpool and Global-Maxpool are used to calculate key features and Softmax is used to calculate attention probabilities as in Eq. 3.

C C  1  S |Fk | + M = H · W · Softmax max(Fk ) /T (3) C k=1

k=1

where M S means the spatial attention probability matrix and Fk is the feature map of the kth channel. It is known that the attention probability matrices of the student network and the teacher network differ significantly. To achieve better performance of the student network, the attention probability matrix of the teacher network is used to guide the student network. The mutual-feature distillation loss Lmutual of the student-teacher network is defined as follows: H W C     2 1  S C T S − δS Fk,i,j  Lmutual = Mi,j Mk δT Fk,i,j (4) HWC i=1 j=1 k=1

where F T and F S denote the feature maps of the teacher network and the student network, respectively. δT and δS are the teacher transformation function and student transformation function, respectively, which are used for feature dimension alignment. Teacher network

Mutual-feature distillation

Mutual-feature distillation

Self-feature distillation

Self-feature distillation

Mutual-feature distillation

Step 1

Self-feature distillation

Step 2

Training image

Student network

Fig. 3. Diagram of alternate feature distillation.

Fig. 4.Diagram of multi-scale feature distillation. To further approximate the attentional feature map of the student network to that of the teacher network, the mathematical description of the attentional loss Lal is as given below:     (5) Lal = f MTS F T , MSS F S + f MTC F T , MSC F S where T and S represent teachers and students respectively. Alternate Feature Distillation. In order to fully combine self-feature distillation and mutual-feature distillation, an alternate feature distillation training strategy is proposed.

Multi-scale and Self-mutual Feature Distillation

443

Teacher network

Maxpool 2×2 Maxpool 4×4 Maxpool 8×8 Feature distillation Maxpool 8×8 Maxpool 4×4 Maxpool 2×2

Maxpool 2×2 Maxpool 4×4 Maxpool 8×8 Feature distillation Maxpool 8×8 Maxpool 4×4 Maxpool 2×2

Maxpool 2×2 Maxpool 4×4 Maxpool 8×8 Feature distillation Maxpool 8×8 Maxpool 4×4 Maxpool 2×2

Training image

Student network

Fig. 4. Diagram of multi-scale feature distillation.

Specifically, a specific flow of alternate feature distillation is given in Fig. 3. In step 1, the mutual feature distillation technique is used for training. When step 1 is completed, the student network, which has initially obtained the teacher network parameters, is trained using the self-feature distillation approach. The two feature distillation techniques are executed alternately until the end of training. The alternate feature distillation technique can give full play to the respective advantages of the self - mutual feature distillation approach and significantly increase the performance of the student network.

2.2 Multi-scale Feature Distillation The usefulness of multi-scale has been demonstrated by approaches in several fields [10, 11]. Feature maps at different scales have different feature descriptions. For example, large scale feature maps concentrate on global feature representation, while small scale feature maps concentrate on local feature representation, and there are limitations of feature distillation approach using only a single scale. Therefore, a multi-scale feature distillation strategy is suggested in this paper, as shown in Fig. 4. In this paper, the Maxpool is used to implement the transformation of multi-scale features. The Maxpool is more focused on salient features, which allows the student network to learn more effective information, thus boosting the achievement of the student network. From Fig. 4, it can be found that three scales of Maxpool, 2 × 2, 4 × 4, and 8 × 8, are used to generate multi-scale features. Based on these three scales and the original scale features, the loss Lscale for multi-scale feature distillation is defined as follows: Lscale =

       1  T S T S T S T S f F1/1 , F1/1 + f F1/2 + f F1/4 + f F1/8 (6) , F1/2 , F1/4 , F1/8 4

T , F T , F T and F T denote feature maps at the original scale, 1/2 scale, 1/4 where F1/1 1/2 1/4 1/8 scale, and 1/8 scale respectively.

444

N. Qiao et al.

2.3 Overall Optimisation Objective As our feature distillation approach requires a combination of all the distillation strategies proposed above. Therefore, the final overall optimization objective Ltotal is defined as follows: Ltotal = αLself + βLmutual + γ Lal + εLscale

(7)

where α, β, γ , and ε are weighting factors. Lself represents the self-feature distillation loss. Lmutual is the mutual-feature distillation loss. Lal means the attentional loss. Lscale stands for the multi-scale feature distillation loss.

3 Experiment In this paper, we face several different application tasks that are relevant for comparison experiments. The first is the CIFAR100 [12] classification task, the most common in machine learning. The second is a classification task performed on the large ImageNet [13] dataset. The third is the currently popular semantic segmentation task. To fairly compare the other approaches, we reproduce the relevant comparison results based on publicly available code and papers by the authors of the approaches. The experimental platform used in this paper is an NVIDIA GeForce RTX 3090 graphics card and PyTorch. To verify the effectiveness of each component of the feature distillation strategy suggested in this paper, we perform an ablation study. 3.1 CIFAR100 Classification In the field of knowledge distillation, the CIFAR100 classification task is the most common choice for validating its performance. The dataset consists of 50,000 images and 100 categories, and it is used for comparison experiments of various approaches. And to verify the generalization and effectiveness of the suggested feature distillation approach, teacher-student networks of various structures are used for comparison experiments. Various structures specifically for WideResNet [14] and PyramidNet [15] networks are used for the experiments, as their number of layers and depth can be easily modified. All of the teacher networks have more layers and are already trained to optimality, and all of the student networks are shallow in depth and without pre-trained weights. Each network undergoes 200 epochs of training with a 0.1 learning rate. Table 1 displays the classification results of the different techniques in different teacher-student structures. From Table 1, it can be seen that the feature distillation approach suggested in this paper achieves optimal results in (a), (b), (c), (d), (e) and (f) respectively. And Table 1 further demonstrates that the feature distillation technique suggested in this paper can be adapted to networks with different structures, such as deep and large networks or shallow and small networks.

Multi-scale and Self-mutual Feature Distillation

445

Table 1. Results of the CIFAR100 classification for different approaches in different teacherstudent structures. All values are expressed as error rates (%), where a smaller value means better. a

b

c

d

e

f

Teacher

WideResNet 28–4 21.09

WideResNet 28–4 21.09

WideResNet 28–4 21.09

WideResNet 28–4 21.09

PyramidNet-200 (240) 15.57

PyramidNet-200 (240) 15.57

Student

WideResNet 16–4 22.72

WideResNet 28–2 24.88

WideResNet 16–2 27.32

ResNet 56 27.68

WideResNet 28–4 21.09

PyramidNet-110 (84) 22.58

KD [1]

21.69

23.43

26.47

26.76

20.97

21.68

FitNets [2]

21.85

23.94

26.30

26.35

22.16

23.79

AT [3]

22.07

23.80

26.56

26.66

19.28

19.93

FT [4]

21.72

23.41

25.91

26.20

19.04

19.53

AB [5]

21.36

23.19

26.02

26.04

20.46

20.89

CO [7]

20.89

21.98

24.08

24.44

17.80

18.89

FG [6]

20.73

21.81

24.03

24.41

17.71

18.64

Ours

20.46

21.39

23.95

24.27

17.50

17.93

3.2 ImageNet Classification To further validate the performance of the suggested feature distillation scheme on large size images. We conduct relevant experiments in the ImageNet data. Where the image size of the ImageNet dataset is 469 × 387 on average, it contains 1.2 million training images and 50,000 validation images. During network training, the input images are randomly cropped to a size of 224 × 224 and all networks are trained for 200 epochs with a learning rate of 0.1. We conduct comparison experiments on two pairs of teacher-student networks. For example, ResNet152 [16] as a teacher network and ResNet50 as a student network. ResNet50 as a teacher network and MobileNet [17] as a student network. Table 2 presents the classification results of the different ways in the two pairs of teacher-student networks. From Table 2, it can be found that the feature distillation technique designed in this paper results in a huge improvement in the performance of the student network and outperforms the current up to date methods. 3.3 Semantic Segmentation This section performs a challenging application task - semantic segmentation. It can further demonstrate the generalisation and effectiveness of the suggested feature distillation approach. The DeepLabV3 + [18] based on ResNet101 is selected for the teacher network, and MobileNetV2 [19] and DeepLabV3 + based on ResNet18 are selected for the student network. The PASCAL VOC 2012 [20] segmentation dataset is chosen for the experimental dataset. All networks are trained for 50 epochs and the learning rate is consistent with the reference [18]. Table 3 displays the outcomes of semantic

446

N. Qiao et al.

Table 2. ImageNet classification results of different techniques in different teacher-student structures. All values are expressed as error rate (%), where a smaller value means better. Network

Method

Top-1 error

Top-5 error

Network

Method

Top-1 error

Top-5 error

ResNet152

Teacher

21.69

5.95

ResNet50

Teacher

23.84

7.14

ResNet50

Student

23.72

6.97

MobileNet

Student

31.13

11.24

KD [1]

22.85

6.55

KD [1]

31.42

11.02

AT [3]

22.75

6.35

AT [3]

30.44

10.67

FT [4]

22.80

6.49

FT [4]

30.12

10.50

AB [5]

23.47

6.94

AB [5]

31.11

11.29

CO [7]

21.65

5.83

CO [7]

28.75

9.66

FG [6]

21.73

5.98

FG [6]

27.62

8.74

Ours

21.59

5.81

Ours

27.43

8.21

segmentation, where the feature distillation technique designed in this paper effectively boosts the segmentation performance of the student networks. The experiments in this section further demonstrate that the feature distillation technique designed in this paper can be applied to different machine learning tasks to facilitate practical task needs. Table 3. Semantic segmentation results. The larger the result means better. Backbone

Method

mIoU

ResNet101

Teacher

77.39

ResNet18

Student

71.79

Ours

75.45

MobileNetV2

Student

68.44

Ours

72.98

3.4 Ablation Study In order to verify the effectiveness of the individual components of the feature distillation approach presented in this paper, relevant ablation experiments are performed. The CIFAR100 dataset is utilized as the testing ground for the experiments for the classification task, and WideResNet 16–4 is used for the student network. Table 4 shows the experimental results of the ablation study, the first one is the classification result of the student network without any feature distillation component. The second shows the classification results for the student network that employs only self-mutual feature

Multi-scale and Self-mutual Feature Distillation

447

distillation. The third shows the classification results for a student network that adopts only multi-scale feature distillation. The fourth is the classification result of the student network with each component employed. As can be seen from Table 4, each of the components suggested in this paper effectively improves the performance of the student network. Table 4. Results of the ablation study.

Error (%)

Baseline

Self-mutual

Multi-scale

Self-mutual + Multi-scale

22.72

20.73

20.79

20.46

4 Conclusion In this work, a self-mutual feature distillation approach is suggested based on the own characteristics of the student network and the reciprocal characteristics of the teacherstudent network. It can significantly enhance the functionality of the student network. Moreover, an alternate feature distillation training approach is designed in order to maximise the advantages of both feature distillations. A multi-scale feature distillation technique is designed in order to exploit the multi-scale information of the features effectively. The outcomes of the experiments further establish the merits of the strategy proposed in this paper.

References 1. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531 (2015) 2. Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: Hints for thin deep nets. In: 3th International Conference on Learning Representations (ICLR), pp. 1–13 (2015) 3. Zagoruyko, S., Komodakis, N.: Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. In: 5th International Conference on Learning Representations (ICLR), pp. 1–13 (2017) 4. Kim, J., Park, S.U., Kwak, N.: Paraphrasing complex network: network compression via factor transfer. In: 32th International Conference on Neural Information Processing Systems (NeurIPS), pp. 2765–2774 (2018) 5. Heo, B., Lee, M., Yun, S., Choi, J.Y.: Knowledge transfer via distillation of activation boundaries formed by hidden neurons. In: 19th AAAI Conference on Artificial Intelligence (AAAI), pp. 3779–3787 (2019) 6. Li, H., Kadav, A., Durdanovic, I., Samet, H., Graf, H.P.: Pruning filters for efficient convnets. In: 5th International Conference on Learning Representations (ICLR), pp. 1–13 (2017) 7. Woo, S., Park, J., Lee, J.-Y., Kweon, I.S.: CBAM: convolutional block attention module. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 3–19. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_1

448

N. Qiao et al.

8. Yang, Z., et al.: Focal and global knowledge distillation for detectors. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4643–4652 (2022) 9. Heo, B., Kim, J., Yun, S., Park, H., Kwak, N., Choi, J.Y.: A comprehensive overhaul of feature distillation. In: IEEE/CVF International Conference on Computer Vision (ICCV), pp. 1921–1930 (2019) 10. Wu, D., Wang, C., Wu, Y., Wang, Q.C., Huang, D.S.: Attention deep model with multi-scale deep supervision for person re-identification. IEEE Trans. Emerg. Top. Comput. Intell. 5(1), 70–78 (2021) 11. Wu, Y., et al.: Person reidentification by multiscale feature representation learning with random batch feature mask. IEEE Trans. Cogn. Dev. Syst. 13(4), 865–874 (2021) 12. Krizhevsky, A., Hinton, G.: Learning multiple layers of features from tiny images. Technical report, Citeseer, (2009) 13. Russakovsky, O., et al.: ImageNet large scale visual recognition challenge. Int. J. Comput. Vis. 115, 211–252 (2015) 14. Zagoruyko, S., Komodakis, N.: Wide residual networks. In: 27th British Machine Vision Conference (BMVC), pp. 1–12 (2016) 15. Han, D., Kim, J., Kim, J.: Deep pyramidal residual networks. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 5927–5935 (2017) 16. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770–778 (2016) 17. Howard, A.G., Zhu, M., Chen, B., Kalenichenko, D., Wang, W., Weyand, T., Andreetto, M., Adam, H.: Mobilenets: efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861. (2017) 18. Chen, L.-C., Zhu, Y., Papandreou, G., Schroff, F., Adam, H.: Encoder-decoder with Atrous separable convolution for semantic image segmentation. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 833–851. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_49 19. Sandler, M., Howard, A., Zhu, M., Zhmoginov, A., Chen, L.C.: Mobilenetv2: Inverted residuals and linear bottlenecks. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4510–4520 (2018) 20. Everingham, M., Eslami, S.A., Gool, L.V., Williams, C.K., Winn, J., Zisserman, A.: The pascal visual object classes challenge: a retrospective. Int. J. Comput. Vis. 111(1), 98–136 (2015)

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation and Guidance in Off-Road Environments Yu Liu1 , Xue Fan1 , Shiyuan Han1(B) , and Weiwei Yu2 1 School of Information Science and Engineering, University of Jinan, Jinan, China

[email protected]

2 Shandong Big Data Center, Jinan, China

[email protected]

Abstract. Despite the widespread application of convolutional neural network (CNN) based and transformer based models for road segmentation task to provide driving vehicles with valuable information, there is currently no reliable and safe solution specifically designed for harsh off-road environments. In order to address this challenge, we proposed a multi-task network (VPrs-Net) capable of simultaneously learning two tasks: vanishing point (VP) detection and road segmentation. By utilizing road clue provided by the VP, VPrs-Net achieves more accurate performance in identifying drivable areas of harsh off-road environments. Moreover, the model guided by the VP can further enhance the safety performance of driving vehicles. We further proposed a multi-attention architecture for learning of task-specific features from the global features to solve the problem of attentional imbalance in multi-task learning. The public ORFD off-road dataset was used to evaluate performance of our proposed VPrs-Net. Experimental results show that compared to several state-of-the-art algorithms, our model achieved not only 96.91% accuracy in the segmentation task, but also a mean error of NormDist of 0.03288 in road VP detection task. Therefore, the proposed model has demonstrated its potential performance in challenging off-road environments. Keywords: Off-Road environments · Multi-task learning · Segmentation · Vanishing point

1 Introduction In recent years, automatic driving technology has become one of the most popular research directions in the field of artificial intelligence. As an important component of automated driving, road segmentation task, which involves the accurate segmentation of road areas and other objects or people from the driving environment. This has attracted considerable attention from researchers [1]. In fact, many excellent models have been proposed since the CNN-based architecture was introduced in semantic segmentation [2–5, 27]. At present, the most advanced image segmentation models were models of transformer-based such as SegFormer [8] and RTFormer [10], which utilizing the advantages of transformer, have greatly improved performance of segmentation task. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 449–460, 2023. https://doi.org/10.1007/978-981-99-4742-3_37

450

Y. Liu et al.

Fig. 1. Typical difficult samples in off-road images. a. Dark road. b. Complex road. c. Urban road.

It could be argued that the existing models have achieved a satisfactory level of performance in road image segmentation task of urban scenes, demonstrating the latest trade-off in accuracy and speed on datasets such as CamVid and Cityscapes. However, these models are having one common drawback that the demand for segmentation of offroad environments is currently not being given much consideration. As shown in Fig. 1, the roads in the images (a and b) represent typical off-road environments. Off-road generally refer to roads with a lower level of structure, often found in urban secondary streets, rural roads, and other off-road environments. These roads lack adequate external lighting and clear lane markings and road boundaries, all of which degrade segmentation performance. This predicament poses a challenge for existing algorithms to accurately and completely segment road drivable areas, which may lead to a failure in ensuring safe driving in dark areas. Although off-road environments, such as the roads depicted in the images above, may not currently require perfect segmentation of road areas and boundaries, it is crucial for providing safe guidance to vehicles at a distance ahead. However, existing algorithms do not pay enough attention to address these depicted situations, which can seriously impede the driving experience of drivers or robots. From a practical perspective, the optimization of existing algorithms should not exclude off-road environments. VP is an important clue for road safety guidance and lane segmentation, and serves as a key constraint in scene understanding [11]. For straight lanes, VPs are the distant lanes, while for curved lanes, they are the intersection points of the lane tangents [12, 13]. In deep learning (DL) algorithms, researchers have attempted to combines VP with lane detection. However, due to the presence of pedestrians and other objects that can obstruct the vision of urban roads, coupled with the excellent performance of CNN-based lane

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation

451

detection methods, the application of VP is limited. Therefore, a multi-task network is proposed in this paper, which utilizes information of road VP assisted with segmentation of off-road. The VP, as a priori knowledge, can enhance model segmentation accuracy, and act a safety guidance for vehicles. Moreover, to make the shared features more suitable for multiple tasks, we also proposed a multi-attentional architecture, which extracts features using the same backbone and generates task-specific attentional weights by independent attention networks. So that the network adaptively focuses on areas that are task-specific. According to the results of experiments, the proposed model still improves in accuracy compared to state-of-the-art semantic segmentation algorithms and can get the accurate VP in the harsh off-road environments. In summary, our contributions are mainly summarized as follows: • As far as we know, the proposed method is the first attempt to use VP to solve the road segmentation problem with deep learning networks, and we further labeled the VPs on 12,198 road images. • We proposed a new multi-task model, VPrs-Net, which effectively integrates the tasks of VP detection and road segmentation. VPrs-Net can significantly improve the accuracy of two tasks in harsh off-road environments. • To solve the challenge of imbalanced feature representation in multi-task network, a multi-attention architecture is proposed in VPrs-Net, which could adaptively identify task-specific priorities in shared features through attention-based network, thereby improving the performance of two tasks. The rest of this paper is organized as follows. Some relevant works are given in Sect. 2. The proposed algorithm is explained in Sect. 3. The effectiveness of the proposed method is illustrated through detailed comparisons in Section 4, and our work is summarized in Sect. 5.

2 Related Work 2.1 Semantic Segmentation Semantic segmentation aims to identify different objects at the pixel level using annotations. With the rapid development of deep learning, many outstanding segmentation algorithms have emerged in recent years. Since the proposal of the full convolutional neural network (FCN), segmentation algorithms in the field of deep learning have made significant breakthroughs [14]. U-Net uses encoder-decoder structures to improve segmentation performance [2]. HRNet uses a multi-resolution fusion approach to combine feature maps of different resolutions [4]. Deeplabv3 + [3] and PSPNet [15] capture contextual information from different levels of the feature hierarchy and extend the receptive field by incorporating spatial pyramid pooling modules. BiSeNet effectively captures both global contextual information and fine-grained spatial detail by fusing information from low-resolution and high-resolution branches to improve segmentation accuracy [5]. DDRNet develops a dual-resolution network that achieves real-time semantic segmentation through multi-scale contextual fusion based on low-resolution feature maps [6]. Currently, the latest algorithms focus on combining transformer and CNN to achieve state-of-the-art results [7–10].

452

Y. Liu et al.

2.2 Multi-task Learning Compared with single-task learning, the goal of multi-task learning is to learn better representations by sharing information among multiple tasks, and its main advantages are that they can train one network to complete multiple tasks simultaneously and have better generalization capabilities. Currently, hard parameter sharing is more commonly used in the most popular multitasking networks, especially in the field of panoramic driving perception [16]. MultiNet combines multimodal learning with multi-task learning and can simultaneously perform scene classification, object detection and driving area segmentation [17]. DLTNet presents a multi-task deep learning model to simultaneously detect drivable areas, lane lines, and traffic objects [18]. YOLOP is a groundbreaking solution that enables panoramic driving perception through embedded devices [19]. It achieves high accuracy and speed in detecting traffic targets, segmenting drivable areas, and detecting lanes simultaneously. HybridNets is an end-to end network and proposed many optimizations such as an efficient segmentation head and box/class prediction networks, customized anchor for each level in the feature network, and an efficient training loss function and training strategy [20]. YoloPv2 proposes an effective and efficient multi-task learning network to simultaneously perform the tasks of traffic object detection, drivable road area segmentation and lane detection [21]. 2.3 Learning Task Inter-dependencies As we know, for any multi-task network, whether it is based on CNN or transformer, there is a problem faced with the weight allocation of multiple tasks. Currently, learning task-specific features and balancing loss of tasks are the two most general strategies to solve above problem. Liu et al. proposes a new multi-task learning architecture that allows learning feature-level attention for specific tasks [22]. Kendall et al. proposes a principled approach to multi-tasks deep learning which weighs multiple loss functions by considering the homoscedastic uncertainty of each task [23]. Chen et al. presents a gradient normalization algorithm that automatically balances training in deep multitask models by dynamically tuning gradient magnitudes [24]. Lin et al. proposes a general framework for generating Pareto efficient recommendations to solve for the difference in loss magnitude or learning speed of different tasks, and find the optimal task weight combination [25]. Bhattacharjee et al. introduces a shared attention between the transformer decoders of the multiple tasks and models the dependencies between multiple vision tasks via a shared attention mechanism [26].

3 Method In earlier works, a variety of deep learning based models have achieved impressive success in road image segmentation, but these models mainly focused on urban road datasets. Our goal is to provide an effective solution for road drivable area segmentation in harsh off-road environments as well as providing reliable guidance for vehicles or robots.

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation

453

Fig. 2. The proposed network, it consists mainly of a feature extraction stage, a multi-attention learning architecture stage and task-specific output heads.

3.1 The Proposed VPrs-Net Figure 2 provides an overview of the proposed architecture, which can perform both VP detection and road segmentation tasks. Our model is very simple, it consists mainly of a feature extraction stage, a multi-attention learning architecture stage, and task-specific output heads. Specifically, the input road image is first fed into a shared backbone to extract features. Then, the generated features are sent to the multi-attention learning architecture at different stages to learn task-specific features for two tasks. Finally, the architecture produces the results of road segmentation and VP detection. For the subnetworks of these two outputs, we use conv block and upsampling block to gradually generate results, where the conv block composed by a 3 × 3 convolution operation followed by BatchNorm and ReLU operations. It is worth mentioning that the upsampling module of the segmentation output needs to restore the features to the size of the original image, while the VP output network only needs to enlarge the features to 1/4 of the image size. This is because keypoint detection tasks are sensitive to positional information, and excessive upsampling can lead to loss of spatial details. Backbone. The backbone network plays a crucial role in feature extraction from input images and is typically implemented using well-established image classification network. Considering that this paper focuses on road segmentation, the DDRNet network was chosen as the backbone due to its impressive performance in both accuracy and real-time semantic segmentation. VP Information. Multi-task learning involves leveraging the shared information between multiple related tasks to enhance the performance of all tasks. The VP of the road can provide important clues for guidance and segmentation of the drivable area of the road, but it is necessary to introduce algorithms into the road segmentation in a reasonable and effective way. Based on the correlation between heatmap and keypoint

454

Y. Liu et al.

detection, the CNN-based heatmap regression was applied to predict the VP of the road [13]. It is important to highlight that the accuracy of the output can be significantly impacted by manually assigning task weights.

Fig. 3. The multi-scale convolutional attention (MSCA) module in SegNeXt

Multi-attention Network. The multi-task network needs to be considered more than two tasks, and the dependence of features on different tasks are varies. Inspired by Work [22], we designed a multi-attention architecture to integrate the feature information extracted from the backbone into task-specific attention network, providing different attention-focused features for different tasks. Specifically, the backbone is only used to extract features, while attention network allowing for self-supervised, end-to-end learning features of task-specific. Unlike other studies, our attentional network requires only the features of the final output of each stage, and the number of modules can be chosen differently depending on the backbone. This flexibility enables learning of more expressive feature combinations to deliver tailored functionality for each task. As shown in Fig. 3, each feature extraction module of backbone generates a set of features, which are connected with the features generated in the previous attention stage (on Stage1, network only use the features generated by backbone). These connected features are then passed through a 1 × 1 convolution operation followed by BatchNorm and ReLU operations to refine features. The self-atterntion module uses an MSCA module in SegNeXt, as shown in the Fig. 3. The MSCA is comprised of three components. Firstly, a depth-wise convolution is used to gather information from local regions. Then, multi-branch depth-wise strip convolutions are used to capture context at multiple scales. Finally, a 1 × 1 convolution is utilized to model channel relationships. The resulting output of the 1 × 1 convolution serves as attention weights to adjust the input of the MSCA.

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation

455

4 Experiments 4.1 Datasets To compare the performance of different algorithms, we evaluate the proposed VPrsNet method on the ORFD dataset [28], which contains three classes (traversable area and non-traversable area, and unreachable area). The ORFD dataset was collected in offroad environments, which contains a total of 12,198 annotated RGB images with the size of the RGB image is 1280 × 720 and includes various scenes such as woodland, farmland, grassland, and countryside, different weather conditions such as sunny, rainy, foggy, and snowy, and different light conditions such as bright light, daylight, twilight, and darkness. Then, we manually annotated the VPs for the ORFD dataset. 4.2 Implementation Details Regarding the training, we use the AdamW optimizer with an initial learning rate of 0.00015, batch-size of 16, weight decay of 0.0125, and beta of 0.9–0.999. All experiments are performed on NVIDIA GTX3090 GPU using Python3.7 with Cuda11.2. We trained our model 30K iterations on ORFD dataset, which contains 8398 images training dataset, the 1245 images validation dataset, and 2555 images test dataset. Specifically, these datasets were utilized as the training dataset, validation dataset, and test dataset in our experiment, respectively. The original image size of 720 × 1280 is preserved for all images. Finally, A Gaussian kernel with the same standard deviation (std = 7) was applied to all the ground truth heatmaps. 4.3 Metrics Following the common methods [1], we used four common metrics for the performance TP 1 k TP evaluation of segmentation: 1) mIOU = k+1 0 TP+FP+FN , 2) Precision = TP+FP , TP TP+TN 3) Recall = TP+FN , and 4) Accuracy = TP+TN +FP+FN . where TP, TN , FP and FN represent the number of true positive, true negative, false positive, and false negative pixels, respectively. Then, we used Normdist for the performance evaluation of VP Pg −Pv  detection: Normdist = Diag(I ) . Where Pg and Pv denote the ground truth of the VP and the estimated VP, respectively. Diag(I ) is the length of the diagonal of the input image. 4.4 Loss The total loss for the VPrs-Net with segmentation and VP estimation is hence described by: Lall = λseg lseg + λvp Lvp

(1)

where Lvp and lseg are heatmap loss and segmentation loss, respectively. we use the mean-squared error for the VP heatmap loss and binary cross-entropy loss for road segmentation. λvp and λseg are the training weight of VP loss and segmentation detection loss, respectively. we set λvp to 15 and λseg to 1 to balance the contribution of the two loss terms during training.

456

Y. Liu et al. Table 1. The segmentation performance of the proposed Vprs-Net and other models.

Method

Backbone

mIOU

Precision

Recall

Accuracy

Speed

U-Net [2]



76.57%

86.07%

87.66%

90.01%

57 ms

HRNet [4]

HRNet-W48

91.23%

95.07%

95.13%

96.37%

193 ms

Deepdlv3 + [3]

ResNet50

90.09%

95.59%

94.83%

96.17%

63 ms

BiSeNet [5]

BiSeNetV2

80.48%

91.87%

86.99%

89.70%

38 ms

STDC [27]

STDC2-Seg

87.86%

95.01%

93.24%

94.87%

52 ms

DDRNet [6]

DDNet23-Slim

90.91%

95.67%

94.71%

95.21%

40 ms

SegFormer [8]

SegFormer-B2

89.60%

95.34%

93.76%

95.47%

83 ms

RTFormer [10]

RTFormer32

88.44%

93.56%

94.09%

95.35%

45 ms

VPrs-Net

DDNet23-Slim

92.35%

95.98%

96.00%

96.91%

44 ms

4.5 The Results of Road Segmentation The results of proposed algorithm on the ORFD test set are reported in the Table 1. For the road segmentation performance, we compared the proposed Vprs-Net with various popular segmentation networks, where U-Net [2], HRNet [4], DeepLabV3 + [3], BiSeNet [5], STDC [27] and DDRNet [6] belong to CNN-based networks and others are classified as transformer-based networks including SegFormer [8] and RTFormer [9]. All these models were widely used for image segmentation tasks and achieved the state-of-the-art performance at the time they were proposed. The comparison results with other segmentation models show that our model achieved the best segmentation performance for all segmentation metrics in terms of mIOU (92.35%), Precision (95.98%), Recall (96.00%) and Accuracy (96.91%). Next, the visual comparison of the segmentation results in a harsh off-road environment is presented in Fig. 4, which shows that our proposed model can achieve more accurate segmentation results. In particular, we can see from Fig. 4 that the suboptimal results in dark road images fail to segment road areas completely. On the contrary, the proposed VPrs-Net performs well on the test set, not only providing a better drivable area, but also providing visual guidance to the vehicles. The superior performance is achieved by co-training the detection and segmentation tasks and by assigning taskspecific features through a multi-attention network, so that the shared layer provides positive effects for segmentation. 4.6 The Results of VP Detection. We systematically tested the results of VP detection by comparing different backbone including ResNet50 and HRNet-W48. As shown in Table 2, the mean errors of NormDist by the proposed VPrs-Net is 0.03288, ResNet50 and HRNet-W48 are 0.03767 and 0.03419 respectively. Unlike the CNN-based backbone with a single task, our approach takes advantage of the shared layers to jointly implement several related tasks to improve

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation

457

Fig. 4. The segmentation visualization results of the proposed Vprs-Net and other models. From left to right are the input images, ground truth, the results of DDRNet, the results of our model and the heatmap of VP.

Table 2. The VP detection performance of the proposed Vprs-Net and other models. Method

Backbone

NormDist

ResNet

ResNet-50

0.03767

HRNet [4]

HRNet-W48

0.03419

VPrs-Net

DDRNet23-Silm

0.03288

performance. This allowed our model to extract more useful information from the offroad images, further improving the VP detection performance. 4.7 Ablation Study To further demonstrate the effectiveness of our proposed model. We conducted an ablation study on the ORFD validation dataset to quantitatively analyze the impact of these all factors. The experiments choose DDRNet as baseline in the comparison and use the same training and inference setting. Table 3 shows that adding VP and multi-attention also improve the segmentation accuracy, while the inference speed slightly decreases. Based on two proposed methods, VPrs-Net achieves 92.35% mIoU, 95.98% Precision, 96.00% Recall, and 96.91% Accuracy. The mIoU, Precision, Recall, and Accuracy are boosted by 1.44%, 0.31%, 1.29%, and 1.7%, respectively, compared to the baseline model. In summary, our proposed methods are effective for semantic segmentation.

458

Y. Liu et al. Table 3. The ablation study results of VP and multi-attention network.

Bassline √

VP









Multi-att



mIOU

Precision

Recall

Accuracy

Speed

90.91%

95.67%

94.71%

95.21%

40 ms

92.03%

95.80%

95.75%

96.37%

41 ms

92.35%

95.98%

96.00%

96.91%

44 ms

5 Conclusions In this paper, we aim to improve the performance of road segmentation in harsh off-road environments by combining richer cue information obtained from road VP. Therefore, we proposed a multi-task model, Vprs-Net, to simultaneously detect road VP and segment road drivable regions. The proposed Vprs-Net consists of a feature extraction layer, a multi-attention network, and a task-specific head layer. The feature extraction layer mainly extracts features from the images, the multi-attention feature network is used to learn task-specific features, and the head layer is used to accomplish the output specific tasks. It can provide more comprehensive information for driving vehicles, while optimization for harsh environments can enhance the safety of driving vehicles. Experimental results show that our model can achieve significant road segmentation and VP detection performance with high accuracy on a test set. Based on these satisfactory results, our proposed multi-task model provides more clues through cross-tasking compared with the previous with single-task model, which can specifically solve the driving problem in the harsh off-road environments based on the improved segmentation performance. Acknowledgments. This research was funded by the Natural Science Foundation of Shandong Province for Key Project under GrantZR2020KF006, the National Natural Science Foundation of China under Grant 62273164, and the Development Program Project of Youth Innovation Team of Institutions of Higher Learning in Shandong Province. A Project of Shandong Province Higher Educational Science and Technology Program under Grants J16LB06 and J17KA055.

References 1. Shen, W., Peng, Z., Wang, X., et al.: A survey on label-efficient deep image segmentation: bridging the gap between weak supervision and dense prediction. IEEE Trans. Pattern Anal. Mach. Intell. 45, 1–20 (2023) 2. Ronneberger, O., Fischer, P., Brox, T.: U-net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds.) Medical Image Computing and Computer-Assisted Intervention – MICCAI 2015. MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28 3. Chen, L.C., Zhu, Y., Papandreou, G., et al.: DeepLab v3+: Encoder-decoder with atrous separable convolution for semantic image segmentation. In: Proceedings of the European Conference on Computer Vision, pp. 801–818. Springer, Munich (2018) 4. Wang, J., Sun, K., Cheng, T., et al.: Deep high-resolution representation learning for visual recognition. IEEE Trans. Pattern Anal. Mach. Intell. 43(10), 3349–3364 (2020)

A Novel Multi-task Architecture for Vanishing Point Assisted Road Segmentation

459

5. Yu C., Wang J., Peng C., et al.: Bisenet: Bilateral segmentation network for real-time semantic segmentation. In: Proceedings of the European conference on computer vision, pp. 325–341. Springer, Munich (2018) 6. Hong, Y., Pan, H., Sun, W., et al.: Deep dual-resolution networks for real-time and accurate semantic segmentation of road scenes. IEEE Trans. Intell. Transp. Syst. 24(3), 3448–3460 (2022) 7. Chu, X., Tian, Z., Wang, Y., et al.: Twins: revisiting the design of spatial attention in vision transformers. Adv. Neural. Inf. Process. Syst. 34, 9355–9366 (2021) 8. Xie, E., Wang, W., Yu, Z., et al.: SegFormer: simple and efficient design for semantic segmentation with transformers. Adv. Neural Inf. Process. Syst. 34, 12077–12090 (2021) 9. Liu, Z., Lin, Y., Cao, Y., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF international conference on computer vision, pp. 10012–10022. IEEE, Montreal (2021) 10. Wang, J., Gou, C., Wu, Q., et al.: RTFormer: efficient design for real-time semantic segmentation with transformer. arXiv:2210.07124 (2022) 11. Lin, Y., Wiersma, R., Pintea, S.L., et al.: Deep vanishing point detection: Geometric priors make dataset variations vanish. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6103–6113. IEEE, New Orleans (2022) 12. Lee, S., Kim, J., Shin Yoon, J., et al.: Vpgnet: vanishing point guided network for lane and road marking detection and recognition. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1965–1973. IEEE, Venice (2017) 13. Liu, Y.-B., Zeng, M., Meng, Q.-H.: Heatmap-based vanishing point boosts lane detection. arXiv:2007.15602 (2020) 14. Long, J., Shelhamer, E., Darrell, T.: Fully convolutional networks for semantic segmentation. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3431– 3440. IEEE, Boston (2015) 15. Zhao, H., Shi, J., Qi, X., et al.: Pyramid scene parsing network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2881–2890.IEEE, Honolulu (2017) 16. Ruder S: An overview of multi-task learning in deep neural networks. arXiv:1706.05098 (2017) 17. Teichmann, M., Weber, M., Zoellner, M., et al.: Multinet: real-time joint semantic reasoning for autonomous driving. In: IEEE Intelligent Vehicles Symposium, pp. 1013–1020. IEEE, Changshu (2018) 18. Qian, Y., Dolan, J.M., Yang, M.: DLT-net: joint detection of drivable areas, lane lines, and traffic objects. IEEE Trans. Intell. Transp. Syst. 21(11), 4670–4679 (2019) 19. Wu, D., Liao, M.W., Zhang, W.T., et al.: Yolop: you only look once for panoptic driving perception. Mach. Intell. Res. 19, 1–13 (2022) 20. Vu, D., Ngo, B., Phan, H.: Hybridnets: end-to-end perception network. arXiv:2203.09035 (2022) 21. Han, C., Zhao, Q., Zhang, S., et al.: YOLOPv2: better, faster, stronger for panoptic driving perception. arXiv:2208.11434 (2022) 22. Liu, S., Johns, E., Davison, A.J.: End-to-end multi-task learning with attention. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1871–1880. IEEE, Seoul (2019) 23. Kendall, A., Gal, Y., Cipolla, R.: Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7482–7491. IEEE, Salt Lake City (2018) 24. Chen, Z., Badrinarayanan, V., Lee, C.Y., et al.: Gradnorm: gradient normalization for adaptive loss balancing in deep multitask networks. In: 35th International Conference on Machine Learning, pp. 794–803. Stockholm (2018)

460

Y. Liu et al.

25. Lin, X., Chen, H., Pei, C., et al.: A pareto-efficient algorithm for multiple objective optimization in e-commerce recommendation. In: 13th ACM Conference on Recommender Systems, pp. 20–28. Association for Computing Machinery, Copenhagen (2019) 26. Bhattacharjee, D., Zhang, T., Süsstrunk, S., et al.: Mult: an end-to-end multitask learning transformer. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12031–12041, IEEE, New Orleans (2022) 27. Fan, M., Lai, S., Huang, J., et al.: Rethinking bisenet for real-time semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9716–9725. IEEE (2021) 28. Min, C., Jiang, W., Zhao, D., et al.: ORFD: a dataset and benchmark for off-road freespace detection. In: 2022 International Conference on Robotics and Automation (ICRA), pp. 2532– 2538. IEEE, Philadelphia (2022)

DCNet: Glass-Like Object Detection via Detail-Guided and Cross-Level Fusion Jianhao Zhang1,2 , Gang Yang1(B) , and Chang Liu1 1 Northeastern University, Shenyang 110819, China

[email protected] 2 DUT Artificial Intelligence Institute, Dalian 116024, China

Abstract. Glass-like object detection aims to detect and segment whole glass objects from complex backgrounds. Due to the transparency of glass, existing detection methods often suffer from blurred object boundaries. Recently, several methods introduce edge information to boost performance. However, glass boundary pixels are extremely sparser than others. Using only edge pixels may negatively affect the glass detection performance due to the unbalanced distribution of edge and non-edge pixels. In this study, we propose a new detail-guided and cross-level fusion network (which we call DCNet) to tackle the issues of glass-like object detection. Firstly, we exploit label decoupling to get detail labels and propose a multi-scale detail interaction module (MDIM) to explore finer detail cues. Secondly, we design a body-induced cross-level fusion module (BCFM), which effectively guides the integration of features at different levels and leverages discontinuities and correlations to refine the glass boundary. Finally, we design an attention-induced aggregation module (AGM) that can effectively mine local pixel and global semantic cues from glass-like object regions, fusing features from all steps. Extensive experiments on the benchmark dataset illustrate the effectiveness of our framework. Keywords: Glass-like Object Detection · Feature Fusion · Deep Learning · Salient Object Detection

1 Introduction Glass-like objects are ubiquitous in all aspects of our lives, such as doors, windows, water glasses, and other glass objects. Because of the transparency and low contrast of the glass, it interferes to some extent with the operation of the visual sensing system [15, 22]. For instance, drones have to avoid hitting the glass of tall buildings while flying. A grabbing robot needs to accurately identify glass-like objects to achieve accurate grabbing. In recent years, visual detection methods for glass-like objects have received increasing attention from researchers. However, glass-like objects are usually transparent and simply reflect/map the contents of their surroundings. As shown in Fig. 1, it can be difficult to distinguish glass objects from the background because their boundaries often share characteristics with the background. Furthermore, glass objects are not salient, and sometimes there are occlusions or reflections on them, which further © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 461–472, 2023. https://doi.org/10.1007/978-981-99-4742-3_38

462

J. Zhang et al.

exacerbates the difficulties of precision detection in the glass. State-of-the-art methods of semantic segmentation [27] and salient object detection [19] can not address these issues satisfactorily.

Fig. 1. Two challenging glass scenes with indefinable boundaries and dimly-lit objects are listed. Our model performs better than GDNet [15].

In recent years, some glass-like object detection approaches based on deep learning have been presented. These methods can be roughly classified into two categories. One is to utilize specific network modules to mine glass area features. For instance, Mei et al. [15] design a contextual feature module for exploring large fields view of the glass area. The other is to introduce some auxiliary information to enhance the feature representation of the glass area. For example, Mei et al. [14] exploit trichromatic intensity and spectral polarization cues from glass images to mine cross-modal global contextual information. Huo et al. [10] effectively use the transmission property of thermal energy to precisely segment glass areas. These methods continuously improve the performance of the glass-like object. Nevertheless, there are still existing two main challenges: (i) Existing glass detection models may get a large prediction error when the pixels are close to the object boundary [21] due to the sparseness of the boundary pixels. Simply introducing boundary supervision may mislead the learning process of the model. (ii) Existing glass segmentation methods mainly focus on the exploration of useful information from different auxiliary cues [6, 23]. However, the obtained auxiliary features usually vary in levels and characteristics, it is difficult to predict glass objects at different scales owing to the lack of an effective fusion mechanism to integrate multi-scale or multi-level features. Due to the two challenges mentioned above, it may be difficult for existing methods to generate glass maps with fine boundaries and structural integrity. Similar to other segmentation tasks [5, 28], the most straightforward way to improve the edge detection capability is to introduce edge information as additional supervision. However, the introduced edge labels only represent pixels on the boundaries, and glass boundary pixels are much sparser than others. Using only edge pixels may negatively affect the glass detection performance due to the unbalanced distribution of edge and non-edge pixels. To address the issue, we introduce label decoupling [21], which explicitly decomposes the original glass labels into body labels and detail labels, and the decomposed detail labels consist of edge and nearby pixels, making full use of pixels near the edge and making the distribution of the pixel more balanced. In contrast to [21, 29], we only exploit detail information to supervise the intermediate output and our approach can be efficiently trained end-to-end.

DCNet: Glass-like Object Detection via Detail-Guided

463

In summary, our contributions are as follows:

Fig. 2. Examples of label decoupling. Detail means the decoupled detail label of the ground truth, where pixels near the boundary of the target have larger values.

1) A new cascade framework is proposed for glass-like object detection, called DCNet. Our network excavates and integrates detail information to boost the performance of glass detection in an exploration-to-aggregation manner. 2) We propose a new multi-scale detail interaction module, MDIM, which can effectively explore the contextual information of the detail map. We propose a new body-induced cross-level fusion module, BCFM, which guides the effective integration of features at different levels and leverage discontinuities and correlations to refine the glass boundary. Moreover, we propose an attention-induced aggregation module, AGM, which uses global attention to further integrate body features into detailed appearance features to improve glass-like detection performance. 3) The proposed network achieves outstanding performance compared to 15 state-ofthe-art methods on GDD datasets.

2 Method 2.1 Label Decoupling The specific contents of a glass surface come from the reflection of opaque objects in its background. The boundary information is extremely susceptible to background interference. In contrast, the consistency of the internal regions makes body features more easily detectable. Hence, we need to utilize a more rational way to deal with the pixel imbalance between edge information and area information. Unlike [21, 29], we only exploit the decoupled detail labels. Figure 2 displays two examples of decoupled detail labels. Concretely, given the ground- truth glass label, we transform the original label into the detail label using a distance transformation DT, where each pixel in the original glass area will be matched to the pixel with the minimum distance on the object boundary. The distance transformation process can be expressed as:  |G(p, q) − B(p, q)|, G(p, q) = 1 (1) Gd (p, q) = 0, G(p, q) = 0

464

J. Zhang et al.

Fig. 3. The overall architecture of the proposed model.

where Gd represents the detail label, B(p, q) denotes the boundary pixel, G(p, q) denotes the glass pixel. B(p, q) and G(p, q) have a minimum Euclidean distance between them. 2.2 Multi-scale Detail Interaction Module We design MDIM to capture large receptive field features at different scales through a multi-branch structure, and conduct cross-scale interactive learning to gain the ability to perceive various local contexts.Take the structure of MDIM in Fig. 4 as an example. MDIM consists of four parallel DIM branches and one residual branch to expand the field of perception while emphasizing the interaction between multi-scale information. To be more specific, assuming the input feature F i , MDIM captures the different characteristics by 4 DIM branches. The first layer of each DIM branch is a convolution layer with a kernel size of 3 followed by a Batch Normalization(BN) and a ReLU activation function to compress the number of channels. The second layer of each DIM branch adopts two parallel spatially separable convolutions {k x 1, 1 x k} and {1 x k, k x 1} to effectively capture local region information and reduce the computation, where the kernel size k is set to 3, 5, 7, 9 and the dilation rate of the first branch starts at 1 and gradually increases in the subsequent branches. The third layer of each DIM branch integrates the contextual features of different receptive fields. In addition, we enhance the exchange of information by adding short connections between DIM branches to obtain perceptual results from multiple branches. The process can be formulated as follows:  Sconv(Fi ), n = 1 (2) DIMni = Sconv(Cat(Fi , DIMn−1 )), n = 2, 3, 4 where Fi denotes the i-th feature map produced by the backbone network, i ∈ {0, 1, 2, 3, 4}, n is the branch number, DIM in denotes the output of the n-th branch, Cat(·) denotes the concatenation, Sconv(·) denotes the three convolutional layers in the

DCNet: Glass-like Object Detection via Detail-Guided

465

Fig. 4. Structure of the Multi-scale Detail Interaction Module (MDIM).

above DIM branch. Next, we use 3 × 3 convolutions on the concatenated outputs of the 4 DIM branches, thereby compressing the channels to 64. Then, we add it to the residual branch to prevent information loss. Finally, We obtained the output features MDIMi , i ∈ {2, 3, 4} which can perceive various local context information. This procedure can be represented as:    4 DIMni (3) MDIMi = Conv(Fi ) ⊕ Convd Catk=1 where Convd (·) denotes 3 × 3 convolution, Conv (·) denotes 1 × 1 convolution followed by BN and ReLU, Cat 4k=1 (·) denotes the concatenation of all 4 branches. 2.3 Body-Induced Cross-Level Fusion Module After exploring high-quality detail prediction maps using the multi-scale detail interaction module, the body encoder at the bottom of Fig. 3(a) will be fed with the original input image and the generated detail map to extract a new body flow. The body encoder consists of three convolutional blocks, passing the body flow to the body decoder. Afterward, the body decoder receives the body flow to generate the predicted glass map. The body decoder is performed by two subsequent Body-induced Cross-level Fusion Module (BCFM). To obtain global and local features more efficiently, we introduce the multi-scale channel attention mechanism (MSCA) [2]. Our BCFM aims to highlight the detailed discontinuities and semantic relevance of features at different scales to refine the boundaries of the glass. In detail, each BCFM takes two flows, including the body flow from the body encoder and the flow from the previous block. Then, we exploit discontinuity and correlation in both flows to aggregate different levels of features. This procedure can be represented as: Fb = ω(Fl  Fh)  Fl ⊕ (1 − ω(Fl  Fh))  F

(4)

BCFLi = fl(Fb) − fs(Fb)

(5)

466

J. Zhang et al.

BCFGi = fl(Fb) − fg(G(Fb))

(6)

BCFi = BCFLi ⊕ BCFGi

(7)

where ω stands for the MSCA. F h and F l are two adjacent features. F h is up- sampled twice such that has the same size as F l , then element-wise summed with F l, ⊕ represents the element-wise addition,  denotes the Hadamard product. And (1 − ω(Fl  Fh)) corresponds to the dotted line in Fig. 3(b). f l denotes local feature extraction, where a 3 × 3 convolution with a dilation rate of 1 is used. f s denotes surrounding environment feature extraction, where a 5 × 5 convolution with a dilation rate of 2 is used. f l and f s are followed by BN and ReLU. G is a global average pooling operation, and f g is a 1 × 1 convolution followed by BN and ReLU. 2.4 Attention-Induced Aggregation Module Figure 3(c) illustrates the detailed structure of the attention-induced aggregation module (AGM). The AGM aims to mine glass-like object regions for local pixel and global semantic clues while combining features from all steps. We introduce multi-scale channel attention to enhance characteristics and generate the (query, key, value) triplets from the input features to achieve self-attention. Specifically, for the given feature map F bh , which contains high-level semantic information, we apply two linear mapping functions on F bh to reduce its dimensionality and generate feature maps Q and K. For the feature map F bl and F d which have rich appearance details, we first use the attention module to enhance the characteristics of F bl and F d , respectively, and then generate F t based on element-wise addition. We use a convolution unit Conv(·) on F t such that it has the same channel dimension (64) and interpolate it to the same size as F bh . Then, following the application of a Softmax function to the channel dimension, we use the second channel as the attention map, resulting in F t . These operations are denoted as SF (·) in Fig. 3(c). After applying adaptive pooling and a crop operation, we obtain the feature map V. We obtain the relationship of each pixel in a feature to all other pixels by computing the matrix inner product of K and V. We perform a Softmax operation on the obtained autocorrelation features to generate weights f with the value range [0,1], which is the required selfattention coefficient. We multiply the attention coefficient, correspondingly, back into the feature map Q and feed the result into the graph convolution to reconstruct the original features. Afterward, we multiply the reconstructed features with the attention coefficient again. Finally, we utilize a convolution unit Conv(·) on the features to reduce the number of channels and perform a residual operation on the original input F bh to obtain the information-enhanced output F o , the process can be formulated as follows: Ft = SA(CA(Fd  (Fd ) ⊕ Fbl  (Fbl )))

(8)

Q = Conv(Fbh ), K = Conv(Fbh )

(9)

V = AP(K  SF(Conv(Ft )))

(10)

DCNet: Glass-like Object Detection via Detail-Guided

467

  f =σ V ⊗ K T

(11)

  Fo = Conv f T ⊗ GCN(f ⊗ Q) ⊕ Fbh

(12)

where CA(·) and SA(·) denote channel attention and spatial attention, respectively. Conv(·) means a convolution with a kernel size of 1. Denotes the inner product operation. AP(·) denotes the pooling and crop operations, f is the correlation attention map, K T and f T are the transpose of K and f which correspond to the red line in Fig. 3(c). 2.5 Loss Function The loss function L T is defined as: 1 LT = (LD + LL ) 2

(13)

where L D and L L denote detail loss and label loss, respectively. Detail loss: We use the proposed detail label Gd for supervision. The detail loss L D is defined as: LD = Lbce (Fd , Gd ) + Liou (Fd , Gd ) + Ldice (Fd , Gd )

(14)

Label loss: The label loss function L L is defined as: LL = Lbce (Fpre , G) + Liou (Fpre , G)

(15)

where F pre refers to the predicted glass map and G refers to the glass label.

3 Experiments 3.1 Datasets and Evaluation Metrics To evaluate the performance of the proposed approach, we conduct experiments on GDD [15]. GDD is a large-scale glass object segmentation dataset consisting of 2980 training images and 936 test images containing various glass scenes from life. We employ intersection over union (IoU), weighted F-measure (Fβω ) [12], mean absolute error (MAE), and balance error rate (BER) [18] to evaluate the performance of our network. 3.2 Comparisons with State-of-the-Art Methods To verify the superiority of the proposed model, we compare it with 15 state-of- the-art methods which selected from other related fields. For a fair comparison, the evaluated glass maps are provided by the authors or generated with models and codes provided by the authors. Quantitative Evaluation. Table 1 shows the results of the quantitative comparison of the four evaluation metrics. Also, the backbones of all methods are given in Table 1.

468

J. Zhang et al.

Fig. 5. Qualitative results of our method and some state-of-the-art methods.

As can be clearly observed from the results, on the GDD, the three evaluation metrics of our method are better than other methods. Specifically, compared with GDNet-B [16], the second-best model, the IoU is improved about 0.6%. Compared with the thirdbest PGSNet [26], our method increases IoU by 0.6% and Fβω by 0.4% on average. The performance superiority of DCNet is mainly due to the well-exploited detail information and the effective aggregation of cross-level features. Qualitative Evaluation. The visualization results are shown in Fig. 5, which include some hard scenes: large glass areas (1st –3rd rows), multiple glass regions (4th –5th rows) and small glass regions (6th rows). It can be clearly found that our method achieves better results compared to other methods. For example, as shown in the first three rows of Fig. 5, the scene contains large areas of glass and low color contrast, which undoubtedly increases the difficulty of detection. In contrast, our method has higher accuracy in boundary regions and a more complete structure. Benefiting from the detail guidance proposed in this study, our network tends to focus more on boundary regions. It is worth noting that our method also exhibits good performance in glass areas with a cluttered background, as shown in the last two rows of Fig. 5. 3.3 Ablation Studies To verify the effectiveness of the detail guidance strategy and each key module of our proposed method, we carry out ablation studies on GDD.

DCNet: Glass-like Object Detection via Detail-Guided

469

Table 1. Quantitative comparison of IoU, weighted F-measure(Fβω ), MAE, and BER to the stateof-the-art methods on GDD datasets. * denotes using CRFs for post- processing. The top three results are highlighted in red , green, and blue respectively. Methods

Pub’Year

Backbone

DeepLabv3+ [1] CCNet [9] FaPN [8] DSS [30] EGNet [28] F3Net [20] DSC [7] PraNet [4] MirrorNet* [25] TransLab [23] Trans2seg [24] GDNet [15] GSD [11] PGSNet [26] GDNet-B [16] DCNet

ECCV’18 ICCV’19 ICCV’21 TPAMI’19 ICCV’19 AAAI’20 CVPR’18 MICCAI’20 ICCV’19 ECCV’20 IJCAI’21 CVPR’20 CVPR’21 TIP’22 TPAMI’22 Ours

ResNet-50 ResNet-50 ResNet-101 ResNet-50 ResNet-50 ResNet-50 ResNet-50 ResNet-50 ResNeXt-101 ResNet-50 ResNet-50 ResNeXt-101 ResNeXt-101 ResNeXt-101 ResNeXt-101 ResNeXt-101

GDD IoU ↑ 0.700 0.843 0.867 0.802 0.851 0.848 0.836 0.821 0.851 0.816 0.844 0.876 0.875 0.878 0.878 0.884

BER ↓

15.49 8.63 5.69 9.73 7.43 7.38 7.97 9.33 7.67 9.70 7.36 5.62 5.90 5.56 5.52 5.81

↑ 0.767 0.867 0.887 0.799 0.870 0.870 0.855 0.847 0.866 0.849 0.872 0.898 0.895 0.901 0.905

MAE↓ 0.147 0.085 0.062 0.123 0.083 0.082 0.090 0.098 0.083 0.097 0.078 0.063 0.066 0.062 0.061 0.058

Table 2. Quantitative evaluation results of the proposed detail guidance. M: removing detail supervision. The best results are shown in bold. Methods

GDD IoU ↑

BER ↓

MAE↑

M

0.873

6.36

0.069

Ours

0.884

5.81

0.058

Fig. 6. Visual illustration of detail guidance.

Effectiveness of Detail Guidance: We carry out the experiments to validate the importance of the major component in the proposed network: detail guidance. We remove detail supervision, which is shown in Table 2 as M. As can be observed, all metrics show a significant decline in performance, especially in IoU (decrease 1.1%). Figure 6 more intuitively shows the visual results of detail guidance. We can clearly find that the method without detail guidance loses many structural details (shown in Fig. 6(b)),

470

J. Zhang et al. Table 3. Ablation studies on GDD. The best results are shown in bold. IoU ↑

Setting

BER ↓

MAE ↓

No.1

baseline

0.861

6.41

0.074

No.2

w/o MDIM

0.875

6.14

0.063

No.3

w/o BCFM

0.871

6.29

0.069

No.4

w/o AGM

0.877

5.97

0.065

No.5

Ours

0.884

5.81

0.058

leading to blurred boundaries in predictions and incorrect segmentation, while adding detail information guidance can effectively convey detailed structural information and further integrate body features into detailed appearance features. Effectiveness of proposed Modules: As shown in Table 3, we design the ablation studies which consist of No. 1, No. 2, No. 3, and No. 4. Concretely, in the No.1 (baseline) experiment, we remove all MDIMs, BCFMs, AGM, as well as the decoupled detail label supervision while keeping the cascade framework, and then simply fuse the features of the last three layers by the element summation operation after upsampling. Next, in the No. 2 (w/o MDIM), No. 3 (w/o BCFM), and No. 4 (w/o AGM) experiments, we remove one of the three modules each time in our network. It can be observed that the No. 2 model decreases the performance of IoU by 0.9% compared to the final model (No. 5). The results show that the introduction of MDIM enables our model to detect objects accurately. Similarly, the removal of BCFM (No. 3) and AGM (No. 4), decrease the performance of IoU by 1.3% and 0.7% respectively. The BCFM and AGM can improve the ability to detect the major parts of glass-like objects.

4 Conclusion In this paper, we propose a detail-guided network to address glass-like object detection in an exploration-to-aggregation manner, called DCNet. Concretely, we first use the decoupled detail label as supervision to provide rich boundary pixel information, and then, we propose MDIM and BCFM to effectively mine the boundary information and global semantic information, and leverage discontinuities and correlations to refine the glass boundaries. In addition, our AGM can integrate the characteristics of the two steps. Experimental results show that our method achieves state-of-the-art glass-like object detection performance. Acknowledgment. This work is supported by the National Natural Science Foundation of China under Grant No. 62076058.

DCNet: Glass-like Object Detection via Detail-Guided

471

References 1. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: Deeplab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Trans. Pattern Anal. Mach. Intell. 40(4), 834–848 (2017) 2. Dai, Y., Gieseke, F., Oehmcke, S., Wu, Y., Barnard, K.: Attentional feature fusion. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3560–3569 (2021) 3. De Boer, P.T., Kroese, D.P., Mannor, S., Rubinstein, R.Y.: A tutorial on the cross- entropy method. Ann. Oper. Res. 134(1), 19–67 (2005) 4. Fan, D.P., et al.: Pranet: parallel reverse attention network for polyp segmentation. In: International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 263–273. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59725-2_26 5. Feng, M., Lu, H., Ding, E.: Attentive feedback network for boundary-aware salient object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1623–1632 (2019) 6. He, H., et al.: Enhanced boundary learning for glass-like object segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 15859–15868 (2021) 7. Hu, X., Zhu, L., Fu, C.W., Qin, J., Heng, P.A.: Direction-aware spatial context features for shadow detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7454–7462 (2018) 8. Huang, S., Lu, Z., Cheng, R., He, C.: Fapn: feature-aligned pyramid network for dense image prediction. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 864–873 (2021) 9. Huang, Z., Wang, X., Huang, L., Huang, C., Wei, Y., Liu, W.: Ccnet: criss-cross attention for semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. pp. 603–612 (2019) 10. Huo, D., Wang, J., Qian, Y., Yang, Y.H.: Glass segmentation with rgb-thermal image pairs. arXiv preprint arXiv:2204.05453 (2022) 11. Lin, J., He, Z., Lau, R.W.: Rich context aggregation with reflection prior for glass surface detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13415–13424 (2021) 12. Margolin, R., Zelnik-Manor, L., Tal, A.: How to evaluate foreground maps? In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 248–255 (2014) 13. Mattyus, G., Luo, W., Urtasun, R.: Deeproadmapper: extracting road topology from aerial images. In: International Conference on Computer Vision (2017) 14. Mei, H., et al.: Glass segmentation using intensity and spectral polarization cues. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12622–12631 (2022) 15. Mei, H., et al.: Don’t hit me! glass detection in real-world scenes. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 3687–3696 (2020) 16. Mei, H., Yang, X., Yu, L., Zhang, Q., Wei, X., Lau, R.W.: Large-field contextual feature learning for glass detection. IEEE Trans. Pattern Anal. Mach. Intell. 45, 3329–3346 (2022) 17. Milletari, F., Navab, N., Ahmadi, S.A.: V-net: fully convolutional neural networks for volumetric medical image segmentation. In: 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. IEEE (2016) 18. Nguyen, V., Yago Vicente, T.F., Zhao, M., Hoai, M., Samaras, D.: Shadow detection with conditional generative adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 4510–4518 (2017)

472

J. Zhang et al.

19. Qin, X., Zhang, Z., Huang, C., Gao, C., Dehghan, M., Jagersand, M.: Basnet: boundary-aware salient object detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 7479–7489 (2019) 20. Wei, J., Wang, S., Huang, Q.: F3net: fusion, feedback and focus for salient object detection. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 12321–12328 (2020) 21. Wei, J., Wang, S., Wu, Z., Su, C., Huang, Q., Tian, Q.: Label decoupling frame- work for salient object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13025–13034 (2020) 22. Whelan, T., et al.: Reconstructing scenes with mirror and glass surfaces. ACM Trans. Graph. 37(4), 102–111 (2018) 23. Xie, E., Wang, W., Wang, W., Ding, M., Shen, C., Luo, P.: Segmenting transparent objects in the wild. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) Computer Vision – ECCV 2020. ECCV 2020. Lecture Notes in Computer Science, vol. 12358, pp 696–711. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58601-0_41 24. Xie, E., et al.: Segmenting transparent object in the wild with transformer. arXiv preprint arXiv:2101.08461 (2021) 25. Yang, X., Mei, H., Xu, K., Wei, X., Yin, B., Lau, R.W.: Where is my mirror? In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8809–8818 (2019) 26. Yu, L., et al.: Progressive glass segmentation. IEEE Trans. Image Process. 31, 2920–2933 (2022) 27. Zhao, H., Shi, J., Qi, X., Wang, X., Jia, J.: Pyramid scene parsing network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2881–2890 (2017) 28. Zhao, J.X., Liu, J.J., Fan, D.P., Cao, Y., Yang, J., Cheng, M.M.: Egnet: edge guidance network for salient object detection. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8779–8788 (2019) 29. Zheng, C., et al.: Glassnet: Label decoupling-based three-stream neural network for robust image glass detection. In: Computer Graphics Forum, vol. 41, pp. 377–388. Wiley Online Library (2022) 30. Zhou, H., Xie, X., Lai, J.H., Chen, Z., Yang, L.: Interactive two-stream decoder for accurate and fast saliency detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9141–9150 (2020)

A Method for Detecting and Correcting Specular Highlights in Capsule Endoscope Images Based on Independent Cluster Distribution Jiarui Ma(B) and Yangqing Hou Jiangnan University, Lihu Avenue 1800, Wuxi, Jiangsu, China [email protected]

Abstract. Highlights of capsule endoscope images can significantly interfere with tasks such as three-dimensional reconstruction of the digestive tract and lesion recognition. Because the red channel pixel values in the non highlight regions of the capsule endoscope image are higher than the green and blue channel values, while the three channel pixel values in the highlight regions are similar, this paper proposes a specular highlight detection and correction method for capsule endoscope images based on independent cluster distribution. Firstly, the method determines the detection threshold based on the difference between the red channel and other channels, solves the RGB channel ratio, and identifies and detects the highlight region; Then, the detected highlight regions are divided into independent clusters based on a disconnected partitioning strategy, and based on the surrounding color, filtering and mask decay are used to compensate for each clustered highlight region; Finally, more detailed results are obtained through Gaussian blurring and mask decay. The experimental results showed that the proposed method improved the accuracy of highlight detection by 29.81% - 41.38% compared to the existing well-known color model based methods, improved the recall rate by 28.83% compared to the adaptive RPCA method, and reduced the NIQE index by 34.9905% compared to the well-known filtering method in terms of correction effect. Keywords: Highlight detection · Highlight correction · Endoscope image processing

1 Introduction Due to the small size of the capsule endoscope, the tight surrounding of the light source, the relatively close distance between the light source and the reflected inner wall, and the very smooth inner wall of the digestive tract, it is inevitable that the high light image belongs to a pure specular reflection image. In addition, the wrinkled structure inside the digestive tract further causes irregular highlights and even color distortion in the captured image [1–4]. These high light areas reduce image quality and have a significant impact on determining the location of the lesion, observing the surgical process, and navigating and positioning the capsule during image acquisition. They have become the main source of errors in many visual based computing tasks. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 473–484, 2023. https://doi.org/10.1007/978-981-99-4742-3_39

474

J. Ma and Y. Hou

In recent decades, many image specular highlight detection and correction algorithms have been proposed at home and abroad. Although existing detection methods can achieve basic positioning of such image highlights, the phenomenon of missing detection and false detection is widespread. Currently, a common method is to determine specular highlights through preset thresholds in different color spaces such as HSV [5, 6], grayscale levels [7, 8], and RGB [9, 10]. Due to the fact that the non highlight areas of the capsule endoscope image tend to be red in color, and the edges of the highlight areas are prone to color distortion, which makes it impossible for color modelbased algorithms to utilize channel value differences to avoid color distortion areas when setting thresholds [9, 10]. In terms of highlight correction, due to the presence of pure specular reflection highlights in capsule endoscope images and their irregular shape and edge color distortion characteristics, many highlight correction algorithms for natural images outside the cavity are not suitable for highlight correction of capsule endoscope images [10–12]. Irregularities in shape and color distortion lead to some filter based methods and adaptive robust principal component analysis methods [13, 14], which can produce poor results in large areas of high light. Therefore, this paper proposes an adaptive detection and correction method based on the color distribution characteristics of capsule endoscope images. The main contributions can be summarized as follows: (1) A threshold adaptive method based on color channel differences is proposed, which can solve the problem of accurate detection of both large areas of high light and weak areas of high light. (2) A compensation method for peripheral pixel highlights based on region clustering distribution is designed to solve the inconsistency between the corrected highlights and their edges and the surrounding environment.

2 Related Work 2.1 Highlight Detection of Capsule Endoscope Images In recent years, researchers have been working to obtain color thresholds using different color spaces, detecting highlights by comparing channel values with threshold values. Most threshold selections are based on manual experience. Arnold [10] et al. proposed a segmentation method based on nonlinear filtering and color image threshold. Although this method is both aimed at highlighting endoscopic images, the locally estimated color threshold cannot accurately locate the boundary position of large reflective areas. Li [15] et al. used the sensitivity of RPCA to outliers to detect scattered weak highlights in dark areas, and obtained a low rank matrix and a sparse matrix through iterative decomposition of the image using principal component analysis. However, if a large area of highlights appears in the digestive tract surface image near illumination, the highlight area will no longer be considered a sparse matrix. Meslouh [16] et al. proposed a method based on a two-color reflection model to convert an image from CIE-RGB to CIE-XYZ representation space. The brightness and normalized chromaticity of the CIE-XYZ space are used to identify the specular area through threshold values. However, this method is not suitable for endoscopic images with red color and in soft and volatile environments, as it is difficult to distinguish between high light areas and surrounding tissues.

A Method for Detecting and Correcting Specular Highlights

475

2.2 Highlight Correction of Capsule Endoscope Images Most of the existing highlight correction methods focus on processing a single image or static scene. However, these methods are difficult to effectively remove highlights in different capsule endoscope image sequences. Shen [17] et al. proposed to classify pixels into different clusters in the pseudo chromaticity space, and calculate the specular reflection component of each pixel based on the corresponding intensity. However, due to the particularity of the capsule endoscope image’s red color and dark image, the accuracy of calculating the reflection component by calculating the intensity will decrease, and this method reduces the accuracy of highlight removal in the capsule endoscope image. Arnold [10] et al. proposed a segmentation method and an efficient repair method based on nonlinear filtering and color image thresholds. Due to the inability of this method to accurately identify large areas of high light, the high light correction part of this method often also has the problem of overcorrection. The other is image sequence methods based on texture functions. Fu [18] et al. proposed using the sparsity of the L0 norm coding coefficient to restore the diffuse reflection components of areas with specular highlights. Although it can fill the high lights well, it has a poor fusion effect with the highlights edge areas, resulting in unnatural repair results. Compared to the network model method, this method does not require a large amount of training data to drive learning, and is suitable for high intensity large highlight regions. Divide the detected highlight areas into independent clusters, and use filtering and mask decay to compensate for each clustered highlight area. This method can achieve natural transitions with edge textures while correcting for no artifacts on the edges.

3 Proposed Method 3.1 Method Overview The technical route framework of this method is shown in Fig. 1. Firstly, the pixel significance of the capsule endoscope image with specular highlights is detected, and whether the pixel value of the pixel point is higher than the pixel value of the surrounding area is calculated. Based on the characteristics of capsule endoscope images with similar three channel values at high light and prominent red channel values at non high light, the ratio of the red channel to the other two channels is used as an adaptive channel ratio threshold for RGB determination for screening. In order to exclude non highlight regions with similar three channel values, an adaptive gray threshold is selected based on the histogram gradient, and then filtered again to obtain a highlight region mask X. The highlight image is divided into multiple connected different highlight regions. And select the pixel points around each connected highlight area as the surrounding area, search for the pixel points around the highlight area that are closest to the centroid of the highlight area, and fill the corresponding pixel values of the pixel points into the corresponding highlight area to obtain the filled image. Gaussian smoothing and mask decay are used to overpower the fill color, making the image more natural and achieving better correction effects.

476

J. Ma and Y. Hou

Fig. 1. Framework of endoscopic image specular highlight detection and correction method based on color distribute.

3.2 Threshold Adaptive Highlight Detection for Color Channel Differences Specular highlights refer to the phenomenon of total reflection that occurs when a light source is directly directed at the surface of an object, corresponding to the occurrence of local highlights on the image. This phenomenon is reflected in the image as the highlight area is brighter than the surrounding pixel points, which is defined as I (x) > α ∗ Imean

(1)

where I (x) is the brightness intensity of the pixel point x, and Imean is the average brightness intensity of the entire image, α Is a constant coefficient. When determining whether there is a highlight area in the image according to Eq. (1), first calculate α Break the Imean < I (x) matrix and record the result as a matrix Imean flag. If αImean < I (x) is set, the corresponding position of Imean flag is set to 1, otherwise it is set to 0. This method is highly susceptible to ambient lighting. At the same time, if the overall Imean is on the high side, it is also difficult to distinguish a large range of highlighted areas using this method. To solve the above problems, this paper proposes a standard based on RGB color spatial distribution. In the digestive tract environment, under normal circumstances, due to the presence of hemoglobin, most gastrointestinal images captured in capsule endoscope images tend to be fleshy red, meaning that the values of the red channel in the RGB three channels are higher than those of the green and blue channels. However, under specular highlights, the values of the RGB three channels are almost the same and tend to be saturated, especially for large areas of high light. Therefore, we define the discrimination criteria by using the ratio of the pixel average values of the red channel to other channels R=

IR (x) 1 2 (IG (x) + IB (x))

(2)

where IR (x), IG (x) and IB (x) are the intensities of the RGB three channels of pixel point x.

A Method for Detecting and Correcting Specular Highlights

477

According to Formula (2), the R of non highlight pixels is higher than that of highlight pixels. Therefore, a threshold value t is introduced to distinguish between highlighted and non highlighted pixels, which is defined as t = μ1 ∗ d 2 + μ2 ∗ d + μ3

(3)

where d is the difference between the red channel and the green and blue channels for each pixel point 1 d = mean(IR (x) − (IG (x) + IB (x))) 2

(4)

Using the defined standard for saturated specular pixels, pixels with R below t are marked as highlight pixels R thresh

(6)

Using the pixel significance detection in step (5) and the adaptive gray level threshold discrimination condition in step (6) above, pixels of the image that meet the condition are detected as highlighted areas. To further improve the adaptability of the highlight detection threshold standard, segmentation preprocessing is applied by dividing the image into overlapping equal size regions. The problem of uneven illumination caused by a single light source is resolved by setting adaptive thresholds for each patch. Specifically, in each patch, define the Imean in Eq. (1), as shown in Eq. (7). Imean = mean(I (x)), x ∈ A.

(7)

In addition, as shown in (3) and (4), d is calculated based on the information of each patch, which means that t in (5) calculated by d is adaptively set for each patch. Before performing the correction algorithm, based on the determination of red channel saturation and the determination of pixel value by gray level threshold, pixels that meet these two limits are selected. These pixels are the detected highlight pixels. The highlight pixels are assigned a value of (255, 255, 255), while the remaining pixels are assigned a value of (0, 0, 0) to generate a binary highlight region mask.

478

J. Ma and Y. Hou

3.3 Highlight Correction for Region Clustering After obtaining a specular highlight area mask, divide it into different areas based on whether the highlight area is connected. As shown in Fig. 1, for the current original image (a), a total of n disconnected highlight regions have been identified by this algorithm. The algorithm uses n different color blocks to mark and divide the highlight regions of the image. In Fig. 1, (a) represents the original image, (b) represents the current highlight region mask image, and (c) represents that the highlight region in the current image is divided into n disconnected highlight regions and labeled with different colors (Fig. 2).

(a) Original image

(b) Highlight area mask

(c) Subarea image

Fig. 2. Dividing image area (n = 46)

After partitioning the highlight area, in order to make the color filling of the highlight area natural, non highlight pixel areas are selected for color filling within the range closest to the specular highlight area. Select a non highlight pixel block that has edges connected to the edge pixel block of the specular highlight area. The surrounding area AroundX_i, can be obtained by removing the areas contained within the highlight area from all selected areas. Figure 3 takes the global image and one of the connected highlight regions, Local 1, as an example.

(a) Original image (full)

(b) Subarea map (full)

(c) Filled area (full)

Fig. 3. Schematic diagram of selecting peripheral pixels

To facilitate calculation, the centroid of each highlight area is selected as the representative point of the highlight area. The steps for calculating the centroid of each region are as follows: For specular highlight regions, the centroid centroid_i selects the center point of the highlight region, centroid_i(row) and centroid_i(col) are the average of the abscissa and ordinate index of all points in the region AroundX_i. The abscissa and ordinate calculation formula for the centroid is defined as Centroid _i(row) = mean(row_index_i)

(8)

A Method for Detecting and Correcting Specular Highlights

Centroid _i(col) = mean(col_index_i)

479

(9)

Traverse all pixel points in the surrounding area AroundX_i, using the centroid of the current highlight area as a representative point, calculate the Euclidean distance between all pixel points in the surrounding area and the centroid, as shown in Formula (10). Select the pixel point closest to the centroid in the current surrounding area, assign the pixel value of the point to the entire highlight area, and obtain the filled image.  M inDis = min( (row − c_row)2 + (col − c_col)2 (10) Due to the method of using centroid pixel value images to fill the highlight area, it is easy to cause a certain color difference between the highlight area and other areas, resulting in obvious boundaries. To solve the problem of unnatural color correction transitions, this method uses Gaussian smoothing to transform the filled image to obtain a smoothed filled image f_img. At the same time, the capsule endoscope image has the characteristics of multi-level color, so it is necessary to use mask decay to process the filled region, making the image more hierarchical and closer to the real state. Select a mask with a decay size of size (size = 3 * 3) to filter the image and obtain a decay matrix. Using the filled and smoothed fill image f_img, decay matrix d, and original img, calculate the correction result graph, which is defined as r_img = d ∗ f _img + (1 − d ) ∗ img

(11)

4 Experiments 4.1 Highlight Detection The experimental environment is Matlab R2021b, the experimental operating system is Windows 11, the computer hardware configuration is Intel i7-10400F, and the CPU main frequency is 3.3 GHz. The experimental data is sourced from CVC-ClinicDB [19]. According to a large number of clinical experiments, the parameters of (1) in the experiment α Is 2.4, the parameter in (3) μ1 , μ2 and μ3 They are set to −2.151 * 10–5, 2.031 * 10–3, and 1.221, respectively. 4.2 Highlight Detection High Light Detection Related Evaluation Indicators. Evaluate the performance of the proposed method through quantitative evaluation.

480

J. Ma and Y. Hou

The definitions of TP, TN, FP, and FN are as follows. Number of pixels

Prediction results

Real situation

Positive example

Negative example

Positive example

TP

FN

Negative example

FP

TN

In this study, we used precision, recall, and accuracy to quantify the performance of the proposed highlight detection method. Evaluation Indicators Related to Highlight Correction. Currently, the image after highlight correction will differ from the original image, so highlight image correction itself is not suitable for evaluating capsule endoscope images after highlight correction using reference image evaluation indicators [20]. In this paper, the reference free image quality evaluation algorithm NIQE (Natural Image Quality Evaluator) is introduced. NIQE is an evaluation indicator used in the field of image correction, compared to other evaluation indicators. NIQE is more consistent with the subjective evaluation of images by the human eye, and the texture details reflected in the image are more consistent with the visual habits of the human eye. It is an evaluation index suitable for the field of image correction. The larger the NIQE value, the worse the quality [21]. The calculation of the final value of NIQE is defined as   (12) D(v1 , v2 , 1 , 2 ) = (v1 − v2 )T (1 + 2 2)(v1 − v2 ) The maximum pixel difference is an evaluation indicator that reflects whether a highlight image has been effectively restored. If the highlight area is effectively corrected, the maximum pixel difference will decrease compared to the original image. The pixel average value is an indicator that reflects whether a high light image can maintain a similar brightness to the original image while effectively correcting it. First, if the image is effectively corrected, the pixel average should decrease, but if the pixel average decreases significantly, it indicates that there may be significant changes in the brightness of the image. 4.3 Highlight Detection Subjective Evaluation. The superiority of the proposed method was further verified. As shown in the figure, the method of Shen DF [7] can detect scattered small area highlights, but there may be excessive coverage, resulting in a larger detected area than the actual highlight area. The method in Arnold [10] is prone to labeling non highlight areas as highlight areas, resulting in the problem of over detection. The method in Meslouhi [16] tends to ignore some small highlight areas. The method of Li [15] has some shortcomings in detection results when encountering images with large areas of high light. The method in this paper can obtain accurate highlights detection results, achieving both large area highlights and weak highlights detection (Fig. 4).

A Method for Detecting and Correcting Specular Highlights

ImInitial Image age

Shen DF[7]

Meslouhi[16]

Arnold[10]

Li

481

Proposed

Pic1

Pic2

Pic3

Fig. 4. Effects of different methods of highlight detection.

Objective Evaluation. The algorithm proposed in this paper improves the accuracy by 41.38% and 29.81% respectively compared to literature [7] and literature [16], improves the recall rate by 29.78% and 28.83% respectively compared to literature [10] and literature [15], and improves the accuracy by 1.88%, 0.95%, 0.71% and 0.82% compared to literature [7], literature [10], literature [16], and literature [15] (Table 1). Table 1. Comparison between TN,FN,TP,FP,Precision,Recall and Accuracy. Method

TN

Shen DF[7]

FN

TP

FP

105625 183

Meslouhi[16]

107519 1380 1679 14

Arnold[10]

105914 45

Precision(%) Recall(%) Acccuracy(%)

2880 2504 53.49

94.03

97.57

99.17

54.89

98.74

3014 1619 65.06

98.53

98.50

Li错误!未找到引用源。 107179 1501 1898 14

99.27

55.84

98.63

Proposed

94.87

84.67

99.45

107393 469

2590 140

Ablation Experiment. Table 2 shows the evaluation of the adaptive grayscale threshold values of the three channel ratios R and I (x) for the key parts of the highlight detection algorithm. Obviously, the three channel ratio R and the adaptive grayscale threshold have a significant impact on detection: the absence of both can cause non highlight pixels to be detected as highlight pixels. Although the detection recall rate has increased, the accuracy and accuracy have significantly decreased. 4.4 Highlight Correction Subjective Evaluation. Shen’s method can better locate high light areas and fill them with grayscale in a small range. The method proposed by Fu et al. has a more accurate

482

J. Ma and Y. Hou Table 2. Ablation study of high light detection.

R

Imean threshold

I(x)threshold

Precision(%)

Recall(%)

Acccuracy(%)







94.87

84.67

99.45





65.54

94.83

98.48

60.11

95.33

98.12



positioning area and a larger range, but only has the effect of masking the highlight area, and does not have a good correction effect on the highlight area. Arnold et al.’s method cannot effectively avoid the impact of capsule endoscope edges on the image repair and correction process. Our method can modify the highlight areas in the form of dots and patches in subjective vision, so as to better integrate their colors with the surrounding areas (Fig. 5).

Image Initial Image

Shen[17]

Fu[18]

Arnold[10]

Proposed

Pic1

Pic2

Pic3

Fig. 5. Effects of different methods of highlight correction.

Objective Evaluation. From the results in Table 3, it can be seen that the method in this article generally performs well in NIQE indicators, with a significant decrease in NIQE indicators compared to the original figure. In addition, compared with the methods of Fu and Arnold et al., it can also have a relatively good performance and greatly improve the quality of image correction. Compared with the method of Shen et al., the method in this paper performs better in images with sufficient light. Table 4 shows the recovery results of each highlight correction method. Obviously, each correction method can effectively correct the highlight area, greatly reducing the maximum pixel difference in the original image, and achieving the correction effect of the highlight area.

A Method for Detecting and Correcting Specular Highlights

483

Table 3. NIQE test result. Image

Initial Image

Shen[17]

Fu[18]

Arnold[10]

Proposed

Pic1

4.8334

4.8789

39.7983

5.0196

4.8078

Pic2

5.5140

5.7849

8.6031

5.6982

5.0629

Pic3

6.5942

7.3507

10.3578

6.0750

6.0740

Table 4. Maximum pixel difference. Image

Initial Image

Shen [17]

Fu [18]

Arnold [10]

Proposed

Pic1

249

194

177

206

224

Pic2

246

170

168

189

172

Pic3

253

181

182

197

143

Table 5 shows the average pixel values for each method (after normalization). From the table below, it can be clearly seen that the algorithm in this paper can effectively guarantee the original brightness of the image while achieving a reduction in the pixel average value of the highlight image. Table 5. Pixel average. Image

Initial Image

Shen[17]

Fu[18]

Arnold[10]

Proposed

Pic1

0.3784

0.3496

0.3048

0.3665

0.3758

Pic2

0.2885

0.2763

0.2398

0.2807

0.2868

Pic3

0.2325

0.2061

0.1884

0.2265

0.2154

5 Conclusion This paper proposes an adaptive specular highlight detection method for capsule endoscope images, mainly considering the characteristics of the color distribution of the capsule endoscope, and proposes a highlight detection standard with adaptive thresholds. Through accurate detection of the highlight area, the highlight area is divided into disconnected clusters, and the color around the highlight area is used to fill the highlight area, Gaussian blur and mask decay are used to make the image more natural and achieve better correction effects. Experiments have shown that this scheme can accurately detect high light areas, and has good subjective and objective restoration effects.

484

J. Ma and Y. Hou

References 1. Xia, W., Chen, E., Pautler, S E.: A global optimization method for specular highlight removal from a single image. IEEE Access 7, 125976–125990 (2019) 2. Kang, H., Hwang, D., Lee, J.: Specular highlight region restoration using image clustering and inpainting. J. Vis. Commun. Image Representation 77, 103106 (2021) 3. Wu, C.H., Su, M.Y.: Specular highlight detection from endoscopic images for shape reconstruction. Int. Symp. Measure. Technol. Intell. Instrum. 870, 357–362 (2017) 4. Yu, B., Chen, W., Zhong, Q.: Specular highlight detection based on color distribution for endoscopic images. Frontiers Phys. 8, 616930 (2021) 5. Oh, J.H., Hwang, S., Lee, J.K.: Informative frame classification for endoscopy video. Med. Image Anal. 11(2), 110–127 (2007) 6. Francois, A., Medioni, G.: Adaptive color background modeling for real-time segmentation of video streams. In: International Conference on Imaging Science, Systems, and Technology CISST’99 (2022) 7. Shen, D.F., Guo, J.J., Lin, G.S.: Content-aware specular reflection suppression based on adaptive image inpainting and neural network for endoscopic images. Comput. Methods Programs Biomed. 192,105414 (2020) 8. Chu, Y., Li, H., Li, X.: Endoscopic image feature matching via motion consensus and global bilateral regression. Comput. Methods Programs Biomed. 190, 105370 (2020) 9. Alsaleh, S.M., Aviles-Rivero, A.I., Hahn, J.K.: ReTouchImg: fusioning from-local-to-global context detection and graph data structures for fully-automatic specular reflection removal for endoscopic images. Comput. Med. Imaging Graph. 73, 39–48 (2019) 10. Arnold, M., Ghosh, A., Ameling, S.: Automatic segmentation and inpainting of specular highlights for endoscopic imaging. J. Image Video Process. 2010, 1–12 (2010) 11. Wu, Z., Zhuang, C., Shi, J.: Deep specular highlight removal for single real-world image. In: SA 2020: SIGGRAPH Asia 2020 (2020) 12. Nguyen, T., Nhat, V.Q., Kim, S.H.: A novel and effective method for specular detection and removal by tensor voting. In: IEEE International Conference on Image Processing, IEEE (2015) 13. Banik, P.P., Saha, R., Kim, K.D.: HDR image from single LDR image after removing highlight. In: 2018 IEEE International Conference on Consumer Electronics (ICCE), IEEE (2018) 14. Nie, C., Xu, C., Feng, B.: Specular reflections removal for endoscopic images based on improved criminisi algorithm. In: 2021 IEEE 6th International Conference on Computer and Communication Systems (ICCCS), IEEE (2021) 15. Li, R., Pan, J., Si, Y.: Specular reflections removal for endoscopic image sequences with adaptive-RPCA decomposition. IEEE Trans. Med. Imaging 39(2), 328–340 (2020) 16. El Meslouhi, O.: Automatic detection and inpainting of specular reflections for colposcopic images. Central Eur. J. Comput. Sci. 1, 341–354 (2011) 17. Shen, H.L., Zheng, Z.H.: Real-time highlight removal using intensity ratio. Appl. Opt. 52(19), 4483–4493 (2013) 18. Fu, G., Zhang, Q., Song, C.: Specular highlight removal for real orld images. Comput. Graph. Forum 38(7), 253–263 (2019) 19. Sánchez, F.J., Bernal, J., Sánchez-Montes, C., de Miguel, C.R., Fernández-Esparrach, G.: Bright spot regions segmentation and classification for specular highlights detection in colonoscopy videos. Mach. Vis. Appl. 28(8), 917–936 (2017). https://doi.org/10.1007/s00 138-017-0864-0 20. Mittal, A., Soundararajian, R., Bovik, A.C.: Making a “Completely Blind” image quality analyzer. IEEE Signal Process. Lett. 20(3), 209–212 (2013) 21. Yue, C., Li, Z., Xu, C., Feng, B.: High light removal algorithm for medical capsule endoscope. Comput. Appl. 1–6 (2022)

Siamese Adaptive Template Update Network for Visual Tracking Jia Wen1,2 , Kejun Ren1,2(B) , Yang Xiang1,2 , and Dandan Tang1,2 1 School of Information Science and Engineering, Yanshan University, Qinhuangdao 066004,

China [email protected] 2 The Key Laboratory for Computer Virtual Technology and System Integration of Hebei Province, Shijiazhuang, China

Abstract. Siamese-based trackers have achieved strong performance in singletarget tracking. Effective feature response maps are fundamental to improving tracker performance when dealing with challenging scenes. However, most Siamese-based trackers have constant template features when tracking. This approach greatly limits the effectiveness of the tracker in complex scenes. To solve this issue, we proposed a novel tracking framework, termed as SiamATU, which adaptively performs update of template features. This update method uses a multistage training strategy during the training process so that the template update is gradually optimized. In addition, we designed a feature enhancement module to enhance the discriminative and robustness of the features, which focuses on the rich correlation between the template image and the search image, and then makes the model more focused on the tracking object to achieve more precise tracking. Through extensive experiments on GOT-10K, UAV123, OTB100, and other datasets, SiamATU has a leading performance, which runs at 26.23FPS, exceeding the real-time level of 25FPS. Keywords: Feature enhancement · Template update · Siamese network · Single-target tracking

1 Introduction Target tracking is a highly active research area in computer vision. Given the initial position of the target object, the tracker can predict the target object’s position on each subsequent frame of the video. However, existing tracking models face numerous challenges, especially for real-world applications where the tracking process is often affected by environmental factors such as lighting variations, appearance changes, various types of occlusions, motion blur, etc. In addition, most models are in a situation where they cannot balance accuracy and speed: regression-based models are high speed but have low accuracy. High accuracy models are mostly classification-based but slow. The range of applications of such models is severely limited. The most popular tracking model of deep learning is the Siamese network tracking model, which simplifies tracking task to an object matching problem, and determines © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 485–497, 2023. https://doi.org/10.1007/978-981-99-4742-3_40

486

J. Wen et al.

the location of the tracking target on the search area by learning a general similarity map generated by the cross-correlation. For example, SiamFC [1] was the first to use Siamese networks to single target tracking in 2016. In 2020, researchers found that due to the introduction of RPN, these trackers resulted in sensitivity to parameters of the number, size, and ratio of the anchor. Therefore, numerous anchor-free Siamese tracking algorithms have been proposed. These trackers have again taken the Siamese network tracking model to a new level. Since most Siamese trackers are trained offline, they cannot update templates online. The fixed template features will make it difficult to track targets accurately with large deformation, fast motion, or occlusion during the process. This will significantly increase the risk of tracking drift or frame loss. The backbone network has been able to extract enough feature information for tracking. Yet this feature information has not been used effectively. This phenomenon will lead to too much redundant information in the similarity score map generated by cross-correlation in the face of similar interference or background clutter. To improve this situation, we proposed an Enhanced Multi-Attention Module that acts before the Siamese network cross-correlation. On the basis of the above development and existing issues, we decided to considerably increase accuracy at the sacrifice of a small amount of speed. Nevertheless, the proposed model, called SiamATU, enables real-time processing requirement. Its principal contributions are as follows: 1. The Updatenet [5] was added to allow SiamATU to generate updated templates in a learnable manner. SiamATU achieves higher tracking accuracy and greater robustness through the multi-stage training method. 2. An enhanced multi-attention module (EMA) is proposed to achieve better matching between cross-correlation. Unlike the conventional cross-correlation method, this method greatly enhances the matching of saliency information and effectively delivers the target information to the search region. The module enhances the search area’s target features and improves the match’s accuracy. 3. SiamATU achieves leading performance on several official datasets, including GOT10K [6], DTB70 [7], UAV123 [8], OTB100 [9]. The tracking speed is more than 25FPS (26.23fps), which reaches the real-time level.

2 Related Works The parts that follow describe the present status of tracking framework, template update and attention mechanism. Tracking Framework: Most existing deep learning trackers fall into three main categories: correlation filter trackers. The current study showed that the Siamese network trackers achieve the greatest balance between accuracy and efficiency. In 2020, researchers have suggested a significant number of anchor-free trackers, including SiamCAR [2], SiamFC++ [3] and Ocean [4]. These models have become the mainstream of Siamese trackers. Template Update: The majority of trackers either used simple linear interpolation [10] to update the template for each frame or could not update the template. The template

Siamese Adaptive Template Update Network for Visual Tracking

487

update mechanism of most trackers is far from adequate in the face of deformation, fast motion or occlusion. Attentional Mechanism: Attentional mechanisms were widely used in various tasks. Hu et al. [11] suggested a SENet strengthen the expressive capability of the network by focusing on the relationships between channels. Wang et al. [12] proposed RASNet for attentional feature extraction of the Siamese tracker.

3 Proposed Methods In this section, we will focus on describing the SiamATU model, as shown in Fig. 1. The model is comprised of four primary elements: the Siamese backbone network, the template update module, the feature enhancement module and the classification regression sub-network.

Fig. 1. SiamATU framework structure

3.1 Siamese Backbone Network In the feature extraction phase, The input of template branch is image information of the first frame, denoted as Z. The input of search branch is image information for all frames except the first, denoted as X. The Backbone is a 5-layer network, with Fi(*) denoting the feature information extracted by the i-th layer of the network. In the feature extraction stage, Since the feature information extracted from different network layers is different, retaining more feature information can achieve more accurate tracking. Therefore, the last three layers of depth features are chosen for the subsequent tracking process. The template features of the final three layers of the backbone network are denoted as F3(Z), F4(Z), F5(Z), and the search features are denoted as F3(X), F4(X), F5(X), respectively. ★ is used to denote the cross-correlation operation. Ri represents the similarity map produced by the cross-correlation of the network at the i-th layer. R3 = F3 (Z)F3 (X )

(1)

488

J. Wen et al.

R4 = F4 (Z)F4 (X )

(2)

R5 = F5 (Z)F5 (X )

(3)

After that, these three layers of deep features are fused. Let T3,4,5 denote the fused output features: T3,4,5 = Cat(R3 , R4 , R5 )

(4)

where R3 , R4 , R5 all have 256 channels, T3,4,5 is the feature response map of 3 × 256 channels. To decrease the number of parameters in the model and facilitate subsequent operations, we apply a convolution layer with 1*1 kernels to T3,4,5 for compressing the channels to 256.

Fig. 2. Framework of template update module

Fig. 3. The process of online template update training

3.2 Adaptive Template Update Module Most Siamese trackers employ a simplistic averaging weighting strategy or choose a subset of historical frames for updating template information. Although these methods bring improvements, their simplicity also limits the accuracy of template updates. We therefore chose a simple two-layer learnable network for template updating. It is trained to learn to efficiently update template features according to the training data. Zi = (1 − α)Zi−1 + αTi

(5)

where Ti represents the features extracted from the prediction box generated for the current frame. Zi−1 represents the current cumulative template. Zi represents the generated template used for the next frame prediction. But this simple update strategy often fails to meet some challenging environments. Therefore, UpdateNet [13] is chosen to update the template features. This is a simple network that learns the update strategy adaptively.   (6) Z i = ϕ T0GT , Z i−1 , Zi

Siamese Adaptive Template Update Network for Visual Tracking

489

where ϕ(∗) denotes the learning function. T0GT represents the initial template for the first frame of each video sequence. Z i−1 represents the previous accumulated template. Zi represents the template obtained with the i-th frame prediction result. Z i represents the template suitable for frame i + 1 prediction. The template update module’s detailed architecture and update process is shown in Fig. 2. The module always accesses the template information of the initial frame, which greatly improves the robustness of the template update (Fig. 3). 3.3 Enhanced Multi-attention Module We propose an Enhanced Multi-Attention module for improving the salience of feature information, referred to as the EMA module. Figure 4 depicts the EMA module’s specifics. The module takes the updated template features and the search features as input. The EMA module is divided into three sub-modules. We denote the template features and the features of the search image as Z and X, with feature shapes of C × h × w and C × H × W. The following search image X will be used as input to introduce the three sub-modules. Channel-Attention. The channel attention module focuses on information between channels. Unlike target detection and image classification tasks with pre-defined categories, single object tracking is a category-independent task. The object category is fixed throughout the tracking process. Each channel map of the convolutional features of a deep network typically responds to a specific object category. Therefore, treating all channel features equally will hinder feature expressiveness. Specifically, the input features of this module are assumed to be X ∈ RC×H×W . Keeping its channel dimension unchanged, we first apply the Average pooling layer and the Max pooling layer on X to generate XA ∈ RC×1×1 and XM ∈ RC×1×1 , respectively. Then reshape these two fea  tures into XA , XM ∈ R1×C , where C = C × 1 × 1. After that, XA and XM are operated through the Multi-Layer Perception and reshaped to obtain XA , XM ∈ RC×1×1 . The two are then joined together to obtain the channel attention weight AC ∈RC×1×1 . AC = MLP(AvgPooling(X)) + MLP(MaxPooling(X))

(7)

Finally, the attention weight map is multiplied with the input X at the beginning of this module to obtain the channel attention feature XC ∈ RC × H × W. X C = Sigmoid (AC )X

(8)

Spatial-Attention. When template features and search features are correlated, the perceptual field is limited because the template feature size is fixed, so that features calculated at each spatial location of the search image only yield information about local regions. This approach ignores the fact that understanding the global context from the entire image is also essential. This module’s input is the output of the channel attention module, and the two modules are linked in series. The input features are noted as X ∈ RC×H×W . To reduce the number of its channels while keeping the size of the feature size unchanged, X is first passed through the Average pooling layer and the Max pooling layer, respectively, to obtain XA , XM ∈ R1×H×W . Then the two are concatenated along

490

J. Wen et al.

the channel dimension to obtain XA+M ∈ R2×H×W . After that, a conv-olution layer with 7 × 7 convolution kernel that has the effect of compressing the channels is applied to XA+M to generate the spatial attention weight map AS ∈ R1×H×W . AS = conv2d(concat(AvgPooling(X), MaxPooling(X)))

(9)

Finally, the spatial attention weight map and the input feature X are multiplied to generate the spatial attention feature XS ∈ RC×H×W .

Fig. 4. The structure of the Enhanced Multi-Attention Module

Self-attention. It is common for distracting information such as object occlusion and similar objects to appear simultaneously during tracking, so a single-branch feature enhancement module tends to lose some valid feature information. To compensate for the lost information, we use the self-attention module as a separate parallel branch, which can produce complete feature information. The Self-Attention module is composed of two branches with input features X ∈ RC×H×W . One branch reshapes X into Q ∈ R1×C×N , where N = H × W. The other branch uses a convolutional layer with 1 × 1 kernel and a reshape operation to generate K ∈ R1×N×1 , where N = H × W. Q is then multiplied by K to obtain the self-attention matrix AK ∈ R1×C×1 . AK = Qsoftmax(K) AK

(10)

We reshape the feature size of to C*1*1, thus achieving a simple summation of. the elements of A with the input feature X to acquire the self-attention feature XSF ∈ RC×H×W .

Siamese Adaptive Template Update Network for Visual Tracking

491

We concatenate the spatial attention feature X S and the self-attention feature XSF together along the channel dimension to obtain the combined feature XSFS ∈ R2C×H×W . To make the input and output of the Enhanced Multi-Attention Module consistent and retain the salient feature information, a 1 × 1 convolution layer with the operation of reducing the number of channels is applied to the combined feature XSFS to change the feature map dimension. X EMA = Sigmoid(BN(XSFS )) + X

(11)

Subsequently, the XSFS is simply summed with the input feature X after a series of normalization processes to obtain the output of the EMA module XEMA ∈ RC×H×W . Similarly, the template features are processed by the EMA module to obtain ZEMA ∈ RC×h×w .

3.4 Training Details and Structure Offline Training. SiamATU employs classification and regression of object bounding boxes directly for each location. The classification branch outputs the foreground and background category scores for each position, noted as Acls ∈ Rw×h×2 . The regression branch outputs the distance from each location to the four edges of the target bounding box, noted as Areg ∈ Rw×h×4 . The center-ness branch is used to suppress excessive displacement of the target box and the output response graph is represented as Acen ∈ Rw×h×1 . The classification branch is optimized using the cross-entropy loss function and the regression branch is optimized using the iou loss function. (x0, y0) and (x1, y1) represent the ground truth bounding box’s upper left and lower right coordinates, respectively. The original image position (x, y) and the response map coordinates (i, j) correspond. The 4-dimensional coordinates of the regression branch response map at position (i, j) can be expressed as follows: 0 1 = ˜l = x − x0, T(i,j) = ˜t = y − y0 T(i,j) 2 3 T(i,j) = r˜ = x1 − x, T(i,j) = b˜ = y1 − y

(12)

The position of the prediction box can be calculated using the above equation, and then we calculate the regression loss function using the following equation:      1  Lreg =   ϑ T(i,j) LIOU Areg (i, j, :), T(x,y) i,j ϑ T(i,j) LIOU represents the IOU loss function, ϑ(•) function can be defined as:  k   1, ifT(i,j) > 0, k = 0, 1, 2, 3 ϑ T(i,j) = 0, otherwise

(13)

(14)

492

J. Wen et al.

Each element C(i, j) of the center-ness branch Acen ∈ Rw×h×1 is defined as follows.    

˜

˜t , b˜ min min l, r ˜ 

     C(i,j) = ϑ T(i,j) ∗ × (15) max ˜l, r˜ max ˜t , b˜ The center-ness loss function is defined as follows.  −1  C(i, j) ∗ logAcen (i, j) LCEN =   ϑ (T(i,j) )==1 ϑ T(i,j)   +(1 − C(i, j)) ∗ log 1 − Acen (i, j)

(16)

The overall loss function of SiamATU is defined as follows. Ltotal = Lcls + μ1 Lcen + μ2 Lreg

(17)

where Lcls represents classification loss, μ1 and μ2 represent the weight parameters for the center-ness loss and regression loss, respectively. During model training, we set μ1 = 1 and μ2 = 3. Online Training. After offline training, we need to load the trained model parameters and unfreeze the online template updates for further online training. Online training aims to make our updated template features more effective. The specific training details are shown in Fig. 3. It is implemented by minimizing the Euclidean distance between the predicted and the true-value template of the next frame. 

 GT (18) L2 = ϕ T0GT , Z i−1 , Zi − Ti+1 2

For the training of the template update module, a multi-stage training strategy was used. The first stage uses the standard linear update method to generate cumulative templates and actual predicted positions for each frame of the training dataset. 0

0

Zi = (1 − γ )Zi−1 + γ Zi0

(19)

which generates cumulative templates and actual predicted positions for every frame. γ is set to 0.0102. The IOU value of the generated prediction box and the ground truth box is used to determine whether the tracking is lost. This strategy is intended to prevent training through many lost frames. In the latter two stages, the best performance models from the former stage are used to generate cumulative templates and object position predictions. The online training dataset consists of 20 sequences from the 280 sequences in the lasot dataset. Each stage requires two training sessions. Different learning rates are required for the two training sessions.

Siamese Adaptive Template Update Network for Visual Tracking

493

4 Experiments 4.1 Implementation Details The proposed SiamATU is implemented on a 3090 graphics card using Pytorch 1.11 and cuda 11.1. We employ ResNet-50 as the backbone network, which has been pre-trained on ImageNet [13]. During training, the batch size is set to 16, and 20 epochs are trained using Stochastic gradient descent (SGD) with an original learning rate of 0.001. In the first 10 epochs, the backbone network is frozen. In the last 10 epochs, the last 3 layers of the backbone network are unfrozen to train together. 4.2 Results on GOT-10K To assess the generality of SiamATU, we tested it on the GOT-10K [6] dataset and compared it to state-of-the-art trackers.As shown in Fig. 5. SiamATU performs the most prominently on these 20 trackers. Compared with SiamCAR [2], SiamATU improves the scores of AO, SR0.5 and SR0.75 by 3%, 3.9% and 3.3% respectively. SiamATU is compared with four well-known trackers (SiamDW [14] and SiamFC [1] and SiamCAR [2] and SiamRPN + + [15]) for visualization of 12 video sequences tracked through the GOT-10K dataset, as shown in Fig. 6. SiamATU can draw more accurate and closer to the target tracking box in the presence of similar objects, fast motion, scale change and full occlusion. This is mainly due to the template update and feature enhancement module of SiamATU. Even if the tracking object has deformation, the template update module can still predict more accurate template features. The robustness of the tracker is enhanced. Before the cross-correlation of Siamese network, the feature enhancement module acts on template features and search features respectively, thus enabling SiamATU to find the significance information of target more accurately. Since the tracker follows the GOT-10K protocol and the labeling boxes of the test dataset are invisible to the tracker and us, the tracking results at GOT-10K are more reliable than other datasets.

Fig. 5. Comparisons of SiamATU and other trackers on GOT-10K

494

J. Wen et al.

4.3 Results on Other Datasets In UAV123 [8], SiamATU surpasses baseline by 1.1% and 1.6% in success and accuracy. Respectively. In OTB100 [9], SiamATU achieves precision of 0.920 and success rate of 0.707. It achieves leading performance on each of these datasets. 4.4 Ablation Experiment We investigated the influence of individual components in SiamATU and performed ablation studies on the DTB70 [7] and GOT-10K [6] datasets. The SiamCAR was regarded as the baseline model. As shown in Table 1. By adding the EMA module to baseline, the success and precision can be improved to 0.620 and 0.809. With the ATU module, the success and precision are further improved by 2.2% and 1.5%. Our model’s success and accuracy have increased by 4.7% and 6.1%, respectively, as compared to the baseline. SR0.5 and SR0.75 indicate the percentage of frames with overlap exceeding 50% and 75%, respectively. Table 2 shows the details of the ablation study for both modules on the GOT-10K [6] dataset. We found that the ATU module outperformed the EMA module on SR0.5 and the AO metric, while its performance on SR0.75 was not as outstanding as that of the EMA module. This indicates that the ATU module improves the stability of the tracker by generating adaptive template features, which reduces the possibility of predicting low-quality bounding boxes during tracking. However, the performance in predicting high-quality bounding boxes is less outstanding than the EMA module. The EMA module is integrated before the cross-correlation of the two branch features to enhance the key features and suppress the secondary features. It is clear from Fig. 7 that the tracker combining the two modules provides better identification of the target position, and its response area matches the target’s look better. It can better suppress background clutter and interference from similar objects than other trackers. This shows that the EMA and ATU modules in SiamATU can optimize the tracker in various aspects. Their collaboration has led to positive results.

5 Conclusion In this paper, we proposed an adaptive online template update tracking framework SiamATU that contains adaptive template update module ATU and feature enhancement module EMA. The ATU module fuses initial, historical, and current frame information to obtain new adaptive template features. It enhances information on the salient features of the target in the face of illumination changes, scale changes, object deformation and low resolution. The EMA module integrates three attention mechanisms to highlight critical features and weaken secondary features by comprehensively weighting the three dimensions of feature information. The EMA module acts before the template features and search features are cross-correlated, and this design allows for a closer contextual relationship between the template and the search region. The accuracy and robustness of the tracker is improved by generating more accurate similarity maps to locate targets. SiamATU employs a multi-stage step-by-step strategy to train and optimize the model. Experimental results show that the proposed SiamATU can achieve competitive performance on five mainstream tracking benchmarks. Besides, the tracker meets real-time requirements.

Siamese Adaptive Template Update Network for Visual Tracking

Fig. 6. Tracking visualization of SiamATU and four well-known trackers on GOT-10K

Table 1. Ablation study on DTB70. Setting

Success Precision Success Precision

Baseline

0.595

0.781





Baseline + EMA

0.620

0.809

+2.5%

+2.8%

Baseline + ATU

0.617

0.796

+2.2%

+1.5%

Baseline + 0.642 EMA + ATU(Ours)

0.842

+4.7%

+6.1%

Table 2. Ablation study on GOT-10K. SR0.5 SR0.75 AO

Setting

AO

Baseline

0.581 0.685 0.444



Baseline + EMA

0.592 0.696 0.458

+1.1%

Baseline + ATU

0.596 0.704 0.456

+1.5%

Baseline + EMA + 0.611 0.724 0.477 ATU(Ours)

+3.0%

495

496

J. Wen et al.

Fig. 7. Visualization of heat maps with proposed ATU module (the second row), with EMA module (the third row), with ATU + EMA module (the fourth row) on 4 sequences from DTB70.

Acknowledgements. This work was supported by the National Natural Science Foundation of China under (grant No. 62273293), Shandong Provincial Natural Science Foundation, China under Grant ZR2022LZH002. And Innovation Capability Improvement Plan Project of Hebei Province (No. 22567626H).

References 1. Bertinetto, L., Valmadre, J., Henriques, J.F., Vedaldi, A., Torr, P.H.S.: Fully-convolutional Siamese networks for object tracking. In: Hua, G., Jégou, H. (eds.) Computer Vision – ECCV 2016 Workshops: Amsterdam, The Netherlands, October 8–10 and 15–16, 2016, Proceedings, Part II, pp. 850–865. Springer International Publishing, Cham (2016). https://doi.org/10.1007/ 978-3-319-48881-3_56 2. Guo, D., Wang, J., Cui, Y., Wang, Z., Chen, S.: Siamcar: Siamese fully convolutional classification and regression for visual tracking. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6269–6277 (2020) 3. Xu, Y., Wang, Z., Li, Z., Yuan, Y., Yu, G.: SiamFC++: towards robust and accurate visual tracking with target estimation guidelines. In: Proceedings of the AAAI conference on artificial intelligence, vol. 34, pp. 12549–12556 (2020) 4. Zhang, Z., Peng, H., Fu, J., Li, B., Hu, W.: Ocean: object-aware anchor-free tracking. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) Computer Vision – ECCV 2020. ECCV 2020. Lecture Notes in Computer Science, vol. 12366, pp. 771–787 Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58589-1_46 5. Zhang, L., Gonzalez-Garcia, A., Weijer, J.V.D., Danelljan, M., Khan, F.S.: Learning the model update for Siamese trackers. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4010–4019 (2019) 6. Huang, L., Zhao, X., Huang, K.: Got-10k: a large high diversity benchmark for generic object tracking in the wild. IEEE Trans. Pattern Anal. Mach. Intell. 43(5), 1562–1577 (2019)

Siamese Adaptive Template Update Network for Visual Tracking

497

7. Li, S., Yeung, D.-Y.: Visual object tracking for unmanned aerial vehicles: a benchmark and new motion models. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 31 (2017) 8. Mueller, M., Smith, N., Ghanem, B.: A benchmark and simulator for UAV tracking. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 445–461. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46448-0_27 9. Wu, Y., Lim, J., Yang, M.-H.: Object tracking benchmark. IEEE Trans. Pattern Anal. Mach. Intell. 37(09), 1834–1848 (2015) 10. Danelljan, M., Hager, G., Shahbaz Khan, F., Felsberg, M.: Adaptive decontamination of the training set: a unified formulation for discriminative visual tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1430–1438 (2016) 11. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7132–7141 (2018) 12. Wang, Q., Teng, Z., Xing, J., Gao, J., Hu, W., Maybank, S.: Learning attentions: residual attentional Siamese network for high performance online visual tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4854–4863 (2018) 13. Deng, J., Dong, W., Socher, R., Li, L.-J., Li, K., Fei-Fei, L.: Imagenet: a large-scale hierarchical image database. In: 2009 IEEE Conference on Computer Vision and Pattern Recognition, pp. 248–255. IEEE (2009) 14. Zhang, Z., Peng, H.: Deeper and wider Siamese networks for real-time visual tracking. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4591–4600 (2019) 15. Li, B., Wu, W., Wang, Q., Zhang, F., Xing, J., Yan, J.: Siamrpn++: evolution of Siamese visual tracking with very deep networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4282–4291 (2019)

Collaborative Encoder for Accurate Inversion of Real Face Image YaTe Liu1,2 , ChunHou Zheng1 , Jun Zhang1 , Bing Wang3 , and Peng Chen1(B) 1 National Engineering Research Center for Agro-Ecological Big Data Analysis & Application,

Information Materials and Intelligent Sensing Laboratory of Anhui Province, School of Internet & Institutes of Physical Science and Information Technology, Anhui University, Hefei 230601, Anhui, China [email protected] 2 School of Computer Science and Technology, Anhui University, Hefei 230601, Anhui, China 3 School of Electrical and Information Engineering, Anhui University of Technology, Ma’anshan 243032, Anhui, China

Abstract. In recent times, there has been notable progress in the effectiveness of Generative Adversarial Networks (GANs) for synthesizing images. Consequently, numerous studies have started utilizing GANs for image editing purposes. To enable editing of real images, it is crucial to embed a real image into the latent space of GANs. This involves obtaining the latent code of the real image and subsequently modifying the image by altering the latent code. However, accurately reconstructing the real image using the obtained latent code remains a challenge. This paper introduces a novel inversion scheme that achieves high accuracy. In contrast to conventional approaches that employ a single encoder for image inversion, our method utilizes collaborative encoders to accomplish the inversion task. Specifically, two encoders are employed to invert distinct regions in the image, namely the face and background regions. By distributing the inversion task between these encoders, the burden on a single encoder is reduced. Furthermore, to optimize efficiency, this paper adopts a lightweight network structure, resulting in faster inference speed. Experimental results demonstrate that our proposed method significantly enhances visual quality and improves the speed of inference. By leveraging collaborative encoders and a lightweight network structure, we achieve notable improvements in image inversion, thus enabling more effective image editing capabilities. Keywords: Face Image Inversion · Collaborative Encoder · Generative Adversarial Network · Efficient

1 Introduction In recent years, Generative Adversarial Networks (GANs) [1] have made significant progress in image synthesis, especially for face images. Existing methods can synthesize face images with diverse styles and high visual quality. The current best generative © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 498–505, 2023. https://doi.org/10.1007/978-981-99-4742-3_41

Collaborative Encoder for Accurate Inversion of Real Face Image

499

adversarial network is StyleGAN [2, 3], which proposes a novel network structure consisting of a Mapping Network and a Synthesis Network. Many studies [4] have shown that the latent space of StyleGAN has the property to attribute disentanglement, which provides an excellent ability to edit images. People can use the pre-trained StyleGAN model for image editing [5]. The editing operation [6] is usually applied to the image generated by StyleGAN. To edit the real image, it is first necessary to invert the given real image into the latent space of StyleGAN, that is, to obtain the latent code corresponding to the real image. Inputting the obtained latent code into the pre-trained generator can reconstruct the real image. Only after getting the latent code of real image, the editing of image attributes [4] can be realized by performing vector operations on the latent code. Therefore, a high-quality inversion scheme is crucial for such editing techniques. In this paper, we propose a collaborative encoder architecture that embeds real images into the W + space. This paper carefully analyzed the real image, and the complete image can be divided into face region and background region according to semantic information. The contours and facial features of the face regions between images have common characteristics, but the background region has tremendous differences. Unlike previous methods, we construct two encoders to invert face and background regions independently, and the two encoders cooperate to complete the final inversion task. We experimentally demonstrate that our encoder not only achieves a significant improvement on reconstruction quality compared to previous works but also achieves faster inference speed due to using a lightweight encoder.

2 Method 2.1 Motivations This paper argues that previous methods [7, 8, 16] only use a single encoder to extract information from a complete image, including common facial features as well as complex and diverse background features. Therefore, this paper uses two independent subencoders to extract the features of different regions respectively, a face feature encoder and a background feature encoder. The two encoders work together to complete the inversion task. Finally, the latent codes obtained by the two encoders are input into the generator to achieve an accurate reconstruction of the image. 2.2 Encoder Architecture Our goal for a given real image is to embed it well into the StyleGAN latent space, that is, to invert the real image into a latent code. According to previous studies, this paper chooses W^ + space as the embedding latent space because W^ + space can contain more facial information, thus enabling more accurate reconstruction. Figure 1 shows the structure of the face encoder. The convolutional layer and the CBAM module [10] are superimposed to extract the facial features of face image. During the encoding process, the shallow, middle and deep feature maps are stored, which are then fused by the feature pyramid structure [11]. The deep feature map is upsampled

500

Y. Liu et al.

and added to the middle layer feature, and the middle layer feature map is added to the shallow layer feature after convolution and upsampling. The three feature maps are obtained after fusion correspond to the style levels of StyleGAN. Finally, this paper design a module for feature map conversion latent code, which is a structure combining convolutional layers, pooling layers, CBAM module, and fully connected layers and used to convert the deep feature maps into rough latent code (w1 , w2 , w3 , w4 ), transform the middle feature map into intermediate latent code (w5 , w6 , w7 , w8 ), and convert the shallow feature map into detail latent code (w9 , w10 …, w17 , w18 ). Figure 1 shows the structure of our background encoder. Here a background region encoder is designed to encode background region. Because the proportion of the background area in the whole image is tiny, only a lightweight convolutional network is required to extract the information of the background area, which can reduce the number of convolutional layers and CBAM modules.

Fig. 1. Face encoder and Background encoder

2.3 Losses First, this paper compute the depth perception loss Llpips [12] between the real and synthetic images. Llpips uses the neural network to extract the depth features of two images and calculates the difference between the depth features to obtain loss value, Llpips , which is more in line with the real perception of human beings.       Llpips = ||F(X ) − F G Eface Xface , Ebackground Xbackground ||2 . (1) where F(·) represents the network used to extract deep features. This paper uses not only in-depth features to limit the difference between input and output images, but also use L2 , at the pixel level to limit the difference between two images.      L2 = ||X − G Eface Xface , Ebackground Xbackground ||2 . (2)

Collaborative Encoder for Accurate Inversion of Real Face Image

501

Finally, this paper uses the similarity, Lid , of face recognition network [13] as the face similarity loss to make the latent code accurately expressing the facial features in real images.       Lid = ||ID(X ) − ID G Eface Xface , Ebackground Xbackground ||2 . (3) where ID(·) denotes the use of face recognition network to extract face features in the image. To summarize, the whole objective for our encoder is: L = λ1 Llpips + λ2 L2 + λ3 Lid .

(4)

where λ1 , λ2 , λ3 are weights to balance different loss functions. This paper finally determined that the weight of λ1 is 0.73, λ2 is 0.93, and λ3 is 0.46.

3 Experiment and Results In this part, this paper introduces the detailed setup of experiments and conduct experiments to compare with existing methods, demonstrating that our method has high reconstruction accuracy and visual quality. This paper also provides detailed ablation experiments to demonstrate the effectiveness of two lightweight encoders collaborating on inversion. 3.1 Dataset This paper conduct face inversion experiments on high-quality face datasets FFHQ [2] and CelebAHQ [14]. The FFHQ dataset is used as the training set, and the CelebAHQ test set is used for quantitative evaluation. The FFHQ dataset contains a total of 70000 high-quality face images. This paper uses a segmentation network [9] to segment the image into face and background regions. The CelebAHQ test set includes 30000 highquality face images, and this paper also performed the same segmentation process as the training set. 3.2 Metrics This paper comprehensively evaluates and compare our method with other encoder-based methods in five aspects: 1. Measure depth perception similarity by LPIPS [12]. 2. The pixel-level L2 distance between the real and the synthetic images. 3. Face similarity between real images and synthetic images. To ensure that face similarity is independent of our loss function, this paper use the CurriclarFace [15] method to evaluate face similarity. 4. Evaluating feature distribution between real and synthetic images using FID. 5. Runtime.

502

Y. Liu et al.

3.3 Quantitative Evaluation This paper quantitatively evaluates the quality of the reconstructed images, and the results are shown in Table 1. This paper computes the LPIPS metric between the input real image and the output reconstructed image. Our method achieves lower LPIPS values, which means that our output reconstructed images are more similar to real images from a human point of view. MSE directly obtains the result by calculating the mean square error between the corresponding pixels between the images. Our method gets a lower MSE value, which means that from the fine-grained pixel level, the difference between our reconstructed image and the corresponding pixels of the real image is more minor. From the quantitative results of face similarity, the inversion method proposed in this paper can better preserve the face information in the real image so that the recon-structed image has a higher face similarity with the input image. Table 1. Quantitative comparison for inversion quality on faces Method

LPIPS↓

MSE↓

Face Similarity↑

FID↓

Runtime↓

pSp [7]

0.17

0.034

0.57

18.5

0.060s

e4e [8]

0.20

0.052

0.50

22.0

0.060s

ReStylee4e [16]

0.19

0.041

0.52

17.7

0.179s

ReStylepSp [16]

0.13

0.030

0.66

11.9

0.179s

Style Transformer [17]

0.16

0.033

0.59

14.1

0.043s

HFGI [18]

0.13

0.022

0.68

5.5

0.120s

Ours

0.11

0.021

0.71

11.9

0.054s

Our method obtained a lower FID value, slightly worse than the HFGI method while outperforming other methods, which shows that the reconstructed image was synthesized using StyleGAN is more in line with the real data distribution. This paper calculated the time required for inference. The collaborative encoder structure increases the number of encoders due to the decomposition of the inversion task, but the design of a single encoder is relatively streamlined, which shortens the inference time and achieves higher reconstruction accuracy. 3.4 Qualitative Evaluation As shown in Fig. 2, this paper offers the visual effects of inverting and reconstructing real images using different methods. In order to more intuitively show the difference in details of different inversion results, this paper use heat maps in the even rows of Fig. 2 to show the difference between synthetic images and real images. According to the synthetic image obtained by inversion of various methods, calculate the absolute value of the pixel difference between it and the real image, and convert it into a heat map according to the size of the absolute value. The larger the absolute value,

Collaborative Encoder for Accurate Inversion of Real Face Image

503

Fig. 2. Qualitative results of image inversion. Our method is compared with pSp, e4e, ReStylepsp , ReStylee4e , Style Transformer and HFGI

the brighter the area in the heat map, and the smaller the absolute value, the darker the area in the heat map. Compared with previous approaches, our method is robust in background reconstruction. For example, in the first and third lines, the input real images all have complex background patterns. In the second and fourth lines of the heat map, the corresponding heat map of our method is darker, indicating that the absolute value between the corresponding pixels of the synthesized image and the real image is small, which better preserves the background pattern and texture information in the real image and enhances the authenticity of the synthesized image. At the same time, for the face area in the image, our method also achieves the improvement of details. For example, in line five, the input image has a simpler background pattern, observe the heat map, the heat map corresponding to our method is darker in hair and facial features, indicating that the face area of the synthesized image is similar to that of the real image.

4 Application At present, this paper embeds the given real image into the latent space of StyleGAN through the encoder, then we can complete some image processing tasks. In this section, we use the latent code output by the encoder to perform some image coloring tasks, and it shown useing the CelebAHQ [20] dataset to show.

504

Y. Liu et al.

To achieve image coloring, this paper used the FFHQ dataset as a training set and set it as a grayscale image. The grayscale image is converted into a latent code through a collaborative encoder, a new image is synthesized using StyleGAN, and the original color image is used as the optimization target during the training process. This paper achieves the input of grayscale images and convert them into latent codes, which can be fed into StyleGAN to synthesize color images corresponding to grayscale images.

Fig. 3. Image Colorization result

As shown in Fig. 3, for a given grayscale image, our encoder can achieve the task of image colorization. The first column of image colorization in Fig. 3 is the real color image sampled from the CelebAHQ dataset, the second column is the model’s input data, which converts the real image into a grayscale image, and the third column is the colored image output by the model.

5 Conclusion In this paper, we propose a decomposition of the task based on the semantic information of facial images. To achieve this, we introduce the utilization of two lightweight encoders that collaborate to accomplish the task. Our experimental results demonstrate that our encoder network not only enhances image reconstruction accuracy but also significantly reduces inference time. These findings indicate the promising potential of our approach in the field of image processing. Moreover, this paper showcases several exciting image processing applications that leverage real images, highlighting their practical application value. We anticipate that our work will find widespread use in future applications that demand both fast and accurate reconstruction of real images. By addressing the challenges in image reconstruction, we aim to contribute to the advancement of image processing technologies and their practical implementations. Acknowledgement. . This work was supported by the National Natural Science Foundation of China (Nos. 62072002, 62172004 and U19A2064), and Special Fund for Anhui Agriculture Research System (2021–2025).

Collaborative Encoder for Accurate Inversion of Real Face Image

505

References 1. Goodfellow, I.P.A., et al.: Generative adversarial networks. Commun. ACM 63(11), 139–144 (2020) 2. Karras, T.L., et al.: A style-based generator architecture for generative adversarial networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4401–4410 (2019) 3. Karras, T.L., et al.: Analyzing and improving the image quality of stylegan. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8110–8119 (2020) 4. H¨ark¨onen, E., et al.: Ganspace: discovering interpretable GAN controls. Adv. Neural Inf. Process. Syst. 33, 9841–9850 (2020) 5. Shen, Y., et al.: Interpreting the latent space of GANs for semantic face editing. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 9243–9252 (2020) 6. Nitzan, Y., et al.: Face identity disentanglement via latent space mapping. arXiv preprint arXiv:2005.07728 (2020) 7. Richardson, E., et al.: Encoding in style: a stylegan encoder for image-to-image translation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2287–2296 (2021) 8. Tov, O., et al.: Designing an encoder for stylegan image manipulation. ACM Trans. Graph. (TOG) 40(4), 1–14 (2021) 9. Lin, S., et al.: Robust high-resolution video matting with temporal guidance. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 238–247 (2022) 10. Woo, S., Park, J., Lee, J.-Y., Kweon, I.S.: CBAM: convolutional block attention module. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 3–19. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_1 11. Lin, T.-Y., et al.: Feature pyramid networks for object detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2117–2125 (2017) 12. Zhang, R., et al.: The unreasonable effectiveness of deep features as a perceptual metric. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 586– 595 (2018) 13. Deng, J., et al.: Arcface: additive angular margin loss for deep face recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4690–4699 (2019) 14. Karras, T., et al.: Progressive growing of GANs for improved quality, stability, and variation. arXiv preprint arXiv:1710.10196 (2017) 15. Huang, Y., et al.: Curricularface: adaptive curriculum learning loss for deep face recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5901–5910 (2020) 16. Alaluf, Y., et al.: Restyle: a residual-based stylegan encoder via iterative refinement. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6711–6720 (2021) 17. Wang, T., et al.: High-fidelity GAN inversion for image attribute editing. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11379–11388 (2022) 18. Hu, X., et al.: Style transformer for image inversion and editing. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11337–11346 (2022)

Text-Guided Generative Adversarial Network for Image Emotion Transfer Siqi Zhu1 , Chunmei Qing1,3 , and Xiangmin Xu2,3(B) 1 School of Electronic and Information Engineering, South China University of Technology,

Guangzhou, Guangdong, China 2 School of Future Technology, South China University of Technology, Guangzhou 510641,

China [email protected] 3 Pazhou Lab, Guangzhou 510330, China

Abstract. In recent years, image emotion computing has attracted widespread attention from researchers in the computer vision field due to its demonstrated potential in various domains. The main research direction in this field has been extended from image emotion classification and recognition to tasks such as image emotion transformation and generation. However, the semantic gap between highlevel semantic image emotion and low-level image features has led to biases in existing image style transfer models’ representation of image emotion, resulting in inconsistencies in the generated images’ emotional features and semantic content. Inspired by text-guided image generation models, we first introduced the guiding role of text information into the research of emotional image transformation. The proposed model has achieved significant advantages in the performance of text-based editing of image emotions compared to existing text-guided image generation models. By leveraging the inherent connection between text and image emotion and content, we improved the accuracy of image emotion transformation and generated more natural and realistic images compared with existing image style transfer models and image emotion transfer models. Keywords: Image emotion transfer · generative adversarial network · affective computing · semantically multi-modal image synthesis

1 Introduction Images as one of the most commonly used multimedia formats are extensively shared on social networking and content-sharing websites, and their visual emotional analysis has diverse applications in various fields. Specific emotional information contained in images and videos can impact people’s views on public events and their decision-making concerning products or services [1]. Artists use images to elicit emotions and provoke particular reactions from viewers through photo enhancement, film post-production, and other methods [2]. In psychological research, images with intricate emotional characteristics have been employed to increase the accuracy of detecting mental illness and © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 506–522, 2023. https://doi.org/10.1007/978-981-99-4742-3_42

Text-Guided Generative Adversarial Network

507

provide personalized treatment or services for patients [3]. In the field of artificial intelligence, designers can edit multimedia content in virtual reality (VR) and augmented reality (AR) environments to enhance users’ immersion, and images with appropriate emotional features aid in enhancing the naturalness of human-computer interaction experiences [4, 5]. The vast potential of visual emotional analysis in real-world applications and academic research has made it one of the focal points in the field of computer vision. Previous research on image emotions mainly focused on recognizing and classifying image emotions [6, 7], using images of human faces and natural scenes as research objects. More recently, some researchers have turned their attention to the task of image sentiment transfer. Ali et al. [8] transformed emotional images by using color features extracted from reference images via a neural network. An et al. [9] transferred the global sentimental state of input images based on reference images retrieved with target sentiment tags. However, these methods for representing image emotions are still limited to the low-level feature level of images, unable to fundamentally solve the semantic gap problem of emotional images. Alternatively, some researchers [10, 11] consider the image sentiment transfer task as a domain adaptation problem, similar to image style transformation tasks [12]. Although these methods have improved the accuracy of sentiment transfer by using image semantic matching between source and reference images, the transferred emotional distribution is still limited by the sentimental categories of the reference image groups. Moreover, the sentimental features extracted from the reference domain may introduce more interference into the transferred images. Considering this problem, Zhu et al. proposed the EGAN [13] model to improve the accuracy of image emotion transfer by decoupling the semantic content and emotional features of the image. In general, there is still a need for further research to improve the harmony between emotional features and semantic content in generated images.

Fig. 1. The text-guided structure of TGE-GAN model for image emotion transfer.

In recent years, there have been some advances in cross-modal image research in the areas of image classification [14, 15], recognition [16], and transformation [17]. Some researchers have started to focus on cross-modal image style transfer and generation problems. Among them, text-guided image generation model represented by CLIP [18] has gained widespread attention due to its outstanding performance in image-related

508

S. Zhu et al.

tasks. Based on this work, some researchers have further applied the guiding role of text in image style transfer and facial expression changes [19]. However, these works still focus on editing the semantic content of images or transforming the image style represented by low-level features, without involving the high-level emotional semantics of images. For an image, its text information label can simultaneously build the intrinsic connection between image emotional features and semantic content. In view of this, here we introduced the guiding role of text information in image emotion transfer. As shown in Fig. 1, we designed a model TGE-GAN (Text-Guided-Emotion GAN) that utilizes text information to refine the emotion features of images in the task of image emotion transformation. This strategy applies a Nature-High-Concept (NHC) encoder and an Emotion-Low-Feature (ELF) extractor to embed images from different domains into a shared latent NHC space and a shared latent ELF space, then we embed text that describes the emotional category of the image as a label into the existing generative adversarial model. By utilizing the connection between text information and image content, we refined the emotion features to better reflect the low-level emotion features extracted from the image. We decouple the semantic content and emotional features of the image, which improves the accuracy and effectiveness of text guidance and avoids the direct influence of the emotional label on the semantic content of the input image. The main contributions of this work are as follows: 1) We introduced the text-guided model into the task of image sentiment transfer. By using the sentiment categories of images as text labels to constrain the generation of image sentiment features with the CLIP loss network pre-trained on large-scale datasets, we improved the accuracy of image sentiment transformation. 2) By decoupling emotion features and content features, we can more accurately locate and refine the transferable emotion information extracted from the image. This strategy can reduce the interference of text information on the content of the image and improve the effect of emotion transformation. 3) The results of experiments show that compared with existing text-guided image generation models, the proposed model can generate expected emotional images. Meanwhile, the proposed model has an advantage in the accuracy of sentiment transformation, and the generated images have a more harmonious and natural visual appearance in terms of semantic content and sentiment features compared with existing image style transfer and image sentiment transformation models.

2 Related Works 2.1 Image Emotion Transfer With the development of research in the field of image emotion computing in recent years, the task of image emotion transformation has gradually attracted attention from some researchers. Most of these tasks are based on tasks such as image style transformation or image generation and propose corresponding models based on the differences in the distribution of target image emotions. Existing emotion transformation methods are mainly divided into two stages. Early image emotion models mainly achieved the transformation of image emotion distribution by changing the low-level features of the image. For example, Kim et al. [20] proposed a system that can find reference image

Text-Guided Generative Adversarial Network

509

segments from a set of labeled images, and then achieved emotional state transition by re-coloring the source image segments and the selected target image segments. Ali et al. [8] used neural networks to extract color features of reference images to achieve image emotion transformation. These algorithms usually do not consider the distribution of the same emotional domain on different semantic content, which affects the accuracy of the generated image’s emotional expression. Recently, some researchers have included the rationality between image semantic content and target emotional distribution in their research scope. Chen et al. [11] proposed SentiGAN as the core component to perform image emotion transformation at the object level. Zhao et al. [10] also proposed a domain adaptive method, which learned the discrete probability emotional distribution of the target image and adapted it to the original image, to convert an image to another emotional domain. These models are usually plagued by domain interference and it is difficult to maintain consistency between generated images and original images in terms of content. An et al. [9] transferred the global emotional state of the input image based on the reference images retrieved using target emotional labels. Although these methods improve the accuracy of emotion transfer by using semantic matching between source and reference images, the transferred emotional distribution is still limited to the emotional categories of the reference image group. In addition, emotional features extracted from the reference domain may cause more interference to the transferred image. To address this issue, Zhu et al. [13] proposed a model that decouples the semantic content and emotional features of images to improve the accuracy of image emotion transformation. Overall, the harmony between image semantics and emotional feature distribution in image emotion transformation and generation tasks still requires further research. 2.2 Text-Guided Image Generation and Manipulation In recent years, GAN has achieved significant accomplishments in image synthesis [21, 22] and image translation [23, 24]. In [25], CycleGAN leveraged the cycle consistency scheme to preserve the common information between source images and translated images, which was also used in DualGAN [26] and DiscoGAN [27]. Subsequently, some researchers began to introduce multimodal information into image generation and transformation tasks to improve the quality of generated images and enhance the model’s ability to represent the real world. Reed et al. [28] were the first to propose using text information to guide image generation by embedding a pre-trained text encoder into a conditional GAN network. While some researchers have suggested methods to enhance image quality using this approach, such as Zhang et al.‘s [29, 30] multi-scale GAN approach and AttnGAN’s [31] attention mechanism between text and image features. Building on this foundation, researchers have further expanded the methods and domains of text-guided image generation. The StackGAN [32] used a two-stage GAN network to improve the quality of text-generated images. The first stage generated the outline of the image, and the second stage enriched the image with details. Both generators in the two stages incorporate text representations to ensure semantic consistency between the two stages. The XMC-GAN [33] network captured the relationship between two modalities, i.e.text and image, through multiple contrastive learning, and ensured the consistency between the generated images and the descriptive text semantics by

510

S. Zhu et al.

maximizing the mutual information between the two different domains. The StyleCLIP [19] offered a single framework that blended high-quality images generated by StyleGAN [34] with the diverse multi-domain semantics learned by CLIP [18] for text-guided image style transfer. Another notable feature is the optimization of the proposed method using pre-trained models from existing large-scale models, reducing the hardware costs of model training. While existing text-guided image generation models perform well in producing image content, they often fail to generate images that correspond to the target emotional distribution based on embedded emotion-related text in the field of natural image emotional transformation and generation. This implies that current generation models are still unable to represent the semantic content and emotional distribution of images in a unified and cohesive manner.

3 Text-Drived Emotional Generative Adversarial Network 3.1 Disentangled Representations of NHC and ELF In this work, we firstly decouple the high-level semantic information and low-level emotion information of the source image, and then use the guiding effect of text information to further refine the extracted emotion features to enhance the performance of image emotion transformation. In the scenario of image emotion transfer, an image that elicits emotional reactions can be represented by two parts [8]: neutral high-level concepts (NHC), which represent the content of the image, and emotional low-level features (ELF), which characterize the emotional information. The key to image emotion transfer is to extract the distribution of these emotion-related features. As shown in Fig. 2, our framework contains a NHC encoder and an ELF extractor. Let x ∈ X and y ∈ Y be images from two different emotion domain X , Y . The NHC encoder E nhc can encode an image into the NHC space (E nhc : X , Y → NHC), while the ELF extractor E elf can map an image onto the ELF space (E elf : X , Y → ELF). Thus x and y can be disentangled as:     elf elf zxnhc , zx = E nhc (x), E elf (x) , zxnhc ∈ NHC, zx ∈ ELF, (1)     elf elf = E nhc (y), E elf (y) , zynhc ∈ NHC, zy ∈ ELF zynhc , zy

(2)

    elf elf and zynhc , zy are disentangled latent represenIn above equations, zxnhc , zx tations of images x and y, respectively. Through sharing the NHC encoder and ELF extractor for both domains, the proposed framework guarantees that the images can be embedded into the two shared spaces. Then we use text information to edit the emotional features in the ELF space to achieve the effect oftext-guided image emotion. Since the text embedded in our model only involves the emotional category of the image and does not contain a description of the semantic content of the input image, decoupling the content and emotional features of the input image can further improve the accuracy of emotional transformation. The architecture of our text-guided generation is depicted in elf Fig. 2. The reference image y is inverted into a latent code zy by the ELF extractor.

Text-Guided Generative Adversarial Network

511

Fig. 2. The training structure of TGE-GAN model. clip

Then, a mapping function M clip is trained to generate residuals zy controlled by text elf of the target emotion category ty , which are concatenated to zy to form a new generated emotional feature code zyce , and a generator G was used to obtain an image, which is evaluated by CLIP and ELF reconstruction loss, i.e.: elf

clip

zyce = [zy ; zy ]

(3)

The input image is disentangled by a set of NHC encoders and ELF extractors, encoded through a mapping function M clip , and finally reconstructed back to the RGB space by a generator G. So u and v can be reconstructed as:   (4) u = G zynhc , zxce   v = G zxnhc , zyce

(5)

Hence, it is possible to combine the NHC representation extracted from the input image with the ELF representation extracted from a reference image and the emotionrelated text when the high-level concepts of the input image are emotionally neutral and do not induce any feelings. In this work, E nhc , E elf , M clip and G are deterministic functions, which can be learned via deep neural networks. 3.2 Loss Functions Except the NHC encoder E nhc , ELF extractor E elf , mapping function M clip and generator G, the TGE-GAN consists of two domain discriminator DX , DY . The function of the discriminator, denoted as DX , is to differentiate between images originating from domain

512

S. Zhu et al.

X and the transferred image u. In the same way, DY distinguishes images from domain Y and v. . To achieve this, the NHC encoder and ELF extractor are employed to disentangle the input image pair x, y, resulting in their disentangled representations (as shown in Eq. (1) and Eq. (2)). The ELF representations are then swapped, and the transferred image pair u and v are reconstructed as:     (6) u = G zynhc , zxce = G E nhc (y), E elf (x), M clip (x) ,     v = G zxnhc , zyce = G E nhc (x), E elf (y), M clip (y) .

(7)

The objective of TGE-GAN is to ensure that both the transferred image u and v cannot be distinguished from the real images in their respective domain X and Y. For this purpose, the model employs a shared encoder, extractor and generator for both domains. This enables the disentanglement of images into two shared spaces for a successful image transformation. Additionally, several essential loss functions are defined to facilitate in the model training process [35]. Self-reconstruction Loss. The self-reconstruction loss Lrecon is employed to accelerate the model training. Incorporating this loss not only reduces the training time but also facilitates the learning process of the model in terms of image decomposition and reconstruction. For image pair x and y, the self-reconstruction loss is:   Lrecon E nhc , E elf , M clip , G =        (8)     Ex,y x − G zxnhc , zxce  + y − G zynhc , zyce  1

1

NHC Reconstruction Loss. As shown in Fig. 2, by incorporating this loss function, the model is encouraged to preserve the NHC of the original image in the transferred image as:            LNHC E nhc , E elf , M clip , G = Ex,y zxnhc , zunhc  + zynhc , zunhc  1      (9)  nhc   nhc  nhc nhc = Ex,y E (x) − E (v) + E (y) − E (u) . 1

Perceptual ELF Reconstruction Loss. Given that the ELF extractor is designed to extract emotional features, it is more effective to measure emotional differences in the feature space. Therefore, a perceptual ELF reconstruction loss is introduced to ensure that the transferred image matches the reference image in terms of emotional feature representations. Figure 2 illustrates the perceptual ELF reconstruction loss:   LELF E nhc , E elf , M clip , G ⎤ ⎡ (10)    

1 φj (x) − φj (u) + φj (y) − φj (v) ⎦, = Ex,y ⎣ 1 1 Cj Hj Wj j

Text-Guided Generative Adversarial Network

513

where φj is the feature map of the j-th layer in the ELF extractor. Its shape is Cj ×Hj ×Wj . We let j be odd number at training phase. CLIP Loss. The text-driven emotion features can be represented by leveraging CLIP Specifically, given a source latent code zxce , and a text prompt t (represent the emotion categary), the CLIP loss could be optimization problem:   Lclip E nhc , E elf , M clip , G    (11)

= Ex,y Dclip G(zxce ), tx + Dclip G(zyce ), ty , where G is a pretrained generator and Dclip is the cosine distance between the CLIP embeddings of its two arguments. Cross-Cycle Reconstruction Loss. Once the transferred images u and v are generated, the model disentangles them to obtain their reconstructed disentangled representations. Then, the model performs a second transfer by swapping the ELF representation and mapping function, resulting in the reconstructed input images xˆ and yˆ . The purpose of this loss function is to enable the model to learn how to map images between two domains with unpaired data. Figure 2 shows this process, and the cross-cycle reconstruction loss is:        LCC E nhc , E elf , M clip , G = Ex,y x − xˆ 1 + y − yˆ 1        (12)     = Ex,y x − G zvnhc , zuce  + y − G zunhc , zvce  . 1

1

Adversarial Loss. When training a GAN, the objective function is adversarial and aims to ensure that the generated images can be classified as real when the network converges. LSGAN [36] employs a least square loss objective that not only stabilizes the model training, but also enhances the quality of the generated images. Learning from LSGAN, for both domains, the adversarial loss is defined as:      nhc elf clip 2 2 LX E = E , (13) +E D , E , M , G, D [(D − 1) (x) (u)) x x,y X X X adv      nhc elf clip 2 2 LY E = E . +E D , E , M , G, D [(D − 1) (y) (v)) y x,y Y Y Y adv

(14)

Total Loss. The formulation of total loss is:   min max E nhc , E elf , M clip , G, DX , DY E nhc ,E elf ,M clip ,G DX ,DY

Y = LX adv + Ladv + λrecon Lrecon + λNHC LNHC + λELF LELF + λclip Lclip + λCC LCC , (15)

where λrecon , λNHC , λELF , λclip , λCC are hyper-parameters that balance the reconstruction terms. We will analyze the effect of parameter settings on the performance of the model in the experiments section.

514

S. Zhu et al.

3.3 Implementation Details All experiments in this study are conducted with images of size 256 × 256. The NHC encoder E nhc employs a two-layer 2-stride convolutional downsampling followed by 5 residual blocks, with Instance Normalization [37] applied in E nhc . Emotional features are extracted entirely through convolutional layers using the ELF extractor E elf , , which are then utilized to compute the affine parameters in AdaIN (Adaptive Instance Normalization) [38]. The mapping function M clip consists of 3 fully connected layers. While the architecture of G is a mirror-symmetrical version of the NHC encoder, its residual block convolutional layers and up-sampling layers are accompanied by AdaIN layers. To prevent checkerboard artifacts from appearing in the reconstructed images, the generator utilizes nearest-neighbor up-sampling and convolutional layers instead of transposed convolution. The AdaIN is defined as: AdaIN(p, q) = σ (q) ×

p − μ(p) + μ(q), σ (p)

(16)

where p is the non-linear activation within the generator, q is the emotional feature map extracted from the reference image by ELF extractor and modulated by text with a mapping function, μ and σ are channel-wise mean and standard deviation operators. Adam optimizer [39] with exponential decay rates (β1 , β2 ) = (0.5, 0.999) is used for optimizing the framework. Batch size is set as 1, learning rate is set as 0.0001, and the hyper-parameters are set as follows: λrecon = 10, λNHC = 5, λELF = 5, λclip = 5 and λCC = 10. The model for each group of emotional transfer is trained for 2 × 105 iterations on a single NVIDIA GTX 1080Ti without learning rate decay or dropout. It averagely takes 15.67 min in training step for each epoch.

4 Experiments 4.1 Image Dataset In this work, we used the emotion image database proposed by Zhu [13]. The database integrates some famous emotion image databases, including ArtPhoto [40], Emotion6 [41], FI [42], FlickrLDL and TwitterLDL [43], and reclassifies and annotates some of the images. The integrated dataset contains six types of emotion images (Anger, Disgust, Joy, Fear, Sadness and Surprise), details of each dataset are listed in Table 1. After integration, the quantity and quality of each type of image meet the usual requirements for training GAN-type models. In our work, to maintain consistency in the training of emotion conversion models for each type, we randomly selected 1000 images from each type of emotion image in the database for training and testing TGE-GAN, with 900 images used as training data and 100 images used as testing data. In the experimental section, we will evaluate and analyze the proposed model from two aspects: visual effects and accuracy of emotion conversion.

Text-Guided Generative Adversarial Network

515

Table 1. Image database summary. The combined dataset used in this study is shown in the last row Dataset

Amusement

Anger

Awe

Contentment

Disgust

Excitement

Fear

Joy

Sadnes

Sur rise

Emotion6



31





245



329

638

308

104

ArtPhoto

101

77

103

70

70

105

115



166



FI

4942

1266

3151

5374

1658

2963

1032

2922

5374

1658

FlickrLDL

1126

178

1355

6030

445

500

570



719—



TwitterLDL

872

185

255

6870

178

702

226



209



Combined



1284





1800



1612

1930

1772

1879

Fig. 3. Comparison with text-driven image manipulation model [19] in image emotion transfer task.

4.2 Qualitative Results The model proposed in this paper is a text-guided image emotion conversion model. As no similar work has been proposed in this field before, in the first part of the visual comparison, we chose StyleCLIP [19], a representative work of text-guided image style transfer, for the comparison of emotion conversion performance. The main function of the StyleCLIP model is to edit the semantic content of images based on the input text. Its research object is mainly focused on human facial information. When the input text contains emotionally related vocabulary, the model can change the facial expression features to achieve changes in image emotion representation, as shown in the Fig. 3. However, due to the limitations of training data and model design, when there is a lack of relevant content that can directly change semantics to produce emotional changes in the image, such as a human face, the model cannot complete the emotion conversion for the corresponding text. In contrast to StyleCLIP, which achieves image emotion conversion by changing the semantic content of images, our proposed TGE-GAN model mainly changes the low-level feature distribution of images based on reference images and text. This enables the model to perform corresponding emotion distribution conversion on images with different types of semantic content, including human faces or scenery images.

516

S. Zhu et al.

As our model achieves image emotion conversion by changing the low-level feature distribution of images, in addition to comparing its functionality with text-guided image editing models, we selected some leading algorithms in the field of general image style transfer for comparison. We also included the EGAN [13], an image emotion transfer algorithm as a baseline for comparison. We will compare the performance of these algorithms and the proposed algorithm from two aspects: quality and data. • AdaAttN [44]. AdaAttN is an arbitrary style transfer method which improve the quality by regularizing local features of the generated image. This paper uses original setting to perform experiments. • SWAG [45]. SWAG proposed a straightforward but powerful solution that involves a softmax transformation of the feature activations, which increases their entropy and it enhanced the robustness of the ResNet architecture. In this paper, the setting of features extraction and softmax transformation is original. • EFDM [46]. EFDM is an arbitrary style transfer approach that applies the Exact Histogram Matching (EHM) algorithm in the feature space. The experiments were conducted using the default hyper-parameter values. • EGAN [13]. EGAN is an image emotion transfer model based on the GAN network. By decoupling the image content and image emotion, it achieves mutual conversion between different emotion categories. The experiment used default parameter settings. In this section, we showcase the results of emotion conversion between the four most typical categories of image emotions: joy, anger, sadness and surprise. Figure 4 shows example results on joy ↔ sadness, joy ↔ anger, surprise ↔ sadness and surprise ↔ anger. It can be seen that AdaAttN and SWAG do not always succeed in reconstructing the high-level semantics of input images after transformation, which has a direct impact on the representation of image emotions. The results of EFDM are often similar to the input image, indicating that its model cannot perform the corresponding transformation of the emotional distribution of the image based on the reference image. Both EGAN and the proposed method include NHC encoders and ELF extractors for input images, resulting in higher consistency in the visual effects of image emotion transformation. By further comparing the transformation effects of the two, we can see that TGE-GAN has better consistency with the reference image in terms of image emotion representation, and it also has an advantage in the quality of local texture of the image. This indicates that text has a positive guiding effect on image emotion representation. The above qualitative results indicate that general-purpose image style transfer algorithms are not necessarily suitable for the task of image emotion transfer. By disentangling the image content features from the image emotion, it is possible to not only maintain the authenticity and realism of the reconstructed images, but also effectively extract emotional features. Based on this, guiding the corresponding emotional features according to the text content can further improve the accuracy of the model in emotion representation.

Text-Guided Generative Adversarial Network

517

Fig. 4. Examples of emotion transfer results. Comparisons between our method and the four baselines.

4.3 Quantitative Evaluation Emotion Deception Rate. In this paper, we adopt the emotion deception rate metric proposed by Zhu to measure the success rate of emotion transfer in our model. This method uses a pre-trained VGG16 [47] network to automatically classify the transferred images into six emotion categories, reducing the cost of large-scale manual labeling. The emotion deception rate is the ratio of transferred images classified as the target emotion category to the total transferred images, reflecting the success rate of the model in transferring the input image’s emotional distribution into the target emotional distribution. Denoting ci as the class indicator of the i-th transferred image. If the trained VGG16 model classifies the i-th transferred image as target emotion class, ci = 1, otherwise ci = 0. The formulation of emotion deception rate r is: r=

N 1 ci , N

(17)

i=1

where N is the number of transferred images. The mean deception rates are reported in Table 2. TGE-GAN achieves the rate of 0.5136. As reference, mean test accuracy of the classification model is 0.9299. Reconstructed Error. The performance of a model is reflected by the quality of the generated images. We use the reconstructed error to evaluate the quality of the transferred images. The mean absolute deviation between the input image and reconstructed image in RGB space is used to quantify the reconstructed error, smaller errors represent better

518

S. Zhu et al.

performance. For a reconstructed image, the reconstructed error e is calculated as: e=

C H W  1  xi,j,k − x˜ i,j,k , HWC

(18)

k=1 j=1 i=1

where x represents the input image, x˜ represents the reconstructed image, and H, W and C are the numbers of height, width and channel of the input image, respectively. The mean reconstructed errors are reported in Table 2. The proposed method achieved competitive results with mean reconstructed error of 8.37. Parameters Settings Analysis. To better evaluate the impact of text guidance on image emotion transfer, a comparative analysis of model performance under different text constraints was conducted. Under the default setting, the λclip was set to 5. Table 2 shows the performances of proposed model with 1 and 10 respectively for λclip while keeping the other model parameters unchanged and their corresponding results are shown in the proposed model’s emotional features are influenced by both the input image and input text, and choosing an appropriate parameter setting will have a direct impact on the accuracy of image emotion transfer. At the same time, since the text and emotion come from the pre-trained model, increasing the proportion of text-guided emotion will also affect the accuracy of image reconstruction. In our experiments, when λclip was set to 1, the image achieved the best reconstruction effect, with an accuracy increase of 0.23% compared to the default parameter setting, but the model’s emotional transfer

Fig. 5. Examples of emotion transfer images with different references and parameter settings.

Text-Guided Generative Adversarial Network

519

success rate decreased by 2.9% . The default parameter setting is a balance between the two performance indicators (Fig. 5).

Table 2. Comparisons of quantitative metrics. Method

Deception rate

Reconstructed error

AdaAttN

0.4315

9.87

SWAG

0.3288

17.02

EFDM

0.4611

14.20

EGAN

0.4925

8.54

Ours (λclip = 5)

0.4611

8.37

Ours (λclip = 1)

0.4985

8.35

Our (λclip = 10)

0.4658

8.55

5 Conclusion This paper proposes a text-guided image emotion transfer strategy, which decouples the image using an NHC encoder and an ELF extractor, allowing emotion-related text to directly affect the emotion feature space of the image, thereby reducing the subjective impact of emotional text on the semantic content of the image. The experimental results validate the effectiveness of this method in guiding image emotion transfer with text, and qualitative and quantitative results demonstrate the superiority of the proposed framework over existing general style transfer models and emotion transfer models. In future work, the text content can be considered to edit the semantic content of the image, further enriching the expressive power of the model. Acknowledgment. This work is partially supported by the following grants: National Natural Science Foundation of China (61972163, U1801262), Natural Science Foundation of Guangdong Province (2023A1515012568, 2022A1515011555), National Key R&D Program of China (2022YFB4500600), Guangdong Provincial Key Laboratory of Human Digital Twin (2022B1212010004) and Pazhou Lab, Guangzhou, 510330, China.

References 1. Ortis, A., Farinella, G.M., Battiato, S.: Survey on visual sentiment analysis, Image Process 14(8), 1440–1456 (2020) 2. Rao, K.S., Saroj, V.K., Maity, S., et al.: Recognition of emotions from video using neural network models. Exp. Syst. Appl. 38(10), 13181–13185 (2020) 3. Gupta, R., Ariefdjohan, M.: Mental illness on instagram: a mixed method study to characterize public content, sentiments, and trends of antidepressant use. J. Ment. Health 30(4), 518–525 (2021)

520

S. Zhu et al.

4. Diognei, M., Washington, R., Michel, S., et al.: A multimodal hyperlapse method based on video and songs’ emotion alignment. Pattern Recogn. Lett. 166 (2022) 5. Liam, S., Alice, O., Hazem, A.: Leveraging recent advances in deep learning for audio-Visual emotion recognition. Pattern Recogn. Lett. 146, 1–7 (2021) 6. Deepak, K., Pourya, S., Paramjit, S.: Extended deep neural network for facial emotion recognition. Pattern Recogn. Lett. 120, 69–74 (2019) 7. Kai, G., Xu, X., Lin, W., et al.: Visual sentiment analysis with noisy labels by reweighting loss. In: IEEE International Conference on Systems, Man, and Cybernetics, pp. 1873–1878 (2018) 8. Ali, M., Ali, A.R.: Automatic Image transformation for inducing affect. In: BMVC (2017) 9. An, J., Chen, T., Zhang, S., Luo, J.: Global image sentiment transfer. In: 25th International Conference on Pattern Recognition, ICPR 2020, pp. 6267–6274 10. Zhao, S., Zhao, X., Ding, G., Keutzer, K.: EmotionGAN: unsupervised domain adaptation for learning discrete probability distributions of image emotions. In: ACM Multimedia Conference, 2018, pp. 1319–1327 11. Chen, T., Xiong, W., Zheng, H., Luo, J.: Image sentiment transfer. In: The 28th ACM International Conference on Multimedia, pp. 4407–4415 (2020) 12. Li, X., Liu, S., Kautz, J.: Learning linear transformations for fast image and video style transfer. In: 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3804–3812 13. Zhu, S., Qing, C., Chen, C., Xu, X.: Emotional generative adversarial network for image emotion transfer. Expert Syst. Appl. 216, 119485 (2022) 14. Wang, H., Bai, X., Yang, M., et al.: Scene text retrieval via joint text detection and similarity learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4558–4567 (2021) 15. Mafla, A., Dey, S., Biten, A.F., et al.: Multi-modal reasoning graph for scene-text based fine-grained image classification and retrieval. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, 2021, pp. 4023–4033 16. Tian, Z., Shen, C., Chen, H., et al.: Fcos: Fully convolutional one-stage object detection. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. 2019, pp. 9627– 9636 17. Fu, T.J., Xin, E.W., William, Y.W.: Language-driven image style transfer. In: Proceedings of the International Conference on Learning Representations, 2022 18. Radford, A., Kim, J., Hallacy, C., et al.: Learning transferable visual models from natural language supervision, In: Proceedings of the 38th International Conference on Machine Learning, 2021, pp. 8748–8763 19. Patashnik, O., Wu, Z., Shechtman, E.. et al.: StyleCLIP: text-Driven manipulation of StyleGAN imagery. In: Proceedings of the IEEE/CVF International Conference on Computer Vision 2021, pp. 2085–2094 20. Kim, H.R., Kang, H., Lee, I.K.: Image recoloring with valence-arousal emotion model. Comput. Graph. Forum 35(7), 209–216 (2016) 21. Radford, A., Metz, L., Chintala, S.: Unsupervised representation learning with deep convolutional generative adversarial networks. 2015, arXiv preprint arXiv:1511.06434 22. Karras, T., Aila, T., Laine, S., et al.: Progressive growing of gans for improved quality, stability, and variation. In: International Conference on Learning Representations 2018 23. Benaim, S., Wolf, L.: One-sided unsupervised domain mapping. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 24. Isola, P., Zhu, J.Y., Zhou, T., et al.: Image-to-image translation with conditional adversarial networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2017, pp. 1125–1134

Text-Guided Generative Adversarial Network

521

25. Zhu, J.Y., Park, T., Isola, P.: Unpaired image-to-image translation using cycle-consistent adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, 2017, pp. 2223–2232 26. Yi, Z., Zhang. H., Tan, P., et al.: Dualgan: unsupervised dual learning for image-to-image translation. In: Proceedings of the IEEE International Conference on Computer Vision, 2017, pp. 2849–2857 27. Kim, T., Cha, M., Kim, H., Lee, J.K., et al.: Learning to discover cross-domain relations with generative adversarial networks. In: International Conference on Machine Learning, 2017, pp. 1857–1865 28. Reed, S., Akata, Z., Yan, X., et al.: Generative adversarial text to image synthesis. In: International Conference on Machine Learning. PMLR, 2016, pp. 1060–1069 29. Zhang. H., Xu, T., Li, H., et al.: Stackgan: text to photo-realistic image synthesis with stacked generative adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision. 2017, pp: 5907–5915 30. Zhang, H., Xu, T., Li, H., et al.: Stackgan++: Realistic image synthesis with stacked generative adversarial networks. IEEE Trans. Pattern Anal. Mach. Intell. 41(8), 1947–1962 (2018) 31. Xu, T., Peng. Z., Qiu, H., et al.: AttnGAN: Fine-grained text to image generation with attentional generative adversarial networks. In: 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2018, pp. 1316–1324 32. Zhang. H., et al:. StackGAN: text to photo-realistic image synthesis with stacked generative adversarial networks. In: IEEE International Conference on Computer Vision, 2017, pp. 5908– 5916 33. Zhang. H., Koh. J.Y., Baldridge. J., et al.: Cross-modal contrastive learning for text-to-image generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021, pp. 833–842 34. Karras, T., Laine, S., Aila, T.: A style-based generator architecture for generative adversarial networks. IEEE/CVF Conference on Computer Vision and Pattern Recognition 2019, 4396– 4405 (2019) 35. Lee. H.Y., Tseng. H.Y., Huang. J.B., et al:. Diverse image-to-image translation via disentangled representations. In: Proceedings of the European Conference on Computer Vision, 2018, pp. 35–51 36. Mao. X., Li. Q., Xie. H., et al.: Least squares generative adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, 2017, pp. 2794–2802 37. Ulyanov. D., Vedaldi. A., Lempitsky. V.: Improved texture networks: Maximizing quality and diversity in feed-forward stylization and texture synthesis. In: Proceedings of the IEEE Conference on Computer Vision And Pattern Recognition, 2017, pp. 6924–6932 38. Huang. X., Belongie. S.: Arbitrary style transfer in real-time with adaptive instance normalization. In: Proceedings of the IEEE International Conference on Computer Vision, 2017, pp. 1501–1510 39. Kingma. D.P., Ba. J.: Adam: a method for stochastic optimization, 2014, arXiv preprint arXiv: 1412.6980 40. Machajdik. J., Hanbury. A.: Affective image classification using features inspired by psychology and art theory. In: Proceedings of the 18th ACM International Conference on Multimedia, 2010, pp. 83–92 41. Peng. K.C., Chen, T., Sadovnik, A., et al.: A mixed bag of emotions: Model, predict, and transfer emotion distributions. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2015, pp. 860–868 42. You. Q„ Luo. J„ Jin. H„ et al.: Building a large scale dataset for image emotion recognition: the fine print and the benchmark. In: Proceedings of the AAAI Conference on Artificial Intelligence, 2016, Vol. 30, No. 1

522

S. Zhu et al.

43. Yang. J., Sun. M., Sun. X.: Learning visual sentiment distributions via augmented conditional probability neural network. Proc. AAAI Conf. Artif. Intell. 31(1) (2017) 44. Liu. S, Lin. T, He. D, et al. Adaattn: Revisit attention mechanism in arbitrary neural style transfer. In: Proceedings of the IEEE/CVF International Conference On Computer Vision, 2021, pp. 6649–6658 45. Wang. P., Li, Y., Vasconcelos, N.: Rethinking and improving the robustness of image style transfer. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2021, pp. 124–133 46. Zhang. Y., Li. M., Li. R., et al.: Exact feature distribution matching for arbitrary style transfer and domain generalization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 2022, pp. 8035–8045 47. Simonyan. K, Zisserman. A. Very deep convolutional networks for large-scale image recognition. In: International Conference on Learning Representations (2015)

What Constitute an Effective Edge Detection Algorithm? Prashan Premaratne(B)

and Peter Vial

School of Electrical, Computer and Telecommunications Engineering, University of Wollongong, North Wollongong, NSW 2522, Australia [email protected]

Abstract. Edge detection is essential in every aspect of computer vision from vehicle number plate recognition to object detection in images or video. Preprocessing stage of many image processing tasks usually associated with Artificial Intelligence are relying on separating multiple objects from each other before accessing any other information. Over the years, researchers have used Sobel, Prewet or Roberts filter and then relying more on robust Canny method. Today, there have been many improvements on the earlier approaches, and now very much rely on Convolutional Neural Networks to assist in determining effective edges that would assist immensely in object detection. From that perspective, it is quite evident that effective edge detection is all about eventual object detection. With this notion in mind, it is easy to see what methods work and what methods would not achieve the goals. Deep Learning (DL) approaches have been gaining popularity over the years. Do DL algorithms outperform the conventional edge detection algorithms? If they do, is it time for us to forget about the conventional approaches and resort to the new state-of-the art? Are they transparent, when performing poorly? How reliable are they? Do they perform consistently when unknown data are presented? This article will analyze the existing and emerging edge detection methods with a view to determine their usability and limitations in computer vision applications that would undoubtedly advance the field of image processing and computer vision. Keywords: Deep Learning · Convolutional Neural Networks · Edge Detection

1 Introduction Image processing is a set of tasks that would try to extract information or features. One of the foremost operations of image processing is edge detection which is the process of finding the boundaries of objects within an image as a discontinuity of brightness along pixels. Computer vision which tries to extract information from images or videos relies heavily on edge detection for segmentation of objects within frames. In many natural images, ambient light or purposeful illumination would capture images that would highlight objects within frames as discontinuities or abrupt transitions in pixel intensities. It is common knowledge that such sharp intensity variations are © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 523–530, 2023. https://doi.org/10.1007/978-981-99-4742-3_43

524

P. Premaratne and P. Vial

associated with objects within frames. However, certain illuminations could also give rise to such sharp intensity transitions. These transitions are related to Depth Discontinuities, Surface Orientation Discontinuities or illumination changes. When edge detection is attempted in an image, the result could be a set of connected curves that correspond to boundaries of objects, surface markings or discontinuities in surface orientations. Thus, the application of edge detection removes the texture and color of the image resulting in a drastically reduced information content, which is attractive for faster processing. It is important to realize that such preprocessing still leave the image’s structural content intact which would help extract variety of feature measures. In reality, many edge detection approaches that we explore in this paper may result in varying outcomes some are not favorable for information extraction. In many real images of nontrivial cases, curves or edges obtained are mostly broken depending on the edge detection algorithm used. Hence, interpreting information such as character recognition in vehicle number plate recognition will have difficulties in feature extraction. Hence, applications along would determine the usefulness of dozens of edge detection algorithms developed so far. 1.1 Canny Edge Detection Canny edge detection, which was developed by J. Canny in 1986 made use of Sobel kernels to extract edges both horizontally and vertically [1]. There are five major steps performed in Canny edge detection. In the first step: due to noise affecting edge detection, 5 × 5 Gaussian filter is used to suppress noise. To elaborate further, when noise is present, pixel values are very different from its neighbors and edges can disappear because of this. Looking at Fig. 1(a), (b), (c) and (d) points out that Prewitt and Sobel filters have very similar outcomes however, canny edge detection provides much more unbroken contours or curves which would be more pleasing to a human eye in interpreting objects. This is quite good indication why Canny edge detection is very much used in practice when edge detection is implemented. As would be seen by the above example, Canny edge detection produces a cleaner and smoother somewhat more continuous curves than both Prewitt and Sobel filters [2]. Next, a Sobel kernels are used for both horizontal and vertical directions and two diagonal directions to detect edges in this noise suppressed image. This leads to accurate edge detection. In the next stage, Sobel filtering may result in perhaps thick edges with more than one pixel. The algorithm next tries to thin this thick edge using non-max suppression where pixels in the gradient direction to the edge with the highest local maximum is preserved while other edge values are made to zero. This approach will remove thick edges and will result in an image single pixel edge. However, in many cases, there would be many edges that would clutter the image. Next, a process called double thresholding is performed where pixel values above the maximum threshold are kept as strong edges. The values below the minimum threshold are made zero and the values between these two thresholds are given another chance. Once the first thresholding stage is completed, the values between top threshold and the bottom threshold are evaluated using their connectedness to strong edges; the edges above the top threshold. If any of the edge sections connected to the strong edges are detected, they are preserved and the

What Constitute an Effective Edge Detection Algorithm?

525

Fig. 1. Most well-known conventional edge detection outcomes on Lena image.

ones not connected are removed. This way, the Canny edge detection performs a more advanced pruning process to remove thick edges in favor of thinner and more smooth edges. Canny edge detector still remain the state-of-the art algorithm even today despite it being conceived in an era when computer vision was in its infancy due to lack of adequate computing power for, other than large scaled mainframe computers. Edge detectors performing better than Canny edge detector appears to be more computationally expensive operating with greater number of parameters [3]. Rachid Deriche in 1987 developed a more optimized version of Canny edge detection which is sometimes known as Canny-Deriche edge detector. It follows similar steps as in Canny edge detector but relies on an Infinite Impulse Response (IIR) filter for edge detection instead of the Sobel filters used in Canny edge detector. Furthermore, Deriche attaches few optimization criteria such that the system should detect all edges and no false edges should be present; extracted edges in the edge map should be as close as possible to the real edges in the real image; no edge should be marked more than once.

526

P. Premaratne and P. Vial

1.2 Kovalevsky Edge Detection Vladimir A. Kovalevsky [4] has proposed to use preprocessing of the image with the Sigma filter [5] and with a special filter for the dilution of the ramps. His approach does not use luminance information but uses the intensity differences of colour channels of adjacent pixels. The method can distinguish between two adjacent pixels of same intensity but different colours. Many previous edge detectors used luminous component of images for edge detection whereas the Kovalevsky method relies on color making it useful for different situations when color information is present. Kovalevsky has made use Mexican Hat filter, which is a high-pass filter to computer a new image containing great positive values on one side and great negative values on the other side. In this calculation, zero crossing are detected as the point where positive values meet negative values. These zero crossings correspond to edges as would be depicted in Fig. 2.

Fig. 2. Kovalevsky edge detection approach using Mexican Hat.

2 Machine Learning Techniques Deep Learning (DL) is currently a popular machine learning technique. In Deep Learning, there are edge detection techniques based on codecs and others on network reconstruction. However, our discussion will be limited to edge detection based on multi-scale feature fusion. 2.1 Edge Detection Method Based on Multi-scale Feature Fusion In many Convolutional Neural Network based approaches, Kernels or filters are used to detect features such as edges in an entire image. In the training process, these kernels get their weights changed as a result of training so that they would be fine-tuned to detect certain features in an image. When the CNN is used for detection of certain features such as edges or faces, the kernel (or kernels) is spatially convolved with an image so that if the feature exists, it can be detected [6]. There have been few DL based edge detection methods since 2015 and one of the most recent one being the one by Su et al. in 2021 which stands out from the rest of such recent

What Constitute an Effective Edge Detection Algorithm?

527

approaches [7]. Their argument for a new DL based edge detector was to see whether great performing Canny type conventional edge detectors could be incorporated into an even more robust DL based edge detection [8]. In their paper, they proposed pixel difference convolution (PDC), where the pixel differences in the image were firstly computed, and then convolved with the kernel weights to generate output features as shown in Fig. 3. Which shows that PDC can effectively improve the quality of the output edge maps.

Fig. 3. PiDiNet configured with pixel difference convolution (PDC) vs. the baseline with vanilla convolution. Both models were trained only using the BSDS500 dataset. Compared with vanilla convolution, PDC can better capture gradient information from the image that facilitates edge detection [7].

Since the introduction of the Deep Learning algorithms, their progress has been rapid due to ever-increasing computing power, higher memory capacity, increasing image sensor resolution, and better optics over the years. Deep Learning has enabled engineers to implement computer vision applications with increased accuracy and cost-effectiveness for tasks such as image classification, semantic segmentation and object detection. The prime reason for the use of Neural Networks in Deep Learning algorithms is that Neural Networks are simply trained on set of data rather than programmed requiring less expert analysis and fine-tuning using the explosive amount of video data available. Deep Learning which are implemented using Convolutional Neural Networks (CNN) can be trained on a custom dataset for any specific application unlike the conventional computer vision algorithms, offering greater flexibility in achieving domain specific high accuracy [9]. As the Fig. 4 indicates, most state-of-the-art computer vision approaches now rely on directly mapping the input to the output as shown in Fig. 4(b) whereas Fig. 4(a) indicates how it was handled using handcrafted features, still using Neural networks few years ago. In 2015, a newly coined edge detection method namely holistically-nested edge detector (HED) was proposed that tackled two critical issues: (1) holistic image training and prediction, inspired by fully convolutional neural networks for image-to-image classification (the system takes an image as input, and directly produces the edge map image as output); and (2) nested multi-scale feature learning, inspired by deeply-supervised nets that performs deep layer supervision to “guide” early classification results [10].

528

P. Premaratne and P. Vial

Fig. 4. (a) Traditional Computer Vision workflow vs. (b) Deep Learning workflow [8].

Fig. 5. HED Edge Detection [10]

They found that the favorable characteristics of these underlying techniques manifested in HED being both accurate and computationally efficient. According their work [10], their edge detection outperforms many conventional and other deep learning edge detection algorithms. However, this is indeed the case when drawing test images from same training data sets. It is quite unclear how they would perform as per their speed and accuracy when dealing with entirely new images. Figure 5 shows the success of their approach against Canny edge detection.

What Constitute an Effective Edge Detection Algorithm?

529

Figure 6 Depicts the advantage of DL in detecting faint edges that is quite difficult with traditional edge detection approaches. Figure 7. Demonstrates that edge detection based on DL is indeed useful and practicable in medical imaging when traditional approaches completely fail [14].

Fig. 6. Example of a medical image with many curved edges, and the edge maps computed by deep learning and classic approaches. (a) The original image (b) The DL FED-CNN approach results [11]. (c) The classic FastEdges [12] results. Both methods achieve high quality of detection while the DL runs in milliseconds and the classic runtime is more than seconds [13].

Fig. 7. Edge detection output of a cancer image [14]

3 Conclusion While those DL algorithms outperform classic methods in terms of accuracy and development time, they tend to have higher resource requirements and are unable to perform outside their training space. Even though, many recent research claim that many DL based edge detection methods outperform in computational time as well as accuracy, it is really doubtful whether the claims would stand under different testing data with different resolutions. Despite many research touting the success of DL approaches in many publications, there is a deep uncertainty hangs over such claims. Hence, when engaging with DL techniques, not just in edge detection but in many classification problems, healthy skepticism should be maintained about their performance, despite glowing

530

P. Premaratne and P. Vial

remarks from many contemporary research. Moreover, classical algorithms are more transparent, which facilitates their adoption for real-life applications. Both have their advantages and the choice is dependent on their applications. In certain predictable scenarios such as in surveillance camera systems, using ample training data would undoubtedly finetune a DL algorithm that would outperform conventional edge detectors with accuracy and efficiency. In case of medical imaging, since it is highly unlikely to image unknown tissues of the body, DL could easily outperform the conventional techniques.

References 1. Canny, J.: A computational approach to edge detection. In: Readings in Computer Vision, pp. 184–203. Elsevier (1987) 2. Marr, D., Hildreth, E.: Theory of edge detection. Proc. R. Soc. Lond. B 207(1167), 187–217 (1980) 3. Gao, W., Zhang, X., Yang, L., Liu, H.: An improved sobel edge detection. In: IEEE International Conference on Computer Science and Information Technology, vol. 5, pp. 67–71. IEEE (2010) 4. Kovalevsky, V.: Image Processing with Cellular Topology, pp. 113–138 (2021). ISBN: 978981-16-5771-9 5. Lee, J.-S.: Digital image smoothing and the sigma filter. Comput. Vis. Graph. Inf. Process. 24(2), 255–269 (1983) 6. Dumoulin, V., Visin, F.: Box GEP a guide to convolution arithmetic for deep learning. arXiv Prepr arXiv arXiv160307285v2 (2018) 7. Su, Z., Liu, W., Yu, Z.: Pixel difference networks for efficient edge detection. arXiv preprint arXiv. 07009 (2021) 8. Sun, R., et al.: Survey of image edge detection. Front. Signal Process. Sec. Image Process. 2 (2022). https://doi.org/10.3389/frsip.2022.82696 9. O’Mahony, N., et al.: Deep learning vs. traditional computer vision (2019). https://doi.org/ 10.48550/arXiv.1910.13796 10. Wang, J., Ma, Y., Zhang, L., Gao, R.X.: Deep learning for smart manufacturing: methods and applications. J. Manuf. Syst. 48, 144–156 (2018). https://doi.org/10.1016/J.JMSY.2018. 01.003 11. Xie, S., Tu, Z.: Holistically-nested edge detection. In: 2015 IEEE International Conference on Computer Vision (ICCV), pp. 1395–1403 (2015). https://doi.org/10.1109/ICCV.2015.164 12. Ofir N., Keller, Y.: Multi-scale processing of noisy images using edge preservation losses. In: International Conference on Pattern Recognition, pp. 1–8. IEEE (2021) 13. Ofir, N., Galun, M., Nadler, B., Basri, R.: Fast detection of curved edges at low snr. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 213– 221 (2016) 14. Ofir, N., Nebel, J.C.: Classic versus deep learning approaches to address computer vision challenges (2021). https://doi.org/10.48550/arXiv.2101.09744 15. Li, X., Jiao, H., Wang, Y.: Edge detection algorithm of cancer image based on deep learning. Bioengineered 11(1), 693–707 (2020). https://doi.org/10.1080/21655979.2020.1778913. PMID:32564648;PMCID:PMC8291821

SporeDet: A Real-Time Detection of Wheat Scab Spores Jin Yuan1 , Zhangjin Huang1(B) , Dongyan Zhang2 , Xue Yang3 , and Chunyan Gu3 1 University of Science and Technology of China, Hefei 230027, China

[email protected], [email protected]

2 College of Mechanical and Electronic Engineering, Northwest A&F University,

Yangling 712100, Shaanxi, China 3 Institute of Plant Protection and Agro-Prodducts Safety, Anhui Academy of Agricultural

Sciences, Hefei 230031, China

Abstract. Wheat scab is a destructive plant disease that has caused significant damage to wheat crops worldwide. The detection of wheat scab spores is essential to ensure the safety of wheat production. However, traditional detection methods require expert opinion in their detection processes, leading to less efficiency and higher cost. In response to this problem, this paper proposes a spore detection method, SporeDet, based on a holistic architecture called ‘backbone-FPN-head’. Specifically, the method utilizes RepGhost with FPN to fuse feature information from the backbone while minimizing the model’s parameters and computation. Additionally, a task-decomposition channel attention head (TDAHead) is designed to predict the classification and localization of FPN features separately, thereby improving the accuracy of spore detection. Furthermore, a feature reconstruction loss (RecLoss) is introduced to further learn the features of RGB images during the training process, which accelerates the convergence of the model. The proposed method is evaluated on spore detection datasets collected from the Anhui Academy of Agricultural Sciences. Experimental results demonstrate that the SporeDet method achieves an optimal mean average precision (mAP) of 88%, and the inference time of the model reaches 4.6 ms on a 24 GB GTX3090 GPU. Therefore, the proposed method can effectively improve spore detection accuracy and provide a reference for detecting fungal spores. Keywords: Wheat scab · Spore detection · Channel attention · Feature reconstruction loss

1 Introduction With the rapid development of deep learning networks in detection tasks, the performance of object detectors has significantly improved. As a result, these detectors have been widely applied in various agricultural scenarios and achieved optimal results. However, the sampling and monitoring of atmospheric airborne plant pathogenic fungi, such as those responsible for wheat scab, still rely on spore trap technology [1]. This method requires taking spore-adherent slides or trap tapes back to a laboratory for identification © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 531–543, 2023. https://doi.org/10.1007/978-981-99-4742-3_44

532

J. Yuan et al.

and detection under a microscope, which can be time-consuming and labor-intensive, particularly considering the scattered and numerous observation points. These limitations make it challenging to comprehend the real-time and dynamic changes of fungal spores in large-scale farmland. To address this challenge, it is necessary to explore methods that can automatically and intelligently detect wheat scab spores in real-time. To the best of our knowledge, no deep learning method currently exists for this purpose. Implementing such a method would not only reduce workload and increase monitoring efficiency but also provide necessary data support for forecasting wheat disease. In recent years, spore detection methods for microscopic images have been primarily categorized into traditional [2, 3] and deep learning-based [4, 5] methods. Traditional methods for spore detection require significant time and effort, and the process is highly complicated. Typically, these methods involve extracting spore microscopic image features and processing them to detect the object. However, traditional methods face challenges due to environmental factors and result in poor detection outcomes. With the rapid development of computer vision technology, deep learning methods have emerged to address these challenges. Although several deep learning techniques for detecting powdery mildew spores and corn stamens are found in literature, these methods have high model complexity, low inference speed, and slow convergence speed. Hence, we propose a Real-time Spore Detection (SporeDet) that aims to improve the accuracy of spore detection by utilizing a decoupled head and designing a feature reconstruction loss (RecLoss). The main contributions of our paper are as follows: 1) we simplify the model by combining the RepGhost bottleneck (RG-bneck) with FPN to fuse features; 2) we design a task-decomposition channel attention head (TDAHead) to alleviate the conflict between classification and regression tasks, and focus only on the corresponding features; 3) we propose a feature reconstruction loss (RecLoss) to learn the features from RGB images while boosting the convergence speed. Furthermore, we conducted extensive experiments on spore detection datasets to evaluate the effectiveness of SporeDet. Our approach achieved an impressive 88% mAP, surpassing other detectors [6–10] by a large margin. We evaluated the quality of our proposed approach by comparing it with the baseline method and verifying the effectiveness of SporeDet.

2 Related Work In this section, we will review some research that is most relevant to our work. Specifically, we will focus on introducing methods for spore detection and object detection, and highlighting their respective advantages and disadvantages. 2.1 Spore Detection Recent advances in spore detection methods can be categorized into two distinct groups: the traditional method based on machine learning and the method based on deep learning. The traditional approach, involving image segmentation, feature extraction, and

SporeDet

533

classification recognition, has been utilized with some success. For instance, Li et al. [2] achieved the automatic counting of summer spores of wheat stripe rust by scaling based on nearest neighbor interpolation, segmentation based on the K-means clustering algorithm, morphological processing, and watershed algorithm segmentation. Qi et al. [3], on the other hand, corrected for illumination using the block background extraction method and detected the edge features of spores using the Canny operator. They utilized the morphological opening and closing operations, and finally, distance transform and improved watershed algorithm to detect the adherent spores, with an accuracy of 98.5%. However, spore detection based on traditional methods has a complex feature extraction process and requires human participation, making the efficiency of their algorithm very low. In contrast, deep learning based approach primarily involves convolutional neural networks and has gained popularity. Zhang et al. [5] aimed to improve detection performance by constructing a detector called FSNet for tiny spores. This method utilizes the two-stage object detector Faster R-CNN to design a feature extraction framework and propose a loss to decrease the error boxes. Liang et al. [4] performed image segmentation of wheat powdery mildew spores based on U-Net by focusing on the pyramid pooling module and the skip connection, thus achieving automatic detection of wheat powdery mildew spores with a segmentation rate of 91.4%. However, the method based on Faster R-CNN has low inference speed. On the other hand, the method based on U-Net requires recognizing the objects in an image and segmenting each of them at the pixel level, which is costly. Therefore, our proposed method is a one-stage detection method based on bounding box regression. Our approach balances speed and accuracy, providing an optimal solution to spore detection. 2.2 Object Detection Object detection is one of the most active topics in computer vision, and it focuses on designing an efficient and accurate detection framework. Deep learning-based object detection can be broadly classified into two categories: two-stage object detection algorithms based on region proposals, such as R-CNN [11], Fast R-CNN [12], Faster R-CNN [8], and Mask R-CNN [13], and one-stage object detection algorithms based on bounding box regression, such as the YOLO [14] and SSD [15] series and RetinaNet [7]. Although two-stage detectors achieve high accuracy in object localization and recognition, they typically come at the cost of reduced speed. Conversely, one-stage detectors offer faster detection while still maintaining reasonable accuracy. Therefore, we have chosen the one-stage object detector YOLOv7 [9] as our baseline method. YOLOv7 leverages the same ‘backbone-FPN-head’ architecture as previous YOLO models, with feature maps extracted by the backbone network across multiple scales and then fused by FPN before being processed by the head. YOLOv7 introduces an extended efficient layer aggregation network (E-ELAN) to enhance the learning capability of the network, but this can increase the computational cost of hardware devices. To address this concern, we have incorporated the hardware-efficient RepGhost module proposed by Chen et al. [16] that employs a structural reparameterization technique to promote feature reusability. Given that object detection aims to recognize and localize objects of interest from images, there is often a conflict between classification and regression tasks. YOLOX [6] proposes a novel decoupled head to provide more accurate predictions for classification

534

J. Yuan et al.

and localization, and we have designed a task-decomposition channel attention head to promote decoupling for specific tasks. Combining these techniques, we aim to develop an efficient and accurate spore detection framework. In this paper, we are committed to exploring spore detection methods. Compared to previous works, this paper strikes the balance between speed and accuracy by designing the RepGhostNeck, the TDAHead, and the RecLoss.

Fig. 1. Proposed network architecture for SporeDet. 1) Backbone to extract multi-scale feature maps; 2) RepGhostNeck to fuse different scale features; 3) TDAHead to improve the accuracy; 4) RecLoss to accelerate the model convergence speed. The dashed line indicates the computation flow of the feature reconstruction loss. © indicates concatenation operation.

3 Proposed Method In this section, we present the architecture of our proposed network, along with the details of its comprising modules. Our proposed SporeDet model has a similar overall pipeline to the current one-stage detector YOLOv7 [9], utilizing the ‘backbone-FPNhead’ structure. The framework of our SporeDet model is illustrated in Fig. 1, which consists of the following parts: 1) the Backbone to extract multi-scale feature maps, 2) the RepGhostNeck to fuse different scale features, 3) the proposed TDAHead to improve the accuracy, and 4) the RecLoss to accelerate the model convergence speed.

SporeDet

535

3.1 The Base Framework In this subsection, we have described the overall framework of the proposed method. At first, we input a microscopic image of the scab spore into the backbone network to extract the features of the scab spore at different scales. These feature maps are further used in the next feature fusion module RepGhostNeck. In order to simplify the model, we replace E-ELAN module that has a stack of many convolutional layers with the RepGhost bottleneck. Also, re-parameterization technology of RepGhost makes the feature reusable, and maximizes the accuracy of the model. TDAHead is designed to alleviate the conflict between classification and regression tasks. Furthermore, the classification and location tasks of TDAHead focus on different types of features to improve detection accuracy. We design a channel attention mechanism to calculate the features of different tasks to promote task decomposition. At the same time, we present the RecLoss in the training process to accelerate the convergence of the model and improve the detection accuracy of spores. Our main contributions lie in the parts of the proposed TDAHead and the RecLoss, which will be introduced in detail in Sect. 3.2 and Sect. 3.3. 3.2 Task-Decomposition Channel Attention Head High-level feature maps have a larger receptive field, allowing them to extract more global information that aids in the identification of different object classifications. However, due to the larger corresponding original image area, they are unsuitable for locating the object’s position accurately. On the contrary, low-level feature maps have a smaller receptive field and correspond to a smaller area of the original image, making them more effective for object localization. However, feature maps only contain local information about the object, so they can not effectively identify the classification of the object. The conflict between the classification and regression tasks affects the accuracy and efficiency of the model. Therefore, decoupled heads are widely used in one-stage and two-stage object detectors. Inspired by YOLOX [6], we designed a TDAHead, as shown in Fig. 2.

Fig. 2. Architecture of the TDAHead.

During the feature fusion module, the input image is downsampled to 8×, 16×, and 32× to generate three feature maps of scale size 80 × 80 80 × 80, 40 × 40 40 × 40, and 20 × 20 20 × 20, respectively, to focus on extracting the features of small, medium and large objects separately. fpn

Xi , ∀i ∈ {1, 2, 3},

(1)

536

J. Yuan et al.

where X fpn X fpn denotes the obtained feature map. Specifically, TDAHead uses three separate branches in parallel, namely the classification branch, the confidence branch, and the regression branch. The classification branch is used to determine object classification, and the dimension of the output features is [B, H , W , C] [B, H , W , C], where B B, H H , W W , and C C refer to the batch size, the height of the feature map, the width of the feature map and the number of categories, respectively. The confidence branch is used to determine whether there is an object, and the dimension of the output features is [B, H , W , 1] [B, H , W , 1], where 1 1 represents the probability of the object. The regression branch is used to predict the regression parameters, and the dimension of the output feature is [B, H , W , 4] [B, H , W , 4], 4 4 denotes the coordinates of the box (x, y, w, h) (x, y, w, h). Finally, the outputs of the three branches are concatenated together to obtain the final result. Motivated by TOOD [17], we design an intra-layer attention mechanism in the decoupled head to emphasize specific features, as shown in Fig. 3.

Fig. 3. Architecture of channel attention mechanism.

First, we add a learned biased term A A to the FPN features obtained from RepConv RepConv to enrich features, and termed as Shift Operator Shift Operator. Formally, let X fpn ∈ RC×H ×W X fpn ∈ RC×H ×W denotes the FPN features, where C C, H H and W W indicate the number of channels, height, and width, respectively. fpn fpn Xˆ i = Xi + A, ∀i ∈ {1, 2, 3},

(2)

where A ∈ RC×1×1 A ∈ RC×1×1 denotes the learned bias term, and C C is the number of channels. Second, weights wi wi of multi-scale features are computed to capture the dependencies between channels, as shown in Eq. (3). wi = σ (conv2 (δ(conv1 (xiinter )))),

(3)

fpn where xiinter xiinter is obtained by applying an adaptive average pooling on Xˆ i , conv1 conv1 and conv2 conv2 refer to 1 × 1 1 × 1 convolution, δ denotes the ReLU ReLU activation function, and σ denotes the sigmoid sigmoid activation function. fpn fpn We then multiply wi wi by Xi Xi to calculate the features Y Y for the classification and the regression tasks. fpn

Y = wi × Xi ,

(4)

SporeDet

537

The results of classification or localization are predicted from Y Y . R = δ(conv3 (Y )),

(5)

where conv3 conv3 is 1 × 1 1 × 1 convolution, and δ denotes the ReLU ReLU activation function. Finally, the results are scaled by a learnable parameter M M to obtain a better prediction of the bounding box. Rˆ = R × M ,

(6)

where M ∈ RC×1×1 M ∈ RC×1×1 denotes a learnable scaling parameter. 3.3 Feature Reconstruction Loss To further accelerate the convergence and improve the accuracy of the spore detection model, we design a feature reconstruction loss. The loss encourages the network to learn valuable features by reconstructing the input features to obtain a denser and more usable feature representation. The feature reconstruction loss is illustrated in Fig. 4. In Fig. 1, the dashed line indicates the computation flow of the feature reconstruction loss, which is performed during the training process, so it does not increase the computational overhead during the inference process.

Fig. 4. Calculation of feature reconstruction loss.

After the feature fusion module, three feature maps with different scales are obtained. Since the feature maps do not have the same scale size, MSELoss cannot be calculated directly. To minimize the distance between the feature maps and understand the global

538

J. Yuan et al.

structure, we upsample the lower-level features and combine them with higher-level ones to obtain high-resolution features. fpn

fpn

Fi = UpSample(Xi ) ⊕ Xi+1 , ∀i ∈ {1, 2},

(7)

where X fpn X fpn denotes the feature maps at different scales, and Fi Fi denotes the feature map obtained by combining two different scales of feature maps. By passing Fi Fi through VGGBlock VGGBlock and repeating the above process, the three scales are finally fused to obtain a feature map that is downsampled by eight times, namely reconstruction features. Ui = VGGBlock(Fi ), ∀i ∈ {1, 2},

(8)

To minimize the gap between the original input image and the learned features, we downsample the input image (img img) to obtain features (Down features) at the same scale as the reconstructed features. D = DownSample(img),

(9)

where D D denotes the feature obtained by downsampling the original input image. With a 1 × 1 1 × 1 convolution, we convert the number of channels of the reconstruction features into three. The similarity between the input and encoded image features is evaluated through mean squared error. Lfea = MSE(Conv(U2 ), D),

(10)

where Lfea Lfea denotes the loss obtained from feature reconstruction. 3.4 Overall Optimization The framework of our SporeDet is shown in Fig. 1. The total loss is defined as the weighted sum of different losses: L = α1 Lcls + α2 Lreg + α3 Lobj + α4 Lfea ,

(11)

where classification loss Lcls Lcls , regression loss Lreg Lreg , confidence loss Lobj Lobj , and feature reconstruction loss Lfea Lfea are balanced by the weight parameters (α1 , α2 , α3 , α4 ), respectively. When α4 = 0, Eq. (11) becomes the original loss of the baseline model.

4 Experiments 4.1 Dataset and Evaluation Metrics All experiments are implemented on the spore detection datasets collected from the Anhui Academy of Agricultural Sciences in 2022. The dataset contains 3000 microscopic images of spores with bounding box annotations for 7 classes, which is further divided

SporeDet

539

into 2,400 and 600 images for training and validation. The categories of spore detection dataset include the primary fungi causing wheat scabs, namely Fusarium graminearum and Fusarium moniliforme. The dataset is mixed with five common fungal spores in the field, including Fusarium pseudograminearum, Colletotrichum glycines, Colletotrichum gloeosporioides, Pestalotiopsis clavispora, and Pear anthracnose spores. The dataset simulates the natural environment with various fungal species in the field. For quantitative evaluation, mean average precision (mAP) [18] is used for accurate comparison, and number of float-point operations (FLOPs) and frames per second (FPS) are adopted for speed comparison. 4.2 Implementation Details We use PyTorch [19] framework to carry out the following experiments. We choose the one-stage object detector YOLOv7 [9] as our baseline. All networks are trained 300 epochs with stochastic gradient descent (SGD) optimizer. The weight decay and SGD momentum are 0.0005 and 0.937, respectively. The initial learning rate is 0.01 with a batch size of 16, and this rate decays according to a cosine schedule. Following the YOLO-Series [6, 9] model, we use exponential moving average (EMA) and grouped weight decay to enhance the robustness of the model. For the hyper-parameters in SporeDet, empirically we set α1 = 0.3, α2 = 0.05, α3 = 0.7 and α4 = 0.00005 in the overall loss L L. As a common practice, Mosaic and Mixup augmentation are adopted to enhance data diversity. 4.3 Comparisons with Existing Methods We compare SporeDet against previous methods, such as Faster R-CNN [8], RetinaNet [7], CenterNet [10], YOLOX [6], and YOLOv7 [9], on the spore detection dataset. These methods are provided by mmdetection [20], and we have utilized pre-trained models to fine-tune the existing models further to get better results. From the results in Table 1, we know that the proposed method has 88% mAP, which is better than other commonly used methods. 4.4 Ablation Studies We perform experimental validation on the collected datasets for qualitative and quantitative analysis. We show the effectiveness of each component step by step. Qualitative Analysis. We perform qualitative analysis by comparing the detection results of the baseline with our method, as shown in Fig. 5. Fig. 5(a) shows the detection results of the baseline method, where the red boxes indicate error boxes. The left and middle red boxes in Fig. 5(a) show that the detection mistakes impurities for spores, and the right red box indicates it misses one spore. In contrast, the detection results of our method in the corresponding left, right, and middle positions are accurate, as illustrated in Fig. 5(b).

540

J. Yuan et al.

Table 1. Comparison of the accuracy of different object detectors on spore detection datasets. All the models are tested at 640 × 640 resolution. AP50 and AP75 represent mAP at 0.50 and 0.75 IoU threshold, respectively. The best results are in boldface, and the second best results are underlined. Model

mAP(↑)

AP50 (↑)

AP75 (↑)

Faster R-CNN

85%

94.7%

90.9%

RetinaNet

84.4%

93.9%

89.8%

CenterNet

77.7%

85.9%

82.7%

YOLOX

76.7%

89.3%

82.6%

YOLOv7

77.7%

90.1%

83.9%

Ours

88%

95.9%

92.3%

Fig. 5. The results of the baseline and our method. Table 2. Ablation study on proposed method. All the models are tested at 640 × 640 resolution. AP50 and AP75 represent mAP at 0.50 and 0.75 IoU threshold, respectively. Our FLOPs are calculated based on rectangular input resolution, such as 640 × 640. The best results are in boldface, and the second best results are underlined. Model

Params

FLOPs

mAP(↑)

AP50 (↑)

AP75 (↑)

(a)

baseline

36.5M

103.3G

77.7%

90.1%

83.9%

(b)

baseline(RepGhostNeck)

30.2M

82.1G

81.6%

93.5%

87.9%

(c)

baseline + TDAHead

37.6M

104.7G

80.8%

92.3%

86.2%

(d)

baseline + RecLoss

36.5M

103.3G

82.4%

92.5%

87.9%

(e)

baseline + TDAHead + RecLoss

37.6M

104.7G

85%

94.5%

89.9%

(f)

Ours

31.3M

83.6G

88%

95.9%

92.3%

SporeDet

541

Quantitative Analysis. To verify the effectiveness of SporeDet, we perform the stepby-step evaluation of the results of RepGhostNeck, TDAHead, and RecLoss, as shown in Table 2. Effect of RepGhostNeck. It is designed to efficiently aggregate multiple-scale information to form a more robust object feature representation. By leveraging the hardwareefficient RepGhost bottleneck module, the computational cost of the model in hardware is reduced, and implicit features are reused through re-parameterization. In Table 2, we get an mAP of 77.7% with the baseline detector. then, by using the RepGhostNeck, It improves the mAP to 81.6%. Furthermore, the number of parameters is reduced from 36.5M to 30.2M, and the computation is reduced from 103.3G to 82.1G, thus improving the real-time performance of the model. Effect of TDAHead. Designing channel attention mechanisms at each branch of TDAHead, I.E., classification, regression, and confidence, has led to the discovery of more features, improving the accuracy of TDAHead. As shown in Table 2, TDAHead has increased the mAP from 77.7% to 80.8%, with a slight increase in parameters and FLOPs. Effect of RecLoss. We obtain the two 1/8 1/8 feature maps, one from multi-scale features and the other from the original image, and minimize the gap between these two feature maps using MSELoss. As shown in Table 2, RecLoss achieves 82.4% mAP. While the components in (c) and (e) of Table 2 have shown the same results of 37.6M and 104.7G across parameters and FLOPs, respectively, because the RecLoss is only generated in the training phase. The statistics suggest that our method strengthens the training cost by improving object detection accuracy without increasing the inference cost. Apart from performance gains, the feature reconstruction loss can significantly boost the convergence speed, as depicted in Fig. 6. We recorded the total loss L L trends on the spore detection datasets both in the presence and absence of the feature reconstruction loss Lfea Lfea . as we can see from Fig. 6, Lfea Lfea can significantly improve the total convergence loss. This further intuitively depicts the effectiveness of the proposed feature reconstruction loss Lfea Lfea .

Fig. 6. Comparison of model convergence speed. The blue line indicates total loss curves for baseline. The orange line shows total loss curves for our method.

In addition, we have tested the model speed on a 24 GB GTX3090 GPU. As shown in Table 3, our method increased the detector speed from 6.7 ms to 4.6 ms.

542

J. Yuan et al.

Table 3. Comparision of model speed. All the models are tested on a 24GB GTX3090 GPU. The latency and FPS in this table are measured without post-processing. FPS represents the number of images processed per second. Model

Params

FLOPs

Latency

FPS

(a)

baseline

36.5M

103.3G

6.7ms

149

(b)

Ours

31.3M

83.6G

4.6ms

217

5 Conclusions and Future Work In this paper, we propose a real-time spore detection method based on the ‘backboneFPN-head’ architecture. Initially, we exploit RepGhost with FPN to fuse specific features and lighten the model. Afterward, we design a task-decomposition channel attention head to improve the accuracy of spore detection. Finally, we present a feature reconstruction loss to boost the model convergence speed. With these improvements, our method achieves better performance across different spore detection when evaluated on datasets. In the future, we plan to extend our work to semi-supervised spore detection. Acknowledgements. This work was supported in part by National Natural Science Foundation of China (grant no. 42271364), the National Key R&D Program of China (Nos. 2022YFB3303402 and 2021YFF0500901), and the National Natural Science Foundation of China (Nos. 71991464/71991460, and 61877056).

References 1. Cao, X., Zhou, Y., Duan, X.: The application of volumetric spore trap in plant disease epidemiolgy. In: Proceedings of the Annual Meeting of Chinese Society for Plant Pathology (2008) 2. Li, X., Ma, Z., Sun, Z., Wang, H.: Automatic counting for trapped urediospores of Puccinia striiformis f. sp. tritici based on image processing. Trans. Chin. Soc. of Agric. Eng. 29 (2013) 3. Qi, L., Jiang, Y., Li, Z., Ma, X., Zheng, Z., Wang, W.: Automatic detection and counting method for spores of rice blast based on micro image processing. Trans. Chin. Soc. Agric. Eng. 31 (2015) 4. Liang, X., Wang, B.: Wheat powdery mildew spore images segmentation based on U-Net. In: 2nd International Conference on Artificial Intelligence and Computer Science (2020) 5. Zhang, Y., Li, J., Tang, F., Zhang, H., Cui, Z., Zhou, H.: An automatic detector for fungal spores in microscopic images based on deep learning. Appl. Eng. Agric. 37 (2021) 6. Ge, Z., Liu, S., Wang, F., Li, Z., Sun, J.: YOLOX: Exceeding YOLO series in 2021. arXiv preprint arXiv:2107.08430 (2021) 7. Lin, T.-Y., Goyal, P., Girshick, R., He, K., Dollar, P.: Focal loss for dense object detection. IEEE Trans. Patt. Anal. Mach. Intell. 42 (2020) 8. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. IEEE Trans. Patt. Anal. Mach. Intell. 39 (2017) 9. Wang, C.-Y., Bochkovskiy, A., Liao, H.-Y.M.: YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors. arXiv preprint arXiv:2207.02696 (2022)

SporeDet

543

10. Zhou, X., Wang, D., Krähenbühl, P.: Objects as points. arXiv preprint arXiv:1904.07850 (2019) 11. Girshick, R., Donahue, J., Darrell, T., Malik, J.: Rich feature hierarchies for accurate object detection and semantic segmentation. In: 2014 IEEE Conference on Computer Vision and Pattern Recognition (2014) 12. Girshick, R.: Fast R-CNN. In: 2015 IEEE International Conference on Computer Vision (2015) 13. He, K., Gkioxari, G., Dollar, P., Girshick, R.: Mask R-CNN. In: 2017 IEEE International Conference on Computer Vision (2017) 14. Redmon, J., Divvala, S., Girshick, R., Farhadi, A.: You only look once: unified, real-time object detection. In: 29th IEEE Conference on Computer Vision and Pattern Recognition (2016) 15. Liu, W., et al.: SSD: single shot MultiBox detector. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) Computer Vision – ECCV 2016. ECCV 2016. LNCS, vol. 9905. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46448-0_2 16. Chen, C., Guo, Z., Zeng, H., Xiong, P., Dong, J.: RepGhost: a hardware-efficient ghost module via re-parameterization. arXiv preprint arXiv:2211.06088 (2022) 17. Feng, C., Zhong, Y., Gao, Y., Scott, M.R., Huang, W.: TOOD: Task-aligned one-stage object detection. In: 2021 IEEE/CVF International Conference on Computer Vision (2021) 18. Lin, T.Y., et al.: Microsoft COCO: common objects in context. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) Computer Vision – ECCV 2014. ECCV 2014. LNCS, vol. 8693. Springer, Cham (2014). https://doi.org/10.1007/978-3-319-10602-1_48 19. Paszke, A., et al.: Automatic differentiation in PyTorch. In: NIPS-W (2017) 20. Chen, K., et al.: MMDetection: open mmlab detection toolbox and benchmark. arXiv preprint arXiv:1906.07155 (2019)

Micro-expression Recognition Based on Dual-Branch Swin Transformer Network Zhihua Xie(B) and Chuwei Zhao Key Lab of Optic-Electronic and Communication, Jiangxi Science and Technology Normal University, 605 Fenglin Avenue, Nanchang 330031, China [email protected]

Abstract. Micro-expression (ME) refers to the facial expression that flashes instantly and can reveal the real feelings and emotions of people. Compared with ordinary facial expressions, ME is not easy to be identified due to its short duration and inconspicuous performance. This paper uses Swin Transformer as the main network and dual-branch as the main framework to extract the temporal and spatial features for micro-expression recognition (MER). The first branch uses optical flow operator to preprocess the ME sequences, and the resulting optical flow maps are fed into the first Swin Transformer to extract motion feature information. The second branch directly sends the apex frame in one ME clip to the second Swin Transformer to learn the spatial feature. Finally, the feature flows from the two branches are fused to implement the final MER task. Extensive experimental comparisons on three widely used public ME benchmarks show that the proposed method is superior to the-state-of-the-art MER approaches. Keywords: Swin Transformer · Dual-branch · Micro-expression Recognition · Optical Flow · Spatial Features

1 Introduction Facial expression is an important form of information transmission between people, through its recognition can promote the understanding of people’s mental state. Facial expressions are generally divided into macro-expression and micro-expression. Macroexpressions are normal expressions that people present in daily interactions. They occur in multiple parts of the face at the same time and usually last between 1/2 and 4 s [1]. Since macro-expression can be easily disguised, it is impossible to identify people’s hidden thoughts and inner feelings. On the contrary, micro-expression (ME) usually appears in tense or high-risk situations. It is a spontaneous expression that cannot be suppressed and hidden. It only appears in local areas of the face, with subtle intensity and short duration. Therefore, ME recognition can be applied to various fields, such as the criminal interrogation, psychoanalysis, clinical diagnosis and public safety. In recent years, with the adventure of micro-expression research and the rapid development of computer vision technology, there are more and more researches on automatic recognition of micro-expression, mainly including traditional methods and deep © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 544–554, 2023. https://doi.org/10.1007/978-981-99-4742-3_45

Micro-expression Recognition

545

learning methods. Traditional methods mainly use hand-made features, such as Local Binary Pattern-Three Orthogonal Planes (LBP-TOP) [2], 3 Dimensional Histograms of Oriented Gradients (3DHOG) [3], image gradient direction histogram (HIGO) [4] and directional optical flow histogram (HOOF) [5] and their variations. With advances in intelligent technologies like machine learning, artificial intelligence, and powerful hardware design (graphics processing units (Gpus)), many deep learning models, such as VGG [6], AlexNet [7], Google Net [8] and ResNet [9], have made great progress in micro-expression recognition (MER). Recently, Swin Transformer [10] has caught extraordinary attentions of the computer vision community for its powerful ability on long-term relative representation. By dividing the image into smaller patches, a two-dimensional image can be converted into a one-dimensional sequence, and thus spatial relations can be converted into relationships between sequence elements. It is worth mentioning that Swin Transformer adopts the method of calculating self-attention in the local window, which greatly reduces the calculation amount while extracting the image information with different scales. Meanwhile, in order to solve the problem of lack of connection between different windows, Swin Transformer introduced a shifted window segmentation strategy. By this means, Swin Transformer network can be applied to solve visual related problems. Furthermore, its performance is better than CNN on various vision tasks. In this work, we are the first to apply the Swin transformer network to MER for diverse motion features extraction. The main contributions of our model are as follows: (1) We propose a Swin Transformer based framework with dual-branch fusion for MER. The model is completely relied on Swin Transformer, instead of CNN at any stage. (2) For the first flow in the dual-branch framework, we implement the two frames with the largest difference in micro-expression (Onset frame and Apex frame) to calculate the optical flow map and input it into Swin transformer to extract the motion information of ME. (3) For the second stream, we feed the image of apex frame into Swin Transformer to extract the spatial information of micro expression. (4) The dual-branch structure is adopted to fuse the temporal and spatial feature information from the two Swin Transformer on the channel before classification. It can enrich learning representation on ME, increase the indexes for judgment, and thus achieve higher accuracy for MER. The recognition performance of the proposed method on three independent datasets (SMIC-HS [11], CASME II [12], SAMM [13]) is better than that of the other related methods. It can be inferred the proposed dual branch ME recognition model is effective and robust for ME representation.

2 Related Works Due to the low intensity, short duration, local locality and other characteristics of microexpression, the main steps of recognition can be divided into: establishment databases, ROI detection and pre-processing, robust spatial-temporal feature extraction, recognition. At present, in order to carry out research on micro expression recognition,

546

Z. Xie and C. Zhao

researchers have mainly established SMIC-HS [11], CASMEII [12] and SAMM [13] and other micro expression databases. In the process of micro expression recognition, robust spatial-temporal feature representation is a very important step in micro expression recognition, which directly affects the performance of subsequent expression classification, and is also the focus and difficulty of current micro expression recognition research. From the perspective of feature extraction, the research methods on MER mainly fall into two categories: (1) Traditional manual feature extraction schemes; (2) Deep neural networks-based methods. Traditionally, hand-crafted features have been developed to describe different facial expressions, such as Local Binary Pattern-Three Orthogonal Planes (LBP-TOP) [2], 3 Dimensional Histograms of Oriented Gradients (3DHOG) [3], Histograms of Image Gradient Orientation (HIGO) [4] and Histograms of Oriented Optical Flow (HOOF) [5] and their variations. Recently, deep learning has greatly improved MER research. In [14], dilated convolution is used to expand the receptive field, so as to extract the subtle features of micro-expression. In [15], the attention module is combined with the residual network to prioritize important channels and spatio-temporal features. In [1], horizontal optical flow sequence, vertical optical flow sequence and gray image sequence are channeled to form a new three-channel micro-expression image sequence, which not only ensures network performance, but also has low complexity, thus greatly reducing the number of parameters and training time. With respect to Transformer, Zhang et al. [16] applied Transformer to micro-expression recognition and proposed MSAD to learn the relationships between various facial parts. In [17], the deep learning framework of Transformer is adopted, and CNN is completely not used, which can learn the short-term and long-term relationship between pixels in the spatial and temporal directions of sample videos. Zhao, Xinhua et al. [18] used two-branch networks to extract emotional features, in which optical flow, differential image, and apex frame represent short and long terms relation and spatial features respectively, and the three features are integrated. Zhu Jie et al. [19] proposed to use a sparse Transformer to extract the spatial-temporal features in ME images. The aforementioned approaches are significantly improved compared with the models based on CNN. It can be inferred that the Transformer is effective and promising tool for MER. The main motivation of this work is how to improve the diverse learning representation ability on diverse ME information using Swin Transformer as the backbone.

3 Proposed Method In this paper, considering the characteristics of ME and the limitations of existing MER methods, a dual-branch deep learning network based on Swin Transformer is proposed. The detail framework of our network is shown in Fig. 1. Specifically, the network designs a dual-branch architecture to extract the temporal and spatial information of microexpression respectively. In one branch, the optical flow method is used to calculate the motion information between the onset frame and the apex frame of the video sequence by inferring the direction and amplitude of the motion of the imaging object according to the changes in the appearance of the pixels between frames. Since the Swin Transformer network can capture global information, the second branch of the network inputs the apex

Micro-expression Recognition

547

frames, containing the most important information in the video sequence, to another Swin Transformer to capture the spatial informative features in ME. The Swin Transformers in the two branches have the same network component, but do not share weights. The maps extracted from two branches are concatenated by the channel-wise means to achieve the straightforward fusion in spatial and temporal domains.

Fig. 1. The framework of proposed dual deep learning network

3.1 Motion Feature Extraction Using Optical Flow Generally, the optical flow processing can extract representative motion features between adjacent frames of micro-expression. Compared with the original pixel data, it can obtain higher signal-to-noise ratio and provide rich and important features in adjacent frames. At the same time, it can effectively reduce the domain difference of different data sets so as to play an important role in improving the performance of ME motion information extraction across databases. Typically, optical flow extraction is based on the principle of constant brightness to estimate moving objects in video and extract motion features between adjacent frames. Suppose that the pixel intensity at (x, y) in frame t of a micro-expression sample sequence is I (x, y, t), and after time t, it moves by (x, y) to frame t + 1 of the micro-expression sequence. At this time, the pixel intensity is I (x + x, y + y, z + z). According to the principle of constant brightness, we can obtain: I (x, y, t) = I (x + x, y + y, z + z)

(1)

By Taylor series expansion of Eq. (1), we can obtain: I (x + x, y + y, t + t) = I (x, y, t) +

∂I ∂I ∂I x + y + t + ε ∂X ∂y ∂t

(2)

548

Z. Xie and C. Zhao

where, ε stands for the second order infinitesimal, which can be neglected. Then substituting Eq. (2) into Eq. (1) and dividing by t, we can get: ∂I x ∂I y ∂I t + + =0 ∂x t ∂y t ∂t t

(3)

Let p and q be the velocity components of pixels along the horizontal and vertical directions respectively, then: p=

y x ,q = t t

(4)

∂I ∂I Let Ix = ∂x , Iy = ∂y , It = ∂I ∂t respectively represent the partial derivatives of pixel intensity at (x, y, t) along each direction, then Eq. (3) can be expressed as:

Ix p + Iy q + It = 0

(5)

Among them, Ix , Iy, It can be obtained from the image data, and (p, q) is the microexpression optical flow estimation vector, which contains the magnitude and direction of each pixel’s motion. 3.2 Spatial Feature Representation Based on Swin Transformer In recent years, CNNs have become the backbone of most deep learning algorithms in computer vision. However, the convolution is always filtered on fixed-size windows, so it will neglect the long-term relationships between pixels. Fortunately, the concept of Transformer, originally introduced in the context of NLP, relies on a self-attention mechanism to capture long-term dependencies between sequence elements. For this reason, Transformer based deep learning architectures have recently received increasing attention from the computer vision community and are beginning to play a vital role in many computer vision tasks. To extend the applicability of Transformer to computer vision, we explored general Transformer backbone, called Swin Transformer [10] for MER. Compared with the architecture based on Transformer, it builds hierarchical feature maps and increases the shifted window partition between successive self-attention layers, which effectively reduces the computational complexity and enhances the modeling ability. 3.2.1 Swin Transformer Blocks Figure 2 is the framework of two Successive Swin Transformer Blocks. Especially, Swin Transformer replaces the multi-head auto-attention (MSA) in Transformer with a shifted window. It consists of a shifted window-based MSA module followed by a 2-layer MLP (Multilayer Perceptron). A Layer Norm (LN) layer is applied before each MSA module and each MLP, and a residual join is applied after each module. The specific calculation process is shown from Eqs. (6) to (9). z 1 = W − MSA(LN (z l−1 )) + z l−1

(6)

z 1 = MLP(LN (z 1 )) + z 1

(7)

Micro-expression Recognition

549

Fig. 2. The detailed Structure of two Successive Swin Transformer Blocks

z l+1 = SW − MSA(LN (z 1 )) + z 1

(8)

z l+1 = MLP(LN (z l+1 )) + z l+1

(9)

where zl and z l represent the output characteristics of (S)W-MSA module and MLP module of block L, respectively. W-MSA and SW-MSA stand for window-based multihead auto-attention configured with regular and shifted window partitions, respectively. 3.2.2 Self-attention Based on Shifted Windows To solve the problem of inconsistent window size after moving, Swin Transformer proposed a method of shifting window segmentation (as shown in Fig. 3). A batch window contains several non-adjacent sub-windows in the feature map. Therefore, the masking mechanism is adopted to limit the self-attention computation to each sub-window. The number of batch windows is the same as the number of regular window partitions, which effectively alleviates this problem.

Fig.3. Shift window segmentation

550

Z. Xie and C. Zhao

4 Experiments and Evaluation In this section, we present the experiments to evaluate the proposed MER model. Additionally, the databases, the experimental results and analysis are elaborated to verify the conclusion. 4.1 Databases In this paper, experiments are conducted on three publicly available and spontaneous micro-expression datasets and fusion datasets, namely SMIC [11], CASMEII [12] and SAMM [13]. The details of three micro-expression datasets are listed in Table 1. Table 1. The details of the micro-expression datasets Datasets

SMIC

CASMEII

SAMM

Fusion datasets

Release time

2013

2014

2018

2019

Participants

16

24

28

68

Sample size

164

147

133

444

Positive

51

32

26

109

Negative

70

90

92

252

Surprised

43

25

15

83

Frames per second

100

200

200

--

To avoid confusion and complexity in categories when the three datasets are combined, each sample is re-labeled as positive, negative, and surprised. After merging the datasets according to the new general category, the fused datasets contain 444 ME samples from 68 participants (16 from SMIC, 24 from CASMEII, and 28 from SAMM). Thus, the distribution of ME samples is diverse and more consistent with the real scene. 4.2 Experimental Results and Analysis 4.2.1 Comparisons with the Related Methods Table 2 shows the comparison between the micro expression recognition (MER) method proposed in this paper and other recognition methods. We compare the experimental results by calculating accuracy (Acc) and F1-score (F1). S Ti Acc =  Si=1 × 100 (10) N i=1 i where Ti is the number of the i-th subject correct prediction samples and Ni is the number of the i-th testing total samples. F1 − score =

2pi × ri pi + ri

(11)

Micro-expression Recognition

551

where pi and ri represent the precision and recall of the i-th ME class, respectively. It can be seen from the Table 2 that the recognition rate of our model on three independent datasets has been improved with various degrees compared with other methods, which illustrates that the proposed method has strong spatial-temporal feature extraction and generalization capabilities for ME. Table 2. Comparison results with other methods Methods

CASME II

SAMM

Acc

F1

Acc

F1

SMIC Acc

F1

AlexNet

62.96

66.75

52.94

42.60

59.76

60.13

DSSN [20] (2019)

70.78

72.97

57.35

46.44

63.41

64.62

Micro-attention [21] (2020)

65.90

53.90

48.50

40.20

49.40

49.60

Dynamic [21] (2020)

72.61

67.00

--

--

76.06

71.00

GEME [22] (2021)

75.20

73.54

55.88

45.38

64.63

--

LFM [23] (2021)

73.98

71.65

--

--

71.34

71.34

Sparse Transformer [19] (2022)

76.11

71.92

80.15

75.47

--

--

MFSTrans [18] (2022)

81.50

80.09

--

--

79.99

78.12

LFVTrans [24] (2022)

70.68

71.06

--

--

73.17

74.47

Ours

83.90

88.90

80.70

66.70

74.20

75.00

4.2.2 The Influence of Optical Flow Operator As we know, the optical flow method can calculate the relative motion between pixels based on the principle of constant brightness. To analyze the impact of optical flow map on our MER model, we compare the experimental results of RGB stream (generated by apex frame) directly input to Swin Transformer for classification and the experimental results of optical flow map (generated by onset frame and apex frame calculation) directly input to Swin Transformer for classification. Following this idea, if the latter is significantly improved than the former, it can confirm feasibility of optical flow processing. The comparison results are visualized in Fig. 3. As shown in Fig. 4, the optical flow feature can improve the accuracy of MER by 11.50%, 10.40%, 5.80% and 8.90% on the full fusion dataset, CASMEII, SAMM and SMIC, respectively. It is believed that optical flow feature can effectively extract ME action information, thus improving the performance of micro-expression recognition. 4.2.3 The Validity of the Dual-Branch Framework This paper adopts dual-branch to extract the motion feature information and spatial feature information of ME respectively, and fuses them by the channel-wise means to ensure the more comprehensive information and the more accurate classification.

552

Z. Xie and C. Zhao

AblaƟon experiments of opƟcal flow 100.0 Accuracy(%)

80.0

73.6 62.1

72.4

82.8 69.2

75.0 53.6

60.0

62.5

40.0 20.0 0.0 Fusion Datasets

CASME2

SAMM

RGB branch

SMIC OpƟcal flow branch

Fig. 4. Ablation experiments on optical flow

AblaƟon experiments of each branch 100.0 Accuracy(%)

80.0 60.0

83.9 73.6 62.1

93.1 82.8 72.4

80.7 75.0 69.2

CASME2

SAMM

74.2 62.5 53.6

40.0 20.0 0.0 Fusion Datasets RGB branch

OpƟcal flow branch

SMIC Dual-branch

Fig. 5. Ablation experiments on each branch

To validate the complementary of the two branches for MER, individual branch is applied to classify micro-expression separately, that is, the optical flow map processed by optical flow operator is input to Swin Transformer for direct classification, and the apex frames in the video sequence are also directly fed into Swin Transformer for classification. Comparing the experimental results of two branches’ respective classification with the experimental results of classification after fusion (as shown in Fig. 5), it is obvious that the fusion effect of dual-branch is better than any single branch. The main reason is that the dual-branch features fully integrate the spatial and temporal information and leads to improvement on MER performance.

Micro-expression Recognition

553

5 Conclusion In this paper, MER framework based on Swin Transformer is explored for diverse information learning. The dual-branch structure is adopted to extract the temporal and spatial features of ME respectively. Moreover, feature fusion is carried out on the channel to enrich the feature description on ME. It makes the classification basis be more convincing. The extensive experiment results on three public datasets show that the accuracy of the proposed method is greatly improved compared with the benchmark, which verifies the effectiveness and feasibility of the dual-branch MER framework based on Swin Transformer. Acknowledgements. This paper is supported by the Natural Science Foundation of Jiangxi Province of China (No. 20224ACB202011), the National Nature Science Foundation of China (No. 61861020) and the Jiangxi Province Graduate Innovation Special Fund Project (No. YC2022-s790).

References 1. Hong, T., Longjiao, Z., Sen, F.: Micro-expression recognition based on optical flow method and pseudo-3D residual network. J. Signal Process. 38(5), 13–21 (2022) 2. Zhao, G., Pietikainen, M.: Dynamic texture recognition using local binary patterns with an application to facial expressions. IEEE Trans. Pattern Anal. Mach. Intell. 29(6), 915–928 (2007) 3. Polikovsky, S., Kameda, Y., Ohta, Y.: Facial micro-expressions recognition using high speed camera and 3D-gradient descriptor. In: International Conference on Crime Detection & Prevention. IET (2010) 4. Li, X., Hong, X., Moilanen, A., et al.: Towards reading hidden emotions: a comparative study of spontaneous micro-expression spotting and recognition methods. IEEE Trans. Affect. Comput. 9(4), 563–577 (2018) 5. Jin, Y., Kai, J., et al.: A main directional mean optical flow feature for spontaneous microexpression recognition. IEEE Trans. Affect. Comput. 7(4), 299–310 (2016) 6. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 7. Krizhevsky, A., Sutskever, I., Hinton, G.: ImageNet classification with deep convolutional neural networks. Commun. ACM 60(6), 84–90 (2017) 8. Szegedy, C., Liu, W., Jia, Y.: Going deeper with convolutions. In; IEEE Conference on Computer Vision and Pattern Recognition, pp. 1–9 (2015) 9. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. In; IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 10. Liu, Z., Lin, Y., Cao, Y., et al.: Swin transformer: Hierarchical vision transformer using shifted windows. In; IEEE/CVF International Conference on Computer Vision, pp. 10012–10022 (2021) 11. Li, X., Pfister, T., Huang, X., et al: A spontaneous micro-expression database: inducement, collection and baseline. In: 2013 10th IEEE International Conference and Workshops on Automatic face and gesture recognition (FG). IEEE, pp. 1–6 (2013) 12. WenJing, Y., Xiaobai, L., Su-Jing, W.: CASME II: an improved spontaneous micro-expression database and the baseline evaluation. PLoS ONE 9(1) (2014)

554

Z. Xie and C. Zhao

13. Davison, A.K., Lansley, C., Costen, N., Tan, K.: SAMM: a spontaneous micro-facial movement dataset. IEEE Trans. Affect. Comput. 9(99), 116–129 (2018) 14. Zhenyi, L., Renhe, C., Yurong, Q.: CNN real-time micro-expression recognition algorithm based on dilated convolution. Appl. Res. Comput. 37(12), 5–13 (2020) 15. Gajjala, V.R., et al.: MERANet: facial micro-expression recognition using 3D residual attention network. In: The Twelfth Indian Conference on Computer Vision, Graphics and Image Processing, pp. 1–10 (2021) 16. Xue, F., Wang, Q., Guo, G.: Transfer: Learning relation-aware facial expression representations with transformers. In: IEEE/CVF International Conference on Computer Vision, pp. 3601–3610 (2021) 17. Zhang, L., Hong, X., Arandjelovi´c, O., et al.: Short and long range relation based spatiotemporal transformer for micro-expression recognition. IEEE Trans. Affect. Comput. 13(4), 1973–1985 (2022) 18. Zhao, X., Lv, Y., Huang, Z.: multimodal fusion-based swin transformer for facial recognition micro-expression recognition. In: 2022 IEEE International Conference on Mechatronics and Automation (ICMA). IEEE, pp. 780–785 (2022) 19. Zhu, J., Zong, Y., Chang, H., et al.: A sparse-based transformer network with associated spatiotemporal feature for micro-expression recognition. IEEE Signal Process. Lett. 29, 2073– 2077 (2022) 20. Khor, H.Q., See, J., Liong, S.T., et al.: Dual-stream shallow networks for facial microexpression recognition. In: 2019 IEEE International Conference on Image Processing (ICIP). IEEE, pp. 36–40 (2019) 21. Chongyang, W., Min, P., Tao, B., Tong, C.: Micro-attention for micro-expression recognition. Neurocomputing 410, 354–362 (2020) 22. Bo, S., Siming, C., Dongliang, L., Jun, H., Lejun, Y.: Dynamic micro-expression recognition using knowledge distillation. IEEE Trans. Affect. Comput. 13, 1037–1043 (2020) 23. Nie, X., Takalkar, M.A., Duan, M., et al.: GEME: dual-stream multi-task GEnder-based micro-expression recognition. Neurocomputing 427, 13–28 (2021) 24. Liu, Z., Ning, J., Cao, Y., et al.: Video swin transformer. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3202–3211 (2022) 25. Hong, J., Lee, C., Jung, H.: Late fusion-based video transformer for facial micro-expression recognition. Appl. Sci. 12(3) (2022)

CC-DBNet: A Scene Text Detector Combining Collaborative Learning and Cascaded Feature Fusion Wenheng Jiang1 , Yuehui Chen2(B) , Yi Cao3 , and Yaou Zhao2 1 School of Information Science and Engineering, University of Jinan, Jinan 250024, China 2 Artificial Intelligence Institute (School of Information Science & Engineering), University of

Jinan, No. 336, Jinan, China [email protected] 3 Shandong Provincial Key Laboratory of Network Based Intelligent Computing, School of Information Science & Engineering), University of Jinan, No. 336, Jinan, China

Abstract. In recent years, scene text detection technologies have received more and more attention and have made rapid progress. However, they also face some challenges, such as fracture detection in text instances and the problem of poor robustness of detection models. To address these issues, we propose a scene text detector called CC-DBNet. This detector combines Intra-Instance Collaborative Learning (IICL) and the Cascaded Feature Fusion Module (CFFM) to detect arbitrary-shaped text instances. Specifically, we introduce dilated convolution blocks in IICL, which expand the receptive fields and improve the text feature representation ability. We replace the FPN in DBNet ++ with a CFFM incorporating efficient channel attention (ECA) to utilize features of various scales better, thereby improving the detector’s performance and robustness. The results of the experiment demonstrate the superiority of the proposed detector. CC-DBNet achieves 88.1%, 86%, and 88.6% F-measure on three publicly available datasets, ICDAR2015, CTW1500, and MSRA-TD500, respectively, with 0.8%, 0.7%, and 1.4% improvement compared with the baseline DBNet ++, respectively. Keywords: Text detection · Cascaded feature fusion · Collaborative learning · Deep learning

1 Introduction The task of scene text detection involves the detection of text instances in images or frames extracted from videos captured in natural environments. Recent advances in hardware capable of processing large amounts of image data, combined with the development of deep learning, have significantly impacted research in this field [1–6]. Scene text detection techniques have become increasingly popular in recent years and are essential in various practical applications [4], which include intelligent transportation, unmanned vehicles, security surveillance, and text translation.

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 555–566, 2023. https://doi.org/10.1007/978-981-99-4742-3_46

556

W. Jiang et al.

Scene text detection is a challenging task due to factors like complex backgrounds, diverse text appearances, variable text sizes, orientations, and language variations. Various methods have been proposed to address these challenges [5, 6]. These methods can be broadly categorized into regression-based and segmentation-based approaches [4]. Regression-based methods draw inspiration from object detection techniques. For instance, Ma et al. [1] proposed Rotation Region Proposal Networks (RRPN), which focus on the orientation of text instances. Liao et al. [6] introduced TextBoxes++ for accurately detecting text in any direction. He et al. [5] developed a Deep Direct Regression (DDR) method to detect multi-oriented text instances. While regression-based methods have achieved significant progress in detecting horizontal text, they have limitations in detecting curved text and text of arbitrary shapes. In contrast, segmentation-based methods have shown outstanding performance by leveraging pixel-level prediction results to detect text instances of various shapes [3]. For example, PixelLink [7] links pixels in text instances and separates them using instance segmentation. PSENet [8] generates kernels of different scales and gradually dilates minimal-scale kernels to capture complete text shapes. PAN [2] combines a computationally efficient segmentation head with Pixel Aggregation (PA) to enhance the speed of scene text detection. ContourNet [9] integrates adaptive-RPN and a contour prediction branch to improve accuracy. DBNet [4] introduces differentiable binarization (DB) that simplifies post-processing by optimizing it during training. Building upon DBNet, DBNet++ [3] incorporates an Adaptive Scale Fusion (ASF) module for effective feature fusion. Segmentation-based methods have the advantage of detecting text of arbitrary shapes. However, they face challenges such as fracture detection and poor robustness [7–10]. This paper presents CC-DBNet, a scene text detector that stems from the DBNet++ architecture. CC-DBNet integrates improved Intra-Instance Collaborative Learning (IICL) modules and the improved Cascaded Feature Fusion Module (CFFM) to concern both the negative impact of fracture detection on detection accuracy and the robustness of the text detector. The DB is also incorporated to simplify the post-processing process. We evaluate the detector on three publicly available datasets and find that our method outperforms existing methods, as demonstrated by the experiment results. The contributions of this paper are summarized as follows: 1. The proposed CC-DBNet is a scene text detector that can detect arbitrary-shaped text instances and achieve competitive results on three publicly available datasets. 2. The improved Intra-Instance Collaborative Learning Module (IICL) is introduced in features extraction to enhance the uniform feature representation capability of the regions of gaps between characters within text instances, which can effectively solve fracture detection issues. 3. Considering the poor robustness of the text detector in natural scene images, an improved Cascaded Feature Fusion Module (CFFM) incorporating efficient channel attention (ECA) is inserted into CC-DBNet.

CC-DBNet

557

2 Methodology 2.1 Overall Architecture The proposed CC-DBNet is an extension of the original DBNet++ [3]. The architecture of CC-DBNet is shown in Fig. 1 and contains three parts: feature extraction, feature fusion, and post-processing. Take the ResNet-18 [15] backbone as an example. First, in the feature extraction process, the backbone is responsible for extracting the feature maps C2, C3, C4, and C5, , the channels of which are 64, 128, 256, and 512, respectively. Our modified Intra-Instance Collaborative Learning (IICL) module [10] is used to change channels to obtain the feature maps P2, P3, P4, and P5 with 128 channels, which are the feature pyramid Fr . Second, in the part of feature fusion, Fr is processed by up-scale and down-scale enhancements using the Feature Pyramid Enhancement Module (FPEM) [2] combined with efficient channel attention (ECA) [11]. Each FPEM can generate an enhanced Fn . . Then, the Feature Fusion Module (FFM) [2] fuses feature pyramids F1 , F2 , …, and Fn to obtain Ff . . Next, the Adaptive Scale Fusion (ASF) [3] is used to get the feature Fs with 512 channels. In the post-processing part, the feature Fs is used to predict the probability and the threshold maps. Via Differentiable Binarization (DB) [4], the approximate binary map is obtained, and based on this, the detection results are produced by the label generation algorithm [3].

Fig. 1. The overall architecture of our proposed CC-DBNet.

2.2 Intra-instance Collaborative Learning Scene text detectors often struggle with fracture detection, as depicted in Fig. 2(a). When characters are highly dispersed, the detection model fails to adequately represent the text features in the regions between characters, resulting in fractured detection. To address this issue, this work enhances the Intra-Instance Collaborative Learning (IICL) proposed by Du et al. [10] to improve the text feature representation capability and effectively alleviate fracture detection. This study treats the text instances as a sequence of characters and gaps, as illustrated in Fig. 2(b). The character regions are separated by gap regions, which are surrounded by character regions on both sides. To enhance the ability of text feature representation,

558

W. Jiang et al.

we facilitate information interaction between gaps and character regions. Precisely, the improved IICL consists of three convolutional blocks with different receptive fields, as shown in Fig. 2(c). Each convolutional block contains three parallel convolutional layers, including regular convolution, horizontal convolution, and vertical convolution, whose sizes are k ∗ k, 1 ∗ k, and k ∗ 1, respectively. Regular convolution focuses on the whole text region and learns a generic text feature representation. Meanwhile, unlike Du et al. [10], we use horizontal and vertical dilated convolutions [13] to focus specifically on the adjacent characters in horizontal and vertical text instances. The dilated convolution expands the receptive fields and further captures multi-scale text information by adjusting the dilation rate. Additionally, the IICL module uses a residual connection to help the model converge quickly.

Fig. 2. (a) An example of fracture detection. (b) The internal composition of text instance. (c) The architecture of the IICL.

2.3 Cascaded Feature Fusion Module Text instances in natural scene images have multiple orientations and various sizes. To address the above challenges and improve the robustness of the detector, we enhance the Cascaded Feature Fusion Module (CFFM) proposed by Wang et al. [2] as a replacement for the FPN [12] in DBNet++. The shallow features, rich in spatial and detailed information, and the in-depth features, rich in semantic and global information, are fully fused by CFFM. This fusion captures more scale information and further improves the detector’s robustness. The CFFM module comprises the Feature Pyramid Enhancement Module (FPEM) and the Feature Fusion Module (FFM). As shown in Fig. 3(a), the FPEM is a “U” shaped structure with up-scale and down-scale enhancements. The up-scale enhancement operation fully integrates shallow and deep features, and the down-scale enhancement operation generates the output feature by continuous scaling and element-wise addition. Unlike Wang et al. [2], we add efficient channel attention (ECA) [11] to the enhancement

CC-DBNet

559

operations, as shown in Fig. 3(b). The channel dimension C is known, and the kernel size k can be adaptively determined by:    log2 (C) b   (1) k = ψ(C) =  +  γ γ odd where |t|odd indicates the nearest odd number of t. Empirically, γ and b are set to 1 and 2, respectively. ECA implements information interaction between channels, retains valid information, and boosts model performance with minimal parameter overhead. In addition, we replace depth-wise convolution in FPEM with dilated convolution [13], which expands the receptive field.

Fig. 3. The details of FPEM (a) and ECA (b). In (a) figure, arrows with dashed lines mean that ECA is used.

The structure of FFM is shown in Fig. 4. First, the n feature pyramids F1 , F2 , …, and Fn produced by FPEM are combined by element-wise addition. Then, the three last feature maps in the combined feature pyramid Fc are upsampled by 2, 4, and 8, respectively, so they are in the same resolution. Next, all the feature maps in Fc are concatenated to form the feature map Ff . Finally, the Adaptive Scale Fusion (ASF) [3] is used to obtain the feature map Fs . 2.4 Differentiable Binarization Binarization is a critical stage in segmentation-based scene text detection. The standard binarization method is not differentiable, therefore, cannot be added to the network training process, as shown in Eq. (2):  1, Pi,j ≥ t (2) Bi,j = 0, other

560

W. Jiang et al.

Fig. 4. The architecture of the FFM.

where t denotes the set threshold value, and (i, j) means the coordinate points. This study applies the differentiable binarization (DB) [4] method. After the ASF module [3], the feature pyramid Fs is used to predict the probability map P and the threshold map T. Then the relationship between P and T is established by Eq. (3) to generate an approximate binary map B: 

Bˆ i,j =

1 1 + e−k (Pi,j −Ti,j )

(3)

where k is an amplification factor set to 50, and (i, j) denotes the coordinate points. Formal, DB is differentiable everywhere within its defined domain and can be optimized as the network is trained. DB helps distinguish text instances from the background and split text instances that are very close together. 2.5 Label Generation This section refers to the label generation method in PSENet [8]. As shown in Fig. 5, each text instance in the image can be represented by a polygon G: G = {Sk }nk=1

(4)

where n is the number of vertices of the G. The label GS of the probability and the approximate binary maps are obtained by shrinking the polygon G by the Vatti clipping algorithm [16]. The shrinkage offset D is calculated by:   A 1 − r2 (5) D= L where r is the shrink ratio, which is empirically set to 0.4. L and A are the perimeter and area of the polygon G, respectively. The label generation process of the threshold map is as follows: first, the original polygon G is dilated with the offset D to obtain the GL . The gap between GS and GL is the boundary of the text regions, where the label of the threshold map can be generated by calculating the distance to the nearest segment in G. The text bounding boxes generated by the probability and the approximate binary maps are almost identical [3]. Considering the efficiency, we use the probability map to

CC-DBNet

561

generate text bounding boxes during inference. The generation of text boxes involves three steps: (1) The probability map is binarized with a constant threshold to obtain a binary map. (2) The shrunk text regions are obtained based on the binary map. (3) The shrunk regions are dilated using the offset D from the Vatti clipping algorithm [16]. The calculation method for the offset D used for expansion is as follows: D =

A × r  L

(6)

where A and L are the area and perimeter of the shrunk polygon, respectively. r is a constant set to 1.5.

Fig. 5. The process of label generation.

2.6 Optimization The loss function L in this paper consists of Ls , Lb , and Lt : L = Ls + α × Lb + β × Lt

(7)

Ls , Lb , and Lt represent the loss of probability map, approximate binary map, and threshold map, respectively. Empirically, the weight coefficients α and β are set to 1.0 and 10, respectively. The calculation methods for Ls and Lb are as follows:  Ls = Lb = yi log xi + (1 − yi ) log(1 − xi ) (8) i∈Sl

where Sl indicates sampling using OHEM [17] with a ratio of positive to negative samples of 1:3. The loss Lt of the threshold map uses L1 loss function:   y ∗ − x ∗  (9) Lt = i i i∈Rd

where Rd is the index set of pixels inside the dilated polygon and y∗ denotes the label for the threshold map.

562

W. Jiang et al.

3 Experiments 3.1 Datasets SynthText [18] is a synthetic dataset consisting of 800 k images. This dataset is labeled as individual words and is an English text dataset. It is used for pre-training. CTW1500 [19] is a curved text detection dataset that includes 1000 training images and 500 testing images. This dataset is a pure English text dataset labeled in the form of text lines. Polygons represent text instances. MSRA-TD500 [20] is a multi-language text detection dataset that contains 300 training images and 200 testing images. This dataset is labeled in the form of text lines. ICDAR2015 [21] is a multi-oriented text detection dataset that includes 1000 training images and 500 testing images. The images in this dataset are blurry and contain many text instances with large-scale variations. This dataset is labeled in the form of individual words and is a pure English text dataset. 3.2 Implementation Details This study is based on the deep learning framework PyTorch, with a backbone network of ResNet-18 or ResNet-50 [15], with deformable convolution [14]. First, pre-training is performed on the SynthText dataset with 100k iterations, and then we fine-tune it on the other datasets with 1200 epochs. The training batch size is set to 16, and the learning rate lr is continuously reduced by: p  epoch l r = l0 ∗ 1 − (10) max_epoch where l0 is set to 0.007, and p is 0.9. Stochastic Gradient Descent (SGD) is used to optimize the model during training. The weight decay and momentum are set to 0.0001 and 0.9, respectively. The number of cascaded FPEMs is set to 2. The data augmentation used in the training process is random rotation with an angle range of (−10°, 10°), random cropping, and random flipping. The size of all processed images is resized to 640 × 640 to improve the training efficiency. 3.3 Evaluation Metrics The evaluation metrics are commonly used in object detection tasks. When the predicted bounding box and the ground truth have an Intersection over Union (IoU) over a set threshold (e.g., 0.5), the detection is considered a true positive (TP). The standard evaluation metrics Recall, Precision, and F-measure are calculated as follows: TP TP + FP TP Recall = TP + FN 2 × Precision × Recall F − measure = Precision + Recall Precision =

(11) (12) (13)

CC-DBNet

563

3.4 Ablation Study We conducted ablation studies on the ICDAR2015 and CTW1500 datasets with DBNet++ as the baseline to show the effectiveness of the improved IICL and CFFM. The detailed experiment results are shown in Table 1. Table 1. Detection results with different settings. “P”, “R”, and “F” indicate Precision, Recall, and F-measure, respectively. Backbone ResNet-18 ResNet-18 ResNet-18 ResNet-18 ResNet-50 ResNet-50 ResNet-50 ResNet-50

IICL

CFFM

ICDAR2015 P R 90.1 77.2 89.5 79.4 77.7 90.3 90.1 80.5 90.9 83.9 91 84.6 90.7 84.3 91.7 84.8

F 83.1 84.2 83.5 85.1 87.3 87.7 87.4 88.1

CTW1500 P R 86.7 81.3 86.8 82 86.5 81.9 87.6 82.7 87.9 82.8 83 88.7 88.6 82.4 88.1 84

F 83.9 84.3 84.1 85.1 85.3 85.8 85.4 86

The Effectiveness of IICL: As shown in Table 1, IICL brings improvement because it expands the receptive field in feature extraction and alleviates the text fracture detection problem. On the ICDAR2015 dataset, IICL achieves an F-measure improvement of 1.1% and 0.4% on ResNet-18 and ResNet-50 backbones, respectively. On the CTW1500 dataset, compared to the model without IICL, IICL achieves an F-measure improvement of 0.4% and 0.5% on ResNet-18 and Res-Net-50 backbones, respectively. The Effectiveness of CFFM: Table 1 shows the improved CFFM of this study also enhances the performance of ResNet-18 and ResNet-50 on both datasets. On the ICDAR2015 dataset, CFFM achieves a performance improvement of 0.4% in F-measure for the ResNet-18 backbone and 0.4% in recall for the ResNet-50 backbone. On the CTW1500 dataset, CFFM achieves a 0.2% performance improvement in F-measure for the ResNet-18 backbone and 0.7% in precision for the ResNet-50 backbone. The Effectiveness of Backbone: The improved CFFM combined with the improved IICL module shows performance improvements on different backbones. For the ICDAR2015 dataset, with the same settings, CC-DBNet improves by 2% and 0.8% in F-measure with the ResNet-18 backbone and the ResNet-50 backbone, respectively. For the CTW1500 dataset, with the same settings, CC-DBNet achieves an improvement of 1.2% and 0.7% in F-measure with the ResNet-18 backbone and the ResNet-50 backbone, respectively.

3.5 Comparisons with Previous Methods We compare our proposed method with previous methods on three publicly available datasets. Figure 6 and Table 2 show some detection results.

564

W. Jiang et al.

Fig. 6. Some visualization results from CC-DBNet and DBNet ++ for text instances of different shapes and languages, with the backbone network using ResNet-50. For each column, the results from DBNet ++ are shown above, and those from CC-DBNet are shown below. Table 2. Detection results on ICDAR 2015, CTW1500, and MSRA-TD500 datasets, backbone is ResNet-50. Methods

ICDAR2015

CTW1500

MSRA-TD500

P

R

F

P

R

F

P

R

F

TextSnake [22]

80.4

84.9

82.6

82.7

77.8

80.1

84.2

81.7

82.9

EAST [23]

83.3

78.3

80.7

78.7

49.7

60.4

87.3

67.4

76.1

PSENet [8]

88.7

85.5

87.1

84.8

79.7

82.2

-

-

-

PAN [2]

84.0

81.9

82.9

86.4

81.2

83.7

84.4

83.8

84.1

LOMO [24]

91.3

83.5

87.2

85.7

76.5

80.8

-

-

-

CRAFT [25]

89.8

84.3

86.9

86.0

81.1

83.5

87.6

79.9

83.6

ABCNet v2 [26]

90.4

86.0

88.1

85.6

83.8

84.7

89.4

81.3

85.2

DBNet [4]

91.8

83.2

87.3

86.9

80.2

83.4

91.5

79.2

84.9

DBNet++ [3]

90.9

83.9

87.3

87.9

82.8

85.3

91.5

83.3

87.2

CC-DBNet

91.7

84.8

88.1

88.1

84

86

92.5

85.1

88.6

Multi-oriented Text Detection: ICDAR2015. According to the results shown in Table 2, CC-DBNet performs better regarding F-measure on the multi-oriented dataset. Specifically, CC-DBNet achieves a 7.4% improvement over EAST [23] and a 5.2% improvement over PAN [2]. Curved Text Detection: CTW1500. As shown in Fig. 6(b) and Table 2, CC-DBNet mitigates the fracture detection problem on the curved text dataset. Moreover, CC-DBNet performs better in terms of F-measure and recall, and precision. For example, CC-DBNet

CC-DBNet

565

improves the F-measure by 2.5% compared to CRAFT [25] and the precision by 2.4% compared to LOMO [24]. Multi-language Text Detection: MSRA-TD500. Table 2 and Fig. 6(c) show the robustness and higher performance of CC-DBNet on multi-language text datasets. Specifically, the precision of CC-DBNet++ by 8.1% compared to PAN [2], the recall of CC-DBNet++ by 3.8% compared to ABCNet v2 [26], and the F-measure of CC-DBNet++ is 1.4% over DBNet++ [3].

4 Conclusion In this paper, we propose a new scene text detector CC-DBNet, that can detect arbitraryshaped text instances. CC-DBNet combines Intra-Instance Collaborative Learning (IICL) modules and an improved Cascaded Feature Fusion Module (CFFM) to mitigate the text fracture detection problem and enhance the robustness of the detection model. We perform ablation studies to demonstrate the effectiveness of CFFM and IICL. In addition, experiments conducted on the ICDAR2015, CTW1500, and MSRA-TD500 datasets show that CC-DBNet outperforms other state-of-the-art scene text detection methods and effectively detects text instances in complex scenes. Acknowledgments. This work was supported in part by the University Innovation Team Project of Jinan (2019GXRC015), Shandong Provincial Natural Science Foundation, China (ZR2021MF036).

References 1. Ma, J., et al.: Arbitrary-oriented scene text detection via rotation proposals. IEEE Trans. Multimedia 20(11), 3111–3122 (2018) 2. Wang, W., Xie, E., Song, X., et al.: Efficient and accurate arbitrary-shaped text detection with pixel aggregation network. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8440–8449 (2019) 3. Liao, M., Zou, Z., Wan, Z., Yao, C., Bai, X.: Real-time scene text detection with differentiable binarization and adaptive scale fusion. IEEE Trans. Pattern Anal. Mach. Intell. 45(1), 919–931 (2022) 4. Liao, M., Wan, Z., Yao, C., Chen, K., Bai, X.: Real-time scene text detection with differentiable binarization. Proc. AAAI Conf. Artif. Intell. 34(7), 11474–11481 (2020) 5. He, W., Zhang, X.Y., Yin, F., Liu, C.L.: Deep direct regression for multi-oriented scene text detection. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 745–753 (2017) 6. Liao, M., Shi, B., Bai, X.: TextBoxes++: a single-shot oriented scene text detector. IEEE Trans. Image Process. 27(8), 3676–3690 (2018) 7. Deng, D., Liu, H., Li, X., Cai, D.: Pixellink: detecting scene text via instance segmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence (2018) 8. Wang, W., Xie, E., Li, X., et al.: Shape robust text detection with progressive scale expansion network. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9336–9345 (2019)

566

W. Jiang et al.

9. Wang, Y., Xie, H., Zha, Z., et al.: Contournet: Taking a further step toward accurate arbitraryshaped scene text detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11753–11762 (2020) 10. Du, B., Ye, J., Zhang, J., Liu, J., Tao, D.: I3cl: intra-and inter-instance collaborative learning for arbitrary-shaped scene text detection. Int. J. Comput. Vision 130(8), 1961–1977 (2022) 11. Wang, Q., Wu, B., Zhu, P., et al.: ECA-Net: efficient channel attention for deep convolutional neural networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11534–11542 (2020) 12. Lin, T.Y., Dollár, P., Girshick, R., et al.: Feature pyramid networks for object detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2117– 2125 (2017) 13. Yu, F., Koltun, V.: Multi-scale context aggregation by dilated convolutions. arXiv preprint arXiv:1511.07122 (2015) 14. Zhu, X., Hu, H., Lin, S., Dai, J.: Deformable convnets v2: More deformable, better results. In: Proceedings of the IEEE/CVF Conference On Computer Vision and Pattern Recognition, pp. 9308–9316 (2019) 15. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 770–778 (2016) 16. Vatti, B.R.: A generic solution to polygon clipping. Commun. ACM 35(7), 56–63 (1992) 17. Shrivastava, A., Gupta, A., Girshick, R.: Training region-based object detectors with online hard example mining. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 761–769 (2016) 18. Gupta, A., Vedaldi, A., Zisserman, A.: Synthetic data for text localisation in natural images. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2315– 2324 (2016) 19. Liu, Y., Jin, L., Zhang, S., Luo, C., Zhang, S.: Curved scene text detection via transverse and longitudinal sequence connection. Pattern Recogn. 90, 337–345 (2019) 20. Yao, C., Bai, X., Liu, W., Ma, Y., Tu, Z.: Detecting texts of arbitrary orientations in natural images. In: 2012 IEEE Conference on Computer Vision and Pattern Recognition, pp. 1083– 1090 (2012) 21. Karatzas, D., Gomez-Bigorda, L., Nicolaou, A., et al.: ICDAR 2015 competition on robust reading. In: 2015 13th International Conference on Document Analysis and Recognition (ICDAR), pp. 1156–1160 (2015) 22. Long, S., Ruan, J., Zhang, W., He, X., Wu, W., Yao, C.: TextSnake: a flexible representation for detecting text of arbitrary shapes. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11206, pp. 19–35. Springer, Cham (2018). https://doi.org/10. 1007/978-3-030-01216-8_2 23. Zhou, X., Yao, C., Wen, H., et al.: East: an efficient and accurate scene text detector. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5551– 5560 (2017) 24. Zhang, C., Liang, B., Huang, Z., et al.: Look more than once: an accurate detector for text of arbitrary shapes. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10552–10561 (2019) 25. Baek, Y., Lee, B., Han, D., et al.: Character region awareness for text detection. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9365–9374 (2019) 26. Liu, Y., Shen, C., Jin, L., et al.: Abcnet v2: Adaptive bezier-curve network for real-time end-to-end text spotting. IEEE Trans. Pattern Anal. Mach. Intell. 44(11), 8048–8064 (2021)

A Radar Video Compression and Display Method Based on FPGA Daiwei Xie(B) , Jin-Wu Wang(B) , and Zhenmin Dai(B) China State Shipbuilding Corporation Limited No. 723 Research Institute, Yangzhou 225001, China [email protected], [email protected], [email protected]

Abstract. A hardware implementation method based on FPGA for radar video’s compression is proposed. This method receives FC video through the high-speed serial port of FPGA. Through compression in a single frame, unify radar videos of different scales and points into 512 points. By compression of multi frames, reduce the brightness of noise, improve the contrast of the video and reduce the amount of data in the video. Then send the compressed radar video to DSP through SRIO, and DSP sends the compressed radar video to the display software by the network. Eliminate isolated noise through morphological operations when display the video. As a result, the target is easier to observe. The hardware implementation method based on FPGA has a simple structure, and the hardware compression method is processed in parallel, which is suitable for radar video’s display due to the characteristics of large data volume and fast data update rate. Image processing methods based on morphological method can eliminate isolated points effectively when there are many isolated points in the scene. Keywords: FPGA · Radar FC video · Compression · Open operation

1 Background The radar video transmitted in old-fashioned radar is an analog signal, which has the disadvantages of short transmission distance, susceptibility to interference, and small bandwidth. The new type of radar video is transmitted through optical fiber in the form of FC video, which is increasingly widely used due to its advantages of long transmission distance, low interference, and wide bandwidth. However, the amount of FC video data is too large, point extraction method by the software will consume a large amount of CPU resources. And it require corresponding FC video receiving hardware modules which makes hardware more complicated. FPGA is not only suitable for designing hardware interfaces, but also for hardware compression methods which can release the CPU and improve its work efficiency. FPGA’s parallel processing characteristics are suitable for radar video display due to due to the characteristics of large data volume and fast data update rate [1–8].

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 567–575, 2023. https://doi.org/10.1007/978-981-99-4742-3_47

568

D. Xie et al.

2 System Composition

FC Radar Video

Optical module

GTX

FPGA

SRIO

DSP

NET

Display software

Fig. 1. System composition

As shown in the Fig. 1, the external FC radar video is connected to the optical module and the optical module is connected to the high-speed serial port of FPGA. The optical module converts the optical signals which is transmitted in the optical fiber into electrical signals. FPGA analyses, processes and compresses the video. Finally, the processed radar video would be sends to the DSP through SRIO. The DSP sends each video to the display software through the network and receives commands from the display software to controls the operation of the FPGA through EMIF (Fig. 2).

Fig. 2. Frame format of FC protocol

The FC protocol consists of a frame start flag, frame content, and frame end flag, where the length of the frame content is composed of a frame header, numerical fields, and CRC verification.

3 Compression in a Single Frame Compression in a single frame refers to uniformly extract 512 points from different ranges and sampling points on the same azimuth code. The resolution ratio of the P-display is 1024 * 1024, with 512 points in any direction. However, the sampling accuracy and working distance of radar repetition frequency vary. As a result, and the number of points of radar video varies. The number of sampling points may exceed 512, such as 3333 for a 100 km range and 1667 for a 50 km range. The number of points (such as 3333 and 1667) may not be an integer multiple of 512 necessarily. In this situation, the method of selecting one from several points has a disadvantage of low accuracy.

A Radar Video Compression and Display Method Based on FPGA

569

In order to extract 512 sampling points evenly, this article adopts the lookup table method. Take 3333 points to extract 512 as an example, 512/3333 is approximate and less than 2/13. It can be seen that extract two points from 13 points, and the last point extracted is 3328. In this way, the relationship between the extracted points can be saved in a ROM. When points need to be taken, the IDs of 512 points are retrieved by reading the ROM, and then the corresponding 512 points can be obtained by reading the RAM that stores all points. And when the points need to be changed, just change the ROM’s initialization file which can achieve this goal. 3.1 Brightness Data Conversion In the message, the amplitude value of a single distance unit is 16 bits while the final brightness value data is 8 bits. So, it is necessary to transform 16bit data to 8bit data. The simplest method is to take the highest or lowest 8 bits, but this will cause the image to be too dark and too bright, and the uneven conversion will affect the final radar video display. Therefore, this article adopts a nonlinear approach: The relationship is as follows: f (x) = 53.12 ∗ log10 x, x ∈ [1, 65535]

(1)

From the above equation, it can be seen that x is 16bit data and f (x) is 8bit data after conversion. There is no IP core in FPGA that can take logarithms directly, but it can be obtained in a indirect way. There are two common methods for taking logarithms: 3.1.1 Using Look up Table Method Taking logarithms is essentially a mapping relationship, with a fixed total number of 65536 mapping relationships. It can be calculated by the lookup table method. Due to the mapping relationship is fixed, we can use a ROM to store this relationship. The input data is used as the address to read out the stored calculation results. The logarithmic result can be obtained within one clock. This is the idea of exchanging area for speed in FPGA. Firstly, calculate all possible results in MATLAB according to Eq. (1) and store the results in the format of a coe file. Then use the IP core of BlockRAM in FPGA to create a ROM with a depth of 65536 (16 bit) and a bit width of 8bit, and initialize the ROM with a coe file. When a 16 bit data is input, as a address, it searches for the corresponding stored data to obtain the calculation result. The advantage of the lookup table method is that it is easy to implement. It is reading and writing RAM essentially. The disadvantage is also obvious. When the input data has a large bit width such as 4 Byte, the storage address of the ROM is 232 , which is showed as an exponential increase in hardware overhead. Meanwhile, when multiple logarithmic operations are required in the program, although instantiation is convenient, it increases resource consumption inevitably.

570

D. Xie et al.

3.1.2 Using Algorithm (1) calculation by using Floating-Point IP The FPGA has an IPcore called Floating Point that handles natural logarithm, which can support various operations such as fixed point to floating point, floating point to fixed point, and floating point calculation (natural logarithm, square root, division, etc.). Therefore, Eq. (1) can be transformed into: f (x) = 22.99 ∗ ln x, x ∈ [1, 65535]

(2)

The whole logarithmic process of calculation is as follows: (1) Use Floating Point to convert the fixed point number of 16bit distance cell to floating point number (2) Calculate natural logarithm by using Floating Point IPcore (3) Convert floating point number to fixed point number (4) Calculate multiplication by shift addition. Calculate natural logarithm would use 4 DSP48E1s, other operations would not use. For 8 channels, 32 DSP48E1 would be used. So the resource consumption would be acceptable. Combining several methods for calculating logarithms, look up table method by using ROM storage, the depth of the ROM is 65536 (16 bit) and the bit width is 8bit. Considering that a FC video corresponds to 4 channels and has H and I bands. As a result, designed in this way, 8 ROMs are required, which will greatly consume FPGA’s LUT resources. Therefore, this paper uses calculation method by using floating point IPcore. 3.2 Process of Point Selection

Fig. 3. Module division of point extraction

As shown in Fig. 3, ROM_Lut is used to store point extraction relationships, RAM_ Store is used to store the converted amplitude of data in a frame, RAM_Msg is used

A Radar Video Compression and Display Method Based on FPGA

571

to store SRIO data to be sent. On the hardware side, the FC interface uses FPGA GTX high-speed serial port. By parsing the message module, the length, number, azimuth code, and data amplitude of the 4-way distance unit of the message can be parsed and stored in RAM_msg. The data amplitude of 2 bytes is converted into brightness values of 0–255 through logarithm and stored in RAM_store. ROM_LUT stores the mapping relationship of the extracted points. For a single video, the depth of the ROM is 512, the bit width is set to 12 bits, and the ID of the extracted point is stored. When the EOF signal of a frame of message is raised, indicating the end of the message, read ROM_Lut, the read data as address to read RAM_Store. Then obtain data of 512 points and write it to RAM_msg. When ROM_Lut has been read over, start reading RAM_Msg to send out the selected message to DSP through SRIO (Fig. 4). The flowchart is as follows:

Fig. 4. Flow Chart of Point Extraction Method 1

The point extraction method 1 is only related to the distance unit and is not related to the data amplitude. In situations where there is a small amount of data, there is a possibility of missing the target. On the basis of method one, method two is not only related to distance units, but also to data amplitude. The main difference is the maximum data amplitude of the selected point. For example, RAM_store stores the brightness values of the points are 10, 9, 8, 7, 6, 5, 4, 3, 2, 1; ROM_ LUT stores the ID number of the selected points is 4,9. The brightness values extracted by method one are 7 and 2. Method 2 is to take the maximum value of the first 4 points, the maximum value of 5–9, namely 10 and 6. Assuming RAM_store stores The brightness values of the points are 10, 9, 8, 0, 6, 5, 4, 3, 0, 1, and the results of method one are 0 and 0.Results using Method 2 are 10 and 6. This avoids the omission caused by the inability to press the point (Fig. 5).

572

D. Xie et al.

Fig. 5. Flow Chart of Point Extraction Method 2

4 Compression in Interframe The original azimuth code was 16 bit, which is 0–65535, and the final azimuth code was 11 bit, which is 0–2047. Therefore, there is a lot of redundancy in the azimuth code. These redundancies greatly increase the amount of data processed by the backend, which would affect the smoothness of the display. On the other hand, due to the existence of random noise, the contrast between the real target and the noise is affected. Based on the redundancy of inter frame videos, this paper proposes a method for compressing redundant videos based on weighted average in inter frame. Compare current to the previous frame message after analyzing the azimuth code and sequence number of a frame message. When the azimuth code has the same value of high 11 bits while the sequence number is different, it indicates that the current message is redundant. Add one to the frame count of the message, read out the previous frame data stored in RAM, accumulate it, and then write it into RAM. When the high 11 bits of the azimuth code are different, it indicates that the current azimuth is a new one. On one hand, read out the accumulated data in RAM, divide by the number of redundant

A Radar Video Compression and Display Method Based on FPGA

573

message frames, obtain the weighted average data and send it to DSP for processing through SRIO. On the other hand, store the current data in RAM and accumulate in the next azimuth code (Fig. 6).

Fig. 6. Flowchart of Compression in Interframe

After using compression in inter frame, the brightness between the target and the noise will be widened. Due to the randomness of the noise, compared to a stable target, weighted average decreases the brightness of the noise. And the brightness of the target remains unchanged from the initial brightness value. It is very convenient for human eye to observe when using compression in inter frame. Generally speaking, the white noise distribution of the system conforms to the Gaussian distribution, and the weighted average method can effectively eliminate the white noise of the system.

5 Display of Radar Video For soft display, after receiving the doorbell, the DSP takes the corresponding data and sends it to the display software through the network. The display software calculates the corresponding two-dimensional position and displays it on the screen.

574

D. Xie et al.

For hard display, after calculating the sine and cosine using CORDIC, obtain the corresponding two-dimensional positions, store them in RAM or other storage devices, and read out in DVI timing, then radar video can be displayed on the screen.

6 A Morphological Method for Eliminating Isolated Points Since A and B are collections of Z2 (two-dimensional integer space), A dilated by B is defined as [9]: A ⊕ B = {z|[(B)Z ∩ A] ⊆ A}

(3)

A corroded by B is defined as: AB = {z|[(B)Z ∩ A] ⊆ A}

(4)

As we are familiar with, dilation enlarges the image while corrosion shrinks it. The opening operation generally smoothes the contour of the object, breaks narrow gaps, and eliminates fine protrusions. The closed operation also makes the contour line smoother, but in contrast to the open operation, it usually eliminates narrow gaps and long gaps, eliminates small holes, and fills in breaks in the contour line. Using structural element B to set a opening operation on A, denoted as A ◦ B, which is defined as: A ◦ B = (AB) ⊕ B

(5)

Therefore, using structural element B to set a opening operation on A means using B to dilate A, and then use B to corrode the result. Similarly, using structural element B to set a closed operation on A, denoted as A • B, which is defined as: A • B = (A ⊕ B)B

(6)

Therefore, using structural element B to set a closed operation on A means using B to corrode A, and then use B to dilate the result. In radar videos, the presence of isolated points can affect observation, and due to the smaller size of isolated points compared to the target, open operations can be used to eliminate isolated points (Fig. 7).

7 Conclusion The foundation of radar video compression is that in a single frame of video, the number of sampling points is much greater than the number of display points, which makes it necessary to extract all points. In order to extract as evenly as possible, this article adopts the method of using a lookup table to record the ID of the sampling point. This method is flexible and can achieve different extraction effects by modifying the ROM initialization file.

A Radar Video Compression and Display Method Based on FPGA

575

Fig. 7. The compression and display process of radar video

In addition, the foundation of radar video compression is based on the significant redundancy between multiple frames of video. The actual azimuth accuracy is much greater than the azimuth accuracy required for fully display. This results in the need to extract between multiple frames of video. In order to achieve uniformity as much as possible, this article adopts a multi frame weighted average method. This method not only compresses radar videos, but also reduces the impact of noise, improves the contrast of the video, and enables better observation of the target. Taking a radar video with an accuracy of 100 km and 30 m as an example, it has 3333 sampling points and 512 samples per frame. The compression ratio for a single frame is 6.5:1. The compression between multiple frames of video, that is, according to the azimuth code, has a maximum compression ratio of 32. So, this method can theoretically achieve a compression ratio of 208:1 without affecting the video display. In order to better display radar videos, this paper proposes a morphological algorithm for processing radar videos, achieving the goal of eliminating isolated points. The processing method based on FPGA is not only suitable for various interfaces in hardware, but also suitable for radar video which has a fast speed and large data. Due to its parallel processing characteristics, The processing method based on FPGA has a high performance limit.

References 1. Xiao, X., Lv, L.: Hardware design and implementation of a radar display technology. Ship Electron. Eng. 28(7), 113–115 (2008) 2. Wang, X., Zhang, G.: Radar and Detection. National Defense Industry Press, Beijing, 7p (2008) 3. Sun, B: Design and Implementation of Radar Signal Processing Algorithm Based on FPGA. Beijing University of Technology, Beijing (2014) 4. Cao, Y., Yao, Y., et al.: Design of radar velocity measurement system based on TMS320F28335. Electron. Dev. 37(1), 45–49 (2014) 5. Zhai, G., Ji, Y.: High performance raster radar display system based on DVI technology. Radar Countermeas. 2, 49–56 (2009) 6. Liu, C., Wen, D.: Long afterglow simulation of PPI Radar on raster scanning display. Comput. Simul. 3, 42–47 (2012) 7. Li, H., Zhu, X., Gu, C.: Development, Design and Application of Verilog HDL and FPGA, pp. 125–127. National Defense Industry Press, Beijing (2013) 8. Gao, Y.: Digital Signal Processing Based on FPGA. Electronic Industry Press, Beijing (2012) 9. Gonzales, R.C., Woods, R.E.: Digital Image Processing, 2nd edn. Electronic Industry Press, Beijing, pp. 423–427 (2007)

SA-GAN: Chinese Character Style Transfer Based on Skeleton and Attention Model Jian Shu , Yuehui Chen(B)

, Yi Cao , and Yaou Zhao

School of Information Science and Engineering, University of Jinan, Jinan, China {yhchen,isecaoy,isezhaoyo}@ujn.edu.cn

Abstract. The Chinese character transfer task must meet two requirements: the transfer image should retain the content structure information of the original Chinese character as much as possible, and present different reference styles. Some of the earlier methods required training with large amounts of paired data, which was a time-consuming task. The existing method follows the normal form of stylecontent disentanglement, and realizes style transfer by combining the reference Chinese character style. This method is easy to cause the problem of missing stroke content and inaccurate overall style transfer. To address these issues, a generation network based on the Chinese character skeleton and attention model is proposed. To further ensure the completeness of the content of the converted Chinese characters, a more efficient upsampling module is introduced to improve the quality of the converted Chinese characters. Through extensive experiments, it is shown that the model requires only one reference Chinese character to produce higher quality Chinese character images than the current state-of-the-art methods. Keywords: Chinese character transfer · Attention · Chinese character skeleton

1 Introduction As a form of expression of text, Chinese characters are closely related to our daily life. As ideograms, Chinese characters have complex structures and semantics, so designing a group of Chinese characters takes a lot of manpower and time. At the same time, the demand for Chinese characters in the domains of culture, media and business has increased dramatically, which has formed a strong contradiction between supply and demand with the existing Chinese character design technology. However, there is a clear difference in quantity between Chinese characters and English alphabets, due to the diversity of Chinese character structures, the task of character transfer is challenging. Chinese characters are highly complex and fine grained. Therefore, a simple texture analysis cannot achieve a successful style transfer. Two early projects, "Rewrite" and "Zi2zi", were trained on thousands of paired Chinese character data to achieve mutual transfer between styles. Since then, HAN [1], PEGAN [2], DC-Font [3] have been improved on this basis and achieved good results. These font style transfer methods are based on paired data for training, which makes it difficult for such models to be widely used. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 576–587, 2023. https://doi.org/10.1007/978-981-99-4742-3_48

SA-GAN: Chinese Character Style Transfer

577

In order to solve the above problems, Zhu et al. combined adversarial training with cyclic consistency constraint for the problem of data pairing in image transfer, and realized the image transfer method CycleGAN without pairing. However, applying it directly to Chinese character transfer tasks does not directly achieve excellent results. DG-Font, FTransGAN [22], DM-Font [4] use a more targeted style/content encoder to improve the accuracy of model transfer styles and the completeness of transfer content. DM-Font [5] is similar to the method of LF-Font [6]. By designing multiple coding modules to predict the stroke labels of Chinese characters, the prediction results are recombined to form new Chinese characters, which solves the problem of missing content caused by sampling to a certain extent. However, for some Chinese characters with complex stroke content, making stroke labels is a time-consuming and labor-intensive process. DG-Font [7] introduces the deformable convolution module [8] to predict the correspondence between strokes between different styles. However, there are still problems with inaccurate style transfer and missing stroke content in the actual transfer task. Based on the above problems, an encoder is designed using the attention mechanism to capture global features, which makes the model transfer more accurate when processing multi-style input. The skeleton information of the original Chinese characters is added to the decoder, and introduced an upsampling module called CARAFE. This method aims to realize many-to-many style transfer. For model testing, the task of transferring Chinese characters can be accomplished by simply passing a style reference Chinese character into the model. Extensive experimental analysis shows that our method can successfully produce high-quality new Chinese characters based on only one style reference, with some improvement in various metrics.

2 Related Work 2.1 Image-To-Image Translation Image to image transfer is to learn the features between different styles to complete the transfer between styles. Pix2Pix [9] is an image-to-image transfer model based on CGAN [10], which realizes the generalization of model structure and loss function. In order to reduce the cost of pairing data production, many unsupervised methods have been proposed. CycleGAN solves the problem of unpaired image datasets with cyclic consistency loss, and realizes bidirectional style transfer of unpaired images in an unsupervised way, but the above method can only achieve mutual transfer between the two styles. To solve this problem, a recent work proposes to transfer multiple styles while keeping the same input. StarGAN [11] solves the problem of many-to-many style transfer between image style transfers. This method combines the target image and the one-hot vector of the target style coding into the generator, so that a single generator can transfer multiple different styles. Tero Karras [12] designed a new network architecture StyleGAN, which extracts high-level features of images, such as posture and gender, and to some extent controls the details of the transfer images.

578

J. Shu et al.

2.2 Font Generation For the style transfer task of Chinese character images, Tian (2016) designed a font style transfer network Rewrite [13] composed of multiple convolution layers including batch normalized [14], activation function and max pooling. The overall appearance of the Chinese character image transfer by this method is not good, and the phenomenon of blurred strokes is easy to occur for fonts with thinner strokes, and the content may be missing for fonts with thicker strokes. Chang et al. [15] designed the font style transfer network Unet-GAN on the basis of the image style transfer network Pix2Pix, which ensures the completeness of the transfer Chinese character content by increasing the number of generator network layers. PEGAN [16] and DC-Font [17] are both modified based on Rewrite. PEGAN modifies the Unet-GAN model by introducing cascade refinement connections in the encoder, while DC-FONT improves the accuracy of transfer styles by introducing a new style encoder to obtain better style representation. However, the above methods are all based on supervised methods and require a large amount of pairing data as training samples. In order to improve the scalability of the model and reduce the cost of pairing data production. Chang et al. [18] implements mutual transfer of two specific fonts based on a modified CycleGAN. However, due to the limitations of CycleGAN itself, the transferred Chinese character images may have problems of missing stroke content and pattern collapse. Gao proposed a three-stage (ENet-TNet-RNet) Chinese character style transfer model. ENet uses a set of mask matrices to extract the skeleton structure of input Chinese characters. TNet transfers the skeleton structure of the source character into the skeleton structure of the target style. RNet learns the stroke information of the target font, renders the stroke details on the target font Chinese character skeleton to generate the target font Chinese character image. Compared with CycleGAN, the effect of this model on generating Chinese character images is more stable. DG-font introduces a feature deformation skip connection that can predict displacement mapping pairs and uses the predictive mapping to apply deformable convolution to the low-level feature mapping of the content encoder to generate high-quality images by transferring. SC-Font [19] uses stroke-level constraints and integrates stroke features into the backbone network step by step to improve the quality of transfer Chinese character images. However, in terms of data, it takes a lot of time to obtain stroke-level labels, and it is difficult for some complex fonts to obtain stroke information. Mx-Font extracts multiple style features that are not explicitly established on component labels, but automatically represent different local concepts by multiple experts, thus allowing model generalization to fonts of unknown component structures, which makes cross-domain shape transfer possible.

3 Method Given a content image Ic and a style image Is , our model aims to produce a Is Chinese character image with a Is style. As shown in Fig. 1, the proposed network consists of a style encoder, a content encoder, a content decoder, and two skeleton-based skip connection layers. Both the style encoder and the content encoder are designed based on the self-attention mechanism. The style encoder extracts multiple feature vectors {s1 , s2 , . . . , sk } of the style image Is layer by layer, and then compresses the feature

SA-GAN: Chinese Character Style Transfer

579

vector {s1 , s2 , . . . , sk } through the Layer Attention module to obtain the style feature Zs . The content encoder extracts the feature vector Zc of the content image through the self-attention module. The content encoder consists of three convolution layers based on the self-attention mechanism and two residual structures. The content encoder based on the self-attention mechanism can ensure the completeness of the content image structure information. Use AdaIN [16] to inject style features into the content decoder. Replace the upsampling module in the content decoder with CARAFE [38]. In addition, the skeletonbased feature skip connection (SCSK) transfers the processed low-level features to the content decoder together. See Sect. 3.2 for details.

Fig. 1. Overview of the proposed method. (a) Overview of our generative network. The Style/Content encoder maps style/content image to feature. SCSK Combine content skeleton features with upper and lower features. (b) A detailed illustration of the SCSK module. (c) The discriminator can distinguish between generated and real images.

3.1 Attention Module The Style Encoder shows what the {Sr , r = 1, 2, 3} module contains. The input in each module is a convolution operation resulting in a feature vector of size C × W × H . In the self-attention module, represent each area in the feature map as vi . The specific process is as follows:    1 H ×W  h(i) = (1) f xI , xj g xj j=1 C(x)

580

J. Shu et al.

where C(x) represents normalized processing, i is an index of output positions, and j is an index that enumerates all possible positions. h and x are feature vectors of the same size. The function g is regarded as a preprocessing process on the input data, and the function f computes the correlation of each pixel to all other pixels. The above process can be understood as a weighted averaging process.

Fig. 2. Self-attention network module structure

We believe that not all regions are equally important, and we measure the importance of each region through an attention mechanism. The specific operation of Fig. 2 is as follows: ui = tanh(Wc hi + bc )

(2)

  ai = softmax uiT uc

(3)

f =

H ×W i=1

soft max(ai vi )

(4)

The NN module is a combined operation of flatten and activation, ui is a feature vector that measures the importance of each area, and uc represents the context vector obtained by flattening the feature vector.

Fig. 3. Architecture of the proposed Layer Attention Network

Unlike the content encoder, a layer attention-based module is added to the style encoder. For the input style image Is , we should pay more attention to the global style rather than the local style. As shown in Fig. 3, the Layer Attention module gets a feature vector Zs that evaluates the importance of each area. This feature vector indicates which area the model should pay more attention to. The specific formula is as follows: w = tanh(Wc fm + bc )

(5)

SA-GAN: Chinese Character Style Transfer

Z=

3 r=1

Zs =

softmax(wfr )

1 k Zk i=1 k

581

(6) (7)

where fm represents the feature map in the Fig. 3, w is the feature vector obtained by fattening and activating the feature map, and Z is the weighted sum of three feature vectors. Each style encoder will accept k images, so the final mapping style code Zs is the average of all feature vectors Z. 3.2 SCSK Module In order to solve the problem of missing stroke content when transferring images, we introduce a skeleton-based feature skip connection structure. As shown in Fig. 1(b), the module mixes the style feature Zs and the content feature Zc , and then adds the skeleton information of the content image to maintain the completeness of the transfer content. The skeleton feature ks is obtained by the pre-training ENet. The specific operation is as follows: ks = ENet(Ic )

(8)

kc = AdaIN(Zs , Zc )

(9)

ksc = f (kc ⊕ ω · ks ) ⊕ kc

(10)

The Adaptive Gradient Algorithm normalized method of AdaIN is used to inject the style of Zs into Zc , and the content image kc with style style is obtained. The importance of the skeleton feature ks is adjusted through the hyper parameter ω, and f is the operation of convolution and activation. 3.3 Loss Function Our model aims to achieve Chinese font transfer through an unpaired method. Therefore, we use four losses: (1) Adversarial loss is used to produce real images. (2) Introduce content consistency loss to encourage the content of Chinese characters after transfer to be consistent with the original Chinese character content. (3) Style consistency loss to ensure the correctness of Chinese character style transfer. (4) Use image reconstruction loss to keep the features of the region unchanged. The overall target loss is as follows. L = Ladv + λimg Limg + λcnt Lcnt + λsty Lsty

(11)

where λimg , λcnt , λsty are the corresponding weights of the three losses. Adversarial loss: The network is designed to transfer the corresponding image by solving a minimax optimization problem, the generation network G attempts to deceive discriminator D by generating a fake image, when the real or transferred, image is input

582

J. Shu et al.

to the discriminator, adversarial loss will penalize the wrong judgment, the formula is as follows.   Ladv = max min Es∈Ps ,c∈Pc log D(s) + log(1 − D(G(s, c))) (12) D

G

Content consistency loss: In the style transfer task, the transfer Chinese character image must maintain the same Chinese character structure as the source image. If pixel-level loss is used, style details may be ignored, which hinders the style transfer task. Conversely, using the content encoder can ignore style features and only compare content features. The formula is as follows: Lcnt = Es∈Ps ,c∈Pc Zc − fc (G(s, c))1

(12)

where f (c) represents the content encoder in the generator. Style consistency loss: As we all know, different Chinese characters of the same font have the same style features. However, only using adversarial loss cannot fully exert the ability of the style encoder. Therefore, it is necessary to design a style loss function similar to content loss to optimize the style encoder to retain the style features of the target font. The formula is as follows: Lsty = Es∈Ps ,c∈Pc Zs − fs (G(s, c))1

(13)

where f (s) represents the content encoder in the generator. Image reconstruction loss: In order to ensure that the generator can rebuild the original image Ic when providing the content image font, a reconstruction loss is applied: Limg = Ec∈Pc c − G(c, c)1

(14)

This objective helps to maintain the completeness of the input image Ic content encoder feature extraction.

4 Experiment To evaluate our Chinese character transfer model, we collected a data set containing a variety of font styles, including imitation handwriting fonts and printed fonts, each with 800 –1000 commonly used Chinese characters. All images are 80 × 80 pixels. We used this as training data, and the test data used the commonly used 3000 Chinese characters as a reference, which mixed multiple types of Chinese character data. 4.1 Comparison with State-of-Art Methods In this section, we have compared the model with the following Chinese font transfer methods: (1) Zi2zi: Is an improved version of the Pix2Pix model, using Gaussian noise as category embeddings to achieve multiple style transfers.

SA-GAN: Chinese Character Style Transfer

583

(2) Unet-GAN: The autoencoder model based on Unet structure and the adversarial training of generator and discriminator are used to complete one-to-one Chinese character image transfer by pairing data training. (3) Cycle-GAN: It consists of two generative networks, which realize the mutual transfer of Chinese character fonts through cyclic consistency loss. Cycle-GAN is an unsupervised method that realizes one-to-one transfer of Chinese character images with unpaired training. (4) StarGAN: Network architecture for multi-style transfer tasks. The model uses a generative network to achieve mutual transfer between various font styles. (5) DG-font: Replacing the traditional convolution with a deformable convolution, unlike StarGAN, DG-Font adds a style encoder to the generative network to make font style transfer more accurate.

Fig. 4. Comparisons to the stat-of-art methods for font generation.

584

J. Shu et al.

Quantitative comparison. The quantitative results are shown in Table1. In experiments, our model is similar to other methods on pixel-level metrics, such as L1, RMSE, SSIM, which focus on the pixel similarity between the transfer image and the input content image, while ignoring the feature similarity closer to human perception. Among the metrics of perception level, FID and LPIPS, our method has a certain improvement compared to other methods. Table 1. Qualitative evaluation on the whole dataset. Mehods

1-N

training

L1

RMSE

SSIM

LPIPS

FID

Zi2zi [15]



Paired

0.4043

4.7086

0.2343

0.330

177.536

Unet-GAN [27]



Paired

0.3576

4.2681

0.2257

0.4351

97.054

Cycle-GAN [5]



Unpaired

0.3855

4.5587

0.2595

0.3862

199.208

StarGAN [8]



Unpaired

0.3943

4.5023

0.2443

0.3605

53.771

DG-Font [40]



Unpaired

0.3796

4.4263

0.2640

0.3761

44.156

Ours



Unpaired

0.3323

4.3803

0.2749

0.3221

39.447

Qualitative comparison. We compare the model with other methods. As shown in Fig. 4, it includes print, cartoon fonts, and imitation handwriting. (1) The Zi2zi model has a good ability to transfer fonts in simple fonts, but the resulting font smoothness needs to be enhanced, such as "无" and "相" in Fig. 4(a, (1,3)).On some challenging fonts, there may be noise and the problem of broken strokes disappearing. As shown in Fig. 4(b, (1–5)). (2) Unet-GAN works better on simple font transfers, but for challenging font generation images of poor quality, the characters in Fig. 4(b, (1,5)) have the problem of broken strokes. (3) Although the Cycle-GAN method realizes the transfer of Chinese character styles through non-pairing, it has certain defects in the completeness of the content image after transfer and the accuracy of the transfer style. As shown in Fig. 4(a, 1), the “ 接, 莲” has a missing problem. In Fig. 4(b, 1), the picture produces a lot of noise, and in (b, 2), there is a problem of pattern collapse. (4) Compared with the Chinese character image transfer by Cycle-GAN, the Chinese character image transfer by StarGAN has a certain improvement in content clarity. (5) Compared with the previous methods, DG-Font has greatly improved the smoothness of Chinese characters and the accuracy of style transfer. The result of our model transfer not only guarantees the integrity of the strokes, but also the transfer style is very close to the target style. But there are also some failure cases, such as the “总” in Fig. 4(a, 4) does not achieve the expected effect on the unpaired dataset.

SA-GAN: Chinese Character Style Transfer

585

5 Ablation Experiment In this part, we add different parts to the model in turn and analyze the impact of each part, including the skeleton-based feature connection layer, attention-based feature extraction, and the impact of the CARAFE operator. We performed ablation experiments on 20 fonts and got average results. Our baseline used the DG-Font model (Fig. 5).

Fig. 5. Effect of different components in our method. We add different parts into our baseline successively.

The Validity of the Original Chinese Character Skeleton Feature Added to the Content Generator. Figure 4(b) shows the transfer results obtained by decoding after adding skeleton feature information to the feature jump module. We can see that the quantitative results in L1 loss, RMSE and SSIM are similar, but there is a significant improvement in LPIPS. LPIPS is more in line with human perception. Its lower value indicates that the two images are more similar, which indicates that the skeleton features of the original Chinese characters have a certain influence on the transfer image. Effectiveness of Feature Extraction Based on Attention Model. Figure 4(c) shows that after changing the content encoder and style encoder, the quality of the transfer pictures has been significantly improved, and the evaluation index of FID has also been significantly improved. Effectiveness of CARAFE Operator to Replace Upper Envelope Operator in Original Decoder. As can be seen from Fig. 4(d), the completeness and quality of the transfer content have been greatly improved, and various indicators have been significantly improved. The specific results are shown in the table below (Table 2):

586

J. Shu et al. Table 2. Comparison of results using operators on various types.

Methods

L1

RMSE

SSIM

LPIPS

FID

Unsampled

0.376

4.38

0.286

0.358

42.66

Deconvolution

0.364

4.33

0.282

0.362

49.09

CARAFE

0.338

4.32

0.282

0.343

39.80

6 Conclusion In our model, the mutual transfer of Chinese character styles is achieved by non-paired training. In order to ensure the integrity of the transfer Chinese character content, we propose a generator model based on skeleton information. In order to make the transfer style more similar to the target style, we adopt feature extraction based on attention. A large number of Chinese font generation experiments have verified the effectiveness of our proposed model, but there is still the problem of missing stroke content on some Chinese characters with complex font structures, which is a problem we need to solve in the next step.

References 1. Gao, Y., Wu, J.: GAN-based unpaired chinese character image translation via skeleton transformation and stroke rendering. In: National Conference on Artificial Intelligence. Association for the Advancement of Artificial Intelligence (AAAI) (2020) 2. Sun, D., Zhang, Q., Yang, J.: Pyramid embedded generative adversarial network for automated font generation. In: IEEE (2018) 3. Heusel, M., Ramsauer, H., Unterthiner, T., et al:. Gans trained by a two time-scale update rule converge to a local nash equilibrium. In: 30th Proceedings of International Conference on Advances in Neural Information Processing Systems (2017) 4. Zhu, J.Y., Park, T., Isola, P.: unpaired image-to-image translation using cycle-consistent adversarial networks. In: 2017 IEEE International Conference on Computer Vision (ICCV), pp. 2242-2251 (2017) 5. Cha, J., Chun, S., Lee, G., Lee, B., Kim, S., Lee, H.: Few-shot compositional font generation with dual memory. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12364, pp. 735–751. Springer, Cham (2020). https://doi.org/10.1007/978-3-03058529-7_43 6. Park, S., Chun, S., Cha, J., et al.: Few-shot font generation with localized style representations and factorization. Proc. AAAI Conf. Artif. Intell. 35(3), 2393–2402 (2021) 7. Xie, Y., Chen, X., Sun, L., et al.: Dg-font: Deformable generative networks for unsupervised font generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5130–5140 (2021) 8. Dai, J., Qi, H., Xiong, Y., et al.: Deformable convolutional networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 764–773 (2017) 9. Isola, P., Zhu, J.Y., Zhou, T., et al.: Image-to-image translation with conditional adversarial networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1125–1134 (2017)

SA-GAN: Chinese Character Style Transfer

587

10. Mirza, M., Osindero, S.: Conditional generative adversarial nets. arXiv preprint arXiv:1411. 1784 (2014) 11. Choi, Y., Choi, M., Kim, M., et al.: Stargan: unified generative adversarial networks for multidomain image-to-image translation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8789–8797 (2018) 12. Karras, T., Laine, S., Aittala, M., et al.: Analyzing and improving the image quality of stylegan. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8110–8119 (2020) 13. https://github.com/kaonashi.tyc/rewrite. Rewrite. 1, 2 14. Ioffe, S., Szegedy, C.: Batch normalization: accelerating deep network training by reducing internal covariate shift. In: International Conference on Machine Learning. PMLR, pp. 448– 456 (2015) 15. Lin, Y., Yuan, H., Lin, L.: Chinese typography transfer model based on generative adversarial network. In; 2020 Chinese Automation Congress (CAC), pp. 7005–7010. IEEE (2020) 16. Sun, D., Zhang, Q., Yang, J.: Pyramid embedded generative adversarial network for automated font generation. In: 2018 24th International Conference on Pattern Recognition (ICPR), pp. 976–981. IEEE (2018) 17. Jiang, Y., Lian, Z., Tang, Y., et al.: DCFont: an end-to-end deep Chinese font generation system. In: SIGGRAPH Asia 2017 Technical Briefs. 2017, pp. 1–4 18. Chang, B., Zhang, Q., Pan, S., et al.: Generating handwritten Chinese characters using cyclegan. In: 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 199–207. IEEE, 2018 19. Jiang, Y., Lian, Z., Tang, Y., et al.: Scfont: structure-guided Chinese font generation via deep stacked networks. Proc. AAAI Conf. Artif. Intell. 33(01), 4015–4022 (2019)

IL-YOLOv5: A Ship Detection Method Based on Incremental Learning Wenzheng Liu1 and Yaojie Chen1,2(B) 1 School of Computer Science and Technology, Wuhan University of Science and Technology,

Wuhan, China [email protected] 2 Hubei Province Key Laboratory of Intelligent Information Processing and Real-time Industrial System, Wuhan University of Science and Technology, Wuhan, China

Abstract. Traditional target detection algorithms based on deep learning require a large number of ship sample data sets to be trained to achieve better detection performance. However, obtaining a large number of data samples is difficult and may suffer from the problem of catastrophic forgetting. To address these issues, this paper proposes an incremental learning based ship detection method called ILYOLOv5 (Incremental Learning YOLOv5). IL-YOLOv5 employs an improved BiFPN and a coordinate attention mechanism to improve its ability to extract shiprelated features. Next, a base dataset of ship classes is trained to create a standard ship detection model, and finally, the incremental learning method is used to continuously modify the model and learn the characteristics of new ship samples. The experimental results show that the model has a good detection performance on both old and new ship datasets. The accuracy rate has reached nearly 72%, and the @mAp.5 has increased by 5.3%, effectively addressing the difficulty of collecting a large number of datasets and catastrophic forgetting. Keywords: Incremental learning · IL-YOLOv5 · Ship detection · Attention mechanism · Feature extraction

1 Introduction Object detection is an important research task in the field of computer vision, aiming to automatically identify and locate objects of interest in images or videos. Object detection techniques have been widely applied in areas such as face recognition, vehicle recognition, intelligent transportation, and robot vision. Gong et al. [1] have achieved promising results in vehicle recognition. Ship detection on the sea belongs to the category of object detection and is a challenging task. In recent years, many domestic and foreign scholars have proposed feasible ship detection algorithms [2–4]. For the present, deep learningbased object detection methods such as Faster R-CNN, YOLO, SSD, etc., have gradually become mainstream. Cui et al. [5] achieved good results in ship accuracy and recognition rate by using the Fast R-CNN network. However, two-stage algorithms such as Faster R-CNN still cannot meet real-time requirements due to their defects in detection © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 588–600, 2023. https://doi.org/10.1007/978-981-99-4742-3_49

IL-YOLOv5

589

speed. Liu et al. [6] improved the precision and speed of SSD network by introducing L2 regularization to extract more feature information of ships and balance the feature information obtained at each layer of the network. Zhou et al. [7] based on YOLOv5 [8], replaced normal convolution with mixed depthwise convolution in the backbone network, introduced an attention mechanism, and used Focal Loss [9] and CIoU Loss to improve the detection capability of the algorithm. However, in current domestic and foreign research on ship detection on the sea, ship detection technology often suffers from false alarms, missed detections, and low accuracy in practical applications. It also faces a serious problem that to achieve a high level of ship recognition and detection, a large amount of sample dataset needs to be trained. Nevertheless, relying solely on human effort to collect a large amount of sample dataset in real life requires a lot of time and manpower. In addition, traditional ship detection methods can achieve general results with a small sample dataset. However, when adding new sample datasets to enhance detection performance, if only training the new sample dataset without training the old sample dataset together, it will cause catastrophic forgetting on the old sample dataset. Incremental learning can continuously learn new features from new sample datasets while retaining previously learned sample features, achieving "non-forgetting learning". Combining traditional ship detection methods with incremental learning methods, this paper proposes an incremental learning YOLOv5 (IL-YOLOv5) ship detection method. Firstly, a skip image information feature extraction method is introduced between the feature extraction layer and the classification layer in the network, and a collaborative attention mechanism is added to extract important information of ship images. Then, the basic sample dataset of ships on the sea is trained to create a standard ship detection model. Finally, the incremental learning method is used to modify the model so that it can continuously learn new ship sample features and achieve good detection for both new and old ship datasets.

2 Related Works 2.1 Object Detection Existing deep object detection models fall generally into two categories: (1) Two-stage detectors (2) One-stage detectors. Faster-RCNN is a classic two-stage deep learning model that utilizes a Region Proposal Network (RPN) to propose candidate object regions and performs classification and refinement on these proposals. However, for small objects, it is difficult to match the size of the candidate regions with the size of the objects, which often leads to the generation of too many or too few candidate regions, resulting in less accurate detection results. YOLO [8, 12, 13, 15] and SSD is based on the single-stage detection idea. Compared to SSD, YOLO is simpler and easier to train and adjust. Duan et al. [14] addressed the problem of low detection and recognition rates of YOLOv2 by introducing support vector machines to classify the detected targets, which greatly improved the detection and recognition rates. Li et al. [15] introduced an attention mechanism based on YOLOv3 and used convolutional operations to replace pooling operations, improving the detection capability of small targets. Kong et al. [16] improved the K-means clustering algorithm based on YOLOv4 to redesign the prior

590

W. Liu and Y. Chen

anchor boxes and used data augmentation to increase the number of small samples in the imbalanced dataset to improve the accuracy of ship target detection. Finally, Softer-NMS (Softer-Non-Maximum Suppression) was introduced to enhance the detection capability and positioning accuracy of the object detector for dense ships. Han et al. [17] based on YOLOv4 first introduced dilated convolutions into the spatial pyramid pooling SPP (Spatial Pyramid Pooling) module [18] to enhance the model’s ability to obtain spatial information of small ships, and then used attention mechanism and residual ideas to improve the feature pyramid for better feature extraction, and finally improved the inference speed by fusing convolution kernels during detection. Nevertheless accurately extracting feature information from overlapping ship targets remains challenging. To address this issue, Li et al. [19] introduced an attention mechanism to YOLOv5 and utilized PolyLoss to obtain ship information. However, YOLOv5 also suffers from the catastrophic forgetting problem, which is a common issue in traditional object detection methods. 2.2 Incremental Learning Method In order to alleviate the problems mentioned above that exist in YOLOv5, this paper employs incremental learning methods. Incremental learning is a machine learning approach that mimics the way humans learn. After a period of learning, the machine can remember previous knowledge and continuously learn new knowledge without forgetting old knowledge. At the initial moment, the model parameter is set to w0 . Assuming we have n data samples (x1 ,y1 ), (x2 , y2 ),…, (xn , yn ), where xi is the input data and yi is the corresponding label. At the t-th time step, we input a new data sample (xn+t , yn+t ) and use the current model parameter wt to make predictions. Assuming the prediction result is y n + t, we then use the difference between the prediction result and the true label to update the model parameters. Specifically, we use a loss function L(yn + t, y n + t) to calculate the prediction error and use methods such as gradient descent to minimize the loss function. The updated model parameter is wt + 1, which can be calculated using the following formula:   (1) wt+1 = wt − αt ∇L yn+t , yˆ n+t 





where αt is the learning rate and ∇L (yn+t , yn+t ) is the gradient of the loss function with respect to the model parameters. By continuously repeating the above process, we can gradually update the model parameters without retraining the entire model, thereby achieving incremental learning.

3 Methodology This paper proposes an IL-YOLOv5 ship detection method that combines the YOLOv5 network model and incremental learning method. The ship sample dataset used in this paper is a small sample dataset with limited data, so in order to extract more image feature information from the small sample dataset, the network is first improved to enhance the ability to extract image feature information, and then the incremental learning method is used to modify the model so that it can continuously learn the feature information

IL-YOLOv5

591

of new ship sample datasets. In this paper, the CA [20] module is added to the original YOLOv5 network model, and extra feature information extraction is performed between the YOLOv5-Backbone structure and the YOLOv5-Neck structure to form an improved structure based on the BiFPN [21] structure, as shown in Fig. 1.

Fig. 1. IL-YOLOv5 network structure

3.1 Image Feature Information Enhancement 3.1.1 Improve BiFPN This paragraph discusses the challenges of collecting a large ship sample dataset and the need to extract more effective ship image features from a small sample dataset. The YOLOv5 network model uses the PANet structure for multi-scale information fusion, but the proposed IL-YOLOv5 network model improves the PANet structure to the BiFPN structure for multi-scale information fusion. The authors further modify the original BiFPN structure to be more suitable for small ship datasets, as shown in Fig. 2. The proposed improvement makes the BiFPN network layer smaller and more appropriate for small sample ship datasets.

Fig. 2. Improved BiFPN

This method achieves bidirectional fusion of deep and shallow features, enhancing the transfer of feature information between different network layers. It can combine

592

W. Liu and Y. Chen

small-scale, medium-scale, and large-scale feature information of ships in the image to make the ship information in the model weight more balanced. In this method, the P5 layer is used to extract large-scale feature information of ships, the P4 layer is used to extract medium-scale feature information, and the P3 layer is used to extract smallscale feature information. This method significantly improves the detection accuracy of YOLOv5 algorithm for ships and has better detection performance. The improved BiFPN network structure learns the importance of different input features and fuses them with distinction.The method used is O=

 i

wi ∗ Ii  ∈ + j wj

(2)

The method used involves the use of learnable weights wi and wj , which can be represented as scalars (per feature), vectors (per channel), or multidimensional tensors (per pixel). The parameter I i is the input for each layer of the network, while ∈ = 0.0001 is a decimal value used to prevent numerical instability. This method is similar to Softmax, as it constrains the range of values between 0 and 1, leading to improved training efficiency and speed. In this paper, an additional skip connection is added between the feature extraction layer and the classification layer to enhance the image feature information extraction method, specifically by improving the BiFPN. The calculation expression for the improved BiFPN is as follows:    w1 ∗ P4in + w2 ∗ Resize P5in td (3) P4 = Conv w1 + w2 + ∈    w1  ∗ P4in + w2  ∗ P4td + w3  ∗ Resize P3out out P4 = Conv (4) w1  + w2  + w3 + ∈ The Resize operation is typically used for downsampling or upsampling. wi is a parameter learned at each layer to differentiate the importance of different features in the feature fusion process. The intermediate feature of level 4 on the top-down path is denoted by P4td , with one input from P4in and another input from P5in . The output feature of level 4 on the bottom-up path is denoted by P4out , with one input from P4in , another input from P4td , and the last input from P3out , for a total of three inputs. 3.1.2 CA Attention Mechanism In this paper, YOLOv5 is a general-purpose object detection model designed to detect various types of objects. When extracting image features in the backbone structure of the YOLOv5 neural network, multiple convolutional kernels are used to extract local information from the input image. At this time, the ship and other redundant image information will be extracted. In the case of the ship on the water surface, the other redundant image information generally refers to the water surface background information, which may not be useful. These information factors have a significant impact. In order to make the model more focused on extracting ship information, this paper

IL-YOLOv5

593

proposes an improvement to the YOLOv5 model by adding a collaborative attention mechanism called Coordinate Attention (CA) to the network. The CA attention mechanism is simple, flexible, and efficient, and it instructs the model to focus on the ship and its image coordinate information while suppressing the water surface background information. Additionally, for lightweight networks, there is no extra computational resource overhead when adding a CA module, and the specific structure of a CA module is shown in Fig. 3.

Fig. 3. CA module

A CA module can be regarded as a computing unit that enhances feature representation. In this study, a CA module was added between the three Detect structures and the YOLOv5-Neck structure in the YOLOv5-Head structure, as shown in Fig. 3. The three Detect structures process small-scale, medium-scale, and large-scale information in the image, respectively. The intermediate tensor output of the C3 module is used as the input of the CA module, and the attention mechanism coordinates attention is used to emphasize the features related to the ship and suppress the background information. The output of the CA module is then fed into the YOLOv5-Neck structure for subsequent processing. This modification helps the model to focus more on extracting ship information, resulting in better detection performance. X = [x1 , x2 , · · · , xc ] ∈ RC×H ×W

(5)

As the input of the CA module, the output is a tensor of the same size with enhanced expressive ability

(6) Y = y1 , y2 , · · · , yc ∈ RC×H ×W where X and Y are parameter matrices of each layer of the network, C is the number of channels, H is the height, and W is the width. When the intermediate tensor Xi(i = 1,2,3) of three different scales enters the CA module, each channel is first encoded with pooled kernels of sizes (H,1) and (1,W) along the horizontal and vertical coordinates of the input Xi. The output of tensor Xi at the c-channel with height h is then expressed as follows: 1  Xc (h, i) (7) Zch (h) = 0≤i th th = ut ∗ max(CAM k (i, j))

(6)

where i and j denote the coordinates and the value of ut will adaptively decrease when the training epoch increase until u meets the lower bound l. ut = max(l, ut−1 ∗ d )

(7)

where t is the epoch and d denote decay rate. We set d = 0.985 and l = 0.65 in practice. Finally, the progressive dropout attention map Apda was computed as follows: Apda =

1  CAM k K

(8)

and we get the final feature map. F = Apda ⊗ Fca

(9)

In general, the feature F is processed by GAP and then fed into a fully-connected layer to obtain the predicted result. y = FC 1 (GAP(F))

(10)

692

X. Lan et al.

2.3 Class Re-Activation Mapping As we all know, binary cross-entropy (BCE) is widely used in CAM-based methods. Due to BCE calculate classification loss independently, each pixel may respond to more than one class. Therefore, we reactivate CAM by learning another fully-connected layer (i.e. FC-Layer-2) by SCE. The main idea is that after the convergence of the multi-label classification network, we decompose the multi-label vector into multiple single-label vectors (as shown in Fig. 1). Meanwhile, we extract CAM and multiply it with the previous feature map of each channel to generate single label feature. Fk = CAM k ⊗ Fca

(11)

where Fca and Fk indicate feature maps before and after the multiplication. The predicted multi-label classification result y, Lsce and total loss Ltotal are defined as follows: yk = FC 2 (GAP(Fk )) 1 Lsce = − K

k=1 yk

K

exp(yk ) yk log  k=1 j exp(yj )

(12)



Ltotal = Lbce + Lsce

(13) (14)

Once the model converges again, we extract class re-activation maps (ReCAM) using the following calculation. ReCAM k =

ReLU (Ak ) , Ak = wk ∗ Fca max(ReLU (Ak ))

(15)

where wk indicate the weight of FC-Layer-1 after re-activation. Since the classification result is more reliable than CAM, we consider that the tissue does not exist when the prediction yk is less than .  ReCAMk , yk ≥  ReCAMk = (16) 0, yk <  In practice, we set  as 0.15. 2.4 Implement Details In our experiments, all the neural networks are implemented using PyTorch. The model is trained on a server with NVIDIA RTX 1080Ti. WiderResNet38 is introduced as the classification backbone. The model is pre-trained on ILSVRC 2012 classification dataset [12]. The specific experimental settings are as follows: in the second part, the batch size is set to 20 and the training epoch is set to 20. All input images are transformed by random horizontal and vertical flips with a probability of 0.5. We set a learning rate of 0.01 with a polynomial decay policy. In the third part, the three hyperparameters batch size, training epoch, and learning rate are set to 16, 10, and 0.005, respectively. We evaluate our proposed model by the following metrics, IoU for each category, Mean IoU (MIoU), Frequency weighted IoU (FwIoU), and pixel-level accuracy (ACC). The white backgrounds inside the alveolus were excluded when calculating the performance in all the experiments.

A Weakly Supervised Semantic Segmentation Method

693

2.5 Datasets LUAD-HistoSeg is a histopathology image dataset proposed by Han et al. [9] which is designed to test our method. The goal is to achieve pixel-level semantic segmentation for four tissues, tumor epithelial (TE), tumor-associated stroma (TAS), necrosis (NEC), and lymphocyte (LYM) using only image-level labels. In addition, LUAD-HistoSeg is split into a training set (16678 patches, image-level label), a validation set (300 patches, pixel-level label), and a test set (307 patches, pixel-level label), and the size of each patch are 224 × 224.

3 Results 3.1 Ablation Experiments To verify the effectiveness of our method, we conducted seven groups of experiments. Based on the experiment result, we decide to use WiderResNet38 as the backbone and conduct further ablation experiments using this model. Specifically, the experiments were designed as follows: (1) WiderResNet38 served as the baseline. (2) we integrated only the progressive dropout attention (PDA) into the baseline. (3) We jointly integrated the progressive dropout attention (PDA) and channel attention (CA) into the baseline. (4) We jointly integrated PDA, CA and ReCAM into the baseline. Table 1. The effect of different modules in our method (in percentage) and the “WR38” in the table represent the WiderResNet38. PDA CA ReCAM Backbone

TE

NEC

LYM

TAS

FwIoU MIoU Acc







ResNet50

69.29 73.86 71.00 64.81 68.23

69.76

81.05







ResNet101 69.34 72.57 67.05 64.08 67.32

68.26

80.42







ResNet152 71.54 69.31 66.14 67.20 69.04

68.55

81.70

– √





WR38

72.86 72.69 71.30 68.95 71.19

71.45

83.11



– √



WR38

75.67 78.08 73.70 69.70 73.32

74.26

84.51





– √

WR38

76.49 81.32 75.41 71.36 74.81

76.15

85.50

WR38

76.50 82.20 75.70 71.74 75.05

76.53

85.66

The results of all modules are presented in Table 1. In the fifth group of experiments, we observed that the PDA achieved 73.32% in FwIoU, 74.26% in MIoU, and 84.51% in Acc, displaying improvements of 2.13%, 2.81%, and 1.4% in each metric, respectively, when compared to the baseline. In the sixth group of experiments, we observed that the network integrating CA and PDA achieved 74.81% in FwIoU, 76.15% in MIoU, and 85.50% in Acc, respectively, showing an improvement of 1.49%, 1.89%, and 0.99% in each metric compared to the performance of group five. This demonstrates that the addition of CA has significantly enhanced the network’s performance. In the seventh group of experiments, we found that the addition of ReCAM improved the predictive

694

X. Lan et al.

performance of the model for TAS and NEC, with improvements in MIoU of 0.38% and 0.88%, respectively. These experimental results validate the effectiveness of each module. To further demonstrate the effectiveness of our method, we compared the CAM generated by different modules. As depicted in Fig. 2 (b), the baseline network can locate the tissue’s position effectively, but it fails to cover the entire object, and the boundaries are unclear. In Fig. 2 (c), Although the class activation maps have been expanded compared to the baseline, the activation degree of the features in the red dashed box is relatively low. Figure 2 (d) shows increased activation of important features. Notably, ReCAM reduces erroneous activation in CAM, resulting in more accurate boundaries of the organization, as observed in the white dashed box in Fig. 2 (e).

Fig. 2. Each column is a 50% blend of the maps with the original image and used red dashed box to mark the target tissues’s regions for better observation.

3.2 Comparison Experiments Table 2 shows a performance comparison between the proposed method and advanced existing methods. Two of the methods are for histopathology images (Table 2 (1) and (6)), and the remaining four are for natural images (Table 2 (2)–(5)). From the table, it is evident that the proposed method achieves 76.53% on MIoU, which is a 0.93% improvement over Han et al. [9], demonstrating superior performance to existing methods. Notably, our method achieved 82.20% and 75.70% MIoU in NEC and LYM, with respective increases of 2.88% and 2.29%. Overall, these experiments indicate that our method can distinguish lung adenocarcinoma tissue better using only image-level labels. In Fig. 3, we compare the segmentation results of advanced methods with our method. Each row represents an example, with the first column showing the original image and the remaining columns representing the ground truth, the results of Han et al. [9], and our results. For clarity, we mix the ground truth, results, and original images at a ratio of 50%. In addition, we use black dashed boxes to highlight differences in the figure. As shown in the first row of the Fig. 3, the original image contains TE, NEC, and LYM.

A Weakly Supervised Semantic Segmentation Method

695

Table 2. Quantitative comparison with existing methods. No.

Methods

TE

NEC

LYM

TAS

FwIoU

MIoU

Acc

(1)

HistoSegNet [7]

45.59

36.30

58.28

50.82

48.54

47.75

65.97

(2)

SC-CAM [13]

68.29

64.28

62.06

61.79

64.74

64.10

78.70

(3)

OAA [14]

69.56

53.56

67.18

62.90

65.58

63.30

79.25

(4)

Grad-CAM++ [15]

72.90

74.18

67.93

66.01

69.78

70.26

81.97

(5)

GGNet [16]

71.85

73.30

69.09

67.26

68.89

70.38

82.21

(6)

Han et al. [9]

77.70

79.32

73.41

71.98

75.13

75.60

85.70

(7)

Ours

76.50

82.20

75.70

71.74

75.05

76.53

85.66

Compare to other method, our method performs better in distinguishing between NEC and TE. In the second row of Fig. 3, The method proposed by Han et al. [9] has limited recognition ability for LYM and TAS. Through the observation of Fig. 3, we found that our method has advantages in recognition on NEC and LYM, which are also reflected in numerical values (our method achieved 82.20% and 75.70% MIoU on NEC and LYM, respectively). Overall, our method performs better than the advanced method in terms of the data and segmentation results.

Fig. 3. Qualitative results of our proposed method.

4 Discussion We choose several samples for discussion, as shown in Fig. 4. Each column depicts an example, where the first row represents the original image, the second row represents the ground truth and the third row shows the results of our method. Observing Fig. 4,

696

X. Lan et al.

it can be observed that the boundaries of the tissues in example 1 and example 2 are relatively clear and the shape is relatively simple. Therefore, our method has achieved good segmentation results. However, the LYM in example 3 is scattered, which makes it difficult for our method to accurately identify boundaries. In addition, in example 4, the color and texture of NEC are similar to LYM, which results in the model recognizing NEC in the lower right corner as LYM. To sum up, our paper proposes various modules that enhance the network’s capability to differentiate between different tissues, as evidenced by the analysis of experimental and segmentation results. Moreover, our approach exhibits advantages over advanced methods in terms of accurately locating tissues and predicting their boundaries.

Fig. 4. The figure shows the segmentation results of tissues that have shapes and textures with different degrees of complexity.

5 Conclusions In this paper, we proposed a novel weakly-supervised learning method for lung adenocarcinoma semantic segmentation on histopathology images. Extensive experiments showed that our method can accomplish the pixel segmentation task using only imagelevel labels and reduce annotation costs. Besides, our method achieved state-of-the-art performance compared with previous methods. In the future, we will explore more possibility to reduce the annotation cost of medical image datasets.

A Weakly Supervised Semantic Segmentation Method

697

Acknowledgements. This work is supported by the Guangxi Key Laboratory of Image and Graphic Intelligent Processing (GIIP2004), the National Natural Science Foundation of China (61862017), and the Innovation Project of GUET (Guilin University of Electronic Technology) Graduate Education (2022YCXS063).

References 1. Dai, J., He, K., Sun, J.: BoxSup: exploiting bounding boxes to supervise convolutional networks for semantic segmentation. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1635–1643 (2015) 2. Bearman, A., Russakovsky, O., Ferrari, V., Fei-Fei, L.: What’s the point: semantic segmentation with point supervision. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9911, pp. 549–565. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946478-7_34 3. Lin, D., Dai, J., Jia, J., He, K., Sun, J.: ScribbleSup: Scribble-supervised convolutional networks for semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3159–3167 (2016) 4. Kolesnikov, A., Lampert, C.H.: Seed, expand and constrain: three principles for weaklysupervised image segmentation. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 695–711. Springer, Cham (2016). https://doi.org/10.1007/978-3319-46493-0_42 5. Wei, Y., Feng, J., Liang, X., Cheng, M.-M., Zhao, Y., Yan, S.: Object region mining with adversarial erasing: a simple classification to semantic segmentation approach. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1568–1576 (2017) 6. Singh, K.K., Lee, Y.J.: Hide-and-seek: forcing a network to be meticulous for weaklysupervised object and action localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 3524–3533 (2017) 7. Chan, L., Hosseini, M.S., Rowsell, C., Plataniotis, K.N., Damaskinos, S.: HistoSegNet: semantic segmentation of histological tissue type in whole slide images. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10662–10671 (2019) 8. Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Grad-CAM: visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 618–626 (2017) 9. Han, C., et al.: Multi-layer pseudo-supervision for histopathology tissue semantic segmentation using patch-level classification labels. Med. Image Anal., 102487 (2022) 10. Chen, Z., Wang, T., Wu, X., Hua, X.-S., Zhang, H., Sun, Q.: Class re-activation maps for weakly-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 969–978 (2022) 11. Wu, Z., Shen, C., Van Den Hengel, A.: Wider or deeper: revisiting the ResNet model for visual recognition. Pattern Recognit. 90, 119–133 (2019) 12. Ahn, J., Kwak, S.: Learning pixel-level semantic affinity with image-level supervision for weakly supervised semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4981–4990 (2018) 13. Chang, Y.-T., et al.: Weakly-supervised semantic segmentation via subcategory exploration. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8991–9000 (2020)

698

X. Lan et al.

14. Jiang, P.-T., Han, L.-H., Hou, Q., Cheng, M.-M., Wei, Y.: Online attention accumulation for weakly supervised semantic segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 44(10), 7062–7077 (2021) 15. Chattopadhay, A., Sarkar, A., Howlader, P., Balasubramanian, V.N.: Grad-CAM++: generalized gradient-based visual explanations for deep convolutional networks. In: 2018 IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 839–847. IEEE (2018) 16. Kweon, H., Yoon, S.-H., Kim, H., Park, D., Yoon, K.-J.: Unlocking the potential of ordinary classifier: class-specific adversarial erasing framework for weakly supervised semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6994–7003 (2021)

Improved Lane Line Detection Algorithms Based on Incomplete Line Fitting QingYu Ren(B) , BingRui Zhao, TingShuo Jiang, and WeiZe Gao School of Computer Science and Technology, Jilin University, Changchun 130000, China [email protected]

Abstract. Lane detection is a crucial component of autonomous driving, as it enables vehicles to accurately perceive their environment and safely navigate roads. However, existing lane detection algorithms often struggle to balance realtime processing and robustness. In this paper, we propose an improved sliding window lane detection algorithm that overcomes this challenge. Our algorithm begins by transforming the image into a bird’s-eye view using inverse perspective transformation and then generates a binary map using the Sobel operator. The algorithm then dynamically selects between quadratic curves and straight lines to fit the vacant portions of the left and right lane lines within the current window and determines the position of the next sliding window. Finally, a quadratic curve or straight line is dynamically selected to fit the entire lane trajectory using the least square method. Our testing results show that our algorithm can detect lane lines quickly while maintaining high robustness and accuracy. This algorithm has the potential to be a valuable contribution to the field of computer vision and image processing for autonomous driving. Keywords: Lane detection · Computer vision · Image processing · Incomplete line fitting · Sliding windows

1 Introduction The rapid advancements in science and technology have significantly contributed to the growth of the automatic driving field [1]. One of the fundamental functions of automatic driving is lane detection. It provides crucial guidance on the deviation of the vehicle from the lane lines, which is essential for ensuring traffic safety. As a result, this function plays a significant role in the development of automatic driving technology. The existing lane detection methods can be divided into three categories: featurebased, model-based and learning-based. 1. Feature-based methods generally rely on feature information such as lane line color [2], gradient changes [3], and edges [4] to perform lane line segmentation or extraction. Feature-based methods can be further subdivided into color-based, edge-based, and V-based [5] methods.

© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 699–709, 2023. https://doi.org/10.1007/978-981-99-4742-3_58

700

Q. Ren et al.

2. Model-based lane detection methods model lane lines as straight lines [6], parabolic curves [7, 8], hyperbolic curves [9], and spline curves [10, 11] based on the trajectory of the lane lines. The model parameters are then solved using methods such as least squares. 3. Learning-based lane detection algorithms mainly extract lane line information using neural networks [12, 13] or Boosting methods [14], which improves detection accuracy compared to traditional methods. However, these algorithms tend to have poor real-time performance and high computational complexity. Given the challenges associated with balancing real-time performance and robustness in existing lane detection algorithms, this paper proposes a lane detection method based on an improved sliding window algorithm. The proposed method is characterized by high accuracy, robustness, and real-time performance, making it a valuable source of lane information for autonomous driving systems.

2 Algorithm Design This paper proposes a lane detection algorithm based on an improved sliding window approach that includes four major parts: image preprocessing, lane detection, lane fitting, and lane display. The image preprocessing includes camera calibration and perspective transformation. Due to camera distortion, the original image is distorted, and calibration is performed by using a distortion parameter vector to correct the pixel coordinates of the original image. Then, the input video frame is transformed into a bird’s-eye view using perspective transformation, which reduces the interference of useless information on lane detection. Lane detection includes edge detection based on the Sobel operator and sliding window search. Edge detection based on Sobel operator produces a binary grayscale image of the original image. Next, the starting base points of the left and right lanes are determined using the pixel distribution histogram. Lane pixels are then extracted using a vertical sliding window search method. The sliding window is determined by the lower midpoint of the bottom edge of the window, width, and height. If there are no lane pixels in the current window’s upper edge and the vertical coordinate differences between adjacent pixel pairs are large, the current window’s lane is extended using a quadratic curve equation. Otherwise, the lane is extended using a linear equation. Finally, the least squares method is used for lane fitting. If the vertical coordinate differences between adjacent pixel pairs are large, the current lane is fitted using a quadratic curve equation; otherwise, a linear equation is used. Finally, the sliding window search results are overlaid onto the original video frame using perspective transformation to obtain the output frame. Figure 1 is a flowchart of the algorithm.

3 Image Preprocessing 3.1 Camera Calibration The imaging principle of a camera is based on the transformation of the coordinate system. Due to the precision of lens manufacturing and errors in assembly processes, distortions can be introduced which cause image distortion. Camera distortion can be divided into two types: radial distortion and tangential distortion.

Improved Lane Line Detection Algorithms

701

Fig. 1. Algorithm flow chart

Radial distortion arises from irregular lens shapes and the modeling approach used, which results in varying focal lengths across different regions of the lens. With radial distortion, the distortion at the center (optical axis) of the imaging device is zero, but it becomes more severe as we move towards the edges. Assuming the distortion coefficients are D = [k1 , k2 , p1 , p2 , k3 ] , where k 1 , k 2 , and k 3 are radial distortion coefficients, and p1 and p2 are tangential distortion coefficients. Tangential distortion arises from the assembly process of the entire camera. Defects in lens manufacturing cause the lens itself to not be parallel to the image plane. The distortion parameter vector obtained by passing the 10 calibration images captured by the on-board camera as parameters to the calibrateCamera() function in OpenCV is denoted as D. After calculation, D = [−0.32740709

− 0.36273064 0.01802072

− 0.00244828 1.47735848]

3.2 Inverse Perspective Transformation The video images captured include information about the current lane and other irrelevant objects. Directly processing these images would reduce the efficiency of lane detection. However, since useless information, such as trees and signs, is mostly concentrated in the upper part of the image, it is only necessary to process the road information within the region of interest. This significantly improves the detection efficiency. By setting 4 points on both the original and the transformed image, these eight points determine the transformation matrix from the original image to the bird’s-eye view. The inverse perspective matrix can then be used to match the points on the original and transformed images one by one.

702

Q. Ren et al.

4 Lane Lines Detection 4.1 Edge Detection Based on Sobel Operator The Sobel operator is a first-order derivative edge detection operator. In the algorithm implementation process, a 3 × 3 template is used as the kernel to convolve with each pixel in the image, and then a suitable threshold is selected to extract the edges. The function of the convolution is determined by the form of the kernel. The kernel is essentially an array of fixed size composed of numerical parameters, and the reference point of the array is located at the center of the array. For convolution of an image, first, the reference point of the kernel is positioned at the first pixel of the image, and the remaining elements of the kernel cover the corresponding local pixels in the image. For each kernel point, the kernel value and the value of the corresponding image point in the image can be obtained, and these values are multiplied and summed, and the result is placed in the position corresponding to the input image reference point. By scanning the convolution kernel over the entire image, this operation is repeated for each point in the image. The gray scale of each pixel of the image is calculated by the following formula:  G = Gx2 + Gy2 (1) The variables Gx and Gy represent the grayscale values in the horizontal and vertical directions, respectively. The direction of the gradient is calculated as follows:   G (2) θ = arctan Gyx For an original image A, applying the Sobel operator will produce two output images: one representing the horizontal edges of the image, denoted as G x , and the other representing the vertical edges of the image, denoted as Gy . The following formula can be used to calculate these output images: ⎡ ⎤ −1 0 +1 Gx = ⎣ −2 0 +2 ⎦∗A (3) −1 0 +1 ⎡ ⎤ +1 +2 +1 Gy = ⎣ 0 0 0 ⎦∗A (4) −1 −2 −1 The cv2.Sobel() function can be used to calculate the color gradient derivatives in the x or y direction, which can then be used for threshold filtering to obtain a binary image. Through experimentation, it has been found that using a threshold filter in the x direction in the range of 30 to 100 can produce a binary image that captures the lane lines. Figures 2(a), 2(b), 2(c) and 2(d) are images after threshold filtering. From the results shown in Fig. 4, it can be observed that edge detection based on the Sobel operator can detect the edges of the left and right lane lines, while excluding interference from other irrelevant factors. This serves as a foundation for the subsequent steps.

Improved Lane Line Detection Algorithms

703

Fig. 2. Image after threshold filtering

4.2 Locate the Left and Right Base Points of the Lane Lines To locate the base points of the lane lines, which are the x-axis coordinates where the lanes first appear in the image, we can partition the image into two halves since the pixels corresponding to the lane lines are concentrated within a certain range along the x-axis. The peak points of the pixel distribution for the left and right sides of the image along the x-axis will correspond to the base points of the respective lane lines. Figure 3 shows the pixel distribution for the lane lines in Fig. 2(a). From Fig. 3, it can be seen that the x-coordinate of the left lane line’s base point is approximately 380, while the x-coordinate of the right lane line’s base point is approximately 960.

Fig. 3. Pixel Distribution Histogram

704

Q. Ren et al.

4.3 Sliding Window Search After determining the approximate positions of the left and right lane lines, a sliding window method is used to search for points on these lines in the image. The search process follows these steps: Step1. After multiple tests, a sliding window width of 6 times the lane width is selected and denoted as w. The window height is selected as 1/9 of the image height and denoted as h. Step2. Pixel point search is performed. During the k-th search iteration, all non-zero pixels within the current window are collected and added to an array named Sk . The number of collected elements in Sk is nk . The bottom midpoint of the search window in the k-th iteration is denoted as Mk (xk , yk ). Step3. To obtain the coordinates of the bottom midpoint of the next search window denoted as Mk+1 (xk+1 , yk+1 ), the following steps can be taken. During the k-th iteration of the search, if there are no lane line pixels on the top edge of the current window, the lane line needs to be extrapolated into the window using the points in Sk . If the difference in vertical coordinate values between adjacent pixels in Sk is large, a quadratic curve equation is used to fit the lane line for the current window. If the difference in vertical coordinate values between adjacent pixels in Sk is small, a straight-line equation is used to fit the lane line for the current window. If the vertical coordinate differences between adjacent pixel pairs in Sk are small, the intersection of the two equations below gives the coordinate of the next window’s bottom edge midpoint Mk+1 (xk+1 , yk+1 ). Here, a represents the slope, b represents the intercept, h is the height of the window, and yi represents the vertical coordinate of the midpoint of the i-th sliding window’s bottom edge. y = ax + b (5) y = yi + h If the vertical coordinate differences between adjacent pixel pairs in Sk are small, the intersection of the two equations below gives the coordinate of the next window’s bottom edge midpoint Mk+1 (xk+1 , yk+1 ). Here, c represents the curvature of the lane line, d represents the slope, e represents the intercept, h is the height of the window, and yi represents the vertical coordinate of the midpoint of the i-th sliding window’s bottom edge. y = cx2 + dx + e (6) y = yi + h Step4. During the k-th search, if there is a lane lines pixel point on the upper edge of the current window, then

n k xk+1 = n1k i=1 xi (7) yk+1 = yk + h If there is no lane lines pixel point in the current window, then xk+1 = xk yk+1 = yk + h

(8)

Improved Lane Line Detection Algorithms

705

The position of the next sliding window can be determined by the coordinates of the midpoint of the lower edge of the next window and the width and height of the window. Step5.When the k-th search is performed, if yk in Mk (xk , yk ) less than 0, it is considered that the top of the lane lines graph has been reached, and the search is stopped. 4.4 Quadratic Polynomial Fitting If the vertical coordinate difference between each pair of adjacent pixels is relatively large, a quadratic polynomial curve can be fitted using the least squares method, and its specific formula is as follows: f (x) = mx2 + nx+ p. Here, m represents the curvature of the lane line, n represents the slope, p represents the intercept. If the vertical coordinate difference between each pair of adjacent pixels is relatively small, a straight line can be fitted using the least squares method, and its specific formula is as follows: f (x) = sx + v. Here, s represents the slope and v represents the intercept. Figure 4(a), 4(b), 4(c) and 4(d) are the lane lines fitting results of Figs. 2(a), 2(b), 2(c) and 2(d). The green box is the result of sliding window, and the yellow line in the middle is the result of lane lines fitting.

Fig. 4. Sliding Window Search Results

706

Q. Ren et al.

5 Experiment In actual driving scenarios, factors such as lane line occlusion and variations in lighting conditions can affect the accuracy of lane line detection [15]. To assess the effectiveness of the proposed algorithm, experiments were conducted using Python 3.6.9 on a PC with an Intel (R) Core (TM) i7-9750H 2.59 GHz CPU. Specifically, the experimental verification was performed on video data from the Udacity self-driving car dataset. 5.1 Actual Lane Detection Results The video was clipped and classified into four types based on road type, yielding typical road, marking interference, light intensity variation, and driving interference. The experimental results are shown in Table 1 and Fig. 5. Based on the experimental results shown in Table 1, which indicate correct rates greater than 90% and error detection rates less than 4%, as well as the detection effect shown in Fig. 5, it can be concluded that the algorithm is minimally affected by interference from road markings, changes in lighting conditions, and driving disturbances, and exhibits good accuracy and robustness. Table 1. Test results of different road types Road type

Number of valid frames

Missing rate/%

False detection rate/%

Accuracy/%

Typical road, Identify interference

1260

1.98

1.04

96.98

Light intensity change

1175

4.44

2.56

93.00

Traffic interference

480

4.22

3.00

92.78

5.2 Analysis and Comparison with Other Lane Detection Methods In order to verify the performance of the algorithm proposed in this paper, we conducted comparative testing with methods from the literature [16, 17], and [18], using the same set of samples (1260 frames of effective video on typical roads). From the results in Table 2, it can be concluded that the detection accuracy of this algorithm is higher than the other three algorithms, but its real-time performance is slightly lower than the improved SegNet algorithm. But it has better real-time performance than edge feature point clustering algorithm and improved Hough transform coupling algorithm.

Improved Lane Line Detection Algorithms

(a) Typical road

707

(b) Identification of interference

(c) Traffic interference

(d) Light intensity variation

Fig. 5. Lane lines detection results of different road types

Table 2. Algorithm Evaluation Comparison Algorithm

Accuracy/%

Single frame processing time/ms

Edge feature point clustering algorithm

92.01

36.1

Improved SegNet algorithm

96.50

18.0

Improved Hough transform coupling

95.25

30.0

The improved lane detection algorithm in this paper

96.98

20.02

6 Conclusion In this paper, we present a lane detection algorithm that achieves high accuracy, real-time performance, and robustness. This algorithm has significant potential in applications such as automatic driving and traffic safety, where it can compensate for the shortcomings of deep learning algorithms that lack real-time capabilities. The key innovations of this paper are as follows: 1. Sobel edge detection is applied from a bird’s eye view to eliminate the influence of irrelevant information, increase the detection speed, and improve the algorithm’s real-time performance.

708

Q. Ren et al.

2. The sliding window search model is improved by dynamically selecting a quadratic curve or a straight line to fit the vacant part of the lane lines in the current window, thereby enhancing the algorithm’s real-time performance and robustness. 3. When the lane lines detected by the sliding window are fitted using the least square method, the algorithm dynamically selects between a quadratic curve or a straight line to fit the lane lines, further improving its real-time performance. Experimental results demonstrate that the algorithm is less affected by sign interference, light intensity changes, and traffic interference, exhibiting high correctness and robustness. Although the algorithm in this article performs well in detection, the lane detection may not be effective when the video is blurry due to high-speed movement of the vehicle, which will be the focus of the next research.

References 1. Guo, L., Wang, J.Q., Li, K.Q.: Lane keeping system based on THASV-II platform. In: IEEE International Conference on Vehicular Electronics and Safety, pp. 305–308 (2006) 2. Hillel, A.B., Lerne, R.R., Dan, L., et al.: Recent progress in road and lane detection: a survey. Mach. Vis. Appl. 25(3), 727–745 (2014) 3. Wang, Y.Z.H., Wang, X.Y., Wen, C.H.L.: Gradient-pair constraint for structure lane detection. J. Image Graph. 17(6), 657–663 (2012) 4. Peng, H., Xiao, J.S.H., Cheng, X., et al.: Lane detection algorithm based on extended Kalman filter. J. Optoelectron.·Laser (3), 567–574 (2015) 5. Lee, S., Kim, J., Yoon, J.S., et al.: VPGNet: vanishing point guided network for lane and road marking detection and recognition. In: IEEE International Conference on Computer Vision (ICCV), pp. 1965–1973 (2017) 6. Wang, K.N., Chu, X.M., Zhang, W.G., et al.: Curved lane detection algorithm based on piecewise linear model and heuristic search. J. Electron. Meas. Instrum. 27(8), 689–695 (2013) 7. Tan, T., Yin, S., Ouyang, P., et al.: Efficient lane detection system based on monocular camera. In: IEEE International Conference on Consumer Electronics (ICCE), pp. 202–203 (2015) 8. Wang, J., Gu, F., Zhang, C., et al.: Lane boundary detection based on parabola model. In: IEEE International Conference on Information and Automation, pp. 1729–1734 (2010) 9. Chen, Q., Wang, H.: A real-time lane detection algorithm based on a hyperbola-pair model. In: IEEE Intelligent Vehicles Symposium, pp. 510–515 (2006) 10. Xu, H.R., Wang, X.D., Fang, Q.: Structure road detection algorithm based on B-spline curve model. Acta Automatica Sinica 37(3), 270–275 (2011) 11. Wang, Y., Teoh, E.K., Shen, D.: Lane detection and tracking using B-snake. Image Vis. Comput. 22(4), 269–280 (2004) 12. Huval, B., Wang, T., Tandon, S., et al.: An empirical evaluation of deep learning on highway driving. Robtics. arXiv: 1504.01716 (2015) 13. Li, X., Wu, Q., Kou, Y., et al.: Lane detection based on spiking neural network and hough transform. In: The 8th International Congress on Image and Signal Processing (CISP), pp. 626–630 (2015) 14. Gopalan, R., Hong, T., Shneie, R.M., et al.: A learning approach towards detection and tracking of lane markings. IEEE Trans. Intell. Transp. Syst. 13(3), 1088–1098 (2012) 15. McCall, J.C., Trivedi, M.M.: Video-based lane estimation and tracking for driver assistance: survey, system, and evaluation. IEEE Trans. Intell. Transp. Syst. 7(1), 20–37 (2006)

Improved Lane Line Detection Algorithms

709

16. Deng, T.M., Wang, L., Yang, Q.Z., et al.: Lane detection method based on improved SegNet algorithm. Sci. Technol. Eng. 20(36), 14988–14993 (2020) 17. Lv, K.H., Zhang, D.X.: Lane detection algorithm based on improved Hough transform coupled with density space clustering. J. Electron. Meas. Instrum. 34(12), 172–180 (2020) 18. Qin, X.Z., Lu, R.Y., Chen, L.M., et al.: Research on multi-scene lane detection and deviation warning method. Mech. Sci. Technol. 39(9), 1439–1449 (2020)

A Deep Transfer Fusion Model for Recognition of Acute Lymphoblastic Leukemia with Few Samples Zhihua Du1 , Xin Xia1(B) , Min Fang2 , Li Yu3 , and Jianqiang Li1 1 Shenzhen University, Shenzhen, People’s Republic of China

{duzh,lijq}@szu.edu.cn, [email protected]

2 Education Center of Experiments and Innovations, Harbin Institute of Technology

(ShenZhen), Shenzhen, China [email protected] 3 Shenzhen University Health Science Center, Shenzhen, China [email protected]

Abstract. Distinguishing between different sub-classes of Acute Lymphoblastic leukemia (ALL) based on morphological differences in blood smear images is challenging. Deep learning methods have been successful for morphological classification. This paper aims to develop a deep transfer fusion model (TFDNet) to predict ALL sub-classes using a few blood cell images. TFDNet is a customized Convolutional Neural Network (CNN) that consists of two transfer learning modules, Xception and Dense, working in parallel for feature extraction. TFDNet then utilizes a two-branch feature extraction layer to fuse the multi-scale features for the diagnosis of ALL sub-classes. To evaluate the effectiveness and generalizability of TFDNet, we compare it with seven state-of-the-art methods on five different datasets, including three small-sample ALL types, as well as skin cancer and brain cancer data. The experimental results demonstrate that TFDNet outperforms the seven state-of-the-art methods on all five datasets. Keywords: ALL · deep learning · CNN · transfer learning · diagnosis

1 Introduction Acute lymphoblastic leukemia (ALL) [1] is a type of cancer that commonly affects children and adults. Early and rapid diagnosis of ALL sub-types is crucial for providing preoperative diagnosis, treatment choice, and postoperative prognosis for patients. Common methods for determining ALL sub-classes, such as complete blood count and peripheral smear examination, rely on the expertise of experienced professionals, resulting in subjective evaluation. Furthermore, manual analysis of a large number of histopathology image data [2] is time-consuming, laborious, and costly. Therefore, there is increasing interest in developing methods that use image analysis and pattern recognition techniques to quantify and identify leukocytes [3–5], thereby reducing the need for manual analysis. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 710–721, 2023. https://doi.org/10.1007/978-981-99-4742-3_59

A Deep Transfer Fusion Model for Recognition

711

One traditional approach involves using hand-crafted feature engineering to extract image features, which are then classified using mainstream machine learning methods such as SVM, KNN, K-means, and random forest. However, these methods lack reusability and portability for other categories of image recognition, as the design of good manual feature engineering is usually specific to the image. In recent years, deep learning has achieved significant breakthroughs in different fields, such as computer vision, pattern recognition, and image analysis [6–8]. Convolutional neural networks (CNNs) can automatically and efficiently extract input image features for analysis, making them widely used in the field of image classification. However, the performance of neural networks often depends on sufficient training samples, which are often limited in the field of medical images, especially for ALL data. To address this challenge, this paper proposes a transfer fusion model framework named TFDNet, which uses a few blood smear images to realize the rapid and accurate identification and diagnosis of medical images, such as ALL sub-classes, to achieve the purpose of early symptomatic treatment for patients. TFDNet is a customized CNN that extracts Xception [9] and Dense [10] features of ALL images in parallel and utilizes a two-branch feature extraction layer to fuse the multi-scale features for classification, demonstrating its potential for application in clinically realistic environments with few training data. 1.1 Related Work For traditional methods, several studies have proposed methods for recognizing leukemia images that rely on manual features such as shape, color, and texture. Reta [11] proposed a method for distinguishing between ALL and AML by segmenting blood cells using contextual color and texture information to identify nuclear and cytoplasmic regions and separate overlapping blood cells. Similarly, other methods proposed by Putzu [12], Patel and Mishra [13], and Singhal [14] for recognizing leukemia images also rely on manual features. Laosai and Chamnongthai [15] proposed an AML classification framework that extracts morphological features such as cell size and color for classification using k-means and contour feature methods to segment kernels. Patel and Mishra [13] proposed an automatic classification system for leukemia that preprocesses images by denoising and filtering before segmenting cells using k-means and Zack. Finally, color, statistical, geometric, and texture features are extracted and used for SVM classification. Singhal [16] implemented an effective algorithm for feature extraction based on local binary (LBP) and geometric characteristics. The extracted features are then used for classification using SVM. However, the manual feature engineering design has a great impact on the performance of image recognition, and it also requires manual intervention. Therefore, there is a need for a better method that can automatically extract features for better ALL diagnosis and can be transferred to other categories of pathological images. Deep neural networks are increasingly being used due to their ability to automatically extract features and their good performance in many recognition tasks. Recently, many diagnostic methods using deep learning to diagnose cancer from histopathology images have been proposed, with impressive results [6, 17, 18]. Furthermore, computer analysis systems based on deep neural networks have been successfully developed and applied to the

712

Z. Du et al.

recognition of various pathological images, including gastric cancer [17], lung cancer [6], and esophageal cancer [18], significantly reducing the workload of pathologists. In the field of ALL recognition, several studies have used deep neural networks and transfer learning to address the challenge of limited sample data. Rehman [19] and Shafique [20] used a transfer learning model based on the AlexNet network to identify ALL image data by completing the recognition task with fewer training iterations. ANDRES [21] proposed a method that uses data augmentation and transfer learning for brain tumor detection, achieving excellent results. Tatdow [22] implemented a convolutional neural network method using a data augmentation technique to train his proposed model by expanding the amount of training data by 20 times. Habibzadeh [23] proposed a classification model based on transfer learning and deep learning. This method starts with preprocessing the data set, and then uses the transfer learning method for feature extraction. Finally, Inception and ResNet were used for white blood cell classification. In summary, while traditional methods for ALL recognition rely on manual feature engineering, deep neural networks and transfer learning have shown great potential for improving the accuracy and efficiency of ALL diagnosis, especially in cases with few sample data. The proposed TFDNet model builds on these studies by using a transfer fusion model framework that can extract multiple features of pathological images and then fuse them to obtain more comprehensive features for classification.

2 Proposed Network Framework Our proposed model (TFDNet) model is divided into three main parts: the image processing module, the feature extraction module, and the feature fusion and classification module. An overview of the TFDNet framework is shown in Fig. 1.

Fig. 1. Overview of the TFDNet. a). Image preprocesses. b). Feature extraction module, which extracts the Xception features and Dense feature in parallel. b1). Xception feature extraction module. b2). Dense feature extraction module. c). Feature fusion module and classification.

1) image processing module: In this study, we use the ImageDataGenerator function, which is an image generator for keras in the high-level Neural Network API, to apply data augmentation to our training strategy. The parameters used by ImageDataGenerator function include horizontal flips, scaling between 0.8 and 1.2, clipping within

A Deep Transfer Fusion Model for Recognition

713

0.2 radians in the clockwise direction, and random rotation of 4°. This ensures that TFDNet fully extracts image features without increasing the training cost. 2) feature extraction module: The feature extraction module consists of two transfer learning [24] sub-modules, Xception and Dense, which extract features from the input image data. We chose Xception and Dense models because they have high recognition rates on natural images, are smaller in size, have fewer parameters, and a relatively small network depth compared to other advanced models. We conducted experiments on the ALL dataset, comparing them with ResNet50V2, InceptionV3, and InceptionResNetV2 on 10-fold cross-validation and 2-fold cross-validation with a smaller training set. Both Xception and Dense models showed greater advantages. Then the input data is fed into the two sub-modules, and the Xception and Dense features of image data are extracted in parallel. The obtained multi-scale features are then input to the next module. These sub-modules are pretrained models on the ImageNet dataset, with all fully connected layers removed after the last convolutional layer. The remaining structure becomes the corresponding feature extraction layers, which are learnable. We can load the parameters of the pretrained model, reducing the number of times we need to retrain on new datasets. 3) feature fusion module: The feature fusion module takes the obtained multi-scale features, namely Xception features and Dense features, flattens them separately and outputs them to the same dimension. Then, the features are regularized and added together to obtain the fused features. The result of the fused features is obtained using the following equation: F_feature = N (X _feature) + N (D_feature)

(1)

where N denotes the regularization of the obtained features, X_features and D_features represent the extracted Xception features and Dense features, respectively. We then train a softmax classifier using the fused features. For the loss function, we use binary cross-entropy function for binary classification data and multivariate cross-entropy function for multi-class classification data. And Their definitions are as follows: N 1  i [y ln(ˆyi ) + (1 − yi ) ln(1 − yˆ i )] B_loss = − N

(2)

i=1

M _loss = −

c 

yi ln yˆ i

(3)

n=1

where yˆ i is the predicted value of the model for the ith sample, N is the number of training set samples, and yi is the label value of the ith sample. And c is the number of categories of labels. 4) Network parameters: The TFDNet model has 39,416,939 parameters, in-cluding 38,792,971 trainable parameters and 623,968 non-trainable parameters. We set the model to have no improvement on the validation set for 5 consecutive times, that is, to stop the training. This technique helps to prevent the model from memorizing the training data and losing its generalization performance.

714

Z. Du et al.

3 Experiments and Results 3.1 Dataset Description Five public datasets are used to train and validate TFDNet, three of which belong to the ALL data set, and the remaining are skin cancer and brain tumor datasets to verify the generalization performance of TFDNet. 1) ALL: The dataset was obtained from Tatdow Pansombut et al. [22]. It consists of 363 blood smear images, including 93 normal cell images and 135 images each of two subtypes of ALL: T-cell prolymphocytic leukemia (pre-T) and B-cell prolymphocytic leukemia (pre-B). Each image contains only one cell, and the size is 256 × 256 pixels. 2) ALL-IDB2: The dataset is an open-source dataset provided by Università degli Studi di Milano [25]. It contains a total of 260 images, each of which contains a central cell to be identified. There are two categories: 130 healthy images and 130 abnormal images. All images have a pixel size of 2592 × 1944. 3) ISBI2019 + ALL: The dataset is based on the first dataset. To eliminate the error caused by data imbalance, we integrated part of the normal morphology cells obtained from the ALL challenge of ISBI 2019 [26] into the first dataset. It contains 405 images, including 135 images each of B-type, T-type, and normal cells. The settings for each image are the same as the images in the ALL dataset. 4) Skin Cancer: The Skin Cancer dataset from Reference [27] is a collection of JPEG files containing either cancerous or non-cancerous skin images. It contains 288 images, including 84 cancer images and 204 cancer-free images. The image pixel size in the dataset is not consistent. 5) Brain Tumor: The brain tumor dataset from Reference [28] contains 3064 T1weighted contrast-enhanced images from 233 patients with three kinds of brain tumors: meningioma (708 slices), glioma (1426 slices), and pituitary tumor (930 slices). 3.2 Evaluation Metrics To evaluate the performance of different methods, we use precision (PRE), recall (REC), accuracy (ACC), area under curve (AUC), and average precision (AP) as the evaluation metrics. PRE refers to the proportion of cases in which the model predicts positive cases. REC refers to the proportion of cases that the model predicts to be positive in cases that are actually positive. ACC refers to the proportion of cases predicted by the model to be positive. AUC is the area under the ROC curve, and the ROC curve is to evaluate the performance of the model. AP, which represents the proportion of cases predicted by the model as positive cases in which the prediction is positive. The larger these five metrics, the better the performance of the model. 3.3 Model Evaluation and Comparison Result In this subsection, we compare the classification performance of TFDNet with seven state-of-the-art models, including SVM, ConVnet [22], AlexNet [29], Xception [9], LeukNet [30], ALNet [30] and Dense [10], and conduct two sets of experiments on

A Deep Transfer Fusion Model for Recognition

715

the ALL dataset to explore its performance. In the first set of experiments, we trained seven advanced models under their respective paper conditions with data augmentation methods to increase the number of samples, and we trained the TFDNet model under the conditions proposed in our paper. In the second set of experiments, we trained all methods under the training condition proposed in our paper, which uses data augmentation by ImageDataGenerator function and early stopping, to evaluate the effect on model performance. Tables 1 and 2 present the results of the five evaluation metrics for all models. Table 1 shows that TFDNet outperforms the second best model Dense by 2.1%, 0.9%, 2.3%, 2.1%, and 2.0% on the five indicators ACC, AUC, AP, Precision, and Recall. In Table 2, most of the evaluate metrics of the seven state-of-the-art models except ACC of ALNet and Dense are improved by using our proposed training strategy including data augmentation and early stopping. For clarity, we present the corresponding result plots of the two sets of experiments in Figs. 2, 3, where (a) and (b) represent the results of the first and second experiments, respectively. Table 1. Results for all models on their own paper training methods. Method

Acc

Auc

Ap

Precision

Recall

Svm

0.555

0.666

0.481

0.631

0.555

ConVnet

0.686

0.852

0.778

0.733

0.698

AlexNet

0.729

0.865

0.783

0.752

0.741

LeukNet

0.824

0.932

0.884

0.844

0.840

Xception

0.843

0.950

0.924

0.873

0.858

ALNet

0.848

0.942

0.891

0.866

0.861

Dense

0.873

0.966

0.945

0.883

0.885

TFDNet

0.894

0.975

0.968

0.904

0.905

Table 2. Results for all models on our proposed training method Method

Acc

Auc

Ap

Precision

Recall

Svm

0.724

0.793

0.652

0.737

0.747

ConVnet

0.742

0.888

0.812

0.775

0.743

AlexNet

0.791

0.933

0.865

0.812

0.810

ALNet

0.821

0.956

0.911

0.840

0.840

LeukNet

0.824

0.957

0.903

0.850

0.842

Dense

0.867

0.972

0.947

0.882

0.881

Xception

0.883

0.973

0.951

0.896

0.895

TFDNet

0.894

0.975

0.968

0.904

0.905

716

Z. Du et al.

It is shown in Fig. 2 that the height of the rectangle represents the range of accuracy, and the middle horizontal line represents the average accuracy. The third column of Fig. 2 shows that the average and optimal accuracy of TFDNet are the highest and relatively stable. Figure 3 shows the two-dimensional representation of graph features obtained by the optimal models obtained by the eight methods through the test set in the process of cross-validation. The first row and the third column of the figure represent the results of the TFDNet model. It can be seen that the TFDNet model clearly distinguishes the three types of graphs, and only a few basic samples overlap at the boundary between type pre-B and type pre-T. In other models, the number of overlapping samples at the other methods is too much.

(a)

(b)

Fig. 2. Box plot of cross-validated accuracy for all models.

Fig. 3. T-SNE plots for all models, Where the number of clusters is equal to the number of classes. Red represents the pre-B type, blue represents the pre-T type, white represents the normal type.

3.4 Generalization Performance To validate the robustness of TFDNet, we conducted comparative experiments on the other two ALL datasets with few samples, namely ALL-IDB2 and ISBI2019+ALL.

A Deep Transfer Fusion Model for Recognition

717

Table 3 shows that TFDNet outperforms other state-of-the-art models on the ALL-IDB2 dataset, with the exception of a slight gap with the Dense model. However, this result is still acceptable, as TFDNet achieves a classification accuracy of 97.3%. Table 4 shows that TFDNet achieves the highest evaluation on ACC (0.913), AUC (0.986), AP (0.913), Precision (0.917), and Recall (0.913), outperforming the second best model Xception by 2.9%, 1.0%, 2.9%, 3.0%, and 2.9%, respectively, on the corresponding indicators. Thus, TFDNet can effectively identify ALL image data with few samples, including both single-cell and multi-cell images with a center cell. Table 3. Result of all models in ALL-IDB2 Method

Acc

Auc

Ap

Precision

Recall

Svm

0.736

0.736

0.677

0.742

0.736

ConVnet

0.832

0.928

0.942

0.844

0.832

AlexNet

0.936

0.973

0.974

0.942

0.936

LeukNet

0.950

0.984

0.996

0.954

0.950

ALNet

0.964

0.992

0.994

0.967

0.964

Xception

0.968

0.993

0.997

0.972

0.968

TFDNet

0.973

0.996

0.997

0.974

0.973

Dense

0.982

0.997

0.998

0.983

0.982

Table 4. Result of all models in ISBI2019+ALL Method

Acc

Auc

Ap

Precision

Recall

ConVnet

0.738

0.873

0.738

0.748

0.738

Svm

0.765

0.824

0.765

0.761

0.765

AlexNet

0.798

0.945

0.798

0.813

0.798

ALNet

0.837

0.959

0.837

0.841

0.837

LeukNet

0.851

0.961

0.851

0.857

0.851

Dense

0.874

0.969

0.874

0.890

0.874

Xception

0.884

0.976

0.884

0.887

0.884

TFDNet

0.913

0.986

0.913

0.917

0.913

In addition to ALL Image Classification, we also aim to explore other pathological image recognition tasks. We conducted multiple experiments on the Skin Cancer and Brain Tumor datasets to compare and analyze the performance of TFDNet with other state-of-the-art models and to evaluate the generalization performance of our model. Table 5 shows that TFDNet outperforms other state-of-the-art models on the Skin Cancer dataset. However, the classification accuracy of all models did not exceed 90%,

718

Z. Du et al. Table 5. Result of all models in skin cancer data

Method

Acc

Auc

Ap

Precision

Recall

ConVnet

0.703

0.875

0.823

0.760

0.701

Svm

0.725

0.794

0.657

0.745

0.749

AlexNet

0.772

0.923

0.865

0.811

0.784

LeukNet

0.817

0.914

0.890

0.847

0.829

ALNet

0.817

0.945

0.887

0.838

0.832

Xception

0.833

0.956

0.926

0.857

0.856

Dense

0.845

0.962

0.932

0.868

0.879

TFDNet

0.888

0.972

0.944

0.898

0.900

which may be related to the fact that the test set has 162 non-cancer images and only 42 cancer images. Table 6 presents great results for the Brain Tumor dataset, with TFDNet achieving higher performance metrics than the second best model Xception, with improvements of 0.8% in ACC, 0.7% in AUC, 1.2% in AP, 1.2% in Precision, and 1.3% in Recall. Overall, TFDNet demonstrates promising potential for application in clinically realistic environments with few training data. Table 6. Result of all models in brain tumor data Method

Acc

Auc

Ap

Precision

Recall

Svm

0.781

0.823

0.814

0.872

0.864

ConVnet

0.794

0.898

0.850

0.907

0.877

ALNet

0.846

0.884

0.891

0.932

0.904

AlexNet

0.862

0.908

0.890

0.930

0.906

LeukNet

0.871

0.903

0.912

0.935

0.929

Dense

0.903

0.952

0.936

0.965

0.946

Xception

0.927

0.970

0.963

0.971

0.960

TFDNet

0.935

0.977

0.975

0.983

0.973

3.5 Image Feature Visualization To evaluate the interpretability and performance of TFDNet, we generated heat maps using the feature maps of the convolutional layers, inspired by Grad-CAM [32]. Next, we calculated the average of the gradient map, which corresponds to the weight of each feature map. Figure 4 shows the heat maps generated by TFDNet for the three types of images and their corresponding original images on the ALL dataset, combined to visualize the

A Deep Transfer Fusion Model for Recognition

719

model’s feature activations. The feature heat maps generated by TFDNet for prediction are darker in the middle of the cell, where the nucleus and chromatin are concentrated, while other cell edges are lighter in color, indicating that these are not features of interest to our model. The middle part of the cell is often where the nucleus and chromatin are concentrated, which is consistent with the medical classification of different ALL subtypes based on the nucleus, chromatin, and nucleolus. Thus, our model can effectively learn the features that distinguish different ALL sub-classes.

Fig. 4. Examples of activation maps. Blue tones mean low activation and show that the corresponding area for classification is not important, which the model doesn’t care about; in contrast, red tones are implied as the most critical regions that can be used for the final classification, which is the focus and the classification basis of the model.

4 Conclusion In summary, we propose TFDNet, a novel fusion model based on transfer learning, that effectively classifies ALL images with few training data and provides interpretable classification basis through visualization. We also explore TFDNet’s generalization performance on different ALL images and other categories of histopathological images. Our study demonstrates that model fusion for extracting multi-scale image features is an effective recognition method for histopathological image analysis with few training samples, including ALL data recognition and other categories. Acknowledgment. This work was supported by the National Key R&D Program of China under Grant 2020YFA0908700, the National Nature Science Foundation of China under Grant 62176164, 62203134, the Natural Science Foundation of Guangdong Province under Grant 2023A1515010992, Science and Technology Innovation Committee Foundation of Shenzhen City under Grant JCYJ20220531101217039. Availability and Implementation. All the code and image data in this paper are available at https://github.com/xin242328/TFDNet, which allows researchers to replicate our experiments to

720

Z. Du et al.

verify the results and use the methods and data for further research in their own fields. And All models were implemented using tensorflow-gpu (version 2.4.0), and all training processes were trained on the GPU (GeForce GTX 1080) in Linux.

References 1. Steven, H., et al.: The 2016 revision of the World Health Organization classification of lymphoid neoplasms. Blood 127(20), 2375–2390 (2016) 2. Li, Z., Zhang, P., Xie, N., Zhang, G., Wen, C.-F.: A novel three-way decision method in a hybrid information system with images and its application in medical diagnosis. Eng. Appl. Artif. Intell. 92, Article no. 103651 (2020) 3. Putzu, L., Caocci, G., Ruberto, C.D.: Leucocyte classifification for leukaemia detection using image processing techniques. Artif. Intell. Med. 62(3), 179–191 (2014) 4. Saraswat, M., Arya, K.V.: Automated microscopic image analysis for leukocytes identifification: a survey. Micron 65, 20–33 (2014) 5. Nazlibilek, S., Karacor, D., Ercan, T., Sazli, M.H., Kalender, O., Ege, Y.: Automatic segmentation, counting, size determination and classifification of white blood cells. Measurement 55, 58–65 (2014) 6. Coudray, N., et al.: Classifification and mutation prediction from non–small cell lung cancer histopathology images using deep learning. Nat. Med. 24(10), 1559–1567 (2018) 7. Yari, Y., Nguyen, T.V., Nguyen, H.T.: Deep learning applied for histological diagnosis of breast cancer. IEEE Access 8, 162432–162448 (2020) 8. Sermanet, P., Eigen, D., Zhang, X., Mathieu, M., Fergus, R., LeCun, Y.: OverFeat: integrated recognition, localization and detection using convolutional networks. arXiv:1312.6229 (2013) 9. Chollet, F.: Xception: deep learning with depthwise separable convolutions (2017) 10. Huang, G., Liu, Z., van der Maaten, L.: Densely connected convolutional networks (2018) 11. Reta, C., Robles, L.A., Gonzalez, J.A., Diaz, R., Guichard, J.S.: Segmentation of bone marrow cell images for morphological classification of acute leukemia. In: Proceedings of the 23rd International FLAIRS Conference, Daytona Beach, FL, USA, May 2010 12. Putzu, L., Caocci, G., Ruberto, C.D.: Leucocyte classifification for leukaemia detection using image processing techniques. Artif. Intell. Med. 62, 179–191 (2014) 13. Patel, N., Mishra, A.: Automated leukaemia detection using microscopic images. Procedia Comput. Sci. 58, 635–642 (2015) 14. Singhal, V., Singh, P.: Texture features for the detection of acute lymphoblastic leukemia. In: Satapathy, S.C., Joshi, A., Modi, N., Pathak, N. (eds.) Proceedings of International Conference on ICT for Sustainable Development. AISC, vol. 409, pp. 535–543. Springer, Singapore (2016). https://doi.org/10.1007/978-981-10-0135-2_52 15. Laosai, J., Chamnongthai, K.: Acute leukemia classification by using SVM and K-means clustering. In: Proceedings of the 2014 IEEE International Electrical Engineering Congress (iEECON), Chonburi, Thailand, 19–21 March 2014, pp. 1–4 (2014) 16. Singhal, V., Singh, P.: Local binary pattern for automatic detection of acute lymphoblastic leukemia. In: Proceedings of the 2014 Twentieth National Conference on Communications (NCC), Kanpur, India, 28 February–2 March 2014, pp. 1–5 (2014) 17. Song, Z., et al.: Clinically applicable histopathological diagnosis system for gastric cancer detection using deep learning. Nat. Commun. 11(1), 1–9 (2020) 18. Gehrung, M., Crispin-Ortuzar, M., Berman, A.G., O’Donovan, M., Fitzgerald, R.C., Markowetz, F.: Triage-driven diagnosis of Barrett’s esophagus for early detection of esophageal adenocarcinoma using deep learning. Nat. Med. 27(5), 833–841 (2021)

A Deep Transfer Fusion Model for Recognition

721

19. Rehman, A., Abbas, N., Saba, T., Rahman, S.I.u., Mehmood, Z., Kolivand, H.: Classification of acute lymphoblastic leukemia using deep learning. Microsc. Res. Tech. 81, 1310–1317 (2018) 20. Shafifique, S., Tehsin, S.: Acute lymphoblastic leukemia detection and classification of its subtypes using pretrained deep convolutional neural networks. Technol. Cancer Res. Treat. 17(1533033818802789) (2018). 31 21. Anaya-Isaza, A., Mera-Jiménez, L.: Data augmentation and transfer learning for brain tumor detection in magnetic resonance imaging, 24 February 2022 22. Pansombut, T., Wikaisuksakul, S., Khongkraphan, K., Phon-On, A.: Convolutional neural networks for recognition of lymphoblast cell images. Comput. Intell. Neurosci. 2019 (2019) 23. Habibzadeh, M., Jannesari, M., Rezaei, Z., Baharvand, H., Totonchi, M.: Automatic white blood cell classification using pre-trained deep learning models: ResNet and inception. In: Proceedings of the Tenth International Conference on Machine Vision (ICMV 2017), International Society for Optics and Photonics, Vienna, Austria, 13–15 November 2017, vol. 10696, p. 1069612 (2017) 24. Pan, S.J., Yang, Q.: A survey on transfer learning. IEEE Trans. Knowl. Data Eng. 22(10), 1345–1359 (2010) 25. Labati, R.D., Piuri, V., Scotti, F.: The acute lymphoblastic leukemia image database for image processing. In: 18th IEEE International Conference on Image Processing (ICIP); Università degli Studi di Milano, Department of Information Technology, via Bramante 65, vol. 26013, pp. 2089–2092 (2011) 26. Honomichl, N.: The cancer imaging archive (TCIA), C_NMC_2019 dataset: All challenge dataset of ISBI (2019).https://wiki.cancerimagingarchive.net/pages/viewpage.action? pageId=52758223 27. Kaggle (2023). https://www.kaggle.com/datasets/kylegraupe/skin-cancer-binary-classific ation-dataset 28. Kaggle (2021). https://www.kaggle.com/datasets/denizkavi1/brain-tumor 29. Loey, M., Naman, M., Zayed, H.: Deep transfer learning in diagnosing leukemia in blood cells, 15 April 2020 30. Vogado, L., et al.: Diagnosis of leukaemia in blood slides based on a fine-tuned and highly generalisable deep learning model. Sensors 21, 2989 (2021) 31. Boldú, L., Merino, A., Acevedo, A., Molina, A., Rodellar, J.: A deep learning model (ALNet) for the diagnosis of acute leukaemia lineage using peripheral blood cell images (2021) 32. Selvaraju, R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Grad-CAM: visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 618–626 (2017)

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules Based on Intra-image and Inter-image Semantic Information Jie Gao1,2,3(B) , Shaoqi Yan1,2,3 , Xuzhou Fu1,2,3 , Zhiqiang Liu1,2,3 , Ruiguo Yu1,2,3,4 , and Mei Yu1,2,3 1 College of Intelligence and Computing, Tianjin University, Tianjin, China

[email protected]

2 Tianjin Key Laboratory of Cognitive Computing and Application, Tianjin, China 3 Tianjin Key Laboratory of Advanced Networking, Tianjin, China 4 Tianjin International Engineering Institute, Tianjin University, Tianjin, China

Abstract. In diagnosing thyroid nodules, weakly supervised semantic segmentation methods alleviate the dependence on pixel-level segmentation labels. However, existing methods underutilize the image information of thyroid nodules under the supervision of image-level labels, which is reflected in intra-image and inter-image. Firstly, the imaging quality of ultrasound images is poor, making the model hardly mining semantic information, leading to misclassification in background. We propose an equivariant attention mechanism to enhance the nodules and background, enabling the model to extract more accurate semantic information intra-image. Secondly, thyroid nodules have fine-grained properties such as cystic, solid, and calcified, existing methods ignore the semantic information between different fine-grained nodules, making it difficult for the model to learn a comprehensive feature representation. We propose to collect the features of nodules and background in the dataset through a memory pool and provide the connections between these features through semantic sharing and contrast. Experiments on the TUI dataset show that our method significantly outperforms existing methods, with mIoU and Dice scores improving to 58.0% and 73.5%. Keywords: Deep Learning · Medical Image · Thyroid Nodule · Weakly Supervised Semantic Segmentation

1 Introduce With the rapid development of deep learning, Computer-Aided Diagnosis (CAD) has become increasingly popular in medical imaging analysis [15]. Medical image segmentation, as an essential component of CAD requires many pixel-level segmentation labels and incurs high costs. Some researchers have focused on Weakly Supervised Semantic Segmentation (WSSS), which trains using only image-level labels (i.e., category data) to alleviate the pressure of obtaining datasets. Existing research maps the category vector to pixel regions in the image through Class Activation Map (CAM) [18]. However, due to © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 722–734, 2023. https://doi.org/10.1007/978-981-99-4742-3_60

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

723

the poor imaging quality and low contrast of thyroid nodule ultrasound images, existing WSSS methods hardly extract effective information that enables models to distinguish nodules from the background accurately. Moreover, due to various fine-grained nodules in the dataset, there are significant differences in size, shape, and location within the same category of nodules. Existing WSSS methods are limited to mining information within a single image, ignoring semantic information between images, which makes it difficult to learn a comprehensive representation of thyroid nodules, resulting in significant deviations in performance on different thyroid nodule images and affecting the overall segmentation results. On the contrary, there is a similarity relationship between the backgrounds of different thyroid nodule ultrasound images. Optimizing background recognition can also promote the segmentation of nodules, which existing methods have not fully utilized. Yu et al. [12] highlighted thyroid nodules in the feature maps through soft erasure to enable the model to learn the nodule features. However, it cannot help the model extract features of weak discriminative nodule regions. Yu et al. [13] used an edge attention mechanism to enhance the extraction of edge features of thyroid nodules, allowing the model to segment more complete nodules of different sizes. Nevertheless, due to the weak discriminative ability of the model, the constraint on the nodule edge has limited effectiveness, and the model cannot learn a comprehensive feature representation. In natural image, some utilizes low-level features, such as texture and color [1, 3, 8, 9, 16], to help the model learn the relationships between pixels in the image, but they are unable to capture deep contextual relationships. Some researchers employ region erasure methods [4, 6, 11–13, 17] erase the strong discriminative object on the feature maps or image to helps the model recognize more objects in weak discriminative, but it may laad to over-erasure. Some researchers introduce self-supervised learning [2, 10, 14], which enhances the image and constraining the consistency between the CAM of the original and augmented images, thus providing additional supervision for the model. However, they only employ random augmentation techniques, which may not provide sufficient information. Another approach explores cross-image relationships [5, 8, 19], providing additional context information inter-images during training, but they do not place enough emphasis on exploring the background and intra-image information. This paper presents a novel method for WSSS that utilizes both intra-image and inter-image semantic information, achieving accurate segmentation of thyroid nodules in ultrasound images, as shown in Fig. 1. To extract intra-image semantic information, we propose an Equivariant Attention Mechanism (EAM) that enhances the features of both nodules and background, enabling the model to extract more effective information from the image in real-time. For inter-image semantic information, we introduce Semantic Contrast and Sharing Methods based on Foreground and Background features (SCSM-FB). This approach collects features of nodules and background extracted by the model during training and provides the model with semantic relationships between features across images. SCSM-FB reduces intra-class feature distance while expanding inter-class feature distance, allowing the model to learn a more comprehensive feature representation and enhancing its robustness.

724

J. Gao et al.

In summary, our main contributions are as follows: (1) To address the issue of the model’s difficulty in extracting effective information within the image, we propose EAM mine more information from the foreground and background of the image, narrowing the gap between the supervision information of WSSS and fully supervised semantic segmentation, thus improving the model’s recognition ability. (2) To address the problem of the large intra-class variation in thyroid nodules, we explore the cross-image relationships commonly overlooked in WSSS and provide inter-image semantic information through SCSM-FB, enabling the model to learn more comprehensive feature representations. (3) Our proposed method combines semantic information intra- and inter-image. Its segmentation performance is significantly better on the thyroid nodule ultrasound image dataset (TUI) than existing methods, with a mIoU of 58.0% and a Dice of 73.5%.

Fig. 1. Overview of our method.

2 Methods We first generate an original CAM for the original image and predict the category vector with the help of a Global Average Pooling (GAP). Then, we generate the corresponding FB-Pair for the original image through EAM and constrain the consistency on the CAM of the FB-Pair and the original image. Finally, we use Memory Pool (MP) based on the foreground and background features to store the features of each category of objects, enabling learning a more comprehensive feature representation through semantic sharing and comparison. The overall structure of IntrNet integrating all of our methods is shown in Fig. 2.

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

725

2.1 EAM Compared with pixel-level labels, the supervision information provided by image-level labels is insufficient for the model to identify thyroid nodules accurately. In addition, the ultrasonic images of thyroid nodules have the characteristics of low semantic information, and it is difficult for the model to mine effective information in the images.

Fig. 2. The overall structure of IntrNet.

For the above problems, we propose EAM, which fully uses the information intraimage by constraining the consistency of the objects identified in the foreground area, background area and original image. Suppose there is an ideal semantic segmentation model Fideal , which can predict the correct segmentation result s for all thyroid nodule ultrasound images img, i.e., Fideal (img) = s. Suppose A(.) is a set of operations for transforming images, including affine transformations such as shrinking and flipping, then the segmentation label corresponding to the image A(img) is A(s) that changes equally. Therefore, the following equation holds for Fideal . Fideal (A(img)) = A(Fideal (img)) = A(s)

(1)

For the general WSSS method FWSSS , since the semantic information in the image cannot be accurately extracted in most cases, the segmentation label generated on the

726

J. Gao et al.

transformed image $A(img)$ does not satisfy the above equation, FWSSS (A(img)) = A(FWSSS (img)). By constraining FWSSS (A(img)) and A(FWSSS (img)) to be equal, the model FWSSS will obtain additional supervision information during training, and closer to Fideal . FWSSS (A(img)) − A(FWSSS (img))1 = A− [F WSSS (A(img))] − FWSSS (img)1 (2) where A− (.) is the inverse operation for CAM, so that it is restored to be consistent with the original CAM. Existing image enhancement methods, such as scaling, rotating, flipping, and translation, are all based on established rules, and the information gain they provide to the model is random. SEAM [10] explored these enhancement methods and found that when selecting different enhancement methods and setting different parameter ranges, the performance of the model varied greatly. For different datasets, the optimal image enhancement methods and parameters used to achieve the best results are different, and finding the optimal solution can only be done through costly permutation and combination. Therefore, we propose EAM, which guides the enhancement process through the CAM generated by the model in real-time and enhances the original image to form a complementary foreground-background pair (FB-Pair) to obtain information gain effectively. Given an input thyroid nodule ultrasound image img ∈ R3×H ×W , the model produces an original CAM yo ∈ RC×H ×W , where C, H , W represent the total number of categories (including background), height, and width, respectively. On yo we generate a mask m ∈ RH ×W , m(i, j) = {0, 1} that distinguishes foreground and background based on a threshold θ . Then, we erase the original image according to the mask to enhance the features of the nodule and the background separately: EAM (img) = img × m, img × m = imgf , imgb

(3)

where m is the complement of m, and the foreground image imgf ∈ R3×H ×W and background image imgb ∈ R3×H ×W together form the FB-Pair. In imgf , background pixels that are incorrectly identified will be more prominent when the nodule accounts for a higher proportion of the image, while the opposite is true in imgb . The CAMs yf and yb generated on imgf and imgb will extract more semantic information about the nodule and background, respectively. To ensure the equivariance constraint of EAM, we first eliminate the interference caused by padding pixels in yf and yb , and only retain the activation responses generated by the original image pixels and fuse them.     (4) EAM − yf , yb = m × yf + m × yb = yf + yb = yfuse where the EAM − (.) denotes the reverse operation for EAM (.), and the resulting yfuse is the CAM obtained by the model on the corresponding original image pixel positions after enhancing the feature of the nodule and the background. It is constrained through equivariant consistency. yfuse − yo 1 = EAM − [F WSSS (EAM (img))] − FWSSS (img)1

(5)

In addition, EAM introduces zero-padded areas without semantic information during the generation of FB-Pair, while the ideal segmentation model Fideal would not be

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

727

affected by such interference, taking imgf as an example: Fideal (img × m) = Fideal (img) × m

(6)

where the FWSSS approach is prone to interference because it fails to learn the semantic features in the image thoroughly. Therefore, we apply the constraint of equivariance to this process as well, taking imgf as an example.       Fwsss imgf × m − Fwsss imgf × m (7) since imgf × m = imgf , the equivariance constraint in Eq. 7 holds. By using EAM, the model can acquire more information intra-image during training, thus extracting features of nodules and background more effectively and accurately distinguishing different categories of nodules and background. 2.2 SCSM-FB Due to the variations in thyroid nodule lesions among different patients, the morphology and size of the same category of thyroid nodules can vary greatly. Existing methods in extracting features of the same category of thyroid nodules may be limited to certain types of nodules and may struggle to learn comprehensive feature representations. Therefore, we propose to explore the semantic information intra-image based on the foreground and background features to improve the robustness of the model. For each image img in the training set, the model generates an original CAM yO . We then use Masked Average Pooling (MAP) to convert it into a feature embedding vector ec for each class c (including background). ec =

Sum(Mc ×yO ) Sum(Mc )

∈ RC

(8)

where Mc is a binary mask, where high response regions for category c in its activation map are set to 1. Sum(.) denotes the sum operation for a 2D matrix, while for a 3D matrix, the matrix operation broadcast mechanism is adopted. To collect objects of different classes in the entire dataset, we create and maintain a dynamic Memory Pool MP = {MP0 , MP1 , MP2 } during training, including the background class. In each training iteration, we update the embedding vector of the current image to the corresponding instance mpc in the memory pool. Based on the memory pool, we introduce the Noise Contrastive Estimation (NCE) commonly used in contrastive learning to enhance the discriminative ability of the model. Specifically, for the embedding vector ec of nodules in the image (Eq. 9), we increase its similarity with the positive memory feature {mpc+ ∈ MPc } and decrease its similarity with the negative memory features {mpc− ∈ MP − MPc } to make the model more discriminative. + 1  eec −mpc 2 LNCE (ec ) = |MP (9) c mpc+ ∈MPc −log ec −mpc+ 2  ec −mpc− 2 c| e

+

mpc− ∈MP−MPc

e

where |MPc | represents the number of instances in the memory pool belonging to the corresponding type, and .2 is the calculation of the l2 distance between vectors.

728

J. Gao et al.

However, due to the limited supervision information provided during training, there is a lot of noise interference in the memory pool. Therefore, we also consider applying NCE to the background in the image. In thyroid nodule ultrasound images, there is also a similarity in the background between different images. Making the model have a stronger discriminative ability on the background also helps segment nodules. Thus, we compute the Semantic Contrast loss by also calculating the NCE for the feature embedding of the background e0 . LSC = LNCE (ec ) + LNCE (e0 ) c c

(10)

A large memory pool can not only enhance the discriminative ability of the model through semantic contrast but can also facilitate learning more comprehensive feature representations through semantic sharing. However, it also contains a lot of redundant information. Directly performing semantic sharing on these data is computationally expensive. Moreover, the redundant information may cause the model to overfit. Therefore, for each class-specific memory pool MPc , we use the k-means algorithm to obtain $K$ representative data, yielding matrix Qc ∈ RK×C . We concatenate the matrices for each class to obtain matrix Q = [Q0 , Q1 , Q2 ] ∈ RC×K×C . We then obtain the similarity matrix $S$ through matrix multiplication between $Q$ and the original CAM yo . S = softmax(Q ⊗ yo ) ∈ R(CK)×(HW )

(11)

For the convenience of matrix computation, we represent Q and yo as Q ∈ RCK×C and yo ∈ RC×HW , respectively. The softmax(.) function normalizes the values in the matrix to be between 0 and 1 to represent the similarity between the object of the category in the current image and the representative data. Then, the CAM based on semantic sharing is obtained by computing with the Q matrix. yS = S ⊗ Q

(12)

To facilitate matrix computation, we represent S ∈ R(HW )×(CK) , Q ∈ RCK×C , and compute the result as yS ∈ RC×H ×W . At this point, yS is obtained by using the memory pool maintained by the dataset to provide the potentially missing object regions corresponding to the category of the original CAM, enabling the model to recognize the object regions more completely. Accordingly, we define the Semantic Sharing loss as follows: LSS = yo − yS 1

(13)

2.3 Loss In the proposed model, the loss function is designed to consist of two parts: intra-image loss function Lintra and inter-image loss function Linter . The intra-image loss function first calculates the classification loss, which produces a predicted class vector z by GAP and compares it to the class label l using cross-entropy loss.     −z  1 C−1 1 e c (14) LCLS = − C−1 c=1 lc log 1+e−zc + (1 − lc )log 1+e−zc

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

729

where lc and zc denote the values of l and z at class c, respectively. We propose EAM to mine more accurate semantic information intra-image. The FB-Pair Cross Regularization (FBCR) is used to constrain the equivariance consistency between FB-Pair and the original image, while Feature Extraction Regularization (FER) serves as a loss function for equivariance consistency during FB-Pair erasure. LFBCR = yo − yfuse 1 

(15) 

LFER = yf − yf 1 + yb − yb 1

(16)

Lintra = LCLS + LFBCR + LFER

(17)

The inter-image loss includes both the Semantic Contrast loss and Semantic Sharing loss. Linter = LSC + LSS

(18)

Finally, the overall loss function of the model is defined as a weighted sum of the intra-image loss and inter-image loss. Lall = Lintra + Linter

(19)

3 Experiments In this paper, we conducted experiments using TUI dataset provided by a collaborating hospital to verify the effectiveness and stability of our proposed WSSS method. The training set includes 1334 benign thyroid nodule images and 1331 malignant thyroid nodule images, while the validation set contains 300 benign and 300 malignant thyroid nodule images. To measure the performance of all experiments, we used IoU and Dice as evaluation metrics. 3.1 Comparison Experiments Table 1 compares our method with existing methods in the same experimental environment. Under IoU evaluation, existing methods perform poorly in nodule recognition, and a notable phenomenon is that they exhibit significant differences in performance between benign and malignant nodules. For instance, SSE and its derivative version SSE-WSSN show performance differences of approximately 10%, while SEAM, which has smaller performance differences, exhibits 4.2%. It indicates that existing methods have not fully learned the features of benign and malignant nodules and cannot achieve optimal performance on both. Our method fully uses semantic information intra-images, allowing the model to fully learn the semantic features of different categories and narrowing the performance difference between benign and malignant nodules to 2.9\%. Although the performance on malignant nodules is slightly inferior to SSE-WSSN, the overall performance is improved by 5.7%. After exploring inter-image semantic information, IntrNet

730

J. Gao et al.

further improves its performance on benign nodules to 58.8% and surpasses SSE-WSSN on malignant nodules, achieving a performance of 57.2% with a performance difference narrowed to 1.6%. Overall, the performance is improved to 58.0%, significantly outperforming existing methods. The Dice coefficient focuses more on measuring the similarity between the prediction and the ground truth. The performance of each method under the Dice coefficient is consistent with that under the IoU. Our method also significantly outperforms existing methods, achieving 73.5%. Table 1. Comparison with existing methods on TUI dataset. Method

IoU

Dice

B.

M.

Baseline

44.6

38.0

SEAM

47.0

Puzzle-CAM CPN SSE

Diff.

Mean.

B.

M.

Diff.

Mean

6.6

41.3

61.7

55.1

6.6

58.4

42.8

4.2

44.9

63.9

59.9

4.0

61.9

46.4

39.4

7.0

42.9

63.4

56.5

6.9

60.0

51.9

47.1

4.8

49.5

68.3

64.0

4.3

66.2

44.9

55.2

10.3

50.1

61.8

71.2

9.4

66.5

SSE-WSSN

46.2

56.1

9.9

51.2

63.2

71.9

8.7

67.6

IntrNet (inter-image)

58.3

55.4

2.9

56.9

73.7

71.3

1.8

72.5

IntrNet (all)

58.8

57.2

1.6

58.0

74.1

72.8

1.3

73.5

Figure 3 displays some qualitative results on the TUI dataset. The segmentation results produced by Baseline, SEAM, and Puzzle-CAM exhibit obvious over-activation problems and hardly extract semantic features from difficult samples in the TUI dataset (first column of images). Although the over-activation problem of CPN is eased, learning semantic features from difficult samples is still hard. SSE-WSSN has successfully extracted semantic features from difficult images, but its segmentation results suffer from under-activation. Our proposed IntrNet can accurately extract semantic features from difficult samples using only intra-image semantic information, but it has over-activation problems. However, by utilizing inter-image semantic information through semantic sharing and contrast, the segmentation results of IntrNet effectively alleviate the over-activation problem and exhibit the smallest difference from the Ground Truth. 3.2 Ablation Study Ablation on Each Module. Table 2 shows the effectiveness of each module proposed in our method. For the Baseline method, which only used the classification loss LCLS during training, the generated CAM-based segmentation results achieved a mIoU of 41.3% for both benign and malignant nodules in the TUI dataset. Applying the LFBCR and LFER from the Intra-Image led to significant improvements over the Baseline, with increases of 9.6% and 9.5%, respectively. When using all semantic information of Intra-Image

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

731

Fig. 3. Qualitative results on the TUI dataset.

Table 2. Ablation experiment on each module. Inter-Image LCLS

Intra-Image LFBCR

LFER

LSS

mIoU LSC

✔ ✔

41.3 ✔



50.9 ✔

50.8 56.9























57.5 ✔

58.0

during training, the segmentation results were further improved to 56.9%, indicating that the various parts of intra-image semantic information have good compatibility. Based on this, our method improved the segmentation performance on TUI to 57.5% and 58.0% by further exploring cross-image relationships and incorporating semantic sharing and contrast information during training. It also indicates that the Intra-image

732

J. Gao et al.

semantic information and Inter-image semantic information in our method have good compatibility, which is also true for the Dice coefficient results. Ablation on Intra-image Semantic Information. Our proposed EAM allows the model to learn intra-image semantic information fully, thereby enabling accurate localization of nodules and reducing misclassifications of noise areas in ultrasound images. As shown in Fig. 4, in some thyroid nodule ultrasound images, SEAM and CPN were not designed to enhance images for low contrast, low resolution, and other defects, leading to their provided intra-image information still identifying some background regions as nodules such as the noise at the corners, while identifying the true nodule regions low response regions. However, after integrating the EAM module, the model can accurately distinguish nodules from the background by fully learning intra-image semantic information and accurately identifying nodules. Ablation on Inter-image Semantic Information. In order to demonstrate the improvement in segmentation performance brought by inter-image semantic information, this section visualizes the difference between the predicted labels generated by the proposed method and the Ground-Truth segmentation labels under different image information utilization scenarios. The results are shown in Fig. 5, where white areas represent the regions with incorrectly predicted labels. Specifically, (a) shows representative thyroid nodule ultrasound images in the TUI dataset, where some images are difficult samples, with the first three images showing benign nodules and the last three images showing malignant nodules. (b) shows the segmentation labels generated by the proposed method using only intra-image semantic information. In this case, there are significant over-activation areas at the edges of some nodules, such as the first and fifth images. At the same time, under-activation occurs at the edges of some nodules, such as the second, third, and sixth images. (c) shows the segmentation labels generated by adding the semantic sharing module to the method in (b), where under-activation at the edges of nodules is significantly reduced in the second, third, and sixth images, while overactivation in the first and fifth images is slightly reduced. (d) shows the segmentation labels generated by the proposed method as a whole, which incorporates the semantic contrast module. The over-activation and under-activation phenomena are significantly improved, and the predicted results at the edges of nodules are more consistent with the ground truth segmentation labels, indicating that the model’s discriminative ability at the edges of nodules has been enhanced.

IntrNet: Weakly Supervised Segmentation of Thyroid Nodules

733

Fig. 4. Visualization of semantic features extracted by SEAM and CPN.

Fig. 5. Ablation on inter-image semantic information.

4 Conclusion In this paper, we train an efficient segmentation model using only image-level labels. We use intra-image and inter-image information to alleviate the overactivation and underactivation problems of the WSSS method in ultrasound images of thyroid nodules. In the utilization of information intra-image, we generate enhanced images through EAM, constraining its equivariance with original images to provide additional supervision information. In the utilization of information inter-image, we make the model’s performance more balanced among different nodules through semantic sharing and semantic comparison. Our method’s segmentation performance has been validated through experiments on the TUI dataset, which is significantly better than the current state-of-the-art methods.

References 1. Ahn, J., Kwak, S.: Learning pixel-level semantic affinity with image-level supervision for weakly supervised semantic segmentation. In: CVPR, pp. 4981–4990. IEEE (2018)

734

J. Gao et al.

2. Jo, S., Yu, I.J.: Puzzle-CAM: improved localization via matching partial and full features. In: ICIP, pp. 639–643. IEEE (2021) 3. Kolesnikov, A., Lampert, C.H.: Seed, expand and constrain: three principles for weaklysupervised image segmentation. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 695–711. Springer, Cham (2016). https://doi.org/10.1007/978-3319-46493-0_42 4. Lee, J., Kim, E., Lee, S., Lee, J., Yoon, S.: FickleNet: weakly and semi-supervised semantic image segmentation using stochastic inference. In: CVPR, pp. 5267–5276 (2019) 5. Li, X., Zhou, T., Li, J., Zhou, Y., Zhang, Z.: Group-wise semantic mining for weakly supervised semantic segmentation. In: AAAI, vol. 35, pp. 1984–1992 (2021) 6. Stammes, E., Runia, T.F., Hofmann, M.: Find it if you can: end to-end adversarial erasing for weakly-supervised semantic segmentation. In: ICDIP, pp. 610–619. SPIE (2021) 7. Sun, G., Wang, W., Dai, J., Van Gool, L.: Mining cross-image semantics for weakly supervised semantic segmentation. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12347, pp. 347–365. Springer, Cham (2020). https://doi.org/10.1007/978-3-03058536-5_21 8. Wang, X., Liu, S., Ma, H., Yang, M.H.: Weakly-supervised semantic segmentation by iterative affinity learning. Int. J. Comput. Vis. 128, 1736–1749 (2020) 9. Wang, X., You, S., Li, X., Ma, H.: Weakly-supervised semantic segmentation by iteratively mining common object features. In: CVPR, pp. 1354–1362 (2018) 10. Wang, Y., Zhang, J., Kan, M.: Self-supervised equivariant attention mechanism for weakly supervised semantic segmentation. In: CVPR, pp. 12275–12284 (2020) 11. Wei, Y., Feng, J., Liang, X.: Object region mining with adversarial erasing: a simple classification to semantic segmentation approach. In: CVPR, pp. 1568–1576 (2017) 12. Yu, M., et al.: Adaptive soft erasure with edge self-attention for weakly supervised semantic segmentation: thyroid ultrasound image case study. Comput. Biol. Med. 144, 105347 (2022) 13. Yu, M., Han, M., Li, X.: SSE: scale-adaptive soft erase weakly supervised segmentation network for thyroid ultrasound images. In: BIBM, pp. 1615–1618. IEEE (2021) 14. Zhang, F., Gu, C., Zhang, C., Dai, Y.: Complementary patch for weakly supervised semantic segmentation. In: ICCV, pp. 7242–7251 (2021) 15. Zhang, M., Chen, Y.: PPO-CPQ: a privacy-preserving optimization of clinical pathway query for e-healthcare systems. IEEE Internet Things J. 7(10), 10660–10672 (2020) 16. Zhang, X., Peng, Z., Zhu, P.: Adaptive affinity loss and erroneous pseudo-label refinement for weakly supervised semantic segmentation. In: ACM Multimedia, pp. 5463–5472 (2021) 17. Zhang, X., Wei, Y., Feng, J., Yang, Y., Huang, T.S.: Adversarial complementary learning for weakly supervised object localization. In: CVPR, pp. 1325–1334 (2018) 18. Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: CVPR, pp. 2921–2929 (2016) 19. Zhou, T., Zhang, M., Zhao, F., Li, J.: Regional semantic contrast and aggregation for weakly supervised semantic segmentation. In: CVPR, pp. 4299–4309 (2022)

Computational Intelligence and Its Application

Novel Ensemble Method Based on Improved k-nearest Neighbor and Gaussian Naive Bayes for Intrusion Detection System Lina Ge1,2,3 , Hao Zhang1,2(B) , and Haiao Li1,2 1 School of Artificial Intelligence, Guangxi Minzu University, Nanning, China

[email protected]

2 Key Laboratory of Network Communication Engineering, Guangxi Minzu University,

Nanning, China 3 Guangxi Key Laboratory of Hybrid Computation and IC Design Analysis, Nanning, China

Abstract. The frequent occurrence of network intrusion events poses a challenge to the performance of intrusion detection systems. Generally, these events contain many unknown attacks, making it difficult for intrusion detection systems to address them effectively. This study proposes a novel ensemble method based on machine learning to improve the performance of intrusion detection systems in detecting unknown attacks. This method implements a deep sparse autoencoder with a unique topological structure to reduce the data dimensionality. Subsequently, the k-nearest neighbor algorithm is improved and integrated with the Gaussian naive Bayes algorithm through voting. Experiments were conducted on the NSL-KDD and CICIDS2017 datasets. Experiments on these benchmark datasets verified the method’s effectiveness and provided encouraging results compared with several other well-known competitors. Keywords: Intrusion detection systems · Autoencoder · K-nearest neighbor · Gaussian naive Bayes

1 Introduction With the rapid development of network technology, various new forms of network intrusion have emerged, making it increasingly necessary to build secure network environments. The essence of network intrusion is to obtain sensitive data by illegally accessing the administrator’s authority, which seriously undermines user data integrity, availability, and confidentiality. With increased network intrusions, traditional technologies, such as firewalls, data encryption, and authentication cannot meet security requirements. Intrusion detection systems are vital in network security because they can monitor potential intrusions. An intrusion detection system is a network security device that can detect, analyze, and respond to network traffic in real-time. Intrusion detection systems can be divided into misuse-based and anomaly-based [1]. Misuse-based intrusion detection systems perform detection based on signatures and can only detect known attack types. Anomaly-based intrusion detection systems detect anomalous behaviors and discover © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 737–748, 2023. https://doi.org/10.1007/978-981-99-4742-3_61

738

L. Ge et al.

unknown attack types [2]. Therefore, anomaly-based intrusion detection systems have been studied by many experts and scholars. However, anomaly-based intrusion detection systems are still in the development stage and are not as widely used as misuse-based intrusion detection systems in practice. In recent years, machine learning for intrusion detection has been successful. Typical methods such as decision trees, support vector machines, and random forests have shown high accuracy in the detection of abnormal traffic. Compared with traditional methods, machine learning technology can efficiently filter abnormal traffic in the face of massive traffic data, thus exerting a vital influence on intrusion detection. Ensemble learning has received considerable attention as a machine learning technique due to its powerful classification performance. It is an important research direction in machine learning techniques combining multiple weak classifiers into a robust classifier to obtain a model with strong generalization and prediction abilities [3]. Ensemble learning is applied in many fields, such as natural language processing and intrusion detection [4]. Ensemble learning combines the advantages of multiple classifiers to obtain a better performance than a single classifier. Feature reduction can reduce the impact of high-dimensional data on the model. The processed data into the model through the feature dimensionality reduction method can reduce the training time and improve the model’s performance [5]. Recently feature reduction and ensemble learning has become a new trend in intrusion detection. Datasets can be processed in two ways in intrusion detection [6]. The first method divides the training set into training and testing parts according to a certain proportion. This method is often highly accurate in model testing. However, models generally do not perform well with zero-day attacks, as the anomaly detection model is tested using a portion of the training set, which already contains attack types that the trained model knows. Another method is not to divide the training set but to use additional test sets to test the model. These additional test sets contained network attack types that had never appeared in the training set. Thus, the model that is trained in this manner can detect new attack types. Unfortunately, few studies have followed this second research method [7]. Although a second research method was adopted, the model’s accuracy on the test set was generally low. This study adopts a second research method to address these problems. In this study, we propose to use the improved k-nearest neighbor (IKNN) algorithm and the Gaussian naive Bayes (GNB) algorithm to increase the accuracy of the models. Considering the limitations of a mono-classifier, the IKNN and GNB algorithms are integrated by voting. The integration method can average the classification error among the base classifiers and thus improve the accuracy [8]. Furthermore, we use deep sparse autoencoders (DSAE) to reduce the data dimensionality. The remainder of this study is organized as follows: The proposed methods, including dimensionality reduction and the ensemble method, are described in Sect. 2. The experimental details and results are presented in Sect. 3. Section 4 presents conclusions and prospects.

Novel Ensemble Method Based on Improved k-nearest Neighbor

739

2 Methodology To address the problem of poor performance in the second research method, a DSAE was proposed for data dimensionality reduction. Furthermore, an ensemble model combining IKNN and GNB was implemented for classification. Finally, we used four evaluation metrics (accuracy, precision, recall, and F1 score) to evaluate performance. The schematic diagram of the proposed method is shown in Fig. 1. The proposed method includes the data processing module, extraction module, and classification module. In the data preprocessing module, the training and test sets are processed in the same ways. Firstly, the character data are sparsely encoded by using one-hot encoding. Secondly, to speed up the training of the model, the dataset is normalized. After that, we use the training set to train a DSAE model. In the data extraction module, we design a deep sparse autoencoder. The autoencoder has a unique topology. We adopt the structure for dimensionality reduction on the training and test sets. The process is described in Sect. 2.1. In the classification module, we propose to use the improved KNN algorithm and GNB algorithm for integration. The combination of the IKNN algorithm and GNB through voting improves the overall performance.

Datasets

Data Preprocessing

Data Extracon

Classificaon

Fig. 1. A schematic diagram of the proposed method

2.1 Feature Reduction With the advent of the big data era, the data dimensions have increased exponentially. The increase in data dimensions has led to the problem of dimensional disasters, which increases the training cost of the model and makes it difficult for the model to learn helpful information from high-dimensional data. Therefore, it is necessary to reduce the dimensionality of the dataset. Currently, PCA is the most popular feature reduction method. However, PCA is a linear method, which has limitations in the feature reduction of nonlinear data. Shallow sparse autoencoders generally consist of an input layer, a single hidden layer (bottleneck layer), and an output layer. In the sparse autoencoder, x ∈ Rn is the input vector, y ∈ Rm represents the bottleneck layer, hl ∈ Rk represents the lth hidden  layer, and x ∈ Rn is used to output the vector. A DSAE is formed by stacking multiple hidden layers. A typical shallow, sparse autoencoder learns limited information from original data. Therefore, this study implements a DSAE with multiple hidden layers in the form of stacking multiple hidden layers so that the sparse autoencoder can learn further hidden information from the original data. The structure comprises a four-layer encoder and a four-layer decoder. Taking the NSL-KDD dataset as an example, the sizes

740

L. Ge et al.

of the one-to-four-layer encoders were 121, 64, 32, and 16, respectively. In contrast, the sizes of the one-to-four-layer decoders were 16, 32, 64, and 121, respectively. Similarly, for the CICIDS2017 datasets, the structures of encoders and decoders are 78-40-20-10 and 10-20-40-78, respectively. Encoding compresses the data at a rate of 0.5, then reconstructed at a rate of 2 during decoding. The input at encoding and output at decoding was 121, which is the data dimension of the preprocessed NSL-KDD dataset. The data dimensionality reduction process is as follows: After the encoder compresses the original data, 16-dimensional data features are obtained, and thereafter, the original data are reconstructed by the decoder. The parameters are updated according to the error backpropagation algorithm, and this process is iterated several times until the loss function converges to a global minimum. Finally, the encoder part with the optimal parameters was saved. 2.2 Ensemble Method Traditional methods, such as using a single classifier to classify the data following feature reduction, no longer meet actual needs [4]. The ensemble method compensates for the shortcomings of a single classifier and improves model performance. The ensemble method combines multiple weak classifiers into a robust classifier and performs better than a single classifier. The ensemble method is divided into homogeneous and heterogeneous ensembles based on the similarities and differences between the base classifiers. A homogeneous ensemble comprises several identical base classifiers, whereas a heterogeneous ensemble is composed of several base classifiers. Zhou et al. [9] proposed a heterogeneous ensemble method based on C4.5, random forest, and ForestPA. This method performed better on the NSL-KDD test set than the other methods. Andalib et al. [10] proposed a heterogeneous ensemble method based on a GRU, CNN, and random forest and achieved similar accuracy to the literature [9] on the NSL-KDD test set. This study adopted a heterogeneous ensemble method and integrated the IKNN and GNB algorithms through voting. 2.2.1 IKNN Algorithm The KNN algorithm is a classic machine-learning algorithm, which is a supervised learning method. By comparing the similarity between the training and test sets, the KNN algorithm selects the k neighbors that are most similar to the test sample. According to the label of the k-neighbor, the label of the test sample is determined by majority voting. The KNN algorithm is widely used in classification tasks because of its simple implementation and high classification accuracy. Sameera et al. [11] demonstrated that the KNN classifier had a good classification effect on data following dimensionality reduction. The voting weight of the k neighbors of the KNN algorithm affects its accuracy. In the traditional KNN algorithm, the voting weights of k neighbors are the same, which means that they have similar voting power. However, in an actual operation, the different voting weights are often more in line with the actual situation due to the variation in the similarity. Therefore, to overcome the defects of the same voting weight in the traditional KNN algorithm, this study improves the traditional KNN algorithm and proposes an IKNN algorithm. IKNN introduces the reciprocal of the Gaussian function to assign different

Novel Ensemble Method Based on Improved k-nearest Neighbor

741

weights to k neighbors. We use the Euclidean distance as the similarity measure of the IKNN and sort the distances of the k neighbors in ascending order. A closer distance indicates a greater similarity between the neighbor and test samples. When using the reciprocal of the Gaussian function, the abscissa represents the distance measurement, and the ordinate represents the weight. When the distance of the k neighbors is used as input, the neighbors at both ends gain more significant weight, and the neighbors in the middle position gain a smaller weight. This method adopts the strategy of “removing the middle and taking the two ends” by reducing the voting weight of neighbors whose distance is in the middle and increasing the voting weight of neighbors whose distance is at the two ends. Thus, it is ensured that neighbors with greater similarity have a greater voting weight, which also allows neighbors with less similarity to exert an influence. The calculation process is shown in Eq. (1). W

1∗K

= 1/ √

1 2π σ

e



(D1∗k −u) k 2σ 2

2

(1)

where Dk1∗k , σ 2 , and u represent the matrix composed of the distances of k neighbors sorted in ascending order, the variance of the matrix, and the mean value of the matrix, respectively. Finally, a 1*k weight matrix was obtained.  Algorithm 1 presents the IKNN algorithm, where N , N are the sizes of the training  and test sets, respectively, xi is the ith sample in the training set, xj is the jth sample in the test set.

2.2.2 GNB GNB uses statistics and probability theory to perform a probability analysis of sample attributes. Unlike the naïve Bayes algorithm, the GNB algorithm assumes that the attributes of each sample are independent of one another and obey a Gaussian distribution. According to Bayes’ formula, the probability of the target category is calculated

742

L. Ge et al.

given the target attributes x11 , x12 , . . . x1n (where n represents the size of the feature dimension). The calculation formula is as follows: p(y|x11 , x12 , . . . x1n ) =

p(y)p(x11 , x12 , . . . x1n |y) p(x11 , x12 , . . . x1n )

(2)

where y, p(y), and p(x11 , x12 , . . . x1n |y) represent the target category, prior and posterior probabilities, respectively. Studies have demonstrated that the GNB method has a better classification effect on data following dimensionality reduction 5. 2.2.3 Ensemble Model Based on IKNN and GNB Currently, bagging, boosting, stacking, and voting are the most commonly used methods for integrating multiple classifiers. The bagging and boosting methods generally use a homogeneous ensemble to integrate models; however, as the method proposed in this study combines two different classifiers, it is not suitable for integration through the bagging and boosting techniques. Although the stacking method can integrate different base classifiers, it is not considered owing to its high computational cost. In contrast, the voting method is more suitable for the proposed method. This study uses a voting method to integrate different classifiers. In the sklearn library, the voting method has two different means of aggregating classification results: soft voting and hard voting. The soft-voting method determines the final result of the ensemble method by comparing the average probability values. However, the hard-voting method determines the final result of the ensemble method according to the frequency of the output target category based on the principle of the majority. Hard voting is suitable for classification problems, whereas soft voting is suitable for regression problems. Because the intrusion-detection problem is a classification problem, hard voting was adopted. Let cj (j = 1, 2, . . . k) be the label of the k-nearest neighbors in the IKNN algorithm. N (k) represents the k  fields of test set xj in the training set. Where I() is the indicator function. yk and yg are the predicted labels of the test set sample, which are calculated using Eqs. (3) and (4),  respectively. Finally, the result y is calculated using Eqs. (5). Algorithm 2 is a detailed ensemble classification algorithm in which yi = y1 , y2 . . . ym denotes the sample labels,    and m represents the number of labels. xj1 , xj2 , . . . xjn is the feature space of the jth sample in the test set, and n is the size of the data dimension. It should be noted that in the decision phase of Algorithm 2, if both IKNN and GNB provide a similar prediction for the same instance, the ensemble classifier will choose this common prediction for the final result. It is reflected in lines 8–10 of Algorithm 2. In lines 11–15, if IKNN and GNB give different classification results, the voting-based ensemble classification makes a decision based on the overall performance of the base classifier on the dataset. First, the ensemble classifier will be sorted in ascending order based on the performance of the base classifier, and the latter classifiers will perform better. Assuming that the accuracy of IKNN on the dataset is higher than that of GNB, the ensemble classifier will select the classification result of IKNN as the final result, and vice versa.    I yi = cj ∗ W 1∗K }, i ∈ Rm , j ∈ Rk (3) yk = argmax{ xi ∈N (k)

      yg = Max(P yi |xj1 , xj2 , . . . xjn , i ∈ Rm , j ∈ RN

(4)

Novel Ensemble Method Based on Improved k-nearest Neighbor 

y = Voting(yk , yg )

743

(5)

3 Experiment and Discussion 3.1 Benchmark Datasets The NSL-KDD dataset used in this study was an improved version of the KDDCup99 dataset. The KDDCup99 dataset was obtained from an intrusion detection evaluation project of the MIT Lincoln Laboratory in the United States 2, which collected data on network connections and system audits over nine weeks. According to Tavallaee et al. 6, the KDDCUP99 dataset contains 78% and 75% redundant data in the training and test sets, respectively. To solve redundancy in the KDDcup99 dataset, Tavallaee et al. extracted the NSL-KDD dataset without redundant data from the KDDCUP99 dataset. The CICIDS2017 dataset was collected by The Canadian Institute for Cybersecurity in 2017. It contains benign and up-to-date common attacks, which resemble true realworld data (PCAPs). It also includes the results of the network traffic analysis using CICFlowMeter with labeled flows based on the timestamp, source and destination IPs, source and destination ports, protocols, and attacks. The CICIDS2017 dataset collects traffic data for five days, of which the first day is normal traffic, and the remaining four

744

L. Ge et al.

days contain normal traffic and abnormal traffic. Compared with the NSL-KDD dataset, the CICIDS2017 dataset contains many latest types of network attacks; therefore, it is very suitable for testing the generalization of models. Because of the large original dataset of CICIDS2017, we randomly selected normal traffic and abnormal traffic at a ratio of 8:2 from the traffic data of these five days; therefore, we built a training set with 100 K samples and one testing set with 900 K samples. The processing of this dataset is described in reference 12. Finally, we extracted ten pieces of data from the original dataset, each containing 100,000 samples, of which nine were used as test datasets and the other as training datasets. 3.2 Dataset Preprocessing The NSL-KDD dataset sample contained 41-dimensional features, among which the three features protocol_type, service, and flag were the symbol types. Because the model cannot handle symbolic data, it is necessary to convert the characteristics of symbolic types into numeric kinds. One-hot encoding is used to transform symbolic features into numeric features. Using the NSL-KDDTrain+ training set as an example, Protocol _type contains three values: TCP, UDP, and ICMP, which are represented by one-hot encoding as 100, 010, and 001, respectively. The service feature has 70 values, which are represented by 70 numbers containing 0 and 1. The flag feature has 11 values, which are represented by 11 numbers containing 0 and 1. After both the training and test sets were one-hot encoded, the feature dimension was 122. Further analysis of the training set and test set data revealed that the value of the feature num_outbound_cmds was zero, and its contribution to the model performance could be ignored; thus, it was deleted. Finally, 121-dimensional data were obtained. The range of each data value in the dataset was different, which led to a slow convergence rate during the model training. Finally, we use the following method to normalize the data to the range [0,1]. MaxMinScaler =

x − xmin xmax − xmin

(6)

where xmax and xmin represent the maximum and minimum values of the column in which feature x is located, respectively. Similarly, for the CICIDS2017 dataset, all the feature attributes were numeric. Therefore, except for the label column, the rest of the feature data were scaled to the range [0,1] using the maximum-minimum normalization method. Because the column with the feature name Flow_Bytes contains null values, we replace it with the average value of the column it is in. Finally, we obtained a processed dataset containing 78 features, excluding the labels.

Novel Ensemble Method Based on Improved k-nearest Neighbor

745

3.3 Experimental Procedure The experimental procedure was divided into two steps. Initially, the training set was divided into training and verification parts at a proportion of 8:2, following which the DSAE encoder was used to reduce the dimensionality of the dataset. The verification part is used to correct the error in the dimensionality reduction process. Following the dimensionality reduction, it was reduced from 121 to 16 and 78 to 10. As the model is concerned with the data of the bottleneck layer, the encoder part needs to be saved for use as the next step. The loss in the dimensionality reduction process using DSAE is illustrated in Fig. 2, which indicates that the reconstruction error on the training and verification sets was below 0.002. Finally, a good dimensionality reduction effect was obtained. In the second step of the experiment, the ensemble model based on IKNN and GNB was used to classify dimensionality-reduced data. The GNB classifier uses default parameters in the sklearn library. The IKNN algorithm uses default settings, except for n_neighbors and weight parameters. The weight parameter is set to the reciprocal of the Gaussian function to provide k neighbors with different voting weights. To compare IKNN and KNN, we used the NSLKDD test set for validation, Fig. 3 depicts the traditional KNN and IKNN algorithms when n_neighbors had different values than the accuracy rate. In sklearn, a uniform KNN with similar weights and distance KNN with weights as the reciprocal of distance is defined, and they are compared with IKNN. From the figure that the IKNN algorithm exhibits a significantly higher accuracy rate than before the improvement. According to the curve in the figure, the accuracy of the IKNN algorithm tends to be stable when n_neighbors is close to 50, whereas, in the process from n_neighbors = 10 to n_neighbors = 50, the accuracy of the IKNN algorithm increased from 84% to 85%, with an overall increase of only 1%—considering that a larger value of n_neighbors results in a higher computational cost, it is more appropriate to set n_neighbors to 20 in the figure. We made the same adjustments to the CICIDS2017 dataset.

Fig. 2. Feature dimensionality reduction

746

L. Ge et al.

Fig. 3. IKNN vs KNN

3.4 Results and Discussion Tables 1 list the comparisons between the proposed method and other methods to demonstrate the superior performance of the proposed method. It is worth noting that some studies did not test all four metrics based on the authors’ different considerations. Most research has focused on test accuracy, but that does not mean other metrics are not important. The accuracy rate can reflect the model’s overall performance, so it becomes the main metric of this comparison. For the completeness of the study, all four metrics were evaluated in this study. Table 1 shows that the accuracy of the proposed method on the NSL-KDDTest+ test set has reached 92%, and the accuracy rate on the NSLKDDTest-21 test set has reached 86%. Although the accuracy of the proposed method on the NSL-KDDTest-21 test set is not as accurate as that listed in the literature 13, it is not significantly different and can prove the effectiveness of the proposed method. In addition, the performance of the proposed method on the CICIDS2017 test set also confirms that the proposed method has good generalization. Table 1. Performance of model on NSL-KDDTest+, NSL-KDDTest-21, and CICIDS2017 Dataset

Method

Accuracy

Precision

Recall

F1 score

NSL-KDDTest+

Proposed method

92.00

89.55

96.17

92.74

DAE+GNB [5]

83.34

\

\

\

NBTree [6]

82.02

\

\

\

AE [7]

88.28

91.23

87.86

89.51

DAE [7]

88.65

96.48

83.08

89.28 (continued)

Novel Ensemble Method Based on Improved k-nearest Neighbor

747

Table 1. (continued) Dataset

NSL-KDDTest-21

CICIDS2017Test

Method

Accuracy

Precision

Recall

F1 score

C4.5+RF+Forest [9]

87.37

\

\

\

GRU+CNN+RF [10]

87.28

\

\

\

GBM [13]

91.82

\

\

\

Proposed method

86.00

88.59

94.71

91.55

NBTree [6]

66.16

\

\

\

C4.5+RF+Forest [9]

73.57

\

\

\

GRU+CNN+RF [10]

76.61

\

\

\

GBM [13]

86.51

\

\

\

Proposed method

98.95

96.59

98.24

97.41

RGB [12]

\

\

\

89.00

MINDFUL [14]

97.70

\

\

94.93

CLAIRE [15]

98.01

\

\

95.20

Stacking [16]

98.82

\

\

\

4 Conclusions An intrusion-detection system plays an important role in preventing network intrusion as an active defense method. DSAE can perform linear and nonlinear data mapping, thereby achieving better results than traditional dimensionality reduction methods. As a typical example of machine learning technology, ensemble methods are used extensively in intrusion detection systems because they can combine several weak classifiers into one robust classifier. The KNN algorithm determines the correlation between samples using the distance measurement method. The algorithm is simple to implement, and the mathematical theory is mature. Furthermore, GNB is sensitive to the representation of data features, which makes it perform better on data following dimensionality reduction. Aimed at the shortcomings of the traditional KNN algorithm, this study further improved it and proposed the IKNN algorithm. Experiments show that the IKNN algorithm has better classification accuracy than the KNN algorithm. The IKNN and GNB algorithms are fundamentally different and can learn data from different perspectives. Therefore, combining IKNN and GBNB using a heterogeneous ensemble method can improve the overall performance. Given the shortcomings of the first research method mentioned previously, this study presented the use of a DSAE to reduce the data dimension. The IKNN and GNB ensembles for the classification methods can improve the model’s accuracy on the two NSL-KDD test sets. In addition, the proposed method was applied to the latest CICIDS2017 dataset, which further proved the effectiveness of the method. The proposed method compensates for the shortcomings of the first method and improves the model’s performance on different test sets. This method is of great significance in the research on intrusion detection systems. In the future, the model will continue to be

748

L. Ge et al.

optimized, and deep learning technology will be introduced into the proposed ensemble method to further improve model performance.

References 1. Elike, H., et al.: Shallow and Deep Networks Intrusion Detection System: A Taxonomy and Survey (2017). https://arxiv.org/abs/1701.02145 2. Thakkar, A., Lohiya, R.: A review of the advancement in intrusion detection datasets. Procedia Comput. Sci. 167, 636–645 (2020) 3. Aytug, O.: Hybrid supervised clustering based ensemble scheme for text classification, Kybernetes (2017) 4. Amin, A.A., Reaz, M.B.I.: A survey of intrusion detection systems based on ensemble and hybrid classifiers. Comput. Secur. 65, 135–152 (2017) 5. Mahmood, Y.-A., et al.: Autoencoder-based feature learning for cyber security applications. In: 2017 International Joint Conference on Neural Networks (IJCNN), pp. 3854–3861. IEEE (2017) 6. Mahbod, T., et al.: A detailed analysis of the KDD CUP 99 data set. In: 2009 IEEE Symposium on Computational Intelligence for Security and Defense Applications, pp. 1–6. IEEE (2009) 7. Can, A.R., Yavuz, A.G.: Network anomaly detection with stochastically improved autoencoder based models. In: 2017 IEEE 4th International Conference on Cyber Security and Cloud Computing (CSCloud), pp. 193–198. IEEE (2017) 8. Sumaiya, T.I., Kumar, C., Ahmad, A.: Integrated intrusion detection model using chi-square feature selection and ensemble of classifiers. Arabian J. Sci. Eng. 44(4), 3357–3368 (2019) 9. Zhou, Y., et al.: Building an efficient intrusion detection system based on feature selection and ensemble classifier. Comput. Netw. 174 (2020) 10. Amir, A., Vakili, V.T.: An autonomous intrusion detection system using an ensemble of advanced learners. In: 2020 28th Iranian Conference on Electrical Engineering (ICEE), pp. 1– 5. IEEE (2020) 11. Nerella, S., Shashi, M.: Encoding Approach for Intrusion Detection Using PCA and KNN Classifier ICCII 2018, pp. 187–199 (2020) 12. Kim, T., Suh, S.C., et al.: An encoding technique for CNN based network anomaly detection. In: 2018 IEEE International Conference on Big Data (Big Data), pp. 2960–2965. IEEE (2018) 13. Tama, B.A., Rhee, K.-H.: An in-depth experimental study of anomaly detection using gradient boosted machine. Neural Comput. Appl. 31(4), 955–965 (2017). https://doi.org/10.1007/s00 521-017-3128-z 14. Giuseppina, A., et al.: Multi-channel deep feature learning for intrusion detection IEEE Access 8, 53346–53359 (2020) 15. Andresini, G., Appice, A., Malerba, D.: Nearest cluster-based intrusion detection through convolutional neural networks. Knowl.-Based Syst. 216, 106798 (2021) 16. Adhi, B., et al.: An enhanced anomaly detection in web traffic using a stack of classifier ensemble IEEE Access 8, 24120–24134 (2020)

A Hybrid Queueing Search and Gradient-Based Algorithm for Optimal Experimental Design Yue Zhang1

, Yi Zhai1(B)

, Zhenyang Xia1

, and Xinlong Wang2

1 School of Computer Science and Technology, Qilu University of Technology (Shandong

Academy of Sciences), Jinan 250353, China [email protected] 2 College of Information and Electrical Engineering, China Agricultural University, Beijing 100107, China

Abstract. Optimal experimental design is a crucial aspect of statistical planning in various scientific fields. Traditional gradient-based optimization algorithms are challenged when dealing with complex experimental conditions and large numbers of factors and parameters. Heuristic algorithms have been used as an alternative; however, they may suffer from premature convergence and cannot guarantee optimal solutions. This study aims to develop a hybrid algorithm that combines the strengths of gradient-based and heuristic optimization algorithms to improve solution quality and alleviate premature convergence issues in experimental design optimization. The proposed algorithm integrates the Multiplicative Algorithm (MA) with the Queueing Search Algorithm (QSA) to address complex optimization problems. The algorithm’s performance is evaluated using numerical examples from generalized linear models (GLM), specifically logistic and Poisson models, and compared with state-of-the-art algorithms such as DE, SaDE, GA, and PSO. The numerical results demonstrate the superiority of the proposed algorithm in terms of convergence speed and solution quality. The proposed algorithm consistently outperforms other algorithms, achieving higher objective function values with fewer iterations. The integration of MA with QSA provides a more effective and robust optimization algorithm for experimental design problems. The proposed algorithm exhibits improved solution quality and competitive convergence speed while alleviating premature convergence issues, making it a promising approach for complex optimization problems in experimental design. Keywords: optimal experimental design · design efficiency · queueing search algorithm

1 Introduction Optimal Experimental Design (OED) is a statistical and mathematical method for arranging experiments. It plays a critical role in ensuring that experiments are carried out in an efficient and effective manner. The primary goal of OED is to arrange the experiment with a minimum number of trials, and at a lower cost, to obtain ideal experimental results and draw scientific conclusions. Finding the optimal design is a challenging task in many © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 749–760, 2023. https://doi.org/10.1007/978-981-99-4742-3_62

750

Y. Zhang et al.

areas of science and engineering, and it is known to be an NP-hard problem, as proved by WJ Welch in [1]. In this study, our focus is on effectively tackling the OED problem by leveraging novel optimization techniques, with the aim of obtaining the best possible design given the constraints and nominal values of the parameters. In the past, gradient-based algorithms such as the Fedorov-Wynn algorithm [2, 3] and Multiplicative algorithms (MA) [4] were commonly used to find optimal experimental designs. However, these algorithms might face challenges when attempting to find the optimal design within a reasonable time frame for more complex multi-factor and multi-parameter models. Recently, the Cocktail algorithm [5] and Randomized Exchange algorithm [6] have been proposed, and they have shown significant improvements in computational efficiency compared to the older algorithms. Yet, gradient-based algorithms still have limitations in finding the optimal design for experiments in complex settings. In recent years, heuristic algorithms have been increasingly applied to solve a variety of complex optimization problems. Particle Swarm Optimization (PSO) [7] and Differential Evolution (DE) [8] have been recently used to solve various optimal design problems in statistical literature [9, 10]. However, heuristic algorithms may also suffer from premature convergence, resulting in suboptimal solutions. Therefore, optimal experimental design remains a challenging issue that requires further research. Different algorithms have distinct strengths and weaknesses when handling various types of models and experimental situations, making it necessary to choose the appropriate algorithm flexibly and adaptively based on the specific problem at hand. The Queueing Search Algorithm (QSA) [11] has been identified as a suitable method for addressing optimal design problems, as it avoids premature convergence and excessive computational time [12]. To further enhance the optimality of QSA solutions, we have integrated a gradientbased algorithm with QSA. During each iteration, a multiplicative process is employed to update a portion of the solution population. This combination allows the algorithm to obtain higher-quality solutions within the population, resulting in a more efficient and effective search for the optimal experimental design. The organization of the paper is as follows: Section 2 introduces the background of optimal design. Section 3 introduces QSA and multiplicative algorithm, and presents the proposed algorithm, and explains how it is implemented and used to solve optimal design problems. In Sect. 4, the proposed algorithm is used to generate optimal designs for logistic model and Poisson model, and is compared with other algorithms, the example results are analyzed in this section. The conclusion is provided in Sect. 5.

2 Backgrounds In an experiment, the goal is to understand how various factors, or independent variables, influence a specific outcome, or response variable. This relationship can be represented using a mathematical model, such as y = η(X, θ ) + ε. Here, y denotes the outcome, and X = (x1 , x2 , · · · , xn )T is a n-dimensional vector of independent variables. Additionally, the p-dimensional vector θ = (θ1 , θ2 , · · · , θp ) represents unknown parameters that need to be estimated.

A Hybrid Queueing Search and Gradient-Based Algorithm

751

In an experimental design with t points, it consists of a combination of the chosen values for the independent variables (support points) and their respective weights. It is represented as a matrix ξ , which can be expressed as:   X1 X2 · · · Xt ξ= , w1 w2 · · · wt here, each support point X i represents a vector of the values for the independent variables at the i-th experimental point, where i = 1, 2, · · · , t. The corresponding weights, wi , indicate the proportion of experimental runs allocated to each design point. These weights t  must satisfy the conditions wi ∈ (0, 1) for i = 1, 2, · · · , t and wi = 1, meaning that i=1

the sum of all weights must equal 1. The effectiveness of the design ξ is evaluated using the Fisher Information Matrix. For a design ξ with t support points, its information matrix can be expressed as: M (ξ, θ ) =

t 

wi f (X i , θ )f T (X i , θ ),

(1)

i=1

where the vector f (X i , θ ) is a p-dimensional vector obtained by taking partial derivatives of the model with respect to each parameter in θ at the variables X i . It can be expressed as:   ∂η(X, θ ) ∂η(X, θ ) ∂η(X, θ ) T f (X, θ ) = , ,..., . (2) ∂θ1 ∂θ2 ∂θp To select the best experimental design, optimal criteria are employed. These criteria identify designs providing the most informative data for parameter estimation, allowing comparison and selection of the most efficient design for a specific purpose. Various optimal criteria exist, each with a unique mathematical definition addressing different aspects of the design. The choice of criterion depends on the research question and experimental goals. In this paper, we focus on the D-optimal criterion, which aims to maximize the log-determinant of the Fisher Information Matrix ξD = arg max |M (ξ, θ )|.

(3)

The D-optimal criterion offers several advantages in experimental design. By minimizing the determinant of the variance-covariance matrix of the estimated parameters, it enhances the efficiency and precision of parameter estimation. Additionally, D-optimal criterion is known to have desirable properties such as being invariant under re-parameterization and robust to model mis-specification [13]. Considering that the computation of the information matrix for nonlinear models depends on the values of parameters, Chernoff [14] proposed a method of using parameter guess values as nominal values, which resulted in an experimental design known as locally optimal design. Experimental designs obtained through this method are feasible when the guessed parameters are close to the actual parameters [15].

752

Y. Zhang et al.

The Equivalence Theorem, introduced by Kiefer and Wolfowitz [16], serves as an essential theoretical tool in optimal experimental design for determining the optimality of a design. In the context of D-optimal criterion, the design ξ is optimal if and only if: d (X, ξ ) − p ≤ 0

∀x ∈ X .

(4)

where the standardized variance d (X, ξ ) can be expressed as: d (X, ξ ) = f (X, θ )T M (ξ, θ )−1 f (X, θ ).

(5)

Equality applies in this formula at the support points of the design ξ . This theorem provides a way to verify if a design is optimal by checking the standardized variance for all X in the design space χ .

3 Queueing Search Algorithm and Its Improvement 3.1 Queueing Search Algorithm The Queueing Search Algorithm (QSA) is an innovative optimization technique inspired by the behavior of queues in real-world systems. QSA operates in three consecutive businesses that guide the search process to identify the global optimum solution. By simulating the interactions between staff and customers in a queuing system, the algorithm efficiently explores and exploits the search space. Within the QSA, there are two types of solutions: staffs and customers. Staff members represent the best solutions in the population, while customers represent the remaining solutions. As a versatile and robust algorithm, QSA is well-suited for addressing complex optimization problems, including those encountered in optimal experimental design. In the following subsection, we will provide a detailed introduction to the workings of QSA and its application in the context of experimental design optimization. In the first business of QSA, there are three queues, namely queue11 , queue12 , and queue13 , and a staff member is assigned to each queue. The remaining population is divided into these three queues and is being served by the staff. The customers are sorted into the queues based on their fitness value. During the updating process, there are two ways to update the state of each customer: Cinew = A + βα(E · |A − Ciold |) + e(A − Ciold ), i = 1, 2, · · · , N ,

(6)

Cinew = Ciold + βα(E · |A − Ciold |), i = 1, 2, · · · , N ,

(7)

where A is the staff member assigned to the queue that customer Ci belongs to; α is a uniform random number in the range [−1,1]; E is an exponentially distributed random vector with a mean of 0.5; e is an exponentially distributed random number; N is total number of customers; and β = g (g/itermax ) , 0.5

where g is the current iteration; itermax is the maximum iteration of this algorithm.

A Hybrid Queueing Search and Gradient-Based Algorithm

753

During the first business in QSA, each customer’s state C old is updated using one of the two equations provided. After a customer updates their state using one of the two equations provided, the resulting state C new is evaluated for fitness. If C new has better fitness than C old , then the next customer will also use the same equation to update their state. Otherwise, the next customer will use the other equation to update their state. The first customer to update their state uses Eq. (6). This process continues until all customers have been updated. After the first business in QSA, each customer is assigned a probability Pri,1 based on following equation: Pri =

rank(fi ) , i = 1, 2, · · · , N , N

(8)

where fi is the fitness value of the customer i , rank(fi ) is the ranking of the fi of customer i , when the fitness value of all customers is ranked in the descending order. Additionally, a random number randi,1 is generated from [0, 1] for each customer. If randi,1 < Pri,1 , then the customer i will be updated in the second business. The total number of customers participating in the second business is denoted as N2 . In the second business of QSA, there are three queues, namely queue21 , queue22 , and queue23 . During the updating process, there are also two ways to update the state of each customer: Cinew = Ciold + e(Cr1 − Cr2 ), i = 1, 2, · · · , N2 ,

(9)

Cinew = Ciold + e(A − Cr1 ), i = 1, 2, · · · , N2 .

(10)

Cr1 and Cr2 are two randomly selected customers. The equation utilized for updating a customer’s state is determined based on the 21 , where T2i , i = 1,2,3 represent confusion degree cv , which is calculated as cv = T22T+T 23 the fitness values of the staff members A2i in queue2i . Each customer is assigned a uniform random number in the range[0, 1], and if this number is less than cv , the customer will use Eq. (9) to update, otherwise, Eq. (10) will be used. In the third business, there is only one queue, and only a portion of the customers who have been processed in the second business have the opportunity to participate in the third business. A random number randi,d is generated from a uniform distribution over the interval [0, 1] for each dimension d of a customer i ’s state Ci . If randi,d is larger than Pri,2 evaluated by Eq. (8), the d th dimension of the Ci will be updated in the third business by the following equation: new old Ci,d = Cr1,d + e(Cr2,d − Ci,d ), i = 1, 2, · · · , N3 ,

(11)

where N3 is the number of customers that will be updated in the third business, Cr1,d and Cr2,d are states of two randomly selected customers. To learn more details about QSA, please refer to [11], where the algorithm is introduced in depth and its performance is evaluated through various experiments.

754

Y. Zhang et al.

3.2 Multiplicative Algorithm Multiplicative Algorithm is a widely used gradient based algorithm for constructing optimal designs. It was first proposed by Silvey, Titterington and Torsney [4] and has been extensively studied in literatures. It is a weight-adjustment algorithm that iteratively modifies the weights of the design points while keeping the design points fixed. At each iteration, the weights are adjusted to improve the efficiency until a satisfactory design is achieved. Formally, let ξ old denote the experimental design at the beginning of an iteration. Then, the MA updates the weights to obtain the new design ξ new as follows: ξ new = MA(ξ old ).

(12)

The expanded formula for the Multiplicative Algorithm iteration step ξ new = MA(ξ old ) can be written as follows: winew = wiold

d (X i , ξ old ) t  j=0

, i = 1, 2, · · · , t,

wjold d (X j , ξ old )

where winew and wiold are the new and old weights, respectively, assigned to the i th design point. The convergence of MA has been studied by many researchers, and monotonic convergence has been demonstrated for a class of optimality criteria [17]. 3.3 Proposed Algorithm Gradient-based algorithms have an advantage in terms of computational efficiency and theoretical foundation when the objective function is smooth and well-behaved, making convergence easy. However, this advantage disappears when experimental conditions are complex or involve a large number of factors and parameters. In such cases, other optimization methods such as heuristic algorithms may be more suitable. However, heuristic algorithms may suffer from the problem of premature convergence and cannot guarantee convergence to the optimal solution, which means that the solutions they generate may be less accurate than those generated by gradient-based algorithms. This article propose an algorithm that integrated MA with QSA, which can improve the quality of solutions obtained by QSA and alleviate the problem of premature convergence. MA in After the second business of QSA begins, a uniform random number randi,1 the range [0, 1] and a probability Pr MA i,1 are assigned to A2i , i = 1, 2, 3, based on the MA MA formula in Eq. (8). If Pr i,1 > randi,1 , the weights of A2i are updated using the multiplicative process based on Eq. (12). This modification helps to improve the quality of the solution of the best-performing customer in the population. After the third business MA and probability Pr MA are assigned to of QSA begins, a new random number randi,2 i,2 MA each customeri , i = 1, 2, · · · , N3 in the third business. If Pr MA i,2 >randi,2 , the weights of the customer are updated using the multiplicative process. This modification helps

A Hybrid Queueing Search and Gradient-Based Algorithm

755

to improve the quality of solutions for customers that may be performing poorly in the population. In the context of optimal design problem, each customer state represents an experimental design, which is composed of t support points and t weights. Each support point has m factors, so the dimension of a customer state should be t(m + 1).

4 Numerical Examples Settings and Analyses In this section, we evaluated the performance of our proposed algorithm for finding locally D-optimal designs on two generalized linear models (GLM): the logistic model and Poisson model. To assess the performance of our algorithm, we compared it against several state-of-the-art algorithms, including DE [8], SaDE [18], GA [19], and PSO [7]. Furthermore, we will also compare the performance of our proposed algorithm with the original version of QSA. The algorithms are implemented in Python 3.8. For logistic model and Poisson model, we generated 2 sets of random values as model parameters and conducted 100 runs for each parameter setting. The log-determinant of Fisher Information Matrix was set as the objective function for algorithms. In the experiments, the population size for all algorithms was set to 500, and the number of iterations was set to 2000. When setting the tuning parameters, we set the F and CR for the DE algorithm to 0.5 and 0.9, respectively, which is suggested by [8]. For SaDE, we set CRm to 0.5 and generated the value of F from the normal distribution [0.5, 0.3] based on [9]. For PSO, we set c1 and c2 to 2 and w to 0.9 recommended by [7]. For GA, we set the crossover probability to 0.95 and the mutation probability to 0.025. Through careful analysis, we set the number of support points for each model as follows: 11 for the logistic model, 9 for the Poisson model. 4.1 Numerical Examples Settings Logistic Model The logistic model is used to model binary or dichotomous response variables, where the response variable can only take two possible values, usually 0 and 1. It assumes that the logit of the probability of the response variable being equal to 1 is a linear function of the predictor variables. In the logistic model, the linear predictor is a linear combination of the predictor variables, each multiplied by its respective regression coefficient. The logistic function is then applied to the linear predictor to obtain the predicted probability of the response variable being equal to 1. The logistic model is widely used in many fields, such as medicine, epidemiology, finance, and social sciences. The expression for the logistic model is: P(Y = 1|X) =

1 1 + e−μ

(13)

where Y is the binary response variable, μ is the linear predictor which is given by: μ =Xθ T =θ0 + θ1 x1 + θ2 x2 + · · · + θp xp ,

(14)

756

Y. Zhang et al.

X = (1, x1 , x2 , . . . , xp )T are the predictor variables, and θ0 , θ1 , θ2 , . . . , θp are the parameters. Meanwhile, for the logistic model, its Fisher Information Matrix equation will take a different form compared to the linear model: M (ξ, θ ) =

t  i=1

eμ wi X i X Ti . (1 + eμ )2

(15)

For logistic model, we specify the linear predictor as: μ = θ0 + θ1 x1 + θ2 x2 + θ3 x3 + θ4 x4 + θ5 x5 .

(16)

The design space χ was set as [−1, 1]5 . And two sets of initial parameters guesses were randomly generated within the range of [−1, 1]6 for the purpose of generating locally D-optimal designs. The parameter guess values are θ 1 = (−0.54, −0.45, 0.97, −0.53, 0.39, 0.28) and θ 2 = (−0.34, −0.56, −0.55, −0.70, 0.77, −0.51). Poisson Model The Poisson model is a statistical model used to describe the relationship between a count-variable and predictor variables. The model assumes that the count-variable follows a Poisson distribution. The Poisson model is often used in fields such as epidemiology, biology, and environmental sciences to model the occurrence of events such as disease outbreaks, species populations, or environmental pollution levels. The Poisson model can be represented as: y ∼ Poisson(λ),

(17)

where y is the count-variable for observation λ, and λ is the parameter of the Poisson distribution. The Poisson model assumes that the mean and variance of y are equal to λ. Poisson model without interaction can be extended to include one or more predictor variables by using a generalized linear model framework, such as: log(λ) = θ0 + θ1 x1 + θ2 x2 + · · · + θp xp ,

(18)

where θ0 is the intercept, θ1 , θ2 , . . . , θp are the coefficients for the predictor variables x1 , x2 , . . . , xp . This model specification is known as the Poisson regression model, and it can be used to estimate the effect of each predictor variable on the count-variable y. The formula for the Fisher Information Matrix of the Poisson model is given by: M (ξ, θ ) =

t  i=1

elog(λ) wi X i X Ti .

(19)

A Hybrid Queueing Search and Gradient-Based Algorithm

757

For Poisson model, we specify the linear predictor as: log(λ) = θ0 + θ1 x1 + θ2 x2 + θ3 x3 + θ4 x1 x2 + θ5 x1 x3 + θ6 x2 x3 . To find locally D-optimal designs, two sets of randomly generated parameters are used, namely θ 3 = (0.52, 0.23, 0.91, −0.14, −0.38, 0.80, −0.71) and θ 4 = (0.80, −0.21, −0.13, −0.78, 0.16, 0.62, −0.05). The experimental space was set to [−1, 1]3 . 4.2 Results and Analyses In this subsection, the results and discussion of numerical examples are presented. Due to the inherent randomness of heuristic algorithms, we choose the mean optimization history as a performance indicator. The objective function is the logarithmic determinant of the information matrix, so a larger objective function value indicates better algorithm performance. Figure 1 shows the average optimization history of 100 runs for the proposed algorithm and its competitors on the Logistic Model with parameters set as θ 1 and θ 2 , respectively. Similarly, Fig. 2 illustrates the average optimization history of 100 runs for the proposed algorithm and its competitors on the Poisson Model with parameters set as θ 3 and θ 4 , respectively. From these four figures, it can be observed that the proposed algorithm achieves superior optimality and fitness values with fewer iterations compared to other algorithms. A closer look at these figures reveals that the original QSA exhibits a fast convergence rate, but its final solution is unsatisfactory, a problem also present in the PSO algorithm. DE, SaDE, and GA have smooth optimization history curves, however, they do not provide satisfactory results, and requiring an long runtime. In contrast, the proposed algorithm slows down the convergence rate while enhancing the quality of the solution. Notably, despite the slower convergence rate, the proposed algorithm maintains a competitive convergence speed relative to other algorithms. Table 1 presents the results of the objective function values obtained from 100 independent runs for the logistic model and Poisson model, respectively. The worst fitness value (worst) and the best fitness value (best) are used as comparative metrics. Additionally, the best results are highlighted in bold. From these two tables, it is evident that the proposed algorithm’s ability to find the best results is significantly superior to that of other algorithms. Moreover, the proposed algorithm can obtain optimal designs that other algorithms cannot achieve. Furthermore, compared to the original QSA, the proposed algorithm considerably improves the quality of the solution’s lower bound.

758

Y. Zhang et al.

Fig. 1. Average optimization history obtained by each algorithm after 100 runs in finding for locally D-optimal design for logistic model on parameter θ 1 (a) and θ 2 (b) Table 1. Comparison results of algorithms over 100 runs on logistic model Algorithm

Logistic Model

Poisson Model

Parameters

Worst

Best

Parameters

Worst

Best

Proposed Algorithm

θ1

3.87e-5

5.21e-5

θ3

220.46

255.72

θ2

3.85e-5

5.24e-5

θ4

495.45

550.34

QSA

θ1

2.47e-5

4.50e-5

θ3

119.43

231.38

θ2

2.93e-5

4.56e-5

θ4

342.98

496.14

θ1

3.37e-5

4.53e-5

θ3

219.09

231.33

θ2

3.42e-5

4.51e-5

θ4

478.80

496.83

SaDE

θ1

3.59e-5

4.54e-5

θ3

226.76

231.38

θ2

3.51e-5

4.56e-5

θ4

479.10

496.84

PSO

θ1

8.16e-6

3.92e-5

θ3

29.46

193.69

θ2

1.08e-5

4.00e-5

θ4

72.66

477.42

θ1

3.76e-5

4.53e-5

θ3

221.29

230.34

θ2

3.62e-5

4.57e-5

θ4

462.75

494.38

DE

GA

A Hybrid Queueing Search and Gradient-Based Algorithm

759

Fig. 2. Average optimization history obtained by each algorithm after 100 runs in finding for locally D-optimal design for Poisson model on parameter θ 3 (a) and θ 4 (b)

5 Conclusion In conclusion, this article presented a novel hybrid algorithm that integrates the Queueing Search Algorithm with the Multiplicative Algorithm for finding Optimal Experimental Designs. The proposed algorithm takes advantage of the strengths of both QSA and MA, improving the quality of solutions obtained by QSA and alleviating the problem of premature convergence. The performance of the proposed algorithm was demonstrated using various models, such as the logistic model and Poisson model. Through the numerical examples, it was shown that the proposed algorithm exhibits superior convergence properties and optimality compared to other algorithms, proving its efficacy and versatility in tackling OED problems. Future research can explore the extension and application of the proposed algorithm to other optimization problems and domains. Additionally, the hybridization of QSA with other algorithms, as well as the development of parallel and distributed computing techniques, can be investigated to further improve the performance of QSA. It is expected that the proposed algorithm will contribute significantly to the field of optimal experimental design and open up new possibilities for optimization methodologies in diverse application scenarios. Acknowledgment. The research work is supported by National Natural Science Foundation of China under Grant No.11901325.

References 1. Welch, W.J.: Algorithmic complexity: three NP-hard problems in computational statistics. J. Stat. Comput. Simul. 15(1), 17–25 (1982)

760

Y. Zhang et al.

2. Fedorov, V.: Theory of optimal experiments. Translated from the Russian and edited by WJ Studden and EM Klimko. Probabil. Math. Stat. 12 (1972) 3. Wynn, H.P.: The sequential generation of D-optimum experimental designs. Ann. Math. Stat. 41(5), 1655–1664 (1970) 4. Silvey, S., Titterington, D., Torsney, B.: An algorithm for optimal designs on a design space. Commun. Statist.-Theory Methods 7(14), 1379–1389 (1978) 5. Yu, Y.: D-optimal designs via a cocktail algorithm. Stat. Comput. 21, 475–481 (2011) 6. Harman, R., Filová, L., Richtárik, P.: A randomized exchange algorithm for computing optimal approximate designs of experiments. J. Am. Stat. Assoc. 115(529), 348–361 (2020) 7. Kennedy, J., Eberhart, R.: Particle swarm optimization. In: Proceedings of ICNN’95International Conference on Neural Networks, pp. 1942–1948. IEEE (1995) 8. Storn, R., Price, K.: Differential evolution-a simple and efficient heuristic for global optimization over continuous spaces. J. Global Optim. 11(4), 341 (1997) 9. Qiu, X., Xu, J., Tan, K.C.: A novel differential evolution (DE) algorithm for multi-objective optimization. In: 2014 IEEE Congress on Evolutionary Computation (CEC), pp. 2391–2396. IEEE (2014) 10. Tong, L., Wong, W.K., Kwong, C.K.: Differential evolution-based optimal Gabor filter model for fabric inspection. Neurocomputing 173, 1386–1401 (2016) 11. Zhang, J., Xiao, M., Gao, L., Pan, Q.: Queuing search algorithm: A novel metaheuristic algorithm for solving engineering optimization problems. Appl. Math. Model. 63, 464–490 (2018) 12. Nguyen, B.M., Hoang, B., Nguyen, T., Nguyen, G.: nQSV-Net: a novel queuing search variant for global space search and workload modeling. J. Ambient. Intell. Humaniz. Comput. 12(1), 27–46 (2021). https://doi.org/10.1007/s12652-020-02849-4 13. Abdelbasit, K.M., Plackett, R.: Experimental design for binary data. J. Am. Stat. Assoc. 78(381), 90–98 (1983) 14. Chernoff, H.: Locally optimal designs for estimating parameters. The is used to model binary. Ann. Math. Statist. 586–602 (1953) 15. Li, G., Majumdar, D.: D-optimal designs for logistic models with three and four parameters. J. Statist. Plan. Inferen. 138(7), 1950–1959 (2008) 16. Kiefer, J., Wolfowitz, J.: The equivalence of two extremum problems. Can. J. Math. 12, 363–366 (1960) 17. Yu, Y.: Monotonic convergence of a general algorithm for computing optimal designs (2010) 18. Qin, A.K., Suganthan, P.N.: Self-adaptive differential evolution algorithm for numerical optimization. In: 2005 IEEE Congress on Evolutionary Computation, pp. 1785–1791. IEEE (2005) 19. Holland, J.H.: Genetic algorithms. Sci. Am. 267(1), 66–73 (1992)

A Review of Client Selection Mechanisms in Heterogeneous Federated Learning Xiao Wang1,2 , Lina Ge1,2,3(B) , and Guifeng Zhang1,2,3 1 School of Artificial Intelligence, Guangxi Minzu University, Nanning 530006, China

[email protected]

2 Key Laboratory of Network Communication Engineering, Guangxi Minzu University,

Nanning 530006, China 3 Guangxi Key Laboratory of Hybrid Computation and IC Design Analysis, Nanning 530006,

China

Abstract. Federated learning is a distributed machine learning approach that keeps data locally while achieving the utilization of fragmented data and protecting client privacy to a certain extent. However, the existence of data heterogeneity may cause instability and low efficiency during federated learning training, while system heterogeneity may lead to resource waste and low efficiency. Meanwhile, due to limited communication bandwidth and other resources, selecting the participating client in the training process becomes significant. Research has shown that a reasonable client selection mechanism can improve training efficiency and model accuracy. This article discusses the heterogeneity faced by federated learning in heterogeneous environments and the highly dynamic challenges it will face in the future. It reviews the latest client selection mechanisms from the perspective of client reputation, time threshold, and other factors. Finally, it provides some research directions for federated learning client selection. Keywords: Federated Learning · Client Selection · Node Selection · Heterogeneous Federated Learning · Machine Learning

1 Introduction The development of machine learning inevitably requires a large amount of data. With the development of IoT technology and the increase in IoT devices, as well as the improvement of people’s privacy awareness and the constraints of legal policies, traditional centralized machine learning seems unable to cope with the situation of massive and scattered data. Thanks to Google’s proposed federated learning technology, we can use distributed storage of private data to collaboratively train machine learning models. Federated learning has better dynamism and flexibility than traditional centralized machine learning, providing a distributed machine learning solution in edge computing environments. It leaves the original data on the local device, and because the participating parties do not share the original data, it ensures the security of local privacy data to some extent. In addition, facing large-scale model training scenarios, federated learning, © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 761–772, 2023. https://doi.org/10.1007/978-981-99-4742-3_63

762

X. Wang et al.

as a distributed machine learning paradigm, can better adapt, ensuring a certain level of federated learning efficiency while also guaranteeing reliable federated model prediction accuracy. In order to obtain an ideal model, federated learning generally requires that the data of each participant is independent and identically distributed (IID) and has been heavily labeled, while fully participating in the entire federated learning process. However, the actual situation is often not satisfactory. On the one hand, data widely distributed on client devices mostly lack labels, and there may not be enough common features or samples among participants, and the datasets they possess may have significant differences in distribution and scale. Moreover, the advantage of federated learning is that it adopts distributed storage of data on user devices to protect data security, which limits data preprocessing operations (such as data cleaning, duplicate data elimination, and augmentation) since the central server cannot access the local data and data statistics of the participants. On the other hand, the network environment in reality is always unreliable. Due to the uneven network resources possessed by each participant, when the central server sends and the participants upload the gradient information or model parameters, the information cannot be received in time, or the information is lost, or the information is transmitted incorrectly. In each round of training, participants may also voluntarily or involuntarily drop out of the training process for various reasons. Moreover, active and passive attacks in the network pose a serious threat to the data security of participants. Some novel research helps us to address the above challenges [1–5]. We believe that Reasonable client selection can improve the efficiency and accuracy of federated learning, as well as to some extent avoid the participation of malicious nodes and protect data. The main contributions of this paper are as follows: (1) Discuss the challenges that federated learning faces due to heterogeneity. (2) Review the research on improving federated learning efficiency from the perspectives of statistical heterogeneity and system heterogeneity. (3) Summarize the latest client selection methods from different angles. (4) Propose some research directions for future client selection.

2 Background 2.1 Research on Statistical Heterogeneity Compared to IID data, statistical heterogeneity is known to have negative impacts on various aspects of federated learning, such as performance, efficiency, and data privacy. To address the negative effects of statistical heterogeneity on federated learning, researchers have studied in the following areas: In terms of limiting model divergence, some studies focus on improving data quality, such as data sharing [3] and data augmentation [6]. Other studies have explored modifying the model aggregation methods [7], customizing the loss function [8], or combining federated learning with reinforcement learning [9] and knowledge distillation [10], among others. Some studies have argued that federal learning protects privacy at the cost of significant communication overhead during training. The empirical acceleration gains provided

A Review of Client Selection Mechanisms in Heterogeneous

763

by distributed training often fall short of the desired optimal linear scaling, and communication overhead may be the main source of this acceleration saturation phenomenon. [11, 12] have reduced communication costs by compressing communication. [13] further improved robustness, while [14] achieved it by adjusting the model structure. Reference [15] analyzed the convergence bounds of federated learning based on gradient descent from a theoretical perspective, proposed a control algorithm for learning data distribution, system dynamics, and model features, and dynamically adjusted the frequency of global aggregation in real time based on this. However, the method proposed in the paper only applies to IID data distribution. Selecting suitable clients to participate in training is a unique method. [16] requests client resource information to select clients for participation in the current round of training, addressing the issue of reduced efficiency when some clients have limited computational resources (i.e., longer update times) or poor wireless channel conditions (longer upload times), allowing the server to aggregate updates from as many clients as possible and accelerate performance improvements in the model. [17] proposes a new layered client selection scheme to reduce variance and achieve better convergence and higher accuracy. [18] uses MultiAgent reinforcement learning to propose the FedMarl framework, which selects the optimal clients at runtime based on device performance statistics and training behavior, according to target goals specified by the application designer. [19] focuses on the energy cost issue in resource-constrained mobile devices. It is based on the classic FedAvg and only seeks to find and send weight updates containing important information to reduce the energy cost of clients while maintaining the model accuracy learned in FL. In addition, [20] innovatively explores privacy protection in non-IDD data FL and proposes a Differential Privacy algorithm called 2DP-FL, which has advantages in privacy protection, learning convergence, and model accuracy. Here is a comparison table for them as shown in Table 1. 2.2 Research on System Heterogeneity System heterogeneity mainly comes from the different resources owned by each client. One of the main issues related to building effective communication methods in federated learning is the way gradients are updated, that is, the choice between synchronous and asynchronous updating. Asynchronous updating can solve the speed problem caused by resource differences [21], but face the problem of gradient staleness. To address the above issues, paper [8] proposes using a proximal term to constrain the “distance” between the local model and the global model, so that the deviation between the local model of clients with more computation (or more update steps) and the global model is not so large, thereby alleviating the drift phenomenon. Paper [22] assigns the "epochs" involving aggregation weights to different clients, with clients with fewer training epochs assigned higher aggregation weights so that their local models are not overwhelmed by clients with larger training epochs. However, the aforementioned research on system heterogeneity in federated learning did not take into account the volatility of clients. The above methods have greatly promoted the development and application of federated learning, but there is a common problem of sacrificing one performance for the sake of improving another performance. Client selection is a unique method that provides us with another way to solve this problem. Through client selection, we can choose

764

X. Wang et al. Table 1. Comparison ofresearch on statistical heterogeneity.

No.

Classification

Methods

[3]

Limiting Model Divergence

improving data quality through data sharing

[6]

Improve data quality through data enhancement

[7]

modifying the model aggregation method

[8]

customizing the loss function

[9]

Build a client portfolio with quasi-optimal performance using reinforcement learning

[10]

Using transfer learning and knowledge distillation without sharing data or models can facilitate knowledge transfer

[11]

Reducing Communication Costs

model compression

[12] [13]

To achieve data compression with good robustness

[14]

Only receive the public part applicable to the server model

[15]

Dynamic computing of the rounds of local updates

[16] [17]

Designing Customized Client Selection Methods

Maximizing the number of participants Client-side stratification to reduce variance

[18]

reinforcement learning

[19]

reduce energy costs

[20]

Using differential privacy-based algorithms to protect privacy

preferred users to participate in training before each round of training, based on considerations such as the data owned by the users, the performance of their devices, or the degree of trust in the users. This can reduce the overall communication cost of the system and improve the system efficiency without affecting the final model accuracy, or achieve higher model accuracy within the same training time. The time and communication cost of the selection process are minimal.

3 Client selection methods in Federated Learning FedAvg [7] uses random client selection. However, due to differences in computing and communication capabilities, as well as different data samples, randomly selected participants may reduce the performance of the FL process. We will summarize the research content in the following sections.

A Review of Client Selection Mechanisms in Heterogeneous

765

3.1 Client Selection Based on Trust Level In existing federated learning systems, clients may upload unreliable data, resulting in fraud in federated learning tasks. Clients may also intentionally or unintentionally execute unreliable updates, such as poisoning attacks on data or updates with low-quality data due to the heterogeneity and dynamics of clients. Therefore, some research focuses on finding trustworthy clients. Paper [23] proposed a reputation-based client selection scheme that uses a multi-weight subjective logic model, considering both direct interaction and recommendations from other task publishers to calculate the candidate’s comprehensive reputation value. This method considers three weight attributes: interaction frequency, interaction timeliness, and interaction effectiveness. Reputation is stored in the blockchain to achieve irrefutable and tamper-resistant reputation management. During federated learning, Intel’s SGX technology is used to detect “lazy” clients who do not train all their local data, and some poisoning attack detection schemes are used to identify poisoned attacks and unreliable clients, and then remove unreliable local model updates. After completing a training round, new local reputation opinions are generated for clients based on their performance in the previous round, and updated in the reputation blockchain. In [24], a reputation-based regional FL framework is introduced for highly dynamic intelligent transportation systems. Due to the dynamic nature of the traffic system, devices may leave their region and join a new one. Each leader calculates the reputation of vehicles based on honesty, accuracy, and interaction timeliness and selects vehicles with high reputation to participate in the FL process. The forward sharing of the vehicle client with the server will improve the accuracy of the global model, and also increase the reputation of the vehicle, while the opposite will lead to a decrease in reputation. Paper [25] introduces a reputation model based on the beta distribution function to measure the trustworthiness of terminal devices. Trust values are evaluated based on the participants’ contributions to the global model, and the client’s behavior is divided into positive and negative forms, and the reputation value of the terminal is updated by evaluating the contribution of different behaviors to the global model. They also propose a scheduling strategy based on reputation and successful transmission rates, considering reliability and fairness. To address the intentional execution of unreliable updates by clients for destructive purposes or the participation of low-quality data due to their own resource limitations in federated learning under wireless networks, the current client selection mechanism based on reputation mainly selects participants with high comprehensive reputation to participate in federated learning. Although no a priori information from the client is required but most selection mechanisms do not consider the impact of system and data heterogeneity on the system’s learning efficiency. The comparison of client selection based on trust level is shown in Table 2. 3.2 Client Selection Based on Time Threshold To cope with the aggregation of federated learning in unstable training environments, clients may be required to complete local training within a specific deadline. Different from the static time threshold proposed in [16, 26] proposes an adaptive deadline determination algorithm for mobile devices. The deadline for each round is adaptively

766

X. Wang et al. Table 2. Comparison of client selection based on trust level.

No.

Heterogeneity

Dynamicity

Fairness

Target

[23]

×

Improve the reliablity of FL

×

× √

×

[24]

Improve accuracy of knowledge

[25]

×

× √



Improving the reliability and convergence performance

determined based on considering the performance differences of mobile devices, rather than using a fixed deadline. In [27], aims to maximize the number of participating clients by predicting their ability to perform training tasks based on their information. Additionally, considering the data heterogeneity among clients, the authors prioritize the participants with the highest event rates. The client selection mechanism based on the deadline mainly studies the impact of system heterogeneity on federated efficiency. The study from static to dynamic deadlines enhances the robustness of client selection methods. However, the client selection method based on time thresholds determined by clients’ own resources cannot guarantee fairness, and clients may not necessarily be willing to inform the server of their resource information while the server requires a priori information from the client. The current client selection method based on time thresholds lacks consideration of fairness in client selection and the vulnerability of clients in dynamic environments. The comparison of client selection based on time threshold is shown in Table 3. Table 3. Comparison of client selection based on time threshold. No.

Statistical Heterogeneity

Fairness

Reliability

Target

[16]

×

×

×

[26]

× √

×

×

×

×

Maximizing the number of participants while considering system heterogeneity

[27]

3.3 Client Selection Based On Reinforcement Learning In reinforcement learning, an agent learns to achieve a goal in a complex and uncertain environment and maximize its reward. [28] formulates the above problem as a multi-armed bandit problem, minimizing the client’s local computation time and data transmission time under the assumption that achieving the target accuracy requires a certain number of communication rounds, in both ideal and non-ideal environments. Realizing that in reality the server cannot know the client’s local training time, and that most selection methods prefer clients with higher performance, to ensure fairness in the system, [29]while predicting the training time of clients based on their reputation, the minimum average selection rate of each participant is constrained and adjusted.

A Review of Client Selection Mechanisms in Heterogeneous

767

[30] proposes an experience-driven control framework called FAVOR, aimed at determining the optimal ratio of devices in each round to minimize the number of rounds and solve the distribution problem of non-IDD data. The authors formulate the client selection process of FL as a deep reinforcement learning problem. To select devices in each communication round, a double deep Q-learning mechanism is proposed to improve the accuracy of the global model and reduce the number of communication rounds. [31] considered the training quality and efficiency of heterogeneous nodes, screened malicious nodes, and optimized the training latency while improving the accuracy of federation learning models. Considering that some client information is not prior, the client selection mechanism based on reinforcement learning is more suitable for federated learning in real-world situations. Compared with the previous client selection mechanisms, it considers more comprehensively, but the issue of reliability of participants still needs to be addressed. The comparison of client selection based on reinforcement learning is shown in Table 4. Table 4. Comparison of client selection based on reinforcement learning. No. [28] [29] [30] [31]

Statistical Heterogeneity √ × √ √

System Heterogeneity √

Dynamicity √

Fairness √







× √

×

×

×

×

3.4 Client Selection Based on Probability Due to limitations in communication resources or geographic distribution, federated learning algorithms such as FedAvg typically perform multiple local iterations on randomly selected clients, and periodically aggregate the new local model updates through a central server. However, due to the presence of system heterogeneity and data heterogeneity, these schemes often converge slowly. [32] found that previous research mostly lacked joint consideration of system heterogeneity and statistical heterogeneity. They simultaneously considered system and data heterogeneity, providing a new convergence upper bound for the selection probability of any client, and generated a non-convex training time minimization problem. [33] derived a new convergence upper bound for the non-convex loss function using federated learning with arbitrary device selection probabilities. They designed a random optimization problem based on the Lyapunov drift with penalty framework based on the current channel conditions to minimize the weighted sum of convergence bounds and communication time. [34] believed that it is unnecessary for the global model to aggregate all the parameters of all clients, and excluding some unfavorable local updates may result in a more accurate global model. Based on this, they proposed a probabilistic node selection framework (FedPNS), which dynamically adjusts the selection probability of devices based on their contribution to the model.

768

X. Wang et al.

[35] designs a participant selection algorithm to address the accuracy degradation caused by statistical heterogeneity. They use weight divergence to determine the degree of non-IID data. By evaluating the weight differences between client models and auxiliary models, participants with lower degrees of non-IID are selected and trained at higher frequencies. Analyzing the important factors that affect model accuracy and convergence speed, [36] proposes using loss value and training time as two indicators to evaluate model quality. Based on client quality, the study guides client selection and combines it with random selection strategy in a certain proportion to improve global model accuracy. However, this method results in a higher probability of being selected for clients with higher loss values and faster computation, which is unfair. To address the heterogeneity of client devices, [37] considers the contribution of each client to accelerating global model convergence and the communication cost of the system. It proposes a client selection method based on the knapsack model to maximize the weight change of the locally trained model as the optimization objective. However, this method requires prior resource information collection from the clients. Table 5. Comparison of client selection based on probability. No. [32] [33] [34] [35] [36] [37]

Statistical Heterogeneity √

System Heterogeneity √

Dynamicity

Fairness





× √

×

×

×

×



× √

×

×





×

×

×

×

√ √

×

Research on client selection mechanisms based on probability distribution has studied the impact of statistical heterogeneity and system heterogeneity on federated efficiency, although no a priori information about the client is required as in the reinforcement learning-based client selection method, the reliability of the clients involved in the training is not guaranteed. The comparison of client selection based on probability allocation is shown in Table 5. 3.5 Discussion There is also other method like [38], a group-based participant selection mechanism was introduced, where participants were divided into different groups based on Group-Earth Mover’s Distance (GEMD) to balance the label distribution of clients. In summary, the impact of heterogeneity on federated efficiency has been explored to some extent, and client selection mechanisms based on trust level can mitigate threats from unreliable clients to a certain extent. However, current research on client selection

A Review of Client Selection Mechanisms in Heterogeneous

769

mechanisms in federated learning is not comprehensive enough and often requires specific conditions, which are difficult to achieve in real-world environments. We summarize the existing challenges as follows: (1) Existing federated learning systems rely too heavily on central servers, and the reliability of the central server determines whether the entire system can operate stably. (2) Most methods require prior knowledge of the resource information that clients possess, which is difficult to achieve in real-life scenarios. On the one hand, clients may not be willing to disclose their resources, and there is a risk of privacy leakage during the information transmission process. On the other hand, it is not easy to collect this information in a large-scale environment. (3) The existence of system heterogeneity, especially the instability of the network, results in high dynamics of clients. Clients that are participating in a round of training may actively or passively drop out of the current round of training due to their own reasons or system issues, which leads to wasted resources and decreased system efficiency. (4) Cryptographic methods introduced for privacy preservation will generate additional communication overhead, which affects the efficiency of the system. (5) Most client selection methods prefer to select high-performance devices, which is unfair to other clients and may also affect the efficiency of the system, because some low-performance devices may possess high-quality data.

4 Future Research Directions Research on the client selection mechanism in federated learning is still in its early stages, and there are some challenges that need to be gradually addressed in future research. This article identifies the following research directions: (1) Decentralized client selection mechanism in federated learning: Although existing client selection mechanisms in federated learning reduce the possibility of training problems caused by client nodes’ errors to some extent, existing research has not considered the situation where the central server fails. The introduction of a decentralized mechanism using blockchain can be very helpful in this regard. (2) Transfer learning and reinforcement learning in federated learning: Transfer learning can utilize knowledge from different but related source domains to build effective models for small datasets or limited label applications in the target domain, which is helpful for highly diverse datasets owned by various participants in practical scenarios. Reinforcement learning can select the optimal client without prior knowledge of resource information, which is suitable for real-world situations. Research on the integrated application of transfer learning and reinforcement learning in federated learning can further improve the efficiency of federated learning. (3) Research on dynamic client selection: With the full adoption of IPv6 on the horizon, we are moving towards an era of ubiquitous connectivity. The construction of future smart cities relies heavily on highly dynamic mobile devices such as smart wearables and smart transportation that interact with the surrounding environment and other smart facilities. Future federated learning needs to consider a large number of devices

770

X. Wang et al.

that can join or leave the training at any time and anywhere. Research on dynamic client selection will be a huge challenge. (4) Client’s resistance to attack: Currently, cryptographic methods, differential privacy, and adversarial training are the main means used in federated learning for privacy protection. These methods introduce additional computational and communication overheads, sacrificing system efficiency to enhance privacy protection in federated learning. Adversarial training still faces privacy risks as it requires clients to upload seed data samples. Research on client attack resistance in selecting clients for training is also a major challenge in building efficient and reliable federated learning. (5) Improvements for system heterogeneity: There are mainly two solutions to the performance differences among devices. The first is to use asynchronous updating methods to solve the problem of faster devices waiting for slower ones, but this method faces the problem of stale gradients. The second is to move local computation to edge servers with consistent performance to bridge the gap between device performance. However, this method still faces the limitation of communication resources. It is also worth researching how to optimize different update methods for different scenarios or how to use edge servers.

5 Conclusion Reasonable client selection mechanisms are of great significance in heterogeneous federated learning environments. On the one hand, by using reinforcement learning methods to select suitable clients for each round of training, the rewards of each round of training tend to be maximized, which can achieve a certain model accuracy with fewer communication rounds or a higher model accuracy with the same number of communication rounds. On the other hand, reputation-based client selection mechanisms can effectively address the poisoning attacks that currently exist in federated learning by filtering out unreliable clients. How to ensure fairness in client selection and the client selection method in highly dynamic environments are future directions worth studying. In this article, we have summarized different client selection methods for optimizing communication and efficiency in federated learning, starting from the challenges faced by data heterogeneity and system heterogeneity. Finally, after summarizing the challenges of high dependence on the central server, difficulty in prior knowledge of client resource information, high dynamism of clients, and fairness in selection methods in current federated learning, we believe that in the future, utilizing transfer learning and reinforcement learning to optimize client selection in federated learning, especially in environments with highly dynamic clients, can greatly reduce communication costs and accelerate model convergence. We will further study this topic. Acknowledgement. This work was supported by the National Natural Science Foundation of China under Grant 61862007, and Guangxi Natural Science Foundation under Grant 2020GXNSFBA297103.

References 1. Chen, A., Fu, Y., Sha, Z., et al.: An emd-based adaptive client selection algorithm for federated learning in heterogeneous data scenarios. Front. Plant Sci. 13 (2022)

A Review of Client Selection Mechanisms in Heterogeneous

771

2. Rubner, Y., Tomasi, C., Guibas, L.J.: The earth mover’s distance as a metric for image retrieval. Int. J. Comput. Vision 40(2), 99 (2000) 3. Zhao, Y., Li, M., Lai, L., et al.: Federated learning with non-iid data. arXiv preprint arXiv: 1806.00582 (2018) 4. Xia, Z., Chen, Y., Yin, B., et al.: Fed_ADBN: an efficient intrusion detection framework based on client selection in AMI network. Expert Syst. (2022) 5. Huang, T., Lin, W., Shen, L., et al.: Stochastic client selection for federated learning with volatile clients. IEEE Internet Things J. 9(20), 20055–20070 (2022) 6. Jeong, E., Oh, S., Kim, H., et al.: Communication-efficient on-device machine learning: Federated distillation and augmentation under non-iid private data. arXiv preprint arXiv: 1811.11479 (2018) 7. McMahan, B., Moore, E., Ramage, D., et al.: Communication-efficient learning of deep networks from decentralized data. Artificial Intelligence And Statistics, pp. 1273–1282. PMLR (2017) 8. Li, T., Sahu, A.K., Zaheer, M., et al.: Federated optimization in heterogeneous networks. Proc. Mach. Learn. Syst. 2, 429–450 (2020) 9. Pang, J., Huang, Y., Xie, Z., et al.: Realizing the heterogeneity: a self-organized federated learning framework for IoT. IEEE Internet Things J. 8(5), 3088–3098 (2020) 10. Li, D., Wang, J.: Fedmd: Heterogenous federated learning via model distillation. arXiv preprint arXiv:1910.03581 (2019) 11. Wang, H., Sievert, S., Liu, S., et al.: Atomo: Communication-efficient learning via atomic sparsification. Adv. Neural Inform. Process. Syst. 31 (2018) 12. Bernstein, J., Wang, Y.X., Azizzadenesheli, K., et al.: signSGD: Compressed optimisation for non-convex problems. In: International Conference on Machine Learning, pp. 560–569. PMLR (2018) 13. Sattler, F., Wiedemann, S., Müller, K.R., et al.: Robust and communication-efficient federated learning from non-iid data. IEEE Trans. Neural Networks Learn. Syst. 31(9), 3400–3413 (2019) 14. Kang, D., Ahn, C.W.: Communication cost reduction with partial structure in federated learning. Electronics 10(17), 2081 (2021) 15. Wang, S., Tuor, T., Salonidis, T., et al.: Adaptive federated learning in resource constrained edge computing systems. IEEE J. Sel. Areas Commun. 37(6), 1205–1221 (2019) 16. Nishio, T., Yonetani, R.: Client selection for federated learning with heterogeneous resources in mobile edge. In: ICC 2019–2019 IEEE international conference on communications (ICC), pp. 1–7. IEEE (2019) 17. Shen, G., Gao, D., Yang, L., et al.: Variance-reduced heterogeneous federated learning via stratified client selection. arXiv preprint arXiv:2201.05762 (2022) 18. Zhang, S.Q., Lin, J., Zhang, Q.: A multi-agent reinforcement learning approach for efficient client selection in federated learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, issue 8, pp. 9091–9099 (2022) 19. Zhao, J., Feng, Y., Chang, X., et al.: Energy-efficient client selection in federated learning with heterogeneous data on edge. Peer-to-Peer Network. Appl. 15(2), 1139–1151 (2022) 20. Xiong, Z., Cai, Z., Takabi, D., et al.: Privacy threat and defense for federated learning with non-iid data in AIoT. IEEE Trans. Industr. Inf. 18(2), 1310–1321 (2021) 21. Lu, Y., Huang, X., Dai, Y., et al.: Differentially private asynchronous federated learning for mobile edge computing in urban informatics. IEEE Trans. Industr. Inf. 16(3), 2134–2143 (2019) 22. Ruan, Y., Zhang, X., Liang, S.C., et al.: Towards flexible device participation in federated learning. In: International Conference on Artificial Intelligence and Statistics, pp. 3403–3411. PMLR (2021)

772

X. Wang et al.

23. Kang, J., Xiong, Z., Niyato, D., et al.: Incentive mechanism for reliable federated learning: a joint optimization approach to combining reputation and contract theory. IEEE Internet Things J. 6(6), 10700–10714 (2019) 24. Zou, Y., Shen, F., Yan, F., et al.: Reputation-based regional federated learning for knowledge trading in blockchain-enhanced IoV. In: 2021 IEEE Wireless Communications and Networking Conference (WCNC), pp. 1–6. IEEE (2021) 25. Song, Z., Sun, H., Yang, H.H., et al.: Reputation-based federated learning for secure wireless networks. IEEE Internet Things J. 9(2), 1212–1226 (2021) 26. Lee, J., Ko, H., Pack, S.: Adaptive deadline determination for mobile device selection in federated learning. IEEE Trans. Veh. Technol. 71(3), 3367–3371 (2021) 27. AbdulRahman, S., Tout, H., Mourad, A., et al.: FedMCCS: multicriteria client selection model for optimal IoT federated learning. IEEE Internet Things J. 8(6), 4723–4735 (2020) 28. Xia, W., Quek, T.Q.S., Guo, K., et al.: Multi-armed bandit-based client scheduling for federated learning. IEEE Trans. Wireless Commun. 19(11), 7108–7123 (2020) 29. Huang, T., Lin, W., Wu, W., et al.: An efficiency-boosting client selection scheme for federated learning with fairness guarantee. EEE Trans. Parallel Distrib. Syst. 32(7), 1552–1564 (2020) 30. Wang, H., Kaplan, Z., Niu, D., et al.: Optimizing federated learning on non-iid data with reinforcement learning. In: IEEE INFOCOM 2020-IEEE Conference on Computer Communications, pp. 1698–1707. IEEE (2020) 31. He, W., Guo, S., Qiu, X., Chen, L., Zhang, S.: Node selection method in federated learning based on deep reinforcement learning. J. Commun. 42(06), 62–71 (2021) 32. Luo, B., Xiao, W., Wang, S., et al.: Tackling system and statistical heterogeneity for federated learning with adaptive client sampling. In: IEEE INFOCOM 2022-IEEE Conference on Computer Communications, pp. 1739–1748. IEEE (2022) 33. Perazzone, J., Wang, S., Ji, M., et al.: Communication-efficient device scheduling for federated learning using stochastic optimization. In: IEEE INFOCOM 2022-IEEE Conference on Computer Communications, pp. 1449–1458. IEEE (2022) 34. Wu, H., Wang, P.: Node selection toward faster convergence for federated learning on non-iid data. IEEE Trans. Network Sci. Eng. 9(5), 3099–3111 (2022) 35. Zhang, W., Wang, X., Zhou, P., et al.: Client selection for federated learning with non-iid data in mobile edge computing. IEEE Access 9, 24462–24474 (2021) 36. Wen, Y., Zhao, N., Zeng, Y., Han, M., Yue, L., Zhang, J.: A client selection strategy based on local model quality. Comput. Eng. 1–16 (2023). https://doi.org/10.19678/j.issn.1000-3428. 0065658 37. Guo, J., Chen, Z., Gao, W., Wang, X., Sun, X., Gao, L.: Clients selection method based on knapsack model in federated learning. Chin. J. Internet Things 6(04), 158–168 (2022) 38. Ma, J., Sun, X., Xia, W., et al.: Client selection based on label quantity information for federated learning. In: 2021 IEEE 32nd Annual International Symposium on Personal, Indoor and Mobile Radio Communications (PIMRC), pp. 1–6. IEEE (2021)

ARFG: Attach-Free RFID Finger-Tracking with Few Samples Based on GAN Sijie Li1 , Lvqing Yang1(B) , Sien Chen2,3 , Jianwen Ding1 , Wensheng Dong4 , Bo Yu4 , Qingkai Wang5,6 , and Menghao Wang1 1 School of Informatics, Xiamen University, Xiamen, China

[email protected]

2 School of Navigation, Jimei University, Xiamen, China 3 School of Management, Xiamen University, Xiamen, China 4 Zijin Zhixin(Xiamen)Technology Co., Ltd., Xiamen, China 5 State Key Laboratory of Process Automation in Mining and Metallurgy, Beijing, China 6 Beijing Key Laboratory of Process Automation in Mining and Metallurgy, Beijing, China

Abstract. Traditional finger-tracking methods using sensors, cameras, and other devices are limited by high costs, environmental sensitivity, and inadequate user privacy protection. To address these challenges, we propose ARFG, an attach-free RFID finger-tracking system based on Generative Adversarial Network (GAN). ARFG captures time-series reflection signal changes resulting from finger movements in front of RFID tag arrays. These signals are transformed into feature maps that serve as inputs for DS-GAN, a fully supervised classification model using a semi-supervised algorithm. ARFG achieves accurate recognition of finger traces, with soft thresholding used to overcome the challenge of limited dataset training of conventional GANs. Extensive experiments demonstrate ARFG’s high accuracy, with average accuracy of 94.69% and up to 97.50% in conditions with few samples, various traces, users, finger speeds, and surroundings, showcasing its robustness and cross environmental capabilities. Keywords: Finger-Tracking · RFID · Few Samples · GAN

1 Introduction The emergence of the artificial intelligence of things has led to the adoption of non-typing input devices. While computer vision-based finger-tracking technology [1] offers high accuracy and convenience, it is susceptible to light intensity interference and privacy issues. To address these concerns, researchers have explored alternatives such as WIFI [2] multi-frequency continuous wave radar [3] and electromagnetic polarization discrimination [4], but the high cost of their hardware impedes their widespread adoption. Consequently, researchers have shifted their focus to RFID [5] as a promising low-cost solution. Early RFID-based finger tracking techniques relied on attach-based methods, which limited user experience by requiring physical attachment of the tag to the user’s body or © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 773–784, 2023. https://doi.org/10.1007/978-981-99-4742-3_64

774

S. Li et al.

an object. Recently, researchers have developed attach-free identification methods that rely on the tag to capture signals influenced by human gestures for behavior recognition [6]. RF-Finger [7], for example, employs tag arrays to enable over-the-air writing and significantly enhances user experience. With the wide application of deep learning techniques in various fields, researchers have made several attempts to combine artificial intelligence technology and RFID. For instance, Random Forest is utilized in ReActor [8] to extract gesture signal features and finally finish the recognition of gestures. However, collecting a large number of samples for training these models is time-consuming and requires significant effort. GAN and VAE has been shown to effectively address the problem of limited dataset, but VAE implementation requires significant computational resources and time, and is prone to instability, which limit its practical use in some scenarios. EC-GAN [9] has shown promising results by controlling the relative importance of generated data, but the feature extraction capability of the classifier needs improvement to handle synthesized images with noise and missing features. To this end, we present ARFG, a novel approach for finger-tracking that extracts trace features from attach-free finger signals, visualizes finger traces through per-unitmoment finger position calculation, and achieves high tracking accuracy with limited training samples. To enhance image feature learning, we introduce soft thresholding [10] into the classifier for denoising, and add self-attention and spectral normalization [11] to the generator and discriminator. In this paper, our main contributions are summarized as follows: 1. To the best of our knowledge, ARFG is the first system that employs GAN to successfully solve the shortage of training data for RFID-based finger-tracking tasks. 2. For this scenario, we propose DS-GAN, which effectively improves the training performance of the classifier in the presence of noisy synthesized images from the generator. 3. Comprehensive experiments have demonstrated that our system has impressive performance across a range of users, finger speeds, and environments, with an overall recognition accuracy of 94.69%.

2 Related Work 2.1 Attach-Based Recognition The attach-based approaches requires the user to wear or touch the device. Works of classify gestures by measuring the acceleration of a gesture with an accelerometer and utilizing a Dynamic Time Warping (DTW) distance based on template matching. Zhou Ren et al. [12] leverage the Kinect sensor to classify ten different gestures with high accuracy. RF-glove [13] utilizes three antennas and five commercial tags attached to five fingers to build a contactless smart sensing system that shows good capabilities in fine-grained gesture recognition. 2.2 Attach-Free Recognition With the continuous development of computer vision, attach-free based methods have received much attention and the signals medium for capturing human movements have

ARFG: Attach-Free RFID Finger-Tracking with Few Samples

775

become diverse. RF-Finger built a theoretical model to extract fine-grained reflection features from the original RFID signal to describe the influence of the finger on the tag array and achieved an accuracy of 88% in the finger tracking scenario. 2.3 Few-Shot Learning in RFID For the sake of accomplishing RFID-based gesture recognition with inadequate samples, researchers have attempted to apply few-shot learning to RFID. Unsoo Ha et al. [14] proposed RF-EATS, which uses variational autoencoders (VAEs) to generate synthetic samples to train models, and also uses migration learning to migrate trained VAEs to new environments with good results. RF-Siamese [15] using twin networks achieves an accuracy of 93% with only one sample of each gesture. MetaSense [16] and RF-net [17] use meta-learning to reduce the number of training samples and perform well on static gesture recognition.

3 Preliminaries In this section, we expound upon fundamental theories related to RFID technology, elucidate the a prior knowledge utilized in our experiments, and evaluate the impact of finger movements on RFID signal propagation. 3.1 Definition of RFID Signal When an electronic passive tag approaches the magnetic field range of an RFID read-er, it receives the reader’s RFID signal and employs the energy from the induction current to transmit the data stored in the chip. According to the research in [7], the Received Signal Strength Indication (RSSI) and Phase signals will exhibit a brief but noticeable fluctuation once the finger is close to the tag. The signal returned by the tag can be expressed as:   R R 10 10 10 10 cosθ + J sinθ, (1) S= 1000 1000 where R is the RSSI value and θ is the phase value. When the finger is placed in front of the tag, the signal returned S consists of two components: the signal Sdirect transmitted directly to the tag, and the reflected signal Sreflect passing through the finger. For Sdirect , it is directly measurable and equal to the signal the tag sends when the hand is down. Therefore, we can determine the signal reflected off the finger by subtracting Sdirect from the S computed while the hand is facing the tag array. The impact of finger motions on the RFID signal is likewise represented by this reflected signal.

776

S. Li et al.

4 System Design 4.1 Problem Definition In this study, the received signal from a tag is modeled as a sequence of times-tamps. Specifically, we denote the signal received from tag i as xi (t), which represents the data acquired by the i-th tag at moment t. Consequently, the temporal trace data can be represented as: ⎡

x1 (1) x1 (2) · · · x1 (T ) ⎢ x2 (1) x2 (2) · · · x2 (T ) ⎢ X =⎢ . .. . . .. ⎣ .. .. . xi (1) xi (2) . . . xi (T )

⎤ ⎥ ⎥ ⎥, ⎦

(2)

where i represents the number of tags, while T refers to the time interval between the hand-raising and hand-lowering actions. The dataset provided as input to the ARFG system is denoted as D = (Xn , ln )K n=1 , where K represents the number of data samples, X represents the raw data. Each original data X is associated with a label ln ∈ [0, 7], corresponding to the 8 different traces (“a”, “b”, “c”, “d”, “e”, “f ”, “a_l”, “a_r”) that need to be recognized in the experiments performed in this study. In this experiments, “a_l” and “a_r” denote the arrow pointing to the left and right, respectively. The selected letters and shapes represent the most frequently used finger motions in real world scenarios. Moreover, the significant differences between the shapes and letters ensure the reliability and validity of the obtained results. In feature extraction module, the temporal data Xafter , which has been preprocessed and split, is transformed into images with distinct trace features denoted as Tp , and then processed by the trace recognition module. The ultimate objective of this study is to accurately predict the labels for the unlabeled Xi samples. 4.2 System Overview The ARFG framework, as shown in Fig. 1, consists of two key modules: feature extraction and trace recognition. The former preprocesses and visualizes the temporal data, while the latter utilizes DS-GAN to recognize finger traces. To augment the training dataset, a 100-dimensional random vector is fed into the DS-GAN, the trace recognition module. Feature Extraction. After receiving the RFID signal from each tag in the tag array that contains RSSI and phase information, a series of operations will be taken to extract the feature image of the finger trail from these signals. The entire procedure can be divided into the following steps. Normal Data Preprocessing. The phase angle can jump at θ = π with a magnitude of 2π, as depicted in Fig. 2(a). The large and erratic fluctuations in the phase value prior to unwrapping renders it unsuitable for use as experimental data. Hence, we adopt the One-Dimensional Phase Unwrapping method [18] to correct the phase. Furthermore, we address issues such as misreading, lost reads, and random selection of answer time slots with linear interpolation. The Savitzky-Golay filter is applied to suppress noise and

ARFG: Attach-Free RFID Finger-Tracking with Few Samples

777

Fig. 1. Framework of ARFG. The original temporal signals is processed by the Feature Extraction module to feature images, then the Trace Recognition module predicts the trace labels from the extracted feature images.

(a) Original phase (left) and phase after unwrapping (right).

(b) Trace segmentation.

Fig. 2. The results of modules in Feature Extraction.

smooth the RFID signal, as a jittered and noisy signal can compromise future feature extraction. Trace Segmentation. Accurately identifying the start and end points of a gesture is essential for extracting finger movement features from RFID signals, while minimizing noise and preserving key information. We have modified the segmentation algorithm of the maximum variance stream. Our approach comprises three steps: i) Signal Normalization. We perform signal normalization to ensure that the signal variations of all tags in the tag array are of the same magnitude, by normalizing the phase and RSSI data after preprocessing. ii) Weight Distribution. Employing a weight distribution approach that dynamically adjusts a sliding window for each tag to calculate the variance stream corresponding to the phase and RSSI separately, and selects a suitable weight to combine the two variance streams. iii) Variance Maximization. The first and last peaks of the maximum variance stream, which are created by extracting the maximum variance of each sliding window, are used to determine the gesture’s beginning and finishing positions. An example of the finger movement after trace segmentation is shown in Fig. 2(b). Data Visualization. When the label array is viewed as a 6 × 6 grid, the Pearson correlation coefficient [19] can be used to determine the likelihood that the finger will be

778

S. Li et al.

above the i-th tag at the time t: E[(Pt (i)Pa (t))] − μPt (i) μPa (t) L(t, i) =

,

(E (Pt (i))2 − μ2Pt )(E (Pa (t))2 − μ2Pa )

(3)

the distance of the finger from the label is denoted by dHTi . N represents the number of tags in the tag array, and C is a constant assumed to be 1 in this paper. μ is the mean 4 is the theoretical reflected energy distribution value of Pt (i) andPa . Pt (i) = C × 1/dHT i matrix of the whole array corresponding to the i-th tag when the finger is on top of it. Pa (t) is the actual reflected energy distribution matrix obtained using the phase and RSSI information of each tag at timet. An actual reflected energy distribution matrix Pa = |Sreflect |2 is calculated for each tag at each moment. The likelihood L(t, i) of the finger being located on all tags at moment t is formed into a 6 × 6 likelihood estimation matrix that reflects the probability of the finger being located at each position on the tag array. The feature map is created from the finger traces after smoothing, which are made up of the position with the highest locational probabilities at each moment. Trace Recognition. This module consists of the DS-GAN, which is our proposed model consisting of three key modules: the generator, discriminator, and classifier, as depicted in Fig. 3.

Fig. 3. The Network architecture of DS-GAN.”SN(ConvTrans)” and “SN(Conv2d)” employ spectral normalization (SN) for transposed and stride convolution layers, respectively.”BN” refers to batch normalization and”GAP” refers to global average pooling layers.

Generator. The generator transforms a 100-dimensional random vector using four-layer transposed convolution with a kernel size of 4, progressively reshaping the input into a virtual RFID trace map of size 3 × 32 × 32. The synthesized images by the generator need to be evaluated by the discriminator. Discriminator. The architecture of the DS-GAN discriminator module bears similarities to the generator module, except for the substitution of transposed convolution with stride

ARFG: Attach-Free RFID Finger-Tracking with Few Samples

779

convolution. The activation function also changes from ReLU to LeakyReLU, while the output layer utilizes the sigmoid activation function. The primary objective of the discriminator module is to determine whether an image is real or not. Classifier. The DS-GAN classifier component employs the Deep Residual Shrinkage Net-work with channel-shared thresholds (DRSN-CS). The DRSN is a variant of the basic residual network that embeds a sub-network for adaptively setting thresholds. To extract features from the RFID finger-trace feature map, the network first calculates the absolute values of all the features, followed by a layer of global pooling. Next, the pooled RFID gesture feature map is processed by two paths. The outputs of the two paths are multiplied together to obtain the threshold value, which can be expressed as: τ = α • A,

(4)

where α is the normalized result of the sigmoid function and A represents the result of the first path averaged over the RFID gesture feature map. In order to reduce the impact of noise in generated images, we apply soft thresholding within the classifier and assign labels to generated images using a pseudo-labeling technique based on the classifier’s performance. Specifically, we only retain the generated images and labels when the classifier predicts the sample’s label with high confidence or the probability exceeds a predetermined threshold. During the training process, each small batch generator produces new images fed directly to the classifier. Its loss function can be defined as: Lc (Tp , l t , v) = CE(C(Tp ), l t ) + λCE(C(G(v))), argmax(C(G(v))) > ξ,

(5)

where λ is the adversarial weight, a parameter that controls the relative importance of the generated data compared to the true data. CE is the cross entropy loss, C is the classifier, and ξ is the pseudo-label threshold [20].

5 Performance Evaluation 5.1 Experimental Setup The experimental setup contains a 6 × 6 AZ-9629 RFID tag array and an ImpinJ Speedway R420 RFID reader integrated with an S9028PCL directional antenna. The arrays were placed on one side of a white plastic box. The opposite side of the box is the RFID antenna perpendicular to the ground, and it is approximately 50 cm away from the tag array. The user faces the tag array and waves his finger 10–15 cm away from the array. Experiments were conducted in two environments, Env_A(an empty office with less interference) and Env_B (a laboratory with more interference), to confirm the viability of ARFG. For the specific implementation part of DS-GAN, three different learning rates of 0.0002, 0.0004, and 0.0001 are used for the classifier, discriminator, and generator respectively, Adam is used as the optimizer, λ is chosen to be 0.1 and ξ selects 0.7. A total of 10 participants (6 men and 4 women) are invited to participate across the two experimental settings. 6 letters and 2 shapes were to be written by them ten times

780

S. Li et al.

each, along with half-fast and half-slow motions. A total of 800 pieces of data were collected, including 100 pieces for each gesture. For our studies, 20% of them served as the training data set, and the remaining 80% served as the test data set. The classification accuracy and macro F1-score are defined as metrics to evaluate the performance of ARFG, and we compare ARFG with four recognition methods, namely RF-Finger [7], Random Forest [8], ResNet, and EC-GAN [9]. 5.2 Experimental Results The test results are presented in Table 1. In comparison to RF-Finger, Random Forest, ResNet, and EC-GAN, ARFG exhibits superior recognition accuracy. With 8 finger traces, ARFG achieves an accuracy of 94.69%, with 4 actions having an accuracy of 95% or higher, notably actions “d” and “a_r” with an accuracy of 97.5%. These results demonstrate that even with limited samples, ARFG can provide reliable results for a variety of finger motions. This illustrates that not only can ARFG utilize DS-GAN for data augmentation to address the low performance issue of Random Forest and ResNet on small sample data, but it can also solve the problem of RF-finger’s inability to recognize fast gestures using an excellent gesture segmentation algorithm, as well as tackle the issue of poor performance of EC-GAN on high-noise virtual images by utilizing soft thresholding. Table 1. Quantitative comparison of the ARFG and other approaches at different finger traces. Methods

Accuracy(%) Overall

a

b

c

d

e

f

a_l

a_r

94.69

93.75

92.50

91.25

97.50

93.75

95.00

96.25

97.50

EC-GAN

91.72

92.50

93.75

93.75

90.00

92.50

88.75

92.50

90.00

RF-Finger

89.06

87.50

93.75

91.25

90.00

90.00

90.00

88.75

81.25

ResNet

80 63

85.00

83.75

85.00

76.25

80.00

80.00

77.50

77.50

Random Forest

75.93

76.25

73.75

73.75

78.65

72.50

78.75

78.75

75.00

ARFG

Evaluation of Different Users. We evaluate the robustness of ARFG by comparing the recognition accuracy of the model on 10 experimental participants, and the results are shown in Fig. 4(a). All ten participants’ finger trail recognition accuracy is more than 90%, with user 7 having the highest recognition accuracy (98.44%) and user 10 (a girl) having the lowest recognition accuracy (90.63%). We hypothesize that the slender fingers of females would cause the unsteady reflection signal, which lowers the recognition efficiency. As a result, ARFG has a good performance at identifying the finger movements made by various people. Evaluation of Different Finger Speeds. We designed experiments to test the recognition accuracy of ARFG at different speeds fingers. For 800 samples consisting of 8 finger traces, a total of 50 fast samples and 50 slow samples were collected for each

ARFG: Attach-Free RFID Finger-Tracking with Few Samples

781

trace, and the results are displayed in Fig. 4(b). “c”, “e”, and “a_l” have an accuracy of 95% at fast speeds and "a_l" has an accuracy of 100% at slow speeds. The experimental results demonstrate that finger speed has minimal impact on ARFG, as the accuracy for both slow and fast speeds traces is consistently above 92.5% and 87.5%, respectively. Evaluation of Different Environments. We assess the cross-environmental efficacy of the ARFG through a process of training on one environment and testing its performance on a distinct environment. For this purpose, we collected four datasets from two different environments. TRA, TRB and TEA, TEB represent the training and testing sets, respectively, collected from Env_A and Env_B. The results are shown in Fig. 4(c). The accuracy of ARFG trained on TRA and tested on TEB is 91.71%, while achieved an accuracy of 87.97% when trained on TRB and tested on TEA. It shows that ARFG still has excellent robustness under different environments.

(a)Accuracy with different users.

(b)Accuracy with different finger speeds.

(c)Accuracy with different environment.

Fig. 4. Evaluation results.

5.3 Ablation Study Evaluation of Different Environments. We compared the performance of ARFG with three modified versions of the model: one without self-attention and spectral normalization, one without the soft thresholding module, and one without the data visualization module. The results of the experiments are presented in Table 2. Our experiments reveal that the self-attention layer placed before the convolutional layer of the output layer and spectral normalization applied to the first three convolutional layers can capture the association of higher-level features and larger regions of features, resulting in an improved accuracy of 1.87%. This overcomes the traditional convolutional layers’ limitations in handling a high range of features. In contrast, the absence of the soft thresholding module in ARFG results in a 1.04% decrease in trace recognition accuracy. The soft thresholding module’s denoising ability proves to be significantly beneficial in improving the classifier’s performance in the face of generated data. By utilizing Eq. 4, a positive and appropriately sized threshold ensures that ARFG can recognize traces accurately by setting the features corresponding to noise to 0 while retaining

782

S. Li et al.

the features related to traces. Moreover, the visual representation of the traces plays a critical role in ARFG’s performance, as the accuracy drops to 81.25% when the trace is not visualized. Overall, these results demonstrate the effectiveness of the investigated modules in ARFG. Table 2. Accuracy and macro F1-score of ARFG in the absence of different modules. Modules

ARFG Accuracy(%)

MacroF1(%)

Nothing

94.69

94.52

Self-Attention and Spectral Normalization

92.97

92.91

SoftThresholding

93.75

93.68

Data Visualization

81.25

81.01

Table 3. Accuracy and macro F1-score of ARFG in different dataset size. Dataset size(%)

ARFG Accuracy(%)

MacroF1(%)

5

86.45

86.23

10

89.58

89.44

15

92.21

92.19

20

94.69

94.52

Dataset Size. To investigate the minimum amount of data required for ARFG to achieve satisfactory performance, we conducted experiments on datasets with 5%, 10%, 15%, and 20% of the training samples. Our results, presented in Table 3, indicate that ARFG can still achieve an 86.45% recognition accuracy with only 5% of the training data. It is noteworthy that the model’s accuracy and macro F1-scores remain consistent across different training datasets, highlighting its superiority in the finger trail recognition scenario.

6 Conclusion In this paper, we introduce ARFG, an attach-free RFID finger-tracking system with few samples based on GAN, it achieves high recognition accuracy with only a small number of finger-trace feature maps trained. It solves the problem of data-dependent and timeconsuming data acquisition for RFID-based attach-free recognition solutions. Not only ARFG leverage the feature extraction module to convert the original RFID signal into a finger feature map that can be used for clear identification, but it also uses DS-GAN for

ARFG: Attach-Free RFID Finger-Tracking with Few Samples

783

trace recognition and generates virtual images to supplement the difficult-to-acquire real RFID data. Furthermore, it embeds soft thresholding into the classifier to further improve the classification ability of the classifier in the face of noisy images. The experimental results show that ARFG achieves 94.69% accuracy for finger-tracking tasks in various scenarios. In the future, we plan to explore the potential of transformer for extracting more information from the temporal data of RFID. Acknowledgements. This paper is supported by the 2021 Fujian Foreign Cooperation Project(No. 2021I0001): Research on Human Behavior Recognition Based on RFID and Deep Learning; 2021 Project of Xiamen University (No. 20213160A0474): Zijin International Digital Operation Platform Research and Consulting; State Key Laboratory of Process Automation in Mining & Metallurgy, Beijing Key Laboratory of Process Automation in Mining & Metallurgy(No. BGRIMMKZSKL-2022-14): Research and application of mine operator positioning based on RFID and deep learning; National Key R&D Program of China-Sub-project of Major Natural Disaster Monitoring, Early Warning and Prevention(No. 2020YFC1522604).

References 1. Chang, X., Ma, Z., Lin, M., Yang, Y., Hauptmann, A.G.: Feature interaction augmented sparse learning for fast kinect motion detection. IEEE Trans. Image Process. 26(8), 3911–3920 (2017) 2. Gupta, H.P., Chudgar, H.S., Mukherjee, S., Dutta, T., Sharma, K.: A continuous hand gestures recognition technique for human-machine interaction using accelerometer and gyroscope sensors. IEEE Sens. J. 16(16), 6425–6432 (2016) 3. Ma, Y., Hui, X., Kan, E.C.: 3d real-time indoor localization via broadband nonlinear backscatter in passive devices with centimeter precision. In: Proceedings of the 22nd Annual International Conference on Mobile Computing and Networking, pp. 216–229 (2016) 4. Shangguan, L., Jamieson, K.: Leveraging electromagnetic polarization in a two antenna whiteboard in the air. In: Proceedings of the 12th International on Conference on emerging Networking EXperiments and Technologies, pp. 443–456 (2016) 5. Ding, H., et al.: A platform for free-weight exercise monitoring with passive tags. IEEE Trans. Mob. Comput. 16(12), 3279–3293 (2017) 6. Zou, Y., Xiao, J., Han, J., Wu, K., Li, Y., Ni, L.M.: Grfid: a device-free rfid-based gesture recognition system. IEEE Trans. Mob. Comput. 16(2), 381–393 (2016) 7. Wang, C., et al.: Multi touch in the air: Device-free finger tracking and gesture recognition via cots rfid. In: IEEE INFOCOM 2018-IEEE Conference on Computer Communications, pp. 1691–1699. IEEE (2018) 8. Zhang, S., et al.: Real-time and accurate gesture recognition with commercial rfid devices. IEEE Trans. Mobile Comput. (2022) 9. Haque, A.: Ec-gan: Low-sample classification using semi-supervised algorithms and gans (student abstract). In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 15797–15798 (2021) 10. Zhao, M., Zhong, S., Fu, X., Tang, B., Pecht, M.: Deep residual shrinkage networks for fault diagnosis. IEEE Trans. Industr. Inf. 16(7), 4681–4690 (2019) 11. Zhang, H., Goodfellow, I., Metaxas, D., Odena, A.: Self-attention generative adversarial networks. In: International Conference on Machine Learning, pp. 7354–7363. PMLR (2019) 12. Ren, Z., Yuan, J., Meng, J., Zhang, Z.: Robust part-based hand gesture recognition using kinect sensor. IEEE Trans. Multimedia 15(5), 1110–1120 (2013)

784

S. Li et al.

13. Xie, L., Wang, C., Liu, A.X., Sun, J., Lu, S.: Multi-touch in the air: concurrent micromovement recognition using rf signals. IEEE/ACM Trans. Networking 26(1), 231–244 (2017) 14. Ha, U., Leng, J., Khaddaj, A., Adib, F.: Food and liquid sensing in practical environments using {RFIDs}. In: 17th USENIX Symposium on Networked Systems Design and Implementation (NSDI 20), pp. 1083–1100 (2020) 15. Ma, Z., et al.: Rf-siamese: approaching accurate rfid gesture recognition with one sample. IEEE Trans. Mobile Comput. (2022) 16. Gong, T., Kim, Y., Shin, J., Lee, S.J.: Metasense: few-shot adaptation to untrained conditions in deep mobile sensing. In: Proceedings of the 17th Conference on Embedded Networked Sensor Systems, pp. 110–123 (2019) 17. Ding, S., Chen, Z., Zheng, T., Luo, J.: Rf-net: a unified meta-learning framework for rf-enabled one-shot human activity recognition. In: Proceedings of the 18th Conference on Embedded Networked Sensor Systems, pp. 517–530 (2020) 18. Itoh, K.: Analysis of the phase unwrapping algorithm. Appl. Opt. 21(14), 2470 (1982) 19. Pearson, K.: Vii. note on regression and inheritance in the case of two parents. Proc. Royal Soc. London 58(347–352), 240–242 (1895) 20. Lee, D.H., et al.: Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks. In: Workshop on Challenges in Representation Learning, ICML, vol. 3, p. 896 (2013)

GM(1,1) Model Based on Parallel Quantum Whale Algorithm and Its Application Huajuan Huang1 , Shixian Huang2 , Xiuxi Wei1(B) , and Yongquan Zhou1 1 College of Artificial Intelligence, Guangxi Minzu University, Nanning 530006, China

[email protected] 2 College of Electronic Information, Guangxi Minzu University, Nanning 530006, China

Abstract. Aiming at the problems of thinning and perforation leakage caused by corrosion in atmospheric distillation unit of a refinery, GM (1,1) model was used to predict the corrosion trend at the elbow of inlet distribution pipe of top heat exchanger of atmospheric distillation tower. Firstly, the exponential function is introduced to transform the original data of the traditional GM (1,1) model to improve the smoothness of the original sequence. Then, the background weights of the improved GM (1,1) model (EFGM (1,1)) are optimized by the proposed parallel quantum whale optimization algorithm (QWOA) to provide the optimal background weights for the model, so as to improve the prediction accuracy of the model. Finally, the QWOA-EFGM (1,1) model is established. The established new model is used to predict the corrosion at the elbow of the inlet distributor of the overhead heat exchanger of atmospheric distillation column. The prediction results of the actual example show that the QWOA-EFGM (1,1) model proposed in this paper has higher prediction accuracy and better stability than the traditional GM (1,1) model and EFGM (1,1) model. Keywords: Whale optimization algorithm · Parallel quantum theory · GM(1,1) model · Distillation apparatus · Corrosion prediction

1 Introduction In the atmospheric distillation system of refinery, equipment corrosion is a common phenomenon. The impurities that corrode the equipment mainly include sulfur compounds, inorganic salts, naphthenic acids, nitrogen compounds and so on. The corrosion problem directly affects the long-term, stable, full-load and high-quality operation of the production device, and reduces the factory operating rate. In severe cases, it will cause equipment leakage or failure, resulting in major accidents such as fire and explosion. Therefore, the detection and evaluation of corrosion conditions are of great significance for ensuring the safety performance and operational efficiency during refining operation. In the production process, more accurate and reliable corrosion monitoring information must be adopting various methods at the same time, the cost of using high-cost corrosion monitoring methods too frequently is quite huge. Therefore, using grey GM(1,1) prediction model to analyze the atmospheric distillation unit corrosion can reduce the monitoring frequency and cost. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 785–797, 2023. https://doi.org/10.1007/978-981-99-4742-3_65

786

H. Huang et al.

The grey system theory was first proposed by Chinese scholar Deng Ju-Long [1]. It is an uncertain system with small samples and poor information. Grey prediction model is an important part of grey system theory. Because of its low requirement for modeling data, it is applied to many fields such as society, economy, energy, industry, agriculture, and so on. With the deepening of research problems, many scholars at home and abroad have found that the traditional grey GM(1,1) prediction model has defects, which leads to low prediction accuracy. Therefore, they have done a lot of research to optimize the GM(1,1) model. Some scholars improve the prediction accuracy of GM(1,1) model by improving the initial value. For example, Gao et al. [2] transformed the original sequence by translation and sine function and optimized the initial value to obtain a new time response function, which was applied to the corrosion prediction of submarine pipelines. Wu [3] et al. used genetic algorithm to improve the initial value range of the cumulative sequence, so as to obtain the best initial value to replace the initial conditions of the model. The effectiveness of the model was verified by examples. Lu Jie [4] used the first-order linear difference equation to solve the model parameters and set the initial value and background value of the model as variables, and combined the optimized model with the annual consumption data of China petroleum to predict. Yang et al. [5] used the improved fireworks optimization algorithm to optimize the weight coefficient and initial value correction term of the background value of the grey model, so as to improve the fitting accuracy and prediction accuracy of the model. Some scholars improve the prediction accuracy of GM (1,1) model by improving the background value. For example, Fan Xinhai et al. [6] used the method of automatic optimization to determine the weight, and took the weight with the highest prediction accuracy as the background value. Through comparative analysis of examples, it was proved that the application effect was good. Yue Xi et al. [7] improved the model by using the integral mean value theorem to fit the real background value, and predicted a set of data by the optimized model, and the prediction accuracy was improved. Yang et al. [8] used three parameters to reconstruct the background value, which improved the performance of grey prediction model. Some scholars improve the prediction accuracy of GM(1,1) model by improving the sequence smoothness. Li Cuifeng et al. [9] proposed a function transformation to improve the smoothness of the sequence, theoretically proved that the model is more effective than logarithmic function method and power function method. Gong Wenwei et al. [10] predicted the demand of a company’s new car model by establishing a prediction model of the second exponential smoothing method, thus providing theoretical support for the company’s procurement and production. Some scholars optimized the parameters of GM(1,1) model to improve the prediction accuracy. Liu et al. [11] used Weibull cumulative distribution function to construct a new parameter to replace the traditional parameter value, which effectively predicted the consumption of renewable energy in Europe and the world. Yu et al. [12] used the least absolute deviation method and simplex method to estimate the model parameters and applied them to the study of satellite clock bias prediction. This paper proposes a GM(1,1) prediction model based on parallel quantum whale algorithm (QWOA-EFGM(1,1)). The first stage of the model uses the exponential function to preprocess of the traditional GM(1,1) model. The second stage uses the parallel quantum whale optimization algorithm (QWOA) to optimize the background weights

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

787

in the EFGM(1,1) model. Finally, the newly established QWOA-EFGM(1,1) is used to predict the corrosion trend at the elbow of inlet distribution pipe of top heat exchanger of atmospheric distillation tower, and the proposed new model is compared with the traditional GM(1,1) model and EFGM(1,1) model respectively. Combined with the comparative analysis of the prediction results of each model, the validity, accuracy and stability of the proposed model are verified.

2 GM(1,1) Prediction Model To build GM(1,1) model, set the original sequence column to:   X (0) = x(0) (1), x(0) (2), ..., x(0) (n)

(1)

where x(0) (n) is the original modeling data and must be greater than 0; n is the total number of modeling data. 2.1 Data Testing and Processing To ensure the feasibility of the GM(1,1) model, the original data needs to be tested before modeling. Calculate the level ratio of the series: δ(K) = 2

x(0) (k) , k = 2, 3, ..., n − 1)

x(0) (k

(2)

2

If δ(k) ∈ (e− (n+1) , e (n+1) ), then the series level ratio test passes, and the GM(1,1) model can be built and forecasted. Otherwise, perform appropriate transformation processing on the data, such as translation transform: y(0) (k) = x(0) (k) + c , k = 2, 3, ..., n

(3)

2.2 Testing and Processing of Data Step1: The 1-AGO production sequence X (0) is generated by an accumulation of the original sequence X (1) , that is, X (1) = (x(1) (1), x(1) (2), ..., x(1) (n)) x(1) (k) =

k 

x(0) (i),k = 1, 2,...n

(4)

(5)

i=1

Step2: The immediate-mean sequence of x(1) is generated as: Z (1) = (z (1) (2), z (1) (3), ..., z (1) (n))

(6)

788

H. Huang et al.

z (1) (k) = βx(1) (k) + (1 − β)x(1) (k − 1) k = 1, 2..., n

(7)

Among them, z (1) (k) is the background value of the model; β is backgroundweighted factors, and it usually takes a value of 0.5. Step3: A single-variable first-order differential equation is established, which is regarded as a prediction model, namely the GM(1,1) model: x(0) (k) + az (1) (k) = b,k = 1, 2, 3, ..., n

(8)

The corresponding whitening differential equation of GM(1,1) is: dx(1) + ax(1) (t) = b dt

(9)

where a and b are parameters to be solved; a is the development coefficient, which mainly controls the development trend of X 0 and X (1) ; b is a grey action that reflects the changing relationship between data. Step4: Estimating the values of a and b by least square method. (a, b)T = (BT B)−1 BT Y ⎛

−z (1) (2)1



⎜ (1) ⎟ ⎜ −z (3)1⎟ ⎜ ⎟ B = ⎜ .. ⎟ ⎜ .. ⎟ ⎝ .. ⎠ −z (1) (n)1

(10)



⎞ x(0) (2) ⎜ (0) ⎟ ⎜ x (3)⎟ ⎜ ⎟ Y=⎜. ⎟ ⎜. ⎟ ⎝. ⎠

(11)

x(0) (n)

Step5: Under the initial condition x(1) (1) = x(0) (1), the solution of the whitening differential equation is, b −a(t+1) b e (12) x(1) (t) = x(0) (1) − + a a Step6: Let t = k + 1, then the time response sequence of the grey differential equation is, b −ak b (1) (0) xˆ (k + 1) = x (1) − e (13) + a a Step7: The predicted value of the original sequence x is, xˆ (0) (k + 1) = xˆ (1) (k + 1) − xˆ (1) (k) b −ak a (0) e , k = 1, 2, ..., n = (1 − e ) x (1) − a

(14)

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

789

2.3 Improved GM(1,1) Model 1) improve the smoothness of the original sequence When the original sequence changes exponentially or approximately exponentially, the traditional GM(1,1) model can obtain higher accuracy prediction results. However, for some non-smooth sequence data, it will cause a large error in the prediction structure. Therefore, it is very important to process the original data to improve its smoothness. This paper proposes a method of exponential function transformation to improve the smoothness of the original data sequence (EFGM(1,1)) is as follows. Step1: Using the exponential function q−x (q > 1) to process the original sequence 0 X to get E 0 .   (0) (0) (0) E (0) = q−x (1) , q−x (2) , ..., q−x (n) (15) Step2: A new time response sequence is obtained by modeling the new sequence E 0 with GM (1,1). Eˆ (1) k + 1 = Eˆ (1) (k + 1) − Eˆ (1) (k)

(16)

Thus, a new reduction sequence is obtained as, Eˆ (0) k + 1 = Eˆ (1) (k + 1) − Eˆ (1) (k)

(17)

Step3: The reduction value of the original sequence can be obtained by Eˆ 0 . x(0) (k) = − logq (Eˆ (0) (k)), k = 1, 2, ..., n

(18)

Through multiple experiments, the prediction accuracy is the highest when q takes a value of 1.1.

3 Whale Optimization Algorithm and Quantum Whale Optimization Algorithm 3.1 Whale Optimization Algorithm 1) Surrounding prey stage. In the WOA, the whale first recognizes the location of the prey and then surrounds it. Assuming that the current optimal position is the target prey, the other individuals in the group will try to update their positions towards the optimal position. This behavior is represented by the following equations:

(19) D = C · X∗ (t) − X(t) X(t + 1) = X∗ (t) − A · D

(20)

where t denotes the current iteration, X∗ (t) is the position vector of the best solution obtained so far, X(t) is the position vector, D represents the distance between the current

790

H. Huang et al.

whale position and the target prey, A and C are coefficient vectors are calculated as follows: A = 2a · rand 1 − a

(21)

C = 2rand 2

(22)

In the above formula, rand 1 and rand 2 is a random vector in [0,1], and a is linearly decreased from 2 to 0 over the course of iterations. The expression is as follows: a =2−

2t Max_iter

(23)

where Max_iter is the maximum number of iterations. 2) Bubble-net attack stage Spiral updating position: first calculates the distance between the whale and prey, then a spiral equation is created between the position of whale and prey to mimic the helix-shaped movement of humpback whales as follows: X (t + 1) = D · ebl · cos(2π l) + X ∗ (t)

(24)

where D = |X ∗ (t) − X (t)| represent the current distance between the whale and the target prey,l is a random number in [−1 1] and b is a constant. Actually, the whales perform bubble-net attacking at the same time as they encircle prey. The behavior can be modeled as:  ∗ X (t) − A · D p ≥ 0.5 (25) X (t + 1) = D · ebl · cos(2π l) + X ∗ (t) p < 0.5 where p is a random number in [0, 1]. 3) Hunting prey stage In this stage, humpback whales randomly select the positions to search for prey. The mathematical model is as follows: Drand = |C · Xrand (t) − X (t)|

(26)

X (t + 1) = Xrand (t) − A · Drand

(27)

where Xrand is a random position vector (a random whale). 3.2 Parallel Quantum Whale Algorithm Quantum evolutionary algorithm is a probabilistic evolutionary algorithm that utilizes quantum superposition, entanglement and correlation for parallel computing . In order to improve the convergence accuracy of the traditional whale optimization algorithm, this paper proposes a new parallel quantum whale optimization algorithm (QWOA) which combines quantum evolutionary algorithm and whale optimization algorithm.

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

791

In the parallel quantum whale optimization algorithm, the minimum information unit is called a quantum bit, and the state of a quantum bit is represented as follows. |ψ = α|0 + β|1

(28)

where α and β denote the probability magnitudes of state 0 and 1, respectively, satisfying |α|2 + |β|2 = 1. The above quantum properties are introduced into the whale optimization algorithm, so that each individual corresponds to multiple pairs of probability amplitude. The encoding of the i-th whale in the quantum whale is as follows:





  α11

α12



α1k

α21

α22



α2k

αm1

αm2



αmk (29) qi =

...

...

... β11 β12 β1k β21 β22 β2k βm1 βm2 βmk Among them,i = 1, 2, ..., m, j = 1, 2, ..., k, m represents the number of whale populations, and k represents the number of qubits. The quantum rotation gate is used to mutate and update each individual in the population and the process is as follows:      cos(θi ) − sin(θi ) αi αi = (30) sin(θi ) cos(θi )_ βi βi where (αi , βi ) represents the updated probability magnitude of the individual.

4 QWOA-EFGM(1,1) Model Construction The background weight β of the EFGM(1,1) model is optimized by the parallel quantum whale optimization algorithm, and then the optimized EFGM(1,1) model is used to predict the corrosion trend at the elbow of the inlet distribution pipe of the overhead heat exchanger of atmospheric distillation tower. In order to test that the parallel quantum whale optimization algorithm can minimize the mean absolute percentage error(MAPE) of the GM(1,1) model in finding suitable background weights, this paper selects the MAPE as the objective function, as follows:

n

xˆ (0) − x (0)

 i i 1 × 100% (31) fitnessMAPE = (0) n xi i=1 The main steps of QWOA-EFGM(1,1) model construction are as follows: Step1: Initialize the parameters in the whale optimization algorithm, such as population size(N ), spatial dimension (dim), upper and lower boundaries of population search range(ub and lb), maximum number of iterations(Max_iter); Step2: Initialize the quantum whale population, and the quantum bit probability amplitude α and β are initially set to 1. By measuring each individual of the quantum whale population, the binary code P = x1 , x2 , ..., xn , x1 = (x11 , x12 , ..., x1m ) of each whale individual is obtained, and each bit of xij is 0 or 1. The measurement process

792

H. Huang et al.

is as follows: randomly generate a number rand ∈ [0, 1], if rand > α 2 , then x = 1, otherwise, take 0; Step3: The new position is obtained by the position update formula of the three stages of surrounding prey, bubble-net attack, and hunting prey. Step4: The quantum whale population is mutated by quantum rotation gate to obtain a new position. Step5: The whale population and the quantum whale population are merged, and the fitness values of the population whales are sorted at the same time, and the top N optimal whale individuals are selected. Step6: The iteration terminates when the iteration of the QWOA algorithm reaches the maximum number; otherwise, return to step2. After the iteration, the optimal whale individual in the population is assigned to the background weight β of the EFGM(1,1) model, and the optimized EFGM(1,1) model is used to predict the corrosion trend at the elbow of the inlet distribution pipe of the overhead heat exchanger of atmospheric distillation tower.

5 Prediction and Analysis The data come from the atmospheric distillation unit of a refinery . The elbow specification of the inlet distribution pipe of the tower top heat exchanger is 219 mm × 9 mm, and the material is steel 20. The wall thickness of the elbow was measured four times at different positions every other month, and the four average values of 18 months were collected as measurement data. The average data of 12 months are selected as the original sequence of GM(1,1) model and QWOA-EFGM(1,1) model. The specific elbow wall thickness values are shown in Table 1. The remaining 6 months of data as a prediction sequence to test the accuracy of its forecast. 5.1 Inspection and Processing of Data In order to fully verify the prediction effect of each prediction model, this paper adopts the mean absolute error (MAE), root mean square error (RMSE), goodness-of-fit test (R2 ) and absolute degree of grey incidence (σ ), their definition are as follows: (1) Mean absolute error (MAE) MAE =

n

1 

(0) (0)

xi − xˆ i n

(32)

i=1

(2) Root mean square error (RMSE)  RMSE =

 1   (0) (0) 2 xi − xˆ i n n

i=1

1 2

(33)

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

793

(3) goodness-of-fit test (R2 )   (0) (0) 2 xi − xˆ i R2 = 1 −  n  (0) (0) 2 i=1 xi − x i n

i=1

(0)

In the formula, xi

(34)

is the average value of the original sequence.

(4) absolute degree of grey incidence (σ )



n−1 (0)

1 |ε0 | =

x (k) + x(0) (n)

2

k−2



n−1

 1 (0)

(0)

εˆ 0 = x ˆ x ˆ (k) + (n)

2

k−2



n−1   1 



εˆ 0 − ε0 = xˆ (0) (n) − x(0) (n)

xˆ (0) (k) − x(0) (k) +

2

k−2

1 + |ε0 | + εˆ 0



σ = 1 + |ε0 | + εˆ 0 + εˆ 0 − ε0

(35)

where ε0 is the associated parameter of the original sequence; εˆ 0 is the correlation of the predicted sequence parameters; σ is the grey absolute correlation degree. The smaller the mean absolute error (MAE) and root mean square error (RMSE) of the above evaluation indexes, the better the prediction results. The closer the value of goodness-of-fit test R2 and grey absolute correlation σ are to 1, the better the fitting effect. 5.2 Analysis of Experimental Results 5.2.1 QWOA and WOA Algorithm Convergence When using the parallel quantum whale algorithm and the traditional whale algorithm to optimize the background weights in the QWOA-EFGM(1,1) model, the convergence of the QWOA and WOA algorithms is shown in Fig. 1. The parameter settings of QWOA and WOA: the population size is 30, the number of iterations is 50, the spatial search range is [0,1], and the dimension is 12. The QWOA and WOA algorithms run 30 times respectively, and take the smallest error as the final result. It can be seen from the figure that the optimization accuracy of parallel quantum whale algorithm is superior to the traditional whale algorithm. Compared with the traditional WOA, the improved QWOA can find better background weights, so that the error of model prediction results is smaller.

794

H. Huang et al.

5.2.2 Comparative Analysis of the Prediction Results of Each Model 1) The prediction results of each model The wall thickness prediction results of the remaining 6 months and the prediction residuals of GM(1,1), EFGM(1,1) and QWOA-GM(1,1) are shown in Table 2 below. It can be seen from Table 2 that the prediction residuals of the EFGM(1,1) prediction model optimized by parallel quantum whale are much smaller than those of the GM(1,1) model and the EFGM(1,1) model in the four months of 15–18 months. The minimum relative error of the traditional GM(1,1) model is 0.0015, the maximum relative error is 0.1057, and the error fluctuation is 0.1033, while the minimum error of the QWOAEFGM(1,1) model is 0.0084, the maximum error is 0.0538, and the error fluctuation is 0.0454. It shows that the improved prediction model not only improves the prediction accuracy, but also performs better in terms of stability. Figure 2 is the prediction residual comparison chart of each model. Table 1. Elbow wall thickness value Month Test Point

1

2

3

4

5

6

7

8

9

10

11

12

Point1

8.83 8.81 8.85 8.74 8.36 8.36 7.85 7.65 7.53 7.14 7.54 7.36

Ponit2

8.92 8.88 8.05 8.87 8.56 7.81 7.62 7.52 7.44 6.83 6.85 6.53

Point3

8.87 8.87 8.23 8.23 7.89 7.63 7.99 7.14 7.06 7.22 6.32 6.36

Point4

8.86 8.56 8.75 7.96 7.65 7.36 7.46 7.01 7.11 6.85 6.71 6.57

Mean Value

8.87 8.78 8.47 8.45 8.115 7.79 7.73 7.33 7.285 7.01 6.855 6.705

Table 2. Prediction value of elbow wall thickness month

actual wall thickness

GM(1,1) model

EFGM(1,1) model

QWOA-EFGM(1,1) model

(mm)

predicted value(mm)

residual error(mm)

relative error

predicted value(mm)

residual error(mm)

relative error

predicted value(mm)

residual error(mm)

relative error

13

6.5375

6.4862

0.0513

0.0078

6.4492

0.0883

0.0135

6.4349

0.1026

0.0157

14

6.300

6.3097

0.0097

0.0015

6.2424

0.0576

0.0091

6.2249

0.0751

0.0119

15

5.965

6.1379

0.1729

0.029

6.0356

0.0706

0.0118

6.0148

0.0498

0.0084

16

5.620

5.9709

0.3459

0.0615

5.8287

0.2037

0.0362

5.8048

0.1798

0.0320

17

5.3325

5.8084

0.4759

0.0982

5.6219

0.2894

0.0543

5.5947

0.2622

0.0492

18

5.110

5.6503

0.5403

0.1057

5.415

0.3050

0.0597

5.3847

0.2747

0.0538

2) Comparative analysis of prediction results In order to verify the effectiveness of the proposed model, the traditional grey GM(1,1) model, the improved EFGM(1,1) model and the EFGM(1,1) model optimized by the parallel quantum whale algorithm are compared 20 times, and take the best prediction result for comparison. The evaluation indexes of prediction error of each model are shown in Table 3.

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

795

Fig. 1. Convergence of QWOA and WOA Algorithms

Fig. 2. Prediction residual comparison curves of each model

According to Table 3, the parallel quantum whale optimization EFGM(1,1) model proposed in this paper is 0.109 and 0.013 smaller than the other two models in the index MAE, and 0.154 and 0.018 smaller than the other two models in the index RMSE, indicating that the QWOA-EFGM(1,1) model has higher prediction accuracy than GM(1,1) and EFGM(1,1) models. In the indicators R2 and σ , the QWOA-EFGM(1,1) model is closer to 100% than the other two models, which proves that the EFGM(1,1) model optimized by parallel quantum whale has a better fitting effect on the actual value. Each evaluation index is significantly better than GM(1,1) model and EFGM(1,1) model, which proves that QWOA-EFGM(1,1) model has higher prediction accuracy, and smaller prediction error and can better predict the thickness of the elbow wall. In order to compare the prediction effect of each prediction model more intuitively, Fig. 3 shows the comparison of each prediction model in different evaluation indexes. It can be seen from Fig. 3 that the QWOA-EFGM(1,1) prediction model proposed in this paper can achieve higher prediction accuracy than the other two models.

796

H. Huang et al. Table 3. Comparison of prediction results of each model

Fig. 3. Comparison of evaluation indexes of each prediction model

6 Conclusion The example analysis shows that the QWOA-EFGM(1,1) model has strong prediction ability and good stability. Using QWOA-EFGM(1,1) model to predict the corrosion trend at the elbow of the inlet distribution pipe of the overhead heat exchanger of atmospheric distillation tower, which can provide a strong theoretical support for the detection of atmospheric distillation unit and reduce the monitoring frequency and cost. Acknowledgments. This work is supported by the National Natural Science Foundation of China (62266007). Guangxi Natural Science Foundation (2021GXNSFAA220068, 2018GXNSFAA294068).

References 1. Deng, J.L.: Grey control system. J. Huazhong Inst. Technol. 03, 9–18 (1982). https://doi.org/ 10.13245/j.hust.1982.03.002 2. Gao, J.F., Hao, B.: Corrosion prediction of submarine pipeline based on improved grey model. China Water Transport 21(09), 48–50. 4 (2021) 3. Wu, Y., Zhang, D.-Q., Xu, N.: Initial value optimization of opposite grey model and its application to road surface subsidence prediction. Eng. Surv. Mapp. 23(12), 60–62. https:// doi.org/10.19349/j.cnki.issn1006-7949.2014.12.015

GM(1,1) Model Based on Parallel Quantum Whale Algorithm

797

4. Lu, J., Li, F.: Optimization and Application of GM(1,1) Model Based on Initial Value and Background Value. Oper. Res. Manage. Sci. 29(09), 27–33 (2020) 5. Yang, N., Li, H., Yuan, J., Li, S., Wang, X.: Medium-and long-term load forecasting method considering grey correlation degree analysis. Proc. CSU-EPSA 30(06), 108–114 (2018) 6. Fan, X.H., Miao, Q.M., Wang, H.M.: Grey Prediction GM (1,1) model and its improvement and application. J. Armored Force Eng. Inst. (02), 24–26 (2003) 7. Xi, Y., Yang, Y.: Optimization and application of GM(1,1) model based on minimum error. Appl. Res. Comput. 33(08), 2328–2330 (2016) 8. Yang, X.L., Zhou, M., Zeng, B.: A new method for constructing background value of grey prediction model. Statist. Dec. 34(19), 14–18 (2018) 9. Li, C., Dai, W.: An approach of the grey modelling based on cotx transformation. Syst. Eng. 2005(03), 110–114 (2005) 10. Weiwen, G., Jing, H.: A demand forecast model based on the gray theory and eexponential smoothing method. Statist. Dec. 01, 72–76 (2017) 11. Liu, X., Xie, N.: A nonlinear grey forecasting model with double shape parameters and its application. Appl. Math. Comput. 360, 203–212 (2019) 12. Yu, Y., Huang, M., Wang, X., Hu, R.: Navigation satellite clock bias prediction based on grey model improved by least absolute deviations. Map. Bull. 2019(04), 1–6 (2019). https://doi. org/10.13474/j.cnki.11-2246.2019.0102

3D Path Planning Based on Improved Teaching and Learning Optimization Algorithm Xiuxi Wei1 , Haixuan He2 , Huajuan Huang1(B) , and Yongquan Zhou1 1 College of Artificial Intelligence, Guangxi Minzu University, Nanning 530006, China

[email protected] 2 College of Electronic Information, Guangxi Minzu University, Nanning 530006, China

Abstract. An improved teaching and learning algorithm suitable for global path planning is proposed for the 3D path planning problem of unmanned aerial vehicles. In order to improve the shortcomings of basic teaching and learning algorithms in practical applications, group teaching is implemented in the teaching stage to improve the local development and optimization ability of algorithms; Add autonomous learning after the learning stage to improve the algorithm’s global optimization ability. The simulation results show that compared with the teaching and learning algorithm, artificial fish swarm algorithm, and ant colony algorithm, using the improved teaching and learning algorithm for unmanned aerial vehicle 3D path planning is more effective and cost-effective. Keywords: Teaching and learning algorithm · Path planning · Grouping · Autonomous learning

1 Introduction At present, drones have been widely used in various fields such as border patrol and power inspection due to their flexible appearance, small size, and strong adaptability [1, 2]. Drone path planning refers to the ability of drones to smoothly reach the target area within a specified range, and to plan an optimal flight path by balancing other factors such as flight environment and energy consumption. There are currently many algorithms to solve such problems, such as particle algorithm, ant colony algorithm, genetic algorithm, particle swarm algorithm, etc. Although these algorithms have advantages, they also have some shortcomings. Reference [3] uses genetic algorithms for path planning of unmanned aerial vehicles, but the data required for path planning needs to be encoded, which is difficult to achieve, and different encoding methods have different effects; The literature [4] uses the improved ant colony algorithm, but its parameters are more, and the initial pheromone set is not objective, which affects the effect of path planning; In reference [5], the adaptive particle swarm optimization algorithm incorporating Tabu search algorithm can obtain the near optimal solution of the path, but it is highly dependent on the maximum number of iterations. The teaching and learning algorithm is a new type of swarm intelligence algorithm proposed by Rao et al. [6] in 2011 to simulate the teaching and learning relationship © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 798–809, 2023. https://doi.org/10.1007/978-981-99-4742-3_66

3D Path Planning Based on Improved Teaching and Learning

799

between teachers and students in a classroom. It has the advantages of easy implementation, simple structure, and few parameters. However, the basic TLBO algorithm also has some drawbacks, such as insufficient convergence speed and loss of global optimal solutions. This article proposes an improved teaching and learning optimization algorithm (ITLBO) for this purpose. In the teaching stage, group teaching is carried out to improve the local exploration and development ability of the algorithm. After the learning stage, an autonomous learning stage is added, where students actively seek optimization and make up for differences, avoiding regression, and improving the global optimization ability and accuracy of the algorithm. Apply the ITLBO algorithm to unmanned aerial vehicle path planning, and compare and analyze it with the basic TLBO algorithm, ACA algorithm [7], and ABC algorithm [8]. Prove the effectiveness of the algorithm through simulation.

2 Establishment and Analysis of Models 2.1 Environmental Model The information required for drone path planning needs to be obtained from actual natural environment terrain models. A comprehensive terrain modeling can effectively improve the effectiveness of path planning. The benchmark terrain model [9] in this paper is shown in Eq. (1).      Z1 (x, y) = sin(y + a) + b · sin(y + a) + c · cos d · x2 + y2 + e · cos(y) + f · sin x2 + y2

(1)

where, x and y represent the coordinates of the points projected by the model on the horizontal plane; Z 1 is the height value corresponding to the horizontal plane; a, b, c, d, e, f , and g are constant coefficients used to adjust the degree of relief of the reference terrain in the environmental model. The exponential function is used to describe the natural mountains in the flight environment, and the total number of mountains is expressed by n. The central coordinates of each peak(x i ,yi ). Control the height of the mountain using hi ; The attenuation and slope of the mountain peak along the x-axis and y-axis directions are represented by x si and ysi . The environmental model is shown in Fig. 1. The mathematical model is shown in Eq. (2).       n  x − xsi 2 y − ysi 2 Z2 (x, y) = hi exp − − (2) xsi ysi i=1

2.2 Cubic B-spline Curve Path In order to improve the computational efficiency of the algorithm and ensure smooth and flyable flight paths, a cubic B-spline curve is generally used for smoothing the flight path of unmanned aerial vehicles, as shown in Eq. (3). p(u)(x, y) =

n  i=0

di Ni,k (u)

(3)

800

X. Wei et al.

Fig. 1. Environmental model

where, d i (i = 0,1, ,n) is used to limit the curve range and is the control vertex of the B-spline curve. N i,k(u) are the value functions that determine the parameter u, as shown in Eq. (4). ⎧ 1, ui ≤ u ≤ ui+1 ⎪ ⎪ ⎨ Ni,0 (u) = 0, other (4) ⎪ u − ui ui+k+1 − u ⎪ ⎩ Ni,k (u) = Ni,k−1 (u) + Ni+1,k−1 (u) ui+k − ui ui+k+1 − ui+1 , n) is a node vector, and its sequence where, it is specified that 0/0 = 0, ui (i = 0, 1 , satisfies a non-decreasing relationship. This article sets the repeatability of the vector at both ends as k + 1, and the vector at the inner nodes is uniformly distributed, with a value of k of 3. This means that the path is fitted using a cubic quasi uniform B-spline curve equation. 2.3 Fitness Function The establishment of a correct fitness function is crucial to the evaluation of path quality, and the flight range and obstacle avoidance cost of UAVs need to be taken into account. The path length fitted by Eq. (3) determines the voyage cost f v , and the total distance is determined by all adjacent nodes (x i , yi zi ) and (x i+1 , yi+1 , zi+1 ) in the path, as shown in Eq. (5). fv =

m   i=1

(xi+1 − xi )2 + (yi+1 − yi )2 + (zi+1 − zi )2

(5)

3D Path Planning Based on Improved Teaching and Learning

801

d min represents the minimum safe flight distance between the drone and the mountain peak. To ensure safe flight, the collision free distance between obstacles is d s , and the obstacle avoidance cost is represented by f a , as shown in Eq. (6).

0, dmin ≥ ds fa = (6) ∞, dmin < ds Considering f a and f v comprehensively, the fitness function shown in Eq. (7) is established. When d min ≥ d s , the UAV has no flight danger and the obstacle avoidance cost does not need to be calculated; When d min < d s , the UAV is considered to be in flight danger, and the fitness function f is infinite. f = fv + fa

(7)

3 Teaching and Learning Algorithms The standard TLBO algorithm mainly includes two stages: the teaching stage, where teachers strive to improve the overall level of the class, and the learning stage, where students improve their own grades through mutual communication and learning. 3.1 Teaching Stage During the teaching phase, the optimal individual will be selected as the teacher X teacher to teach other students in the class and improve the average value of the entire class. The current student is X i , and the update formula is shown in Eq. (9). mean =

n  1 i=1

n

Xi

Xinew = Xiold + r ∗ (Xteacher − TF ∗ mean)

(8) (9)

where, TF = round[ 1+ rand(0,1)], which is the learning factor, indicating whether the student fully learns from the teacher; mean is the average individual in the class; The random number r 1 takes a value in the [0,1] interval. After the teaching stage, if X i new is better than X i old , update X i new with X i old , otherwise it will not be updated. 3.2 Learning Stage During the learning stage, student Xi randomly selects one student Xj from the class to learn and produces a new solution as shown in Eq. (10).   ⎧ old + r × X old − X old , ⎪ X 2 ⎪ i i ⎪ ⎪   j ⎨ f (Xi )  < f Xj . new  Xi = (10) ⎪ Xi (t) + r3 × Xjold − Xiold , ⎪ ⎪ ⎪   ⎩ f (Xi ) > f Xj .

802

X. Wei et al.

where, the random numbers r 2 , r 3 take values in the [0,1] interval, and f (X i ), f (X j ) correspond to the students’ grades in X i , X j subjects. Equation (10) indicates that if student X j ’s subject performance is better than student X i ’s subject performance, then student X i performs co learning with student X j ; Conversely, student X i performs Reverse learning. After the learning stage, if X i new is better than X i old , update X i old with X i new , otherwise it will not be updated.

4 Strategies for Improving Teaching and Learning Algorithms 4.1 Group Teaching The standard TLBO algorithm, regardless of the individual’s fitness, adopts formula (9) to update uniformly, ignoring the differences between individuals. In order to improve the accuracy of algorithm optimization and avoid the algorithm falling into local optimization, the personalized teaching method is adopted, which divides the “teaching” stage into two parts. The individual whose fitness value is greater than or equal to the average value directly learns from teachers. Because this part of the group has strong acceptance ability, teachers will increase teaching efforts, Imitate the particle swarm optimization algorithm by increasing the learning factor c = 2, and still use the original update method if it is less than the average value. The improved update method is beneficial for overcoming the defect of the algorithm falling into local optima and improving the accuracy of the solution. The updated formula for the “teaching” stage after improvement is shown in Eq. (11).   ⎧ old old ⎪ X + 2r × X − X ⎪ 1 teaher i i ⎪ ⎨ new f < mean (X ) i = (11) Xi old ⎪ ⎪ ⎪ Xi + r2 × (Xteacher − TF ∗ mean) ⎩ f Xj ≥ mean

4.2 Autonomous Learning After completing learning and communication with peers, student X i should actively seek out the gap between themselves and the globally optimal individual. At the same time, in order to prevent the phenomenon of individual learning regressing, it is also necessary to learn from the individual’s previous optimal value, in order to meet higher learning requirements, improve the global development and exploration ability of the algorithm, and accelerate the convergence speed of the algorithm. The update formula is shown in Eq. (12).     (12) Xinew = Xiold + r4 × Xteacher − Xiold + r5 × Xpbest − Xiold

3D Path Planning Based on Improved Teaching and Learning

803

4.3 Steps for Improving Algorithm Implementation Step 1: Establish a three-dimensional environment model for drone flight, and set the coordinates of the starting and ending points. Step 2: Initialize the population. Randomly generate n individuals, each containing k spatial scatter points, and set the maximum number of iterations t max . Step 3: Calculate the average value mean and the fitness value f i of each individual, and find the individual with the smallest fitness as the teacher X teacher to update the optimal position of the individual. Step 4: According to Eq. (11), execute the “teaching” stage and generate new solutions. Step 5: According to Eq. (10), perform the “learning” stage and generate new solutions. Step 6: According to Eq. (12), perform the autonomous learning phase, update individual positions, update individual optimal positions, and globally optimal positions. Step 7: Determine whether the population has reached the termination condition t = t max . If so, exit; Otherwise, proceed to step (3).

5 Experimental Results and Analysis 5.1 Simulation Environment and Experimental Data To verify the effectiveness of the ITLBO algorithm proposed in this article, experimental results were compared and analyzed with basic TLBO, ABC, and ACA. The experimental environment is a Windows 64bit system, the computer configuration is configured with Intel Core i5-8300H processor, the main frequency is 2.3 Hz, the running memory is 8 GB, and the simulation platform is MATLAB2016b. The modeling task space is 100 m * 100 m * 100m. Nine obstacle peaks are set in the simulation environment planning space, starting from [0,0,5] and ending at [100100,5]. The number of path nodes is 3. Table 1 shows the parameters of mountain obstacles. Table 2: Parameter settings required for each comparative algorithm. 5.2 Simulation Results and Analysis The experiment uniformly sets the maximum number of iterations for algorithm simulation t max = 100, the population size is 50, the number of path nodes is 3, and each algorithm runs independently 30 times. The optimal, worst, average, and variance of the four algorithms are calculated, as shown in Table 3. The 30 simulation results of ITLBO, TLBO, ABC, and ACA are shown in Fig. 2, the box diagram is shown in Fig. 3, and the path planning results are shown in Figs. 4, 5, 6 and 7. From Table 3 and Fig. 3, it can be seen that the optimal value and stability of the TLBO algorithm are poor; The improved algorithm ITLBO in this article has the lowest optimal, worst, and average values. Although its stability is slightly worse than ACA and ABC algorithms, it can be seen from Fig. 2 that the ITLBO search results are generally lower, and it is more likely to obtain the best path planning results.

804

X. Wei et al. Table 1. Peak Parameters

Peak number

Peak parameters hi

xi

yi

xsi

ysi

1

30

20

30

8

12

2

40

40

60

7

9

3

75

55

45

4

8

4

50

75

68

8

7

5

65

80

10

8

4

6

57

50

15

13

5

7

66

15

70

5

13

8

43

85

40

6

12

9

35

60

90

7

4

Table 2. Algorithm Parameter Settings Algorithm

Parameter settings

ITLBO

Teaching factors c = 2

ABC

Pheromone factor α = 10, Heuristic function factor β = 1, Pheromone constant Q = 1, Pheromone volatilization factor ρ = 0.1

ACA

Threshold l = 5

Table 3. 30 Simulation Results Algorithm

optimal value

worst value

average value

variance

ITLBO

143.753

154.078

147.948

3.775

TLBO

168.816

265.790

206.137

26.337

ACA

153.761

166.753

158.708

3.263

ABC

145.662

160.786

151.884

3.433

3D Path Planning Based on Improved Teaching and Learning

805

Fig. 2. 30 simulation results

Fig. 3. Variance

Analyzing Figs. 4, 5, 6 and 7, we can compare the path planning results of four different algorithms. The TLBO algorithm has poor track smoothness effect; The ITLBO algorithm, ACA algorithm, and ABC algorithm have smooth and avoidable tracks.

806

X. Wei et al.

Fig. 4. ITLBO Planning Results

Fig. 5. TLBO Path Planning Results

3D Path Planning Based on Improved Teaching and Learning

807

Avoiding mountain obstacles allows the drone to reach the destination smoothly, but the path length planned by ACA algorithm and ABC algorithm is longer and the height changes greatly. It can be seen from Fig. 8 that the ITLBO algorithm can achieve the minimum fitness. Although the TLBO algorithm and ACA algorithm converge faster, they fall into the local optimum prematurely. Compared with the ITLBO algorithm, the ABC algorithm not only converges slowly, but also fails to achieve a smaller fitness value than the ITLBO algorithm. From this, it can be seen that the ITLBO algorithm can avoid premature falling into local optima, has fast convergence speed, and can effectively improve the quality of path planning.

Fig. 6. ACA Path Planning Results

808

X. Wei et al.

Fig. 7. ABC Path Planning Results

Fig. 8. Fitness Value Curve

3D Path Planning Based on Improved Teaching and Learning

809

6 Conclusion An improved teaching and learning algorithm (ITLBO) is proposed for the threedimensional path planning problem of unmanned aerial vehicles. Implementing group teaching during the teaching stage is in line with the evolutionary characteristics of the algorithm, avoiding the algorithm from falling into local optima prematurely, and improving the accuracy of the solution; After the “learning” stage, ITLBO introduces students’ autonomous learning stage, where each student actively seeks optimization and compensates for differences, improving the algorithm’s global exploration and development ability. The simulation results show that compared with TLBO, ACA, and ABC, the ITLBO algorithm can effectively avoid obstacles at a lower cost and can search for the optimal path more effectively. The effectiveness and superiority of the ITLBO algorithm have been demonstrated through experiments. Acknowledgments. This work is supported by the National Natural Science Foundation of China (62266007). Guangxi Natural Science Foundation (2021GXNSFAA220068, 2018GXNSFAA294068).

References 1. Ren, N., Zhang, N., Cui, Y., Zhang, R., Pang, X.: Method of semantic entity construction and trajectory control for UAV electric power inspection. J. Comput. Appl. 40(10), 3095–3100 (2020) 2. Zou, L., Zhang, M., Bai, J., Wu, J.: A survey of modeling and simulation of UAS swarm operation. Tactical Missile Technol. (03), 98–108 (2021) 3. He, J., Tan, D.: Three-dimensional terrain path planning based on sight range and genetic algorithm. Comput. Eng. Appl. 57(15), 279–285 (2021) 4. Chen, C., Zhang, L.: Three-dimensional path planning based on improved ant colony algorithm. Comput. Eng. Appl. 55(20), 192–196 (2019) 5. Liu, H., Yang, J., Wang, T., Wang, X.: route planning in the three- dimensional space using improved adaptive PSO. Fire Control Command Control 38(11), 141–143 (2013) 6. Rao, R.V., Savsani, V.J., Vakharia, D.P.: Teaching–learning-based optimization: a novel method for constrained mechanical design optimization problems. Comput. Aided Des. 43(3), 303–315 (2011) 7. Dorigo, M., Birattari, M., Stutzle, T.: Ant colony optimization. IEEE Comput. Intell. Mag. 1(4), 28–39 (2006) 8. Gao, W., Liu, S.: A modified artificial bee colony algorithm. Comput. Oper. Res. 39(3), 687–697 (2012) 9. Lian, X., Liu, Y., Chen, Y., Huang, J., Gong, Y., Huo, L.: Research on multi-peak spectral line separation method based on adaptive particle swarm optimization. Spectroscopy Spectral Anal. 41(05), 1452–1457 (2021)

Author Index

A Alaskar, Haya 330 Ansari, Sam 330 B Bu, Lijing 639 C Cai, Zhiqi 354 Cao, Hua 247 Cao, Xiaoqun 235 Cao, Yi 555, 576 Chang, Liu 676 Chen, Bin 137 Chen, Cang 354 Chen, Dinghao 247 Chen, Guang Yi 263, 429 Chen, Guanyu 124 Chen, Jiahao 688 Chen, Li 199 Chen, Lifang 70 Chen, Peng 498, 601, 614, 626 Chen, Ran 85, 98 Chen, Sien 33, 773 Chen, Yaojie 588, 664 Chen, Yuehui 555, 576 Chen, Yunhao 70, 283 Cui, Xueyan 109 D Dai, Dong 639 Dai, Lingyun 308 Dai, Zhenmin 567 Deng, Mingjun 639 Deng, Xiaojun 109 Ding, Hanqing 296 Ding, Jianwen 773 Dong, Lu 438

Dong, Wensheng 247, 773 Du, Zhihua 710 E Elmeligy, Mohamed Essa, Amine 330

330

F Fan, Ruiqi 413 Fan, Xue 449 Fang, Min 710 Feng, Guangsheng 162 Fu, Xianghua 222 Fu, Xuzhou 722 G Gao, Jie 722 Gao, WeiZe 699 Ge, Lina 737, 761 Geng, Jing 137 Gong, Zhaohui 174 Gu, Chunyan 531 Guo, Hao 21 Guo, Jie 401 Guo, Xin 187 Guo, Yanan 235 Guo, Yanwen 401 H Hameed, Iqra 367 Han, Shiyuan 58, 449 He, Haixuan 798 Hou, Yangqing 473 Hu, Jun 124 Hu, Junjie 137 Hu, Kai 85, 98 Hu, Yujun 343 Huang, Bo 377

© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNCS 14087, pp. 811–813, 2023. https://doi.org/10.1007/978-981-99-4742-3

812

Huang, Huajuan 785, 798 Huang, Ke-Yang 149 Huang, Shixian 785 Huang, Shouwang 413 Huang, Tao 137 Huang, Yifan 70, 283 Huang, Zhangjin 531 Hussain, Abir 330 J Jia, Tongyao 652 Jiang, TingShuo 699 Jiang, Wenheng 555 Jin, Jiyong 21 Jintao, Li 676 K Khan, Ameer Hamza 46 Khan, Wasiq 330 Koo, Insoo 367 Krzyzak, Adam 263, 429 L Lan, Xiaobin 688 Li, Bingwei 296 Li, Changyun 109 Li, Feng 211, 343 Li, Haiao 737 Li, Hui 320 Li, Jiafeng 652 Li, Jianqiang 710 Li, Nuo 149 Li, Qi 124 Li, Shijie 389 Li, Sijie 773 Li, Wei 162 Li, Zhongtao 58 Liang, Huanwen 222 Lin, Lanliang 33 Lin, Ruohan 688 Liu, Bin 3, 21 Liu, Chang 461 Liu, Jingyi 124 Liu, Jin-Xing 308 Liu, Wenzheng 588 Liu, YaTe 498 Liu, Yu 449 Liu, Zhiqiang 722

Author Index

Lu, Kun 601, 614, 626 Lu, Yu 222, 354 Luo, Jiusong 211 Luo, Lei 58 M Ma, Jiarui 473 Mahmoud, Soliman 330 Mei, Jiaming 688 Meng, Xiangxu 162 Mi, Jian-Xun 149 N Ni, Jiawei

601, 614, 626

O OBE, Dhiya Al-Jumeily 330 P Pan, Fei 401 Pan, Lejun 601, 614, 626 Peng, Kecheng 235 Premaratne, Prashan 523 Q Qiao, Nianzu 438 Qing, Chunmei 506 R Ren, Kejun 485 Ren, QingYu 699 Ren, Zhen 70, 283 S Shang, Junliang 272, 308 Shang, Li 377 Shen, Jianlu 70, 283 Shi, Chenxi 33 Shu, Jian 576 Shuyuan, Tang 676 Sun, Jia 438 Sun, Zhuang 199 T Tang, Dandan 485 Tang, Linxia 137 Tian, Wenlong 235 Turky, Ayad 330

Author Index

V Vial, Peter 523 W Wan, Lanjun 109 Wang, Bing 498, 601, 614, 626 Wang, Huiqiang 162 Wang, Jin-Wu 567 Wang, Juan 272, 308 Wang, Lele 3 Wang, Menghao 773 Wang, Qingkai 247, 773 Wang, Quanyu 124 Wang, Ruijuan 320 Wang, Wenyan 601, 614, 626 Wang, Xiao 761 Wang, Xinlong 749 Wei, Xiuxi 785, 798 Wen, Jia 485 Wu, Mengqi 174 Wu, Shaohui 199 Wu, Shichao 413 Wu, Yongrong 33 Wu, Yujia 187 X Xia, Xin 710 Xia, Zhenyang 749 Xiang, Yang 485 Xie, Daiwei 567 Xie, Weifang 354 Xie, Wenfang 263, 429 Xie, Xinyu 639 Xie, Yonghui 320 Xie, Zhihua 544 Xie, Zhuo 174 Xu, Jie 272, 308 Xu, Jin 296 Xu, Xiangmin 506 Y Yan, Kuiting 272, 308 Yan, Shaoqi 722 Yan, Zihui 70, 283 Yang, Benxin 664 Yang, Chen 46 Yang, Gang 461 Yang, Huali 137 Yang, Jun 58

813

Yang, Lvqing 33, 247, 773 Yang, Shuangyuan 33 Yang, Xixin 58 Yang, Xue 531 Yang, Yifei 389 Yang, Yinan 199 Yu, Bo 33, 247, 773 Yu, Li 710 Yu, Mei 722 Yu, Ruiguo 722 Yu, Weiwei 449 Yu, Zichang 222 Yuan, Jin 531 Yuan, Shasha 272, 308

Z Zaidan, Abdullah 330 Zhai, Yi 749 Zhan, Kangning 187 Zhang, Dongyan 531 Zhang, Guifeng 761 Zhang, Hao 737 Zhang, Jianhao 461 Zhang, Jun 498, 601, 614, 626 Zhang, Yanju 688 Zhang, Yinyan 46 Zhang, Yue 749 Zhang, Yuze 377 Zhang, Zhengpeng 639 Zhang, Zhihong 174 Zhao, BingRui 699 Zhao, Chuwei 544 Zhao, Guoping 174 Zhao, Haoxin 109 Zhao, Yaou 555, 576 Zhao, Zheng 162 Zhao, Zhong-Qiu 85, 98 Zheng, ChunHou 498 Zheng, Wenqi 162 Zhong, Mingyang 389 Zhou, Lijuan 174 Zhou, Mengge 235 Zhou, Yongquan 785, 798 Zhu, Chenlin 601, 614, 626 Zhu, Siqi 506 Zhu, Yunjie 70, 283 Zhuo, Li 652 Ziad, Suhaib 330