116 22 75MB
English Pages 833 [823] Year 2023
LNAI 14090
De-Shuang Huang · Prashan Premaratne · Baohua Jin · Boyang Qu · Kang-Hyun Jo · Abir Hussain (Eds.)
Advanced Intelligent Computing Technology and Applications 19th International Conference, ICIC 2023 Zhengzhou, China, August 10–13, 2023 Proceedings, Part V
123
Lecture Notes in Computer Science
Lecture Notes in Artificial Intelligence Founding Editor Jörg Siekmann
Series Editors Randy Goebel, University of Alberta, Edmonton, Canada Wolfgang Wahlster, DFKI, Berlin, Germany Zhi-Hua Zhou, Nanjing University, Nanjing, China
14090
The series Lecture Notes in Artificial Intelligence (LNAI) was established in 1988 as a topical subseries of LNCS devoted to artificial intelligence. The series publishes state-of-the-art research results at a high level. As with the LNCS mother series, the mission of the series is to serve the international R & D community by providing an invaluable service, mainly focused on the publication of conference and workshop proceedings and postproceedings.
De-Shuang Huang · Prashan Premaratne · Baohua Jin · Boyang Qu · Kang-Hyun Jo · Abir Hussain Editors
Advanced Intelligent Computing Technology and Applications 19th International Conference, ICIC 2023 Zhengzhou, China, August 10–13, 2023 Proceedings, Part V
Editors De-Shuang Huang Department of Computer Science Eastern Institute of Technology Zhejiang, China Baohua Jin Zhengzhou University of Light Industry Zhengzhou, China Kang-Hyun Jo University of Ulsan Ulsan, Korea (Republic of)
Prashan Premaratne University of Wollongong North Wollongong, NSW, Australia Boyang Qu Zhong Yuan University of Technology Zhengzhou, China Abir Hussain Department of Computer Science Liverpool John Moores University Liverpool, UK
ISSN 0302-9743 ISSN 1611-3349 (electronic) Lecture Notes in Artificial Intelligence ISBN 978-981-99-4760-7 ISBN 978-981-99-4761-4 (eBook) https://doi.org/10.1007/978-981-99-4761-4 LNCS Sublibrary: SL7 – Artificial Intelligence © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 This work is subject to copyright. All rights are reserved by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors, and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd. The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore
Preface
The International Conference on Intelligent Computing (ICIC) was started to provide an annual forum dedicated to emerging and challenging topics in artificial intelligence, machine learning, pattern recognition, bioinformatics, and computational biology. It aims to bring together researchers and practitioners from both academia and industry to share ideas, problems, and solutions related to the multifaceted aspects of intelligent computing. ICIC 2023, held in Zhengzhou, China, August 10–13, 2023, constituted the 19th International Conference on Intelligent Computing. It built upon the success of ICIC 2022 (Xi’an, China), ICIC 2021 (Shenzhen, China), ICIC 2020 (Bari, Italy), ICIC 2019 (Nanchang, China), ICIC 2018 (Wuhan, China), ICIC 2017 (Liverpool, UK), ICIC 2016 (Lanzhou, China), ICIC 2015 (Fuzhou, China), ICIC 2014 (Taiyuan, China), ICIC 2013 (Nanning, China), ICIC 2012 (Huangshan, China), ICIC 2011 (Zhengzhou, China), ICIC 2010 (Changsha, China), ICIC 2009 (Ulsan, South Korea), ICIC 2008 (Shanghai, China), ICIC 2007 (Qingdao, China), ICIC 2006 (Kunming, China), and ICIC 2005 (Hefei, China). This year, the conference concentrated mainly on theories and methodologies as well as emerging applications of intelligent computing. Its aim was to unify the picture of contemporary intelligent computing techniques as an integral concept that highlights the trends in advanced computational intelligence and bridges theoretical research with applications. Therefore, the theme for this conference was “Advanced Intelligent Computing Technology and Applications”. Papers that focused on this theme were solicited, addressing theories, methodologies, and applications in science and technology. ICIC 2023 received 828 submissions from 12 countries and regions. All papers went through a rigorous peer-review procedure and each paper received at least three review reports. Based on the review reports, the Program Committee finally selected 337 high-quality papers for presentation at ICIC 2023, and inclusion in five volumes of proceedings published by Springer: three volumes of Lecture Notes in Computer Science (LNCS), and two volumes of Lecture Notes in Artificial Intelligence (LNAI). This volume of LNAI_14090 includes 67 papers. The organizers of ICIC 2023, including Eastern Institute of Technology, China Zhongyuan University of Technology, China, and Zhengzhou University of Light Industry, China, made an enormous effort to ensure the success of the conference. We hereby would like to thank the members of the Program Committee and the referees for their collective effort in reviewing and soliciting the papers. In particular, we would like to thank all the authors for contributing their papers. Without the high-quality submissions from the authors, the success of the conference would not have been possible. Finally,
vi
Preface
we are especially grateful to the International Neural Network Society, and the National Science Foundation of China for their sponsorship. June 2023
De-Shuang Huang Prashan Premaratne Boyang Qu Baohua Jin Kang-Hyun Jo Abir Hussain
Organization
General Co-chairs De-Shuang Huang Shizhong Wei
Eastern Institute of Technology, China Zhengzhou University of Light Industry, China
Program Committee Co-chairs Prashan Premaratne Baohua Jin Kang-Hyun Jo Abir Hussain
University of Wollongong, Australia Zhengzhou University of Light Industry, China University of Ulsan, Republic of Korea Liverpool John Moores University, UK
Organizing Committee Co-chair Hui Jing
Zhengzhou University of Light Industry, China
Organizing Committee Members Fubao Zhu Qiuwen Zhang Haodong Zhu Wei Huang Hongwei Tao Weiwei Zhang
Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China Zhengzhou University of Light Industry, China
Award Committee Co-chairs Michal Choras Hong-Hee Lee
Bydgoszcz University of Science and Technology, Poland University of Ulsan, Republic of Korea
viii
Organization
Tutorial Co-chairs Yoshinori Kuno Phalguni Gupta
Saitama University, Japan Indian Institute of Technology Kanpur, India
Publication Co-chairs Valeriya Gribova M. Michael Gromiha Boyang Qu
Far Eastern Branch of Russian Academy of Sciences, Russia Indian Institute of Technology Madras, India Zhengzhou University, China
Special Session Co-chairs Jair Cervantes Canales Chenxi Huang Dhiya Al-Jumeily
Autonomous University of Mexico State, Mexico Xiamen University, China Liverpool John Moores University, UK
Special Issue Co-chairs Kyungsook Han Laurent Heutte
Inha University, Republic of Korea Université de Rouen Normandie, France
International Liaison Co-chair Prashan Premaratne
University of Wollongong, Australia
Workshop Co-chairs Yu-Dong Zhang Hee-Jun Kang
University of Leicester, UK University of Ulsan, Republic of Korea
Organization
ix
Publicity Co-chairs Chun-Hou Zheng Dhiya Al-Jumeily Jair Cervantes Canales
Anhui University, China Liverpool John Moores University, UK Autonomous University of Mexico State, Mexico
Exhibition Contact Co-chair Fubao Zhu
Zhengzhou University of Light Industry, China
Program Committee Members Abir Hussain Antonio Brunetti Antonino Staiano Bin Liu Bin Qian Bin Yang Bing Wang Binhua Tang Bingqiang Liu Bo Li Changqing Shen Chao Song Chenxi Huang Chin-Chih Chang Chunhou Zheng Chunmei Liu Chunquan Li Dahjing Jwo Dakshina Ranjan Kisku Dan Feng Daowen Qiu Dharmalingam Muthusamy Dhiya Al-Jumeily Dong Wang
Liverpool John Moores University, UK Polytechnic University of Bari, Italy Università di Napoli Parthenope, Italy Beijing Institute of Technology, China Kunming University of Science and Technology, China Zaozhuang University, China Anhui University of Technology, China Hohai University, China Shandong University, China Wuhan University of Science and Technology, China Soochow University, China Harbin Medical University, China Xiamen University, China Chung Hua University, Taiwan Anhui University, China Howard University, USA University of South China, China National Taiwan Ocean University, Taiwan National Institute of Technology Durgapur, India Huazhong University of Science and Technology, China Sun Yat-sen University, China Bharathiar University, India Liverpool John Moores University, UK University of Jinan, China
x
Organization
Dunwei Gong Eros Gian Pasero Evi Sjukur Fa Zhang Fengfeng Zhou Fei Guo Gaoxiang Ouyang Giovanni Dimauro Guoliang Li Han Zhang Haibin Liu Hao Lin Haodi Feng Hongjie Wu Hongmin Cai Jair Cervantes Jixiang Du Jing Hu Jiawei Luo Jian Huang Jian Wang Jiangning Song Jinwen Ma Jingyan Wang Jinxing Liu Joaquin Torres-Sospedra Juan Liu Jun Zhang Junfeng Xia Jungang Lou Kachun Wong Kanghyun Jo Khalid Aamir Kyungsook Han L. Gong Laurent Heutte
China University of Mining and Technology, China Politecnico di Torino, Italy Monash University, Australia Beijing Institute of Technology, China Jilin University, China Central South University, China Beijing Normal University, China University of Bari, Italy Huazhong Agricultural University, China Nankai University, China Beijing University of Technology, China University of Electronic Science and Technology of China, China Shandong University, China Suzhou University of Science and Technology, China South China University of Technology, China Autonomous University of Mexico State, Mexico Huaqiao University, China Wuhan University of Science and Technology, China Hunan University, China University of Electronic Science and Technology of China, China China University of Petroleum, China Monash University, Australia Peking University, China Abu Dhabi Department of Community Development, UAE Qufu Normal University, China Universidade do Minho, Portugal Wuhan University, China Anhui University, China Anhui University, China Huzhou University, China City University of Hong Kong, China University of Ulsan, Republic of Korea University of Sargodha, Pakistan Inha University, Republic of Korea Nanjing University of Posts and Telecommunications, China Université de Rouen Normandie, France
Organization
Le Zhang Lejun Gong Liang Gao Lida Zhu Marzio Pennisi Michal Choras Michael Gromiha Ming Li Minzhu Xie Mohd Helmy Abd Wahab Nicola Altini Peng Chen Pengjiang Qian Phalguni Gupta Prashan Premaratne Pufeng Du Qi Zhao Qingfeng Chen Qinghua Jiang Quan Zou Rui Wang Saiful Islam Seeja K. R. Shanfeng Zhu Shikui Tu Shitong Wang Shixiong Zhang Sungshin Kim Surya Prakash Tatsuya Akutsu Tao Zeng Tieshan Li Valeriya Gribova
Vincenzo Randazzo
xi
Sichuan University, China Nanjing University of Posts and Telecommunications, China Huazhong Univ. of Sci. & Tech., China Huazhong Agriculture University, China University of Eastern Piedmont, Italy Bydgoszcz University of Science and Technology, Poland Indian Institute of Technology Madras, India Nanjing University, China Hunan Normal University, China Universiti Tun Hussein Onn Malaysia, Malaysia Polytechnic University of Bari, Italy Anhui University, China Jiangnan University, China GLA University, India University of Wollongong, Australia Tianjin University, China University of Science and Technology Liaoning, China Guangxi University, China Harbin Institute of Technology, China University of Electronic Science and Technology of China, China National University of Defense Technology, China Aligarh Muslim University, India Indira Gandhi Delhi Technical University for Women, India Fudan University, China Shanghai Jiao Tong University, China Jiangnan University, China Xidian University, China Pusan National University, Republic of Korea IIT Indore, India Kyoto University, Japan Guangzhou Laboratory, China University of Electronic Science and Technology of China, China Institute of Automation and Control Processes, Far Eastern Branch of Russian Academy of Sciences, Russia Politecnico di Torino, Italy
xii
Organization
Waqas Haider Wen Zhang Wenbin Liu Wensheng Chen Wei Chen Wei Peng Weichiang Hong Weidong Chen Weiwei Kong Weixiang Liu Xiaodi Li Xiaoli Lin Xiaofeng Wang Xiao-Hua Yu Xiaoke Ma Xiaolei Zhu Xiangtao Li Xin Zhang Xinguo Lu Xingwei Wang Xinzheng Xu Xiwei Liu Xiyuan Chen Xuequn Shang Xuesong Wang Yansen Su Yi Xiong Yu Xue Yizhang Jiang Yonggang Lu Yongquan Zhou Yudong Zhang Yunhai Wang Yupei Zhang Yushan Qiu
Kohsar University Murree, Pakistan Huazhong Agricultural University, China Guangzhou University, China Shenzhen University, China Chengdu University of Traditional Chinese Medicine, China Kunming University of Science and Technology, China Asia Eastern University of Science and Technology, Taiwan Shanghai Jiao Tong University, China Xi’an University of Posts and Telecommunications, China Shenzhen University, China Shandong Normal University, China Wuhan University of Science and Technology, China Hefei University, China California Polytechnic State University, USA Xidian University, China Anhui Agricultural University, China Jilin University, China Jiangnan University, China Hunan University, China Northeastern University, China China University of Mining and Technology, China Tongji University, China Southeast Univ., China Northwestern Polytechnical University, China China University of Mining and Technology, China Anhui University, China Shanghai Jiao Tong University, China Huazhong University of Science and Technology, China Jiangnan University, China Lanzhou University, China Guangxi University for Nationalities, China University of Leicester, UK Shandong University, China Northwestern Polytechnical University, China Shenzhen University, China
Organization
Yunxia Liu Zhanli Sun Zhenran Jiang Zhengtao Yu Zhenyu Xuan Zhihong Guan Zhihua Cui Zhiping Liu Zhiqiang Geng Zhongqiu Zhao Zhuhong You
xiii
Zhengzhou Normal University, China Anhui University, China East China Normal University, China Kunming University of Science and Technology, China University of Texas at Dallas, USA Huazhong University of Science and Technology, China Taiyuan University of Science and Technology, China Shandong University, China Beijing University of Chemical Technology, China Hefei University of Technology, China Northwestern Polytechnical University, China
Contents – Part V
Intelligent Computing in Computer Vision ATY-SLAM: A Visual Semantic SLAM for Dynamic Indoor Environments . . . . Hao Qi, Zhuhua Hu, Yunfeng Xiang, Dupeng Cai, and Yaochi Zhao Minimizing Peak Memory Footprint of Inference on IoTs Devices by Efficient Recomputation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Xiaofeng Sun, Chaonong Xu, and Chao Li AGAM-SLAM: An Adaptive Dynamic Scene Semantic SLAM Method Based on GAM . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Dupeng Cai, Zhuhua Hu, Ruoqing Li, Hao Qi, Yunfeng Xiang, and Yaochi Zhao
3
15
27
A Water Level Ruler Recognition Method Based on Deep Learning Technology . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Jingbo An, Kefeng Song, Di Wu, and Wanxian He
40
FRVidSwin:A Novel Video Captioning Model with Automatical Removal of Redundant Frames . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Zehao Dong, Yuehui Chen, Yi Cao, and Yaou Zhao
51
A Simple Mixed-Supervised Learning Method for Salient Object Detection . . . . Congjin Gong, Gang Yang, and Haoyu Dong
63
A Lightweight Detail-Fusion Progressive Network for Image Deraining . . . . . . . Siyi Ding, Qing Zhu, and Wanting Zhu
75
SwinCGH-Net: Enhancing Robustness of Object Detection in Autonomous Driving with Weather Noise via Attention . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Shi Cao, Qing Zhu, and Wanting Zhu
88
MBDNet: Mitigating the “Under-Training Issue” in Dual-Encoder Model for RGB-d Salient Object Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Shuo Wang, Gang Yang, Yunhua Zhang, Qiqi Xu, and Yutao Wang
99
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix and Hybrid Loss Function for Hyperspectral Image Classification . . . . . . 112 Jinchao Jiao, Changqing Yin, and Fei Teng
xvi
Contents – Part V
Brain Tumor Image Segmentation Network Based on Dual Attention Mechanism . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 125 Fuyun He, Yao Zhang, Yan Wei, Youwei Qian, Cong Hu, and Xiaohu Tang A ConvMixEst and Multi-attention UNet for Intervertebral Disc Segmentation in Multi-modal MRI . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 137 Sipei Lu, Hanqiang Liu, and Xiangkai Guo One-Dimensional Feature Supervision Network for Object Detection . . . . . . . . . 147 Longchao Shen, Yongsheng Dong, Yuanhua Pei, Haotian Yang, Lintao Zheng, and Jinwen Ma Use the Detection Transformer as a Data Augmenter . . . . . . . . . . . . . . . . . . . . . . . 157 Luping Wang and Bin Liu An Unsupervised Video Summarization Method Based on Multimodal Representation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 171 Zhuo Lei, Qiang Yu, Lidan Shou, Shengquan Li, and Yunqing Mao An Industrial Defect Detection Network with Fine-Grained Supervision and Adaptive Contrast Enhancement . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 181 Ying Xiang, Hu Yifan, Fu Xuzhou, Gao Jie, and Liu Zhiqiang InterFormer: Human Interaction Understanding with Deformed Transformer . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 193 Di He, Zexing Du, Xue Wang, and Qing Wang UCLD-Net: Decoupling Network via Unsupervised Contrastive Learning for Image Dehazing . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 204 Zhitao Liu, Tao Hong, and Jinwen Ma LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model for Unsupervised Medical Image Registration . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 216 Zhaoyang Liu, Xiuyang Zhao, Dongmei Niu, Bo Yang, and Caiming Zhang Joint Skeleton and Boundary Features Networks for Curvilinear Structure Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 227 Yubo Wang, Li Chen, Zhida Feng, and Yunxiang Cao A Driver Abnormal Behavior Detection Method Based on Improved YOLOv7 and OpenPose . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 239 Xingquan Cai, Shun Zhou, Jiali Yao, Pengyan Cheng, and Yan Hu
Contents – Part V
xvii
Semi-supervised Semantic Segmentation Algorithm for Video Frame Corruption . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 251 Jingyan Ye, Li Chen, and Jun Li DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 263 Lei Shi, Hongqiu Xue, Caixia Meng, Yufei Gao, and Lin Wei Improved YOLOv5s Method for Nut Detection on Ultra High Voltage Power Towers . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 277 Lang Xu, Yi Xia, Jun Zhang, Bing Wang, and Peng Chen Improved Deep Learning-Based Efficientpose Algorithm for Egocentric Marker-Less Tool and Hand Pose Estimation in Manual Assembly . . . . . . . . . . . . 288 Zihan Niu, Yi Xia, Jun Zhang, Bing Wang, and Peng Chen Intelligent Computing in Communication Networks Adaptive Probabilistic Broadcast in Ad Hoc Networks . . . . . . . . . . . . . . . . . . . . . 301 Shuai Xiaoying, Yin Yuxia, and Zhang Bin A Light-Weighted Model of GRU + CNN Hybrid for Network Intrusion Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 314 Dong Yang, Can Zhou, and Songjie Wei Reinforcement-Learning Based Preload Strategy for Short Video . . . . . . . . . . . . . 327 Zhicheng Ren, Yongxin Shan, Wanchun Jiang, Yijing Shan, Danfeng Shan, and Jianxin Wang Particle Swarm Optimization with Genetic Evolution for Task Offloading in Device-Edge-Cloud Collaborative Computing . . . . . . . . . . . . . . . . . . . . . . . . . . . 340 Bo Wang and Jiangpo Wei DBCS-SMJF: Designing a BLDCM Control System for Small Machine Joints Using FOC . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 351 Leyi Zhang, Yingjie Long, Yingbiao Hu, and Huinian Li Intelligent Data Analysis and Prediction A Hybrid Tourism Recommendation System Based on Multi-objective Evolutionary Algorithm and Re-ranking . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 363 Ruifen Cao, Zijue Li, Pijing Wei, Ye Tian, and Chunhou Zheng
xviii
Contents – Part V
Intelligence Evaluation of Music Composition Based on Music Knowledge . . . . 373 Shuo Wang, Yun Tie, Xiaobing Li, Xiaoqi Wang, and Lin Qi StockRanker: A Novelty Three-Stage Ranking Model Based on Deep Learning for Stock Selection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 385 Rui Ding, Xinyu Ke, and Shuangyuan Yang Design and Application of Mapping Model for Font Recommendation System Based on Contents Emotion Analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 397 Young Seo Ji and Soon bum Lim Time Series Prediction of 5G Network Data Based on Improved EEMD-BiLSTM Prediction Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 409 Jianrong Li, Zheng Li, Jie Li, Gongcheng Shi, Chuanlei Zhang, and Hui Ma CWA-LSTM: A Stock Price Prediction Model Based on Causal Weight Adjustment . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 421 Qihang Zhang, Zhaoguo Liu, Zhuoer Wen, Da Huang, and Weixia Xu StPrformer: A Stock Price Prediction Model Based on Convolutional Attention Mechanism . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 433 Zhaoguo Liu, Qihang Zhang, Da Huang, and Dan Wu Diagnosis of Lung Cancer Subtypes by Combining Multi-graph Embedding and Graph Fusion Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 445 Siyu Peng, Jiawei Luo, Cong Shen, and Bo Wang Detformer: Detect the Reliable Attention Index for Ultra-long Time Series Forecasting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 457 Xiangxu Meng, Wei Li, Zheng Zhao, Zhihan Liu, Guangsheng Feng, and Huiqiang Wang An Ultra-short-Term Wind Speed Prediction Method Based on Spatio-Temporal Feature Decomposition and Multi Feature Fusion Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 469 Xuewei Li, Guanrong He, Jian Yu, Zhiqiang Liu, Mei Yu, Weiping Ding, and Wei Xiong A Risk Model for Assessing Exposure Factors Influence Oil Price Fluctuations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 482 Raghad Alshabandar, Ali Jaddoa, and Abir Hussain
Contents – Part V
xix
A Dynamic Graph Convolutional Network for Anti-money Laundering . . . . . . . . 493 Tianpeng Wei, Biyang Zeng, Wenqi Guo, Zhenyu Guo, Shikui Tu, and Lei Xu Bearing Fault Detection Based on Graph Cyclostationary Signal Analysis and Convolutional Neural Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 503 Cong Chen and Hui Li Rolling Bearing Fault Diagnosis Based on GWVD and Convolutional Neural Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 514 Xiaoxuan Lv and Hui Li Expert Systems Aggregation of S-generalized Distances . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 527 Lijun Sun, Chen Zhao, and Gang Li Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy for Green Multi-depot Heterogeneous-Fleet Vehicle Routing Problem . . . . . . . . . 537 Linhao Xu, Bin Qian, Rong Hu, Naikang Yu, and Huaiping Jin A Novel Algorithm to Multi-view TSK Classification Based on the Dirichlet Distribution . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 549 Lei Nie, Zhenyu Qian, Yaping Zhao, and Yizhang Jiang Expert Knowledge-Driven Clothing Matching Recommendation System . . . . . . 559 Qianwen Tao, Jun Wang, ChunYun Chen, Shuai Zhu, and Youqun Shi Research on Construction Method of IoT Knowledge System Based on Knowledge Graph . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 573 Qidi Wu, Shuai Zhu, Qianwen Tao, Yucheng Zhao, and Youqun Shi Reinforcement Learning Robust Anti-forensics on Audio Forensics System . . . . . . . . . . . . . . . . . . . . . . . . . . 589 Qingqing Wang and Dengpan Ye Off-Policy Reinforcement Learning with Loss Function Weighted by Temporal Difference Error . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 600 Bumgeun Park, Taeyoung Kim, Woohyeon Moon, Sarvar Hussain Nengroo, and Dongsoo Har
xx
Contents – Part V
On Context Distribution Shift in Task Representation Learning for Online Meta RL . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 614 Chenyang Zhao, Zihao Zhou, and Bin Liu Dynamic Ensemble Selection with Reinforcement Learning . . . . . . . . . . . . . . . . . 629 Lihua Liu, Jibing Wu, Xuan Li, and Hongbin Huang Reinforcement Learning for Routing Problems with Hybrid Edge-Embedded Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 641 Xinyu Ke, Rui Ding, and Shuangyuan Yang Advancing Air Combat Tactics with Improved Neural Fictitious Self-play Reinforcement Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 653 Shaoqin He, Yang Gao, Baofeng Zhang, Hui Chang, and Xinchen Zhang Recent Advances in Deep Learning Methods and Techniques for Medical Image Analysis Power Grid Knowledge Graph Completion with Complex Structure Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 669 Zhou Zheng, Jun Guo, Feilong Liao, Qiyao Huang, Yingyue Zhang, Zhichao Zhao, Chenxiang Lin, and Zhihong Zhang A Graph-Transformer Network for Scene Text Detection . . . . . . . . . . . . . . . . . . . . 680 Yongrong Wu, Jingyu Lin, Houjin Chen, Dinghao Chen, Lvqing Yang, and Jianbing Xiahou Hessian Non-negative Hypergraph . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 691 Lingling Li, Zihang Li, Mingkai Wang, Taisong Jin, and Jie Liu Explainable Knowledge Reasoning on Power Grid Knowledge Graph . . . . . . . . . 705 Yingyue Zhang, Qiyao Huang, Zhou Zheng, Feilong Liao, Longqiang Yi, Jinhu Li, Jiangsheng Huang, and Zhihong Zhang A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph Embedding Models for Main Electrical Equipment . . . . . . . . . . . . . . . . . . . . . . . . . 715 Yanyu Chen, Jianye Huang, Jian Qian, Longqiang Yi, Jinhu Li, Jiangsheng Huang, and Zhihong Zhang Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 726 Haitao Zhang, Yuxing Dai, Qing Yin, Xin He, Jian Ju, Haotian Zheng, Fengling Shen, Wenjuan Guo, Jinhu Li, Zhihong Zhang, and Yong Duan
Contents – Part V
xxi
Improve Knowledge Graph Completion for Diagnosing Defects in Main Electrical Equipment . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 738 Jianye Huang, Jian Qian, Yanyu Chen, Rui Lin, Yuyou Weng, Guoqing Lin, and Zhihong Zhang Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 749 Shitao Cai, Zhou Zheng, Chenxiang Lin, Longqiang Yi, Jinhu Li, Jiangsheng Huang, and Zhihong Zhang CSAANet: An Attention-Based Mechanism for Aligned Few-Shot Semantic Segmentation Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 760 Guangpeng Wei and Pengjiang Qian Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 773 Liangjun Chen and Pengjiang Qian GAN for Blind Image Deblurring Based on Latent Image Extraction and Blur Kernel Estimation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 785 Xiaowei Huang and Pengjiang Qian GLUformer: An Efficient Transformer Network for Image Denoising . . . . . . . . . 797 Chenghao Xue and Pengjiang Qian Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 809
Intelligent Computing in Computer Vision
ATY-SLAM: A Visual Semantic SLAM for Dynamic Indoor Environments Hao Qi1 , Zhuhua Hu1(B) , Yunfeng Xiang1 , Dupeng Cai1 , and Yaochi Zhao2 1 School of Information and Communication Engineering, Hainan University, Haikou 570228,
China [email protected] 2 School of Cyberspace Security (School of Cryptology), Hainan University, Haikou 570228, China [email protected]
Abstract. Visual Simultaneous Localization and Mapping (VSLAM) is a critical technology that enables mobile robots to accurately sense their surroundings and perform localization and map building. However, the assumptions underlying VSLAM algorithms are based on static environments, often leading to poor performance in highly dynamic indoor scenes. Accurately estimating camera pose and achieving precise localization in such environments poses a significant challenge. This paper presents ATY-SLAM: Adaptive Thresholding combining YOLOv7tiny SLAM, a VSLAM method for dynamic feature point culling and keyframe optimization in highly dynamic scenes. It can effectively improve robustness and accuracy in dynamic indoor environments. In ATY-SLAM system, we employ a combination of the YOLOv7-tiny object detection network, motion consistency detection, and the LK optical flow algorithm to detect dynamic regions in the image. Then, the unstable feature points are removed, thus further improving the stability of the VSLAM system. We then employ an adaptive thresholding method to select a stable keyframe, which solves the problem of poor quality of keyframes selected by existing heuristic thresholding. Experimental results on the public TUM RGB-D dataset demonstrate that the proposed algorithm reduces the absolute trajectory error in high dynamic scenes by an average of 96.4% compared to ORB-SLAM3. Additionally, while maintaining similar accuracy, the time required to process each frame in the tracking thread is reduced by over 98.8% compared to the classical DynaSLAM algorithm, achieving real-time performance. Keywords: VSLAM · Highly dynamic scenes · ORB-SLAM3 · YOLOv7- tiny · Adaptive threshold
This work was supported in part by the Key Research and Development Project of Hainan Province (ZDYF2022GXJS348, ZDYF2022SHFZ039), the Hainan Province Natural Science Foundation (623RC446) and the National Natural Science Foundation of China (62161010, 61963012). The authors would like to thank the referees for their constructive suggestions. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 3–14, 2023. https://doi.org/10.1007/978-981-99-4761-4_1
4
H. Qi et al.
1 Introduction SLAM techniques have been extensively researched in recent years, enabling instantaneous map construction and sensor position estimation in unknown environments. Among them, vision-based SLAM techniques have the advantages of low cost and rich information acquisition, thus becoming a hot area of current research [1]. However, the assumption of static environments underlying traditional visual SLAM limits its application scenarios. The traditional random sampling consensus (RANSAC) algorithm can identify feature points on dynamic objects as outliers and filter them, but it is limited to lowly dynamic scenes with few dynamic elements. In highly dynamic scenes where dynamic objects occupy a large image area, traditional visual SLAM techniques may suffer from heavily distributed feature points on dynamic objects, leading to a significant decrease in accuracy. The estimated trajectories may no longer be usable. To address this issue, research on visual SLAM in dynamic scenes has become a frontier topic. The study of visual SLAM in dynamic scenes has attracted significant attention and research efforts [2]. 1.1 Related Work SLAM technology was first proposed by Smith and Cheeseman in 1986 and has since been developed for over three decades. In 2007, Prof. A.J. Davison proposed MonoSLAM [3] as the first real-time monocular vision SLAM system, but it could only perform localization and map building offline. In the same year, Klein et al. proposed the PTAM [4] system, the first algorithm to use nonlinear optimization as a backend and parallelize the tracking and map-building process. However, it was only applicable to small scenarios. In 2015, MurArtal et al. proposed the ORB-SLAM [5], which estimates poses based on ORB feature points and adopts a bag-of-words model to overcome cumulative errors. In 2021, Campos et al. proposed ORB-SLAM3 [6], a visual SLAM algorithm that supports multiple cameras. This algorithm optimizes several aspects, such as map initialization, relocation, closed-loop detection, and keyframe selection. The above VSLAM algorithms all assume a stationary environment, and if there are moving objects in the scene, the localization and mapping accuracy will be reduced. To solve this problem, the researchers proposed a motion segmentation method based on the assumption that most of the image features are static. Kim et al. [7] constructed a static background based on initial image depth and removed non-static objects in subsequent frames, but this approach had limitations in filtering out pedestrians. Then, Sun et al. [8] proposed a motion removal method based on RGB-D data as a preprocessing module to filter out motion objects. With the rapid development of deep learning in image processing, motion segmentation methods based on a priori semantic information have been applied in dynamic SLAM systems. DS-SLAM [9], based on the ORB-SLAM2 [10] frame-work, uses the SegNet [11] semantic segmentation network to obtain semantic information in the scene and removes the influence of moving objects. However, the system’s approach of semantic segmentation for each frame significantly in- creases the computational load, making it unsuitable for real-time operation. Then, the DynaSLAM [12] algorithm combines
ATY-SLAM: A Visual Semantic SLAM for Dynamic
5
the MaskRCNN [13] semantic segmentation network with multi-view geometry to filter dynamic feature points generated by moving objects using depth information consistency, but this system is not optimized for segmentation efficiency. To address the problem of the poor real-time performance of semantic methods in dynamic environments, Liu et al. [14] designed an RDS-SLAM system based on ORB-SLAM3. The system includes semantic segmentation and optimization threads, which run in parallel with the original threads. In addition, Fan et al. [15] designed Blitz-SLAM for indoor dynamic environments and image semantic analysis using the Blitz network. However, the number and distribution of matching points in the image directly affects the accuracy of camera pose estimation. In summary, extant visual SLAM systems suffer from inadequate robustness or real-time performance when operating in highly dynamic environments. Consequently, simultaneously enhancing the robustness and real-time performance of visual SLAM in highly dynamic indoor settings is currently a pressing issue. 1.2 Paper Contributions This paper presents a solution to address the aforementioned issues by integrating an object detection network into the feature extraction process. The resulting detection outcomes are utilized to effectively detect dynamic objects, eliminate outliers, and further refine keyframe selection. The proposed algorithm offers several advantages, including enhanced localization accuracy and improved robustness, while maintaining real-time performance. 1. In the front-end of the proposed system, a lightweight object detection network YOLOv7-tiny is innovatively employed in combination with motion consistency detection and the LK optical flow algorithm to detect dynamic objects in indoor scenes. 2. An adaptive thresholding algorithm is applied to address indoor highly dynamic scenes, and the quality of keyframe selection in dynamic scenes is improved by calculating the number of change points in the observation model to set up the threshold. 3. Experimental results demonstrate that our system reduces the localization error by an average of 96.4% in the absolute trajectory error (ATE) compared to ORB-SLAM3 under the TUM dynamic scene dataset, validating the effectiveness of the proposed approach.
2 System Description In dynamic environments, moving objects can affect the ORB-SLAM3 algorithm, resulting in reduced localization accuracy and poor robustness. To address this issue, we introduce the YOLOv7-tiny object detection algorithm in the front end of the ORB-SLAM3 system to detect targets in the input image and extract feature points. Once the semantic information in the image is obtained, the dynamic objects in the image are identified through the use of motion consistency detection and the LK optical flow algorithm. Based on the dynamic object detection results, the dynamic feature points are removed. This
6
H. Qi et al.
leaves only static feature points to participate in the position calculation. Furthermore, we optimize the selection of keyframes using adaptive thresholding to obtain higherquality keyframes. The overall system flowchart is shown in Fig. 1, with the red dashed boxes representing the improvements of this paper.
Fig. 1. Overview of our algorithm.
2.1 Dynamic Feature Point Culling Based on YOLOv7-Tiny Object detection is a fundamental task in computer vision, with YOLOv7 [16] being one of the most widely used regression-based algorithms. In this study, we adopted YOLOv7tiny as the detection model, consisting of four primary components: input, backbone, neck, and prediction. The input section performs various operations, including data enhancement, to preprocess the images before feeding them to the backbone network. The backbone network partially extracts features from the preprocessed images. The neck module subsequently processes the extracted features through feature fusion to obtain three sizes of large, medium, and small features. The fused features are then inputted into the detection head to produce the output result after inspection. The YOLOv7-tiny is utilized to detect potential dynamic objects. In the dynamic object prediction box generated based on object detection, dynamic feature points are determined by prior knowledge, and the specific elimination process is described as follows [17–19].When the kth frame image is input, all feature points of the image extracted from the visual odometer are denoted as F k . F k can be expressed as follows: Fk = {f1 , f2 , f3 , · · · , fn }
(1)
Once the image passes through the object detection network, all dynamic feature points can be determined using a priori knowledge as Dk . According to the semantic information prediction frame, Dk can be expressed as follows: Dk = {d1 , d2 , d3 , · · · , dn }
(2)
ATY-SLAM: A Visual Semantic SLAM for Dynamic
7
If fi ∈ Dk (i = 1, 2, 3, · · · , n), the feature point f i is considered a dynamic feature point and removed from F k in the tracking thread. The remaining feature points are quasistatic feature points, and the set of these points is denoted as Pk . It can be concluded that Pk ∪ Dk = Fk
(3)
For identifying dynamic feature points, we employ the YOLOv7-tiny algorithm in combination with motion consistency detection and the LK optical flow algorithm [9, 20– 22]. Firstly, the motion consistency detection algorithm based on epipolar geometry constraint is utilized. The formula can be derived according to the research of Chang et al. [20]. |P2 FP1 | d= X 2 + Y 2
(4)
Assuming that P is a coordinate point in the world coordinate system and is projected onto the imaging planes of two cameras, a pair of well-matched feature points P1 and P2 are obtained. F represents the corresponding fundamental matrix, I 1 is the epipolar line corresponding to point P1 , and X and Y are vectors on the epipolar line I 1 . The symbol d is the distance between point P2 and the epipolar line I 1 . If the distance d is greater than a certain threshold εth , then the point is considered as a dynamic point. To address the problem of incorrect motion estimation caused by parallel movement between the camera and dynamic objects, we employ the LK optical flow method. The grayscale of an image can be viewed as a function of time, denoted as I(x, y, t), where (x, y) is the position of the ORB feature point in the image at time t. Based on the brightness constancy assumption of optical flow, the formula can be derived [20]. u (5) = −Itk , k = 1, . . . , 36 Ix Iy k v The variables u and v denote the motion speed of a feature point in the X and Y directions, respectively. I x and I y are the gradients of the image at the corresponding point in the X and Y directions, respectively. I t represents the variation in the feature point’s grayscale over time. If the magnitude of the optical flow of a feature point exceeds a certain threshold εth2 , it is identified as a dynamic point. 2.2 Keyframe Optimization Method Based on Adaptive Thresholding Hosseininaveh et al. [23] generated high accuracy and high-density point clouds using a camera observation model, while Azimi et al. [24] optimized keyframe selection using a camera observation model. However, their study solely focused on keyframes in static scenes, utilizing a monocular camera, stereo camera, and IMU. Based on this[24], we conduct a study for this keyframe selection method in indoor highly dynamic scenes. Each frame containing dynamic objects is initially eliminated by YOLOV7-tiny, while the adaptive thresholding method further diminishes the influence of indoor dynamic objects on localization.
8
H. Qi et al.
The camera observation model is abstracted as a 40-degree four-area cone, with a 10-degree angle for one area. The keyframe selection process involves the last keyframe, reference frame, and current frame. The reference frame is the one that is most similar to and closest to the last keyframe. Each frame tracks the number of feature points that have a matching relationship with the feature points of the last keyframe. The matching points in each frame are connected to a cone vertex to form a vector and classified into one of the four regions based on the angle of the vector to the main axis of the cone. The observation angle and the view vector area change as the frame moves from the last keyframe to the current frame. Adaptive thresholds are designed to model the number of points at which the line of sight changes accordingly, determine the adaptive threshold, and decide whether or not to select the current frame as the keyframe. The initial value of the adaptive threshold T initial is defined as follows: N1 N4 N + N 3 3 N2 N5 (6) Tinitial = 2 where N 1 indicates the number of matching points between the current frame and the last keyframe, N 2 indicates the number of matching feature points between the reference frame and the last keyframe, N 3 indicates the number of feature point changes between the reference frame and the cone region of the last keyframe, N 4 indicates the number of all feature points in the current frame and N 5 indicates the number of all feature points in the reference frame. The adaptive thresholds are defined as follows: θ ≤ 0.1 Tadaptive = (0.1 + 1)Tinitial (7) Tadaptive = (1 + rθ − λ − ε − δ)Tinitial others The θ is used to reduce the number of feature point matches from the reference frame to the current frame. The r is also considered to reduce the probability that two frames after the last keyframe are selected as keyframes and to ensure that another keyframe is selected again after six frames from the last keyframe. The r is used as an adaptive factor multiplied by θ to ensure the correct space distribution of keyframes. The λ is used to narrow the threshold and to solve the problem of conical regions that are prone to change but fail to exceed the threshold. The ε is used to narrow the threshold to prevent the number of cone regions exceeding half of the matching points from causing the adaptive threshold to fail. The δ is used to narrow the threshold to prevent the situation where half of the number of matching points between the reference frame and the last keyframe is greater than the number of matching points between the current frame and the last keyframe causing the adaptive threshold to fail. ⎧ 1 θ = N2N−N ⎪ ⎪ 2 ⎪ ⎪ r = 6−id ⎪ ⎨ 3 N6 3 λ= N − (8) N1 N2 ⎪ N ⎪ 3 ⎪ ε = − 0.5 ⎪ N1 ⎪ ⎩ 2 −2N1 δ=N N2 +2N1 where the id in the r coefficient represents the difference between the current frame and the last keyframe id. The N 6 in the λ coefficient indicates the number of change points between the current frame and the last keyframe in the cone region.
ATY-SLAM: A Visual Semantic SLAM for Dynamic
9
3 Experimental Results and Analysis 3.1 Datasets and Experimental Environment The performance of the proposed SLAM algorithm is evaluated using the RGB-D dataset provided by the Technical University of Munich. This dataset is widely used in the field of visual SLAM research as it provides real trajectories obtained by a high-precision motion capture system. The system uses multiple high-speed cameras with inertial measurement systems to obtain real-time camera position and pose data. For this study, five of these datasets are selected for testing, including the low-dynamic dataset and the high-dynamic dataset. The fr3_sitting dataset is a low dynamic scene where people are sitting on a chair with no visible movement. The fr3_walking dataset is a highly dynamic scene where people are constantly walking through the scene. To improve the performance of the object detection network, we trained it on a server equipped with an Intel Core i9-10900X CPU and an NVIDIA GeForce RTX 3080 GPU. All other experiments were conducted on a PC with an Intel Core i5-11320H CPU, NVIDIA GeForce MX450 GPU, 16G RAM, and Ubuntu 18.04 system environment. For the quantitative evaluation of the experimental results, this paper uses a common evaluation criterion in the field of VSLAM, the absolute trajectory error (ATE) [25]. The error of the system in this paper is compared with the error of ORB-SLAM3 and the relative improvement rate η is calculated as η=
orbslam3 − ours × 100% orbslam3
(9)
3.2 Experimental Results The experiments in this paper were conducted on five TUM datasets, and the results were compared with those of ORB-SLAM3 [6], DynaSLAM [12], DS-SLAM [9], DVOSLAM [26], and OFD-SLAM [27]. Table 1 presents the comparison with ORB-SLAM3 in terms of absolute trajectory error, where RMSE, Mean Error, Median Error, and Standard Deviation (S.D.) values are reported. As shown in Table 1, our proposed algorithm achieved some improvement over ORB-SLAM3 in highly dynamic scenes but did not yield better accuracy in the lowly dynamic environment. This indicates that ORBSLAM3 can achieve better results without interference from dynamic objects. Therefore, the ATY-SLAM algorithm can effectively tackle the problem of low positional estimation accuracy caused by moving objects in dynamic environments. In order to ensure the effectiveness of each module, this paper conducts ablation experiments on the dynamic feature point culling module based on YOLOv7-tiny and the adaptive thresholding module. Table 2 shows the experimental results, where “1” represents the presence of the module, and “0” represents the deletion of the module. The evaluation criterion is the RMSE value. The impact of each module is demonstrated in Table 2. The data results suggest that the algorithm has the smallest RMSE when both modules are used simultaneously. Table 3 presents a comparison between the ATY-SLAM algorithm proposed in this paper and other SLAM algorithms in dynamic environments. It can be observed that
0.3256
0.5534
0.7378
0.5086
0.0081
walking_xyz
walking_rpy
walking_half
sitting_xyz
0.0073
0.4889
0.7168
0.4824
0.3173
0.0093
0.5466
0.8210
0.6426
0.3598
0.0045
0.1940
0.3558
0.3251
0.1524
0.0102
0.0176
0.0279
0.0131
0.0082
Ours STD
Mean
RMSE
Mean
Median
ORB-SLAM3 [6]
walking_static
Sequences
0.0089
0.0146
0.0201
0.0113
0.0066
Median
0.0121
0.0233
0.0431
0.0155
0.0087
RMSE
0.0065
0.0152
0.0329
0.0084
0.0066
STD
–25.93%
96.54%
96.22%
97.63%
97.48%
Mean
–21.92%
97.01%
97.20%
97.66%
97.92%
Median
Improvements%
Table 1. Comparison of absolute trajectory error between ORB-SLAM3 and our algorithm(m)
–30.11%
95.74%
94.75%
97.59%
97.58%
RMSE
–44.44%
92.16%
90.75%
97.42%
95.67%
STD
10 H. Qi et al.
ATY-SLAM: A Visual Semantic SLAM for Dynamic
11
Table 2. Ablation experiment(m). YOLOv7-tiny
Adaptive Threshold
walking_xyz
walking_half
walking_static
walking_rpy
1
1
0.0155
0.0233
0.0087
0.0431
1
0
0.0158
0.0239
0.0113
0.0535
0
1
0.3186
0.2654
0.0609
0.3476
0
0
0.6426
0.5466
0.3598
0.8210
the DynaSLAM and ATY-SLAM algorithm achieve the highest localization accuracy. However, the MaskRCNN semantic segmentation network used in DynaSLAM executes pixel-by-pixel, leading to higher detection accuracy than the object detection model used in this paper. Nevertheless, the lightweight YOLOv7-tiny object detection network used in this paper outperforms the DynaSLAM algorithm in terms of execution speed. To further evaluate the real-time performance, time tests were conducted on the w_xyz dataset for the proposed algorithm and the DynaSLAM algorithm. The time required to process each image frame on the tracking thread for both algorithms is listed in Table 4. As shown, DynaSLAM has poor real-time performance, while the ATY-SLAM algorithm achieves a good balance between accuracy and real-time performance. Therefore, the ATY-SLAM algorithm can effectively address the impact of moving objects on the stability of SLAM systems in dynamic environments. Table 3. The absolute trajectory error of different algorithms(m). Sequences
ORB-SLAM3 [6]
DynaSLAM [12]
DS-SLAM [9]
DVO-SLAM [26]
OFD-SLAM [27]
Ours
walking_static
0.3598
0.0090
0.0081
–
–
0.0087
walking_xyz
0.6426
0.0150
0.0247
0.5966
0.0189
0.0155
walking_rpy
0.8210
0.0400
0.4442
0.7304
0.1533
0.0431
walking_half
0.5466
0.0250
0.0303
0.5287
0.1612
0.0233
sitting_xyz
0.0093
0.0140
–
0.0505
–
0.0121
Table 4. Tracking time comparison(ms). Algorithm
Time
DynaSLAM [12]
2452
Ours
29
Improvements
98.81%
12
H. Qi et al.
Fig. 2. Comparison of estimated trajectories and real trajectories in highly dynamic environments.
Figure 2 shows the trajectory and error distribution estimated using ORB-SLAM3, DynaSLAM, and our proposed algorithm. The black line in the figure represents the true trajectory of the dataset, while the blue line depicts the trajectory estimated by the system, and the red line represents the error between the two. It is evident from the figure that the estimated pose and true trajectory using our proposed algorithm are very close to DynaSLAM. Compared to ORB-SLAM3, various error values are reduced by an order of magnitude. 3.3 Discussion The ATY-SLAM algorithm exhibits superior performance over most dynamic visual SLAM algorithms on the highly dynamic TUM dataset. The DynaSLAM algorithm and the ATY-SLAM algorithm achieve the highest localization accuracy. This algorithm reduces the average absolute trajectory error in highly dynamic scenes by 96.4% compared to ORB-SLAM3. Moreover, the tracking thread section of the ATYSLAM algorithm exhibits a reduction of over 98.8% in time consumption compared to DynaSLAM, which offers better localization accuracy. The ATY-SLAM algorithm remarkably enhances the real-time performance of the algorithm while simultaneously improving its robustness and accuracy.
ATY-SLAM: A Visual Semantic SLAM for Dynamic
13
Nevertheless, the performance of the algorithm on the lowly dynamic TUM dataset is significantly inferior to the localization accuracy of ORB-SLAM3, which is a common limitation of most dynamic scene-based visual SLAM algorithms. Therefore, it is essential to investigate and develop strategies that can improve the accuracy in both highly and lowly dynamic scenes in future research.
4 Conclusions In dynamic indoor environments, visual SLAM systems are often affected by moving objects, which can lead to tracking failures and decreased positional estimation accuracy. To address this issue and enhance localization accuracy, this paper proposes integrating the YOLOv7-tiny algorithm, an efficient deep neural network-based object detection method, with ORB-SLAM3. In the proposed approach, the YOLOv7-tiny algorithm is utilized to identify potential moving objects. Subsequently, motion consistency detection and the LK optical flow algorithm are employed within the predicted bounding box of the dynamic targets to eliminate dynamic objects in the scene, aiming to enhance the accuracy and robustness of visual SLAM in dynamic environments. Additionally, an adaptive thresholding method is employed during keyframe selection to ensure the quality of keyframes in dynamic environments. Based on the results on the TUM dataset, it was found that compared to ORB-SLAM3, the proposed algorithm reduced the absolute trajectory error by more than 96.4%. Moreover, with a similar accuracy to the classical Dyna-SLAM, the time consumption for processing each frame on the tracking thread was reduced by over 98.8% in comparison. In summary, this paper improves the robustness and accuracy of visual SLAM in high dynamic indoor scenarios. Regarding future work, on the one hand, there is still room for optimization in the algorithm model, and the accuracy of indoor dynamic scene localization needs to be further improved. On the other hand, it is necessary to further utilize the semantic information extracted by object detection to construct a semantic map. Acknowledgements. We thank Shenzhen Umouse Technology Development Co., Ltd. (HDKYH-2021307) for their support in equipments and experimental conditions.
References 1. Macario Barros, A., Michel, M., Moline, Y., Corre, G., Carrel, F.: A comprehensive survey of visual slam algorithms. Robotics 11(1), 24 (2022) 2. Wen, S., Li, P., Zhao, Y., Zhang, H., Sun, F., Wang, Z.: Semantic visual slam in dynamic environment. Auton. Robot. 45(4), 493–504 (2021) 3. Davison, A.J., Reid, I.D., Molton, N.D., Stasse, O.: Monoslam: Real-time single camera slam. IEEE Trans. Pattern Anal. Mach. Intell. 29(6), 1052–1067 (2007) 4. Klein, G., Murray, D.: Parallel tracking and mapping for small AR workspaces. In: 2007 6th IEEE and ACM International Symposium on Mixed and Augmented Reality, pp. 225–234. IEEE (2007) 5. Mur-Artal, R., Montiel, J.M.M., Tardos, J.D.: Orb-slam: a versatile and accurate monocular slam system. IEEE Trans. Rob. 31(5), 1147–1163 (2015)
14
H. Qi et al.
6. Campos, C., Elvira, R., Rodríguez, J.J.G., Montiel, J.M., Tardós, J.D.: Orb-slam3: an accurate open-source library for visual, visual–inertial, and multimap slam. IEEE Trans. Rob. 37(6), 1874–1890 (2021) 7. Kim, D.H., Kim, J.H.: Effective background model-based RGB-D dense visual odometry in a dynamic environment. IEEE Trans. Rob. 32(6), 1565–1573 (2016) 8. Sun, Y., Liu, M., Meng, M.Q.H.: Improving RGB-D slam in dynamic environments: a motion removal approach. Robot. Auton. Syst. 89, 110–122 (2017) 9. Yu, C., et al.: Ds-slam: A semantic visual slam towards dynamic environments. In: 2018 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), pp. 1168– 1174. IEEE (2018) 10. Mur-Artal, R., Tardós, J.D.: Orb-slam2: an open-source slam system for monocular, stereo, and RGB-D cameras. IEEE Trans. Rob. 33(5), 1255–1262 (2017) 11. Badrinarayanan, V., Kendall, A., Cipolla, R.: Segnet: a deep convolutional encoder-decoder architecture for image segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(12), 2481– 2495 (2017) 12. Bescos, B., Fácil, J.M., Civera, J., Neira, J.: Dynaslam: tracking, mapping, and inpainting in dynamic scenes. IEEE Robot. Autom. Lett. 3(4), 4076–4083 (2018) 13. He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2961–2969 (2017) 14. Liu, Y., Miura, J.: RDS-SLAM: real-time dynamic slam using semantic segmentation methods. IEEE Access 9, 23772–23785 (2021) 15. Fan, Y., Zhang, Q., Tang, Y., Liu, S., Han, H.: Blitz-slam: a semantic slam in dynamic environments. Pattern Recogn. 121, 108225 (2022) 16. Wang, C.Y., Bochkovskiy, A., Liao, H.Y.M.: Yolov7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors. arXiv preprint arXiv:2207.02696 (2022) 17. Ai, Y., Rui, T., Lu, M., Fu, L., Liu, S., Wang, S.: DDL-SLAM: a robust RGB-D slam in dynamic environments combined with deep learning. IEEE Access 8, 162335–162342 (2020) 18. Fan, Y., et al.: Semantic slam with more accurate point cloud map in dynamic environments. IEEE Access 8, 112237–112252 (2020) 19. Han, S., Xi, Z.: Dynamic scene semantics slam based on semantic segmentation. IEEE Access 8, 43563–43570 (2020) 20. Chang, Z., Wu, H., Sun, Y., Li, C.: RGB-D visual slam based on yolov4-tiny in indoor dynamic environment. Micromachines 13(2), 230 (2022) 21. Kang, R., Shi, J., Li, X., Liu, Y., Liu, X.: DF-SLAM: A deep-learning enhanced visual slam system based on deep local features. arXiv preprint arXiv:1901.07223 (2019) 22. Xiao, L., Wang, J., Qiu, X., Rong, Z., Zou, X.: Dynamic-slam: Semantic monocular visual localization and mapping based on deep learning in dynamic environment. Robot. Auton. Syst. 117, 1–16 (2019) 23. Hosseininaveh, A., Remondino, F.: An imaging network design for UGV-based 3D reconstruction of buildings. Remote Sens. 13(10), 1923 (2021) 24. Azimi, A., Ahmadabadian, A.H., Remondino, F.: PKS: a photogrammetric key- frame selection method for visual-inertial systems built on orb-slam3. ISPRS J. Photogrammetry Remote Sens. 191, 18–32 (2022) 25. Sturm, J., Engelhard, N., Endres, F., Burgard, W., Cremers, D.: A benchmark for the evaluation of RGB-D slam systems. In: 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems. pp. 573–580. IEEE (2012) 26. Kerl, C., Sturm, J., Cremers, D.: Dense visual slam for RGB-D cameras. In: 2013 IEEE/RSJ International Conference on Intelligent Robots and Systems. pp. 2100–2106. IEEE (2013) 27. Cheng, J., Sun, Y., Meng, M.Q.H.: Improving monocular visual slam in dynamic environments: an optical-flow-based approach. Adv. Robot. 33(12), 576–589 (2019)
Minimizing Peak Memory Footprint of Inference on IoTs Devices by Efficient Recomputation Xiaofeng Sun1 , Chaonong Xu1(B) , and Chao Li2 1 China University of Petroleum, Beijing 102249, China
[email protected] 2 Zhejiang Lab, Hangzhou 311121, China
Abstract. Deploying Deep Neural Networks (DNNs) on tiny IoT devices brings excellent convenience to daily life. Nowadays, the small data memory capacity of tiny IoT devices based on Micro-Controller Units (MCUs) poses a great challenge for DNN inference, since once the size of the data memory of an MCU is less than the peak memory requirements of inference a DNN, the DNN cannot be deployed on the MCU. In this paper, we introduced the recomputation method into DNN inference for reducing data memory footprint. By saving activations of some non-divergence nodes instead of these of divergence nodes on one hand, and restoring the necessary activations by recomputation on the other hand, less peak memory footprint could be achieved, thus resolving the above challenge of insufficient data memory. We propose an efficient inference scheduling algorithm based on the idea of recomputation and evaluate our proposed approach using some popular DNNs including InceptionV1, InceptionV3, and ResNet. Relative to TensorFlow Lite Micro, a widely-used framework for intelligent applications on MCUs, our algorithm achieves a 1.07x−1.25x reduction in peak memory footprint, and 0.01x−0.23x higher than the state-of-the-art algorithm. Keywords: Inference · Recomputation · Neural Network · MCU
1 Introduction According to relevant statistics, ting IoTs based on MCUs, such as printers, intelligent cameras, industrial sensors, etc., have reached 250 billion. The large number shows their irreplaceable roles in our daily lives. Deep Neural Networks (DNNs), which provide intelligence processing capability, will thus find brilliant application prospects in IoTs. In other words, the era of AIoT (Artificial Intelligence of Things) will be the focus of the next decade. Currently, the prevalent steps for developing NNs on IoTs devices involve: (1) training intelligent models on devices with high computing powers, such as high-performance servers or clusters, and getting a weighted computation graph. (2) deploying the trained models on IoTs devices and doing inference based on the computation graph. Obviously, the feasibility of inference on IoTs devices will be challenged by the gap between model requirements and memory capacity available. For tiny IoTs devices, the challenge is © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 15–26, 2023. https://doi.org/10.1007/978-981-99-4761-4_2
16
X. Sun et al.
more severe. A lot of work on network pruning, low-bit quantization, knowledge distill, etc. are proposed to overcome the gap. However, these works modify network models more or less and will bring accuracy loss. Figure 1 illustrates the peak memory and the average memory requirements for inferencing ResNet-18 [1] and InceptionV1 [2] with input sizes being 42 × 42 × 3 and 86 × 86 × 3, respectively. Peak memory instead of average memory is our focus since once the data memory size of a hardware platform is less than the peak memory requirements of a DNN, the DNN cannot be deployed on the hardware platform. In our experiments, although the average memory requirement1 is small, the peak memory footprint far exceeds the data memory of a typical MCU2 . Their gap poses a great challenge for developing applications on memory-constrained AIoTs.
Fig. 1. Comparison of peak memory footprint and average memory footprint.
Through analyzing the memory footprint in the inference stage, we discovered that the activations of divergence operators3 in the model have to be stored until all of its successor operators have been executed, which results in large memory requirements. We wonder if we can save peak memory footprint by freeing the activations of divergence operators. We employ the idea of recomputation, which saves activations of non-divergence nodes instead of divergence nodes, and restores them when needed by recalculation from the saved ones. To illustrate the outline of the recomputation approach, we provide a simple example in Fig. 2. Figure 2(a) illustrates the inference scheduling sequence based on the classic topology sorting algorithm. For every operator in the computation graph, the number labeled in the red box indicates its scheduling index. For example, for the operator L1, it will be executed at the first step during inference. Based on the inference steps labeled, since L1 is a divergence node, its output l1 must be held in memory until step 5 for proper inference. We compute the memory footprint at each step, as shown by the red line in Fig. 2(c), where the peak memory footprint is achieved at step 3, which is 30 memory units and includes the output of L2, L3, and L1. In comparison, Fig. 2(b) illustrates the inference scheduling sequence based on our recomputation strategy. Its 1 The required average memory is the average of the size of the SRAM occupied in running every
operator of a computation graph. 2 For the popular MCU STM32F7, it only has 512kB SRAM. 3 For example, L1 in Fig. 2(a) is a divergence node.
Minimizing Peak Memory Footprint of Inference on IoTs Devices
17
operator scheduling sequence is as follows: we hold the input of L1, i.e., l0 , discard l1 after step 2, and recompute l1 from l0 at step 4. In that way, the peak memory footprint is achieved at step 2, which is 23 memory units and includes the output of L1, L2 and l0 . The memory footprint at each step is shown by the blue line in Fig. 2(c). The results demonstrate that the recomputation strategy can successfully reduce the peak memory footprint of inference. Obviously, finding the optimal inference scheduling sequence for minimizing peak memory is a complex problem, and the addition of the recomputation strategy undoubtedly further increases the difficulty of the problem. We solve the problem in four steps as follows. (1) We find a basic scheduling sequence using the ordinary scheduling algorithm based on topology sorting (2) We find all URSs on the computation graph. (3) Combine these URSs by constructing their power set, and for every combined case, we construct an inference scheduling strategy in a heuristic way and compute its peak memory. (4) Compare these peak memory footprints for all combined cases brutally, and get the optimal inference scheduling sequence.
Fig. 2. Example of recomputation on a CNN computation graph. The red box beside each operator denotes its scheduling sequence in (a) and (b), and the number on each edge represents the size of the activations. For any operator, its input and output space must be allocated in data memory for executing the operator. For instance, the data memory required for L6 is 20 (8 + 12) memory units. (a) and (b) are two distinct scheduling strategies, while (c) illustrates their memory footprints according to their scheduling sequence.
Our contribution can be summarized as follows: 1. To our knowledge, we are the first to introduce the idea of recomputation in model inference4 . Although this method slightly increases the computation burden, it effectively reduces the memory footprints without any compromise on inference accuracy. 2. We developed a novel and efficient inference scheduling algorithm based on the idea of recomputation. 4 The recomputation technique has recently been proposed and used in model training, which
consists of two steps: (1). Discarding some results computed during the forward phase once they are no longer needed in the forward phase. (2). As these results are needed during the backward phase, they are recomputed from other activations kept.
18
X. Sun et al.
3. We do a series of real performance evaluations for typical DNNs. Compared to TensorFlow Lite Micro (TFLM), our proposed algorithm achieves a 1.07x to 1.25x reduction in peak memory footprint, which is 0.01x to 0.23x higher than the state-of-the-art achieved by SERENITY.
2 Related Work The process of scheduling a DNN model can be subdivided into two distinct levels: interoperator scheduling and intra-operator scheduling. The latter is primarily concerned with excavating parallelism within an operator, such as parallel matrix multiplication, while inter-operator scheduling focuses on excavating parallelism among operators. IOS [6] and Rammer [7] employ both intra- and inter-operator parallelism to speed up model inference. Comparatively, our approach focuses on reducing peak memory of inference instead of speed-up by parallelisms. Peak memory reduction is now a hot spot in AIoTs, in MCUNet [8] and MCUNetV2 [9], Han et al. proposed TinyEngine and patch-based inference to change the computation mode of operators, aiming to address the problem of insufficient memory. Their focus is on operator implementation, while we focus on graph-level optimization. Similarly, SERENITY [5] and HMCOS [3] employ inter-operator scheduling to minimize the peak memory footprint, where SERENITY designs a dynamic programming-based algorithm to find the optimal topological order, while HMCOS is based on a hierarchical scheduling algorithm for networks found by neural architecture search. However, both SERENITY and HMCOS only allow a one-shot execution for an operator, thereby restricting the scheduling space. Our research endeavors to overcome the limitation and finds peakmemory-friendly scheduling by recomputation. The concept of recomputation was initially introduced in the domain of DNNs training by [4], which viewed it as a trade-off between memory and computation. Subsequent research has further explored the application of recomputation in DNNs training, and they focus on which part of activations generated during training will be saved by utilizing techniques such as tree decomposition [10] or integer programming [11]. We, however, have pioneered the application of recomputation into DNNs inference and implemented an alternative approach to reducing peak memory footprint by allowing recomputation instead of one-shot computation.
3 Preliminaries In this section, we present the fundamental concepts and definitions necessary for understanding our works. 3.1 Computation Graph and Inference Schedule A DNN can be defined by a computation graph G = (V , E), where V is the set of operators and E is the set of edges representing dependencies. A computation graph is actually a directed acyclic graph. A single operation, such as convolution or concatenation, is described by an operator in the graph. An intermediate tensor that is an output
Minimizing Peak Memory Footprint of Inference on IoTs Devices
19
of operator i and an input of operator j are represented by edge (i, j). Figure 2(a) is a computation graph of a simple DNN. We define an inference schedule sequence H of a computation graph G as H =< v1 , v2 , v3 , . . . , vk >, where vi is an operator in V and k ≥ |V |, and vi specifies the operator to be executed at the step i. For example, the schedule for Fig. 2(a) is H =< L1, L2, L3, L4, L5, L6, L7, L8 >. Obviously, an effective inference schedule must adhere to the topological relationships of the computation graph. 3.2 Unit Recomputation Strategy The recomputation strategy involves recomputation tensors and their corresponding pre→ decessor tensors. To delineate this relationship, we define URS and denote it as (− e− p es ) as − − → follows. For a URS (ep es ), ep and es are edges on the computation graph, ep is before es in the topology sorting of the graph, and the size of ep is less than that of es , and besides, any subpath of it is not a URS. → URS (e−− p es ) is virtually a memory reduction plan which does not save es and obtains −→ it using ep with recomputing. In Fig. 2(b), (l0 l1 ) is a URS. The indecomposability of URS makes it the building block of our heuristic algorithm, which will be introduced in Sect. 4. 3.3 Problem Formulation Since inference of a DNN using a scheduling sequence consists of multiple sequential steps. The peak memory is obviously the maximum memory footprint of all steps. Because the computation graph of any DNN is available, we can calculate the memory footprint of each operator, and the memory footprint of each step during the inference process can thus be calculated using the formula (2) to (4). The memory footprint at step t includes two parts. (1) Memory to be used by both inference and recomputation in later steps, and the input tensors for step t, which is It−1 in Formula (2). (2) Pre-allocated buffer space for output tensors at step t, i.e., the right term in formula (2). After step t, any memory space occupied by unnecessary tensors that will not be used in subsequent steps will be released, as described in formula (3) and (4). We define I0 be the input of the computation graph G. We notate the inference scheduling sequence as H , and |H | is its inference step number. Based on the above analysis, we can get the memory footprint at each step, i.e., M1 , · · · , M|H | . The problem of finding the minimum peak memory can be formulated as follows. min max(M1 , · · · , M|H | )
(1)
Mt = It−1 + mem_usg(output of vt )
(2)
s.t. for every step t,
It = Mt −
ei ∈free(t)
mem_usg(ei ), and I0 = the size of input
(3)
20
X. Sun et al.
free(t) = {tensors to be freed after step t}
(4)
where vt is the operator at step t in inference schedule sequence H , and the function mem_usg() presents the size of memory for saving tensor. free (t) represents the tensors to be freed after step t, which can be get based on the inference schedule sequence and the computation graph.
4 Algorithm Design Finding the optimal inference scheduling sequence with recomputation is obviously difficult even using brutal searching. Fortunately, we find that there are many URSs in a computation graph. As we have talked about, a URS is an indecomposable and candidate memory reduction plan. Therefore, for a computation graph, we can find all its URSs, enumerate all combinations of URSs, and then find a scheduling sequence for each combination case by computing the peak memory footprint of the scheduling sequence. After that, we can identify the inference scheduling sequence with the smallest memory footprint. The above process is described in Fig. 3. For convenience, this method is named Minimizing Memory in Inference by Recomputing (MMIR).
Fig. 3. Overview of the MMIR. For a given computation graph, we first find all URSs and a basic scheduling sequence. Then, we use a heuristic algorithm to generate new scheduling sequences for each combination case and select the one with the lowest peak memory footprint as the optimal inference scheduling sequence.
4.1 Acquisition of URSs A URS may bring memory footprint reductions or not, in other words, it may be good or useless. We find that a good URS always satisfies Property 1. Based on it, useless URS can be discovered rapidly and discarded. e ), e must be the output of a divergence operator. Property 1. For a good URS (− e−→ p s
s
Proof. There are two topology cases for a useless URS. → Case 1. For a URS (− e− p es ), there is no divergence node from ep to es in the computation −→ graph. Just like (l2 l4 ) in Fig. 2.
Minimizing Peak Memory Footprint of Inference on IoTs Devices
21
In this case, ep and es must lie at a chain in the computation graph. Since es is not the output of a divergence operator, and thus it will only be used once in the inference process of the computation graph. In other words, it has not necessary to be regenerated. → So, (e−− p es ) is not a good URS. Case 2. There is a divergence node from ep to es in the computation graph, and es is −→ not the direct output tensor of the divergence node. Just like (lo l2 ) in Fig. 2. Since es is not the direct output tensor of the divergence node, we assume eo is its −→ −→ direct output tensor. For URS (lo l2 ) in Fig. 2, l1 is es actually. Obviously, (lo eo ) is a URS, −→ which contradicts that (lo l2 ) is a URS. Taking the above in all, Property 1 is proved. The primary advantage of good URS lies in memory footprint reduction by saving smaller tensors instead of large ones. For a computation graph, if all good URSs are found, the memory footprint could be reduced during the inference. Of course, brutal searching is not preferred because of high time complexity. Based on the definition and property of URS, we propose two heuristic ideas to accelerate the search process. Heuristic Idea 1. For any divergence node, we start from its direct output tensor and assign it as es , and then perform a backward search to find potential ep . Explanation. Based on Property 1, we know that for any URS, es must be the direct output of a divergence operator. Based on the definition of the URS, we can find ep from es based on the computation graph. Since the candidate nodes for being es are easy to be located, the complexity of finding all good URS could not be high. e−→ e ) whose memory Heuristic Idea 2. For URS searching starting from e , once a URS (− s
p s
gain is less than a given threshold is found, the search from es should be stopped. The threshold is the upper bound of the memory gain of a URS. Explanation. By enumerating all topology sorting cases, we can find a scheduling sequence H with the least peak memory. For the scheduling sequence H , we record its memory footprint Mi at every inference step i, and thus get max(M1 , . . . , M|H | ) − min(M1 , . . . , M|H | ), which is an upper bound of the memory gain recomputation strategies that can be achieved. For convenience, the scheduling sequence H is named as basic scheduling sequence. → For a URS (− e− p es ), if we do compute from ep to es , peak memory during computing is virtually the cost of computing es from ep . Obviously, if the cost is larger than the → memory gain, the URS (− e− p es ) bring no memory saving. Based on the above two heuristic ideas, the algorithm for finding all good URSs is illustrated in Algorithm 1.
22
X. Sun et al.
4.2 Inference Schedule Algorithm with Recomputation Several existing methods can find a basic scheduling sequence H. Based on H, the proposed heuristic algorithm generates a new inference scheduling sequence by modifying H with URSs as follows.
We first construct the power set for all good URSs found. For every set of the power −−→ −− −−−→ ep2 e→ set, w.l.o.g., we notate it as Ƒ and assume Ƒ = − ep1 es1 , − s2 , . . . , epk esk . Then, for esk in Ƒ, we browse H sequentially and label all operators whose input tensor is esk .
Minimizing Peak Memory Footprint of Inference on IoTs Devices
23
We append the computing path from epk toesk before these labeled operators in H except for the first one. In that way, we present an inference scheduling sequence RH based on −→ the URS − e− pk esk . By repeating the above process for each esk in Ƒ with iterative RH, we get an inference scheduling sequence corresponding to the set Ƒ. Now, we can compute the peak memory footprint for recomputing strategy based on Ƒ. In the last step, by comparing all the peak memory footprints corresponding to Ƒ, the inference scheduling sequence with minimum peak memory footprint is found. Algorithm 2 describes the aforementioned steps.
5 Experimental Evaluation 5.1 Customized Network Models To better demonstrate the effectiveness of our algorithms, we manually customized 20 building blocks similar to the Inception block [2]. Each block is named as BxSy, where Bx (In our experiment, x = 2, 3, 4, 5, 6) represents the number of branches in the block, and Sy (In our experiment, y = 3, 4, 5, 6) represents the number of convolution operators on each branch. For every block, there is a convolution operator at its entry and a concatenate operator at the end. So, the number of operators in each block is equal to xy + 2. Since a DNN is usually constructed by stacking the same network blocks sequentially, its peak memory footprint is the same as that of its building block. So, we focus on the peak memory footprint of these building blocks. Table 1. Parameter setting of building blocks B x Sy
output channel number after the convolution operator at the entry
output channel number in a branch
B x S3
24
8, 16, 4
B x S4
32
8, 16, 4, 4
B x S5
16
8, 16, 8, 16,4
B x S6
32
16, 16, 8, 16, 8, 4
Table 1 lists the structure of the building blocks. For example, for BxS3 (x = 2, 3, 4, 5, 6), the output channel number of the convolution operator at the entry is 24, and the output channel numbers of three convolution operators in a branch are 8, 16, and 4, respectively. Besides, the channel number of input for building blocks is 3. The topology and dimension settings are set deliberately to demonstrate the effectiveness of our algorithm. First, based on Property 1, the multichannel structure is set for forming divergence operators, which are essential for creating URS. Second, the output channel number of the entry convolution operator should be larger than its input channel number, to be consistent with the definition of URS. In fact, the structure of channel amplification is common [2, 12]. Third, based on Heuristic idea 2, if the output channel number in a branch is large, the threshold in Heuristic idea 2 will be small, and thus the gain by
24
X. Sun et al.
1.08x 1.08x
1.00x 1.13x
1.06x 1.13x
1.07x 1.14x
1.00x 1.13x
1.08x 1.08x
1.07x 1.07x
1.06x 1.13x
1.00x 1.14x
1.09x 1.09x
1.07x 1.14x
1.08x 1.17x
1.10x 1.10x
1.00x 1.15x
1.07x 1.07x
1.08x 1.18x
1.00x 1.23x
1.25x 1.25x
1.50 1.25 1.00 0.75 0.50 0.25 0.00
1.17x 1.17x
TFLM SERENITY MMIR 1.20x 1.20x
Reduction in Peak Memory
recomputing could not be prominent. So, output channel numbers of the convolution operators in a branch are deliberately set to be smaller than that of the entry convolution operator. We compare our recomputing approach with TensorFlow Lite Micro (TFLM) and SERENITY. TFLM is based on topology sorting of the computing graph, while SERENITY is the state-of-the-art memory-constrained operator scheduling algorithm using dynamic programming. To have a fair comparison, we realize the above three algorithms on the same platform in our lab. Figure 4 illustrated the comparison results in reduction in peak memory footprint. In Fig. 4, the peak memory footprint produced by the schedule of TFLM is set as the baseline. It can be seen that MMIR achieves a peak memory reduction of 1.07x to 1.25x compared to TFLM. And MMIR provides the additional peak memory reduction of 0.07x to 0.23x compared to SERENITY on different network blocks.
B 2 S 3 B 2 S 4 B 2 S 5 B 2 S 6 B 3 S 3 B 3 S 4 B 3 S 5 B 3 S 6 B 4 S 3 B4 S 4 B 4 S 5 B4 S 6 B 5 S 3 B 5 S 4 B 5 S 5 B5 S 6 B 6 S 3 B 6 S 4 B6 S 5 B 6 S 6
Fig. 4. Reduction in peak memory footprint of our approach against TensorFlow Lite Micro (TFLM) and SERENITY.
B5S5
B5S6
B6S3
B6S4
1.07x
B5S3
Original MMIR 1.00x
B4S6
1.17x
B4S5
1.16x
B4S4
1.07x
B4S3
1.00x
B3S6
1.16x
B3S5
1.00x
B3S4
1.06x
1.15x
B3S3
1.00x
B2S6
1.14x
B2S5
1.05x
B2S4
1.00x
1.12x
B2S3
1.00x
1.04x
1.00
1.00x
1.25
1.00x
1.50
1.00x
Extra computation cost
Meanwhile, we observe that for certain networks such as B2 S3 , B2 S4 , B2 S5 , B3 S4 , B3 S5 , B4 S5 , B5 S3 , B5 S5 , B6 S5, MMIR and SERENITY achieve the same gain. The reason is as follows. For an inference scheduling sequence RH caused by only one → URS, say (− e− p es ), if the inference step when peak memory occurs has no intersection with recomputation paths from ep to es in RH, the peak memory of RH will have no → relationship with the URS (− e− p es ). If the above fact takes place for each URS in a URS combination, no peak memory footprint gain will be achieved. Our algorithm incurs additional computational costs because of the recomputation, and we demonstrate the extra computational workload by calculating the FLOPs of different scheduling sequences. For the 20 customized networks, the computational workload ranges from 1.04x to 1.17x (Fig. 5).
B6S5
B6S6
0.75 0.50 0.25 0.00
B5S4
Fig. 5. Computational costs Increment of MMIR
Minimizing Peak Memory Footprint of Inference on IoTs Devices
25
5.2 Real-World Network Models When selecting network models for evaluation in the real world, we found that URS was not found in some typical models nowadays, which means no advantages compared to existing memory reduction methods. In order to demonstrate that MMIR works better than existing methods on real DNN models, we made slight adjustments to three wellknown neural networks: ResNet-18, InceptionV1, and InceptionV3. These adjustments aim to ensure their accuracy performances are satisfied in real applications, and both the original and the adjusted models are trained with CIFAR-10 dataset. Table 2. Comparison of memory footprint gains Model
TFLM(KB)
SERENITY(KB)
MMIR(KB)
Accuracy
ResNet-18
−
−
−
0.7171
ResNet-18-A
352
352 (1x)
332 (1.06x)
0.6969
ResNet-18-B
336
336 (1x)
332 (1.01x)
0.6845
InceptionV1
−
−
−
0.74
InceptionV1-A
496
400 (1.24x)
396 (1.25x)
0.7631
InceptionV1-B
384
352 (1.09x)
320 (1.2x)
0.7471
InceptionV3
−
−
−
0.8859
InceptionV3-A
8019
7290 (1.1x)
6561 (1.22x)
0.8917
The experimental results in Table 2 show that for ResNet-18-A/B, SERENITY cannot reduce memory footprint, while MMIR can achieve a reduction of 1.01x/1.06x in peak memory footprint. In InceptionV1-A/B, MMIR can reduce memory footprint by 1.2x and 1.25x compared to TFLM, and achieve an improvement of 0.01x and 0.11x compared to SERENITY. For InceptionV3-A, MMIR can achieve a reduction of 1.22x in peak memory footprint which is 0.11x higher than SERENITY. These results demonstrate that MMIR is effective for real DNN models.
6 Conclusions As a new form of intelligence, edge intelligence requires a systematic solution to the problem of insufficient memory during inference. In this paper, we introduce the idea of recomputation into neural network inference, which reduces peak memory footprint by expanding inference scheduling space. Through realistic experiments on customized and typical DNN models, we demonstrate that our proposed method can achieve lower peak memory footprint compared to SOTA methods. The advantage of our algorithm lies in its use of additional computation cost to further reduce peak memory footprint. An increase in offline computation cost is not important for IoT devices, since lower peak memory requirement is vital for their usability. Through our approach, it is very worthwhile to pay a small cost in exchange for the ability to run more DNN models
26
X. Sun et al.
on IoT devices. However, the effectiveness of our method is limited by the structure of the original neural network computation graph. If no good URS is found within the computation graph, our algorithm will achieve no improvement. Overall, our work has significant significance for the development of edge intelligence. Acknowledgement. This study is supported by National Key R&D Program of China (2022YFB4501600).
References 1. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778. IEEE, Las Vegas, NV, USA (2016) 2. Szegedy, C., et al.: Going deeper with convolutions. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1–9. IEEE, Boston, MA, USA (2015) 3. Wang, Z., Wan, C., Chen, Y., Lin, Z., Jiang, H., Qiao, L.: Hierarchical memory-constrained operator scheduling of neural architecture search networks. In: Proceedings of the 59th ACM/IEEE Design Automation Conference, pp. 493–498. ACM, San Francisco, California, USA (2022) 4. Chen, T., Xu, B., Zhang, C., Guestrin, C.: Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174 (2016) 5. Ahn, B.H., Lee, J., Lin, J.M., Cheng, H.P., Hou, J., Esmaeilzadeh, H.: Ordering chaos: memory-aware scheduling of irregularly wired neural networks for edge devices. In: MLSys, pp. 44–57. Austin, TX, USA (2020) 6. Ding, Y., Zhu, L., Jia, Z., Pekhimenko, G., Han, S.: Ios: Inter-operator scheduler for cnn acceleration. In: MLsys, pp. 167–180. Virual Conference (2021) 7. Ma, L., et al.: Rammer: enabling holistic deep learning compiler optimizations with rTasks. In: Proceedings of the 14th USENIX Conference on Operating Systems Design and Implementation, pp. 881–897. USENIX, Virtual Conference (2020) 8. Lin, J., Chen, W.M., Lin, Y., Cohn, J., Gan, C., Han, S.: Mcunet: tiny deep learning on iot devices. In: NeurIPS, pp. 11711–11722. MIT Press, Virtual Conference (2020) 9. Lin, J., Chen, W.M., Cai, H., Gan, C., Han, S.: Memory-efficient patch-based inference for tiny deep learning. In: NeurIPS, pp. 2346–2358. MIT Press, Virtual Conference (2021) 10. Kumar, R., Purohit, M., Svitkina, Z., Vee, E., Wang, J.: Efficient rematerialization for deep networks. In: NeurIPS, pp. 1–10. MIT Press, Vancouver, BC, Canada (2019) 11. Jain, P., et al.: Checkmate: breaking the memory wall with optimal tensor rematerialization. In: MLSys, pp. 497–511. Austin, TX, USA (2020) 12. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826. IEEE, Las Vegas, NV, USA (2016)
AGAM-SLAM: An Adaptive Dynamic Scene Semantic SLAM Method Based on GAM Dupeng Cai1 , Zhuhua Hu1(B) , Ruoqing Li2 , Hao Qi1 , Yunfeng Xiang1 , and Yaochi Zhao2 1 School of Information and Communication Engineering, Hainan University, Haikou 570228,
China [email protected] 2 School of Cyberspace Security (School of Cryptology), Hainan University, Haikou 570228, China
Abstract. With rapid developments in the fields of autonomous driving, robot navigation, and augmented reality, visual SLAM technology has become one of the core key technologies. While VSLAM systems perform more consistently in static scenes, introducing dynamic objects such as people, vehicles, and animals into the scene makes reliable map building and localization more difficult, and accurate trajectory estimation more challenging to achieve. In this paper, we propose a semantic VSLAM system based on the Global attention mechanism (GAM) and adaptive thresholding. First, GAM improves the segmentation accuracy of the Mask R-CNN network model for dynamic objects and eliminates the influence of dynamic objects on the VSLAM system. In addition, adaptive thresholding generates adaptive factors based on the number of key points extracted in the scene and dynamically adjusts the FAST threshold, which enables more stable extraction of feature points in dynamic scenes. We have verified our approach on the TUM public dataset, and compared with the DynaSLAM method. The absolute trajectory error (ATE) and relative trajectory error (RPE) are reduced to some extent on its dataset. Especially on the W_rpy dataset, the accuracy of our method is improved by 38.78%. The experimental results show that our proposed method can significantly improve the overall performance of the system in highly dynamic environments. Keywords: VSLAM · GAM · Highly dynamic environment · deep learning · adaptive thresholding
1 Introduction With the rapid development of artificial intelligence, significant progress has been made in the fields of autonomous driving, robot navigation, and augmented reality [1–3]. Among them, Simultaneous Localization and Mapping (SLAM) technology is one of This work was supported in part by the Key Research and Development Project of Hainan Province (ZDYF2022GXJS348, ZDYF2022SHFZ039), the Hainan Province Natural Science Foundation (623RC446) and the National Natural Science Foun- dation of China (62161010, 61963012). The authors would like to thank the referees for their constructive suggestions. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 27–39, 2023. https://doi.org/10.1007/978-981-99-4761-4_3
28
D. Cai et al.
the key pillars in these fields. SLAM provides the basic navigation and perception functions for various intelligent systems by acquiring sensor data on unknown environment in real time while accomplishing its own position localization and environment map construction. And then with the development of computer vision, it gradually evolved into VSLAM based on visual information, which uses camera sensors to obtain information from the environment to estimate the robot’s positional and map [4]. Nowadays, with the development of deep learning, Semantic VSLAM has attracted a lot of attention in the VSLAM field, which not only considers the spatial location and map information of the robot, but also uses semantic information to improve the performance of VSLAM [5]. However, semantic VSLAM still faces numerous challenges in dynamic scenes [6]: firstly, semantic VSLAM often has unstable matching in the face of fast-moving objects, which can affect the segmentation accuracy of the semantic segmentation network. Secondly, there may be occlusion relationships between objects, which makes it difficult for the semantic segmentation network to accurately recognize objects in the case that only some areas of some objects are visible. Finally, in the case of large illumination changes or low texture, the fixed FAST threshold cannot adapt to the environmental changes, which will lead to a decrease in the number of detected feature points, thus affecting the performance of the semantic VSLAM system. 1.1 Related Work and Motivation Mur-Artal et al. introduced ORBSLAM [7], a monocular SLAM algorithm that utilizes ORB feature points. This algorithm comprises essential components including initialization, tracking, and graph building. Nonetheless, its robustness is limited when faced with challenges such as motion blur, variations in illumination, and occlusions. Additionally, the speed of graph drawing is relatively slow. In response to these issues, the authors proposed version 2.0 based on ORB-SLAM [8]. This updated version not only addresses the aforementioned problems but also incorporates interfaces for binocular and RGB-D cameras to enhance robustness. Nevertheless, it is worth noting that ORBSLAM2 may exhibit instability when confronted with fast-moving objects. Instances of instability can arise due to significant disparities between adjacent frame views, excessive displacement, unreliable feature point matching, mismatching, missed matching, and matching failure caused by changes in feature point descriptors. To address the challenges posed by dynamic scenes, researchers have explored the application of deep learning techniques in dynamic SLAM. Yu et al. [9] introduced DSSLAM, which combines a semantic segmentation network with a mobile consistency checking method to mitigate the influence of dynamic objects. However, this method is relatively complex. Zhao et al. [10] proposed a monocular SLAM approach that incorporates the L-VO neural network [11] and dense 3D mapping for localization. This method effectively handles dynamic objects and surpasses most monocular SLAM techniques, but it may not be suitable for complex scenes. Wu et al. [12] integrated the YOLOv3 [13] object detection network with geometric constraints to detect dynamic objects in the environment and mitigate their impact, thereby improving the accuracy and robustness of the SLAM system. However, this approach has certain limitations in terms of accuracy. Bescos et al. [14] presented a SLAM system designed for dynamic
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
29
scenes, which combines depth estimation, semantic segmentation, and motion modeling to achieve segmentation and removal of dynamic objects. This method enhances the robustness and accuracy of the SLAM system. However, the detection accuracy of this approach still requires improvement. Although these methods partially mitigate the challenges posed by dynamic objects, they still encounter difficulties when dealing with fast-moving objects or objects with significant deformations. Additionally, unconventional environments such as severe occlusions present further limitations to existing approaches. Illumination variation is indeed a crucial factor that affects the performance of SLAM systems. Several researchers have addressed this issue by enhancing feature extraction methods.H. Alismail et al. [15] introduced a direct visual odometry approach that utilizes binary descriptors. This method improves feature point extraction in challenging illumination conditions. However, its performance advantage is primarily observed in low illumination environments. Zhang et al. [16] proposed LF-Net, a method for learning local features from images, which achieves superior feature point extraction and matching. Nonetheless, this approach requires a substantial amount of training data and computational resources. In general, these methods have made significant progress in feature point extraction and adaptive thresholding strategies. However, it is important to adapt and optimize these techniques appropriately for different scenarios and device limitations in practical applications. Compared to SLAM, VSLAM can model the scene quickly and perform localization and pose updates in real-time [8]. However, in dynamic scenes, the presence of dynamic objects can affect the information of the VSLAM system and even degrade its accuracy. To solve these problems, some scholars have achieved more desirable results by adding semantic information to VSLAM. Such methods are often able to classify and distinguish objects based on semantic knowledge, thus reducing the impact of dynamic objects on localization and map building. However, in dynamic scenes, semantic VSLAM faces two main problems: (1) In terms of feature extraction, the instability of lighting conditions, rotation, and scale change performance can affect the feature extraction matching ability of the VSLAM system, leading to the system being prone to tracking losses. Moreover, the fixed FAST threshold of the VSLAM system may affect its performance in scenes with large changes in illumination intensity or texture characteristics. Therefore, modifying the VSLAM system with adaptive thresholding becomes especially important. (2) In terms of semantic segmentation, semantic VSLAM may encounter problems such as increased variation between consecutive frames and occlusion relationships when facing fast-moving objects. This can make it difficult for the semantic segmentation network to accurately identify the boundaries and classes of objects and reduce the segmentation accuracy of the semantic VSLAM system, which can affect its overall performance.
30
D. Cai et al.
1.2 Key Contributions of the Paper In Sect. 1.1, we elaborate the found problems in more detail. In response, this paper proposes an adaptive FAST threshold to improve the stability of feature extraction and focus on more important information by introducing a GAM attention mechanism. The method in this paper brings significant performance improvement for semantic VSLAM, and the main contributions are as follows: (1) In semantic VSLAM segmentation of dynamic objects, we note that focusing on more important feature information can improve the accuracy of object boundary and category recognition. Therefore, we combine global attention mechanism (GAM) in Mask R-CNN network model to make the model further consider global features based on extracting local features and effectively improve the segmentation accuracy of dynamic objects. (2) The overall performance of VSLAM is influenced to some extent by the feature points, and single extraction of feature points by larger or smaller thresholds tends to miss important feature points or introduce large amounts of noise. We improve the accuracy and robustness of feature point extraction by adaptively adjusting the thresholds considering the local contrast and background noise level in the image to meet the needs of subsequent localization algorithms. (3) As we all know, a more comprehensive understanding of the environment can effectively improve the performance of semantic VSLAM in dynamic scenarios. We integrate DynaSLAM based on its advantages in dynamic scenes with an optimized Mask R-CNN and an adaptive threshold, which extracts highly stable features, to provide more accurate and stable scene perception and object detection capabilities for real-time visual SLAM systems. 1.3 Paper Structure The rest of this paper is organized as follows. Section 1 introduces the significance and background of the research content and summarizes the current state of research on visual SLAM in dynamic scenes in recent years. Section 2 introduces the attention-based Mask R-CNN network, adaptive thresholding, and semantic SLAM system proposed in this paper. In Sect. 3, the experimental results are analyzed and discussed, and the accuracy of the proposed model is evaluated. Section 4 summarizes the work of this paper and suggests areas for improvement.
2 Proposed Method In this paper, we propose a new system that is an improvement on the ORB-SLAM2 [8]. The general flow of the system is shown in Fig. 1, where the block diagram with blue background shows the main work of this paper. Firstly, adaptive thresholding is introduced in the feature extraction and matching process to effectively adjust the number of feature points extracted in a dynamic environment. Second, we introduce the global attention mechanism GAM [17] into the front-end of the Mask R-CNN [18] semantic segmentation network to identify dynamic objects by motion consistency detection [14].
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
31
We perform semantic segmentation of dynamic objects and provide the segmentation results as a priori information to the VSLAM system. Then, the front-end tracking thread determines the contour of the dynamic object based on the mask and eliminates the feature points on the dynamic object. Finally, the system uses static feature points for local mapping and closed-loop detection to improve the accuracy and robustness of the visual SLAM system in dynamic environments.
Fig. 1. The System Model of AGAM-SLAM.
2.1 Global Attention Module The introduction of an attention module is aimed at enhancing the performance and effectiveness of the model by enabling it to automatically focus on and assign weights to important information in the input. Currently, attention mechanisms can be categorized into three main types: self-attention mechanism, local attention mechanism, and global attention mechanism. The self-attention mechanism is suitable for modeling sequential data, allowing the model to capture dependencies between different positions in the sequence. The local attention mechanism, on the other hand, reduces computational complexity by limiting attention to a local region of the input. Lastly, the global attention mechanism captures global contextual information by considering the entire input. The Global Attention Mechanism (GAM) [17] has a wide scope and can capture global dependency relationships within a sequence. It effectively learns global contextual information across the entire sequence, enabling it to better capture long-range dependencies between different positions. GAM is an improvement upon the Convolutional Block Attention Module (CBAM) [19]. GAM primarily consists of two components: channel attention and spatial attention. The channel attention mechanism captures global information in the channel dimension of the feature map, while the spatial attention mechanism captures spatial information within the feature map. By incorporating the GAM
32
D. Cai et al.
attention mechanism, the proposed method in this paper leverages the ability to capture global dependencies and contextual information, leading to improved performance in the semantic VSLAM tasks. GAM can capture the global information of input features at different levels by combining channel attention and spatial attention. Integrating GAM into the Mask RCNN network model can help the model focus on more important features and thus improve its performance on the target task. The network model of GAM is shown in Fig. 2.
Permutation C×W×Hα W×H×C
MLP
reverse premutaon
C×H×W
sigmoid
MC(F1)
Input feature F1
Channel Attention MC
C/r × W × H 7×7 Conv
C×H×W
sigmoid
7×7 Conv
Input feature F2
Ms(F2)
Spatial Attention Ms
Input feature F1
Output feature F
Fig. 2. GAM Network Model.
2.2 Improved Mask R-CNN Network Model Based on GAM Module The Mask R-CNN [18] network model is an improvement on the Faster R-CNN [20] for target detection and segmentation. The main components of the Mask R-CNN network model include Backbone, Region Proposal Network (RPN), ROI Align Feature Pyramid Network (FPN), etc. In the face of dynamic objects, Mask R-CNN identifies dynamic objects and contours the regions in motion. After multiple convolutions through the feature extraction network, these incomplete object or moving object features are easily rejected by the network model. By focusing on more important feature information obviously helps the model performance, so we combine GAM into the Mask R-CNN model as described bellow (1) The original image is passed through the ResNet101 [21] backbone network to obtain feature maps C1, C2, C3, C4, C5.The stride of C2, C3, C4, and C5 compared to the original image is 4, 8, 16, and 32, respectively;(2) Apply GAM to each feature map: on the one hand, apply channel attention and spatial attention to feature map C5 to obtain the GAM-processed feature maps. On the other hand, similar operations are performed for the other feature maps (C4, C3, C2) in turn, each time upsampling the results of the previous layer and adding them to the current layer, and then applying GAM attention; (3) applying the FPN part of Mask R-CNN to the GAM-processed feature maps to generate pyramid feature maps P1, P2, P3, P4; (4) transferring these pyramid feature maps are passed to the RPN to generate suggestion frames; (5) ROI Align operation is performed on the suggestion frames to align the features; (6) the aligned features are passed to the head of Mask R-CNN, including classification, bounding box regression and segmentation mask prediction; (7) the instance segmentation results are generated
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
33
based on the network output decoding prediction results. The improved model is shown in Fig. 3 below: image P2
C1 C2
GAM
C3
GAM
C4
P3
Conv RPN
GAM
ROI Align
P4 C5
GAM
Full Connected Layer
Mask Class Box Reg
P5 P6 Fig. 3. Improved Mask R-CNN Model with GAM.
2.3 Adaptive Thresholds The robustness of ORB-SLAM2 [8] may be affected in a dynamic environment. To improve the robustness of the system in dynamic environments, we add adaptive thresholds to it. The FAST threshold is adjusted according to the actual number of extracted keypoints. The equations for the adaptive factor and the desired number of extracted keypoints are as follows: DAF =
NKP_desired NKP_current
NKP_desired = nfeatures × 0.9
(1) (2)
DAF represents the adaptive factor, which is employed to adjust the FAST threshold. NKP_desired denotes the desired number of keypoints to be extracted, while NKP_current refers to the actual number of keypoints currently extracted. In Equation (2), nfeatures represents the number of keypoints to be extracted for each layer of the image. Two thresholds are available: iniThFAST threshold and minThFAST threshold. In the initial case, the iniThFAST threshold is set to 20 and the minThFAST threshold is set to 7. By multiplying the adaptive factor, the FAST threshold can be dynamically adjusted based on the actual number of keypoints extracted.If the actual number of keypoints is less than the desired value, the threshold will be lowered to extract more keypoints;
34
D. Cai et al.
conversely, if the number of keypoints is greater than the desired value, the threshold will be raised to reduce the number of keypoints.Such adjustments can help the system to better extract critical points and improve robustness in dynamic environments.The adaptation threshold process is shown in Fig. 4.
RGB-D image
Image pyramid
ScaleFactor
NKP too low? N
N KP _ desired N KP _ current
Y
iniThFAST
DAF
minThFAST
Feature extracƟon
NKP_current
Output
NKP_desired KeyPoint
Descriptor
Fig. 4. Flowchart of Adaptive Threshold.
3 Experiment We conducted an experimental comparison of AGAM-SLAM and DynaSLAM [14] using an Intel Core i9-10900X CPU and an NVIDIA GeForce RTX 3080 GPU. We rst retrain the optimized Mask R-CNN model on the COCO 2014 dataset [22]. The COCO dataset contains more than 80 categories, including people, vehicles, animals, etc. The training results can be satisfied for visual SLAM systems used in dynamic environments. We validate the effectiveness of our method on the TUM dataset. The dataset mainly consists of scenes of people sitting on chairs and walking, with the camera moving in different directions, thus the dataset has a dynamic effect on the detection of the camera. Among them, fr3_sitting is a low dynamic sequence and fr3_walking is a high dynamic sequence. In these two types of sequences, there are four types of camera motion trajectories: xyz means the camera moves along xyz axis, halfspere means the camera moves along hemispherical trajectory, rpy means the camera rotates along three axes, and static means the camera is stationary.
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
35
3.1 Evaluation Metrics To analyze the experimental results, we utilized Absolute Trajectory Error (ATE) and Relative Pose Error (RPE) as evaluation metrics [23]. RMSE represents the root mean square error, where P1 ,…, Pn ∈ SE(3) denotes the pose estimate, Q1 ,…, Qn ∈ SE(3) denotes the true pose, and Δ denotes the time interval. ATE is the direct difference between the estimated pose and the true pose. After aligning the timestamps, the difference between each pair of poses is calculated. The ATE and RMSE formulas of the i-th frame are as follows: Fi :=Q−1 SPi
(3)
1 1 trans(Fi )2 ) 2 n
n
RMSE(F1:n ) := (
(4)
i=1
The RPE is the amount of change in the real and estimated poses calculated at intervals, and then the difference between the changes is made to obtain the relative positional error. The RPE and RMSE formulas for frame i are as follows: Ei := (Qi−1 Qi+ )−1 (Pi−1 Pi+ )
(5)
1 1 trans(Ei )2 ) 2 m
n
RMSE(E1:n , ) := (
(6)
i=1
In this paper, the errors of AGAM-SLAM and DynaSLAM are compared and the relative improvement rate equation is as n=
DynaSLAM − Ours × 100% DynaSLAM
(7)
Table 1. Absolute trajectory error results (ATE)
ORB-SLAM2 DynaSLAM Ours Improvement
s_static 0.15701 0.09819 0.06293
s_rpy s_hf s_xyz w_static 0.13193 0.04515 0.01618 2.37277 0.22308 0.02284 0.02025 0.14728 0.18167 0.02179 0.02091 0.11914
35.91% 18.56% 4.6%
w_rpy 2.34353 0.06211 0.03802
w_hf 0.74442 0.02472 0.02226
w_xyz 1.54280 0.02122 0.01932
-3.26% 19.11% 38.78% 9.95% 8.95%
36
D. Cai et al.
3.2 Experimental Results We conducted experiments under eight sequences of dynamic scenes, and the experimental results are shown in Tables 1 and 2. In both the high dynamic sequences of walking and the low dynamic sequences of sitting, our model has a more significant improvement in accuracy compared with DynaSLAM. Among them, the ATE of W_static sequence is improved by 19.1%, while the ATE of W_rpy sequence is improved more significantly by 38.78%. Table 2. Relative trajectory error results (RPE)
ORB-SLAM2 DynaSLAM Ours Improvement
s_static 0.00674 0.00617 0.00624
s_rpy 0.01744 0.01995 0.02024
s_hf 0.01242 0.01871 0.01733
s_xyz 0.01133 0.01294 0.01286
-1.13% -1.45% 7.38% 0.62%
w_static 0.02440 0.00766 0.00685
w_rpy 0.03337 0.02276 0.02028
w_hf 0.04159 0.01601 0.01353
w_xyz 0.03746 0.01475 0.01054
10.57% 10.9% 15.49% 28.54%
Figure 5 shows the absolute trajectory error plots for ORB-SLAM2, DynaSLAM and our model in a highly dynamic environment. The black line in the figure indicates the true trajectory, the blue line indicates the estimated trajectory, and the red line indicates the difference between the true and estimated trajectories of the camera motion. Compared with DynasSLAM, the absolute trajectory error of AGAM-SLAM on the W_rpy dataset is significantly declined and its estimated trajectory has a high fit to the real trajectory. Obviously, our method can remove dynamic objects more accurately and effectively reduce the impact of dynamic objects on the SLAM system. 3.3 Discussion By comparing the data, AGAM-SLAM exhibits higher accuracy and better bit pose estimation performance than DynaSLAM in a highly dynamic environment, especially in W_static and W_rpy sequences. This is mainly attributed to the addition of adaptive thresholding and GAM to our VSLAM system. Dynamic adaptive thresholding enables more stable extraction of feature points. Meanwhile, GAM enhances the Mask R-CNN network model’s ability to extract global and local features in objects, thus improving the segmentation accuracy of dynamic objects.
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
(a)
ORB-SLAM2_w_static
(d)
ORB-SLAM2_w_hf
(g)
ORB-SLAM2_w_rpy
(j)
ORB-SLAM2_w_xyz
(b)
DynaSLAM_w_static
(e)
(c)
37
AGAM-SLAM_w_static
DynaSLAM_w_hf
(f)
(h)
DynaSLAM_w_rpy
(i)
AGAM-SLAM_w_rpy
(k)
DynaSLAM_w_xyz
(l)
AGAM-SLAM_w_xyz
AGAM-SLAM_w_hf
Fig. 5. Comparison of estimated trajectories and real trajectories in highly dynamic environments.
4 Conclusion To eliminate the effect of dynamic objects on the visual SLAM system. We integrated the GAM [17] module into the Mask R-CNN [18] network model and retrained the network model into the DynaSLAM system. Meanwhile, we added adaptive thresholding to VSLAM. First, adaptive thresholding can improve the number of features extracted in dynamic environments and reduce the occurrence of phenomena such as broken frames and missed frames, thus improving the robustness of the system. Among them, we obtain the mask of dynamic objects by semantic segmentation module, and input the mask as a priori information into the tracking thread of SLAM to remove dynamic feature points and keep static feature points for tracking. Finally, we perform an experimental validation
38
D. Cai et al.
on the TUM dataset, and the results show that the overall performance of our proposed system is more significantly improved in a highly dynamic environment compared to DynaSLAM. However, the system in this paper still has some shortcomings. First, the semantic segmentation network still has some room for optimization, and the accuracy of network segmentation needs to be further improved. Secondly, the Mask R-CNN network in this paper is more complex, and the addition of GAM adds a certain burden on the computing power of the computer. Finally, the system cannot be used in real-time in embedded devices, and we will try some lightweight semantic segmentation networks in the next work to improve the operation speed of the system while ensuring the accuracy of network segmentation. Acknowledgements. We thank Shenzhen Umouse Technology Development Co., Ltd. For their support in equipments and experimental conditions.
References 1. Chang ,Y., Tian, Y., How, J.P., Carlone, L.: Kimera-multi: a system for distributed multirobot metric-semantic simultaneous localization and mapping. In 2021 IEEE International Conference on Robotics and Automation (ICRA), pp. 11210−11218 (IEEE) 2. Cheng, J., Zhang, L., Chen, Q., Hu, X., Cai, J.: A review of visual slam methods for autonomous driving vehicles. Appl. Artif. Intell. 114, 104992 (2022) 3. Jinyu, L., Bangbang, Y., Danpeng, C., Nan, W., Guofeng, Z., Hujun B.: Survey and evaluation of monocular visual-inertial slam algorithms for augmented reality. Virt. Real. Intell. Hardw. 1(4), 386 410 (2019) 4. Liu, Y., Miura J.: Rds-slam: Real-time dynamic slam using semantic segmentation methods. IEEE Access 9, 23772−23785 (2021) 5. Wang, H., Ko, J.Y., Xie, L.: Multi-modal semantic slam for complex dynamic environments (2022) 6. Li, A., Wang, J., Xu, M., Chen, Z.: DP-SLAM: A visual slam with moving probability towards dynamic environments. Science 556, 128−142 (2021) 7. Mur-Artal, R., Montiel, J.M., Tardos, J.D.: ORB-SLAM: a versatile and accurate monocular SLAM system. IEEE Trans. Robot. 31(5), 1147−1163 (2015) 8. Mur-Artal, R., Tardós, J.D.: Orb-slam2: an open-source slam system for monocular, stereo, and RGB-D cameras. IEEE Trans. Robot. 33(5), 1255−1262 (2017) 9. Yu, C., et al.: DS- SLAM: a semantic visual slam towards dynamic environments. In: 2018 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), pp. 1168−1174. IEEE (2018) 10. Zhao, C., Sun, L., Purkait, P., Duckett, T., Stolkin, R.: Learning monocular visual odometry with dense 3d mapping from dense 3d ow. In: 2018 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), pp. 6864 6871. IEEE (2018) 11. Konda, K.R., Memisevic, R.: Learning visual odometry with a convolutional network. VISAPP 486(490), 2015 (2015) 12. Wu, W., Guo, L., Gao, H., You, Z., Liu, Y., Chen, Z.: Yolo-slam: a semantic slam system towards dynamic environment with geometric constraint. Neural Comput. Appl. 1−16 (2022) 13. Redmon, J., Farhadi, A.: Yolov3: An incremental improvement (2018) 14. Bescos, B., Fácil, J.M., Civera, J., Neira, J.: DynaSLAM: Tracking, mapping, and inpainting in dynamic scenes. IEEE Robot. Autom. Lett. 3(4), 4076−4083 (2018)
AGAM-SLAM: An Adaptive Dynamic Scene Semantic
39
15. Alismail, H., Kaess, M., Browning, B., Lucey, S.: Direct visual odometry in low light using binary descriptors. IEEE Robot. Autom. Lett. 2(2), 444−451(2016) 16. Ono, Y., Trulls, E., Fua, P., Yi, K.M.: Learning local features from images. Inf. Process. Syst. Yi. Lf-net. 31 (2018) 17. Liu, Y., Shao, Z., Hoffmann, N.: Global attention mechanism: retain information to enhance channel-spatial interactions (2021) 18. He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2961−2969 (2017) 19. Woo, S., Park, J., Lee, J.Y., Kweon, I.S.: Cbam: convolutional block attention module. In Proceedings of the European Conference on Computer Vision (ECCV), pp. 3−19 (2318) 20. Girshick, R.: Fast r-cnn. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1440−1448 (2015) 21. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770−778 (2016) 22. Lin, T.-Y., et al.: Microsoft COCO: common objects in context. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) ECCV 2014. LNCS, vol. 8693, pp. 740–755. Springer, Cham (2014). https://doi.org/10.1007/978-3-319-10602-1_48 23. Sturm, J., Engelhard, N., Endres, F., Burgard, W., Cremers, D.: A benchmark for the evaluation of RGB-D SLAM systems. In: 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, pp. 573 580. IEEE (2012)
A Water Level Ruler Recognition Method Based on Deep Learning Technology Jingbo An1 , Kefeng Song2(B) , Di Wu3(B) , and Wanxian He3(B) 1 Yellow River Water and Hydropower Development Group Co., Ltd., Chengdu, China 2 Yellow River Engineering Consulting Co., Ltd., Zhengzhou 450003, China 3 Guangxi Academy of Sciences, No. 98 Daling Road, Xixiangtang, Nanning, Guangxi, China
[email protected], [email protected]
Abstract. Water level data is important information for flood control command and dispatch, and monitoring the water level accurately and effectively is essential for irrigation works. A water level ruler is one of the most common tools for measuring water level owing to its low cost and simplicity of use. Recently, some researchers attempted to use computer vision technology to automatically identify the scale value of the water ruler. However, most of them produced unsatisfactory results due to shifting illumination, background, and other environmental variables. To improve the robustness of the model, we proposed a water-level ruler recognition method based on deep learning technology. Specifically, we first used a deep learning-based object detection model to detect the water ruler. Then the character ‘E’ of the ruler was further detected by the Yolov5 model. Finally, we designed a CRNN-based model to recognize the detected characters and used the total length of the ruler to subtract the detected length of characters to gain the water level. We collected a water ruler dataset from the realistic scene, which contains 1160 ruler images and 439 character images. Experiment results on the collected dataset and the real scene show that our approach achieved good performance, which has applicative value and brilliant prospects. Keywords: deep learning · water level monitoring · water ruler
1 Introduction The water level of the river is a basic hydrological observation element. Monitoring the water level effectively and accurately is of great significance for reducing the risk of flooding disasters and ensuring the safety of residents along the river. A common approach to obtaining the water level is using the water level sensor, such as ultrasonic, rangefinder, optical sensor, and pressure sensor. However, these sensors exist the following shortcomings: the high cost of installation and maintenance, the sensitivity to the environment, and so on. Since the low cost of the water ruler and the measurement results can be obtained from visual inspection, many Hydrological observation stations rely on manpower to use the water ruler to measure the water level [1]. Nevertheless, some observation points are in remote and dangerous areas, and relying on artificial to read the water level is inefficient, labor-intensive and inconvenient, which is difficult to meet the high-precision and high-efficiency requirements at the same time. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 40–50, 2023. https://doi.org/10.1007/978-981-99-4761-4_4
A Water Level Ruler Recognition Method
41
Benefitting from the development of computer science, many scholars have attempted to solve the task using computer vision technology. Iwahashi et al. [2] used an image operator to obtain the water line in sequential video frames, but it failed to address an irregular line and ran out of time. Feng et al. [3] obtained better results using additional artificial image features and machine learning technology. Later, more applicable approaches based on digital image processing were presented [5–7]. Since 2013, deep learning has started to show dominance in computer vision [8–11], and various deep convolutional neural network models have been applied for detecting water level rulers or recognizing water lines with higher accuracy. Dou et al. [12] proposed a water level recognition method based on a convolutional neural network and image processing technology. First, they used digital image processing techniques to pre-process the source images. Then the CNN was designed to identify the values of digital characters. Finally, the water level value was calculated by the mathematical relationship between scale lines and the value of characters. Jafari et al. [4] used reference objects in videos and images to estimate water levels with time. In particular, they developed a deep learningbased segmentation technique to estimate flood levels and verified the algorithm with laboratory experiments and two urban streams. Although the above-mentioned methods acquired certain effects, most of them are limited to a few specific situations, which hampers by shifting illumination, an unsteady background, and other environmental factors. Moreover, maintaining model accuracy while improving the efficiency of the water level recognition model is still a challenging problem. To improve the robustness and efficiency of the recognition algorithm, a novel water level recognition method based on deep learning technology is proposed in this study. The method mainly contains three major modules: 1) water level ruler detection module, 2) character detection module and 3) character recognition module. In module 1), we use a one-stage detector with a light backbone to obtain the water level ruler region of interest and make some changes in anchors to simplify the training phase. In module 2), rather than detecting all letters on the water level ruler, we focus on the character “E” because it is more recognized than other alphanumeric characters. In module 3), we recognize the character “E” which is nearest to the water surface into a more specific set of categories. We conducted extensive experiments both on the collected dataset and the real scene data to verify the effectiveness of the proposed method. Experiment results show that in character detection, we achieved 0.981 precision and 0.980 recall. In character recognition, our method yielded 0.952 precision and 0.961 recall, representing an approximately 19% increase over traditional CV technology and a 3% improvement over a fully connected neural network.
2 Related Work Recently, with the advancement of deep learning technology, deep convolutional neural networks have found widespread application in a variety of computer vision tasks. Lin et al. [14] developed a method for detecting water level rulers utilizing a deep learning model to detect the characters by dilation and connected-components searches, followed by machine learning classification algorithms to determine the categories of each character. Tao et al. [15] used SSD, a practical one-stage object detector, to obtain the bounding
42
J. An et al.
box of a water level ruler and localize each character with the Sobel operator. For a better result, Wang et al. [16] detected using a one-stage model but used a deeper convolutional neural network, ResNet-34, as the backbone, and improved the detection accuracy over previous methods. Despite some impressive improvements in detection using deep learning, shifting illumination and different camera positions make it exceedingly difficult to segment the characters using classic image processing technologies. Shan et al. [17] used a deep learning object detection model to complete water level ruler detection and character identification simultaneously without image processing methods and met the goal of reading water levels with an end-to-end model. However, the backbone of the model is ResNet-101, an extremely parametric and deep convolutional neural network, which requires large computing power for training and inference. Besides, an additional model for localizing water lines was built, increasing the computational cost.
3 Methods To address the problems discussed in Sect. 2, we present a method for water level reading based on character detection and recognition. It can recognize the complete or incomplete character nearest to the water surface and all other characters above the water and obtain the water level according to the output. The model consists of three major modules: (1) water level ruler detection, (2) character detection, and (3) character recognition. 3.1 Water Level Ruler Detection To detect the water level ruler efficiently and accurately, we use a one-stage deep model based on YOLOv5s, a lightweight CNN architecture. The model focuses on an extreme speed and accuracy trade-off. Instead of a large backbone, the model was created using an innovative and light backbone that includes various internal blocks such as a CBL, Res-Unit and CSP1 [18], see Fig. 1. As a basic block, a CBL is combined with convolutional layers, batch normalization, and a leaky ReLU active function. The Res-Unit is another basic residual block constructed from several CBL blocks. The CSP1 block is the major model component with two branches: one is combined with the CBL, Res-Unit and convolutional layers sequentially, and the other is a convolutional layer parallel to the previous one. The output of the two branches is concatenated as the output with subsequent layers. Specifically, the model inputs are addressed by an equal interval down sampling method (called the focus block) to decrease the input size, increasing the perceptive field ranges of each pixel while preserving semantic information. Following the backbone, the neck of the model uses an FPN + PAN [19] architecture with CSP2 blocks (similar to a CSP1 block) to extract multiscale features.
A Water Level Ruler Recognition Method
43
Fig. 1. YOLOv5s internal block. (a) CBL and Res-Unit. (b) CSP1_X and CSP2_X
YOLO defines some anchors for feature maps at various scales, and the resulting bounding box is calculated by regressing and classifying each anchor. A water level ruler has a very uniform shape that sets it apart from other objects in the area, so we removed some anchors that are quite distinct from the water level ruler and make the regression task easier than using all scales. 3.2 Character Detection Character detection is performed on the water level ruler detection output. Characters on the water level ruler are divided into two categories: numeric characters belonging to {1, 2, 3,…,9, 0}, and special characters (which look like the character “E” or its mirror), see Fig. 3. We focus on detecting the special “E” character and ignore others because numeric characters have visual differences, whereas the “E” character is almost identical and easily recognized. Considering that the image input size is much smaller than the water level detection, we use the same network architecture in Sect. 3.1 but remove the focus block to prevent over-down sampling (Fig. 2).
Fig. 2. Characters on a water level ruler. Numeric characters and the special “E” character.
44
J. An et al.
3.3 Character Recognition The character “E” nearest to the water surface might be complete or incomplete, and it exhibits more obvious morphological changes compared to the numeric characters due to its special hierarchical structure when flooded by water in different positions, and the process is sequential. As a result, we apply a CRNN-based mixed deep model to recognize characters to produce more precise results. The architecture of the CRNN consists of three components: a CNN, RNN, and transcription [20]. A broadly sequenced design is used to create the CNN component, which includes convolutional layers, pooling layers, and batch normalization (see Fig. 3 (a)). As most text pictures are lower in height and wider in width, the sizes of the final pooling layer were reduced from 2 × 2 to 1 × 2, allowing for downsampling while maintaining the width of the feature map. In practice, we convert the input image to grayscale with a fixed size (channel, height, width) = (1, 30, 15), and the output size is (512, 1, 4). The RNN component comes after the CNN, and uses the output of the CNN as input. The CNN feature map can be thought of as a sequence with a length of 4 and a size of 512, and its column vectors are utilized as input for one RNN time step. Considering that a conventional RNN suffers from gradient disappearance and a lack of rich context information, Bi-LSTM is utilized instead. The special design of Bi-LSTM can capture long dependencies in both forward and reverse directions so that the superposing architecture extracts more context features and then obtains the logits output to classify each time step feature into a set of categories, including the complete “E” character and incomplete classes. Each category corresponds to a length, allowing for more precise water level results in value. The RNN output is used as input to the transcription component, and the connectionist temporal classification (CTC) loss function is used to perform end-to-end combined training for the CNN and RNN. Figure 4 (b) illustrates the RNN and transcription components. To improve efficiency, we only recognize the character nearest to the water’s surface, and the length of the water level ruler above the water surface is calculated using (1): length = (N − 1) ∗ M + S
(1)
where N is the number of “E” characters in the character detection output, M is the length of a complete character, and S is the length associated with the incomplete character categories.
4 Experiments We gather two datasets from real-world sources. The first has 1160 images with the item and character annotations. The second contains 439 images of the complete and incomplete special “E” character, some of which are shown in Fig. 4. All images are from the Yellow River and range in brightness. Experiments are conducted on a PC with 32 GB memory, an i7-9900K 3.7 GHz CPU, and an NVIDIA RTX 2080-Ti GPU.
A Water Level Ruler Recognition Method
45
Fig. 3. CRNN architecture. (a) CNN component (b) RNN and transcription components
Fig. 4. Character recognition dataset. (a) Complete special “E” character. (b) Incomplete “E” character. (c) “E” character labels with various flooding levels.
4.1 Comparisons of Model Complexity We compare the model described in Sec. 2 of [17] with ours in terms of the backbone and running speed. Table 1 shows the results and demonstrates that our method has a lighter backbone and advantages in speed.
46
J. An et al. Table 1. Comparison of the model in Sect. 2 of [17] and our model Paramaters
Speed (FPS)
FasterRCNN(ResNet-101)
85205312
17.7
Ours (YOLOv5s)
1179648
52.8
4.2 Results of the Detection Experiment We compared a boosting method and a template-match method with ours, where the boosting method is based on the Adaboost algorithm and uses local binary pattern (LBP) features. The template-match method uses six images as templates. Our method is superior to the traditional CV method, and the experimental results are shown in Fig. 5.
Fig. 5. Character detection results on the test set. (a)–(d) are the detection results for different brightnesses, angles, and deformations. The sub-diagrams on the left of (a)–(d) are the results of the multi-template-match method, those in the middle are the results of the Adaboost algorithm with LBP features, and those on the right are the results of our approach.
4.3 Results of the Recognition Experiment Table 2 and Fig. 6 present the results of “E” character recognition from some traditional CV/ML methods, a fully connected network (FCN) model, and ours. The results show that our approach is able to recognize various incomplete or complete “E” characters in different deformations or brightnesses and significantly outperforms traditional methods in many metrics.
A Water Level Ruler Recognition Method
47
Table 2. Experimental results of character detection on the test set Recall
Precision
F1-score
KNN (CD + HU)
0.742
0.733
0.725
GBDT (CD + HU)
0.803
0.804
0.794
FCN
0.929
0.937
0.928
Ours
0.961
0.952
0.952
F1-score = (2 × precision × recall)/(precision + recall). In this experiment, the KNN is a set of neighbors = 7; the GBDT is a set of tree_depth = 8, leaves = 128, and we use the image central distance (CD) and HU-moments (HU) as inputs; the FCN is constructed with 2 hidden linear-layer (256 nodes) and the tanh activation function, and the FCN input is binary images of size (24, 24).
Fig. 6. Confusion matrix of “E” character recognition. (a) KNN, (b) GBDT, (c) FCN, (d) Ours.
48
J. An et al.
4.4 Experiment Results on the Real Scene Data In order to verify the effectiveness of the method, we embed the proposed model into the edge computing equipment and then put it in the channel dangerous project of Madu on the Yellow river. The equipment consists of three major parts: a high-definition infrared camera, a wireless communication module, and a data processing module. The camera is utilized to capture the video frames of the water ruler. The data processing module contains algorithm code that provides the water level recognition result. The results are uploaded to the cloud via the wireless communication module. The entire verification lasts about four months, we used an infrared camera with 720P resolution and a 6mm focal distance. We set the observation distance between the equipment and the water level ruler to 50 m. To make a comparison experiment, we count the water level from the real scene manually. The actual operating results are shown in Fig. 7, from which we can observe that the water rulers are well detected and the water level values are accurately presented.
Fig. 7. Test results in Madu. The upper left corner of each sub-image shows the annotation of the length of the water level ruler above the water.
Table 3 presents part of the comparison results, in which ground truth represents the measured data and our result means the recognition results of the proposed method. From the Table, we can see that the recognition error of our method is about 0.02 m. Among these errors, most of them are caused by the surface of the water rising and falling due to environmental factors such as wind and waves, hence these error values are within the acceptable range for the water level.
A Water Level Ruler Recognition Method
49
Table 3. Comparison experiment results Ground truth(m)
Our result(m)
Reading error(m)
1.14
1.14
0
1.05
1.07
+0.02
1.16
1.14
−0.02
1.16
1.16
0
1.18
1.17
−0.01
1.15
1.13
−0.02
0.97
0.94
−0.03
0.84
0.83
−0.01
1.40
1.41
+0.01
0.76
0.78
+0.02
1.10
1.11
+0.01
0.87
0.84
−0.03
5 Conclusion In this paper, we discussed the problem of water level reading using computer vision technologies and proposed a method for water level reading based on deep learning. The approach consists of three major modules: 1) water level ruler detection, 2) character detection, and 3) character recognition. Compared to existing methods, our method does not require an additional model to detect water lines or a traversal operation and ensures precision and real-time performance. Additionally, we contributed a dataset that was obtained from the Yellow River in the real world, including 1160 scene images for detection and 439 special character images for recognition, and used this as a data foundation to complete our work. The experimental results are significant, achieving a 0.980 recall and 0.981 precision on detection and a 0.961 recall and 0.952 precision on recognition. In summary, we believe that our method is efficient and practical and can be used to measure water levels in a real-world environment. In future studies, we will concentrate on multi-object detection of water level rulers and attempt to employ an end-to-end model to simplify training and improve efficiency. Acknowledgments. This work was supported in part by the National Science Foundation of GuangXi Province under Grant 2021JJA170199.
References 1. Chen, G., et al.: Method on water level ruler reading recognition based on image processing. SIViP 15(1), 33–41 (2020)
50
J. An et al.
2. Iwahashi, M., Udomsiri, S.: Water level detection from video with fir filtering. In: 2007 16th International Conference on Computer Communications and Networks, pp. 826–831. IEEE (2007) 3. Feng, J., et al. An estimation method of water level identification based on image analysis. Jiangsu Province: CN109522889A,2019–03–26 4. Jafari, N.H., et al.: Real-time water level monitoring using live cameras and computer vision techniques. Comput. Geosci. 147, 104642 (2021) 5. Yu, L., et al.: Convolutional neural networks for water body extraction from Landsat imagery. Int. J. Comput. Intell. Appl. 16(01), 1750001 (2017) 6. Sabbatini, L., et al.: A computer vision system for staff gauge in river flood monitoring. Inventions 6(4), 79 (2021) 7. Narayanan, R.K., et al.: A novel approach to urban flood monitoring using computer vision. In: Fifth International Conference on Computing, Communications and Networking Technologies (ICCCNT), pp. 1–7. IEEE (2014) 8. Wu, D., et al.: Attention deep model with multi-scale deep supervision for person reidentification. IEEE Trans. Emerg. Top. Comput. Intell. 5(1), 70–78 (2021) 9. Wu, Y., et al.: Person reidentification by multiscale feature representation learning with random batch feature mask. IEEE Trans. Cogn. Dev. Syst. 13(4), 865–874 (2020) 10. Wu, D., et al.: Deep learning-based methods for person re-identification: a comprehensive review. Neurocomputing 337, 354–371 (2019) 11. Liang, X., Wu, D., Huang, D.S.: Image co-segmentation via locally biased discriminative clustering. IEEE Trans. Knowl. Data Eng. 31(11), 2228–2233 (2019) 12. Dou, G., et al.: Research on water-level recognition method based on image processing and convolutional neural networks. Water 14(12), 1890 (2022) 13. Zhang, Y., et al.: Water level recognition method based on water level ruler image. Jiangsu Province: CN108921165B, 2022–04–22 (2022) 14. Feng, L., et al. A water level ruler recognition method based on deep learning[P]. Zhejiang Province: CN110427933A, 2019–11–08 (2019) 15. Zhuo, T., Qing-Chuan, T., Shen, J.J.: Video water level detection algorithm based on SSD target detection. Modern Comput. (Prof.) 09, 60–64 (2019) 16. Wang, L., et al.: Research on water level recognition method based on deep learning algorithm. Water Resour. Inf. 03, 39–43+56 (2020). https://doi.org/10.19364/j.1674-9405.2020. 03.009 17. Shan, S.H., et al.: Deep learning based water level ruler e-zigzag scale recognition method. FuJian Province: CN110472636B, 2022–10–14 (2022) 18. Chae, J.W., et al.: Swoon monitoring system based on YOLOv4-CSP object detection algorithm. Trans. Korean Inst. Electr. Eng. 71(1), 239–245 (2022) 19. Liu, S., et al.: Path aggregation network for instance segmentation. in: 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8759–8768. IEEE (2018) 20. Shi, B., et al.: An end-to-end trainable neural network for image-based sequence recognition and its application to scene text recognition. IEEE Trans. Pattern Anal. Mach. Intell. 39(11), 2298–2330 (2017)
FRVidSwin:A Novel Video Captioning Model with Automatical Removal of Redundant Frames Zehao Dong1 , Yuehui Chen2(B) , Yi Cao3 , and Yaou Zhao2 1 University of Jinan, Jinan, Shandong, China 2 Artificial Intelligence Institute (School of Information Science and Engineering), University of
Jinan, Jinan, Shandong, China [email protected] 3 Shandong Provincial Key Laboratory of Network Based Intelligent Computing (School of Information Science and Engineering), University of Jinan, Jinan, Shandong, China
Abstract. Video captioning aims to generate natural language sentences that describe the visual content of given videos, which requires long-range temporal modeling and consumes significant computational resources. Existing methods typically operate on frames uniformly sampled from videos, leading to time scale inconsistency and redundancy in contiguous frames. In this paper, we propose a transformer-based architecture called Frame-Reduce Swin transformer (FRVidSwin) for video captioning. Our method takes a frame sequence along with the frame indices sampled from a video as input and outputs a natural language sentence describing its content. The FRVidSwin Encoder automatically evaluates the importance of each frame in the video using self-attention and discards redundant ones, reducing computational cost. This allows the model to focus on informative frames to generate high-quality features, improving text synthesis. We propose the Time Index Position Encoding based on Roformer, where the frame indices in the original video are kept and directly encoded. This preserves the time flow consistent with the original video, facilitating the model’s perception of slow and fast movements. Experimental results show that our model can generate high-quality captions and outperforms mainstream models, such as HMN and ORG-TRL, on MSVD and MSR-VTT benchmarks. Keywords: Computer vision · Video caption · Machine Learning
1 Introduction The task of video captioning is to generate a series of sentences that describe the content of a video clip. This requires a model to fully extract spatial-temporal features from the clip and then reorganize them into a representational structure suitable for generating proper text. Similar to image captioning [1], video captioning involves high-level understanding of visual data. However, in addition to detecting and recognizing fundamental elements in images such as objects and scenes, a video captioning model must also recognize the behavior of various objects over time based on their spatial-temporal change © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 51–62, 2023. https://doi.org/10.1007/978-981-99-4761-4_5
52
Z. Dong et al.
patterns. This task is more challenging than action recognition, which requires a model to capture key features to determine one main activity in the video clip. In contrast, video captioning requires comprehensive summarization to express the relationship among different activities, including causality. Traditional methods [2, 3] construct a frame sequence by evenly drawing a fixed number of frames from a video clip. An image classification model or an image action recognition model is then applied in a frame-wise manner to obtain features for each frame, capturing spatial information. Many pretrained models, such as ResNet backbone [4], Optical Flow [5], and C3D [6], are applicable in this phase. The obtained feature sequence is then processed by a sequence model to obtain temporal information. Finally, a language model generates the text based on these features. Since the extracted feature greatly influences performance, researchers initially focus on designing powerful encoders to generate better representations [7]. These models are usually pretrained over a large amount of data and then applied to various tasks via transfer learning. However, considerable differences generally exist among data distributions and task characteristics, making model fine-tuning strikingly difficult and time-consuming. Transformers have become widely used in video processing tasks due to their advantages in dealing with sequence data. Novel transformer architectures such as ViViT [8] and TimeSformer [24] incorporate spatial-temporal attention. ViVIT employs selfattention to process video clips of variable lengths, which allows for the design of endto-end video captioning models thanks to its flexibility and expressivity. VidSwin [9] introduces the local computing window mechanism to extract local features at different scales, reducing the computational cost. With diverse features, VidSwin achieves the state-of-the-art (SOTA) results in various video analysis tasks. However, self-attention still faces two major issues when dealing with video data: (1) it consumes significant computational resources due to the high dimensionality of video data, even for lowresolution videos, and (2) most models focus only on spatial information extraction and neglect the correlations or differences between spatial and temporal information. In a recent study [10], it was found that redundant information across contiguous frames does not provide any improvement in downstream tasks. To address this, the authors proposed a sparse sampling strategy with a CLIPbert model, which demonstrated notable success in video question answering and text video retrieval. However, it remains unclear whether this sampling method can generate adequate video representations. Moreover, many existing models are inadequate in processing temporal information, with some completely ignoring it. For example, avgpooling operations result in a loss of crucial temporal information. Some other models only roughly sort frames according to their temporal order and thus cannot precisely represent time or distinguish between slow and fast actions. In this paper, we introduce a new video encoder that can address the aforementioned challenges and learn comprehensive video representations in an end-to-end manner. Our proposed model, called Frame-Reduce Swin (FRVidSwin), is inspired by the Swin-transformer and leverages self-attention layers to automatically filter out repetitive frames. This is achieved through an uneven sparse sampling strategy that is embedded within the network and is trained to distinguish between redundant and informative
FRVidSwin: A Novel Video Captioning Model
53
frames. Furthermore, to maintain the precise temporal information of each frame, we incorporate Index Position Encoding to encode the frame index, thereby improving the accuracy of the generated captions. The contributions of our work are as follows: • The proposed FRVidSwin encoder can perform sparse sampling without extra computations. Relying on this design, FRVidSwin can process videos in variable lengths in an end-to-end manner. The performance of FRVidSwin is not sensitive to the length of video, showing better generalization; • The proposed position encoding method solves the time scale inconsistency problem caused by the sparse sampling.
2 Methods 2.1 Model Architecture
Video Frame image
sampling process Frame number
FR VidSwin Encoder
Video Token
Frame token index [10, 13, 23, 30, 43, 53, 56, 60, 63, 73, 80, 86, 90, 96,]
Sentence Decoder
Caption a girl is riding a bike around a parking lot
Index Position Encoding
[0 3 6 10 13 16 20 23 26 30 33 36 40 43 …]
Fig. 1. The architecture of FRVidSwin. Given a frame sequence along with the frame indices sampled from a video as the input, FRVidSwin outputs a natural language sentence describing its content.
Figure 1 depicts the overall architecture of the proposed FRVidSwin, which is designed to generate natural language descriptions of video content from input video frames and their associated indices. The model comprises three main components: a FRVidSwin Encoder, an Index Position Encoding Block, and a Sentence Decoder. The FRVidSwin Encoder extracts spatial-temporal information from the input video and automatically filters out redundant frames to enhance sparsity. The Index Position Encoding Block encodes the temporal gaps between the preserved video tokens and their corresponding time indices and integrates the resulting time position codes with the frame features. The Sentence Decoder utilizes a sequence-to-sequence (seq2seq) model to generate natural language sentences. Detailed explanations of each module are presented in the subsequent subsections.
54
Z. Dong et al. 3D W-MSA/3D SW-MSA
Attention Matrixes
Frame Index:[0 3 6 10 13 16 20 23 26 30 33 36 40 43 …]
Frame Index:[0 3 6 10 13 16 20 23 26 30 33 36 40 43 …]
Fig. 2. Schematic diagram of FR module. We take the average value of the attention score of the tokens in each video feature frame as the measure of the video feature frame.
2.2 FRVidSwin Encoder As mentioned in [1], using long frame sequences or a dense sampling strategy can benefit video understanding. However, sampling more frames from videos can result in highly redundant video data, as most parts of a video change slowly. This repetitive information does not contribute to learning and can increase computational costs if not filtered out. To address this issue, we designed a Frame-Reduce Video Swin Encoder (FRVidSwin Encoder) that evaluates the importance of each frame and drops unimportant ones. The FRVidSwin Encoder is adapted from the VidSwin model [2] which is pre-trained on Kinetics action recognition dataset. As shown in Fig. 2, the attention matrix generated by the 3D W/SW-MSA operation in FRVidSwin is used to measure the attention of a patch-token and multiple patchtokens of the same window during self-attention operation. After the 3D W/SW-MSA module, we measure the contribution of each video frame using the average value of the attention score of the tokens in each video feature frame. This method fairly measures the importance of each video frame in the task, as it is normalized when dividing the video feature sequence into different windows. We set F = {F1 , F2 , F3 , F4 } as the number of tokens remaining after each block. The average value of the attention matrix of patches from the video token and participants in the W-MSA calculation is used as the score of the video token. The model input is a video clip with T frames, each of which is a colorful image with the resolution H × W . The video clip is first converted into a tensor with the shape as [T , H , W , 3]. Then, 2 contiguous frames are combined together and divided into patches with the shape of [2 × 4 × 4 × 3]. Each patch is called a 3D-patch token. Thus, the input tensor can be converted into a token tensor with the shape as [T /2, H /4, W /4, 96], whose elements are 96-dimensional vectors. This token tensor then goes through 4 stages as shown in Fig. 3. Within each stage, the features are extracted by VidSwin transformers. After each stage, redundant frame features are omitted by Frame Reduce Block.
FRVidSwin: A Novel Video Captioning Model
55
In our standard model, the hyperparameter F is set to F1 = 28, F1 = 24, F1 = 20, F1 = 16. Finally, FRVidSwin Encoder outputs a sequence of 16 video tokens and the index of the video frame corresponding to each token in the input video. 2.3 Time Index Position Encoding
Fig. 3. Schematic diagram of FRVidSwin Encoder. According to the set target frame or frame rate, linear reduces the frame number in 4 blocks.
As noted in [29], video clips are often sampled at a constant time interval to build a frame sequence, with the sampled frames being re-indexed sequentially. However, this approach may not be optimal for videos of varying lengths since adjacent frames in short videos correspond to shorter durations than those in long videos, but are uniformly reindexed. This can lead to a loss of information as the model cannot distinguish between slowly and fast changing movements. Moreover, the FRVidSwin Encoder, which discards frames unevenly, exacerbates this issue. To address this problem, we introduce a Time Index Position Encoding method in our model. In this method, a video clip is uniformly sampled to build a frame sequence, but the original frame indices in the video are kept in the sequence as an absolute positional encoding. During the FRVidSwin Encoder, original frame indices are always preserved, allowing us to see which redundant frames are dropped. By maintaining this positional information, the decoder can correctly perceive the time flow in the sequence data. Therefore, prior to the Sentence Decoder, these frame indices are encoded using the non-trainable Roformer [30] positional encoding method, as shown in Eq. (1) and Eq. (2). f (V m , Im ) = Im V m ⎛
Im
cos Im θ0 −sin Im θ0 0 0 ⎜ sin I θ cos I θ 0 0 m 0 m 0 ⎜ ⎜ 0 0 cos Im θ1 −sin Im θ1 ⎜ ⎜ 0 0 sin Im θ1 cos Im θ1 =⎜ ⎜ .. .. .. .. ⎜ ⎜ . . . . ⎜ ⎝ 0 0 0 0 0 0 0 0
··· ··· ··· ··· .. .
(1) 0 0 0 0 .. .
· · · cos Im θd /2−1 · · · sin Im θd /2−1
0 0 0 0 .. .
⎞
⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ ⎟ −sin Im θd /2−1 ⎠ cos Im θd /2−1 (2)
56
Z. Dong et al.
V m is the m-th video token in the video token sequence, whose frame index is Im . Concerning the sparsity of I , the calculation of Eq. (1) can be accelerated by Eq. (3), where ⊗ is the element-wise multiplication. ⎞ ⎛ ⎞ ⎛ ⎞ ⎛ ⎞ ⎛ cos Im θ0 −Vm0 sin Im θ0 Vm0 ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ V ⎜ m1 ⎟ ⎜ cosIm θ0 ⎟ ⎜ Vm1 ⎟ ⎜ sin Im θ0 ⎟ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎜ Vm2 ⎟ ⎜ cos Im θ1 ⎟ ⎜ −Vm2 ⎟ ⎜ sin Im θ1 ⎟ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ V 3 ⎟ ⊗ ⎜ cos Im θ1 ⎟ + ⎜ Vm3 ⎟ ⊗ ⎜ sin Im θ1 ⎟ (3) f (V m , Im ) = ⎜ ⎜ ⎜ ⎜ ⎟ ⎟ ⎟ ⎟ ⎜ m .. .. .. ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ .. ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ . ⎟ ⎜ . . . ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎟ ⎜ ⎝ Vmd −2 ⎠ ⎝ cos Im θd /2−1 ⎠ ⎝ −Vmd −2 ⎠ ⎝ sin Im θd /2−1 ⎠ Vmd −1 cos Im θd /2−1 Vmd −1 sin Im θd /2−1 By the properties of trigonometric functions, it can be verified that Eq. (4) holds. It indicates that the Time Index Position Encoding can reflect the temporal distance between two frames during self-attention operations. (Φ Im V m ) (Φ In V n ) = V m Φ Im Φ In V n = V m Φ In −Im V n
(4)
2.4 Sentence Decoder We build the Sentence Decoder and Tokenizer using the bert-base-uncased configuration from the huggingface-transformers library. Its vocab size is 30522 and hidden layer dimension is 768. Dimension transformation module consists of a 1D Convolutional layer with 1 × 1 kernal size.We use the classic Transformer way to generate sentences. The generative model takes the start tag (BOS) as the initial input, generates one-word tag at a time, and uses the previously generated tag as input to the decoder until our model outputs a predefined end tag (EOS) or reaches the output limit. 2.5 Training The model training contains 2 stages. First, we initialize the FRVidSwin Encoder using the VidSwin model parameters pre-trained on the k400 training dataset [3] and randomly initialize other parts. We freeze the parameters of the encoder and train the dimension transformation module and Sentence Decoder. In the second stage, the entire network is fine-tuned. In this way, the FRVidSwin Encoder will focus more on the semantic information related to the video description, further improving the feature extraction, and helping the model select more relevant content from the video. The applied loss function is Cross-Entropy as shown in Eq. (5). Lmax ζ (wi ) logW i (5) LCE = − i=1
where wi represents i-th word in label sentence, Lmax is maximum sentence length, Wi represents generated result from sentence generator after softmax at i-th position. Where ζ(wi ) ∈ R30522 is the one-hot encoding of word wi in vocabulary with a length of 30,552 from BERT base model.
FRVidSwin: A Novel Video Captioning Model
57
3 Experimental Results In this section, we use four widely used evaluation metrics, BLEU@4 [4], METEOR [5], ROUGE-L [6], and CIDER [7], to measure the performance of the model on two classical datasets: MSVD [8], MSRVTT [9]. The performance of the model is compared with other models, and ablation experiments are performed to verify the effect of FRVidSwin. 3.1 Datasets MSVD. MSVD is a classic dataset for video captioning tasks. It consists of 1970 Youtube video clips, each of them has 35 label sentences written by human as captions. We choose 1200 short videos as the dataset, 100 for verification, and 670 as the test set. This is also the standard segmentation method for most video captioning models. MSRVTT. MSRVTT is the Microsoft Research - video to text (MSR-VTT) challenge of ACM multimedia 2016. MSR-VTT contains video clips from 20 categories, including music, people and games. It contains 10,000 video clips, each of which is labeled with about 20 English sentences, with a vocabulary of 28k words. Our experiments use the same setup as existing methods (6513, 497, and 2990 videos for training, validation, and testing).
3.2 Implementation Details Video Pre-processing. In the experiment, we uniformly adjusted the video frame rate to 30fps and unified the video resolution to 224 × 224. Text Pre-processing. First, we delete the punctuation in the captions and turn it into lowercase. Then caption sentences are uniformly truncated or padded to 20 words to keep up with existing benchmarks, and tokenized by the standard BertTokenizer from huggingface. Other Details. The FRVidSwin Encoder is based on the Swin Transformer Model pretrained on the k400 dataset. Its window size is 8 × 7 × 7. The FRVidSwin Encoder is divided into four blocks with a depth of [2, 2, 18, 2], and the numbers of heads of the self-attention in each stage are [4, 8, 16, 32]. The number of video tokens is 16, the hidden layer dimension is 768, the learning rate is set to 1e − 5 at first 10 epochs, and then decrease at the rate of 0.2. During the test, we use a beam search with size 3 to generate subtitles. Our entire system is implemented with PyTorch, and all experiments are conducted on 2 RTX3090 GPUs.
3.3 Main Results We compared FRVidSwin with previous results on both MSVD and MSRVTT datasets. Our model achieved improvements on all four metrics in the MSVD baseline test. The results are as shown in Table 1. Among them, our model achieved 115.0 on cider (+11.1 improvement), 60.8 on BLEU@4 (+1.6 improvement), 39.2 on METEOR (+1.5
58
Z. Dong et al.
Table 1. Comparison with state-of-the-art methods on MSVD and MSR-VTT benchmarks. The best results are shown in bold. Method
MSVD
MSR-VTT
B@4
M
R
CIDEr
B@4
M
R
CIDEr
PickNet [10]
52.3
33.3
69.6
76.5
41.3
27.7
59.8
44.1
RecNet [11]
52.3
MARN [12]
48.6
34.1
69.8
80.3
39.1
26.6
59.3
42.7
35.1
71.9
92.2
40.4
28.1
60.7
47.1
OA-BTG [13]
56.9
36.2
–
90.6
41.4
28.2
–
46.9
POS-CG [14]
52.5
34.1
71.3
88.7
42.0
28.2
61.6
48.7
MGSA [15]
53.4
35.0
–
86.7
42.4
27.6
–
47.5
GRU–EVE [16]
47.9
35.0
71.5
78.1
38.3
28.4
60.7
48.1
STG-KD [17]
52.2
36.9
73.9
93.0
40.5
28.3
60.9
47.1
SAAT [18]
46.5
33.5
69.4
81.0
40.5
2.2
60.9
49.1
ORG-TRL [19]
54.3
36.4
73.9
95.2
43.6
28.8
62.1
50.9
SGN [20]
52.8
35.5
72.9
94.3
40.8
28.3
60.8
49.5
MGRMP [21]
55.8
36.9
74.5
98.5
41.7
28.9
62.1
51.4
HMN [22]
59.2
37.7
75.1
104.0
43.5
29.0
62.7
51.5
FRVidSwin (ours)
60.8
39.2
76.9
115.0
43.4
29.1
62.4
52.9
improvement) and 76.9 on ROUGE-L (+1.8 improvement). In the MSRVTT baseline test, our model also achieved the best of the two indicators. It is worth noting that the CIDER indicator can better capture human judgment on natural language by measuring human judgment on consensus. Therefore, our model can generate sentences with flexible structure and accurate semantics. 3.4 Ablation Study
Table 2. Effectiveness of FR module on video captioning task. Video Encoder
MSVD B@4
MSR-VTT M
R
CIDEr
B@4
M
R
CIDEr
VidSwin Encoder
61.0
38.9
76.6
112.1
43.2
29.0
62.3
52.6
FRVidSwin Encoder
60.8
39.2
76.9
115.0
43.4
29.2
62.5
52.9
Effectiveness of FRVidSwin Encoder. To verify the effectiveness of the FR block, we show the results of the ablation experiment in Table 2. First, we use an unmodified
FRVidSwin: A Novel Video Captioning Model
59
VidSwin Encoder as a video coder to train a baseline model Table 2. In the second row, we show the effect of our method on the same task. We can observe from this that adding the FR module can improve the coding effect of the video coder in the video description task. Table 3. Transfer between different frame rates. Video Encoder
frames
VidSwin Encoder
64
61
38.9
76.6
112.1
43.2
29
62.3
52.6
64 → 32
60.6
38.6
76.2
110.2
42.5
29
62
52.5
64 → 16
58.4
38
75.2
105.5
41.7
28.6
61.5
51.3
FRVidSwin Encoder
MSVD B@4
MSR-VTT M
R
CIDEr
B@4
M
R
CIDEr
64
60.8
39.2
76.9
115
43.4
29.2
62.5
52.9
64 → 32
60.5
38.9
76.5
113.1
43.4
29.1
62.4
52.9
64 → 16
60.6
38.6
76.2
111.2
42.7
28.7
62.3
52.3
Transfer Between Different Frame Rates. Because our proposed FRVidSwin Encoder focuses on the more important parts of the video during the video encoding stage, our model should theoretically gain better generalization performance between different input number of frames. The results in Table 3 also prove this. Table 3 shows the result of FRVidSwin Encoder and original VidSwin Encoder migration between different number of frames. Notably, the results on MSVD and MSR-VTT benchmarks show that the CIDEr of the FRVidSwin Encoder drops from 115.0/52.9 to 110.2/52.3, which is smaller than the 112.1/52.6 to 105.5/51.3 of the original VidSwin Encoder, indicating that our model has better performance when migration model between different number of frames. Table 4. Effectiveness of Time Index Position Encoding on video captioning task. Models
datasets
B@4
M
R
CIDEr
VidSwin Encoder
MSVD
60 5
38 5
76 3
109 1
VidSwin Encoder + Time Index PE
MSVD
61
38.9
76.6
112.1
FRVidSwin Encoder
MSVD
59 9
37 7
75 0
107 8
FRVidSwin Encoder + Time Index PE
MSVD
60.8
39.2
76.9
115
Effectiveness of Time Index Position Encoding. From Table 4, we can see that removing redundant video tokens without including positional encoding leads to a degradation in the model’s performance. The reason for this is that the time scale of the video data
60
Z. Dong et al.
is corrupted during the removal of the redundant token, which reduces the quality of the video data and leads to a degradation in the quality of the generated utterances.
3.5 Qualitative Results
Fig. 4. Qualitative results. The first row shows evenly sampled video frames for reference. The video frames that are marked with a red line have been removed by FR module. A comparison between the sentences generated by our model and Ground-Truth is shown below.
Figure 4 shows some qualitative results of FRVidSwin. The first row shows evenly sampled video frames for reference. The video frames that are marked with a red line have been removed by FR module. A comparison between the sentences generated by our model and Ground-Truth is shown below. From the figure we can see that our model can generate high quality captions and in most cases is able to select video frames that help generate captions and ignore some redundant frames. In the first example, the character and environment changes in the first five images are minimal, so more video frames are removed by the model. Only two of last eight frames has been removed because the camera movements and transitions contain a lot of new information.
4 Conclusions In this paper, we introduce FRVidSwin, an end-to-end Transformer-based architecture for video captioning. We propose a novel video frame sequence sparse method based on attention matrix and design the FRVidSwin Encoder to improve the model’s generalization ability. To further enhance the model’s performance, we introduce Time Index Position Encoding, a method to generate positional encoding based on time distribution, which can represent the temporal distance between video tokens and capture their
FRVidSwin: A Novel Video Captioning Model
61
temporal information. Our experimental results demonstrate the effectiveness of the proposed modules, and our method achieves state-of-the-art performance on the MSVD and MSR-VTT benchmarks. Acknowledgments. This work was supported in part by the National Natural Science Foundation of China (No. 52001039), National Natural Science Foundation of China under Grand (No. 52171310), improvement project for small and medium-sized enterprises in Shandong Province (No. 2021TSGC1012) and the University Innovation Team Project of Jinan (2019GXRC015).
References 1. Donahue, J., et al.: Long-term recurrent convolutional networks for visual recognition and description. IEEE Trans. Pattern Anal. Mach. Intell. 39, 677–691 (2017). https://doi.org/10. 1109/TPAMI.2016.2599174 2. Liu, Z., et al.: Video Swin Transformer. arXiv:2106.13230 [cs]. (2021) 3. Kay, W., et al: The kinetics human action video dataset. arXiv preprint arXiv:1705.06950 (2017) 4. Papineni, K., Roukos, S., Ward, T., Zhu, W.-J.: BLEU: a method for automatic evaluation of machine translation. In: Proceedings of the 40th Annual Meeting on Association for Computational Linguistics, pp. 311–318. Association for Computational Linguistics, USA (2002) 5. Banerjee, S., Lavie, A.: METEOR: an automatic metric for MT evaluation with improved correlation with human judgments. In: Proceedings of the ACL Workshop on Intrinsic and Extrinsic Evaluation Measures for Machine Translation and/or Summarization, pp. 65–72 (2005) 6. Lin, C.-Y.: Rouge: a package for automatic evaluation of summaries. In: Text Summarization Branches Out, pp. 74–81 (2004) 7. Vedantam, R., Zitnick, C.L., Parikh, D.: CIDEr: consensus-based image description evaluation. In: 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4566–4575 (2015) 8. Chen, D.L., Dolan, W.B.: Collecting highly parallel data for paraphrase evaluation. In: ACL, pp. 190–200 (2011) 9. Xu, J., Mei, T., Yao, T., Rui, Y.: MSR-VTT: a large video description dataset for bridging video and language. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5288–5296 (2016) 10. Chen, Y., Wang, S., Zhang, W., Huang, Q.: Less is more: picking informative frames for video captioning. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11217, pp. 367–384. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-012618_22 11. Wang, B., Ma, L., Zhang, W., Liu, W.: Reconstruction network for video captioning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7622– 7631 (2018) 12. Pei, W., Zhang, J., Wang, X., Ke, L., Shen, X., Tai, Y.-W.: Memory-attended recurrent network for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8347–8356 (2019) 13. Zhang, J., Peng, Y.: Object-aware aggregation with bidirectional temporal graph for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8327–8336 (2019)
62
Z. Dong et al.
14. Wang, B., Ma, L., Zhang, W., Jiang, W., Wang, J., Liu, W.: Controllable video captioning with pos sequence guidance based on gated fusion network. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 2641–2650 (2019) 15. Chen, S., Jiang, Y.-G.: Motion guided spatial attention for video captioning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 8191–8198 (2019).https://doi. org/10.1609/aaai.v33i01.33018191 16. Aafaq, N., Akhtar, N., Liu, W., Gilani, S.Z., Mian, A.: Spatio-temporal dynamics and semantic attribute enriched visual encoding for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12487–12496 (2019) 17. Pan, B., et al.: Spatio-temporal graph for video captioning with knowledge distillation. In: 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 10867–10876 (2020) 18. Zheng, Q., Wang, C., Tao, D.: Syntax-aware action targeting for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13096–13105 (2020) 19. Zhang, Z., Shi, Y., Yuan, C., Li, B., Wang, P., Hu, W., Zha, Z.-J.: Object relational graph with teacher-recommended learning for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13278–13288 (2020) 20. Ryu, H., Kang, S., Kang, H., Yoo, C.D.: Semantic grouping network for video captioning. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 2514–2522 (2021) 21. Chen, S., Jiang, Y.-G.: Motion guided region message passing for video captioning. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1543–1552 (2021) 22. Ye, H., Li, G., Qi, Y., Wang, S., Huang, Q., Yang, M.-H.: Hierarchical modular network for video captioning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 17939–17948 (2022)
A Simple Mixed-Supervised Learning Method for Salient Object Detection Congjin Gong1,2 , Gang Yang1(B) , and Haoyu Dong1 1 Northeastern University, Shenyang 110819, China
[email protected] 2 DUT Artificial Intelligence Institute, Dalian 116024, China
Abstract. Weakly supervised salient object detection aims to address the limitations of fully supervised methods that heavily rely on pixel-level data. However, the sparse nature of weak labels often results in suboptimal detection accuracy. Drawing inspiration from human visual attention mechanisms, we propose a Mixed-Supervised Learning method to mitigate this issue. Mixed-Supervised Learning refers to training a neural network with hybrid data. Specifically, we propose a two-stage training strategy. In stage I, the model is supervised by a large number of scribble annotations so that it can roughly locate salient objects. In stage II, a small number of pixel-level labels are used for learning to endow the model with detail decoding capability. Our training strategy decomposes the SOD task into two sub-tasks, object localization and detail refinement, and we design a corresponding network, LRNet, which includes the supplementary detail information, a Feature Attention module (FA), and a Detail Refinement module (DF). The two-stage training strategy is simple and generalizable. Extensive experiments demonstrate the effectiveness of the training strategy, and our model detection accuracy surpasses the existing state-of-the-art models of weakly supervised learning, even reaching fully supervised results. Besides, experiments on COD and RSI SOD tasks demonstrate the generality of our method. Our code will be released at https://github.com/nightmengna/LRNet. Keywords: Salient object detection · Mixed-Supervised learning
1 Introduction Salient Object Detection (SOD) is dedicated to detecting and segmenting the objects or regions of an image that attract the most human attention. It has been widely applied to other computer vision fields such as object recognition, target tracking, image segmentation, and other vision tasks [1]. With the recent development of deep convolutional neural networks (CNNs) [2] and transformer [3], SOD methods with RGB image inputs (RGB SOD) [1, 4, 5] have achieved significant progress. However, these deep learningbased methods heavily rely on a large amount of data with pixel-level labels for model training. Such densely labeled annotations are often labor-intensive and costly. To solve the problem of the high cost of pixel-level annotation, several weakly supervised methods are proposed. Common weakly supervised labels include image-level [6], © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 63–74, 2023. https://doi.org/10.1007/978-981-99-4761-4_6
64
C. Gong et al.
bounding box [7], scribble [8] and point [9] annotation. The information required for training SOD model can be divided into two categories, location information, and detail information. Weak annotations contain only location information, not detail information, and their application results in limited detection performance. Previous weakly supervised methods design loss functions [10] to impose constraints on the scope to be predicted or introduce auxiliary tasks [11] to mine more information required for SOD task. Although these methods are free of pixel-level labels, they require rigorous mathematical derivation, additional training consumption, and a complex training strategy. In this paper, we propose a simple mixed-supervised method to address the problem of heavy burden of pixel-level annotation and insufficient information in weak labels. Mixed-supervised learning refers to training a neural network with multiple supervised types of data, usually containing a large number of weak annotations and a small number of pixel-level labels [12]. Cong et al. [13] first implements hybrid data in SOD task. They utilize pixel-level labels to reduce noise in weak annotations iteratively, using updated high-quality pseudo-label for training. Unlike Cong et al. [13], our proposed model learns progressively from hybrid data. As scribble annotations can convey positional and structural information without background noise, we chose them over other weak labels for mixed supervision in this paper. Specifically, we employ a large number of scribble annotations and a small number of pixel-level labels in a ratio of 20:1 as the training dataset and introduce a two-stage training strategy. In stage I, the model learns from scribble annotations that allow it to roughly localize salient objects. In stage II, based on rough localization, the model trains on a small number of pixel-level labels to improve its detail detection capability. This two-stage training strategy is consistent with human visual attention mechanisms, focusing first on the area of interest and then on the salient objects. The two-stage training strategy decouples the SOD task into two sub-tasks of object localization and detail refinement. Based on the two sub-tasks and the training strategy, we design a localization and refinement network—LRNet. In the object localization sub-task, we design a U-shape [14] network based on PVT [3] that can predict a rough saliency map with its capability to integrate features gradually. For the detail refinement sub-task, we propose a feature attention module(FA) for delivering rough saliency information, which can use the side output of LRNet as spatial attention to guide the learning of low-level features. At the same time, to enhance the capability of feature representation, we additionally introduce low-level features from ResNet [2] as a supplementary detail information encoder. Finally, considering the variability of detail features, we integrate features into a detail refinement module (DF) via a stepwise fusion strategy to avoid losing details. The designed modules integrate different data information into the model directly. Further, we find that our training strategy may be applicable to some downstream segmentation tasks. We also carry out tests in the areas of camouflaged object detection(COD) and salient object detection in remote sensing images (RSI SOD). Overall, our key contributions are summarized as follows: 1) We propose a simple mixed-supervised learning method for the SOD task that enables the model to learn from scribble annotations and pixel-level labels in stages, solving the problem of difficult data annotation and insufficient information in weak labels.
A Simple Mixed-Supervised Learning Method
65
2) Based on our method, we propose a localization and refinement network, LRNet, consisting of a supplementary detail information encoder, a FA module for guiding detailed decoding, and a DF module for feature fusion. 3) Our model outperforms existing weakly supervised methods on five SOD datasets and approaches the fully supervised level. Furthermore, it shows good generalization in the COD and RSI SOD tasks.
2 Proposed Method 2.1 Training Strategy Our training strategy is inspired by human visual attention mechanisms. Usually, when people enter a new scene, they always notice a salient area first and then identify the salient objects. Thus, we argue that localization-then-segmentation strategies can reduce the SOD task difficulty. Therefore, we design a two-stage training strategy. Stage I (Salient Object Localization): In stage I, the model is expected to detect roughly salient regions, as if people were initially attracted to something. We use 10k scribble annotations from DUTS [8] for training in this stage. Scribble annotations can represent the rough location of the object and due to the ease of annotation, a large number of samples can be obtained. The more samples we have, the wider the sample distribution we can obtain. In addition, scribble annotations maintain the structural information of the object, which is advantageous in the segmentation of irregular targets. By training the model with a large number of scribble annotations, the model is able to learn sufficient sample distribution and predict the approximate location of salient objects. Stage II (Detail Refinement): When people are attracted, they want to know clearly what attracts them and focus their attention on the salient objects through eye focusing. Similar to this, the model uses the prediction map of stage I as a guide to focus and refine the salient objects. In this paper, 500 images sampled from DUTS with pixel-level labels are used for training in stage II. Why is it possible to achieve detail refinement with only a few pixel-level labels in stage II? Since only local information needs to be considered, not global information, it is relatively simple to predict details at the local scale. More- over, it is well known that the detail features are similar. Different objects may have the same features, such as boundaries, colors, and textures, thus it is not necessary to annotate all samples with details. The prediction map after detail refinement on the basis of the rough saliency map is the final result. The two-stage training strategy fits well with the weak and strong annotations and the characteristics of two subtasks in SOD, and is able to integrate the hybrid data information into the model. 2.2 Network Structure Based on the above training strategy, we design a corresponding localization and refinement network, LRNet. The overall structure of the network is shown in Fig. 1. It consists of a U-shape network based on PVT, a supplementary detail information encoder from
66
C. Gong et al.
Fig. 1. The overall framework of the proposed LRNet.
ResNet, a feature attention module(FA), and a detail fusion module(DF). In stage I, scribble annotations are used to supervise the two outputs of the network, the final output Pred. And the side output Pred.-S; In stage II, the side output Pred.-S is treated as a rough saliency map and sent to FA module to guide decoding details. The detail maps are fed to DF module for decoding along with the supplementary detail information and Pred.-S to get the final output Pred.. Pred. is supervised by pixel-level labels. In the salient object localization stage, we need to combine global contextual information to localize the object. Considering that Transformer has a stronger global modeling capability than CNNs, we use PVT as the backbone. Referring to the U-shape network, we also design a level-by-level encoding and decoding structure to gradually recover the feature maps. In the detail refinement stage, we design a feature attention module(FA) and a detail fusion module(DF), and introduce a supplementary detail information encoder. Feature Attention Module(FA): After training of stage I, the model is able to roughly locate the salient objects. Then, we use Pred.-S as attention to guide detail features for decoding. Since high-level features contain semantic information and have low resolution, we only perform detail extraction on the first two layers of features. The module structure is shown in Fig. 1 and is defined as: F = Sigmoid (Conv(Cat(Conv(Fm ⊗ Ps ), Fm )))
(1)
where F m denotes the feature map of backbone, Ps represents the side output Pred.-S and F is the output of FA. Supplementary Detail Information Encoder: We additionally introduce the first two layers of ResNet as another encoder. The main reasons include: 1) In stage I, due to the
A Simple Mixed-Supervised Learning Method
67
supervision of the scribble annotation, the model only focuses on location and structure information, and its original detail decoding ability is degraded. 2) CNNs and Transformer have different encoding forms, and introducing multiple forms of detail encoding will enrich the detail features. 3) ResNet itself contains pre-trained parameters, which can extract sufficient detail information. The introduction of supplementary detail information further improves the model refinement capability. Detail Fusion Module(DF): In the final operation, we have a rough saliency map Pred.-S and 4 detail feature maps (F1, F2, Fr1, Fr2). Direct fusion of features may result in loss of details due to feature misalignment. Therefore, we use a step-by-step detail fusion strategy. As shown in Fig. 2, we first fuse the features from the FA module by pixel addition and convolution operations respectively, and then fuse the features from the ResNet by concatenation and convolution operations. Finally, the fused feature maps are optimized through a zoom-in and zoom-out strategy [15] to obtain the final prediction map. Specifically, the feature map is sampled up and down to different scales, and convolution operation is performed in multi scales to optimize the details.
Fig. 2. The structure of DF module.
2.3 Loss Functions Scribble Loss: Under the supervision of the scribble annotations, we use partial crossentropy to calculate the loss to train the model. The loss is defined as: Lpce = −GTi logSi − (1 − GTi )log(1 − Si ) (2) i∈Scr
where GT denotes the ground truth, Scr is the pixel set of the scribble annotations. Full-Supervised Loss: In the detail refinement stage, we train the model using pixellevel labels and calculate the loss using IGL loss [16] and Floss [17], which can be expressed as: LF = LIGL + LFLoss
(3)
where L IGL (·) contains a BCE loss, a boundary loss, and an IOU loss. It can be formulated as: LIGL = LBCE + αLB + βLIOU
(4)
where α and β are balance weights, we empirically set them as 1 and 0.7 respectively.
68
C. Gong et al.
3 Experiments 3.1 Datasets Training Datasets. We use pixel-level labels and scribble annotations for different stages of training. Concretely, we sample 500 random samples from the DUTS-TR [18] dataset as full-supervised data, providing pixel-level annotations. The rest of the DUTS-TR dataset serves as scribble annotations, where scribble annotations are from [8]. Testing Datasets. To validate the effectiveness of the proposed model, we perform our experiments on 5 public SOD datasets: DUTS-TE [18] (5019 images), DUT-OMRON [19] (5168 images), ECSSD [20] (1000 images), PASCAL-S [21] (850 images), HKUIS [22] (4447 images). Evaluation Metrics. We adopt the following metrics for evaluation, including weighted F-measure(Fβω ) [23], Mean Absolute Error(MAE) [24], and S-Measure(S m ) [25]. The weighted F-measure and S-measure values are positively related to performance, but the MAE score is just the opposite. 3.2 Comparison with State-of-the-Arts We compare our method with other state-of-the-art models, including fully- supervised methods(VST [4], GateNet-R [26], DCNet [5], SelfReformer [1], MINet [27], U2-Net [28]) and weakly-supervised methods(WSSA [8], SCWSSOD [10], PSOD [9], MFNet [6], NRDNet [29]). For the sake of fairness, the compared prediction maps are provided by the authors, and the evaluation metrics are derived from the original papers or calculated under the same benchmark. Quantitative Evaluation. The comparative results of the different methods are shown in Table 1. In comparison to the weakly supervised model, our method achieves optimal metrics. In the DUTS-TE dataset, our method achieves 4.8%, 1.1%, and 2.4% improvement on the three metrics over the suboptimal method respectively. Notably, our method achieves suboptimal levels for fully supervised models. In MAE and Fβω , our method outperforms most fully supervised models. But there is still a gap compared to the stateof-the-art full supervision methods. Quantitative comparative results demonstrate the necessity and superiority of hybrid data. Although the mixed supervised method lacks some fairness when compared with the weakly supervised method due to the use of strong labeled data, training with hybrid data can obtain greater model performance gain by using a small number of additional pixel-level labels. Compared to fully supervised methods, we achieve 90%–96% of their performance at only 10% of the data labeling cost. Qualitative Evaluation. Part of the prediction maps are shown in Fig. 3. We can find that our method has advantages in some challenging scenarios: large objects(1st -2nd rows), small objects(3rd -4th rows), multiple objects(5th rows), complex edges(6th -7th rows), special structures(8th -9th rows). Compared to weakly supervised methods, we only use an additional 500 pixel-level labels to obtain more complete and refined prediction maps with high quality details.
A Simple Mixed-Supervised Learning Method
69
Table 1. Quantitative results of different methods on five SOD benchmark datasets, ↑ and ↓ respectively indicate that the larger and smaller the score, the better. ‘F’ means fully supervision, ‘W’ means weakly supervision, and ‘M’ means mixed supervision. The best performance is marked in BOLD, and the second best performance is marked in UNDERLINE. MINet Metric
WSSA
SCWS SOD
MFNet
NRDNet
PSOD
Ours
2020
2020
2021
2021
2022
2020
2021
2021
2022
2022
F
F
F
F
F
W
W
W
W
W
M
0.825
0.804
0.829
0.840
0.828
0.872
0.710
0.792
0.635
0.691
0.778
0.840
0.037
0.044
0.035
0.035
0.037
0.026
0.062
0.049
0.076
0.073
0.045
0.034
Sm ↑
0.884
0.861
0.895
0.892
0.896
0.911
0.804
0.841
0.775
0.781
0.853
0.877
↑
0.738
0.757
0.749
0.760
0.755
0.784
0.669
0.731
0.536
0.615
0.729
0.765
MAE ↓
0.055
0.054
0.051
0.051
0.058
0.041
0.068
0.060
0.087
0.088
0.064
0.048
Sm ↑
0.833
0.847
0.848
0.845
0.850
0.856
0.785
0.812
0.742
0.745
0.824
0.837
↑
0.911
0.910
0.906
0.920
0.910
0.926
0.835
0.875
0.765
0.808
0.902
0.913
MAE ↓
0.033
0.033
0.035
0.032
0.034
0.027
0.061
0.050
0.084
0.077
0.036
0.034
Sm ↑
0.925
0.928
0.929
0.928
0.932
0.935
0.866
0.882
0.834
0.834
0.914
0.916
↑
0.821
0.797
0.815
0.825
-
0.848
0.733
0.784
0.670
0.701
0.808
0.815
PASCAL-S MAE ↓
HKU-IS
SelfRe former
F
DUTS-TE MAE ↓
ECSSD
VST
2020
↑
DUTO
U2-Net GateNet-R DCNet
0.064
0.074
0.065
0.062
0.067
0.050
0.140
0.078
0.115
0.110
0.065
0.064
Sm ↑
0.857
0.844
0.865
0.862
0.873
0.874
0.797
0.820
0.770
0.768
0.853
0.850
↑
0.899
0.890
0.893
0.905
0.902
0.915
0.831
0.872
0.770
0.829
0.885
0.906
MAE ↓
0.028
0.031
0.029
0.027
0.030
0.024
0.047
0.037
0.059
0.051
0.032
0.028
Sm ↑
0.920
0.916
0.925
0.922
0.928
0.930
0.865
0.882
0.846
0.854
0.902
0.911
Fig. 3. Visual comparsion with state-of-the-art models.
70
C. Gong et al. Table 2. Test metrics in different stages. DUTS-TE
STAGE
↑
DUTO
MAE ↓
Sm ↑
↑
ECSSD
MAE ↓
Sm ↑
↑
MAE ↓
Sm ↑
I
0.637
0.076
0.746
0.591
0.093
0.723
0.781
0.072
0.824
Ⅱ
0.840
0.034
0.877
0.765
0.048
0.837
0.913
0.034
0.916
HKU-IS
PASCAL-S
STAGE
MAE ↓
Sm ↑
↑
MAE ↓
Sm ↑
I
0.703
0.093
0.776
0.737
0.071
0.798
Ⅱ
0.815
0.064
0.850
0.906
0.028
0.911
↑
3.3 Ablation Studies To demonstrate the effect of our training strategy and all of the components, a series of ablation studies are conducted. Effect of Our Training Strategy. We test the output of both stages separately. The predicted saliency maps are from the final output Pred. of LRNet. Table 2 shows the result of experiments. It can be seen that test metrics are not satisfactory with scribble annotations alone in stage I. However, with the addition of pixel-level label supervision, our method can achieve significant improvement. Figure 4 shows the outputs of the network in different stages. We find that the model can locate the salient objects and outline them roughly under the supervision of scribble annotations. Furthermore, the model predicts finer boundaries while maintaining the object’s structure in stage II, which means that sample distribution information and generic detail information are fused into the model at different stages.
Fig. 4. The output of the different stages of our model.
Feature Map Occlusion. To further demonstrate that the model learns the detail decoding ability rather than the salient information in stage II, we design the feature map occlusion experiments. We use the two-stage training strategy to train LRNet. During the testing, we occlude the right half of Pred.-S and observe the final prediction saliency map Pred. After losing guidance. The results are shown in Fig. 5. The prediction maps show a huge gap between the left and right sides. Because of losing the guidance, the predictions on the right side become confusing. However, these confusing predictions are really information about the boundaries and details of objects or backgrounds, such as the boundaries of hats. This is a good indication that the model sufficiently learns
A Simple Mixed-Supervised Learning Method
71
Fig. 5. Visual comparison of feature map occlusion experiment. “Pred.” and “Pred.-O” refer to the test output without and with occlusion of Pred.-S, respectively.
the detail information in stage II and completes the prediction under the guidance of Pred.-S. Effect of Modules. We design ablation experiments for the modules in LRNet, and the experimental results are shown in Table 3. We quantitatively demonstrate the contribution of the feature attention module(FA), and detail fusion module(DF). The experimental results show that both modules we design enhance the decoding ability of the model. Table 3. Effectiveness analysis of the modules in our proposed method. Setting
DUTS-TE
No
Fβω ↑
MAE ↓
Sm ↑
Fβω ↑
MAE ↓
Sm ↑
0.832
0.038
0.873
0.903
0.040
0.909
FA
DF
1 2
✓
3 4
✓
ECSSD
0.836
0.039
0.875
0.907
0.038
0.915
✓
0.835
0.037
0.874
0.908
0.037
0.911
✓
0.840
0.034
0.877
0.913
0.034
0.916
Applications in Downstream Tasks. Our method also works in the fields of camouflaged object detection(COD) and salient object detection in remote sensing images(RSI SOD). In the COD experiment, we use 6000 scribble annotations (5000 from COD10K [30], 1000 from CAMO [31]) and 650 pixel-level labels for training, where the scribble annotations are generated according to the method in [32]. In the RSI SOD experiment, the training data is from [33]. We use 1400 scribble samples and 300 pixel-level samples for training. Table 4 shows the experiment results in COD and RSI SOD. It can be seen that after the stage II of training, the model detection accuracy has been greatly improved, and our method outperforms the SOTA models on MAE and S-measure. Figure 6 shows a partial visual comparison, from which it can be found that our model has better boundary and structure predictions.
72
C. Gong et al.
Table 4. Test metrics in camouflaged object detection and salient object detection in remote sensing image. COD Methods
Sup
Stage
CAMO
COD10K
Fβω ↑
MAE ↓
Sm ↑
Fβω ↑
MAE ↓
Sm ↑
SINet [30]
F
0.606
0.100
0.751
0.551
0.051
0.771
CRNet [34]
W
0.641
0.092
0.735
0.576
0.049
0.733
Ours Ours
I
0.616
0.097
0.736
0.215
0.115
0.767
M
II
0.732
0.072
0.792
0.346
0.022
0.890
Sup
Stage
ORSSD Fβω ↑
MAE ↓
Sm ↑
Fβω ↑
MAE ↓
Sm ↑
0.762
0.026
0.862
0.726
0.014
0.861
I
0.582
0.048
0.747
0.338
0.057
0.614
II
0.853
0.018
0.869
0.804
0.010
0.842
RSI SOD Methods
SBANet [33]
W
Ours Ours
M
EORSSD
Fig. 6. Visual comparison in the fields of COD and RSI SOD
4 Conclusion In this paper, we proposed a simple mixed-supervised learning method for the SOD task, which addresses not only the difficulty of pixel-level annotation but also the sparsity of scribble annotations. Using the two-stage training strategy, the model sequentially learns localization and detail refinement capabilities from the scribble and pixel-level data in stages, which is consistent with the human visual attention mechanism. Mixedsupervised learning is a compromise method that balances data labeling efficiency and algorithm performance. When using it, it is necessary to set the proportion of hybrid data according to the specific situation. Based on our strategy, we designed a localization and refinement network, LRNet. Experimental evaluation demonstrated the effectiveness and generalization of our method.
A Simple Mixed-Supervised Learning Method
73
Acknowledgment. This work is supported by the National Natural Science Foundation of China [grant number 62076058].
References 1. Yun, Y.K., Lin, W.: Selfreformer: Self-refined network with transformer for salient object detection. arXiv preprint arXiv:2205.11283 (2022) 2. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. pp. 770–778 (2016) 3. Wang, W., et al.: Pvt v2: Improved baselines with pyramid vision transformer. Computational Visual Media 8(3), 415–424 (2022) 4. Liu, N., Zhang, N., Wan, K., Shao, L., Han, J.: Visual saliency transformer. In: Proceedings of the IEEE/CVF international conference on computer vision, pp. 4722–4732 (2021) 5. Wu, Z., Su, L., Huang, Q.: Decomposition and completion network for salient object detection. IEEE Trans. Image Process. 30, 6226–6239 (2021) 6. Piao, Y., Wang, J., Zhang, M., Lu, H.: Mfnet: Multi-filter directive network for weakly supervised salient object detection, pp. 4136–4145 (2021) 7. Liu, Y., Wang, P., Cao, Y., Liang, Z., Lau, R.W.H.: Weakly- supervised salient object detection with saliency bounding boxes. IEEE Transac- tions on Image Processing 30, 4423–4435 (2021) 8. Zhang, J., Yu, X., Li, A., Song, P., Liu, B., Dai, Y.: Weakly- supervised salient object detection via scribble annotations, pp. 12546–12555 (2020) 9. Gao, S., et al.: Weakly-supervised salient object detection using point supervison. National conference on artificial intelligence (2022) 10. Yu, S., Zhang, B., Xiao, J., Lim, E.G.: Structure-consistent weakly supervised salient object detection with local saliency coherence 35(4), 3234–3242 (2021) 11. Wang, X., Al-Huda, Z., Peng, B.: Weakly-supervised salient object detection through object segmentation guided by scribble annotations. In: 2021 16th International Conference on Intelligent Systems and Knowledge Engineering (ISKE), pp. 304–312. IEEE (2021) 12. Gao, F., et al.: Segmentation only uses sparse annotations: Unified weakly and semi-supervised learning in medical images. Med. Image Anal. 80, 102515 (2022) 13. Cong, R., et al.: A weakly supervised learning framework for salient object detection via hybrid labels. IEEE (2022) 14. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: International Conference on Medical image computing and computerassisted intervention, pp. 234–241. Springer (2015) 15. Pang, Y., Zhao, X., Xiang, T.-Z., Zhang, L., Lu, H.: Zoom in and out: A mixed-scale triplet network for camouflaged object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2160–2170 (2022) 16. Zhou, B., Yang, G., Wan, X., Wang, Y., Liu, C., Wang, H.: A simple network with progressive structure for salient object detection, pp. 397–408 (2021) 17. Zhao, K., Gao, S., Wang, W., Cheng, M.-M.: Optimizing the f-measure for threshold-free salient object detection, pp. 8849–8857 (2019) 18. Wang, L., et al.: Learning to detect salient objects with image-level supervision, pp. 136–145 (2017) 19. Yang, C., Zhang, L., Lu, H., Ruan, X., Yang, M.-H.: Saliency detection via graph-based manifold ranking, pp. 3166–3173 (2013) 20. Yan, Q., Xu, L., Shi, J., Jia, J.: Hierarchical saliency detection, pp. 1155–1162 (2013)
74
C. Gong et al.
21. Li, Y., Hou, X., Koch, C., Rehg, J.M., Yuille, A.L.: The secrets of salient object segmentation, pp. 280–287 (2014) 22. Li, G., Yu, Y.: Visual saliency based on multiscale deep features, pp. 5455–5463 (2015) 23. Ran, M., Lihi, Z.-M., Tal, A.: How to evaluate foreground maps? In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 248–255 (2014) 24. Perazzi, F., Philipp, K., Pritch, Y., Hornung, A.: Saliency filters: contrast based filtering for salient region detection. In: 2012 IEEE conference on computer vision and pattern recognition, pp. 733–740. IEEE (2012) 25. Fan, D.-P., Cheng, M.-M., Liu, Y., Li, T., Borji, A.: Structure- measure: A new way to evaluate foreground maps. In: Proceedings of the IEEE international conference on computer vision, pp. 4548–4557 (2017) 26. Zhao, X., Pang, Y., Zhang, L., Lu, H., Zhang, L.: Suppress and balance: a simple gated network for salient object detection. In: European conference on computer vision, pp. 35–51. Springer (2020) 27. Pang, Y., Zhao, X., Zhang, L., Lu, H.: Multi-scale interactive network for salient object detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9413–9422 (2020) 28. Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O.R., Jagersand, M.: U2-net: Going deeper with nested u-structure for salient object detection. Pattern recognition 106, 107404 (2020) 29. Piao, Y., Wu, W, Zhang, M., Jiang, Y., Lu, H.: Noise- sensitive adversarial learning for weakly supervised salient object detection. IEEE Transactions on Multimedia (2022) 30. Fan, D.-P., Ji, G.-P., Sun, G., Cheng, M.-M., Shen, J., Shao, L.: Camouflaged object detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 2777–2787 (2020) 31. Le, T.-N., Nguyen, T.V., Nie, Z., Tran, M., Sugimoto, A.: Anabranch network for camouflaged object segmentation. Computer Vision and Image Understanding 184, 45–56 (2019) 32. Wu, W., Qi, H., Rong, Z., Liu, L., Su, H.: Scribble- supervised segmentation of aerial building footprints using adversarial learning. IEEE Access 6, 58898–58911 (2018) 33. Huang, Z., Xiang, T.-Z., Chen, H.-X., Dai, H.: Scribble-based boundary-aware network for weakly supervised salient object detection in remote sensing images. ISPRS J. Photogramm. Remote. Sens. 191, 290–301 (2022) 34. He, R., Dong, Q., Lin, J., Lau, R.W.H.: Weakly-supervised camouflaged object detection with scribble annotations. arXiv preprint arXiv:2207.14083 (2022)
A Lightweight Detail-Fusion Progressive Network for Image Deraining Siyi Ding, Qing Zhu, and Wanting Zhu(B) Beijing University of Technology, Beijing 100124, China [email protected]
Abstract. The performance of image deraining algorithms has been significantly improved by introducing deep learning-based methods. However, their network structures have become more complicated and diverse, making it difficult to strike a balance between rain removal performance and processing speed. To address the aforementioned issues, we innovatively propose the Lightweight Detail-fusion Progressive Network (LDPNet) for image deraining, which can obtain more detailed rain-free images with fewer parameters and faster running speed. First, we decompose the challenging deraining task into multi-stage subtasks that gradually recover the degraded images. An effective combination of dense connections and the Gate Recurrent Unit allows our network to not only reuse features within each stage but also to transfer information between stages. This structure can achieve good performance with fewer parameters. Second, we design a multi-scale detail extraction block for recovering and reconstructing image details, which enhances the processing of detailed information by obtaining features with different receptive fields. Its lightweight design is achieved based on a depth-separable structure. Furthermore, we integrate a lightweight coordinate attention mechanism to achieve precise localization perception of detailed information in the rain removal region, which effectively improves the rain removal effect. The experimental results of the comprehensive datasets demonstrate the great superiority of our algorithm. Keywords: Image Deraining · Lightweight Design · Attention Mechanism · Multi-scale · Deep Learning
1 Introduction With the rapid development of technology, more and more computer vision tasks are applied to outdoor scenes, whose reliability and accuracy are more vulnerable to bad weather such as rain. The implementation of image rain removal algorithms can effectively improve the performance of subsequent tasks such as image classification and object detection. Therefore, it is of great research importance to perform rain removal and detail restoration on rainy images. The research on lightweight image deraining algorithms for better operation on resource-constrained end devices such as smartphones and autonomous vehicles is also an important challenge. Image rain removal algorithms can be divided into two main types: traditional optimization-based methods and deep learning-based methods. Traditional methods © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 75–87, 2023. https://doi.org/10.1007/978-981-99-4761-4_7
76
S. Ding et al.
often use various priors to explore rain streak features, including sparse encoding [1], Gaussian Mixture Models [2], and low-rank representation [3]. However, these methods are not ideal for deraining under heavy rain conditions and the running time of the models is too long. With the rapid development of deep learning, more and more scholars are using this technique to achieve excellent deraining results. These methods can largely compensate for the shortcomings of traditional methods. Many researchers use convolutional neural networks (CNN) as the backbone network. Based on this, some classic convolutional network architectures [4–6] have also been used for the deraining task. In addition, some more approaches [7, 8] propose the use of generative adversarial networks (GAN) for rain removal. Although the performance of these methods has been improved with the help of deep learning technology, there are still problems such as image color distortion and loss of detailed information. Meanwhile, these algorithms are more focused on using complex models to improve the accuracy of the algorithm, which leads to a long time for model training and reasoning. Therefore, how to efficiently obtain high-quality rain-free images is still the focus and difficulty of current research. Especially for scenes with high real-time requirements, such as automatic driving, the rain removal algorithm needs to be lightweight designed. Considering that the existing rain removal algorithms lack the network mechanism of detail recovery and run slowly, we propose Lightweight Detail-fusion Progressive Network (LDPNet). In this work, we use a progressively optimized densely connected structure with faster inference speed as the base framework to study and design the rain removal model. Based on this, we further introduce our proposed Multi-scale Detail Extraction Block (MDEB) and an Improved Coordinate Attention mechanism (I-CA) into the network to achieve a more thorough removal of rainwater without destroying the background details. Their structure is lightweight, which allows us to enhance the model effectiveness without adding too much extra runtime. Finally, we compare the proposed algorithm with six existing rain removal algorithms on the Rain100L [9], Rain100H [9], Rain1200 [10], and Rain800 [11] benchmark datasets, and real rainy images. The results show that our method is superior to other comparison methods in both visual and objective evaluation metrics. In addition, we confirm the effectiveness of each block through ablation experiments. Overall, this paper makes the following contributions: • We propose a multi-stage rain removal network with a dense connection as the basic structure for rain removal. Moreover, we introduce the convolutional Gate Recurrent Unit to effectively erase the deeper rain streaks in the image. Our network has fewer parameters and a faster running speed. • We design a novel Multi-scale Detail Extraction Block (MDEB) to recover and reconstruct the image background information. Introducing this block can give more detail to the image after rain removal with less additional runtime. • We further integrate an Improved Coordinate Attention mechanism (I-CA) with a lightweight structure into the network, which can efficiently locate comprehensive information about rain streaks and thus eliminate them more accurately.
A Lightweight Detail-Fusion Progressive Network
77
2 Related Work Early research on rain removal methods was mostly based on traditional optimization methods, which aimed to restore clean background images from rainy images. However, these traditional methods cannot achieve satisfactory rain removal effects. Driven by the success of deep learning in computer vision tasks, deep learning-based methods for rain removal have also made rapid progress, which largely makes up for the shortcomings of traditional methods. In recent years, many CNN-based rain removal methods have achieved advanced results, and structurally, these methods can be broadly classified into single-stage and multi-stage. Most rain removal methods are based on single-stage design. Inspired by ResNet, Fu et al. [12] designed a model to recover clear images from degraded images of rain streaks. A single network structure that does not consider cross-scale relationships has the potential for information loss during training. For this problem, Wang et al. [13] explored the cross-scale approach and intra-scale fusion method between networks to successfully accomplish the image deraining task. Based on the research of previous scholars, Jiang et al. [14] designed three different modules: Coarse Fusion Module (CFM), Fine Fusion Module (FFM), and Reconstruction Module (RM). They combined the pyramid structure and channel attention mechanism to complete the rain removal. Furthermore, many methods [15–17] used multi-stage image restoration methods to remove rain. Among them, Xia et al. [15] proposed a new deep model architecture for stage-by-stage rain removal. The context expansion network was used to complete the rain-removing work in every stage. And they introduced a recurrent neural network for the complete retention of useful information from the previous stage. Zamir et al. [16] proposed a new co-design that optimally balances spatial details and high-level contextual information. And they introduced a novel per-pixel adaptive module inspired by the attention of in-situ supervisors for reweighting local features. Besides, some methods [7, 8] are based on Generative Adversarial Networks (GANs) to sort the rain removal problem out. Cao [8] integrated dilated convolution and gated features into GAN, and achieves good effect by aggregating multi-level context information. In conclusion, deep learning-based image rain removal methods have become the main approach. However, these algorithms cannot balance the rain removal effect with real-time performance. Therefore, the deraining algorithm still needs to be improved and designed in a lightweight way so that it can meet the needs of practical scenarios that require high real-time performance.
3 Method In this paper, a Lightweight Detail-fusion Progressive Network (LDPNet) is proposed for scenarios with limited computational power or high requirements for real-time performance. In this section, we describe the proposed LDPNet in more detail. To this end, we first introduce the overall framework, then show the network modules in detail, and finally discuss the loss function of the network.
78
S. Ding et al.
3.1 Network Structure Because of the randomness of rain streaks, their directions and shapes are not fixed. A simple deep network cannot stably remove them. For this reason, we design a multistage network to divide the rain-removing process into several stages to better remove various complex rain streaks. Figure 1 outlines our proposed overall network structure. We concatenate the original image Xinput with the stage-deraining result XT as the input for the next stage T + 1 and use this structure as the basis for our model construction. Stacking several different subnetworks to build a multi-stage network will increase the parameters significantly, so our method performs N cycles of the same subnetwork. This approach achieves the deep deraining effect of a progressive structure without increasing the number of additional parameters.
Fig. 1. Overview of the Lightweight Detail Fusion Progressive Network (LDPNet) framework
Fig. 2. Detailed description of the LDPNet network structure
Figure 2 details one of the stages of our network, which contains a Rain Streaks Deep Removal Module (RSDRM), a Multi-scale Information Recovery and Reconstruction Module (MIRRM), and a Detail Location Sensing Module (DLSM). Among them, RSDRM is used to gradually remove rain streaks at different depth levels, MIRRM is used to recover and reconstruct background information at different scales, and DLSM is used to pinpoint rain streaks and avoid losing background information due to excessive rain removal.
A Lightweight Detail-Fusion Progressive Network
79
3.2 Rain Streaks Deep Removal Module To fully learn the rain streaks information and improve the network performance, we design the DenseBlock, which realizes feature reuse by interconnecting all layers while effectively deepening the network. The specific structure of DenseBlock is shown in Fig. 3. A 3 × 3 convolutional layer and ReLU activation function form a DenseLayer, and five identical DenseLayers are densely connected to each other. This connection enhances the feature transfer capability and utilizes the feature information more effectively with fewer parameters. The DenseBlock is used as the main structure of the RSDRM for initial rain removal. On this basis, to avoid losing background information or overly visible residual traces, RSDRM incorporates the idea of recurrent neural networks to gradually remove deeper rain streaks. We introduce the Convolutional Gated Recurrent Unit (ConvGRU) [18] for the deraining stage by stage, which fully learns the image information through two gate structures, the reset gate, and the update gate. The effective combination of the DenseBlock and ConvGRU can largely solve the problem of insufficient or excessive rain removal while facilitating the gradual removal of denser rain streaks in heavy rain scenarios.
Fig. 3. Structure diagram of DenseBlock
3.3 Multi-scale Information Recovery and Reconstruction Module MIRRM consists of a 3 × 3 convolutional layer and a Multi-scale Detail Extraction Block (MDEB), which can fuse the acquired multi-scale information into the backbone network and is used to compensate for the details lost in the image deraining process. Specifically, for rain streaks with different sizes, directions, and densities, we design MDEB for extracting information under different receptive fields, as shown in Fig. 4. This block can fully obtain the local features of the image, so it can recover and reconstruct the lost details to a large extent. In addition, depthwise separable convolution is used for feature extraction by combining channel-by-channel convolution and point-by-point convolution. Compared with the conventional convolution, its number of parameters is smaller and the operation cost is lower. To ensure the lightweight of this block, we use depthwise separable convolution (D-Conv) to operate on the feature map instead of partial normal convolution. At the same time, we use 1 × 1 convolution for both intermediate branches to reduce the dimensionality first, thus reducing the number of operations.
80
S. Ding et al.
Fig. 4. Structure diagram of Multi-scale Detail Extraction Block (MDEB)
Large-size convolutional kernels result in greater computational effort. Therefore, we adopt the idea of two 3 × 3 convolutional kernels instead of one 5 × 5 convolutional kernel to reduce the parameters while keeping the receptive field constant. Meanwhile, to further lighten the MDEB, SimpleGate [19] is used in this block to divide the features into two parts along the channel dimension and then multiply them instead of the activation function. SimpleGate brings more nonlinear expressiveness to the model while reducing the number of parameters. 3.4 Detail Location Sensing Module Attention mechanisms have been widely used in various computer vision tasks. The coordinate attention (CA) proposed by Hou et al. [20] has been shown to be enhanced for computer vision tasks by making good use of average pooling to aggregate relevant information in two different directions of the feature map. Inspired by CA, we add improved coordinate attention (I-CA) to DLSM to further enhance the network’s ability to perceive rain streaks. The specific structure of I-CA is shown in Fig. 5. It first decomposes channel attention into two feature encoding processes that aggregate features along different directions. Then, the generated feature maps containing location information are combined and subjected to 1 × 1 convolution, batch normalization, and nonlinear activation operations. Here, we use the MetaAconC [21] activation function, which dynamically learns the nonlinearity of the activation function and can significantly improve the performance of I-CA. After that, we perform split, 1 × 1 convolution, and nonlinear activation operations on the feature map to form a pair of orientation-aware and position-sensitive feature maps. They can complementarily enhance the representation of the target of interest. Finally, they are multiplied with the original feature maps along the channel to form an attention map containing location information. DLSM can significantly enhance the network’s ability to perceive rain streaks and improve the network’s performance.
A Lightweight Detail-Fusion Progressive Network
81
Fig. 5. Structure diagram of Improved Coordinate Attention (I-CA)
3.5 Loss Function In rainy weather, rain streaks will cause serious noise pollution to the videos and images Since this paper solves the image deraining problem, we need to measure the similarity between the deraining image and the original rain-free image. SSIM [22] is an important metric for evaluating the image quality to detect the similarity of two images of the same size, which takes into account the brightness (l), contrast (c), and structure (s) metrics, which is consistent with human visual perception and can make the learned results more detailed. Therefore, we choose negative SSIM as the loss function of the model in this paper. Where, for a T-stage model, we use XT to denote the image after deraining and Xgt to denote the rain-free image with a clean background. L = −SSIM(XT , Xgt )
(1)
4 Experimental Results and Analysis 4.1 Datasets In real scenarios, it is difficult to obtain two images with and without rain that are consistent in all parameters such as shooting angle and lighting. Therefore, our network is evaluated using four synthetic datasets, Rain100L, Rain100H, Rain1200 and, Rain800. Rain100L and Rain100H both contain 1800 image pairs for training and 100 image pairs for testing. The Rain1200 has 12000 pairs of images for training and 1200 pairs of images for testing. The Rain800 composite dataset consists of 700 training images and 100 test images. These four datasets differ in rain streak intensity and direction, which can validate the effectiveness of our method in several aspects. 4.2 Training Details and Evaluation Metrics The experiments are implemented in the Pytorch framework. In our experiments, we use an Nvidia RTX 2080 GPU for training and an Adam optimizer to iteratively update the parameters. We set the batch size to 12, the initial learning rate to 0.001, and epochs
82
S. Ding et al.
to 100. The learning rate is multiplied by 0.2 for decay when the number of training sessions reaches 30, 50, and 80. The number of network cycles T is set to 6. We use two commonly evaluated metrics: Peak Signal-to-Noise Ratio (PSNR) and Structural Similarity (SSIM) to evaluate the performance of our rain removal method. For these two metrics, higher values represent higher image quality, i.e., the deraining image is closer to a real clean rain-free image. 4.3 Evaluation on Synthetic Datasets Under the two metrics of PSNR and SSIM, we compare our method LDPNet with six methods on four synthetic datasets. Table 1 shows the results of the quantitative comparison in detail, LDPNet achieves the highest or second-best PSNR and SSIM on all test datasets. Taking the Rain100L dataset as an example, our method achieves a PSNR of 36.94, gaining 10.78 dB, 8.19 dB, 8.04 dB, 0.68 dB, 4.54 dB, and 1.17 dB over DIDMDN [23], RESCAN [15], UMRL [24], PReNet [25], MSPFN [14], and MPRNet [16], respectively. Compared with other methods, the PSNR of our method is improved by 19.45% on average. Table 1 also shows the results of the number of parameters and the running time of our method compared with other methods. LDPNet has only 0.10 million parameters and a running time of only 0.10 s. Compared with the latest outstanding algorithm MPRNet, LDPNet has only 2.75% of its parameters and is almost twice as fast as it. Figure 6 and Fig. 7 show the test results of different algorithms on various datasets. It can be observed from the figure that DIDMDN is not effective in removing rain, there are too many rain traces left, and the color of the image after rain removal is distorted. RESCAN and UMRL can effectively remove rain, but there are still a few traces of rain remaining. PReNet, MSPFN, and MPRNet are relatively clean for the removal of rain streaks, but the image background details are still partially blurred. Table 1. Comparison results of average PSNR and SSIM for four widely used synthetic rain datasets (including Rain100L, Rain100H, Rain1200 and Rain800). Also, the parameters for each method are compared with the runtime of 480 × 320 size images. Bolded indicates the best in the evaluation of this metric. Methods
Rain100L
Rain100H
Rain1200
Rain800
PSNR/SSIM
PSNR/SSIM
PSNR/SSIM
PSNR/SSIM
Params (M)
Times (s)
DIDMDN
26.16/0.868
16.94/0.635
28.74/0.923
22.49/0.835
0.14
0.49
RESCAN
28.75/0.908
25.07/0.832
29.40/0.909
24.93/0.878
0.50
0.54
UMRL
28.90/0.937
24.97/0.860
29.43/0.930
24.71/0.897
0.98
0.13
PReNet
36.26/0.983
27.89/0.898
30.34/0.933
24.73/0.884
0.17
0.15
MSPFN
32.40/0.933
28.66/0.860
31.60/0.913
27.01/0.851
15.82
0.50
MPRNet
35.77/0.974
29.38/0.917
31.49/0.939
28.89/0.927
3.64
0.19
Ours
36.94/0.985
28.97/0.921
32.08/0.951
29.26/0.933
0.10
0.10
A Lightweight Detail-Fusion Progressive Network
83
Fig. 6. Comparison of rain removal effects of various methods on synthetic datasets
Fig. 7. Comparison of rain removal details of various methods on a sample image
Compared to the above algorithms, our method can remove rain streaks more thoroughly, resulting in a more natural and realistic image with a higher degree of color restoration. Our rain-free image has no obvious blurring and rain removal traces, while the image has more details. This shows that the rain removal effect of the LDPNet proposed in this paper is more reliable and has better detailed processing than other comparison algorithms. 4.4 Evaluation on Real Rainy Images To assess the generality of our method, we also test it on real rain images. Since there are no pairs of real rain-free images, we use the unreferenced image evaluation metric Natural Image Quality Evaluator (NIQE) for evaluation. A smaller NIQE value indicates a higher image quality. It can be clearly seen from Table 2 that LDPNet achieves the best visual quality with the smallest NIQE of 3.52. Figure 8 shows the test results of LDPNet and other algorithms on a real rain image. The above experiments can prove that our method can not only obtain an excellent rain removal effect on synthetic images but also can effectively remove rain on real rain images, which has high practicality. 4.5 Ablation Experiments To fully understand the effect of each block on network performance improvement, we propose five ablation experiments and evaluate with the Rain100L dataset. As shown in
84
S. Ding et al. Table 2. Comparison results of average NIQE on the real images
Methods
DIDMDN
RESCAN
UMRL
PReNet
MSPFN
MPRNet
Ours
NIQE
4.13
3.98
4.11
3.85
4.06
3.59
3.52
Fig. 8. Comparison of rain removal effects of various methods on the real image
Table 3, rain streak removal by DenseBlock only will result in relatively low performance. Applying ConvGRU in RSDRM can lead to more complete rain streak removal. The introduction of I-CA and MDEB, i.e., adding DLSM and MIRRM to the network, both improve the network performance, but the best performance is introduced at the same time. It can be seen that the LDPNet proposed in this paper has the best results. To determine the optimal number of stages for our network, we conduct a comparative experiment for different stages T. As can be seen in Table 4, there is a significant improvement in the deraining effect as T increases from 1 to 6. When T = 7, the network performance decreases. The final results show that the best results are achieved when T = 6, with no redundancy. Therefore, we set the network to 6 stages. Table 3. Ablation experiments with different compositions. Modules DenseBlock
Different combinations of modules √ √ √ √
ConvGRU
√
√
√
√
√
√
I-CA MDEB
√ √
√
PSNR
35.04
36.21
36.83
36.37
36.94
SSIM
0.978
0.981
0.984
0.982
0.985
A Lightweight Detail-Fusion Progressive Network
85
Table 4. Comparison of LDPNet models with different T-stages
PSNR SSIM
T =1
T =2
T =3
T =4
T =5
T =6
T =7
30.41
33.06
34.96
36.02
36.65
36.94
36.91
0.940
0.967
0.978
0.982
0.984
0.985
0.985
5 Conclusion In this paper, we propose a Lightweight Detail-fusion Progressive Network (LDPNet) for image deraining. Based on the basic structure of the multi-stage, we fuse the stagederaining results with the original rain image as the input for the next stage. The effective combination of densely connected structures and Convolutional Gated Recurrent Unit allows us to achieve deep removal of rain. Meanwhile, we propose a new Multi-scale Detail Extraction Block (MDEB) to recover the lost detail information in the background. In addition, we incorporate Improved Coordinate Attention (I-CA) to achieve accurate rain removal. According to the experimental results on public datasets, LDPNet outperforms other algorithms with fewer parameters and faster running speed. Our method can achieve the PSNR value of 36.94 and the SSIM value of 0.985 in the Rain100L dataset, while the inference time is only 0.10 s. Due to the lightweight nature of our model, our model will be applied to the field of autonomous driving in the future, such as portable robots and autonomous driving chips. The simplicity of our model makes it easy to integrate with other algorithms, such as road obstacle detection applications on rainy days. Acknowledgments. This work is supported by Beijing Natural Science Foundation (4232017).
References 1. Wang, Y., Liu, S., Chen, C., Xie, D., Zeng, B.: Rain removal by image quasi-sparsity priors. arXiv: 1812.08348 (2018) 2. Li, R., Cheong, L.-F., Tan, R.T.: Single image deraining using scale-aware multi-stage recurrent network. arXiv: 1712.06830 (2017) 3. Guo, X., Xie, X., Liu, G., Wei, M., Wang, J.: Robust low-rank subspace segmentation with finite mixture noise. Pattern Recognit 93, 55–67 (2019) 4. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770–778. Las Vegas, NV, USA (2016) 5. Xingjian, S., Chen, Z., Wang, H., Yeung, D.-Y., Wong, W.-K., Woo, W.-c.: Convolutional LSTM network: a machine learning approach for precipitation nowcasting. Advances in Neural Information Processing Systems, 802-810 (2015) 6. Szegedy, C., et al.: Going deeper with convolutions. In: 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 1–9. Boston, MA, USA (2015)
86
S. Ding et al.
7. Li, J., Feng, H., Deng, Z., Cui, X., Deng, H., Li, H.: Image derain method for generative adversarial network based on wavelet high frequency feature fusion. In: Pattern Recognition and Computer Vision. PRCV 2022. Lecture Notes in Computer Science, vol. 13537. Springer, Cham (2022) 8. Cao, M., Gao, Z., Ramesh, B., Mei, T., Cui, J.: Single image deraining integrating physics model and density-oriented conditional GAN refinement. IEEE Signal Process. Lett. 28, 1635–1639 (2021) 9. Yang, W., et al.: Deep joint rain detection and removal from a single image. In: 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 1685–1694. Honolulu, HI, USA (2017) 10. Zhang, H., Patel, V.M.: Density-aware single image de-raining using a multi-stream dense network. In: Proceedings of the 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 695–704 (2018) 11. Zhang, H., Sindagi, V., Patel, V.M.: Image de-raining using a conditional generative adversarial network. IEEE Trans. Circuits Syst. Video Technol. 30(11), 3943–3956 (2020) 12. Fu, X., Huang, J., Zeng, D., Huang, Y., Ding, X., Paisley, J.: Removing rain from single images via a deep detail network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1715–1723 (2017) 13. Wang, C., Xing, X., Wu, Y., Su, Z., Chen, J.: DCSFN: deep cross-scale fusion network for single image rain removal. In: Proceedings of the 28th ACM International Conference on Multimedia. Association for Computing Machinery, 1643–1651. New York, NY, USA (2020) 14. Jiang, K., et al.: Multi-scale progressive fusion network for single image deraining. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8343–8352 (2020) 15. Li, X., Wu, J., Lin, Z., Liu, H., Zha, H.: Recurrent Squeeze-and-Excitation Context Aggregation Net for Single Image Deraining. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 262–277. Springer, Cham (2018). https://doi. org/10.1007/978-3-030-01234-2_16 16. Zamir, S.W., et al.: Multi-stage progressive image restoration. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 14816–14826 (2021) 17. Zheng, Y., Yu, X., Liu, M., Zhang, S.: Single-image deraining via recurrent residual multiscale networks. IEEE Transactions on Neural Networks and Learning Systems 33(3), 1310–1323 (2022) 18. Cho, K., et al.: Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation, arxiv e-print (2014) 19. Chen, L., Chu, X., Zhang, X., Sun, J.: Simple Baselines for Image Restoration. In: Avidan, S., Brostow, G., Cissé, M., Farinella, G.M., Hassner, T. (eds.) Computer Vision ECCV 2022. LNCS, vol 13667. Springer, Cham (2022) 20. Hou, Q., Zhou, D., Feng, J.: Coordinate Attention for Efficient Mobile Network Design, In: 2021 IEEE Conference on Computer Vision and Pattern Recognition, pp. 13708–13717 (2021) 21. Ma, N., Zhang, X., Liu, M., and Sun, J.: Activate or Not: Learning Customized Activation. arXiv: 2009.04759 (2020) 22. Wang, Z., Bovik, A.C., Sheikh, H.R., Simoncelli, E.P.: Image quality assessment: from error visibility to structural similarity. IEEE Trans. Image Process. 13(4), 600–612 (2004) 23. Zhang, H., Patel, V.M.: Density-aware single image de-raining using a multi-stream dense network. In: 2018 IEEE Conference on Computer Vision and Pattern Recognition, pp. 695– 704. Salt Lake City, UT, USA (2018)
A Lightweight Detail-Fusion Progressive Network
87
24. Yasarla, R., Patel, V.M.: Uncertainty guided multi-scale residual learning-using a cycle spinning CNN for single image de-raining. In: 2019 IEEE Conference on Computer Vision and Pattern Recognition, pp. 8397–8406. Long Beach, CA, USA (2019) 25. Ren, D., Zuo, W., Hu, Q., Zhu, P., Meng, D.: Progressive image deraining networks: a better and simpler baseline. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3932–3941 (2019)
SwinCGH-Net: Enhancing Robustness of Object Detection in Autonomous Driving with Weather Noise via Attention Shi Cao, Qing Zhu, and Wanting Zhu(B) College of Software Engineering, Beijing University of Technology, Beijing, China
Abstract. Object detection in autonomous driving requires high accuracy and speed in different weather. At present, many CNN-based networks have achieved high accuracy on academic datasets, but their performance disastrously degrade when images contain various kinds of noises, which is fatal for autonomous driving. In this paper, we propose a detection network based on shifted windows Transformer (Swin Transformer) called SwinCGH-Net, with a kind of new detector head based on lightweight convolution attention module, which makes full use of the attention mechanism in both feature extraction and detection stages. Specifically, we use Swin Transformer as backbone to extract feature in order to obtain effective information from a small amount of pixels as well as integrate global information. Then we further improve the robustness of the network through the detector head contained lightweight attention block S-CBAM. Furthermore, we use Generalized Focal Loss to calculate loss, which effectively enhances the representation ability of the model. Experiments on Cityscapes and Cityscapes-C datasets demonstrate the superiority and effectiveness of our method in different weather condition. With the increasing level of weather noise, our method shows strong robustness compared with previous method, especially in small object detection. Keyword: Object detection · Autonomous driving · Robustness · Attention mechanism · Weather noise
1 Introduction Object detection is an important part of autonomous driving. Although many object detection models perform well on academic datasets (COCO, PASCAL VOC), neural networks are easily disturbed by noises, such as blur, brightness, color, etc. [1]. These situations are often encountered in industrial applications like autonomous driving. As shown in Fig. 1, Rain, fog, snow, strong light and other weather conditions may cause a sharply decline on performance for most models [2]. To solve this problem, many advanced algorithms for removing noises have been proposed, such as RCDNet [3], HDCW-Net [4]. However, most of these noises removal algorithms are designed to remove a single kind of weather noise. In fact, the weather condition is difficult to predict in real automatic driving. Therefore, improving the detection accuracy through removing specific weather noise can not solve the underlying problems in autonomous driving. In © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 88–98, 2023. https://doi.org/10.1007/978-981-99-4761-4_8
SwinCGH-Net: Enhancing Robustness of Object Detection
89
addition, complex road conditions make matters worse. Pedestrians and automobiles in the distance occupy few pixels in the image, so the model should be sensitive to small object at the same time [5].
Fig. 1. It is necessary for autonomous vehicles to keep a high detection precision in different weather, such as snow, fog, strong light and rain.
CNNs are extremely susceptible to noise during both training and testing. Its performance relies heavily on the quality of the annotated images, and even randomly annotated annotations can be memorized [6]. So in actual detection tasks, when the model encounters problems that do not belong to the training data, it will show poor precision, and the misjudgment often happen. Even if the target object is still clearly visible to human eye, CNN network may cannot recognize the object [5], and the influence is even more fatal in automatic driving during severe weather. Compared with neural networks, the human visual system performs more robust [7], the origin of which is that humans can naturally find salient regions in complex scenes. In contrast, although CNNs can learn local features effectively, it is limited by the local perceptual characteristics of convolution and cannot effectively learn global semantic information [8]. Inspired by this observation, attention mechanism were introduced into computer vision which aims to mimic this aspect of the human visual system. Attention mechanism can be viewed as a dynamic weight adjustment process based on the features of input images [9]. In recent years, attention mechanisms such as DCNs [10] and ECANet [11] has attracted more and more attention, and has achieved remarkable results in image classification, detection, and segmentation tasks. The proposal of selfattention [12] made NLP develop rapidly. Then in computer vision area, ViT [13], Swin Transformer [14], DETR [15] were proposed subsequently. Attention-based networks have been proved to own great potential, even show a trend to replace convolution networks as general architecture in computer vision area [9]. Through experiments, we found that reasonable use of the attention mechanism can effectively enhance the robustness of the neural network and improve the generalization ability of the model to cope with different weather conditions in object detection of autonomous driving. We propose SwinCGH-Net, which uses Swin transformer as the backbone network for feature extraction, replaces the long sequence in ViT [13] with a sliding window, and embodies the receptive field in CNN in a hierarchical manner. At the same time, a lightweight convolution attention module S-CBAM using a smaller convolution kernel is embedded at the head of the detector, in order to find the area
90
S. Cao et al.
of attention in scenes with dense objects, and GFL is used to calculate the loss. The contributions can be summarized as follows: • Multi-head Self Attention and Shifted Window are used to obtain global information and context when extracting features. And Generaized Focal Loss is used to enhace the representation of the model. • A lightweight convolution attention module S-CBAM with a smaller convolution kernel are designed to integrate scattered information of small objects. And the module is embedded at GFL detector head to filter useful features. • Experimental results and ablation study on Cityscapes and Cityscapes-C weather noise dataset show that SwinCGH-Net has strong robustness and improves precision of small object detection. It also faster than two-stage detector, which keeps the advantage on detection speed of one-stage detector.
Fig. 2. The architecture of the proposed SwinCGH-Net which contains three parts. The first part is an Swin Transformer based feature extractor. The second part is a feature fusion pyramid to deal with features with various sizes. The third part is detector head(CG Head) with lightweight convolution attention module S-CBAM and generalized focal loss is used to calculate loss.
2 Method Following the overview in Fig. 2, the proposed framework consists three parts. Firstly, a self-attention based feature extractor which contains a four-stage Swin Transformer obtains feature maps layer by layer. Secondly, the feature maps generated in the second, third, and fourth stages are passed into the Feature Pyramid Network (FPN) feature fusion pyramid respectively, in order to mix features extracted in different stages. Then, lightweight convolution attention module S-CBAM at detector head (CG Head) put attention on those crucial features before the loss calculation. Finally, generalized focal loss is used to calculate loss to unite the classification score and the NMS score (Non-Max Suppression).
SwinCGH-Net: Enhancing Robustness of Object Detection
91
The detection head predicts the target object. For each stage of the encoder, an even number of Swin Blocks are used for feature extraction, and a series of downsampling is used to reflect the hierarchical structure and the process of expanding the receptive field similar to CNN in a layered manner, and then passed in FPN Upsampling reshapes the feature map to the same feature map size as the corresponding layer in the Swin encoder, and performs pixel-value addition with features extracted in Swin. 2.1 Swin Transformer Feature Extractor In this section, the feature extractor contains a four-stage Swin Transformer for better understanding the meaning of feature maps and integrating global information. It is necessary to divide the feature map (H × W × 3) into patches and embed vectors ( H4 × W4 × 96) through Patch Embedding before entering the first SwinBlock. Since the Swin Transformer is based on the window for attention calculation, the window size of 7 × 7 is used to divide the feature map into ( H4 × 7) × ( W4 × 7) windows, where each feature is a 96-dimensional vector. Each stage of feature extractor set differ in terms of the number of SwinBlock. We set 2,2,18,2 block respectively in our framework so as to integrate the context. The SwinBlock consists of a window-based multi-head self-attention (W-SMA) module and a sliding-window-based multi-head self-attention (SW-MSA) module. In the WMSA stage, the self-attention inside the respective window is calculated, and in the SW-MSA stage calculates the self-attention between different windows to obtain global information by offsetting on the original image. The two MSA modules and the NLP layer are connected through LayerNorm (LN). The whole SwinBlock process can be expressed as follows: zˆ l = W _MSA(LN (z l−1 )) + z l−1
(1)
z l = MLP(LN (ˆz l )) + zˆ l
(2)
zˆ l+1 = SW _MSA(LN (z l )) + z l
(3)
z l+1 = MLP(LN (ˆz l+1 )) + zˆ l+1
(4)
zˆ l , z l , zˆ l+1 , z l+1 represent the output features of W-MSA, MLP, SW-MSA, and MLP in the extractor respectively. The QKV matrix is used to calculate self-attention in W-MSA and SW-MSA module, where Q is the feature information that needs to be queried, K is the feature information waiting to be queried, and V is the actual obtained characteristic information. In the window, the Q of each patch and the K of the entire sequence will calculate the score, and then assign features based on the score, which will be applied to V. The final score after softmax is the final context result: Q · KT (5) sim(Q, K) = softmax √ dk
92
S. Cao et al.
A set of QKV matrices can obtain a set of features, and multiple sets of features can be obtained through multiple sets of QKV by using multiple heads, and the features extracted by each head are spliced together to obtain the final features. In our framework, 3 heads are used to process the initial 96-dimensional feature vector, that is to say, each head processes 32-dimensional features.
Fig. 3. The detector head called CG Head, which consists of two parts. The first part is S-CBAM at the front of the detector head, which contains channel attention module and spatial attention module. The second part is used to calculate loss, which include NMS score and bbox regression.
2.2 CG Head CG Head is responsible for detection task through feature maps that extracted by feature extractor and mixed by FPN. In weather conditions such as rain, snow, fog, frost, etc., object detection in autonomous vehicles face more complex environment than when the weather is sunny. Although we have extracted feature maps which include context and global information through self-attention, how to take advantage of multilayer features for prediction is more critical. S-CBAM. When facing more complex images to be detected caused by weather noise, S-CBAM module focus on important features and suppress unnecessary features through a convolution attention module with smaller convolution kernels, which proves to be more sensitive to small objects. This module consists of two parts: channel Attention and Spatial Attention, which is equivalent to performing two screenings on the extracted features, as shown in Fig. 3. Channel attention mainly focuses on which part of the image is important. The feature map F ∈ RC×H ×W obtains a 1 × 1 attention map through maximum pooling and average pooling respectively (MC_max ∈ RC×1×1 and MC_avg ). Then add them to get the channel attention result, and map the channel attention to MC ∈ (0, 1) through the sigmoid function to get the channel attention result: (6) MC = sigmoid MC_max ∈ RC×1×1 + MC_avg ∈ RC×1×1
SwinCGH-Net: Enhancing Robustness of Object Detection
93
The output of the channel module is used as the input feature map of the spatial module. First, do a channel-based maximum pooling as well as average pooling, and splice these two results on channel dimension. After a dimensionality reduction through convolution, finally the spatial attention map is generated by sigmoid: MS = sigmoid (AvgPool(F), MaxPool(F)).
(7)
Particularly, in the spatial attention convolution, a 3 ×3 convolution kernel is used, which is smaller than the original method (7 × 7) and achieves better results especially the precision of small object detection, and also reduced parameters and computation. GFL Loss Function. We use Generalized Focal Loss (GFL) [16] to calculate loss, in order to model uncertainty in complex scene as well as combine classification score and NMS score. The loss function consists of two parts: Quality Focal Loss (QFL) and Distribution Focal Loss (DFL). Quality Focal Loss (QFL) is one of the extended forms of Focal Loss [17] on continuous labels. It not only guarantees the characteristics of Focal Loss which can balance difficult and easy samples, but also allows it to support the supervision learning of continuous values. The loss function introduces a quality label y ∈ [0, 1] [0, 1] to replace the discrete label y ∈ 0, 1. The label y = 0 represents a negative sample with a quality score of 0, and 0 < y ≤ 1 represents a positive sample with a target IoU score of y. The specific expression of the sample is: QFL(σ ) = −|y − σ |β ((1 − y)log(1 − σ ) + ylog(σ ))
(8)
where |y − σ |β represents the absolute distance between the predicted result sigma and the real continuous label y, which is used for end-to-end training to solve the problem of inconsistency between training and evaluation. Due to the influence of weather noise, objects to be detected often have uncertain boundaries, but the real distribution is usually not too far away from the marked position. By enabling the network to quickly focus on the values near the marked position, can make the probability as high as possible, specifically expressed as: DFL(Si , Si+1 ) = −(yi+1 − y)log(S i ) + (y − yi )log(Si+1 ) Finally, QFL and DFL are unified into GFL: β GFL = −y − yl pyl + yr pyr (yr − y)log pyl + (y − yl )log pyr
(9)
(10)
3 Experiment 3.1 Dataset and Implementation Details We first evaluate the normal precision on Cityscapes, which is one of the biggest datasets for autonomous driving. It focuses on the urban road in real scenes, and the tasks are closer to the needs of autonomous driving. We use the leftImg8bit/gtFine, a finely labeled
94
S. Cao et al.
dataset. The entire dataset contains 50 European cities and 5000 finely labeled images. The images size are 1024 × 2048 (h × w). In addition, in order to evaluate the robustness of the model, we use Cityscapes-C [5], a dataset that contains a large number of corruptions, to test model’s robust precision. We used the weather noise damage in Cityscapes-C to simulate the weather conditions of autonomous driving for testing, including four weather conditions: snow, fog, frost, and strong light, and each weather condition contains 5 damage levels. We design experiments based on mmdetection and use RTX3090 for training. Each model is trained for 36 epochs and uses the SGD optimizer. The initial learning rate is 0.01, which decays once at 27 and 33 epoch respectively. We use precision (AP), recall (AR), and AP50 as well as AP75 when IoU values are 50% and 75% respectively. At the same time, we pay special attention to the performance of the model in small object detection, for mAP_s and AR_s are evaluated. 3.2 Results on Cityscapes
Table 1. Normal precision evaluation and recall evaluation of our method and other methods on the Cityscapes Dataset. Method
AP(%)
AP50(%)
AP75(%)
APs(%)
AR(%)
ARs(%)
Faster R-CNN
35.0
55.8
35.2
35.2
11.1
40.6
Mask R-CNN
33.7
56.5
35.2
35.2
10.3
41.3
YOLO V3
26.7
49.0
25.8
25.8
4.1
32.8
RetinaNet
35.6
57.6
36.9
36.9
6.8
44.4
SwinCGH-Net
42.1
65.3
43.2
43.2
16.1
52.3
We compare our method with other method on the Cityscapes dataset, including YOLO V3 [18], RetinaNet [17], Faster R-CNN [19] and Mask R-CNN [20]. YOLO V3 and RetinaNet is one-stage detectors, and Faster R-CNN and Mask R-CNN is twostage detector. It can be seen from Table 1 that the precision and recall of our method has an obvious enhancement compare with other networks, especially the AP of small objects reaches 16.1%, which achieves a better effect in one of the most difficult area in object detection. Obviously, self-attention mechanism in feature extractor and the lightweight convolution attention module at detector head can pay more attention on those parts which only take up very few pixels but contain crucial information. In real traffic scenes, small objects generally show as pedestrians and vehicles in the distance. Another condition is that there exist a lot of occlusions in traffic scenes, which cause uncertain edges, so the GFL can better calculate the position of bounding box. 3.3 Results on Cityscapes-C We chose Cityscapes-C weather noise damage to test robust precision of our model in snow, frost, fog, and bright light. It can be found that our method performs more robust
SwinCGH-Net: Enhancing Robustness of Object Detection
95
than others. As shown in Fig. 4, the descent curve of our model obvious smoother than others when the images contains weather noise. Especially under strong light conditions, the AP of our model only drops by 9.8% from level 0 to level 5. In contrast, weather noise clearly has a catastrophic effect on Faster R-CNN. With the increased level of weather noise, the AP of Faster R-CNN drops most obviously, even worse than the one-stage algorithm YOLO V3. After introducing weather noise, the complexity of the data increases a lot compare with clear images in Cityscapes, so various noises decline AP in different extent. It is necessary to filter out important feature information from images with weather noise. However, attention mechanism is adept in finding “what” is meaningful and “where” the meaningful information exists, greatly improving the representation ability of the network.
Fig. 4. Robustness Test of Different Networks on Cityscapes-C Weather Noise in various corrosion levels.
3.4 Ablation Study In ablation study, firstly, we replaced the feature extraction part in SwinCGH-Net with Resnet 101 with a consistent structure, in order to prove the effectiveness of the Swin feature extractor, and find that its effect dropped a lot after we change the backbone, presented in Table 2. Secondly, in order to prove that the best effect can only be achieved by adding S-CBAM at the detector head, we try to add S-CBAM to the backbone of SwinCGH-Net, but the performance in precision and speed is not as good as adding to the head. Finally, we test the convolution kernels of different sizes when calculating spatial attention, including 3 × 3, 5 × 5, 7 × 7, and 9 × 9, and found that using 3 ×
96
S. Cao et al.
Table 2. The ablation study aims to prove the effectiveness of the Swin feature extractor and adding the S-CBAM attention module to the head of the detector. And in the convolution attention module, using a smaller convolution kernel can achieve better precision and speed. Method
Backbone
AP(%)
APs(%)
AR(%)
ARs(%)
FPS(img/s)
Faster R-CNN
ResNet 101
35.0
11.1
41.3
13.5
9.4
GFL-r101
ResNet 101
38.8
12.9
49.0
19.8
22.8
GFL-Swin
Swin Transformer
40.8
15.0
51.0
21.4
16.5
SwinCGH-Net
Swin Transformer
42.1
16.1
52.3
23.2
15.9
Table 3. S-CBAM position and spatial attention partial convolution kernel size test. Position
Kernel Size AP(%) APs(%) AR(%) Flops(GFLOPs) FPS(img/s)
Head
9×9
41.6
14.7
51.7
313.78
14.7
Head
7×7
41.7
15.8
51.2
313.78
15.3
Head
5×5
41.7
15.4
51.7
313.70
15.7
Head
3×3
42.1
16.1
52.3
313.67
15.9
Backbone
3×3
41.7
14.4
51.8
313.80
12.8
Backbone + Head 3 × 3
42.3
15.5
52.0
313.81
12.5
3 convolution kernel can reach the fastest speed and the highest precision, presented in Table 3. It also achieved the best results in small object detection.
4 Conclusion In this paper, we present SwinCGH-Net, which significantly improves the robustness of object detection in bad weather for autonomous vehicles. We use Swin Transformer as the backbone to extract features, taking advantage of its self-attention mechanism to obtain global information, and use its hierarchical structure to keep the network structure consistent with CNNs while obtaining a dynamic and flexible receptive field. At the same time, we embed the lightweight convolution attention module S-CBAM in detector head to quickly screen out high-value information, then use generalized focal loss to reinforce the representation of model. Our method outperforms the previous methods on Cityscapes and Cityscapes-C, showing a better robustness on images with “weather noise”, as well as improving AP and AR on small object detection. The ablation analyses demonstrate the importance of each component in our model, especially using smaller convolution kernels in spatial attention can achieve better detection results. For the future work, compressing the model to speed up model’s inference speed, in order to further fit the real-time demand of automatic driving is also worth doing. Acknowledgements. This work is supported by Beijing Natural Science Foundation (4232017).
SwinCGH-Net: Enhancing Robustness of Object Detection
97
References 1. Laugros, A., Caplier, A., Ospici, M.: Are adversarial robustness and common perturbation robustness independant attributes? In: Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops, pp. 0–0 (2019) 2. Sakaridis, C., Dai, D., Van Gool, L.: Semantic foggy scene understanding with synthetic data. Int. J. Comput. Vision 126, 973–992 (2018) 3. Wang, H., Xie, Q., Zhao, Q., Meng, D.: A model-driven deep neural network for single image rain removal. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3103–3112 (2020) 4. Chen, W.-T., et al.: All snow removed: Single image desnowing algorithm using hierarchical dual-tree complex wavelet representation and contradict channel loss. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4196–4205 (2021) 5. Michaelis, C., et al.: Benchmarking robustness in object detection: Autonomous driving when winter is coming. arXiv preprint arXiv:1907.07484 (2019) 6. Wang, Y., Sun, X., Fu, Y.: Scalable penalized regression for noise detection in learning with noisy labels. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 346–355 (2022) 7. Geirhos, R., Temme, C.R., Rauber, J., Schütt, H.H., Bethge, M., Wichmann, F.A.: Generalisation in humans and deep neural networks. Advances in neural information processing systems 31 (2018) 8. Hu, H., Zhang, Z., Xie, Z., Lin, S.: Local relation networks for image recognition. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3464–3473 (2019) 9. Guo, M.-H., et al.: Attention mechanisms in computer vision: A survey. Computational Visual Media 8(3), 331–368 (2022) 10. Zhu, X., Hu, H., Lin, S., Dai, J.: Deformable convnets v2: More deformable, better results. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9308–9316 (2019) 11. Wang, Q., Wu, B., Zhu, P., Li, P., Zuo, W., Hu, Q.: Eca-net: efficient channel attention for deep convolutional neural networks. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 11534–11542 (2020) 12. Vaswani, A., et al.: Attention is all you need. Advances in neural information processing systems 30 (2017) 13. Dosovitskiy, A., et al.: An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020) 14. Liu, Z., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF international conference on computer vision, pp. 10012–10022 (2021) 15. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., Zagoruyko, S.: End-to-End Object Detection with Transformers. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12346, pp. 213–229. Springer, Cham (2020). https://doi.org/ 10.1007/978-3-030-58452-8_13 16. Li, X., et al.: Generalized focal loss: Learning qualified and distributed bounding boxes for dense object detection. Adv. Neural. Inf. Process. Syst. 33, 21002–21012 (2020) 17. Lin, T.-Y., Goyal, P., Girshick, R., He, K., Dollár, P.: Focal loss for dense object detection. In: Proceedings of the IEEE international conference on computer vision, pp. 2980–2988 (2017) 18. Redmon, J., Farhadi, A.: Yolov3: an incremental improvement. arXiv preprint arXiv:1804. 02767 (2018)
98
S. Cao et al.
19. Ren, S., He, K., Girshick, R., Sun, J.: Faster r-cnn: Towards real-time object detection with region proposal networks. Advances in neural information processing systems 28 (2015) 20. He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: Proceedings of the IEEE international conference on computer vision, pp. 2961–2969 (2017)
MBDNet: Mitigating the “Under-Training Issue” in Dual-Encoder Model for RGB-d Salient Object Detection Shuo Wang1,2 , Gang Yang1(B) , Yunhua Zhang1 , Qiqi Xu1 , and Yutao Wang1 1 Northeastern University, Shenyang 110819, China
[email protected] 2 DUT Artificial Intelligence Institute, Dalian 116024, China
Abstract. Existing RGB-D salient object detection methods generally rely on the dual-encoder structure for RGB and depth feature extraction. However, we observe that the encoders in such models are often not adequately trained to obtain superior feature representations. We name this problem the “under-training issue”. To this end, we propose a multi-branch decoding network (MBDNet) to suppress this issue. The MBDNet introduces additional decoding branches with supervision to form a multi-branch decoding (MBD) structure, facilitating the training of the encoders and enhancing the feature representation. Specifically, to ensure the effectiveness of the introduced supervision and improve the performance of additional decoding branches, we design an adaptive multi-scale decoding (AMSD) module. We also design a multi-branch feature aggregation (MBFA) module to aggregate the multi-branch features in MBD to further improve the detection accuracy. In addition, we design an enhancement complement fusion (ECF) module to achieve multi-modality feature fusion. Extensive experiments demonstrate that our MBDNet outperforms other state-of-the-art methods and mitigates the “under-training issue”. Keywords: Salient object detection · RGB-D · “under-training issue” · Feature fusion
1 Introduction Salient object detection aims to imitate the human visual attention mechanism to locate and segment the most noticeable object within a given image. As an image pre-processing technology, it is currently widely applied in a variety of relevant computer vision tasks, such as person re-identification and action recognition [1]. The conventional SOD task only uses RGB image for saliency map inference [27]. But in some challenging situations, including those with cluttered back-grounds or poor texture contrast, the predictions tend to get worse. Moreover, with the continuous development of depth sensing technology, the acquisition of depth information is becoming easier. Therefore, depth images are introduced into the SOD task to boost the detection performance with the abundant spatial and geometrical information [23]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 99–111, 2023. https://doi.org/10.1007/978-981-99-4761-4_9
100
S. Wang et al.
The prevailing RGB-D models are constructed on the dual-encoder architecture [5]. These models normally employ two parallel encoders to explicitly extract RGB and depth features separately. However, the amount of information contained in multi-modality data is inconsistent, different modalities are strong or weak. This prevents the encoder of each modality in the dual-encoder model from performing adequate feature learning during training. As a result, the dual-encoder model is unable to obtain the optimal prediction, which we refer to as the “under-training issue”. In other multi-modality learning fields such as visual question answering (VQA) [4], video classification [25], etc., some scholars have studied and discussed related issues. But in the field of RGB-D SOD, this issue is still unexplored. If more supervised information can be introduced to the model, then the encoders will be more adequately trained. So we introduce additional decoding branches and supervise their prediction maps to suppress the “under-training issue”. We refer to this structure as multi-branch decoding (MBD) and further propose an multi-branch decoding network (MBDNet). Specifically, the design concepts for each component in MBDNet are as follows: we design an adaptive multi-scale decoding (AMSD) module for multi-scale decoding of cross-level features. Then, to comprehensively aggregate the multibranch features in MBD, we design a multi-branch feature aggregation (MBFA) module. Furthermore, we also design an enhancement complement fusion (ECF) module to achieve the fusion of multi-modality features. In summary, our contributions can be summarized as follows: –A novel MBDNet is proposed for the RGB-D SOD task, which suppresses the “under-training issue” existing in the traditional dual-encoder models. It introduces two additional single-modality decoding branches with supervision via the proposed multibranch decoding (MBD) structure. –The AMSD module is designed to ensure the utilization of supervised information by decoding the cross-level features in a multi-scale way. Furthermore, the MBFA module and the ECF module are designed to achieve comprehensive aggregation of multi-branch features in MBD or multi-modal features in the encoding stage, respectively. –Extensive experiments demonstrate that our proposed method outperforms 14 stateof-the-art RGB-D methods on five public datasets.
2 Proposed Method 2.1 Overview Figure 1 illustrates the proposed MBDNet structure for RGB-D salient object detection. First, we extract multi-level features from paired RGB and depth images with two independent backbone networks (ResNet50 [11]), and reduce the number of channels for each level of features to 64 uniformly by RFB [26]. Moreover, we propose an enhancement complement fusion (ECF) module to complete the fusion of multi-modality features. Next, in order to mitigate the adverse effects of the “under-training issue”, we design a multi-branch decoding (MBD) structure. The MBD introduces two additional singlemodality decoding branches besides the conventional RGB-D branch decoding design,
MBDNet: Mitigating the “Under-Training Issue”
101
Multi-Branch Decoding (MBD) structure 1
2
En-1
3
En-2
1
En-3
2
ECF
En-2 1
ECF
AMSD
ECF
En-5
En-4 3
Enhancement Complement Fusion Unit
4
3
AMSD
2
AMSD
1
AMSD
5
5
ECF
En-3 2
En-5
4
ECF
4
5
En-4
3
ECF
En-1
4
MBFA
MBFA
MBFA
MBFA
4
3
2
1
AMSD
5
4
AMSD
AMSD 3
Adaptive Multi-Scale Decoding Unit
AMSD
AMSD 1
2
MBFA
Multi-Branch Feature Aggregation Unit
Fig. 1. An overview of the proposed MBDNet.
i.e., only RGB features or depth features are used for saliency prediction. So that the encoders of each modality are able to receive more training through supervision. In MBD, we design an adaptive multi-scale decoding (AMSD) module to obtain multiscale information in the decoding process of cross-level features. Last but not least, we design a multi-branch feature aggregation (MBFA) module by combining ECF and AMSD to make full use of multi-branch features in MBD. Details of each component are provided in the following sections. 2.2 Enhancement Complement Fusion Module Taking into account the complementarity of features from different modalities, we propose an enhancement complement fusion (ECF) module that combines multiple attention mechanisms to take advantage of one another’s complementary information to obtain a stronger feature representation. As shown in Fig. 2, the encoding features fRi and fDi from the i-th level are entered into ECF. We combine channel attention with spatial attention through multiplication to obtain corresponding attention maps MRi and MDi to emphasize the useful regions in current modality. Also, we can quickly calculate the corresponding reverse attention maps MRi and MDi , which are more concentrated on the regions that are useless in their respective modalities. Next, we utilize the positive attention map to highlight the useful information and filter out useless information, such as background and noise to obtain the enhanced i . features feRi and feD To collect the information that is valuable in another modality but not in the present modality, we further multiply the reverse attention map with the enhanced features by i . elements of another modality to obtain the complementary features fcRi and fcD Then, the respective enhanced features are added to the complementary features, and a residual connection is introduced to retain more original information to complete the complement and enhancement of the features. Finally, we adaptively complete the weighted fusion of features based on the relative contribution of each modality on the channel. In addition, the previous output of the (i + 1)-th ECF is introduced to retain more high-level semantic information. It is worth noting that when i = 5, the weighted fusion feature is the output of the current module.
102
S. Wang et al.
Fig. 2. Diagram of the enhancement complement fusion (ECF) module.
2.3 Multi-branch Decoding Structure Towards the goal of mitigating the adverse effect of the “under-training issue”, we design a multi-branch decoding (MBD) structure, as shown in green in Fig. 1, which consists of two main components: two additional single-modality decoding branches with supervision and a traditional RGB-D decoding branch. Single-Modality Decoding Branch: The additional single-modality decoding branches aim to predict saliency maps through the multi-level single-modality features and introduce supervision so that the corresponding encoder can be more adequately trained. To achieve better decoding and make sure that the additional supervision can effectively facilitate the training of the encoder, we design an adaptive multi-scale decoding (AMSD) module as shown in Fig. 3. The AMSD is able to adaptively select the most appropriate feature interaction and fully extract the multi-scale information contained in the cross-level features to enhance the decoding effect. i ,fi with different approaches to get fadd Specifically, we integrate the fmi and fmi+1 mul1 , − de i i fmul2 and fmax , where m refers to RGB or depth modality. Then, concatenate the above features together and compress the number of channels back to 64 while learning the optimal interaction method. i i i i tm = Conv3 ([fadd , fmul1 , fmul2 , fmax ])
(1)
Next, we send tm into the modified Res2Net [10] block and repeat the convolution operation to gradually collect the decoding information of a larger scale, the acquisition process of multi-scale feature Fi can be expressed as: Fi = {
i=1 tm i=2 Conv3i (tm ) Conv3i (tm + Fi−1 ) 2 < i ≤ 4
(2)
Finally, we concatenate all of the Fi together, pass them through Conv1 , and further introduce a residual structure to retain the original features, thus improving the ability
MBDNet: Mitigating the “Under-Training Issue”
103
Fig. 3. Diagram of the adaptive multi-scale decoding (AMSD) module.
to detect multi-scale objects. The output fmi − de of AMSD can be expressed as: fmi − de = tm + Conv1 ([F1 , F2 , F3 , F4 ])
(3)
RGB-D Decoding Branch: The RGB-D decoding branch mainly utilizes the multii to generate the final predicted saliency map. In this process, level RGB-D features ffuse based on the aforementioned ECF and AMSD, we develop a multi-branch feature aggregation (MBFA) module as shown in Fig. 4. The MBFA introduces the decoding features of additional branches in MBD as auxiliary data into the RGB-D decoding process to produce more accurate prediction.
Fig. 4. Diagram of the multi-branch feature aggregation (MBFA) module.
i and As with the single-modality decoding branch, the RGB-D encoding features ffuse i+1 are fed into the AMSD for preliminary processing to obtain decoding features ffuse − de i ftemp . Meanwhile, we simply employ the basic convolution block to integrate the fRi− de i and fDi − de from the same level into fother : i , f i+1 ), f i i i i = AMSD(ffuse ftemp fuse− de other = Conv3 ([fR− de , fD− de ]) ,
(4)
i i Taking into account that ftemp and fother have some complementary information, the ECF is further adopted to achieve more comprehensive feature interaction and obtain better decoding features: i i = ECF(ftemp ,f i ). ffuse − de − de other− de
(5)
104
S. Wang et al.
2.4 Loss Function In our proposed MBDNet, the last layer of features of each of the three decoding branches is employed to predict the corresponding saliency maps, denoted as SR , SD , and SRGBD respectively. During training, we supervise three saliency maps simultaneously with the information-guided loss (IGL) proposed in [31]. The total loss can be formulated as: L = α1 IGL (PR ) + α2 IGL (PD ) + α3 IGL (PRGBD )
(6)
where α1 , α2 , and α3 are the trade-off parameters, and we set them all to 1 experimentally to get a better prediction.
3 Experiments 3.1 Datasets and Evaluation Metrics In order to completely validate the effectiveness of the proposed MBDNet, we conduct the performance evaluation on five public RGB-D benchmark datasets, including NJUD (1985) [14], NLPR (1000) [20], DUT (1200) [22], SIP (929) [8], STERE (1000) [18]. The specific configuration of the training and testing sets follows [15]. And four widely used metrics, including mean F-measure (Fβ ) [2], weighted F-measure (Fβw ) [17], mean absolute error MAE) [21] and enhanced-alignment measure (Em ) [7] are adopted to evaluate various methods. 3.2 Comparisons with State-Of-The-Art Methods To completely illustrate the effectiveness of MBDNet, we compare its performance with 14 state-of-the-art RGB-D SOD methods from the last three years including HDFNet [19], S2MA [16], D3Net [8], DSA2F [23], DCF [12], UC-Net [29], JLDCF [9], CDNet [13], BBSNet [28], DFMNet [30], HINet [3], CIR-Net [6], DCMF [24] and MVSalNet [32]. Quantitative Evaluation. Table 1 shows the quantitative comparison results of four evaluation metrics. As can be observed, the Fβ , Fβw and MAE metrics of MBDNet on the NJUD, NLPR, DUT and STETE datasets are overwhelmingly superior to competing methods. Compared to the CIRNet [6], which has the second-best overall performance on the NJUD, NLPR and DUT datasets, the MBDNet improves on these metrics by an average of about 1.2%, 1.6% and 5.7%. Furthermore, MBDNet does not perform as well on the SIP dataset, but its Fβ and Fβw metrics still achieve the best and second-best performance, respectively. The quantitative comparison results clearly demonstrate the advantages and superiority of our proposed method for the RGB-D SOD task. Qualitative Evaluation. To further demonstrate the superiority of our method, Fig. 5 shows the visual comparison of the proposed MBDNet with other RGB-D SOD methods in various scenes. It can be observed that our method not only provides excellent detection in general scenes (1st–2nd rows) but also achieves accurate segmentation in challenging scenes, including complex background (3rd–4th rows), low-quality depth image (5th– 6th rows), multiple objects (7th–8th rows), and low contrast (9th–10th rows). These
MBDNet: Mitigating the “Under-Training Issue”
105
Table 1. Quantitative comparisons of Fβ , Fβw , MAE, and Em on 5 widely-used RGB-D datasets. "↑" and "↓" indicate that larger or smaller is better. The best three results are shown in red, green, blue fonts respectively. HDFNet S2MA D3Net DSA2F DCFNet UCNet JLDCF CDNet BBSNet DFMNet HINet CIRNet DCMF MVSalNet Metric
NJUD
NLPR
DUT
SIP
STERE
CVPR TPAMI TPAMI
TIP
TIP
arxiv
PR
TIP
TIP
ECCV
2020
2020
2020
2021
2021
2021
2021
2021
2021
2022
2022
2022
2022
2022
Ė
0.893
0.858
0.879
0.898
0.896
0.887
0.897
0.903
0.903
0.893
0.896
0.909
0.879
0.903
Ė
0.877
0.832
0.854
0.882
0.877
0.870
0.881
0.884
0.884
0.875
0.876
0.891
0.855
0.884
0.902
0.038
0.058
0.046
0.039
0.039
0.041
0.039
0.038
0.038
0.043
0.039
0.035
0.045
0.036
0.033
Ę
ECCV CVPR TNNLS CVPR
Ours
0.917
Ė
0.944
0.927
0.939
0.939
0.943
0.940
0.954
0.949
0.949
0.946
0.945
0.955
0.947
0.949
0.953
Ė
0.895
0.874
0.873
0.896
0.897
0.891
0.896
0.898
0.896
0.893
0.883
0.901
0.875
0.906
0.917 0.904
Ė
0.882
0.852
0.848
0.880
0.884
0.878
0.882
0.882
0.879
0.877
0.871
0.884
0.856
0.893
0.023
0.030
0.030
0.024
0.024
0.025
0.024
0.025
0.023
0.024
0.026
0.023
0.029
0.021
0.020
Ė
0.963
0.953
0.953
0.952
0.957
0.956
0.963
0.960
0.961
0.961
0.957
0.966
0.954
0.966
0.964
Ė
0.885
0.884
0.717
0.924
0.921
0.851
0.860
0.924
0.904
0.915
/
0.921
0.905
0.911
0.929
Ė
0.867
0.868
0.673
0.914
0.912
0.828
0.839
0.911
0.886
0.902
/
0.906
0.891
0.896
0.918
0.041
0.042
0.096
0.030
0.031
0.056
0.054
0.030
0.036
0.031
/
0.029
0.034
0.034
0.029
Ė
0.945
0.935
0.833
0.956
0.956
0.903
0.919
0.958
0.954
0.956
/
0.959
0.956
0.953
0.960
Ė
0.876
0.855
0.839
0.866
0.876
0.868
0.877
0.876
0.869
0.873
0.840
0.876
/
/
0.883
Ė
0.848
0.818
0.799
0.828
0.840
0.836
0.849
0.839
0.830
0.842
0.800
0.840
/
/
0.848
0.047
0.057
0.063
0.057
0.052
0.051
0.047
0.053
0.055
0.049
0.660
0.052
/
/
0.052
Ę
Ę
Ę Ė
0.930
0.919
0.909
0.912
0.922
0.919
0.931
0.92
0.922
0.926
0.899
0.924
/
/
0.925
Ė
0.871
0.854
0.866
0.892
0.890
0.884
0.887
0.888
0.883
0.878
0.960
/
0.873
0.896
0.902
Ė
0.852
0.824
0.838
0.868
0.872
0.867
0.854
0.866
0.858
0.860
0.831
/
0.848
0.877
0.884
0.041
0.051
0.046
0.039
0.037
0.039
0.042
0.038
0.041
0.040
0.049
/
0.043
0.036
0.034
0.943
0.932
0.938
0.942
0.948
0.944
0.942
0.946
0.942
0.948
0.933
/
0.946
0.949
0.948
Ę Ė
are made possible by our proposed MBD, which facilitates the encoders in extracting stronger feature representations and further indicates the robustness of our proposed method.
3.3 Ablation Studies To demonstrate the "under-training issue" of the encoder in the dual-encoder model and the effectiveness of the proposed method, we conduct a series of ablation studies on the NJUD, NLPR, and DUT datasets. Are the encoders in MBDNet really trained more fully? We design an experiment as shown in Fig. 6 with the following procedure: 1) We construct five models M1–M5, as shown in the yellow region of Fig. 6. The details of each model are provided in the figure caption. We train them from scratch and name their RGB encoders sequentially, from Er1 to Er5. 2) Then we construct a referee model as shown in the orange region of Fig. 6 to evaluate the performance of the encoder.
S. Wang et al.
Low Contrast
Multiple Objects
Low-quality Depth Image
Complex Background
General Scenes
106
(a)RGB
(b)Depth
(c)GT
(d)Ours
(e)CIRNet (j)BBSNet (g)DSA2F
(h)DCF
(i)CDNet
(f)D3Net
(k)S2MA
(l)HDFNet
Fig. 5. Visual comparison with eight state-of-the-art RGB-D models in various scenes.
3) Next, we load the parameters of each RGB encoder (Er1-Er5) into the referee model, train the referee model, and record its performance metrics while keeping the encoder parameters unchanged.
D
D
D
D
Er3
Fusion
D
D
D
D
D
D
Fusion
Ed4
Er5
Fusion
Ed5
E
load Er1
M1
Er2
Fusion
Ed2
M2
Ed3
Er4
M4
M3 Supervision
E
Encoder
D
Decoder
M5 Fusion
Er i(i=1~5)
Referee
Feature Fusion
Fig. 6. Experiments to evaluate the performance of encoders for different models. M1 and the referee model are both single-encoder models that predict with only one modality of data. M5 is the proposed MBDNet, and we degenerate M5 to the typical dual-encoder model M2 by remove all additional decoding branches from M5. Subsequently, an additional RGB or depth decoding branch and supervision are introduced to M2 to create M3 and M4, respectively.
4) Finally, we employ the same approach to evaluate the performance of the depth encoder (Ed1-Ed5) in M1-M5.
MBDNet: Mitigating the “Under-Training Issue”
107
The results are shown in Table 2. The referee model’s metrics are used to assess the performance of the encoder loaded in the current model, the higher the metric, the more powerful the encoder. Since M1 contains only one encoder, the training process does not involve simultaneous optimization of multiple encoders, so we consider encoders Er1 and Ed1 have received adequate training to regard them as baselines. M2 is a typical dual-encoder model. It is clear from the results that the performance of Er2 and Ed2 is significantly lower than that of Er1 and Ed1, i.e., the dual-encoder model performs worse than the single-encoder model in terms of encoder performance. This means that even though the dual-encoder model has converged, the RGB and depth encoders still have not been adequately trained. We refer to this phenomenon as the "under-training issue". Table 2. Performance evaluation of the RGB and depth encoders for M1-M5 with the referee model in Fig. 6. The best results of each modality are shown in BOLD. Encoder
NJUD
NLPR
DUT
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Er1
0.897
0.881
0.037
0.941
0.895
0.885
0.024
0.955
0.917
0.905
0.032
0.954
Er2
0.884
0.862
0.044
0.930
0.885
0.874
0.025
0.952
0.911
0.895
0.035
0.948
Er3
0.896
0.875
0.040
0.940
0.894
0.881
0.025
0.955
0.914
0.898
0.035
0.951
Er4
0.887
0.865
0.042
0.932
0.885
0.871
0.027
0.952
0.909
0.892
0.036
0.947
Er5
0.896
0.877
0.039
0.939
0.894
0.880
0.025
0.953
0.915
0.901
0.034
0.951
Ed1
0.843
0.809
0.063
0.904
0.829
0.803
0.041
0.925
0.778
0.732
0.088
0.864
Ed2
0.822
0.784
0.072
0.888
0.785
0.750
0.052
0.900
0.690
0.616
0.119
0.818
Ed3
0.820
0.783
0.072
0.888
0.781
0.749
0.054
0.897
0.675
0.614
0.126
0.814
Ed4
0.826
0.791
0.070
0.890
0.811
0.784
0.047
0.914
0.743
0.691
0.100
0.848
Ed5
0.830
0.796
0.065
0.895
0.825
0.795
0.044
0.918
0.743
0.694
0.100
0.845
On top of M2, M3 introduces an RGB decoding branch, M4 introduces a depth branch, and M5 introduces both RGB and depth decoding branches. The results reveal that Er3 and Er5 have superior encoding quality than Er2, while Er4 is comparable to Er2. This indicates that while introducing an additional decoding branch does assist the corresponding encoder receive more training and mitigate the effects of the “undertraining issue”, it does not help the other encoder. We also observe that the performance of Er3 and Er5 is still inferior to that of Er1, indicating that the proposed MBD does not completely address the “under-training issue” and can only partially inhibit the phenomenon. The analysis of the evaluation results for depth encoders from different models leads us to the same conclusions as for RGB modality, which will not be repeated here. Effectiveness of Multi-Branch Decoding Structure: To demonstrate the effectiveness of the proposed MBD in enhancing the overall performance of the model, we conduct
108
S. Wang et al.
additional experiments on the M2, M3, M4, and M5 depicted in Fig. 6. The results are presented in Table 3. The quantitative results indicate that models M3, M4, and M5 all perform better than M2, suggesting that the introduction of the additional decoding branch is obviously beneficial for improving model performance. Furthermore, M5 exhibits the best performance among the four models, which demonstrates that the more branches introduced, the greater the performance improvement. Table 3. Ablation studies on the proposed multi-branch decoding (MBD) structure. The best results are shown in BOLD. Encoder
NJUD
NLPR
DUT
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Er1
0.897
0.881
0.037
0.941
0.895
0.885
0.024
0.955
0.917
0.905
0.032
0.954
Er2
0.884
0.862
0.044
0.930
0.885
0.874
0.025
0.952
0.911
0.895
0.035
0.948
Er3
0.896
0.875
0.040
0.940
0.894
0.881
0.025
0.955
0.914
0.898
0.035
0.951
Er4
0.887
0.865
0.042
0.932
0.885
0.871
0.027
0.952
0.909
0.892
0.036
0.947
Er5
0.896
0.877
0.039
0.939
0.894
0.880
0.025
0.953
0.915
0.901
0.034
0.951
Table 4. Ablation studies on the proposed modules, including ECF, AMSD, and MBFA. The best results are shown in BOLD. No
1 2 3 4 5 6 7 8
baseline
ECF
AMSD
MBFA
√ √
√
√
√
√
√
√
√
√
√
√ √
√
√ √ √
√
√
√
NJUD
NLPR
DUT
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
Fβ ↑
Fβw ↑
MAE↓
Em ↑
0.902
0.883
0.039
0.940
0.899
0.886
0.024
0.958
0.917
0.900
0.035
0.947
0.903
0.888
0.036
0.940
0.903
0.895
0.021
0.960
0.923
0.909
0.031
0.957
0.905
0.888
0.036
0.942
0.904
0.891
0.022
0.958
0.923
0.903
0.032
0.950
0.904
0,885
0.037
0.942
0.904
0.893
0.022
0.960
0.926
0.909
0.030
0.953
0.907
0.893
0.035
0.944
0.906
0.896
0.021
0.960
0.927
0.913
0.029
0.955
0.909
0.894
0.035
0.946
0.907
0.896
0.021
0.960
0.928
0.916
0.029
0.958
0.911
0.895
0.033
0.946
0.905
0.895
0.022
0.958
0.924
0.912
0.030
0.954
0.917
0.902
0.033
0.953
0.917
0.904
0.020
0.964
0.929
0.918
0.029
0.960
According to the joint analysis of the above results, adequacy of encoder training indeed influences the performance of the dual-encoder model. The proposed MBD can effectively facilitate more adequate training for each encoder, and further enhance the overall model detection performance. These findings demonstrate the contribution of MBD to suppress the “under-training issue”. Effectiveness of Proposed Modules: To demonstrate the effectiveness of the proposed modules, including ECF, AMSD, and MBFA, we conduct various ablation studies, the
MBDNet: Mitigating the “Under-Training Issue”
109
quantitative results are shown in Table 4. We replace each module with a more fundamental structure while preserving its essential functionalities to build the baseline. Then we add the modules to the baseline to observe changes in the metrics. The results show that adding each module to the baseline individually can improve the performance. While adding different modules in pairs to the baseline achieves higher performance. And if all three modules are added to the baseline at the same time, the complete model, i.e., the proposed MBDNet, achieves the best performance.
4 Conclusion In this paper, we attempt to mitigate the adverse effect of the "under-training issue" in the dual-encoder model. We propose a novel multi-branch decoding network (MBDNet) that introduces additional decoding branches with supervi- sion into the network through the multi-branch decoding (MBD) structure. In addition, we design an AMSD to ensure that the introduced supervision is ade- quately employed to facilitate the training of the encoder. We also design ECF and MBFA to comprehensively aggregate the multimodality features in the en- coding stage or the multi-branch features in the MBD, respectively. Extensive experiments demonstrate that MBDNet outperforms state-ofthe-art models on five benchmark datasets. Comprehensive ablation experiments also prove that encoders in MBDNet are more adequately trained, i.e., the effect of the “undertraining issue” in the RGB-D SOD task is mitigated to a certain extent, further verifying our contribution. Acknowledgment. This work is supported by the National Natural Science Foundation of China under Grant No. 62076058.
References 1. Abdulmunem, A., Lai, Y.-K., Sun, X.: Saliency guided local and global descriptors for effective action recognition. Computational Visual Media 2(1), 97–106 (2016). https://doi.org/10. 1007/s41095-016-0033-9 2. Achanta, R., Hemami, S., Estrada, F., Susstrunk, S.: Frequency-tuned salient region detection. In: 2009 IEEE conference on computer vision and pattern recognition, pp. 1597–1604. IEEE (2009) 3. Bi, H., Wu, R., Liu, Z., Zhu, H., Zhang, C., Xiang, T.Z.: Cross-modal hierarchical interaction network for rgb-d salient object detection. Pattern Recogn. 136, 109194 (2023) 4. Cadene, R., Dancette, C., Cord, M., Parikh, D., et al.: Rubi: Reducing unimodal biases for visual question answering. Advances in neural information processing systems 32 (2019) 5. Chen, H., Li, Y.: Progressively complementarity-aware fusion network for rgb-d salient object detection. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3051–3060 (2018) 6. Cong, R., et al.: Cir-net: Cross-modality interaction and refinement for rgb-d salient object detection. IEEE Trans. Image Process. 31, 6800–6815 (2022) 7. Fan, D.P., Gong, C., Cao, Y., Ren, B., Cheng, M.M., Borji, A.: Enhanced-alignment measure for binary foreground map evaluation. arXiv preprint arXiv:1805.10421 (2018)
110
S. Wang et al.
8. Fan, D.P., Lin, Z., Zhang, Z., Zhu, M., Cheng, M.M.: Rethinking rgb-d salient object detection: Models, data sets, and large-scale benchmarks. IEEE Transactions on Neural Networks and Learning Systems 32(5), 2075–2089 (2020) 9. Fu, K., Fan, D.P., Ji, G.P., Zhao, Q., Shen, J., Zhu, C.: Siamese network for rgb-d salient object detection and beyond. IEEE transactions on pattern analysis and machine intelligence (2021) 10. Gao, S.H., Cheng, M.M., Zhao, K., Zhang, X.Y., Yang, M.H., Torr, P.: Res2net: A new multi-scale backbone architecture. IEEE Trans. Pattern Anal. Mach. Intell. 43(2), 652–662 (2019) 11. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778 (2016) 12. Ji, W., et al.: Calibrated rgb-d salient object detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 9471–9481 (2021) 13. Jin, W.D., Xu, J., Han, Q., Zhang, Y., Cheng, M.M.: Cdnet: Complementary depth network for rgb-d salient object detection. IEEE Trans. Image Process. 30, 3376–3390 (2021) 14. Ju, R., Ge, L., Geng, W., Ren, T., Wu, G.: Depth saliency based on anisotropic center-surround difference. In: 2014 IEEE international conference on image processing (ICIP), pp. 1115– 1119. IEEE (2014) 15. Li, C., Cong, R., Piao, Y., Xu, Q., Loy, C.C.: Rgb-d salient object detection with crossmodality modulation and selection. In: European Conference on Computer Vision, pp. 225– 241. Springer (2020) 16. Liu, N., Zhang, N., Han, J.: Learning selective self-mutual attention for rgb-d saliency detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 13756–13765 (2020) 17. Margolin, R., Zelnik-Manor, L., Tal, A.: How to evaluate foreground maps? In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 248–255 (2014) 18. Niu, Y., Geng, Y., Li, X., Liu, F.: Leveraging stereopsis for saliency analysis. In: 2012 IEEE Conference on Computer Vision and Pattern Recognition, pp. 454–461. IEEE (2012) 19. Pang, Y., Zhang, L., Zhao, X., Lu, H.: Hierarchical dynamic filtering network for rgb-d salient object detection. In: European Conference on Computer Vision, pp. 235–252. Springer (2020) 20. Peng, H., Li, B., Xiong, W., Hu, W., Ji, R.: Rgbd salient object detection: a benchmark and algorithms. In: European conference on computer vision, pp. 92–109. Springer (2014) 21. Perazzi, F., Krähenbühl, P., Pritch, Y., Hornung, A.: Saliency filters: contrast based filtering for salient region detection. In: 2012 IEEE conference on computer vision and pattern recognition, pp. 733–740. IEEE (2012) 22. Piao, Y., Ji, W., Li, J., Zhang, M., Lu, H.: Depth-induced multi-scale recurrent attention network for saliency detection. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 7254–7263 (2019) 23. Sun, P., Zhang, W., Wang, H., Li, S., Li, X.: Deep rgb-d saliency detection with depth-sensitive attention and automatic multi-modal fusion. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 1407–1417 (2021) 24. Wang, F., Pan, J., Xu, S., Tang, J.: Learning discriminative cross-modality features for rgb-d saliency detection. IEEE Trans. Image Process. 31, 1285–1297 (2022) 25. Wang, W., Tran, D., Feiszli, M.: What makes training multi-modal classification networks hard? In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12695–12705 (2020) 26. Wu, Z., Su, L., Huang, Q.: Cascaded partial decoder for fast and accurate salient object detection. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 3907–3916 (2019)
MBDNet: Mitigating the “Under-Training Issue”
111
27. Yao, Z., Wang, L.: Erbanet: enhancing region and boundary awareness for salient object detection. Neurocomputing 448, 152–167 (2021) 28. Zhai, Y., et al.: Bifurcated backbone strategy for rgb-d salient object detection. IEEE Trans. Image Process. 30, 8727–8742 (2021) 29. Zhang, J., et al.: Uncertainty inspired rgb-d saliency detection. IEEE Transactions on Pattern Analysis and Machine Intelligence (2021) 30. Zhang, W., Fu, K., Wang, Z., Ji, G.P., Zhao, Q.: Depth quality-inspired feature manipulation for efficient rgb-d and video salient object detection. arXiv preprint arXiv:2208.03918 (2022) 31. Zhou, B., Yang, G., Wan, X., Wang, Y., Liu, C., Wang, H.: A simple network with progressive structure for salient object detection. In: Chinese Conference on Pattern Recognition and Computer Vision (PRCV) (2021) 32. Zhou, J., Wang, L., Lu, H., Huang, K., Shi, X., Liu, B.: Mvsalnet: Multi-view augmentation for rgb-d salient object detection. In: Computer Vision–ECCV 2022: 17th European Conference, Tel Aviv, Israel, October 23–27, 2022, Proceedings, Part XXIX, pp. 270–287. Springer (2022)
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix and Hybrid Loss Function for Hyperspectral Image Classification Jinchao Jiao1 , Changqing Yin1(B) , and Fei Teng2 1 School of Software Engineering, Tongji University, Shanghai 201804, China
[email protected] 2 Institute of Agricultural Resources and Regional Planning, Chinese Academy of Agricultural
Sciences, Beijing 100081, China
Abstract. Hyperspectral image (HSI) classification is a significant and demanding research area in the field of remote sensing and earth observation. The effective extraction and utilization of texture features from HSI is a challenging problem. In addition, class imbalance is the common problem in the remote sensing datasets, further complicating the HSI task of achieving optimal performance. To address these two problems, in this paper, a deep learning network W-net with gray-level co-occurrence matrix (GLCM) for HSI classification and a hybrid loss function are proposed. The network can utilize GLCM to extract each texture feature for each band of the HSI. The extracted feature maps and the RGB image are downsampled simultaneously in the encoder part, and then upsampled to obtain the final classification map in the decoder part. Meanwhile, the hybrid loss function combines focal loss (FL) with softmax equalization loss (SEQL) to adjust the balance between classes and suppress the gradient of negative samples in rare classes. Experimental results show that the proposed network demonstrates the effective integration of texture features from hyperspectral data using GLCM in the training process, while also offering a solution to the problem of class imbalance, resulting in a promising HSI classification performance. Keywords: Deep learning · Hyperspectral image classification · Convolutional neural networks · Gray-level co-occurrence matrix · Texture extraction
1 Introduction Agricultural and rural land utilization and statistical analysis hold paramount importance in the national economy. Quick and efficient statistical reporting and accounting of agricultural land have the potential to augment the country’s comprehension of data that is of utmost national significance, such as food production. Hyperspectral imaging (HSI) has found extensive utility in several domains [1, 2]. In the majority of these applications, HSI classification represents a pivotal preliminary step, directed towards conferring a semantic designation to each individual pixel of the image. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 112–124, 2023. https://doi.org/10.1007/978-981-99-4761-4_10
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
113
During the early stages of image classification, owing to the limitations in computer hardware and the absence of standardized datasets, the majority of the image classification methodologies relied on manually engineered features for image feature extraction [3–7]. With the development of neural networks, many networks have achieved increasingly better results in image classification tasks, such as ResNet [8], Unet [3] and DeepLab series [4, 9–11]. In the field of HSI classification, various CNN-based classification models have been proposed by relevant scholars, which can be categorized into three types based on the types of features extracted: (1) spectral CNN, which leverages 1D CNN [12] to extract the spectral features of HSI. (2) spatial CNN, which employs 2D CNN [13] to extract spatial features from HSI after dimension reduction. (3) spectralspatial CNN, which utilizes three-dimensional convolution [14] or dual branch network [15] to simultaneously extract spectral-spatial joint features of HSI. The gray-level co-occurrence matrix (GLCM) is a well-established texture analysis method. Researchers often use GLCM in combination with other algorithms, combining GLCM with k-nearest neighbors (KNN) classification methods for butterfly image recognition [16], achieving better results than traditional methods. Mohammadpour et al. [17]. Used GLCM in combination with random forest methods applied to vegetation mapping. Aggarwal [18], in medical image processing, also used GLCM in combination with a random forest approach that can extract important textures in brain magnetic resonance imaging (MRI) images. Class imbalance is a common problem in the dataset, which is manifested by a large number of samples in the majority category and a small number of samples in the rare category, and a large difference in the number of samples between the majority category and the rare category. Typically, two approaches are employed to address the issue of imbalanced datasets: (1) data augmentation; (2) modification of the loss function, whereby various loss functions have been proposed by scholars to tackle the challenge of class imbalance. Focal loss (FL) [19] are modifications of the cross-entropy loss function. The FL demonstrates superior performance, but its reliance on two global parameters may pose challenges for tuning. Equalization loss [20] optimizes the loss function from the perspective of gradient, and ignores the discouraging gradients of negative samples for rare categories. The current CNN-based HSI models adopt a relatively random dimensionality reduction method, which may not be effective in efficiently extracting relevant information from the vast spectral information available. In contrast, the GLCM approach is capable of efficiently extracting texture features of the target category from the abundant spectral data, and its combination with the CNN model yields better results. Moreover, the dataset [21] used in this study exhibits imbalanced sample categories, with significantly varying sample sizes across different categories, which poses a significant challenge for classification tasks, particularly for rare categories, whose negative gradients can substantially impact the overall classification accuracy. Existing loss functions may fail to meet the dataset’s requirements adequately. The FL function facilitates the improvement of the model’s performance by emphasizing the learning of poorly classified voxels. And the SEQL function alleviates the effect of the overwhelmed discouraging gradients during learning by introducing an ignoring strategy. Consequently, the combination of FL and
114
J. Jiao et al.
SEQL can leverage their respective strengths to address the challenges in the datasets, resulting in significant improvements in performance. The main contributions of this article can be summarized as follows: (1) We introduce an innovative deep learning network architecture, called W-net, that combines the semantic segmentation network with GLCM. This architecture can efficiently extract multispectral texture features by fusing and overlaying the multispectral GLCM with the spectral band information in the upsampling stage, ultimately improving accuracy; (2) We propose a hybrid loss function that combines FL and SEQL. The addition of coefficients and suppression of gradients effectively addresses the problem of class imbalance in the dataset and yields satisfactory results; (3) Our experimental results demonstrate that W-net performs well on the dataset and is effective in HSI classification.
2 Datasets and Methods 2.1 Datasets This paper uses two datasets as benchmarks for analysis and evaluation purposes: the Xiongan New Area Matiwan Village (hereinafter referred to as Xiongan) dataset [21] and the AeroRIT scene [22]. The Xiongan dataset possesses three primary attributes: multiple feature classes, high spectral resolution, and high spatial resolution. Its spectral range is 400–100 nm with 250 bands, image size is 3750 × 1580 pixels, and spatial resolution is 0.5 m. The Xiangan dataset encompasses a comprehensive set of 19 distinct categories of features, which have been mapped and labeled (Fig. 1).
Fig. 1. The Xiongan dataset. (a) false color map; (b) ground-truth map.
The AeroRIT dataset is a HSI dataset of the Rochester Institute of Technology’s university campus, which was captured by the Headwall Micro E sensor. This sensor has a high spectral resolution and captures a total of 372 spectral bands. We utilized HSI data consisting of a subset of 51 bands, which were obtained by sampling every tenth band between 400 nm and 900 nm. The spatial resolution of this dataset is 0.4 m/px, resulting in an image of size 1973 × 3975 px. This dataset consists of ground objects that fall into five categories, as illustrated in Fig. 2.
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
115
Fig. 2. The AeroRIT dataset. (a) false color map; (b) ground-truth map.
The datasets are partitioned to training, and test sets, with a patch size of 256 × 256. Initially, we crop the image to obtain 256 × 256 non-overlapping patches, which are subsequently subjected to a random partitioning process to generate the training set and test set. 2.2 Methods We propose a deep convolutional network (W-net) with GLCM for HSI classification and a hybrid loss function that combines FL and SEQL.
Fig. 3. The architecture of W-net for HSI classification.
W-net Architecture. The overall architecture of the proposed W-net for HSI classification is presented in Fig. 3. The proposed network leverages the GLCM feature information to enhance the accuracy of HSI classification. The W-net is based on the Unet network and adds a convolutional auxiliary network with GLCM features on the other side, forming a W-shaped architecture. Therefore, it is referred to as the W-net.
116
J. Jiao et al.
In the encoder part, the network takes RGB images as inputs on the left side and uses basic blocks based on ResNet-18 [8] for progressive downsampling. On the right side of the network, the feature map of the GLCM is input and also progressively downsampled to make the feature map size of different layers consistent. In the decoder part, each layer concatenates the upsampled feature map with the feature map of the same layer on the left and right sides of the network, and then continues the upsampling process. The GLCM feature is obtained based on spectral bands. In this architecture, a sliding window of size 7 is used to calculate the GLCM feature map, and each pixel of the GLCM feature map has already fused the information of surrounding pixels. Since the downsampling on the left side of the network reduces the input image to 1/32 of its original size, the corresponding receptive field of the final fused pixel is relatively large. Therefore, a similar downsampling process is applied to the GLCM to ensure the consistency of feature map sizes and increase the receptive field of the GLCM feature map, which can improve the performance after fusion. GLCM Feature Map Generation. In the W-net, we choose the GLCM feature map as inputs on the right side of the network. The GLCM feature map can assist the network in extracting and utilizing texture features in HSI more effectively. And it consists of eight aspects: mean, variance, correlation, dissimilarity, contrast, entropy, angular second moment (ASM), and homogeneity. Mean is used to describe the variation of the average pixel value in the local position of the image. Variance describes the measure of the pixel value and deviation. Correlation is used to describe the similarity of pixels in the row or column direction around a given pixel in an image. Dissimilarity describes the difference in pixel values in a local area of the image. Contrast describes the distribution and local variation of pixels in the image, and is used to describe the sharpness and depth of texture in the image. Entropy is used as a measure of the amount of information in an image. ASM describes the uniformity of pixel distribution and the degree of texture fineness in an image. Homogeneity is used to describe the uniformity of pixels in the local region of the image. μi = μj = σi2 = σj2 = fcorrelation
N −1 N −1 i=0
iPi,j
(1)
jPi,j
(2)
Pi,j (i − μi )2
(3)
2 Pi,j j − μj
(4)
j=0
N −1 N −1 i=0
j=0
N −1 N −1 i=0
j=0
N −1 N −1 i=0
j=0
N −1 N −1 (i − μi ) j − μj = i=0 j=0 σi2 σj2
fdissimilarity = fcontrast =
N −1 N −1 i=0
j=0
N −1 N −1 i=0
j=0
|i − j|Pi,j
(i − j)2 Pi,j
(5)
(6) (7)
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
fentropy = −
N −1 N −1 i=0
fasm =
j=0
Pi,j log Pi,j
N −1 N −1 i=0
j=0
2 Pi,j
N −1 N −1
fhomogeneity =
i=0
j=0
117
(8) (9)
Pi,j 1 + (i − j)2
(10)
where i is the row number and j is the column number, N is the number of rows, Pi,j is the probability value recorded for the cell i, j, μi , μj is mean value, σi2 , σj2 is variance value. Each specific spectral band can generate 8 corresponding GLCM feature maps. The XiongAn dataset has a total of 256 bands, with each image having a size of 256 × 256 pixels. Therefore, any given hyperspectral image has 2048 GLCM feature maps. Hybrid Loss Function. In the W-net, we employ a hybrid loss function consisting of contributions from both SEQL [20] and FL [19]. The SEQL learns the class distribution alleviate the effect of the overwhelmed discouraging gradients on tail categories, and The FL function facilitates the improvement of the model’s performance by emphasizing the learning of poorly classified voxels. The complete loss function can be mathematically expressed as follows: L = LSEQL + LFOCAL = −
C j=1
C yj log p˜ j − α(1 − pi )μ yj log pj j=1
(11)
where p˜ j = C
ezj
˜ ke k=1 w z ej
pj = C
k=1 e
zk
zk
(12) (13)
and the weight term wk is computed by: w˜ k = 1 − βTλ (fk )(1 − yk )
(14)
where β is a random variable with a probability of γ to be 1 and 1 − γ to be 0. In this equation, we use β to randomly retaining the gradient of negative samples. The frequency of category k within the dataset denoted by fk , is calculated as the number of samples belonging to class k, divided by the total number of samples in the dataset. And Tλ (x) is a threshold function that outputs 1 when x < λ and 0 otherwise.
118
J. Jiao et al.
3 Experiments and Analysis 3.1 Implementation Details and Metrics Our model is optimized by the AdamW optimizer, with an initial learning rate 0.004 and the weight decay 0.001. The experiments were performed on an Ubuntu 16.04 operating system, using a machine equipped with two GTX 1080Ti and 64 GB RAM. The metrics used in this paper are listed as follows: intersection over union (IoU), mean intersection over union(mIoU), pixel accuracy(acc), mean pixel accuracy(macc). TP A∩B = A∪B FP + FN + TP k Pii 1 mIoU = k k i=1 k j=1 Pij + j=1 Pji − Pii IoU =
(15) (16)
the Pij represents the number of pixels predicted as j but actually belong to class i. nii acc = i (17) i ti nii 1 i (18) macc = ncls i ti the nij represents the number of pixels predicted as class j but actually belong to class i, nclsrepresents the total number of target classes (including the background), and tj = j nij represents the total number of pixels in target class j (the ground truth label). 3.2 Comparison with Our Method Six algorithms were chosen as representatives for comparison purposes. Random forest is a traditional algorithm commonly for HSI classification. In the computer vision, there exist four predominant models that are widely utilized to tackle semantic segmentation tasks. These models include Unet [3], Deeplab v3+ [4] and Swin-Unet [5]. Additionally, the current state-of-the-art models for HSI classification tasks comprise two advanced models. These models are FSSANet [6] and CTAFNet [7]. For detailed information and parameter settings of the model, please refer to the corresponding paper. Figure 4 shows the visual result on the Xiongan dataset. Table 1 shows the experimental results of different methods in the Xiongan test set. In Fig. 4, the classification maps generated by the W-net model show good performance, resulting in complete segmentation regions with smooth edges and less noise within the region. The difference between the generated maps and the ground truth is minimal. As shown in Table 1, the proposed W-net model achieves the best performance with 97.23% macc and 87.32% mIoU, outperforming the second-ranked CTAFNet by 1.06% in macc and 0.68% in mIoU. Among all the models, Random Forest has the worst performance with an macc of 78.32% and an mIoU of 58.23%. Among all the deep learning networks, FSSANet has the worst performance with an macc of 78.65% and an mIoU of 66.10%. W-net achieves the best performance in 14 out of 19 categories.
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
119
Fig. 4. Visual result of different methods on the Xiongan dataset. Table 1. Classification results for the Xiongan test set. Class
Random Forest [23]
Unet [3]
DeepLab v3+ [4]
Swin-Unet [5] FSSA Net [6]
CTAF Net [7]
W-net
Rice
91.23
94.88
97.71
93.80
95.76
99.20
99.30
Stubble
84.12
87.50
98.01
95.03
97.44
99.96
99.95
Water
93.54
96.85
94.98
88.42
91.56
95.11
95.06
Grassland
71.20
73.20
68.08
61.46
67.35
92.90
92.50
Willow
84.23
85.60
77.02
70.83
72.57
97.74
97.89
Elm
76.32
79.96
57.28
62.28
45.38
92.53
93.12
Maple
64.20
68.51
52.15
63.14
49.41
89.29
90.02
White wax
78.64
80.43
77.83
49.27
74.28
85.19
85.64
Luan
98.70
99.66
93.08
87.27
97.97
99.52
99.50
Sophora
72.01
73.36
78.71
69.95
52.22
91.31
91.45
Peach
74.23
78.21
78.29
51.76
49.66
95.99
95.23
Vegetable
20.45
31.06
32.70
14.94
25.98
53.72
60.45
Corn
31.21
57.92
56.88
34.62
61.70
80.21
84.32
Poplar
70.21
74.74
71.34
59.24
48.67
78.53
80.53
Pear
70.32
75.90
80.07
61.76
65.82
93.86
94.63
1.20
8.50
23.28
13.10
9.51
34.36
40.26
Soybean
(continued)
120
J. Jiao et al. Table 1. (continued)
Class
Random Forest [23]
Unet [3]
DeepLab v3+ [4]
Swin-Unet [5]
Lotus
18.23
40.19
41.34
5.32
10.17
18.69
House
78.65
83.63
macc
78.32
86.78
mIoU
58.23
73.64
Robinia
FSSA Net [6]
CTAF W-net Net [7]
20.05
34.75
57.23
64.12
13.47
11.25
30.12
48.45
83.53
60.76
64.64
89.38
90.12
87.24
80.21
78.65
96.17
97.23
72.78
63.62
66.10
86.64
87.32
Figure 5 shows the visual comparison of different methods on the AeroRIT test set. Overall, the classification map generated by the W-net is highly consistent with the ground truth. In terms of edge detection, the classification map produced by the W-net exhibits smoother edges when compared to the CTAFNet. Additionally, the Wnet’s classification map is effective in segmenting shadows, while exhibiting higher discriminative power in mixed pixel regions when compared to the Unet. For the cars category, the W-net generates segmentation results that are more aligned with the ground truth.
Fig. 5. Visual comparison of different methods on the AeroRIT test set.
Table 2 shows the experimental results of different methods on the AeroRIT test set. The proposed W-net model achieves the best performance with 95.78% macc and 82.43% mIoU, outperforming the second-ranked CTAFNet by 0.71% in macc and 1.02% in mIoU. Among all the models, Random Forest has the worst performance with an macc of 73.09% and an mIoU of 55.62%. Among all the deep learning networks, DeepLabv3+ has the worst performance with an macc of 92.85% and an mIoU of 70.72%. W-net achieves the best performance in 5 categories, with cars exhibiting a promising improvement in accuracy by 8.79% over the runner-up method CTAFNet.
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
121
Table 2. Classification results on the AeroRIT test set. Class
Random Forest [23]
Unet [3]
DeepLab v3+ [4]
Swin-Unet [5]
FSSA Net [6]
CTAF Net [7]
W-net
Buildings Vegetation Roads Water Cars
62.12 89.56 61.45 31.21 32.45
81.45 92.84 81.32 67.78 46.65
80.40 91.35 82.12 68.21 43.12
80.95 92.16 81.56 65.45 44.60
82.12 92.87 83.61 70.28 49.76
86.63 95.72 84.64 78.42 61.64
88.55 95.62 85.41 78.86 70.43
macc mIoU
73.09 55.62
93.08 73.68
92.85 70.72
93.01 71.43
93.35 72.30
95.07 81.41
95.78 82.43
3.3 Ablation Study GLCM Feature Map. We validate the impact of GLCM feature map on the performance of model. In order to verify the effectiveness of the GLCM feature map, we conduct ablation experiments. Three models are designed for the experiment: U-Net, W-net, and W-net without the GLCM feature map on the right side of the network. The loss function is the hybrid loss (FL + SEQL). The experimental results are shown in Table 3. Through ablation experiments, we find that removing the right part of the network (GLCM feature map) in the W-net resulted in lower macc and mIoU than the W-net, with a reduction of 4.11% and 6.78%, respectively. The accuracy of specific categories is also lower than that of W-net. These results suggest that the GLCM feature map contributes to improving the accuracy of the model and achieving better performance. Table 3. Ablation study of GLCM feature map on the Xiongan dataset. Metric
Unet [3]
W-net without GLCM [7]
W-net
macc
86.78
93.12
97.23
mIoU
73.64
80.54
87.32
Hybrid Loss. We also validate the impact of various loss functions on the training process and overall performance of the model. To ensure that the effects of loss functions are distinct from network design decisions, we employ the W-net architecture and trained it with different loss functions. This way, we can concentrate our analysis exclusively on the influence of loss functions on the efficacy of models. We try four loss functions, including cross entropy loss, FL, SEQL, and hybrid loss between FL and SEQL. The four parameters, α, μ, β, and γ in Eq. 11 and Eq. 14 are set based on the best performance on the training set. Specifically, we set α = 0.25, μ = 2, β = 5.0 × 10−3 , and γ = 0.9. The performances of the model trained with the four loss functions described above are shown in Table 4. Several observations can be made from this experiment. First, the hybrid loss functions outperform simple FL or SEQL. This indicates that taking the
122
J. Jiao et al.
Table 4. Comparisons of performances of model trained with different loss functions on the Xiongan dataset. Loss Function
Soybean
Robinia
macc
mIoU
cross entropy loss
20.33
24.32
94.32
82.4
FL
24.21
25.32
96.12
85.7
SEQL
35.64
36.23
95.23
83.4
FL + SEQL
40.26
48.45
97.23
87.0
hybrid loss can lead to an improvement in performance. Second, by adopting SEQL, the recognition accuracy for rare classes can be significantly improved, resulting in a performance gain of approximately 15% when SEQL is used alone. Furthermore, adopting a combination of FL and SEQL can result in a performance gain of approximately 20%. Based on the above observations, the hybrid loss with FL combined with SEQL is used to train W-net.
4 Conclusions In this paper, a novel W-net is proposed for HSI classification and a hybrid loss for addressing the issue of class imbalance. The experimental results demonstrate that the W-net outperforms traditional methods such as random forest and deep learning models including U-Net, Deeplab v3+, and CTAFNet in the classification of various land cover types. This superior performance is attributed to the W-net’s ability to better utilize multispectral information beyond RGB bands. By calculating the gray-level co-occurrence matrix of other bands, the texture information that is helpful for classification but hidden in other bands is preserved and used in the training of the deep learning classification network. In addition, the hybrid loss is effective in helping the model learn the features of rare categories, improving the recognition accuracy of rare categories, and alleviating the problem of class imbalance. In conclusion, W-net exhibits excellent classification performance as a deep learning model and can also incorporate texture features from hyperspectral images into the training process using gray-level co-occurrence matrices. Moreover, it effectively addresses the issue of sample imbalance, making it a viable solution for practical applications. Acknowledgement. This work was supported in part by Calibration and Validation of Highresolution Aerial System Applications, in part by High resolution Earth observation System Project under grant No. 09-H30G02-9001-20/22, and in part by the National Key Research and Development Program of China (No. 2020YFD1100602).
References 1. Khan, M.J., Khan, H.S., Yousaf, A., Khurshid, K., Abbas, A.: Modern trends in hyperspectral image analysis: a review. IEEE Access 6, 14118–14129 (2018)
W-net: Deep Convolutional Network with Gray-Level Co-occurrence Matrix
123
2. Liu, B., et al.: Underwater hyperspectral imaging technology and its applications for detecting and mapping the seafloor: a review. Sensors 20(17), 4962 (2020) 3. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-31924574-4_28 4. Chen, L.-C., Zhu, Y., Papandreou, G., Schroff, F., Adam, H.: Encoder-decoder with atrous separable convolution for semantic image segmentation. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 833–851. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_49 5. Cao, H., et al.: Swin-unet: Unet-like pure transformer for medical image segmentation. In: Computer Vision--ECCV 2022 Workshops: Tel Aviv, Israel, October 23–27, 2022, Proceedings, Part III. pp. 205–218. Springer (2023) 6. Sun, J., et al.: Fusing spatial attention with spectral-channel attention mechanism for hyperspectral image classification via encoder–decoder networks. Remote Sens. 14(9), 1968 (2022) 7. Li, J., Xing, H., Ao, Z., Wang, H., Liu, W., Zhang, A.: Convolution-transformer adaptive fusion network for hyperspectral image classification. Appl. Sci. 13(1), 492 (2023) 8. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 9. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: Semantic image segmentation with deep convolutional nets and fully connected crfs. arXiv preprint arXiv:1412.7062 (2014) 10. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: Deeplab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Trans. Pattern Anal. Mach. Intell. 40(4), 834–848 (2017) 11. Chen, L.C., Papandreou, G., Schroff, F., Adam, H.: Rethinking atrous convolution for semantic image segmentation. arXiv preprint arXiv:1706.05587 (2017) 12. Chen, Y., Jiang, H., Li, C., Jia, X., Ghamisi, P.: Deep feature extraction and classification of hyperspectral images based on convolutional neural networks. IEEE Trans. Geosci. Remote Sens. 54(10), 6232–6251 (2016) 13. Zhao, W., Du, S.: Spectral–spatial feature extraction for hyperspectral image classification: a dimension reduction and deep learning approach. IEEE Trans. Geosci. Remote Sens. 54(8), 4544–4554 (2016) 14. Hu, W., Huang, Y., Wei, L., Zhang, F., Li, H.: Deep convolutional neural networks for hyperspectral image classification. J. Sens. 2015, 1–12 (2015) 15. Zhang, H., Li, Y., Zhang, Y., Shen, Q.: Spectral-spatial classification of hyperspectral imagery using a dual-channel convolutional neural network. Remote Sens. Lett. 8(5), 438–447 (2017) 16. Andrian, R., Maharani, D., Muhammad, M.A., Junaidi, A., et al.: Butterfly identification using gray level co-occurrence matrix (glcm) extraction feature and k-nearest neighbor (knn) classification. Rico andrian, devi maharani, meizano ardhi muhammad, akmal junaidi. Jurnal Kupu-Kupu Indonesia 6(1), 11-21 (2022) 17. Mohammadpour, P., Viegas, D.X., Viegas, C.: Vegetation mapping with random forest using sentinel 2 and GLCM texture feature—a case study for lousã region, Portugal. Remote Sens. 14(18), 4585 (2022) 18. Aggarwal, A.K.: Learning texture features from glcm for classification of brain tumor mri images using random forest classifier. Trans. Signal Process. 18, 60–63 (2022) 19. Lin, T.Y., Goyal, P., Girshick, R., He, K., Dollár, P.: Focal loss for dense object detection. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2980–2988 (2017)
124
J. Jiao et al.
20. Tan, J., Wang, C., Li, B., Li, Q., Ouyang, W., Yin, C., Yan, J.: Equalization loss for long-tailed object recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11662--11671 (2020) 21. Cen, Y., et al.: Aerial hyperspectral remote sensing classification dataset of Xiongan new area (Matiwan village). J. Remote Sens 24(11), 1299–1306 (2020) 22. Rangnekar, A., Mokashi, N., Ientilucci, E.J., Kanan, C., Hoffman, M.J.: Aerorit: a new scene for hyperspectral image analysis. IEEE Trans. Geosci. Remote Sens. 58(11), 8116–8124 (2020) 23. Breiman, L.: Random forests. Mach. Learn. 45, 5–32 (2001)
Brain Tumor Image Segmentation Network Based on Dual Attention Mechanism Fuyun He1,2(B) , Yao Zhang1 , Yan Wei1 , Youwei Qian1 , Cong Hu3,4 , and Xiaohu Tang1,2 1 School of Electronic and Information Engineering, Guangxi Normal University,
Guilin 541004, Guangxi, China [email protected] 2 Guangxi Key Laboratory of Brain-Inspired Computing and Intelligent Chips, School of Electronic and Information Engineering, Guangxi Normal University, Guilin 541004, Guangxi, China 3 Guangxi Key Laboratory of Automatic Detecting Technology and Instruments, Guilin University of Electronic Technology, Guilin 541004, Guangxi, China 4 Guangxi Key Laboratory of Wireless Wideband Communication and Signal Processing, Guilin University of Electronic Technology, Guilin 541004, Guangxi, China
Abstract. Due to the different shape, location and size of brain tumors, and their appearance is highly heterogeneous, which makes it difficult to establish effective brain tumor image segmentation rules. In view of the excessive depth and the lack of connection between global and local feature information of current medical image segmentation network, which leads to the reduction of image segmentation accuracy, we propose an improve brain tumor image segmentation network based on depth residuals and dual attention mechanisms. Firstly, inspired by the residual network, we replaced the traditional convolutional block with deep residual block, which not only can more feature information be extracted, but also network degradation can be suppressed and convergence can be accelerated. Secondly, the introduction of dual attention mechanism in each skip connection can fuse richer context information, making the model more focused on the features that need to be segmented in the tumor region, while suppressing the irrelevant regions. Then, the loss function of our model is the combination of cross entropy function and dice similarity coefficient function. Finally, MRI images from the BraTS2019 dataset are used to train and test, and the Dice coefficient is used to evaluate the segmentation accuracy. The average Dice coefficient of the whole tumor area WT, the core tumor area TC and the enhanced tumor area ET are 0.8214, 0.8408 and 0.7448, respectively. The experiment results show that our model can effectively improve the segmentation accuracy of brain tumor images compared with other deep segmentation models. Keywords: Brain tumor segmentation · Dual attention mechanism · Depth residual · MRI images
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 125–136, 2023. https://doi.org/10.1007/978-981-99-4761-4_11
126
F. He et al.
1 Introduction A brain tumor is a disease in which brain tissue becomes cancerous, and the incidence is higher and the mortality rate is over 3%. The site of the disease can be any part of the brain, such as cerebellum, brain stem, cerebral hemisphere, skull base and so on [1]. Due to the specific location of these lesions, once a tumor develops into a malignant tumor, the tumor cells will metastasize and become life-threatening [2]. Brain tumors are classified into four grades according to the degree of canceration and corresponding histopathological features [3]. Grade I and II brain tumors usually occur in children and adolescents. Grade III and IV brain tumors occur mostly in older adults. Timely early diagnosis is very important for brain tumor treatment. Thus, reasonable and effective treatment can be developed in time to avoid the deterioration of the tumor [4]. At present, the segmentation of brain tumor images is mainly carried out manually by experts, and the segmentation results of brain tumor images by different experts are quite different, which takes a long time and has a high misjudgment rate, which seriously affects the follow-up treatment of patients. In recent years, many excellent deep learning-based brain tumor image segmentation methods have been proposed [5].Among them, U-net network is one of the most classical semantic segmentation networks, which was originally used for cell wall segmentation [6] and greatly promoted the development of medical image segmentation [7].McKinley et al. applied dense connection to U-Net, replacing simple convolution blocks in U-net networks with dense blocks of extended convolution [8]. Ref. [9] added a variable autoencoder branch similar to the automatic encoder VAE at the end of the encoder. They added extra guidance and regularization to the encoder branch, which led them to take first place in the BraTS2018 challenge. Isensee et al. paid more attention to the training process and believed that well-trained U-Net could also achieve excellent segmentation performance [10]. Isensee’s idea provides a new way to image segmentation. From the perspective of the whole network structure, U-Net codec network is a single network structure. As the network depth deepens, the network structure becomes more complex and more information is learned, thus improving the segmentation performance. Therefore, many researchers began to study the cascade of multiple deep neural networks. Kamnitsas et al. believed that a single network structure was easy to overfit specific data sets, so they tried to integrate multiple deep neural networks for training and configured them with different parameter pairs [11]. Iqbal et al. also adopted a similar algorithm to output the final result through voting strategy [12].Wu et al. proposed a lightweight multi-scale features learning block to perform deep supervision. This module is only used at the training phase and discarded during test phase, which won’t add inference time [13]. These deep learning-based segmentation networks are limited by differences in the shape and texture of brain tumors, resulting in loss of local information, and the obtained segmentation results are still not fine enough. In some networks, simple splicing of the feature map leads to more redundant information and the attention of any region of the input image is the same, thus affecting the segmentation effect. To solve these problems, by combining high and low resolution feature information, we propose a brain tumor image segmentation network incorporating dual attention mechanisms (DAMNet), which can effectively enhance the segmentation features of brain tumor images
Brain Tumor Image Segmentation Network
127
and suppress irrelevant regions by integrating deep residual and attention mechanisms, thus improving the segmentation accuracy of the network. The rest of the paper is organized as follows: The content of Sect. 2 mainly discusses the structure of the segmentation network, training and testing and evaluation metrics. Section 3 describes the experimental dataset and preprocessing. Section 4 is the experimental results and analysis, and Sect. 5 provides a summary and outlook of the research.
2 Methodology The architecture of U-Net can make good use of both shallow and high-level semantic information [14].The skip connection can integrate shallow information into high-level feature map while compensating for the details of low-level information. In the design concept of network architecture, it is hoped to transmit shallow information in backbone network. Therefore, residual modules are introduced in each convolution block to map input information to next layer, so as to make better use of shallow information. The encoding and decoding patterns based on U-shape structure are performing well in more and more fields; full convolutional network is becoming the mainstream of medical image segmentation [15]. 2.1 Deep Residual Module In the image recognition competition of ILSVRC in 2015, the 152-layer residual network proposed by He et al. won the championship, but the error reached 3.6% [16]. Residual network has been widely used in various fields by researchers and has achieved very good results. Unlike the complex “carry gates” and “transfer gates” in Highway network [17], shortcuts in the residual network are exactly the same mapping and therefore do not introduce additional parameters. The structure of depth residual module is shown in Fig. 1. The input x is transmitted to the output H(x) through the shortcut connection, thus speed up the transmission of information. At this time, the goal of learning is no longer the complete output H(x), but the difference between the output and the input, namely H(x) − x.
Fig. 1. Structure of depth residual module [17]
As showed in Fig. 2, Fig. 2(a) is a traditional U-Net convolution block; Fig. 2(b) is a deep residual module. In the proposed model, the convolutional block in U-Net is replaced by the deep residual module [18].The deep residual module make a batch normalization and ReLU pre-activation before 3 × 3 convolution, and the input is directly transmitted to the back layer through identity mapping.
128
F. He et al.
Fig. 2. The structure comparison between the original convolution module and the depth residual module. (a) The original convolution module; (b) Deep residual module.
2.2 Dual Attention Mechanism Simple direct splicing will bring a lot of redundant information, not only resulting in the waste of computing resources, but also lead to the ambiguity of the segmentation area and unable to highlight the details. Our model added dual attention mechanism in the skip connection to improve the attention of tumor regions [13]. Therefore, based on the existing attention mechanism, a dual attention module is introduced to effectively establish the interdependence between each pixel channel. Its structure is shown in Fig. 3. Where, Fg represents the high-resolution feature vector after up-sampling. Fs represent the low-level feature vector of the coded unit of the skip connection. Both Fs and Fg obtain the same number of channels Cint through 1 × 1 convolution transformations. The main function of 1 × 1 convolution kernel is to increase the number of nonlinear mappings, reduce convolution kernel parameters, and realizes cross-channel information interaction. The weight matrix is obtained through ReLU function activation, 1 × 1 convolutions, and sigmoid activation, which is multiplied by Fs point by point to obtain the output vector. The output feature vector is connected in series with Fg as the input of the next layer. 2.3 Network Architecture The brain tumor image segmentation network based on dual attention mechanism uses deep residual module to replace the coding layer and decoding layer in the original image segmentation network structure, which can deepen the network training and solve the problem of gradient disappearance. By introducing dual attention modules in the segmentation network, the network can better learn important feature information, suppress redundant feature information. The proposed DAM-Net network structure is shown in Fig. 4. The input data of the network is a four-channel feature image composed of slices of four different modes. In the deep residual module of the network backbone, batch
Brain Tumor Image Segmentation Network
129
Fig. 3. Dual attention module structure diagram
Fig. 4. DAM-Net network structure diagram
normalization and ReLU pre-activation are performed before each 3 × 3 convolution. After two operations of BN, ReLU and 3 × 3conv, the feature map and shallow feature information are fused by identity mapping. The down-sampling operation adopts the max-pooling. In the skip connection, the feature map output by the attention mechanism is spliced with the feature map obtained by the up-sampling. The size of the final feature map is the same as that of the input image. At the end of the network, 1 × 1conv and Sigmoid are used to classify and calculate the pixels of feature image. 2.4 Training and Testing The experimental hardware platform is configured with Intel i7-13700F CPU (2.10 GHz), 256GB memory, and Windows Server 2012 R2 operating system. All experiments were performed on the Pytorch platform using the NVIDIA GeForce GTX 1050TI GPU. Adam optimizer was used for training, with momentum set to 0.9, Batchsize to 8, LR to 0.0003, and attenuation coefficient to 0.0001.The early stop strategy was adopted to prevent the
130
F. He et al.
model from over-fitting. In the network training stage, the input images included preprocessed brain tumor images and true segmentation labels. 80% of the BraTS2019 data set was taken as the training set and the remaining 20% as the verification set. The cross entropy function can measure the difference between the real probability distribution and the predicted probability distribution. The cross entropy function adopts the mechanism of competition between classes, so it can learn the information between different classes better. However, it only focuses on the accuracy of prediction results of correct labels, and ignores the difference of other incorrect labels; it is very susceptible to category imbalance, leading to the dispersion of learned features. The cross entropy function is defined as formula (1), 1 p xij log pˆ xij m m
CEL = −
n
(1)
i=1 j=1
where p represents label value of the true class of pixel xij, pˆ represents the predicted label value of pixel xij. m represents the number of pixels of the output image, and n represents the number of pixels of the label image. In the task of brain tumor image segmentation, the most important evaluation index is Dice similarity coefficient. Dice correlation coefficient effectively measures the degree of overlap between algorithm segmentation results and real labels. All pixels of the same category are taken as a whole to calculate the loss value. There is no need to calculate the background pixel when calculating the merging ratio, so the calculation amount is greatly reduced, and the imbalance of positive and negative samples is alleviated. The specific calculation of Dice similarity coefficient loss function is shown in the formula (2), 2∗ Dice = 1 −
N
pi ∗ pˆ i
i=1
N
i=1 pi +
N
(2) pˆ i + ε
i=1
where N represents the number of pixels of the image, pi represents i-th label value, pˆ represents the predicted label value ε is the minimum factor that prevents the denominator from being zero. Considering the advantages and disadvantages of the two loss functions comprehensively, the loss function of our model is the weighted sum of the two, which can be written as follows: L = α × CEL + β × Dice
(3)
2.5 Performance Evaluation Metrics Dice coefficient and Hausdorff distance were used to evaluate the segmentation accuracy of BraTs2019 dataset. The Dice coefficient is shown in formula (4), Dice =
1 |P1 ∧ T1| 2 |P1| + |T1|
(4)
Brain Tumor Image Segmentation Network
131
The similarity between predicted results and labels can be well evaluated by the proportion of the overlap between ground truth and prediction region. We divide the image into four regions, as shown in Fig. 5. T1 in blue is the real tumor area, T0, outside of the blue area, is the normal area, P1 in red is the predicted tumor area, P0, outside of the red area, is the predicted normal region. P1^T1 represents the overlap between the predicted region and the ground truth. | | represents the sum of pixels.
Fig. 5. Schematic diagram of tumor region division
Hausdorff distance measures the similarity between two point sets and represents the maximum mismatching degree between predicted result and actual label, and its definition is shown in Formula (5), Hausdorff = max{maxa ∈ A minb ∈ B a − b, maxb ∈ B mina ∈ A b − a}
(5)
where a is a point in set A, b is a point in set B, a − b represents the Euclidean distance between a and b, s maxa ∈ A minb ∈ B a − b represents the forward Hausdorff distance between a and b, maxb ∈ B mina ∈ A b − a represents the backward Hausdorff distance between b and a. The smaller the value of this formula is, the higher the matching degree of the two point sets is.
3 Experimental Dataset and Preprocessing 3.1 Experimental Dataset The BraTs2019 dataset came from 19 different institutions, each using a different nuclear magnetic resonance imaging equipment [19]. The dataset included 259 HGG (high-grade gliomas) and 76 LGG (low-grade gliomas).Each data set includes four modes: Flair, T1, T2 and T1ce, the size of each mode is 240 × 240 × 155. In this data set, the red region is gangrene area (NET, non-enhancing tumor, label value is 1), the green region is swelling area (ED, paratumolar edema, label value is 2), the yellow region is enhanced tumor area (ET, enhanced tumor, label value is 4), and the background label value is 0.Three regions can be divided: Whole Tumor (WT = ED + ET + NET), Tumor Core (TC = ET + NET) and Tumor Enhancement ET. 3.2 Data Preprocessing Since many characteristics of the original data will greatly affect the performance of the model, data preprocessing has a great impact on the segmentation accuracy. The
132
F. He et al.
BraTs2019 dataset was officially preprocessed, including registration, re-sampling to 1 × 1 × 1 mm and skull separation. Due to the improper operation and the unique nature of the image, a harmful artifact is generated in the final image. Although these artifacts are unlikely to be observed by the human eye, they can significantly affect segmentation performance. N4 bias field correction is a common method which can significantly improve the intensity inhomogeneity of brain tumor images. Figure 6 shows the effect comparison before and after N4 bias field correction.
Fig. 6. Comparison of N4 bias field before and after correction (Left: before correction, Right: after correction).
In addition, there is a large deviation in the gray value range of the dataset, which will bring problems such as large computation and slow convergence speed to model training. So it needs to be standardized after bias field correction. In the experiment, zscore standardization was used to make the four modes follow the Gaussian distribution with mean value of 0 and variance of 1. Figures 7 and 8 show the comparison effect before and after standardization.
Fig. 7. Before standardization
Fig. 8. After standardization
Learning models trained by unbalanced samples generally have poor generalization ability and prone to over-fitting. Therefore, the sections without lesions were removed from the training set to prevent serious category imbalance due to the few or even no lesions in the sections. Finally, we crop sections to the same size as 160x160 by flip, rotation, random cutting and so on. The data enhancement results are shown in Fig. 9.
Brain Tumor Image Segmentation Network
133
Fig. 9. Data enhancement effect
4 Experimental Results and Analysis Table 1 shows the experiments results of brain tumor image segmentation of three comparison networks and the proposed segmentation network (DAM-Net) on the Dice coefficient. The experimental results showed that in WT, TC and ET, Dice values reaches 0.8214, 0.8408 and 0.7448, respectively. It can be seen that in terms of Dice coefficient index, the proposed segmentation network is superior to the other three networks in WT, TC and ET region. Table 1. Brain tumor image segmentation results of four kinds of networks in Dice coefficient (Dice) Method
Dice WT
TC
ET
U-Net [6]
0.8047
0.8384
0.7351
U-Net + Attention [13]
0.8144
0.8283
0.7390
Res + U-Net [17]
0.8162
0.8387
0.7415
DAM-Net
0.8214
0.8408
0.7448
Table 2 shows the experimental results of brain tumor image segmentation of three comparison networks and the proposed segmentation network (DAM-Net) on the Hausdorff coefficient. The experimental results show that in WT, TC and ET, Hausdorff value reaches 2.6880, 1.6600 and 2.8979, respectively. It can be seen that in terms of Hausdorff coefficient index, the proposed segmentation network is superior to the other three networks in WT, TC and ET region. On the Dice coefficient and Hausdorff, the overall performance of brain tumor image segmentation network based on dual attention mechanism is better than the other three kinds of networks. The biggest advantage of our segmentation network is that the proposed dual attention module can transmit the contextual semantic feature information
134
F. He et al.
Table 2. Brain tumor image segmentation results of four kinds of networks in Hausdorff Method
Hausdorff WT
TC
ET
U-Net [6]
2.7347
1.7199
2.9186
U-Net + Attention [13]
2.7320
1.6852
2.8996
Res + U-Net [17]
2.7053
1.6783
2.9015
DAM-Net
2.6880
1.6600
2.8979
more closely, greatly reduce the loss of feature information, and pay more attention to the details of brain tumor images in the decoding process. Compared with the traditional image segmentation network structure U-Net, DAM-Net not only combines the local and global feature information of images, but also refers to the network structure based on U-Net + Attention and Res + U-Net. These network structures are inserted into the channel or spatial attention module in the decoding process and the attention module can better capture the important feature information of the image and improve the accuracy of image segmentation.
Fig. 10. Comparison of image segmentation results of four networks. (a)U-Net; (b) U-Net + Attention; (c) Res + U-Net; (d) DAM-Net; (e) Ground Truth
Figure 10 shows partial image segmentation results of four segmentation networks on experimental BraTs2019.From (a)-(e) each column in turn is U-Net network, U-Net + Attention network, Res + U-Net network, DAM-Net network, and real label. It can be seen that U-Net can only roughly locate the location of the brain tumor, while U-netbased networks, such as Res + U-Net, can segment the outline and shape of the brain tumor image more clearly. This is because residual structure is added to the network, so that the network can connect the semantic information of the image more effectively.
Brain Tumor Image Segmentation Network
135
Thus, the important feature information of image can be highlighted, which is helpful to the extraction of image feature in the process of encoding. DAM-Ne network adds depth residual module on this basis, which can better connect the local and global feature information of image. However, it can be seen from Fig. 10 that there are some redundant image scatter points in the Res + U-Net network segmentation results. Similarly, although the overall segmentation of brain tumor images by U-Net + Attention network is quite clear, the shallow feature information from the coding layer is noisy, the edge information of the image is frizzled, there is a lot of redundant information, which easily affects the accuracy of brain tumor image segmentation. However, the contour edge extracted by DAM-Net network is more prominent, the detail recovery effect is better, which is more close to the real label.
5 Conclusion Aiming at the low accuracy of image segmentation caused by the lack of local and global feature information in the current tumor image segmentation network, we propose a brain tumor image segmentation network based on dual attention mechanism. The improvement of the performance of the segmentation network is discussed from the data input, network structure and loss function of the model. The N4 bias field correction of data input eliminates the influence of harmful artifacts. The attention mechanism and the introduction of deep residuals not only accelerate the convergence speed and information transmission speed of the network, but also enhance the weight of tumor regions, highlight tumor regions and suppress useless feature information. The loss function is the combination of cross entropy loss function and dice similarity coefficient loss function, which can effectively alleviate the class imbalance. The experimental comparison shows that our proposed brain tumor image segmentation network is superior to other comparison networks. However, the proposed segmentation network still has some limitations. For example, 2D segmentation network is to segment the slices of 3D data set, but does not effectively use the spatial information of the data set. Therefore, the next research focus is to extend the 2D model 3D model. Acknowledgment. This work is supported by the National Natural Science Foundation of China (62062014), and Key Scientific Research Project of Guangxi Normal University (2018ZD007).
References 1. Hugues, D.: New philosophy, clinical pearls, and methods for intraoperative cognition mapping and monitoring “à la carte” in brain tumor patients. Neurosurgery 88(5), 919–930 (2021). https://doi.org/10.1093/neuros/nyaa363 2. Cinar, N., Ozcan, A., Kaya, M.: A hybrid DenseNet121-UNet model for brain tumor segmentation from MR Images. Biomed. Signal Proces. Control 76, 103647 (2022). https://doi.org/ 10.1016/j.bspc.2022.103647 3. Wang, S., et al.: Label-free detection of the architectural feature of blood vessels in glioblastoma based on multiphoton microscopy. IEEE J. Sel. Topics Quantum Electron. 27(4), 1–7 (2021). https://doi.org/10.1109/JSTQE.2021.3058175
136
F. He et al.
4. Akter, F., et al.: Pre-clinical tumor models of primary brain tumors: challenges and opportunities. Biochimica et Biophysica Acta (BBA) – Rev. Cancer 1875(1), 188458 (2021). https:// doi.org/10.1016/j.bbcan.2020.188458 5. Zhao, X., et al.: A deep learning model integrating FCNNs and CRFs for brain tumor segmentation. Med. Image Anal. 43, 98–111 (2017). https://doi.org/10.1016/j.media.2017. 10.002 6. Li, N., Ren, K.: Double attention U-Net for brain tumor MR image segmentation. Int. J. Intell. Comput. Cybern. 14(3), 467–479 (2021). https://doi.org/10.1108/IJICC-01-2021-0018 7. Tunga, P.P., et al.: U-net model based classification and description of brain tumor in MRI images. Int. J. Image Graph 21, 2140005 (2020). https://doi.org/10.1142/S02194678214 00052 8. McKinley, R., Meier, R., Wiest, R.: Ensembles of densely-connected CNNs with labeluncertainty for brain tumor segmentation. In: Crimi, A., Bakas, S., Kuijf, H., Keyvan, F., Reyes, M., van Walsum, T. (eds.) BrainLes 2018. LNCS, vol. 11384, pp. 456–465. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-11726-9_40 9. Nayak, U.A., et al.: Validation of segmented brain tumor from MRI images using 3D printing the Asian. Pac. J. Cancer Prev. 22(2), 523–530 (2021). https://doi.org/10.31557/APJCP.2021. 22.2.523 10. Isensee, F., Kickingereder, P., Wick, W., Bendszus, M., Maier-Hein, K.H.: No new-net. In: Crimi, A., Bakas, S., Kuijf, H., Keyvan, F., Reyes, M., van Walsum, T. (eds.) BrainLes 2018. LNCS, vol. 11384, pp. 234–244. Springer, Cham (2019). https://doi.org/10.1007/978-3-03011726-9_21 11. Kamnitsas, K., et al.: Ensembles of multiple models and architectures for robust brain tumour segmentation. In: Crimi, A., Bakas, S., Kuijf, H., Menze, B., Reyes, M. (eds.) BrainLes 2017. LNCS, vol. 10670, pp. 450–462. Springer, Cham (2018). https://doi.org/10.1007/978-3-31975238-9_38 12. Iqbal, S., et al.: Deep learning model integrating features and novel classifiers fusion for brain tumor segmentation. Microsc. Res. Tech. 82(8), 1302–1315 (2019). https://doi.org/10.1002/ jemt.23281 13. Di, W., Chao, W., Yong, W., et al.: Attention deep model with multi-scale deep supervision for person re-identification. IEEE Trans. Emerg. Topics Comput. Intell. 5(1), 70–78 (2021). https://doi.org/10.1109/TETCI.2020.3034606 14. Zunair, H., Hamza, A.B.: Sharp U-Net: depthwise convolutional network for biomedical image segmentation. Comput. Biol. Med (2021). https://doi.org/10.48550/arXiv.2107.12461 15. Shelhamer, E., Long, J., Darrell, T.: Fully convolutional networks for semantic segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(4), 640–651 (2017). https://doi.org/10.1109/ TPAMI.2016.2572683 16. He, K., et al.: Deep residual learning for image recognition. In: IEEE (2016). https://doi.org/ 10.1109/CVPR.2016.90 17. Srivastava, R.K., Greff, K., Schmidhuber, J.: Highway Networks. Comput. Sci. (2015). https:// doi.org/10.48550/arXiv.1505.00387 18. Li, F., et al.: Latent traits of lung tissue patterns in former smokers derived by dual channel deep learning in computed tomography images. Sci. Rep. 11(1), 4916 (2021). https://doi.org/ 10.1038/s41598-021-84547-5 19. Hu, K., et al.: Brain tumor segmentation using multi-cascaded convolutional neural networks and conditional random field. IEEE Access 7, 92615–92629 (2019). https://doi.org/10.1109/ ACCESS.2019.2927433
A ConvMixEst and Multi-attention UNet for Intervertebral Disc Segmentation in Multi-modal MRI Sipei Lu, Hanqiang Liu(B) , and Xiangkai Guo School of Computer Science, Shaanxi Normal University, Xi’an, China [email protected]
Abstract. Accurate segmentation of spinal magnetic resonance imaging (MRI) plays a critical role in the diagnosis and evaluation of intervertebral discs. However, accurate disc segmentation is not easy due to the overly trivial and high similarity of disc tissues, as well as the huge variability between slices, and multimodal MRI disc segmentation is even more challenging. For this reason, we propose a model called ConvMixEst and Muti-Attention Unet (CAM-Unet), which combines MLP with the attentional mechanisms of Inverted Variational Attention (IVA) and Dilated Gated Attention (DGA). Specifically, in this work, we propose the IVA module for detailing the overall feature information of the dataset and capturing feature information at different scales, and also design a ConvMixEst for enhancing the global context information. After doing trials on the MICCAI-2018 IVD challenge dataset, we obtain Dice similarity coefficient equal to 92.53(%) and Jaccard coefficient equal to 86.10(%) and Precision equal to 94.09(%). Keywords: Intervertebral disc segmentation · Deep learning · Multi-modal imaging · Attention mechanism
1 Introduction Magnetic resonance imaging (MRI) is one of the few rapid and accurate clinical diagnostic methods that is not harmful to the body and is widely used for the diagnosis and detection of intervertebral disc abnormalities. An intervertebral disc (IVD) is a cartilaginous connection between two adjacent vertebrae that serves to withstand pressure, cushion shock, and protect the brain. Once the disc is deformed, it may cause various problems, making it particularly important to assess disc abnormalities by segmentation of the disc. The technique of dividing disc areas in vertebral images is known as segmentation of the intervertebral disc, however manual segmentation is tedious, time-consuming, and prone to subjectivity. Thus, automated segmentation of the intervertebral discs offers a reliable and repeatable method for identifying and diagnosing illness. A large proportion of disc segmentation networks nowadays use unimodal images, but multi-modal MR images can provide different and complementary information than unimodal ones, which can help improve recognition accuracy, so multi-modal disc segmentation can produce richer, more objective and accurate results. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 137–146, 2023. https://doi.org/10.1007/978-981-99-4761-4_12
138
S. Lu et al.
Traditional segmentation methods are widely used in vertebral and intervertebral disc segmentation, but with the development of deep learning, convolutional neural networks have excelled far beyond traditional methods in this area [1–4]. The 2015 ISBI cell tracking competition was won by UNet [5] by a wide margin, and it also had a significant impact on the area of medical image segmentation. Many research choose to improve the disadvantages of the U-encoder structure or do its deformation structure, for example, Zhao et al. [6] proposed a generic multi-scale subtraction network (M2 SNet) to reduce the large amount of redundant information generated by methods such as elementwise addition or concatenation in u-shaped structures. They designed a basic subtraction unit (SU) and extended it to obtain pixel-level and structure-level differential information and rich multi-scale difference information. In [7], the authors proposed a convolutional segmentation network combining hybrid convolution and multi-scale attention gates to overcome the inability of the encoder to extract global contextual information efficiently. The inter-slice attention module (ISA) of the SAU-net [8], developed by Zhang et al., was used to apply this attention mechanism to the spatial information dimension and to improve the segmentation results using the inter-slice information of adjacent slices. In addition, Li et al. [9] developed a strip pooling attention mechanism based on codec structures, exploiting the complete information feature of the entire anatomy and the remote reliance of spine outcomes for disc and vertebral segmentation. However, none of the aforementioned networks is able to simultaneously resolve the issue of data loss for both global and local features and improve the global contextual data. Inspired by UNeXt [10] and MALUNet [11], we propose an efficient image segmentation network, CMA-Unet, for 3D intervertebral disc segmentation. A tokenized MLP (Tok-MLP) [10] that helps to encode meaningful feature information is applied in CMA-Unet, while Dilated Gated Attention Block (DGA) [11] module is applied to obtain global and local information to help the model focus more on the target region. Notably, we design an attention mechanism to enhance the feature map at different scales and abstraction levels, and propose a ConvMixEst module to enhance global context information. In summary, our contributions are: • We suggest an attention mechanism module called an Inverted Variational Attention Block (IVA) that describes the overall feature information of the dataset more comprehensively. • We propose a ConvMixEst module for enhancing global context information. • Extensive experiments are performed on the MICCAI-2018 IVD challenge dataset. Results show that we have effectively improved the disc segmentation performance.
2 Methods Our network structure is based on a U-shaped encoder architecture with tandem IVA and DGA in the final two stages and Tok-MLP in the first three stages (see Fig. 1). As for the number of channels, it is still set to the regular {64, 128, 256, 512, 1024}. In this section, we first introduce the two modules proposed in this paper: DGA and ConvMixEst, and elaborate the proposed CMA-Unet.
A ConvMixEst and Multi-attention UNet
139
Fig. 1. Overview of the proposed CAM-UNet architecture.
2.1 Inverted Variational Attention Block Enhancing the overall feature information and rich multi-scale information of the dataset is one of the keys to improve the image segmentation performance. Therefore, we propose an attention module called IVA, as shown in Fig. 2. Given an input x ∈ RC×H×W , the size is unchanged after pointwise convolution. We use depthwise convolution to expand the feature map fourfold to obtain x ∈ R4C×H×W , and then pointwise convolution is applied to recover the dimension. Finally, the residual information is summed. It is worth mentioning that the depthwise separable convolution proposed in the architecture of Xception [19] is intended to reduce the parameters and computational effort in the model. But we found that by separating pointwise convolution and depthwise convolution, the spatial dimension and channel dimension of the feature map can be split, so that they can focus on different features separately, making it easier for the network to capture feature information at different scales and abstraction level. At the same time, we multiply the number of channels by 4, enabling to describe the overall feature information of the dataset more comprehensively.
Fig. 2. Inverted Variational Attention Block
140
S. Lu et al.
2.2 ConvMixEst We propose the ConvMixEst module for enhancing global contextual information, which is based on a hybrid concept. Specifically, depthwise convolution is used to blend spatial locations and dilated convolution is used to increase the perceptual field of the convolution kernel. The implementation of ConvMixEst is shown in Fig. 3. The feature map through two sets of residual convolution layers, a depthwise convolution with kernel size k × k and an extended convolution with dilated rate of 3 and kernel size of 7, and the output feature map is obtained by 1 × 1 convolution. Each convolution layer is followed by the activation function GELU and the BatchNorm normalization layer:
f = BN (σ1 {DepthWiseConv(x)})
(1)
f = BN σ1 DilatedConv f
(2)
f = BN (σ1 {Conv(f )})
(3)
where f denotes the output feature map in the ConvMixEst block, σ 1 denotes the GeLU activation, x denotes the input feature map, and BN denotes the batch normalization
Fig. 3. ConvMixEst Block
2.3 CMA-Unet After combining our suggested module with Tok-MLP and DGA, we acquire a CMAUnet network with Tok-MLP for the first three stages and an IVA in series with DGA for the final two stages on the basis of a U-shaped structure. The channel axes of the labeled convolutional features need to be serialized before tokenizing the MLP, which helps the MLP focus only on specific locations of the convolutional features. In this module, the features must be transformed and projected into tokens, as shown in Fig. 4(b), after which the tokens must be passed to the shifted MLP across the width. We then pass these tokens to depth convolution and activate using GELU. Next, they pass through the shifted MLP across the height, where the residual information must be summed. Finally,
A ConvMixEst and Multi-attention UNet
141
LN is applied to normalize them. And we pass the output features to the next block, and the above calculation can be summarized as follows: ZshiftW = Shift(x)
(4)
ZT = Tokenize ZshiftW
(5)
Z = DepthConv((MLP(ZT )))
(6)
ZshiftH = Shift(Z)
(7)
ZH = Tokenize ZshiftH
(8)
Z = LN (T + MLP(GELU (ZH )))
(9)
where T represents the tokens, H represents height, W represents width, and LN represents layer normalization.
Fig. 4. (a) The Dilated Gated Attention Block; (b) the Tok-MLP bolck
While we use IVA and DGA arranged in series, the former describes the overall feature map of the dataset more comprehensively and captures feature information at different scales, while the latter can gather both global and local feature information in the sample. They complement each other to maximize the advantages of the two attention mechanisms. Split Dilated Conv Unit (SDC) and Gated Attention Unit (GA) make up the DGA, as indicated in Fig. 4(a). With the use of deep separable convolution and various dilated rates, the SDC splits the feature map into four sections along the channel dimension. After the concatenation and convolution operation, we pass the feature map to the GA and uses depthwise separable convolution to generate attention maps with the same shape as the input features to suppress the unimportant information in the features transmitted by SDC.
142
S. Lu et al.
In the encoder, we downsample using maximum pooling and input the feature map to the ConvMixEst module in the fifth stage to enhance the global context information, and in the decoder stage, we upsample by bilinear interpolation with a convolutional kernel size 3 × 3.
3 Experiments 3.1 Dataset In this section, we evaluate our model on the MICCAI-2018 IVD public challenge dataset. At least seven lower spine IVDs are included in the dataset (see Fig. 5), each of which is composed of 16 sets of 3D multi-modal MRI images of the vertebrae with a total of four modalities (fat, in-phase, opp-phase, and water), each with manual segmentation. The CAM-UNet model takes 2D slices of multi-modal image volume as input and set the size to 256 × 256 when slicing.
Fig. 5. (a) Fat image; (b) In-phase; (c) Opp-phase; (d) water image; (e) the ground truth of the corresponding IVD image.
3.2 Implementation Details Our following experiments are implemented in PyTorch and 12 GB NVIDIA RTX 3060Ti GPU, with all network batch sizes set to 8 and trained using Adam network optimizer with learning rate of 10e−5, where epoch is set to 300. The 3D multi-modal MRI raw images are sliced in the sagittal axis to obtain 2D slices, and the training data set is obtained after normalization and removal of slices that do not contain targets. In addition, no other data post-processing operations are used. 3.3 Ablation Studies Ablation Study on the Single Module To demonstrate and analyze the value of the suggested module, our model was trained on the IVD dataset. When all Tok-MLP structures are employed, the DSC is only 89.02% and the Jaccard coefficient is just 80.21%, as indicated in the Table 1. We can observe that both DSC and Precision experienced growth of greater than 1% after the introduction of the DGA block, which shows that the block successfully collects both global and local feature information. Then we added the
A ConvMixEst and Multi-attention UNet
143
IVA block in series with it to further improve the performance, the experiment shows that the IVA can describe the feature information of the dataset in detail and capture the feature information at different scales. And then, to obtain the best performance, we further incorporated the ConvMixEst block to make Precision, DSC, and Jaccard coefficient to achieve 92.53%, 86.10%, and 94.09%, respectively, which fully shows the effectiveness of ConvMixEst to enhance the global context information. Table 1. Ablation study on the single module. DSC (%)
Jaccard (%)
Precision (%)
Pure ToK-MLP
89.02
80.21
87.12
ToK-MLP + DGA
90.17
82.10
89.65
ToK-MLP + IVA + DGA
91.83
84.90
93.28
ToK-MLP + IVA + DGA + ConvMixEst
92.53
86.10
94.09
Ablation Study on IVA and DGA: To find the ideal configuration of CAM-Unet, we investigate how the two attention mechanisms are coupled. In this paper, we explore the performance variance brought about by the sequential of IVA and DGA in tandem, as described in MALUNet [11] about the performance variation induced by the variable sequential order between modules. Table 2 makes it clear that linking IVA and DGA in series will result in the best performance. Table 2. Ablation study on IVA and DGA. Model
DSC (%)
Jaccard (%)
Precision (%)
ToK-MLP + IVA + DGA + ConvMixEst
92.53
86.10
94.09
ToK-MLP + DGA + IVA + ConvMixEst
90.79
83.13
90.32
4 Result We compared other networks disc segmentation networks, focusing on the DSC, Jaccard coefficient, and Precision metrics. It is worth noting that some literatures do not discuss Precision or JC coefficients, so they are marked with a “/” in the table. Among them, the datasets used by IVD-net, RIMNet and the network proposed by Mader et al. are all consistent with our work. With the DSC improvement of 0.91% over the more wellknown IVD-net and improvements of 2.83% and 4.17% over the networks that were just recently proposed by Tijana et al. [15] and Meletios et al. [16], it is clear from the Table 3 that our suggested CAM-UNet network is clearly superior.
144
S. Lu et al. Table 3. Comparison with others methods.
Networks
DSC (%)
Jaccard (%)
Precision (%)
BSU-Net [12]
89.44
/
/
RIMNet [13]
91.70
87.00
/
IVD-Net [4]
91.62
/
/
Mader et al. [14]
90.40
/
/
SAU-Net [8]
89.86
81.56
89.75
Tijana et al. [15]
89.70
81.30
/
Meletios et al. [16]
88.36
/
/
CAM-Unet
92.53
86.10
94.09
On the same baseline, we compare medical image segmentation networks using different techniques to improve UNET on the IVD dataset. Both UNeXt and MALUNet are lightweight medical image segmentation models with parameters {32, 64, 128, 256, 512} and {8, 16, 24, 32, 48, 64}, respectively, but for comparison purposes, we adjusted to be consistent with this paper. And the setting of this parameter is mentioned in the ablation research in the UNeXt and MALU-Net papers. In MALU-Net, the DSC index is 87.55% for parameters {32, 64, 128, 256, 512}; in UNeXt, the JA factor is 83.10%. As shown in the Table 4, CAM-Unet improves almost all the metrics compared to other models, and some results are shown in Fig. 6. We can see that the segmentation results of MALUNet-L have some residual parts, and the segmentation results of our network are more accurate at the edges than the more or less irrelevant speckle problem that all other networks have. Table 4. Compare models at the same baseline. Networks
DSC (%)
Jaccard (%)
Precision (%)
U-Net [5]
89.51
81.01
88.33
U-net++ [17]
88.32
79.08
86.78
U-net3+ [18]
90.13
82.03
93.27
CMU-Net [7]
91.59
84.48
95.10
UNEXT-L [10]
90.21
82.17
90.10
MALUnet-L [11]
86.02
75.47
85.87
CAM-Unet
92.53
86.10
94.09
A ConvMixEst and Multi-attention UNet
145
Fig. 6. Randomly selected segmentation results. (a) Ground Truth; (b) U-Net; (c) U-net++; (d) U-net3+; (e) CMU-Net; (f) UNEXT-L; (g) MALUnet-L; (h) CAM-Unet.
5 Conclusion In this work, we propose a straightforward network for multi-modal MRI disc segmentation that has a U-shaped basic framework with Tok-MLP in the first three stages and tandem IVA and DGA in the last two stages. And we input feature information as patches into the ConvMixEst module in the fifth stage. ConvMixEst seeks to improve the global context information, while the proposed IVA module is utilized to describe the overall feature information of the dataset in detail and capture the feature information at different scales. Comparative studies using the MICCAI-2018 IVD challenge dataset show that CAM-Unet performs better than the recently proposed intervertebral disc image segmentation network and the traditional medical image segmentation network. Additionally, in the future, our study will focus on further attempt to enhance CAM-UNet, including larger convolutional kernels or higher encoder architectures, as well as potential applications in other jobs.
References 1. Dutande, P., Baid, U., Talbar, S.: Deep residual separable convolutional neural network for lung tumor segmentation. Comput. Biol. Med. 141, 105161 (2022). https://doi.org/10.1016/ j.compbiomed.2021.105161 2. Cheng, J., Tian, S., Yu, L., et al.: ResGANet: residual group attention network for medical image classification and segmentation. Med. Image Anal. 76, 102313 (2022). https://doi.org/ 10.1016/j.media.2021.102313 3. Punn, N.S., Agarwal, S.: RCA-IUnet: a residual cross-spatial attention-guided inception UNet model for tumor segmentation in breast ultrasound imaging. Mach. Vis. Appl. 33(2), 37 (2022). https://doi.org/10.1007/s00138-022-01280-3 4. Dolz, J., Desrosiers, C., Ben Ayed, I.: IVD-Net: intervertebral disc localization and segmentation in MRI with a multi-modal UNet. In: Zheng, G., Belavy, D., Cai, Y., Li, S. (eds.) CSI 2018. LNCS, vol. 11397, pp. 130–143. Springer, Cham (2019). https://doi.org/10.1007/9783-030-13736-6_11 5. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-31924574-4_28
146
S. Lu et al.
6. Zhao, X., Jia, H., Pang, Y., et al.: M2 SNet: multi-scale in multi-scale subtraction network for medical image segmentation. arXiv preprint arXiv:2303.10894 (2023). https://doi.org/10. 48550/arXiv.2303.10894 7. Tang, F., Wang, L., Ning, C., Xian, M., Ding, J.: CMU-Net: a strong ConvMixer-based medical ultrasound image segmentation network. arXiv preprint arXiv:2210.13012 (2022). https://doi. org/10.48550/arXiv.2210.13012 8. Zhang, Y., Yuan, L., Wang, Y., Zhang, J.: SAU-Net: efficient 3D spine MRI segmentation using inter-slice attention. In: Medical Imaging with Deep Learning, pp. 903–913 (2020) 9. Li, C., Liu, T., Chen, Z., et al.: SPA-RESUNET: strip pooling attention resunet for multiclass segmentation of vertebrae and intervertebral discs. In: 2022 IEEE 19th International Symposium on Biomedical Imaging (ISBI), pp. 1–5 (2022). https://doi.org/10.1109/ISBI52 829.2022.9761577 10. Valanarasu, J.M.J., Patel, V.M.: UNeXt: MLP-based rapid medical image segmentation network. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) Medical Image Computing and Computer Assisted Intervention – MICCAI 2022. LNCS, vol. 13435, pp. 23–33. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16443-9_3 11. Ruan, J., Xiang, S., Xie, M., Liu, T., Fu, Y.: MALUNet: a multi-attention and light-weight unet for skin lesion segmentation. In: 2022 IEEE International Conference on Bioinformatics and Biomedicine (BIBM), pp. 1150–1156 (2022). https://doi.org/10.1109/BIBM55620.2022. 9995040 12. Kim, S., Bae, W.C., Masuda, K., Chung, C.B., Hwang, D.: Fine-grain segmentation of the intervertebral discs from MR spine images using deep convolutional neural networks: BSUNet. Appl. Sci. 8(9), 1656 (2018). https://doi.org/10.3390/app8091656 13. Das, P., Pal, C., Acharyya, A., et al.: Deep neural network for automated simultaneous intervertebral disc (IVDs) identification and segmentation of multi-modal MR images. Comput. Methods Programs Biomed. 205, 106074 (2021). https://doi.org/10.1016/j.cmpb.2021. 106074 14. Mader, A.O., Lorenz, C., Meyer, C.: A general framework for localizing and locally segmenting correlated objects: a case study on intervertebral discs in multi-modality MR images. In: Zheng, Y., Williams, B.M., Chen, K. (eds.) MIUA 2019. CCIS, vol. 1065, pp. 364–376. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-39343-4_31 15. Šušteršiˇc, T., Rankovi´c, V., Milovanovi´c, V., Kovaˇcevi´c, V., Rasuli´c, L., Filipovi´c, N.: A deep learning model for automatic detection and classification of disc herniation in magnetic resonance images. IEEE J. Biomed. Health Inform. 26(12), 6036–6046 (2022). https://doi. org/10.1109/JBHI.2022.3209585 16. Liaskos, M., Savelonas, M.A., Asvestas, P.A., Lykissas, M.G., Matsopoulos, G.K.: Bimodal CT/MRI-based segmentation method for intervertebral disc boundary extraction. Information 11(9), 448 (2020). https://doi.org/10.3390/info11090448 17. Zhou, Z., Siddiquee, M.M.R., Tajbakhsh, N., Liang, J.: Unet++: redesigning skip connections to exploit multiscale features in image segmentation. IEEE Trans. Med. Imaging 39(6), 1856– 1867 (2019). https://doi.org/10.1109/TMI.2019.2959609 18. Huang, H., Lin, L., Tong, R., et al.: Unet 3+: a full-scale connected Unet for medical image segmentation. In: ICASSP 2020–2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 1055–1059 (2020). https://doi.org/10.1109/ICASSP 40776.2020.9053405 19. Chollet, F.: Xception: deep learning with depthwise separable convolutions. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1251–1258 (2017)
One-Dimensional Feature Supervision Network for Object Detection Longchao Shen1 , Yongsheng Dong1(B) , Yuanhua Pei1 , Haotian Yang1 , Lintao Zheng1 , and Jinwen Ma2 1 School of Information Engineering, Henan University of Science and Technology,
Luoyang 471023, China [email protected] 2 Department of Information and Computational Sciences, School of Mathematical Sciences and LMAM, Peking University, Beijing 100871, China
Abstract. Self-attention mechanisms have been widely used in object detection tasks to distinguish the importance of different channels and reinforce important information in features, and also leads to the exciting results at all scales. However, most of the self-attentive mechanisms, as well as their variants, focus only on the channel dimension and thus easily ignore the wide and high dimensions of the feature map that play an important role in capturing local contextual information. To alleviate this problem, in this paper we propose an one-dimensional feature supervision network for object detection (1DSNet). Specifically, we first propose an one-dimensional feature supervision module (1DSM). It uses a lightweight one-dimensional feature vector to weight the features from the width and height perspectives, respectively, for jointly reinforcing the important information in the features. Moreover, in order to improve the representation of multi-scale feature context information, we construct a receptive field dilated pyramid pooling (RFDSPP) that can obtain a larger field of view based on the spatial pyramid pooling. Finally, experimental results demonstrate that our proposed 1DSNet is effective and competitive when compared with some representative methods. Keywords: Object Detection · CNN · Feature Pyramid · Attention Mechanism
1 Introduction Computer vision aims to enable computers to understand and recognize the content of images and videos for machines to acquire human-like vision. Object detection is a very important task in the field of computer vision, aiming to identify and locate the position of a specific object in an image or video. Currently, object detection has important implications and applications in many fields. In order to achieve high accuracy, the object detection model needs to generate features with global contextual information and multi-level semantic features, both of which are important for the detection task. To achieve this goal, researchers have made improvements in the backbone network [1], the feature extraction network [2, 3], and © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 147–156, 2023. https://doi.org/10.1007/978-981-99-4761-4_13
148
L. Shen et al.
the detection head [4], respectively. On the other hand, attention mechanism [5] has been proved effective on the improvement of object detection accuracy in numerous experiments. Attention mechanisms typically include channel attention [6] and spatial attention [7, 8], both of which focus mainly on the channel dimension of the feature map while ignoring the other two dimensions of the feature map, width and height. In addition, these self-attentive-based models require the generation of a very large number of attention weights for the number of parameters and consume very expensive computational resources. Specifically, in this paper, we first design an one-dimensional feature supervision module (1DSM), which aggregates the width and height of features separately, and each is used to supervise the width and height information of multi-level features in the network. In addition, a receptive field dilated pyramid pooling module (RFD-SPP) is constructed in order to enhance the perceptual field of the top-level feature map of the network and facilitate the separation of the most salient contextual features. By combining 1DSM and RFD-SPP, our proposed method achieves competitive results on both the MS COCO dataset and the VOC dataset. The contributions of this paper can be summarized as follows: • We design an one-dimensional feature supervision module (1DSM) that utilizes a lightweight one-dimensional convolutional to weight the wide and high-dimensional information of the feature map, respectively, to enhance the local contextual information of the features as well as the multi-level feature representation capability. • We design a receptive field dilated spatial pyramid pooling (RFD-SPP) to significantly improve the perceptual field coverage of the feature map while taking into account the normal SPP effect, thus enhancing the inheritance of contextual information. • We construct an one-dimensional feature supervision network for object detection (1DSNet) based on 1DSM and RFD-SPP by integrating them into an object detection network paradigm. Experiments demonstrate that our proposed 1DSNet achieves competitive results.
2 Methodology In this section, we propose an one-dimensional feature supervision network for object detection (1DSNet). In the following subsections, we propose one-dimensional feature supervision module (1DSM) and receptive field dilated spatial pyramid pooling (RFDSPP) to enhance multi-level semantic features and contextual information, respectively. Finally, we give the overall architecture of our proposed 1DSNet. 2.1 One-Dimensional Feature Supervision Module The channel attention mechanism proposed by SENet [9] is widely used. The main idea is to weight the fusion of features by learning the weights of image features on the channel dimension, thus emphasizing the important feature channels and suppressing the less important ones. But channel attention only notices that each feature channel needs to be ranked in importance, ignoring the other two dimensions of the feature map, width and
One-Dimensional Feature Supervision Network for Object Detection
149
height. Therefore, we constructed a one-dimensional feature supervision module that uses one-dimensional convolution to weight the contextual information of the image, including the width scale and height scale. The structure of the module is shown in Fig. 1.
Fig. 1. Structure of 1DSM. The left input is the original image and the right input is the feature map to be weighted.
The 1DSM consists of two parts, the left part of Fig. 1 shows the multi-level feature extraction structure, which uses a lightweight ResNet-18 network to initially extract the features that need to be convolved in one dimension, and the number of levels and scales of features extracted are the same as the feature maps obtained from the backbone network used. The right-hand part of Fig. 1 shows the 1D feature supervision module structure. As the shape of the feature map obtained after the multi-level feature extraction structure is B × C × H × W, B is the batch size and does not generally change, so B is omitted in all of the following. The feature shape that can be processed by the onedimensional convolution is C × H (W), so it needs to be squeezed and reduced in dimension. This results in a single-scale directional feature map of C × H (W), which is then convolved in one dimension, and the number of channels is reduced to one to obtain a 1 × H (W) one-dimensional feature. The input features to be fused in this module have the shape C × H × W. To weigh them with 1D features, they need to have the same number of dimensions. Therefore, a dilation to the width or height dimension is attached after the one-dimensional convolution to obtain feature vectors of 1 × H × 1 and 1 × 1 × W respectively. Afterwards, similar to the channel attention module, this is matrix multiplied and weighted with the features to be fused and subjected to a 1 × 1 two-dimensional convolution for feature smoothing to obtain the final output feature map of the 1DSM. 2.2 Receptive Field Dilated Spatial Pyramid Pooling Spatial pyramid pooling (SPP) [10] is a technique widely used in deep learning, originally used to address the difficulties posed by input images of different sizes for convolutional neural networks. In the direction of object detection, the purpose of using SPP is that information from different scales of the input image can be pooled while keeping the feature map size constant, which can improve the model’s ability to recognize objects, especially for objects of different scales. As the pooling operation is essentially a downsample operation on sub-regions of the input feature map, some important information
150
L. Shen et al.
may be lost in the pooled features. This chapter adds a receptive field expansion branch to SPP. With the perceptual field expansion branch, SPP can better retain contextual information and, by increasing the range of the perceptual field, can better capture the complete information of objects at different scales.
Fig. 2. The architecture of 1DSNet. The backbone is DSPDarkNet-53. The red arrow is the feature map input to 1DSM for weighting, and the green arrow is the weighted feature map returned to replace the original feature map.
The specific structure of the module is shown in Fig. 2 for RFD-SPP. Containing four branches, three of which are max-pooling branches, the max-pooling layer has pooling kernels of sizes 5 × 5, 9 × 9, and 13 × 13. The other is receptive field dilated branch. Specifically, the input features are first feature-enhanced by applying a 3 × 3 depth-separable convolution and then connect to dilation blocks, each of which contains a dilation branch and a skip connection branch. The dilation branch applies a dilated convolution to expand the entire receptive field to a larger area. The skip connection branch directly connects the features before and after processing to ensure that the original receptive field is not lost. 2.3 The Overall Architecture of Our Proposed 1DSNet We construct an one-dimensional feature supervision network for object detection with DSPDarkNet53 as the backbone, combined with 1DSM and RFD-SPP. The specific structure is shown in Fig. 2. One 1DSM is connected to the other side of the image input, outputting three scales of 1D features, respectively, shaped to match the three layers of structural feature maps output by the backbone network and weighted with each layer of feature maps before its bottom-up fusion phase.
One-Dimensional Feature Supervision Network for Object Detection
151
3 Experiments In this section, we test our method on the challenging public large datasets MS COCO and PASCAL VOC and compare it with some existing representative networks. Our experimental section consists of three parts: 1) implementation details; 2) ablation experiments; 3) comparison with state-of-the-art methods. 3.1 Implementation Details On the backbone network, 1DSNet uses a Mish [14] activation function that guarantees continuous smoothness and non-monotonicity at each point, thus making gradient descent more effective. The network input image size is 416 × 416, and the stochastic gradient descent (SGD) method is used to optimize the network parameters with a momentum of 0.937, a weight decay rate of 5e−4, and a batch size of 16. The learning rate is initially set to 1e−6, and the learning rate decay mode is cosine decay, specifically in 200 epochs as shown in Fig. 3. The above hyperparameters are determined by hardware parameters and previous experience. To avoid overfitting, 1DSNet uses Dropblock [12], which, unlike the conventional Dropout, which performs the local region’s culling indirectly, achieves joint effectiveness on both the fully connected and convolutional layers. Additionally, 1DSNet abandons the conventional bounding box regression loss and adopts CIoU [13] loss during the training phase because bounding box regression loss and IoU optimization are not exactly equivalent. This increases the stability of the object frame regression.
Fig. 3. Learning rate variation in 200 epochs. The first 50 epochs are trained for freezing backbone network parameters, and the last 150 epochs are thawed for training.
To reduce the training time of the model, instead of choosing to start training from scratch, we use transfer learning and load the generic pre-training weights of DSPDarkNet-53 on ImageNet. Also we use 5-fold cross-validation, dividing the whole dataset into 5 equal-sized fractions. One fold at a time is used as the validation set and the remaining 4 as the train set, and the operation is repeated 4 times. The final validation accuracy is calculated by taking the average accuracy of these 5 model validation data.
152
L. Shen et al.
3.2 Ablation Experiment This section conducts ablation experiments on 1DSNet and provides a detailed analysis of the results. The main study is on the impact of each of the two methods proposed in this chapter in the network and the impact of different spatial pyramidal pooling structures on the network performance. Table 1. Effect of each component. The bolded part is the Max. of the current column. 1DSM baseline √ – √
RFD-SPP – √ √
AP
AP50
AP75
APs
APm
APl
51.4
77.7
56.7
31.0
48.0
54.4
52.9
78.3
58.6
27.3
47.5
56.1
57.3
84.8
64.8
37.6
51.2
61.7
58.4
84.4
65.6
28.1
48.8
63.0
Firstly, 1DSM and RFD-SPP are applied to the baseline network respectively, and the baseline network for all ablation experiments in this section is the DSPDarkNet53 backbone as well as the Yolo detection head structure, and the validation dataset was VOC2007. The comparison results of this set of experiments are shown in Table 1. As can be seen in Table 1, the AP of the baseline network is 51.4. When 1DSM is applied to the network, an AP result of 1.5 higher than that of the baseline method is obtained. This is due to the fact that 1DSM enhances the important information in the feature map using one-dimensional features, making it easier for the network to learn the corresponding information. By applying RFD-SPP, the network obtains 57.3 AP, showing that it has a significant effect. Finally, when 1DSM and RFD-SPP are applied simultaneously, the highest 58.4 AP is obtained, proving that both modules can work on the network simultaneously. 3.3 Main Results This section evaluates the performance of 1DSNet on the PASCAL VOC and MS COCO datasets and compares it with other state-of-the-art one-stage and two-stage detectors. For the sake of experimental fairness, all results are obtained under the same single-scale training and inference settings. Result on VOC2007. We first compare 7 different methods on VOC2007 dataset. The comparison results between the proposed 1DSNet and the seven methods are shown in Table 2. The two best comparison methods are Yolo-v3 [19] and Yolo-v4 [11]. This is because its backbone network uses the DarkNet series, which has a better ability to extract features. The 1DSNet proposed in this chapter achieved the highest 84.73 mAP in the table. Among all 20 object categories, 15 achieved the highest results, and the other 5 also achieved more advanced results.
One-Dimensional Feature Supervision Network for Object Detection
153
Table 2. PASCAL VOC 2007 dataset detection result. The bolded part is the maximum values of the current column. Method
mAP aero
car
bike
bird
boat
bottle bus
cat
chair cow
Faster 66.97 68.98 80.36 77.82 71.82 51.20 46.25 46.25 84.17 47.28 62.17 R-CNN [15] RetinaNet [3]
70.86 76.42 78.00 61.49 60.47 72.64 50.10 67.11 69.59 52.85 66.22
CenterNet [16]
77.45 86.17 88.45 85.56 76.30 64.73 59.20 84.32 88.55 62.38 82.04
SSD300 [17]
72.11 85.32 79.41 74.97 76.58 60.79 49.86 82.29 88.00 50.91 73.45
M2Det [18]
75.66 83.18 87.36 85.61 72.57 68.08 45.17 84.80 83.49 58.25 80.61
Yolo-v3 [19]
81.03 89.27 91.81 88.14 80.01 64.46 70.03 88.93 87.78 67.57 87.26
Yolo-v4 [11]
82.75 92.33 92.98 89.70 80.41 74.49 71.22 90.01 86.76 66.82 91.03
1DSNet
84.73 94.03 93.66 91.29 85.89 68.23 75.03 93.36 87.24 69.91 93.10
Method
mAP table
dog
horse moto person plant
sheep sofa
train
tv
Faster 66.97 56.41 77.15 79.28 72.74 77.52 37.93 69.18 65.30 76.71 64.56 R-CNN [15] RetinaNet [3]
70.86 37.48 83.12 76.79 85.75 72.53 56.34 72.25 69.49 68.15 69.51
CenterNet [16]
77.45 67.74 80.71 86.91 85.47 84.23 49.41 78.74 75.59 84.52 78.02
SSD300 [17]
72.11 53.79 82.30 78.80 78.79 83.52 44.21 74.66 66.27 83.81 74.45
M2Det [18]
75.66 75.48 79.46 85.45 82.75 80.89 44.32 77.28 78.91 84.90 74.71
Yolo-v3 [19]
81.03 77.88 86.41 88.42 90.29 87.85 47.22 78.34 80.30 89.17 79.45
Yolo-v4 [11]
82.75 73.52 89.05 91.51 90.41 89.23 53.67 81.83 73.65 93.33 83.06
1DSNet
84.73 79.59 86.30 92.77 92.24 91.73 60.46 88.41 76.83 90.95 83.55
Result on COCO2017. To further test the robustness of 1DSNet, we conducted experiments again on a larger dataset with more data volume, COCO2017, and compared it with 10 methods. These results were trained on the COCO train-set and tested on the test-set. The results of FoveaBox [20], PISA [21], FCOS [22], and YoloF [23] statistics are taken from the references.
154
L. Shen et al.
Among them, FoveaBox, FCOS, and PISA are trained using four V100 GPUs. The remaining methods were trained using a single RTX-2080ti GPU. The comparison results are shown in Table 3. As can be seen from Table 3, in addition to YoloF using 8GPU, among the comparison methods using a single GPU, Yolo-v4 using the more advanced DSPDarkNet backbone obtained the best result with 36.6 AP. Among the other single GPU methods, Mask-RCNN, a two-stage detection network using the more generalized ResNet backbone, obtained the next-best result with 36.0 AP. In the approach using multiple GPUs, PISA obtained the highest AP 37.3, using 4 GPUs. This is due to the fact that PISA explored hierarchical local ranking to rank the im-portance of the samples and reweighted the importance scores. YoloF with 8 GPUs obtained the highest 37.7AP for all comparison methods. Our proposed 1DSNet is trained using a single GPU, and the final result is 37.9 AP, achieving the best value among all compared methods. It is 1.3 AP higher than Yolo-v4 which also uses a single GPU and DSPDarkNet53. 1DSNet uses fewer GPU resources to obtain better results compared to PISA and YoloF with multi-GPU training. Also, we achieved the highest results for AP75, APs and APm with 41.4, 23.9, and 43.6, respectively. But despite the clear advantage of 1DSNet in most of the evaluation metrics, it still has drawbacks, including 1DSNet being lower than Yolo-v3 in AP50. This is due to the fact that 1DSNet uses CIoU as a threshold in training, which adds an aspect ratio similarity metric that can provide higher IoU scores for positive sample prediction frames, but this is largely ineffective for the AP50 with lower thresholds. Table 3. Microsoft COCO 2017 dataset detection results. The bolded part is the maximum values of the current column. Method
Backbone
AP
AP50
AP75
APs
APm
APl
Faster R-CNN [15]
ResNet50
33.2
52.5
35.7
13.5
32.6
45.4
Mask R-CNN [25]
ResNet50
36.0
55.4
39.0
20.2
39.9
47.5
CenterNet [16]
ResNet50
32.5
51.0
34.8
13.8
38.6
46.8
FCOS [22]
ResNet50
37.1
55.9
39.8
21.3
41.0
47.8
FoveaBox [20]
ResNet50
37.1
57.2
39.5
21.6
41.4
49.1
Yolo-v2 [24]
DarkNet19
21.6
44.0
19.2
5.0
22.4
35.5
Yolo-v3 [19]
DarkNet53
33.0
57.9
34.4
18.3
35.4
41.9
Yolo-v4 [11]
DSPDarkNet53
36.6
57.4
38.4
9.7
34.1
55.0
PISA[21]
ResNet50
37.3
56.5
40.3
20.3
40.4
47.2
YoloF [23]
ResNet50
37.7
56.9
40.6
19.1
42.5
53.2
1DSNet
DSPDarkNet53
37.9
55.7
41.4
23.9
43.6
47.5
One-Dimensional Feature Supervision Network for Object Detection
155
4 Conclusion In this paper, we propose a novel one-dimensional feature supervision network for object detection (1DSNet). Our proposed IDSNet improves the previous attention mechanism that only weighted the channel dimensions of feature maps by introducing an onedimensional feature supervision module (1DSM). The constructed IDSM utilizes an one-dimensional feature vector to weight the width and height dimensions of the feature map, allowing for joint reinforcement of important information. This additional supervision leads to better performance on object detection tasks compared to previous methods that only focused on the channel dimension. Secondly, a receptive field dilated spatial pyramid pooling (RFD-SPP) is proposed to address the problem of contextual semantic information loss in networks. By integrating these two modules, we construct the 1DSNet. The experimental results show that 1DSNet achieves competitive results compared with a variety of state-of-the-art methods in a consistent experimental setting. Acknowledgment. This work was supported by the Natural Science Foundation of Henan under Grant 232300421023.
References 1. Qiao, S., Chen, L. C., Yuille, A.: Detectors: detecting objects with recursive feature pyramid and switchable atrous convolution. In: 34th IEEE Conference on Computer Vision and Pattern Recognition, pp. 10213–10224. IEEE Press, Online (2021) 2. Tan, Z., Wang, J., Sun, X., Lin, M., Li, H.: Giraffedet: a heavy-neck paradigm for object detection. In: 10th International Conference on Learning Representations. Elsevier Press, Online (2022) 3. Lin, T.Y., Goyal, P., Girshick, R., He, K., Dollár, P.: Focal loss for dense object detection. In: 16th IEEE International Conference on Computer Vision, pp. 2980–2988. IEEE Press, Venice (2017) 4. Li, F., et al.: Lite detr: an interleaved multi-scale encoder for efficient detr. In: 36th IEEE Conference on Computer Vision and Pattern Recognition. IEEE Press, Vancouver (2023) 5. Vaswani, A., et al.: Attention is all you need. Adv. Neural. Inf. Process. 30 (2017) 6. Lee, H., Kim, H.E., Nam, H.: Srm: a style-based recalibration module for convolutional neural networks. In: 17th IEEE International Conference on Computer Vision, pp. 1854–1862. IEEE Press, Seoul (2019) 7. Deng, S., Liang, Z., Sun, L., Jia, K.: Vista: Boosting 3d object detection via dual cross-view spatial attention. In: 35th IEEE Conference on Computer Vision and Pattern Recognition, pp. 8448–8457. IEEE Press, New Orleans (2022) 8. Guo, M.H., Lu, C.Z., Hou, Q., Liu, Z.N., Cheng, M.M., Hu, S.M.: SegNeXt: rethinking convolutional attention design for semantic segmentation. In: 16th Advances in Neural Information Processing Systems. MIT Press, New Orleans (2022) 9. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: 31th IEEE Conference on Computer Vision and Pattern Recognition, pp. 7132–7141. IEEE Press, Salt Lake City (2018) 10. He, K., Zhang, X., Ren, S., Sun, J.: Spatial pyramid pooling in deep convolutional networks for visual recognition. IEEE Trans. Pattern Anal. Mach. Intell. 37(9), 1904–1916 (2015) 11. Bochkovskiy, A., Wang, C.Y., Liao, H.Y.M.: Yolov4: optimal speed and accuracy of object detection. arXiv preprint arXiv:2004.10934 (2020)
156
L. Shen et al.
12. Ghiasi, G., Lin, T.Y., Le, Q.V.: Dropblock: A regularization method for convolutional networks. Adv. Neural. Inf. Process. 31 (2018) 13. Zheng, Z., Wang, P., Liu, W., Li, J., Ye, R., Ren, D.: Distance-IoU loss: Faster and better learning for bounding box regression. In: 34th AAAI Conference on Artificial Intelligence, vol. 34, no. 07, pp. 12993–13000. AAAI Press, New York City (2020) 14. Misra, D.: Mish: a self regularized non-monotonic activation function. arXiv preprint arXiv: 1908.08681 (2019) 15. Ren, S., He, K., Girshick, R., Sun, J.: Faster r-cnn: towards real-time object detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell. 39(6), 1137–1149 (2017) 16. Duan, K., Bai, S., Xie, L., Qi, H., Huang, Q., Tian, Q.: Centernet: keypoint triplets for object detection. In: 17th IEEE International Conference on Computer Vision, pp. 6569–6578. IEEE Press, Seoul (2019) 17. Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C.Y., Berg, A.C.: Ssd: single shot multibox detector. In: 14th European Conference on Computer Vision, pp. 21–37. Springer Press, Amsterdam (2016) 18. Zhao, Q., et al.: M2det: a single-shot object detector based on multi-level feature pyramid network. In: 33th AAAI Conference on Artificial Intelligence, vol. 33, no. 1, pp. 9259–9266. AAAI Press, Hawaii (2019) 19. Redmon, J., Farhadi, A.: Yolov3: an incremental improvement. arXiv preprint arXiv:1804. 02767 (2018) 20. Kong, T., Sun, F., Liu, H., Jiang, Y., Li, L., Shi, J.: Foveabox: beyond anchor-based object detection. IEEE Trans. Image Process. 29, 7389–7398 (2020) 21. Cao, Y., Chen, K., Loy, C.C., Lin, D.: Prime sample attention in object detection. In: 33th IEEE Conference on Computer Vision and Pattern Recognition, pp. 11583–11591. IEEE Press, Seattle (2020) 22. Tian, Z., Shen, C., Chen, H., He, T.: Fcos: fully convolutional one-stage object detection. In: 17th IEEE International Conference on Computer Vision, pp. 9627–9636. IEEE Press, Seoul (2019) 23. Chen, Q., Wang, Y., Yang, T., Zhang, X., Cheng, J., Sun, J.: You only look one-level feature. In: 34th IEEE Conference on Computer Vision and Pattern Recognition, pp. 13039–13048. IEEE Press, Online (2021) 24. Redmon, J., Farhadi, A.: YOLO9000: better, faster, stronger. In: 30th IEEE Conference on Computer Vision and Pattern Recognition, pp. 7263–7271. IEEE Press, Hawaii (2017) 25. He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: 16th IEEE International Conference on Computer Vision, pp. 2961–2969. IEEE Press, Venice (2017)
Use the Detection Transformer as a Data Augmenter Luping Wang and Bin Liu(B) Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab, Hangzhou 311121, China {wangluping,liubin}@zhejianglab.com
Abstract. Detection Transformer (DETR) is a Transformer architecture based object detection model. In this paper, we demonstrate that it can also be used as a data augmenter. We term our approach as DETR assisted CutMix, or DeMix for short. DeMix builds on CutMix, a simple yet highly effective data augmentation technique that has gained popularity in recent years. CutMix improves model performance by cutting and pasting a patch from one image onto another, yielding a new image. The corresponding label for this new example is specified as the weighted average of the original labels, where the weight is proportional to the area of the patch. CutMix selects a random patch to be cut. In contrast, DeMix elaborately selects a semantically rich patch, located by a pre-trained DETR. The label of the new image is specified in the same way as in CutMix. Experimental results on benchmark datasets for image classification demonstrate that DeMix significantly outperforms prior art data augmentation methods including CutMix. Oue code is available at https://github.com/ZJLAB-AMMI/DeMix. Keywords: detection transformer · object detection · data augmentation · CutMix · image classification
1 Introduction Data augmentation is a technique used in machine learning where existing data is modified or transformed to create new data. The principle behind data augmentation is that even small changes to existing data can create new useful examples for training. For example, flipping an image horizontally can create a new training example that is still representative of the original object. In general, the goal of data augmentation is to increase the diversity of the training data while preserving its underlying structure. By introducing variations into the training data, models can learn to generalize better and perform well on unseen data. Data augmentation has been around for decades, but recent advances in deep learning have made it more widely used and effective. It has been commonly used in computer vision tasks like image classification [8, 20, 24], object detection [21, 39, 40], and segmentation [5, 10, 31, 37], as well as in natural language processing [2, 6, 16]. As a powerful technique that can significantly improve model performance with little additional effort or cost, its widespread adoption and continued development are indicative of its importance in modern machine learning. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 157–170, 2023. https://doi.org/10.1007/978-981-99-4761-4_14
158
L. Wang and B. Liu
There are many typical methods of data augmentation, including flipping, rotating, scaling, cropping, adding noise, and changing color, which can be applied randomly or according to specific rules, depending on the task and the desired results [25]. Recent years also witnessed notable developments in data augmentation techniques, among which CutMix [33], a method that combines parts of multiple images to create new training samples, is our focus in this work. CutMix is a simple yet highly effective data augmentation technique that has gained popularity in recent years. Given a pair of training examples A and B and their associated labels yA and yB , CutMix selects a random crop region of A and replaces it with a patch of the same size cut from B, yielding a new data example. The corresponding label for this new example is specified as the weighted average of the original labels, where the weight is proportional to the area of the patch. As the patch to be cut is randomly selected, it may totally come from the background or the area of the object or an area that mixes the background and the object, while the contribution of this patch to the label of the resulting new image is deterministic. We argue that this is not reasonable. For example, if the patch selected to be cut is all from the background, then its contribution to the label of the resulting new image should be negligible, while using CutMix, its contribution is proportional to its area. Motivated by the aforementioned flaw of CutMix, we propose a knowledge guided CutMix, where the knowledge comes from a pre-trained detection transformer (DETR) model. We term our approach DETR assisted CutMix, or DeMix for short. DETR is an object detection model based on the Transformer architecture [3, 12, 22]. Unlike traditional object detection models, DETR directly models the object detection task as a set matching problem, and uses a Transformer encoder to process input images and a decoder to generate object sets, achieving end-to-end object detection. DeMix takes advantage of the following desirable properties of DETR. – Given an image input to a pre-trained DETR, it can provide an estimate of the class, bounding box position, and corresponding confidence score for each object involved in this image; – It can detect objects of different numbers and sizes simultaneously. As the pre-trained DETR model is trained based on datasets that are different from our target dataset, the class labels it provides can be meaningless, while bounding box positions it provides are surprisingly informative for us to use in DeMix. DeMix cuts the image patch associated with one of the bounding boxes, given by DETR, in an image example, then resizes and pastes it onto a random crop region of another example, to create the new example. The label of this new example is specified in the same way as CutMix. In principle, DeMix provides an elaborate improvement to CutMix by borrowing knowledge from a pre-trained DETR. Even if the knowledge borrowed is totally inaccurate or meaningless, then DeMix reduces to CutMix. Our major contributions can be summarized as follows – We demonstrates how DETR can be used as a tool for data augmentation, resulting in a new approach DeMix;
Use the Detection Transformer as a Data Augmenter
159
– We evaluate the performance of DeMix on several different fine-grained image classification datasets. Experimental results demonstrate that DeMix significantly outperforms all competitor methods, including CutMix; – Our work sheds light on how to improve a general-purpose technique (data augmentation here) using a special-purpose pre-trained model (DETR here) without fine-tuning.
2 Related Works In this section, we briefly introduce DETR followed by related works on data augmentation techniques, especially CutMix, which are closely related to our DeMix. 2.1 DETR Detection Transformer (DETR) is a type of object detection model based on the Transformer architecture [3, 12, 22]. Unlike traditional object detection models, DETR directly models the object detection task as a set matching problem, and uses a Transformer encoder to process input images and a decoder to generate object sets, achieving end-toend object detection [3]. In DETR, the image is divided into a set of small patches and each patch is mapped to a feature vector using a convolutional neural network (CNN) encoder. Then, the output of the encoder is passed as input to the Transformer encoder to allow interactions among the features in both spatial and channel dimensions. Finally, the decoder converts the output of the Transformer encoder into an object set that specifies the class, bounding box position, and corresponding confidence score for each object. DETR does not require predefined operations such as non-maximum suppression or anchor boxes. It can detect objects of different numbers and sizes simultaneously. Figure 1 shows detection results given by a pre-trained DETR model on some image examples in the dataset used in our experiment. DETR is developed for object detection, while, in this paper, we demonstrate that it could also be used as a data augmenter. Specifically, we leverage DETR to locate a semantically rich patch associated with an object for use in CutMix. 2.2 Data Augmentation Data augmentation is a technique used in machine learning to increase the size of a training dataset by generating new examples from existing ones. The goal of data augmentation is to improve the generalization ability of machine learning models by introducing more variations in the training data. There are many existing works on data augmentation, which can be broadly categorized into four groups as follows. 1. Traditional methods: These include commonly used techniques such as random cropping, resizing, flipping, rotating, color jittering, and adding noise [25]. The strength of traditional methods is that they are simple and easy to implement, and can effectively generate new samples from existing ones. However, they may not always work well for complex datasets or tasks, as they do not account for higher-level features and structures.
160
L. Wang and B. Liu
Fig. 1. Object detection with DETR
2. Generative methods: These use generative models, such as variational autoencoders (VAEs) and generative adversarial networks (GANs), to synthesize new data samples [4, 30, 32]. The strength of generative methods is that they can generate high-quality and diverse samples that are similar to the real data distribution, which can help improve model generalization. However, they require large amounts of computational resources and training data to build the generative models, which can be a limitation in some cases. 3. Adversarial methods: These use adversarial attacks to perturb the existing data samples to generate new ones [23, 36, 38]. The strength of adversarial methods is that they can generate realistic and diverse augmented data samples, which can help improve model robustness. However, they require careful design and tuning of hyperparameters to avoid overfitting. 4. Methods that create new training examples by mixing parts of multiple images together such as Mixup [34] and CutMix [33]. Mixup linearly combines data from different examples, while CutMix cuts and pastes patches of images. These methods are much simpler to implement than generative and adversarial methods, while can achieve much better performance than traditional methods. In summary, each data augmentation method has its own strengths and limitations. DeMix proposed here is an algorithmic improvement to the aforementioned method CutMix, which has gained popularity in recent years, since it is simple to implement, while can achieve much better performance than traditional methods. In this paper, we propose DeMix, which is as simple as CutMix, while outperforms it significantly. In CutMix, the image patch to be cut is randomly selected, while the contribution of this patch to the label of the resulting new image is deterministically proportional to its area. As aforementioned, this is not reasonable. For example, if the patch to be cut is all from the background, then its contribution to the label of the resulting new image should be negligible, while using CutMix, it can be large. Several work has been proposed to address this flaw of CutMix, where the basic idea is to select a semantically rich instead of a random patch to be cut and paste, see e.g., SaliencyMix [28], the class activation map (CAM) based method [35], Keepaugment [11], and SnapMix [15]. All these advanced methods require image pre-processing, such as computing the saliency map or CAM, prior to data augmentation, while, in
Use the Detection Transformer as a Data Augmenter
161
contrast, DeMix does not need any preprocessing operation before generating new image examples.
3 DeMix: DETR Assisted CutMix In this section, we describe our DeMix method in detail. Given a pair of image examples, DeMix uses two operations to generate a new image. The first operation employs a pretrained DETR to identify bounding box positions for each object in the source image denoted as xB . The second operation cuts an image patch associated with one bounding box, resize it, and then pastes it onto a randomly selected crop region of the other image, termed target image and denoted as xA , yielding the new image example. The label of this new image is a weighted average of labels of the original images. Figure 2 illustrates the operations that make up DeMix.
Fig. 2. An example show of operations in DeMix. Given a pair of a target image and a source image, denoted as xA and xB , with one-hot labels yA and yB , respectively, DeMix starts by randomly selecting a cropping region for xA , and locating the object region for xB based on the object bounding box outputted by a pre-trained DETR. The ‘Transforming’ operation performed on xB resizes and relocates the image patch located by DETR to make its size and position consistent with those of the cropping region in xA .
In mathematical terms, the process of generating a new image can be explained as follows.: x˜ = (1 − M λ ) xA + T (M B xB ) y˜ = (1 − λ)yA + λyB
(1)
where yA and yB denote labels of xA and xB , respectively, in the form of one-hot vectors, M λ and M B are binary mask matrices with dimensions W×H, W and H denote the width and height of the images, 0 < λ < 1 is a hyperparameter defined as the ratio of the area of the randomly selected crop region of xA to the full area of xA , M λ denotes the binary mask matrix that defines the aforementioned crop region, M B is the binary mask matrix
162
L. Wang and B. Liu
associated to the object bounding box given by DETR, denotes the element-wise multiplication operator, 1 represents a matrix of an appropriate size whose elements are all 1, and finally T () denotes a linear transformation that aligns the size and position of the cut patch to be consistent with those of the crop region in xA on which this patch will be pasted. 3.1 Discussions on the Algorithm Design of DeMix In DeMix, we select the target and source images, which are used for generating a new image, in the same way as in CutMix, namely, they are randomly selected from the training set. Since we use a pre-trained DETR for DeMix, which means that we do not need to train the DETR model, the computational overhead of DeMix is comparable to CutMix. DeMix is based on CutMix, with the same “random cropping” and linear label generation operations. The concept of random cropping has been widely used in deep learning (DL) data augmentation techniques like CutMix [33] and Cutout [7]. The Dropout approach, which is often used for DL regularization [1, 26], is also a form of “random cropping” in essence. However, instead of image patches, it crops neural network weights. Empirically speaking, “random cropping” is a simple yet effective strategy for DL regularization. Its basic mechanism is that DL generally follows “shortcut” learning, making predictions based on “shortcut” features embedded in the training dataset [9]. For instance, if all cows in the training set appear with grass, the DL model could link grass features to cow existence. If a cow appears on a beach without grass in a test image, the model will predict that there is not a cow. By utilizing “random cropping”, DL lessens its reliance on shortcut features, reducing the overfitting probability and enhancing the model’s generalization ability. We display data augmentation results of MixUp [34], CutMix [33], SaliencyMix [28], and DeMix in Fig. 3. As is shown, DeMix cuts and resizes a patch that covers the whole object region from the source image, while SaliencyMix selects a patch corresponding to the most salient box region that only covers a part of the object. CutMix randomly selects a patch to be cut, which may only contain the background region. MixUp mixes the target and source images through linear combination, which may lead to local ambiguity and unnaturalness in the generated images, as addressed by [33]. As DeMix is a pre-trained DETR assisted data augmentation technique, its performance is strongly connected to the quality of the DETR model. When the DETR model works well to accurately locate the object region in the source image, then DeMix could succeed in selecting a semantically rich patch to be cut and mixed. When the DETR model being used fails to accurately locate the object region in the source image, then, in principle, DeMix reduces to CutMix, since the patch to be cut can be seen as one randomly selected.
Use the Detection Transformer as a Data Augmenter
163
Fig. 3. Comparison between related data augmentation techniques.
4 Experiments In this section, we evaluate the performance of our DeMix method through experiments on image classification tasks that involve different datasets, different neural network architectures of different sizes. Related modern data augmentation techniques, see Subsect. 4.1, are used for performance comparison. 4.1 Experimental Setting Datasets. In our experiments, we selected three benchmark fine-grained image datasets for use, namely CUB-200-2011 [29], Stanford-Cars [17], and FGVC-Aircraft [19]. For simplicity, we refer to them as CUB, Cars, and Aircraft, respectively, in what follows. Network Architectures. To perform a comprehensive evaluation on the performance of our method, we selected 6 different network architectures in our experiments, including ResNet-(18, 34, 50, 101) [13], InceptionV3 [27], and DenseNet121 [14]. Comparison Methods. For performance comparison, we selected modern data augmentation techniques in our experiments including CutMix [33], SaliencyMix [28], MixUp [34], and CutOut [7]. We also include a baseline method that refers to a model trained without using any data augmentation technique. Futher Details on Model Training. We used the open-source pre-trained DETR model detr_resnet50 included in the Pytorch-torchvision package. The initial feature extractor parameter values are set to be equal with those of the ResNet50 feature extractor pretrained on the ImageNet-1K dataset [13], and the entire DETR model is trained based on the MS-COCO dataset [18]. Figure 1 demonstrates the detection performance of the aforementioned DETR model on some image examples involved in our experiments. In our image classification tasks, we followed [15] to specify hyper-parameter values for model training. Specifically, we chose the stochastic gradient descent (SGD) with momentum as the optimizer, and set the momentum factor at 0.9. The initial learning rate for the pre-trained weights was set to 0.001, 8 while that for other weights was set to 0.01. If training from scratch, the initial learning rate for all trainable weights was set to 0.01. When using pre-trained weights, the model was trained for 200 epochs and decayed the learning rate by factor 0.1 at 80, 150, and 180 epoch; otherwise, the model was trained for 300 epochs and decayed the learning rate by factor 0.1 at 150, 225, and 270 epoch.
164
L. Wang and B. Liu
4.2 Experimental Results We used different data augmentation techniques in image classification and evaluated the performance of each data augmentation technique via its associated classification accuracy. In Tables 1 and 2, we show the image classification performance, in terms of the average top-1 accuracy, with respect to ResNet architectures with different depths. Results of the baseline method, MixUp, CutOut, and CutMix are directly quoted from [15]. The training setting for the other methods, namely SaliencyMix and DeMix, were set as the same as [15] to guarantee that the performance comparison is fair. We see that DeMix achieved the best performance in almost all the experiments compared to other methods. It also shows that SaliencyMix does not provide a significant performance gain over CutMix on these datasets, while DeMix does. We argue that it is because the discriminative parts of an image are not located in the salient region captured by SaliencyMix, while they are located in the object region detected by the DETR model employed by DeMix. Furthermore, we found that DeMix performs more stable, compared to other methods, when the network depth varies. For example, on the dataset CUB, CutMix and SaliencyMix perform poorly with shallower network architectures ResNet18, while show significant performance improvement with deeper architectures like ResNet101. This may be because the image samples generated by CutMix and SaliencyMix are more noisy than those generated by DeMix, and deeper networks are better at handling noisy samples. Overall, regardless of the network depth, DeMix outperforms the other methods significantly. We further conducted experiments to evaluate the performance of DeMix on other neural network architectures, namely InceptionV3 and DenseNet121. The results are shown in Table 3. Again, we see a significant performance improvement given by DeMix over the other methods. In previous experiments, we fine-tuned pre-trained classification models using augmented training datasets in the modeling training phase. We also conducted experiments wherein we train the classification models from scratch. In this way, we get a clearer performance evaluation, since it avoids the impact of pre-training on data augmentation performance evaluation. The corresponding results are shown in Table 4. We see that DeMix performs best in most cases and comparably in the other ones. In particular, on the CUB dataset, DeMix gives a significant performance improvement compared to CutMix and SaliencyMix. It may be because that, in the CUB dataset, images of different classes have more subtle differences among each other, thus requiring training samples of higher quality; and DeMix can generate samples of higher quality than CutMix and SaliencyMix. In order to understand why DeMix performs better than the other methods involved in our experiments, we investigate the class activation mapping (CAM) associated with each data augmentation technique. CAM is a technique used in deep learning based computer vision to visualize the regions of an image that are most important for a neural network’s classification decision. CAM generates a heatmap that highlights the regions of the image that contributed most to the predicted output class, allowing humans to better understand how the model is making its predictions.
Use the Detection Transformer as a Data Augmenter
165
Table 1. Top-1 accuracy (%) of each method for image classification tasks on datasets CUB, Cars, and Aircraft. The classification network is initialized by a pre-trained ResNet18 or ResNet34. The best performance is marked in bold. CUB
Cars
Aircraft
ResNet18
ResNet34
ResNet18
ResNet34
ResNet18
ResNet34
Baseline
82.35
84.98
91.15
92.02
87.80
89.92
MixUp
83.17
85.22
91.57
93.28
89.82
91.02
CutOut
80.54
83.36
91.83
92.84
88.58
89.90
CutMix
80.16
85.69
92.65
93.61
89.44
91.26
SaliencyMix
80.69
85.17
93.17
93.94
90.61
91.72
DeMix
82.86
86.69
93.37
94.49
90.52
93.10
Table 2. Top-1 accuracy (%) of each method for image classification tasks on datasets CUB, Cars, and Aircraft. The classification network is initialized by a pre-trained ResNet50 or ResNet101. The best performance is marked in bold. CUB
Cars
Aircraft
ResNet50
ResNet101
ResNet50
ResNet101
ResNet50
ResNet101
Baseline
85.49
85.62
93.04
93.09
91.07
91.59
MixUp
86.23
87.72
93.96
94.22
92.24
92.89
CutOut
83.55
84.70
93.76
94.16
91.23
91.79
CutMix
86.15
87.92
94.18
94.27
92.23
92.29
SaliencyMix
86.35
87.59
94.23
94.22
92.41
92.77
DeMix
86.93
88.23
94.59
94.81
93.76
94.27
Table 3. Top-1 accuracy (%) of each method for image classification tasks on datasets CUB, Cars, and Aircraft. The classification network is initialized by a pre-trained InceptionV3 or DenseNet121. The best performance is marked in bold. CUB
Cars
Aircraft
InceptionV3
DenseNet121
InceptionV3
DenseNet121
InceptionV3
Baseline
82.22
84.23
93.22
93.16
91.81
DenseNet121 92.08
MixUp
83.83
86.65
92.23
93.21
92.02
91.42
CutMix
84.31
86.11
93.94
94.25
92.71
93.40
SaliencyMix
85.07
85.26
94.18
93.65
93.58
92.95
DeMix
85.12
87.38
94.13
94.29
93.85
94.27
166
L. Wang and B. Liu
Table 4. Top-1 accuracy (%) of each method for image classification tasks on datasets CUB, Cars, and Aircraft. The classification network architecture is set as ResNet18 or ResNet50, the same as in Table 1, while here the model is trained from scratch, other than pre-trained as shown in Table 1. The best performance is marked in bold. CUB ResNet18
Cars
Aircraft
ResNet50
ResNet18
ResNet50
ResNet18
ResNet50
Baseline
64.98
66.92
85.23
84.63
82.75
84.49
MixUp
67.63
72.39
89.14
89.69
86.38
86.59
CutMix
60.03
65.28
89.11
90.13
85.60
86.95
SaliencyMix
65.60
67.03
88.53
89.81
86.95
88.81
DeMix
70.00
71.80
89.83
91.72
88.48
88.66
We checked CAMs given by the classification models trained with aid of different data augmentation techniques. The visualization results on 3 test image examples are shown in Fig. 4. As is shown, using DeMix, the regions of the image that contributed most to the predicted output class match the real objects’ regions to a greater extent than using the other data augmentation techniques. For example, for the 2nd test image, the classification model trained with aid of MixUp mainly uses the head of the bird, the model associated with SaliencyMix mainly uses the body of the bird, while the model corresponding to DeMix uses the head and a part of the body together, to generate the predicted label. For the 3rd test image, it is clearer that the model associated with DeMix selects a more appropriate region for use in making class predictions. 4.3 Further Experiments We conduct an experiment to investigate the influence of the hyperparameter λ, namely the ratio of the area of the randomly selected crop region, on performance of DeMix. See the result in Table 5, which shows that the performance of DeMix is not very sensitive to the value of λ. Note that our DeMix is built upon CutMix. It utilizes a random λ, the same as CutMix, to enhance sample diversity in the augmented dataset, which has been demonstrated to be beneficial for improving the model’s performance in terms of generalization. We also consider long-tailed recognition tasks on the CUB dataset. Performance comparison results between DeMix with CutMix and SaliencyMix with different imbalance ratios are presented in Table 6. It is shown that, for both architectures ResNet18 and ResNet50, DeMix performs best.
Use the Detection Transformer as a Data Augmenter
167
Fig. 4. Visualizations of class activation mapping (CAM) on 3 test image examples. Table 5. Influence of λ on performance of DeMix on the image classification task using the CUB dataset λ
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
ResNet18
83.83
83.28
82.93
82.78
82.67
82.36
82.02
81.43
80.83
ResNet50
87.33
87.02
86.90
86.90
86.87
86.61
86.49
85.90
85.40
Table 6. Top-1 accuracy (%) comparison on long-tailed CUB dataset with different imbalance ratios ResNet18 Imbalance Ratio
50%
ResNet50 10%
50%
10%
CutMix
32.15
52.50
38.54
62.63
SaliencyMix
29.58
45.53
33.76
54.90
DeMix
32.97
52.99
38.89
62.91
168
L. Wang and B. Liu
5 Conclusion In this paper, we demonstrated that a pre-trained object detection model, namely DETR, can be used as a tool for developing powerful data augmentation techniques. Specifically, we found that a DETR model pre-trained on the MS-COCO dataset [18] can be used to locate semantically rich patches in images of other datasets, such as CUB-200-2011 [29], Stanford-Cars [17], and FGVC-Aircraft [19]. Then we proposed DeMix, a novel data augmentation technique that employs DETR to assist CutMix in locating semantically rich patches to be cut and pasted. Experimental results on several fine-grained image classification tasks that involve different network depths and different network architectures demonstrate that our DeMix performs strikingly better than prior art methods. Our work thus suggests (or confirms) that leveraging the power of a pre-trained (large) model directly, without fine-tuning, is a promising direction for future research to improve task-specific performance. Acknowledgment. This work was supported by Exploratory Research Project (No. 2022RC0AN02) of Zhejiang Lab.
References 1. Baldi, P., Sadowski, P.J.: Understanding dropout. Adv. Neural Inf. Proc. Syst. 26 (2013) 2. Bayer, M., Kaufhold, M.A., Buchhold, B., Keller, M., Dallmeyer, J., Reuter, C.: Data augmentation in natural language processing: a novel text generation approach for long and short text classifiers. Int. J. Mach. Learn. Cybern. 14(1), 135–150 (2023) 3. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., Zagoruyko, S.: End-toend object detection with transformers. In: Proceedings on 16th European Conference on Computer Vision (ECCV 2020), Part I 16, pp. 213-229. Springer (2020) 4. Chadebec, C., Thibeau-Sutre, E., Burgos, N., Allassonniere, S.: Data augmentation in high dimensional low sample size setting using a geometrybased variational autoencoder. IEEE Trans. Pattern Anal. Mach. Intell. 45(3), 2879–2896 (2022) 5. Chaitanya, K., et al.: Semi-supervised task-driven data augmentation for medical image segmentation. Med. Image Anal. 68, 101934 (2021) 6. Chen, J., Shen, D., Chen, W., Yang, D.: Hiddencut: simple data augmentation for natural language understanding with better generalizability. In: Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing, vol. 1, pp. 4380–4390 (2021) 7. DeVries, T., Taylor, G.W.: Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552 (2017) 8. Fawzi, A., Samulowitz, H., Turaga, D., Frossard, P.: Adaptive data augmentation for image classification. In: 2016 IEEE International Conference on Image Processing (ICIP), pp. 3688– 3692. IEEE (2016) 9. Geirhos, R., et al.: Shortcut learning in deep neural networks. Nat. Mach. Intell. 2(11), 665–673 (2020) 10. Ghiasi, G., et al.: Simple copy-paste is a strong data augmentation method for instance segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2918–2928 (2021)
Use the Detection Transformer as a Data Augmenter
169
11. Gong, C., Wang, D., Li, M., Chandra, V., Liu, Q.: Keepaugment: a simple informationpreserving data augmentation approach. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, 1055–1064 (2021) 12. Han, K., et al.: A survey on vision transformer. IEEE Trans. Pattern Anal. Mach. Intell. 45(1), 87–110 (2022) 13. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 14. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4700–4708 (2017) 15. Huang, S., Wang, X., Tao, D.: SnapMix: semantically proportional mixing for augmenting fine-grained data. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 1628–1636 (2021) 16. Kafle, K., Yousefhussien, M., Kanan, C.: Data augmentation for visual question answering. In: Proceedings of the 10th International Conference on Natural Language Generation, pp. 198– 202 (2017) 17. Krause, J., Stark, M., Deng, J., Fei-Fei, L.: 3D object representations for fine-grained categorization. In: Proceedings of the IEEE International Conference on Computer Vision Workshops, pp. 554–561 (2013) 18. Lin, T.Y., et al.: Microsoft COCO: common objects in context. In: 13th European Conference on Computer Vision (ECCV), pp. 740-755. Springer (2014) 19. Maji, S., Rahtu, E., Kannala, J., Blaschko, M., Vedaldi, A.: Fine-grained visual classification of aircraft. arXiv preprint arXiv:1306.5151 (2013) 20. Miko lajczyk, A., Grochowski, M.: Data augmentation for improving deep learning in image classification problem. In: 2018 International Interdisciplinary PhD Workshop (IIPhDW), pp. 117–122. IEEE (2018) 21. Montserrat, D.M., Lin, Q., Allebach, J., Delp, E.J.: Training object detection and recognition cnn models using data augmentation. Electron. Imaging 10, 27–36 (2017) 22. Parmar, N., et al.: Image transformer. In: International Conference on Machine Learning, pp. 4055–4064. PMLR (2018) 23. Peng, X., Tang, Z., Yang, F., Feris, R.S., Metaxas, D.: Jointly optimize data augmentation and network training: adversarial data augmentation in human pose estimation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, 2226–2234 (2018) 24. Perez, L., Wang, J.: The effectiveness of data augmentation in image classification using deep learning. arXiv preprint arXiv:1712.04621 (2017) 25. Shorten, C., Khoshgoftaar, T.M.: A survey on image data augmentation for deep learning. J. Big Data 6(1), 1–48 (2019) 26. Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., Salakhutdinov, R.: Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15(1), 1929–1958 (2014) 27. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 28. Uddin, A., Monira, M., Shin, W., Chung, T., Bae, S.H., et al.: Saliencymix: a saliency guided data augmentation strategy for better regularization. arXiv preprint arXiv:2006.01791 (2020) 29. Wah, C., Branson, S., Welinder, P., Perona, P., Belongie, S.: The Caltech-UCSD birds-200– 2011 dataset (2011) 30. Wu, Z., Wang, S., Qian, Y., Yu, K.: Data augmentation using variational autoencoder for embedding based speaker verification. In: Interspeech, pp. 1163–1167 (2019)
170
L. Wang and B. Liu
31. Xu, J., Li, M., Zhu, Z.: Automatic data augmentation for 3D medical image segmentation. In: 23rd International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI), pp. 378-387 Springer (2020) 32. Yang, H., Zhou, Y.: IDA-GAN: a novel imbalanced data augmentation GAN. In: 2020 25th International Conference on Pattern Recognition (ICPR), pp. 8299–8305 IEEE (2021) 33. Yun, S., Han, D., Oh, S.J., Chun, S., Choe, J., Yoo, Y.: Cutmix: regularization strategy to train strong classifiers with localizable features. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, PP. 6023–6032 (2019) 34. Zhang, H., Cisse, M., Dauphin, Y.N., Lopez-Paz, D.: Mixup: beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 (2017) 35. Zhang, W., Cao, Y.: A new data augmentation method of remote sensing dataset based on class activation map. J. Phys: Conf. Ser. 1961, 012023 (2021) 36. Zhang, X., Wang, Z., Liu, D., Ling, Q.: Dada: deep adversarial data augmentation for extremely low data regime classification. In: IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 2807–2811. IEEE (2019) 37. Zhao, A., Balakrishnan, G., Durand, F., Guttag, J.V., Dalca, A.V.: Data augmentation using learned transformations for one-shot medical image segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 8543–8553 (2019) 38. Zhao, L., Liu, T., Peng, X., Metaxas, D.: Maximum-entropy adversarial data augmentation for improved generalization and robustness. Adv. Neural. Inf. Process. Syst. 33, 14435–14447 (2020) 39. Zhong, Z., Zheng, L., Kang, G., Li, S., Yang, Y.: Random erasing data augmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 13001–13008 (2020) 40. Zoph, B., Cubuk, E.D., Ghiasi, G., Lin, T.Y., Shlens, J., Le, Q.V.: Learning data augmentation strategies for object detection. In: Proceedings of European Conference on Computer Vision (ECCV), pp. 566-583. Springer (2020).
An Unsupervised Video Summarization Method Based on Multimodal Representation Zhuo Lei1,2(B) , Qiang Yu1 , Lidan Shou2 , Shengquan Li1 , and Yunqing Mao1 1 City Cloud Technology (China) Co., Ltd., Hangzhou, China
[email protected] 2 Zhejiang University, Hangzhou, China
Abstract. A good video summary should convey the whole story and feature the most important content. However, the importance of video content is often subjective, and users should have the option to personalize the summary by using natural language to specify what is important to them. Moreover, existing methods usually apply only visual cues to solve generic video summarization tasks, while this work introduces a single unsupervised multi-modal framework for addressing both generic and query-focused video summarization. We use a multi-head attention model to represent the multi-modal feature. We apply a Transformerbased model to learn the frame scores based on their representative, diversity and reconstruction losses. Especially, we develop a novel representative loss to train the model based on both visual and semantic information. We outperform previous state-of-the-art work with superior results on both generic and query-focused video summarization datasets. Keywords: Video Summarization · Multi-modal Representation Learning · Unsupervised Learning
1 Introduction User-generated videos are often long, poorly shot, and unedited, so video summary techniques are useful for users to quickly glance at important content. However, an effective video summary should also be personalized by users, who can indicate video concepts in the summary using natural language queries. Existing methods for video summarization mostly rely on visual cues and frame score prediction, and they are not customizable using natural language input. Therefore, it is desirable to have a single model that can handle both generic and query-focused video summarization tasks. We introduce a multi-modal unsupervised approach that uses both video and natural language text. Given an input video, it generates a video summary guided by either a user-defined natural language query or a system-generated description. First, we apply a bundling center based clustering method to temporally segment the original video into shots. Second, we use a multi-modal attention to compute the fused representation of both inputs, and an image scoring transformer to assign scores to each video frame. We develop a novel representative loss combing with the commonly used diversity loss and © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 171–180, 2023. https://doi.org/10.1007/978-981-99-4761-4_15
172
Z. Lei et al.
construction loss to train the model. Finally, we use a knapsack algorithm to select the maximum number of shots with high scores. Our main contributions are: 1) we show that the designed multi-modal representation and the representative loss is able to generate an effective model to capture video concepts in video summarization task. 2) we propose a multi-modal unsupervised model that can address both generic and query-focused video summarization tasks. 3) we achieve the new state-of-the-art on both generic and query-focused datasets.
2 Related Work Previous works used hand-designed features [2] or various heuristics [1] to represent the importance of video frames. The introduction of TVSum [18] and SumMe [4] provided relevance frame scores annotated by multiple users, which led to the development of many supervised methods [22, 23, 26] that extracted high-level semantic information. Fully convolutional sequence networks treated video summarization as a binary label prediction problem [15, 28], which show competitive performance. Determinantal point processes and LSTM based approaches modeled variable-range dependencies between frames [9, 22, 23]. Moreover, unsupervised methods commonly used reconstruction and diversity losses. Notable unsupervised approaches included an adversarial LSTM based method [10], a generative adversarial network to learn from unpaired data [14], and a cycle consistent learning objective [2]. On the other hand, [16] introduced the Query-Focused Video Summarization (QFVS) dataset for UT Egocentric videos [7] that contained user-defined video summaries for a set of pre-defined concepts. [17] proposed a recurrent attention mechanism to attend to different video frames and shots with the user query as input. However, it limited the long-range time dependencies of video frames, and could only work with the predefined keyword-based queries. [29] also considered the task of video summarization using language queries. To effectively address generic queries from different modalities [31] introduced a graph convolutional networks which is used for both summary module and intent module. [30] proposed the unified multimodal transformer to cover different input modality combinations.
3 Framework Our framework can generate both generic and query-focused video summaries in an unsupervised way. We use the following steps to create a summary from an input video. For image encoding, we extract visual features from the frames using the CLIP model. For text encoding, we generate a multi-sentence caption for the video using the BiModal Transformer dense video captioning model [6], and encode the caption sentences (or a user-defined query) with the CLIP model (ViT). Then, we fuse the frame and text features with a multi-head attention model. To summary the video, first, we cluster the video frames into semantically coherent segments using a bundling center algorithm [8]. Second, we score the segments based on their visual and semantic relevance using a Frame-Scoring Transformer model. We develop a novel representative loss combining
An Unsupervised Video Summarization Method
173
with the commonly used diversity loss and construction loss to train the model. Third, we select the most informative segments using a 0/1 knapsack algorithm [22] to form the final summary. Figure 1 illustrates our framework.
Fig. 1. Overview of the proposed framework. We encode the text (the caption sentences or a userdefined query) and the frames with the CLIP model. Then, we fuse the frame and text features with a multi-head attention model. To summary the video, there are three steps. First, we cluster the video frames into semantically coherent segments using a bundling center algorithm. Second, we score the segments based on their visual and semantic relevance using a Frame-Scoring Transformer model. We develop a novel representative loss combing with the commonly used diversity loss and construction loss to train the model. Third, we select the most informative segments to form the final summary.
3.1 Frame Representation We use the CLIP model to encode the video frames denoted by Ii , i ∈ [1, . . . , n]. We also use it to encode either a user-provided text query or a system-generated dense video
174
Z. Lei et al.
caption. The caption consists of sentences denoted by Sj , j ∈ [1, . . . , l]. We then use the following settings for Query Q, Key K, and Value V: Q = Mimage (Ii ), where i ∈ [1, . . . , n]
(1)
K, V = Mtext (Sj ), where j ∈ [1, . . . , l]
(2)
We use the multi-head attention mechanism from [20] to fuse the video and language features. We feed the frame and text embeddings into a multi-head attention model Membedding , which produces a fused embedding Fembeding . This embedding captures the cross-modal interactions and long-term dependencies between the video and language modalities. The fusion process can be formulated as: Attention(Q, K, V ) = Concat(head1 , . . . , headh )W Q
(3)
where headi = Attention(QWi , KWiK , VWiV )
(4)
QK T and Attention(Q, K, V ) = softmax( √ )V dimK
(5)
Q
where Wi , WiK , and WiV are learnable parameter matrices and dimK is the dimension of K. The fused embedding Fi is the output of the Multi-Head Attention module. 3.2 Video Temporal Segmentation Unlike [8], we segment the video based on the clustering of a deep affinity graph that incorporates multi-modal information. We model the video as a graph, where the nodes are the video frames and the edges are the pairwise cosine similarities. We construct and normalize the affinity graph as A(I , E), where E = {simij } are the edges between frames i and j. We also use a time-constrained graph Aχ (I , E χ ) with a Gaussian function, where χ E χ = {simij } are the edges between frames within a time window. Each node can be represented as: 1 − (i−j) χ simij = √ e 2η2 η 2π
2
(6)
where η is a control parameter that adjusts the temporal penalty and smoothness level. χ Therefore, a temporally constrained graph Atc can be formulated as: χ
Atc = A · Aχ
(7)
Moreover, we use a bundling center to represent a cluster of similar frames instead of a single frame. We apply a dense-neighbor-based clustering method [24] to identify local χ clusters based on the edge connectivity on Atc . A local cluster consists of elements that are locally similar to all of their neighbors, rather than being close to a single element. For more cluster computation details, please refer to [24].
An Unsupervised Video Summarization Method
175
3.3 Learning Representative Loss. One of the challenges in video summarization is how to select the most significant contents for the summary. There is no clear criterion for measuring the importance of video segments or frames, and even humans may have different opinions. A good summary should capture the main contents of the video, so the chosen frames or segments should be representative of the rest of the video. Using Aχ (I , E χ ), we design a novel representative loss based on the TextRank method [11] from natural language processing. We construct a graph where the nodes are video frames and the edges measure the similarity between the frames. We compute the representative score of node Ii as: simij Rj (8) Ri = (1 − a) + a ∗ Ij ∈In(Ii ) Iz ∈Out(Ij ) simjz where a is a damping factor and 0 ≤ a ≤ 1, which controls the probability of jumping from a given node to another random node in the graph [11]. It can be seen as a random surfer switching between video contents. The algorithm starts with random values assigned to each node in the graph, and iterates until convergence. We stop the iteration when the difference between the importance scores of two consecutive iterations is below a given threshold. After the algorithm converges, each node has a score representing the importance of the video frame associated with the node. The final importance scores of the nodes are independent of the initial values, only the number of iterations to converge may vary. We use a weighted cross entropy loss: 1 n W R [Ri log(Ri )] + (1 − W R )[(1 − Ri )log(1 − Ri )] (9) Lrep = − i=1 n where W R is the weight assigned to the score Ri . Reconstruction Loss. We select the keyframes based on the scores given by the FrameScoring Transformer. We use a decoder network with 1 × 1 convolution layers to obtain reconstructed feature vectors for the selected keyframes, which have the same dimension as the original feature vectors. The reconstruction loss Lr is defined as the mean squared error between the reconstructed and the original features of the selected keyframes: 1 ||Fi − Zi ||2 (10) Lr = i∈X X where Zi denotes the reconstructed features. Diversity Loss. We use a repelling regularization function [25] to encourage diversity among the selected keyframes. Following [15], we calculate the diversity loss Ld as the pairwise cosine similarity between the selected keyframes: Zi · Zj 1 (11) Ld = i∈X j∈X ,j=i ||Zi ||2 · ||Zj ||2 X (X − 1) where Zi and Zj denote the reconstructed feature vectors of the ith and jth node. The final loss function is then: L = α· ≤ Lrep + β · Ld + λ · Lr
(12)
176
Z. Lei et al.
where α, β, and λ are hyper-parameters that balance the three loss functions. Frame-Scoring Transformer. To capture the correlation between frames, we use a Transformer that takes the image-text representations as input and outputs a frame score. i . Following [19], this module assigns relevance scores to the fused embedding Fembedding i Fembedding are fed to both the encoder and decoder stacks of the Transformer. We also use positional encoding to add information about the relative positions of the tokens. Positional encoding is added to the input embedding at the bottom of the encoder and decoder stacks.
3.4 Summary Generation Finally, we compute the relative importance score of the segment by taking the average importance of the frames in the segment, rather than summing up the scores, which may favor longer segments over shorter ones.
4 Experiments Datasets. We evaluate the performance of our method on two generic video summarization datasets (TVSum and SumMe) and one query-focused video summarization dataset (QFVS). SumMe contains 25 videos (1–6 min. Duration) from both first-person and third-person views. Each video has been annotated by 15 to 18 users with keyfragments, and multiple user summaries of different lengths (5–15% of the original video duration). TVSum contains 50 videos (1–11 min. Duration) from 10 categories of the TRECVid MED task (5 videos per category). Each video has been annotated by 20 users with scores ranging from 1 (not important) to 5 (very important). QFVS provides ground-truth generic summaries for four videos from the UT Egocentric dataset. The summaries are constructed by dividing the video into shots and asking three users to select the relevant shots, and the final ground-truth is an average of annotations from all users. Evaluation Approach. We use the F-score metric to evaluate our method, which is commonly used by most video summarization methods. It measures the similarity between a machine-generated summary and a user-defined summary by computing their overlap. Following previous work, we compare the generated summary with each of the available user summaries for the same video, and calculate an F-score for each comparison. Then, we take the maximum of these F-scores as the final F-score for this video. Data Configuration. We follow the previous works [21, 22] and evaluate our approach in three different data settings: standard, augment, and transfer. In the standard setting, we use the same dataset for training and testing. In the augment setting, we combine the training set from one dataset with all the data from the other three datasets. In the transfer setting, we train a model on three datasets and test it on the fourth unseen dataset. All the data splits can be obtained from [21, 22].
An Unsupervised Video Summarization Method
177
Table 1. We compare our approach with state-of-the-art methods on SumMe and TVSum for generic video summarization. The best results are highlighted. The numbers represent the F-scores in percentage. Method
SumMe
TVSum
standard augment transfer standard augment transfer Park et al. (SumGraph) [13]
51.4
52.9
48.7
63.9
65.8
60.5
Zhang et al. (LSTM) [23]
38.6
42.9
41.8
54.7
59.6
58.7
Rochan et al. (SUM-FCN) [15] 47.5
51.1
44.1
56.8
59.2
58.2
Rochan et al. (SUM-DeepLab) 48.8 [15]
50.2
45.0
58.4
59.1
57.4
GoogleNet + Transformer [12] 51.6
53.5
49.4
64.2
66.3
61.3
ResNet + Transformer [12]
52.8
54.9
50.3
65.0
67.5
62.8
CLIP-Image + Transformer [12]
53.5
55.3
51.0
65.5
68.1
63.4
CLIP-It [12]
54.2
56.4
51.9
66.3
69.0
65.5
He et al. [5]
46.0
47.0
44.5
58.5
58.9
57.8
Zhou et al. [28]
42.1
43.9
42.6
58.1
59.8
58.9
Ours
54.7
56.1
52.0
66.7
68.5
65.8
Results and Discussion. Our method achieves state-of-the-art performance on generic video summarization in all three settings. Table 1 shows that we obtain the best results in the standard and transfer settings, and slightly inferior to the best results in the augment setting for both TVSum and SumMe. This demonstrates the effectiveness of our method, which can identify important segments and produce informative summaries. Our method can generate video summaries that are closer to human perception than other methods. Furthermore, it is noteworthy that our result surpasses all the supervised methods. We argue that there are no standard rules to define what constitutes important content for video summarization. Therefore, human-generated summaries may vary widely due to different perceptions and personal experiences. We suspect that the training data for usergenerated video summarization is insufficient for supervised methods, and the learned model cannot capture the property of video summarization. For example, TVSum only contains 10 categories of videos, which should be suitable for supervised methods to learn video structure, but our method still outperforms them. Additionally, we believe that structural analysis is crucial for video summarization. A segment-based summary is more consistent with human perception and more reasonable than a frame-based one, because segments contain motion information compared to keyframes. In fact, participants were not required to select a segment-based summary. Hence, we have reasons to believe that a good unsupervised method is more suitable for generic video summarization. For query-focused video summarization, we follow the same experimental setup as [17], which involves four rounds of experiments with one video for testing and one for validation in each round, and the other two videos for training. Table 2 shows that
178
Z. Lei et al.
our approach achieves the best performance on QFVS, with an F-score of 54.7%. The qualitative analysis is similar to the one for generic video summarization above. However, we expect our method to perform much better than generic video summarization in theory. A possible reason for the gap is that the text queries and the video frames are not fully aligned. Table 2. We compare our approach with state-of-the-art methods on QVFS for query-focused video summarization. The best results are highlighted. The best results are highlighted. The numbers represent the F-scores in percentage. Method
Video 1
Video 2
Video 3
Video 4
Average
SeqDPP [3]
36.6
43.7
25.2
18.1
31.0
SH-DPP[16]
35.7
42.7
36.5
18.6
33.4
QFVS [17]
48.7
41.7
56.5
30.0
44.2
CLIP-Image + Query + bi-LSTM [12]
54.5
48.6
62.8
38.6
51.1
ResNet + Query + bi-LSTM [12]
55.2
51.0
64.3
39.5
52.5
CLIP-It [12]
57.1
53.6
66.0
41.4
54.6
Ours
57.2
53.4
66.4
41.4
54.7
5 Conclusion We propose a novel unsupervised multi-modal method for both generic and queryfocused video summarization tasks. Using the CLIP feature, we employ a multi-head attention model to fuse the video frames and the text queries or captions into a joint representation. We first segment the video based on the affinity graph computed with the embedding feature. Then, we design a representative loss that combines the commonly used diversity loss and reconstruction loss to train a Transformer-based model to assign relative scores to video frames. Finally, we select segments with high information scores to generate video summaries. We demonstrate that our method outperforms the state-ofthe-art methods.
References 1. Cai, S., Zuo, W., Davis, L., Zhang, L.: Weakly-supervised video summarization using variational encoder-decoder and web prior. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018, pp. 193–210. Springer International Publishing, Cham (2018) 2. Avila, S., Lopes, A., Antonio da Luz Jr., Araújo, A.: VSUMM: a mechanism designed to produce static video summaries and a novel evaluation method. Pattern Recognit. Lett. 32(1), 56–68 (2011) 3. Gong, B., Chao, W., Grauman, K., Sha F.: Diverse sequential subset selection for supervised video summarization. In: NIPS 2014, pp. 2069–2077. MIT Press, Cambridge, MA, USA, (2014)
An Unsupervised Video Summarization Method
179
4. Gygli, M., Grabner, H., Riemenschneider, H., Van Gool, L.: Creating Summaries from User Videos. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) ECCV 2014. LNCS, vol. 8695, pp. 505–520. Springer, Cham (2014). https://doi.org/10.1007/978-3-319-10584-0_33 5. He, X., et al.: Unsupervised video summarization with attentive conditional generative adversarial networks. In: ACM Multimedia 2019 (2019) 6. Iashin, V., Rahtu, E.: A better use of audio-visual cues: dense video captioning with bi-modal transformer. In: BMVC 2020 (2020) 7. Lee, Y., Ghosh, J., Grauman, K.: Discovering important people and objects for egocentric video summarization. In: CVPR 2012, pp. 1346–1353 (2012) 8. Lei, Z., Sun, K., Zhang, Q., Qiu, G.: User video summarization based on joint visual and semantic affinity graph. In: ACM Multimedia 2016 Workshop on Vision and Language Integration Meets Multimedia Fusion, pp. 45–52 (2016) 9. Li, Y., Wang, L., Yang, T., Gong, B.: How Local Is the Local Diversity? Reinforcing Sequential Determinantal Point Processes with Dynamic Ground Sets for Supervised Video Summarization. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11212, pp. 156–174. Springer, Cham (2018) 10. Mahasseni, B., Lam, M., Todorovic, S.: Unsupervised video summarization with adversarial LSTM networks. In: CVPR 2017, pp. 2982–2991 (2017) 11. Mihalcea, R., Tarau, P.: TextRank: bringing order into text. In: Conference on Empirical Methods in Natural Language Processing 2004, pp. 404–411 (2004) 12. Narasimhan, M., Rohrbach, A., Darrell, T.: CLIP-It! Language-guided video summarization. In: Ranzato, M., Beygelzimer, A., Dauphin, Y., Liang, P., Vaughan, J. (eds.) NIPS 2021, pp. 13988–14000 (2021) 13. Park, J., Lee, J., Kim, I.-J., Sohn, K.: SumGraph: Video Summarization via Recursive Graph Modeling. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12370, pp. 647–663. Springer, Cham (2020) 14. Rochan, M., Wang, Y.: Video summarization by learning from unpaired data. In: CVPR 2019, pp. 7902–7911 (2019) 15. Rochan, M., Ye, L., Wang, Y.: Video Summarization Using Fully Convolutional Sequence Networks. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11216, pp. 358–374. Springer, Cham (2018) 16. Sharghi, A., Gong, B., Shah, M.: Query-Focused Extractive Video Summarization. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9912, pp. 3–19. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46484-8_1 17. Sharghi, A., Laurel, J., Gong, B.: Query-focused video summarization: dataset, evaluation, and a memory network based approach. In: CVPR 2017, pp. 2127–2136 (2017) 18. Song, Y., Vallmitjana, J., Stent, A., Jaimes, A.: TVSum: summarizing web videos using titles. In: CVPR 2015, pp. 5179–5187 (2015) 19. Vaswani, A., et al.: Attention is all you need. In: Guyon, I., et al. (eds.) NIPS 2017, pp. 5998– 6008, Long Beach, CA, USA (2017) 20. Yuan, L., Tay, F., Li, P., Zhou, L., Feng, J.: Cycle-SUM: cycle-consistent adversarial LSTM networks for unsupervised video summarization. In: AAAI 2019 (2019) 21. Zhang, K., Chao, W., Sha, F., Grauman, K.: Summary transfer: exemplar-based subset selection for video summarization. In: CVPR 2016 (2016) 22. Zhang, K., Chao, W.-L., Sha, F., Grauman, K.: Video Summarization with Long Short-Term Memory. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9911, pp. 766–782. Springer, Cham (2016) 23. Zhang, K., Grauman, K., Sha, F.: Retrospective Encoders for Video Summarization. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11212, pp. 391–408. Springer, Cham (2018)
180
Z. Lei et al.
24. Zhang, Q., Qiu, G.: Bundling centre for landmark image discovery. In: ICMR 2015, pp. 179– 186 (2015) 25. Zhao, B., Li, X., Lu, X.: Hierarchical recurrent neural network for video summarization. In: Liu, Q., et al. (eds.) ACMM 2017, pp. 863–871 (2017) 26. Zhao, B., Li, X., Xiaoqiang, L.: TTH-RNN: tensor-train hierarchical recurrent neural network for video summarization. IEEE Trans. Ind. Electron. 68(4), 3629–3637 (2021) 27. Zhou, K., Qiao Y.: Deep reinforcement learning for unsupervised video summarization with diversity-representativeness reward. In: AAAI 2017 (2017) 28. Saquil, Y., Chen, D., He, Y., Li, C., Yang, Y.-L.: Multiple pairwise ranking networks for personalized video summarization. In: ICCV 2021, pp. 1718–1727 (2021) 29. Lei, J., Berg, T.L., Bansal, M.: QVHightlights: detecting moments and highlights in videos via natural language queries. In: NIPS 2021 (2021) 30. Liu, Y., Li, S., Wu, Y., Chen, C., Shan, Y., Qie, X.: Umt: unified multimodal transformers for joint video moment retrieval and highlight detection. In: CVPR 2022 (2022) 31. Wu, G., Lin, J., Silva, C.T.: Intentvizor: towards generic query guided interactive video summarization. In: CVPR 2022, pp. 10503–10512 (2022)
An Industrial Defect Detection Network with Fine-Grained Supervision and Adaptive Contrast Enhancement Ying Xiang1,2,3 , Hu Yifan1,2,3 , Fu Xuzhou1,2,3 , Gao Jie1,2,3 , and Liu Zhiqiang1,2,3(B) 1 College of Intelligence and Computing, Tianjin University, Tianjin 300350, China
{xiang.ying,huyifan,fuxuzhou,gaojie,tjubeisong}@tju.edu.cn
2 Tianjin Key Laboratory of Cognitive Computing and Application, Tianjin 300350, China 3 Tianjin Key Laboratory of Advanced Networking, Tianjin 300350, China
Abstract. Object detection approaches based on deep learning have made remarkable results in Automated Defect Inspection (ADI). However, some challenges still remain. Firstly, many defect objects lack semantic information, which causes the convolutional kernels tend to capture simple gray anomalies, thus making it challenging for the network to distinguish between foreground and background interference. Secondly, the poor image quality like low contrast makes it even more difficult for convolutional networks to extract effective features. To address these issues, this paper propose a one-stage defect detection network with additional fine-grained supervision to enable the model to learn richer features aside from the grayscale, as well as an image enhancement module to adaptively adjust image contrast and highlight object areas. Comprehensive experiments demonstrate significant performance improvements of our proposed method compared to the baseline and other defect detection methods, while maintaining high efficiency, which confirm the correctness and effectiveness of our model. Keywords: Deep Learning · Convolutional Neural Networks · Object Detection · Automated Defect Inspection
1 Introduction Automated locating and recognizing surface defects like scratches, crazing, or damages of industrial products from digital images is a committed step in computer vision-based defect inspection systems, which is widely applied in various fields, such as automobile industry, machine manufacturing, light industry, etc. With the rapid development of deep learning technology, the detection approaches of defects gain a significant boost in performance compared to the traditional methods. Despite the advances in deep learning-based methods, defect detection in industrial images still faces several challenges. Firstly, the poor image quality resulting from low contrast and noise makes some defect objects hardly visible, Secondly, most defects lack rich semantics in their interior. The weak-semantic characteristics of defect objects can lead to a decrease in the accuracy of deep learning-based detectors, Due to the © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 181–192, 2023. https://doi.org/10.1007/978-981-99-4761-4_16
182
Y. Xiang et al.
lack of salient features inside the objects and the irregular or non-fixed shapes of most objects, the convolution kernels tend to capture the gray value changes in the image. In other words, the detector is more sensitive to the parts of the image where the gray value differs from the background, and is thus more easily affected by imaging problems such as speckles and noise, resulting in the formation of incorrect high-response areas, ultimately affecting the accuracy of the detector. Similarly, due to the observation that grayscale variations are the main feature for distinguishing between foreground and background regions, the performance of the model is affected when the image contrast is low, as the convolution kernels struggle to extract effective features without clear gray value changes. Some works aim to address with the above issues by fusing multi-level features [5, 17] or integrating attention mechanism [3, 16] and have achieved promising results. However, these methods essentially improve the model’s ability to classify features, rather than enabling the model to learn richer feature representations, resulting in the detection performance still being influenced by background interference. Furthermore, the issue of image quality persists and poses difficulties for models to extract suitable features in industrial images. To enable the model to learn more diverse features, a possible solution is to convert the original bounding box labels into more fine-grained pixel-level labels. For instance, the model can be trained to label all pixels within the bounding box as defects, which prompts the model to also pay attention to areas within the bounding box that have similar gray values to the background. This enables the model to learn features beyond gray value changes. Based on this idea, we propose a feature enhancement module with fine-grained supervision to suppress incorrect responses in the feature map and improve the accuracy of the detector. Meanwhile, to address the imaging quality problem, an image enhancement module is proposed in this paper to adaptively adjust the contrast of the foreground and background regions of the image, highlighting the object area and enhancing the subsequent network’s feature extraction, ultimately improving the detection performance. In summary, the main contributions of this paper are as follows: 1) We identified the weak-semantic characteristics in industrial defect detection and designed a fine-grained supervision mechanism to enhance the feature extraction ability of deep learning detectors. 2) We proposed an adaptive foreground enhancement method to improve the low contrast and less prominent object areas in the original images. 3) We proposed a one-stage defect detection model that integrates the above modules. Through comprehensive experiments, we have validated its effectiveness in addressing the aforementioned issues.
2 Related Works Deep-learning based defect detection methods can be divided into the following: Feature Fusion. Given that the visual features of defects primarily manifest in grayscale, color, and texture, and defects may vary considerably in scale, feature fusion has proven to be an effective strategy to improve the accuracy of detection. For example, He et al. [5] fuse the last layers of each residual block in ResNet [4], and the resulting
An Industrial Defect Detection Network
183
feature map was subsequently passed to RPN [12] for generating more precise proposal regions. In addition, Zeng et al. [17] propose a feature fusion method that balances features across stages. These approaches have been shown to effectively detect defects, but they may not enable the network to learn better feature representation, thus leaving the model susceptible to background interference. Attention Mechanism. Using an attention mechanism to improve defect detection is a natural idea since many studies have demonstrated its effectiveness in visual tasks [13]. Cheng et al. [3] introduce a difference channel attention module to reduce information loss during network inference. Additionally, Yeung et al. [16] propose a fused-attention network that enhances channel and spatial feature information for precise localization and classification of defects. Our approach can also be regarded as an attention mechanism, but with the introduction of fine-grained supervision, our model not only enhances existing information but also encodes extra richer information. Moreover, similar attention methods have also been adopted in other works [1, 6, 15] to improve performance. Prior Knowledge. Some detection scenarios inherently contain strong prior knowledge. For example, In [10], due to a fixed camera angle, a certain category of objects will always appear in a specific area of the image. Therefore, Liu et al. propose an unsupervised clustering algorithm to establish a map between the image category and the coarse regions of objects, before the image is sent to the network. Applying these prior knowledge to the detection process can greatly improve the accuracy of detection. However, not all detection scenarios have such prior knowledge. Therefore, these methods may lack generalization ability.
3 Proposed Method To address the issues discussed in Sect. 1, this paper proposes a one-stage defect detection network that incorporates fine-grained supervision to encourage the model to encode more diverse features, and an image enhancement module that adaptively enhances the contrast between the foreground and background. Our model is built on RetinaNet [8], which utilizes focal loss to handle the imbalance between positive and negative samples, and also includes FPN [7] for feature fusion, a commonly used technique in most recent defect detection methods as we mentioned in Sect. 2. The overall structure is shown in Fig. 1.
Fig. 1. The overall structure.
184
Y. Xiang et al.
3.1 Fine-grained Supervision Label Generation. To encourage the network to learn more diverse features, the bounding box labels of object detection task can be transformed into pixel-level supervision. Building on this idea, we introduce a fine-grained label called the Hard Mask, which is defined in Eq. (1): 1, (i, j) ∈ Mhard (i, j) = (1) 0, (i, j) ∈ / where Mhard is a binary mask of the same size as the input image, (i, j) represents the coordinate of any pixel in Mhard , and is the set of all pixels inside the bounding boxes. The Hard Mask provides a simple and efficient solution, but in some cases, it may result in information loss. This is illustrated intuitively in Fig. 2(a), where the Hard Mask fails to distinguish overlapped objects, which makes it difficult to effectively supervise the network using the information from these regions. To preserve the information in overlapping regions, it is necessary to assign different values to the same pixel in the label for different objects. Therefore, we propose an alternative label generation method known as the Soft Mask. The Soft Mask utilizes the Gaussian distribution to generate mask labels. Specifically, for each bounding box in the image, a two-dimensional Gaussian distribution is derived, with the mean value being the center point of the bounding box, and the variance being 1/2 of the width and height of the box, respectively. It is worth noting that the horizontal and vertical distributions can be treated as independent distributions for convenience. The distribution derived from each bounding box is normalized to the range (0, 1]. Finally, the Soft Mask calculates the maximum value of each distribution on the pixel-level: 0 1 (i, j), . . . , M n (i, j) Msoft (i, j) = max Msoft (i, j), Msoft soft 2
k (i, j) = exp − 1 Msoft 2
σxk =
wk k 2 , σy
i−μkx
2
(σxk )2 =
+
hk k 2 , μx
j−μky 2 σyk
, k = 0, 1, . . . , n
(2)
= xk , μky = yk
the variables (xk , yk ) denote the center coordinates of the k-th bounding box, while wk and hk represent its width and height. The mean and variance of the Gaussian distribution are denoted by μ and σ, respectively, and exp(·) represents the exponential function with base e. The Soft Mask generates mask labels that can retain and differentiate each object in the supervision information, thus addressing the issue of information loss in the Hard Mask, as shown in Fig. 2(b). It should be noted that, in contrast to the Hard Mask, the Soft Mask also requires the network to pay attention to the context surrounding the objects. To investigate whether this behavior could enhance the model performance, we also introduce an additional approach called the Half-soft Mask: Mhard (i, j), (i, j) ∈ (3) Mhalf−soft (i, j) = / M soft (i, j), (i, j) ∈ The comparison of the above three mask labels will be shown in Sect. 4.
An Industrial Defect Detection Network
(a) Challenges of the Hard Mask in determining position and shape information for overlapping objects.
185
(b) The Soft Mask preserves information of overlapped objects.
Fig. 2. Comparisons of the Hard Mask and the Soft Mask.
Applying the Fine-grained label. One obvious issue with the fine-grained labels is that they are generated based on bounding box annotations, which means that the model only needs to learn the coordinate information of the bounding box to naturally classify the pixels inside it as foreground. To prevent the weakening of the supervision effect of these fine-grained labels caused by the original bounding box annotations, we apply the fine-grained labels to an independent branch called the Feature Enhancement Module (FEM), as shown in Fig. 3(a). In contrast, Fig. 3(b) demonstrates the alternative approach of directly incorporating the mask prediction into the detector’s head. We will compare these two approaches in Sect. 4.
(a) Applying the fine-grained labels to an independent branch.
(b) Applying the fine-grained labels to the detector's head.
Fig. 3. Two approaches of applying the fine-grained labels.
The FEM generates a mask to filter out incorrect responses in the backbone’s feature map, preventing them from being passed to the deeper layers of the network and causing erroneous foreground-background classification. This process can be formalized as Eq. 4: FE = fFEM (FO ) fFEM (FO ) = FO ⊗ (FO )
(4)
here, fFEM (·) is the function fitted by the FEM, FE and FO denote the enhanced feature map and the input feature map, respectively, and ⊗ represents the element-wise multiplication. (·) is implement by a set of dilation convolutions with different dilation rate, which is inspired by [18].
186
Y. Xiang et al.
3.2 Adaptive Image Enhancement To address the problem of the low contrast between foreground and background, an Image Enhancement Module (IEM) is proposed in this paper, as illustrated in Fig. 4. Given an original input image, the IEM serves as a learnable pre-processing module to enhance the significance of the defect area by generating a weight map, and the enhanced image is obtained by pixel-wise multiplication of the original image and the weight map. To promote the perception ability of the IEM towards the foreground area, we introduce additional weak supervision information, referred to as Counting Supervision. Specifically, the IEM is required to predict the number of defect objects for each category in the image. If the IEM can predict these numbers correctly, it indicates that it can perceive the foreground area in the image. Due to the reasons mentioned in Sect. 3.1.2, this supervision information should also be applied in an independent branch. Therefore, the IEM is designed as a dual-branch structure, with one branch receiving counting supervision called the Counting Supervision Branch (CSB), and the other branch generating the weight map called the Weight Map Generation Branch (WMGB). The Weight Map Generation Branch receives foreground area information encoded in the Counting Supervision Branch through feature fusion.
Fig. 4. The structure of the IEM.
Experiments in [14] have shown that convolutional neural networks are more sensitive to high-frequency information in images, and they can still make correct predictions when only high-frequency information is inputted. Therefore, to improve the efficiency of the IEM, we introduce Octave Convolution [2] to decompose the input feature F ∈ Rh×w×c into high-frequency components F H ∈ Rh×w×(1−α)c and low-frequency h w components F L ∈ R 2 × 2 ×αc , where α ∈ [0, 1] is the decomposition coefficient, and is set to 0.5 in our experiments. We can only send F H to the CSB to predict the target quantity, while using F L to generate the weight map and enhance the input image. Besides, to further reduce computational costs, both branches employ separable convolutions. 3.3 Loss Function The model’s overall loss function comprises four parts: counting loss, mask loss, classification loss, and bounding box loss, which are combined in Eq. (5). L = Lcount + Lmask + Lcls + Lbbox
(5)
An Industrial Defect Detection Network
187
where the mask loss is defined by the mean square error, the classification loss follows the RetinaNet by using Focal Loss, and the bounding box loss is calculated using L1 loss. The task of counting objects within an image is akin to predicting a probability distribution rather than a discrete set of independent values. Therefore, we use K-L divergence as the loss function for effectively training the model, as shown in Eq. (6), where C is the number of classes, pc and yc are the predicted and true normalized values of the c-th class. C yc (6) Lcount = yc log c pc
4 Experiments 4.1 Experiments Settings and Evaluation Metrics Our experiments were conducted on the NEU-DET dataset [5], which includes 1800 steel surface-defect images and about 5000 defect objects. All the image data were captured in real production environment. The training set and the test set are randomly split with a ratio of 7:3, resulting in 1260 images used for training and 540 images used for testing. The objects are annotated in VOC format and divided into six classes: Crazing (Cr), Inclusion (In), Patches (Pa), Pitted surface (Pi), Rolled-in scales (Ro), and Scratches (Sc). Each class contains 300 images, of which 210 images are used for training and 70 images are used for testing. The model is based on RetinaNet and utilizes a stochastic gradient descent optimizer with a learning rate of 10−3 , a momentum of 0.9, and a weight decay of 10−4 . The hyperparameters for other comparison methods follow their default settings. In all experiments, the input image size is scaled to 300 × 300 and data augmentation is performed by randomly flipping the images. All experiments are conducted on a NVIDIA RTX Titan GPU. For the experiments, mean average precision (mAP) was used as the evaluation metric. mAP represents the mean average precision of each category at different recall rates under a certain intersection over union (IoU) threshold, which is usually set to 0.5. In addition, the mAP can also be the mean value under a series of IoU thresholds. To distinguish between these two methods, the former is referred to as mAP and the latter is referred to as AP. 4.2 Comparison of the Generation of the Fine-grained Labels In Sect. 3.1.1, we proposed three different mask labels: Hard Mask, Soft Mask, and Half-Soft Mask. The experiment results are presented in Table 1. As can be observed from the table, FEM achieved a performance improvement with all three mask labels, which proves that introducing fine-grained supervision into defect detection is effective. Among the three mask labels, the Soft Mask-based FEM achieved the best performance, indicating that the Soft Mask can provide more effective supervision information than the Hard Mask. Moreover, the experiment results comparing the
188
Y. Xiang et al. Table 1. Results of the comparison of three mask labels.
Method
mAP
Cr.
In.
Pa.
Pi.
Ro.
Sc.
RetinaNet
69.8
39.1
72.7
88.1
80.9
62.3
75.6
RetinaNet + FEM (Hard Mask)
71.2
40.1
69.9
87.3
84.2
64.3
81.6
RetinaNet + FEM (Soft Mask)
71.8
39.3
74.8
87.6
83.7
67.4
77.7
RetinaNet + FEM (Half-Soft Mask)
71.0
43.0
69.9
86.4
85.1
65.9
75.0
Half-Soft Mask and the Hard Mask suggest that the additional context information is not the main factor that leads to the performance improvement of the Soft Mask. In Sect. 3.1.1, we explained that the Soft Mask can prevent the loss of supervision information caused by overlapping objects, which leads to improved detection performance. In order to further validate this point, we divided the testing set into two mutually exclusive subsets according to whether there are overlapping objects or not, denoted as Toverlap and Tnon−overlap , respectively. Comparative experiments were then conducted on these two subsets. The results are presented in Table 2. Table 2. Results of the comparison on Toverlap (mAP1) and Tnon−overlap (mAP2). Method
mAP1
mAP2
RetinaNet + FEM (Hard Mask)
62.8
73.4
RetinaNet + FEM (Soft Mask)
65.4
73.2
The above results show that the Soft Mask outperforms the Hard Mask significantly on images with overlapping objects, while their performance is roughly comparable on other images. This suggests that the performance improvement of the Soft Mask is mainly due to its ability to retain supervision information on overlapping objects, thereby confirming our aforementioned observation. 4.3 Comparison of the Approaches of Applying Fine-grained Labels As mentioned in Sect. 3.1.2, the original bounding box annotations may weaken the fine-grained labels’ supervision effect, thus requiring these labels to be applied to an independent branch. To verify the effectiveness of this approach, we compared two ways of applying the fine-grained labels proposed in Sect. 3.1.2. The method of applying the labels to an independent branch, which is used in this paper, is denoted as FEMA , while the method of parallel application with classification and regression tasks is denoted as FEMB . It is worth noting that, given the adoption of the FPN structure in this paper, the mask labels for FEMB should be resized to fit the different scales of the multi-level feature maps output by FPN. The experiment results are shown in Table 3. The experiments show that the additional supervision information in the approach of FEMB does not have a significant supervisory effect, while using the FEMA with an
An Industrial Defect Detection Network
189
Table 3. Results of the comparison of FEMA and FEMB . Method
mAP
Cr.
In.
Pa.
Pi.
RetinaNet
69.8
RetinaNet + FEMA
71.8
RetinaNet + FEMB
69.7
38.8
Ro.
Sc.
39.1
70.7
88.1
39.3
74.8
87.6
80.9
62.3
75.6
83.7
67.4
77.7
70.2
86.5
84.0
58.6
80.0
independent branch can bring significant performance improvements, thereby validating our hypothesis. 4.4 Comparison with Other Object Detection Methods In addition to the baseline RetinaNet, this paper also implemented and compared some common methods for defect detection tasks. The experiment results are shown in Table 4, while Table 5 shows the comparison of precision at different IoU thresholds. It should be noted that since the NEU-DET dataset does not provide specific training and testing set divisions, some results may differ from those reported in [5]. Table 4. Comparison with other object detection methods. Methods marked with “*” are specifically designed for defect detection. Method
Backbone
mAP
Cr.
In
Pa.
Pi.
Ro.
Sc.
Faster R-CNN (2015) [12]
Resnet50
70.2
37.6
71.6
87.7
78.3
57.9
88.1
SSD (2016) [9]
VGG16
68.6
38.0
72.7
86.2
78.2
65.8
71.1
FPN (2017) [7]
Resnet50
72.4
38.0
75.7
88.7
84.8
63.2
84.1
YOLOv3 (2018) [11]
Darknet53
64.6
18.1
72.1
84.6
74.2
52.3
85.9
DDN* (2019) [5]
Resnet50
70.7
34.4
71.3
86.9
77.7
66.9
86.9
D-DETR (2021) [19]
Resnet50
65.0
27.9
67.1
80.7
79.1
50.7
84.2
FANet* (2022) [16]
Resnet50
67.1
32.3
69.4
84.0
76.5
56.9
83.7
RetinaNet (2017) [8]
Resnet50
69.8
39.1
72.7
88.1
80.9
62.3
75.6
RetinaNet + IEM
Resnet50
71.8
39.7
73.2
88.1
86.5
63.0
80.4
RetinaNet + FEM
Resnet50
71.8
39.3
74.8
87.6
83.7
67.4
77.7
RetinaNet + IEM + FEM
Resnet50
72.6
38.2
71.3
88.0
86.7
68.7
83.0
Based on Table 4 it can be observed that incorporating either IEM or FEM can improve the defect detection performance for almost all categories, indicating the effectiveness of the proposed modules. Meanwhile, Table 5 demonstrates that our model outperforms the other compared methods and achieves the highest mAP at different IoU thresholds. Although the Faster R-CNN model with FPN integration achieves a similar accuracy to our proposed method, it falls behind in terms of efficiency as a two-stage detection framework.
190
Y. Xiang et al.
Table 5. Comparison results under different IoU thresholds, where AP represents the mean value of mAP when the IoU threshold is 0.5–0.95, AP50 and AP75 represent the mAP value when the IoU threshold is 0.5 and 0.75, respectively. Method
AP
AP50
AP75
Params(M)
FPS
Faster R-CNN (2015)
34.6
70.2
29.1
32.8
6.2
SSD (2016)
32.2
68.6
25.7
24.4
21.2
FPN (2017)
35.5
72.4
27.5
41.2
14.0
YOLOv3 (2018)
32.2
64.6
26.7
61.6
22.8
DDN (2019)
32.7
70.7
25.4
165.0
14.4
D-DETR (2021)
34.8
65.0
32.4
39.8
13.4
FANet (2022)
34.4
67.1
30.0
32.0
16.2
RetinaNet (2017)
34.5
69.8
28.6
36.2
18.1
Ours
35.5
72.6
30.0
36.5
17.6
4.5 Visual Comparisons Visual Comparisons of the IEM. In order to intuitively demonstrate the enhancement effect of the IEM on images, this section presents a visual comparison of representative examples of original and enhanced images, as well as a comparison of grayscale distributions before and after enhancement. From Fig. 5(a), it can be seen that compared to the original image, the contrast between the defect and background regions is enhanced in the image output by the IEM, making the defect object more visually prominent. The grayscale histogram comparison on the below also verifies this: compared to the original image, the grayscale distribution of the enhanced image is more uniform. The examples in Fig. 5(b) and (c) represent situations where the original images are overall brighter or darker with the grayscale distributions concentrated on either sides of the horizontal axis, respectively. And the IEM enhances the images and improves image quality by equalizing the grayscale distribution and shifting it towards the center. These examples illustrate that IEM can adaptively enhance images based on different original image conditions, improve image quality, and make defect objects more prominent in the images. Visual Comparisons of the FEM. We also compare the activations in the feature maps of the original RetinaNet with those of our model incorporating FEM, as shown in Fig. 6, where the first row shows the original image with ground truth bounding boxes. The second and third rows depict the activation maps obtained using RetinaNet without and with the FEM, respectively. The results clearly indicate that the model with FEM no longer simply responds to regions with anomalous grayscale values, but also generates high responses correctly within the entire object region. This intuitively demonstrates that FEM can capture richer information and enhance the model’s feature representation.
An Industrial Defect Detection Network
(a) General case.
(b) Image with overall high brightness
191
(c) Image with overall low brightness
Fig. 5. Examples of enhancement results in different grayscale cases.
Fig. 6. Visual comparison of activation maps.
5 Conclusion In this paper, we propose a novel defect detection method based on RetinaNet with the introduction of two new modules: IEM and FEM. Experiment results show that both IEM and FEM can improve mAP by 2.0, while our overall model achieves a 2.8 mAP improvement compared to the baseline. Moreover, our method outperforms other commonly used methods at different IoU thresholds, demonstrating the effectiveness of our proposed approach. Our analysis also reveals that the FEM module can learn richer information and improve the model’s features, especially in cases of overlapping objects. Overall, our proposed method provides a promising solution for defect detection tasks and has potential for practical applications in various industries. Future work can explore the potential of the IEM and FEM modules further and extend the proposed method to other defect detection datasets.
192
Y. Xiang et al.
References 1. Chen, H., et al.: Dcam-net: a rapid detection network for strip steel surface defects based on deformable convolution and attention mechanism. IEEE Trans. Instrum. Meas. 72, 1–12 (2023) 2. Chen, Y., et al.: Drop an octave: reducing spatial redundancy in convolutional neural networks with octave convolution. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3435–3444 (2019) 3. Cheng, X., Yu, J.: Retinanet with difference channel attention and adaptively spatial feature fusion for steel surface defect detection. IEEE Trans. Instrum. Meas. 70, 1–11 (2020) 4. He, K., et al.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 5. He, Y., et al.: An end-to-end steel surface defect detection approach via fusing multiple hierarchical features. IEEE Trans. Instrum. Meas. 69(4), 1493–1504 (2019) 6. Li, M., Wang, H., Wan, Z.: Surface defect detection of steel strips based on improved yolov4. Comput. Electr. Eng. 102, 108208 (2022) 7. Lin, T.Y., et al.: Feature pyramid networks for object detection. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2117–2125 (2017) 8. Lin, T.Y., et al.: Focal loss for dense object detection. In: Proceedings of the IEEE international conference on computer vision, pp. 2980–2988 (2017) 9. Liu, W., et al.: Ssd: Single shot multibox detector. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 21–37. Springer, Cham (2016). https://doi.org/ 10.1007/978-3-319-46448-0_2 10. Liu, Z., et al.: A high-precision positioning approach for catenary support components with multiscale difference. IEEE Trans. Instrum. Meas. 69(3), 700–711 (2019) 11. Redmon, J., Farhadi, A.: Yolov3: An incremental improvement. arXiv:1804.02767 (2018) 12. Ren, S., et al.: Faster r-CNN: Towards real-time object detection with region proposal networks. Adv. Neural. Inf. Process. Syst. 28, 91–99 (2015) 13. de Santana Correia, A., Colombini, E.L.: Attention, please! a survey of neural attention models in deep learning. Artif. Intell. Rev. 55(8), 6037–6124 (2022) 14. Wang, H., et al.: High-frequency component helps explain the generalization of convolutional neural networks. In: Proceedings of the IEEE/CVF conference on computer vision and pattern recognition, pp. 8684–8694 (2020) 15. Xiang, X., et al.: AGCA: an adaptivegraph channel attention module for steel surface defect detection. IEEE Trans. Instrum. Meas. 72, 1–12 (2023) 16. Yeung, C.C., Lam, K.M.: Efficient fused-attention model for steel surface defect de-tection. IEEE Trans. Instrum. Meas. 71, 1–11 (2022) 17. Zeng, N., et al.: A small-sized object detection oriented multi-scale feature fusion approach with application to defect detection. IEEE Trans. Instrum. Meas. 71, 1–14 (2022) 18. Zhang, Z., et al.: Single-shot object detection with enriched semantics. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5813–5821 (2018) 19. Zhu, X., et al.: Deformable DETR: deformable transformers for end-to-end object detection. In: 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, May 3–7, 2021. OpenReview.net (2021)
InterFormer: Human Interaction Understanding with Deformed Transformer Di He , Zexing Du , Xue Wang(B)
, and Qing Wang
School of Computer Science, Northwestern Polytechnical University, Xi’an 710072, China [email protected]
Abstract. Human interaction understanding (HIU) is a crucial and challenging problem, which consists of two subtasks: individual action recognition and pairwise interactive recognition. Previous methods do not fully capture the temporal and spatial correlations when understanding human interactions. To alleviate the problem, we decouple HIU into complementary parts for exploring comprehensive correlations among individuals. Especially, we design a multi-branch network, named InterFormer, to jointly model these interactive relations, which contains two parallel encoders to generate spatial and temporal features separately, and SpatialTemporal Transformers (STTransformer) to exploit spatial and temporal contextual information in a cross-manner. Extensive experiments are conducted on two benchmarks, and the proposed InterFormer achieves state-of-the-art performance on these datasets. Keywords: Human interaction · Spatial-temporal features · Transformer
1 Introduction Since modeling human interaction is of great significance for video surveillance [23], social scene understanding [2], and sports analysis [18], HIU is a key research problem. Different from traditional action recognition, HIU focuses on understanding the interactions among individuals. As shown in Fig. 2, besides recognizing individual action labels, HIU also needs to understand the interactions among individuals, which is relatively challenging, and can conversely benefit individual action recognition. Multiple individuals perform various interactions in the scene and the human interaction changes with the complex variations of spatial and temporal transformations [29]. Earlier methods [30, 31, 33] utilized deep graphical models or conditional random fields (CRFs) to model the interaction, which are of shallow graphical representations and are inadequate to model complicated human interactions. More recently, several attempts with attention mechanism [7, 9, 14, 17] have been proposed to model human interactions. The attention mechanism learns interactions between the actors and selectively extracts information that is important [9]. However, considering human interactions dynamically develop with the change of spatial and temporal, previous methods which either separately or sequentially extract spatial and temporal features can not fully capture the spatial temporal dynamics for the interested person to recognize the individual action © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 193–203, 2023. https://doi.org/10.1007/978-981-99-4761-4_17
194
D. He et al.
or the pair to identify their interaction relations. The importance of spatial and temporal information is not consistent for different actions. Capturing spatial temporal relations comprehensively is critical for reasoning about interactive relations (Fig. 1).
Fig. 1. Example of HIU visualization in a social scene. The rectangle and the above text represent a detected person and the predicted individual action label. OT, KK and BK indicate others, kick and be-kicked respectively. The orange line between two persons denotes a predicted interactive pair.
To address the challenge, we present the InterFormer, which models human interactions by integrating properties from different spatial-temporal branches. The InterFormer consists of two parallel encoders to learn spatial and temporal features separately, and STTransformer to model spatial-temporal context jointly. In short, the contributions can be summarized in three folds: – We present a new HIU framework termed InterFormer, which integrates merits from different spatial-temporal paths for individual action recognition and pairwise interactive understanding. – Two parallel spatial-temporal encoders and two spatial-temporal cross-attention models are utilized to explore complementary correlations among individuals comprehensively. – Experimental results show that our approach achieves leading performance on two evaluated benchmarks.
2 Related Works 2.1 Human Action Recognition Human action recognition (HAR) is a fundamental task in the computer vision community. HAR can be recognized by applying conventional machine learning methods, which require designing and selecting relevant features and might still achieve suboptimal performance [11]. Since the invention of deep learning, numerous works on HAR have been proposed to extract features [5, 8, 19, 21, 27, 36]. [36] pay more attention to the timing information of the video. [19, 27] simultaneously convolve the dimensions between space and time to extract spatial-temporal features. These approaches are designed to predict action categories, leaving the HIU task rather unsolved [32].
InterFormer: Human Interaction Understanding
195
2.2 Human Interaction Understanding Conditional random fields [30, 31] have been widely applied to model interactive relations. These CRFs use different potential functions and different graph construction methods but are of shallow graphical representations, which are incapable of identifying multiple interactions in the same video particularly at the same time. Graph neural networks (GNNs) [18, 33] are used to simulate pairwise interaction relationships between people. [33] introduces the graph model and builds an actor relation graph using GCN to augment the individual representation, while the constructed relational graphs are limited to a few frames and don’t consider the latent temporal individual relationships. [18] present a novel attentive semantic recurrent neural network (RNN) by combining the spatial-temporal attention mechanism and semantic graph modeling. These methods using GNNs model the interactions between individuals on a predefined graph, which typically results in sub-optimal solutions. More recent methods [12, 16, 34] generate the explicit representation of spatialtemporal relations, dedicated to applying attention-based methods to model relations. A multi-modal interaction relation representation model based on spatial-temporal attention mechanism is proposed [34]. [12] integrate two complementary spatial-temporal views to learn complex actor relations in videos. The most related work is [16], which captures spatial-temporal contextual information jointly to augment the individual and group representations effectively with a clustered spatial-temporal transformer. However, it pays much more attention to modeling spatial-temporal context jointly for interactive relations inferring without integrally considering the inconsistent importance of temporal and spatial information for different action categories. Different from them, we propose a multi-branch network to complementarily model these interactive relations and individual actions, in which parallel encoders are deployed to generate spatial and temporal features separately, and STTransformer to model spatial-temporal context jointly. 2.3 Transformer The Transformer is first proposed for sequence-to-sequence machine translation tasks in [28] and is a class of deep learning models defined by some architectural features. The Transformer architecture is a specific instance of the encoder-decoder model [4] where attention can be used as the only mechanism to derive dependencies between inputs and outputs. Transformer has since been widely used in various natural language processing tasks because the self-attention mechanism adopted is particularly suitable for capturing long-term dependencies and is soon applied to other tasks, such as vision. Transformer and its variants have been successfully used in image recognition [6, 26], video understanding [10, 24], and several other use cases [22, 35]. The spatial-temporal tokens are extracted in [1, 3]. TimeSFormer [3] proposes divided spatial and temporal attention. The Video Vision Transformer (ViViT) [1] extracts spatial and temporal tokens from video. A key enabler is the transformer’s self-attention mechanism, which learns interactions between the actors and selectively extracts information and is particularly suitable for capturing long-term dependencies, so we propose transformers for HIU.
196
D. He et al.
Fig. 2. An overview of the proposed InterFormer, which includes a base model, two parallel encoders and two spatial-temporal cross-attention models. Xd represents individual features and Xi represents interaction features, which are input to the Parallel Encoders module and the SpatialTemporal Transformer module to exploit complementary correlations among individuals. And the outputs are fused for pairwise interaction and individual action recognition. During training, we use the cross-entropy loss for both individual and interaction classification.
3 Our Model The overview of our model is illustrated in Fig. 2. We first use CNN to extract features from the input image in Sect. 3.1 and then employ two parallel encoders to generate temporal and spatial features separately in Sect. 3.2. We introduce our STTransformer in detail in Sect. 3.3. 3.1 Features Extractor Given an input image Ximg and detected human bounding boxes, the base model employs a backbone CNN to extract features from the input, then RoIAlign [13] is applied to extract features for each individual given N bounding boxes in each image. Finally, extracted features are processed by an FC layer to generate individual features Xd and interaction features Xi . 3.2 Two Parallel Encoders After feature extraction, we employ two encoders (a spatial encoder and a temporal encoder) in parallel to generate spatial and temporal features respectively. Spatial Encoder: The spatial decoder treats the temporal dimension as a batch dimension and applies the encoder to develop the spatial context of all frames. The encoder is composed of an identical layer, which has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network [28]. Given the input feature X for the t-th frame, the attention function can be formulated as: QK T V, (1) Attention(Q, K, V ) = soft max √ D
InterFormer: Human Interaction Understanding
197
the Q,K, and V are computed as: Q = XWq , K = XWk , V = XWv ,
(2)
where Wq , Wk , and Wv are learnable parameters. D represents the dimension of channels. Under the multi-head attention mechanism, we maintain independent query, key and value weight matrices for each head, resulting in different query, key and value matrices that help the network capture richer features, V = MultiHead (Q, K, V ) + V .
(3)
The FFN is the fully connected feed-forward network in the canonical Transformer. So the process of embedding spatial context for the t-th frame can be formulated as: V = FFN V , (4) where V denotes the feature map for the t-th frame. The feature maps of all the time steps (t = 1, ..., T ) are packed together to obtain Vs . The input individual features Xd and the interaction features Xi are processed following the above operations to extract spatial context. The individual features maps and the interaction feature maps of all the time steps are packed as Vds and Vis respectively. Temporal Encoder: The temporal encoder is applied to mine features with the temporal dynamic evolution clues and enrich temporal context by highlighting informative features in the temporal dimension. The temporal encoder basically follows the procedure for the spatial encoder. The difference from the above spatial encoder is that the temporal encoder treats the spatial dimension as the batch dimension. The notation Vt donates the output of the temporal encoder. Specifically the individual feature map are packed together into Vdt and the interaction feature map are packed together into Vit .
3.3 STTransformer The STTransformer includes two encoders (a spatial encoder and a temporal encoder) in parallel to generate spatial and temporal features respectively and a spatial decoder and a temporal decoder applied by the cross scheme to complementary exploit spatial-temporal context. We now explain STTransformer in detail. Encoders: The encoders follow the operations of the SpatialEncoder and TemporlEncoder in Sect. 3.2. Decoders: Our decoders contain a multi-head cross-attention mechanism and a feedforward network. For the spatial decoder, the output of the spatial encoder Vs is viewed as the query and the output of the temporal encoder Vt is regarded as key and value. MultiHead (Q, K, V ) is calculated as: V s = MultiHead (Vs , Vt , Vt ) + Vs .
(5)
198
D. He et al.
The feed-forward network is computed as: m1 = FFN Vs ,
(6)
where m1 is the output of the spatial decoder. For the temporal decoder, the output of the spatial encoder Vs is viewed as the key and value, and the output of the temporal encoder Vt is regarded as the query. MultiHead (Q, K, V ) is calculated as: Vt = MultiHead (Vt , Vs , Vs ) + Vt .
(7)
The feed-forward network is computed as: m2 = FFN Vt ,
(8)
where m2 is the output of the temporal decoder. The output of these two decoders is obtained by m = m1 + m2 .
(9)
The STTransformer block can be stacked repeatedly, which helps to learn the underlying semantic representations effectively.
3.4 End-to-End Learning Our network is trained in an end-to-end fashion. To ensure recognition consistency between two sub-tasks of HIU, we choose the cross-entropy loss to guide the optimizing process: L = L1 yd , yd + λL2 yi , yi , (10)
where L1 and L2 denote the cross-entropy loss. yd and yi are the individual actions scores and human interaction scores while yd and yi represent the ground truth labels for individual actions and human interaction. λ is the hyperparameter to balance two terms.
4 Experiment We evaluate the proposed method on two popular HIU datasets. We first introduce two datasets and the implementation details of our approach in Sect. 4.1. Then we compare our approach with the state-of-the-art methods in Sect. 4.2. Finally, the ablation studies are conducted to validate the effectiveness of each part within the proposed network in Sect. 4.3. Since the numbers of instances across different classes are significantly imbalanced, we use multiple metrics including F1-score, overall accuracy, and mean IoU for evaluation. For F1-score and overall accuracy, we calculate the mean of the two sub-tasks. To obtain the mean IoU, we first compute the IoU value on each class, then average all IoU values [32].
InterFormer: Human Interaction Understanding
199
4.1 Datasets and Implementation Details Datasets. We conduct experiments on two publicly available datasets, namely the UT dataset and the BIT dataset. UT [20] contains six action classes: handshake, hug, kick, punch, push, and no action. As done in [30], we extend the original action class by introducing a passive class for each of the three asymmetric action classes, including kick, punch, and push (be kicked, be punched, and be pushed), there are 9 action classes in total. Following [16], we divide the 120 short videos of UT into 2 subsets for training and testing. BIT covers 9 action classes including box, handshake, highfive, hug, kick, pat, bend, push, and others (including the passive classes of the former actions). Each class contains 50 short videos. For each class, 34 videos are chosen for training and the rest for testing as [15]. Implementation Details. For feature extractor, we adopt Inception V3 [25] as the backbone. We also apply the ROIAlign with crop size 5 × 5 and a linear embedding to get features with dimension D = 1024. We use one encoder/decoder layer with eight attention heads. For both datasets, we train the network for 100 epochs using mini-batch size of 32 and we adopt ADAM for optimization. The hyper-parameter λ in the cross-entropy loss is set to one, following [12, 16]. In addition, we do not use the positional encoding following [16]. Our implementation is based on PyTorch deep learning framework and a workstation with two NVIDIA GeForce RTX 3090 GPUs. The average time for training one epoch on the UT and BIT datasets is 103 s and 445 s, and the average time for testing one epoch is 21 s and 83 s. Table 1. Comparison with state-of-the-art on UT and BIT. Acc is shorted for accuracy. Manner
UT
BIT
F1 (%)
Acc (%)
IoU (%)
F1 (%)
Acc (%)
IoU (%)
Joint + AS [30]
92.20
95.86
80.30
88.61
91.77
72.12
Modified GN [33]
93.38
96.39
84.13
89.95
91.61
76.42
CAGNet [32]
94.55
97.06
85.50
92.79
95.41
81.32
Ours
97.42
98.81
92.42
93.47
95.92
83.19
4.2 Comparison with the State-of-the-Art We compare our approach with two state-of-the-art methods. For a fair comparison, all methods take Inception V3 as the backbone. The results are listed in Table 1. It is observed that our method outperforms all of these methods for HIU. Compared to the CAG [32], our method achieves an improvement of 2.87/1.75/6.92% on F1/Accuracy/IoU metrics for the UT dataset and improves F1/Accuracy/IoU by 0.68%, 0.51%, and 1.87% respectively for the BIT dataset, which verifies that learning the spatial-temporal context in complementary orders is effective and important for HIU.
200
D. He et al. Table 2. Ablation studies on UT and BIT.
Manner
UT
BIT
F1 (%)
Acc (%)
IoU (%)
F1 (%)
Acc (%)
IoU (%)
Inception V3 [25]
91.44
95.44
79.35
87.84
91.61
72.00
Parallel
95.17
97.27
87.76
92.27
94.89
79.79
STTransformer
95.43
97.83
87.67
92.90
95.40
81.61
Ours
97.42
98.81
92.42
93.47
95.92
83.19
4.3 Ablation Studies To validate the effectiveness of different parts of our model, we perform ablation studies on the validation set of UT and BIT datasets. We compare with the following variants: (1) Baseline: Inception V3 + FC layer. (2) Parallel encoders: this variant consists of a spatial encoder and a temporal encoder. (3) STTransformer: this variant consists of two STTransformer blocks. (4) Our method. Table 2 shows that adopting spatial and temporal encoders in a parallel manner significantly improves the performance on UT and BIT datasets, indicating that learning the spatial-temporal context separately is effective and important for HIU. Meanwhile, the STTransformer performs slightly better than the parallel encoders. And our model makes a remarkable improvement compared to the parallel encoders and STTransformer, which verifies our model integrates merits from different spatial-temporal paths (Fig. 3).
Fig. 3. Confusion matrices on the testing set of BIT by different model variants. (a) is the confusion matrix of STTransformer and (b) represents our model.
To further analyze the effectiveness of our method, we illustrate the confusion matrices generated by STTransformer and our complete model in Fig. 2. Most mistakes occur in the failure to distinguish between box and push. After adding the parallel encoders branch, we can find that the recognition accuracy of box, highfive and hug
InterFormer: Human Interaction Understanding
201
Table 3. Comparisons of different settings for the number of the STTransformer blocks. Number
UT
BIT
F1 (%)
Acc (%)
IoU (%)
F1 (%)
Acc (%)
IoU (%)
0
91.44
95.44
79.35
87.84
91.61
72.00
1
96.48
98.09
90.72
92.79
95.65
81.16
2
97.42
98.81
92.42
93.47
95.92
83.19
3
96.82
98.28
91.38
93.34
95.79
82.50
4
96.36
98.09
90.20
92.61
95.73
81.21
are all improved, which proves the effectiveness of our method by combining different spatial-temporal paths. The STTransformer module can be stacked with several blocks to exploit better spatial-temporal information. So we evaluate the influence of the block number in Table 3. When the number of blocks is set to zero, only one FC layer is adopted to embed features. We find that a single block outperforms the baseline, which demonstrates the effectiveness of the STTransformer module. Stacking two STTransformer blocks reaches the best results on these benchmarks. So we adopt this setting in this paper.
5 Conclusion In this work, we propose a new transformer-based architecture termed InterFormer, which jointly models interactive relations in two complementary manners. Our network relies on a parallel encoder module and a STTransformer module. Ablation studies and comparisons against the state-of-the-art methods on two benchmarks have justified the effectiveness of the proposed approach. In the future, we would extend our model for additional datasets (Volleyball dataset, Collective Activity dataset, TVHI) to fully explore spatial-temporal interaction among individuals, which can also contribute to group activity recognition. Acknowledgements. This work was supported by NSFC under Grant 61801396 and Science and Technology Project of the State Grid Corporation of China (5700-202019186A-0-0-00).
References 1. Arnab, A., Dehghani, M., Heigold, G., Sun, C., Luˇci´c, M., Schmid, C.: Vivit: A video vision transformer. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6836–6846 (2021) 2. Bagautdinov, T., Alahi, A., Fleuret, F., Fua, P., Savarese, S.: Social scene understanding: Endto-end multi-person action localization and collective activity recognition. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4315–4324 (2017)
202
D. He et al.
3. Bertasius, G., Wang, H., Torresani, L.: Is space-time attention all you need for video understanding? (2021) 4. Cho, K., Van Merriënboer, B., Bahdanau, D., Bengio, Y.: On the properties of neural machine translation: Encoder-decoder approaches. arXiv preprint arXiv:1409.1259 (2014) 5. Diba, A., et al.: Temporal 3d convnets: New architecture and transfer learning for video classification. arXiv preprint arXiv:1711.08200 (2017) 6. Dosovitskiy, A., et al.: An image is worth 16x16 words: Transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020) 7. Ehsanpour, M., Abedin, A., Saleh, F., Shi, J., Reid, I., Rezatofighi, H.: Joint learning of social groups, individuals action and sub-group activities in videos. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12354, pp. 177–195. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58545-7_11 8. Feichtenhofer, C., Pinz, A., Zisserman, A.: Convolutional two-stream network fusion for video action recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1933–1941 (2016) 9. Gavrilyuk, K., Sanford, R., Javan, M., Snoek, C.G.: Actor-transformers for group activity recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 839–848 (2020) 10. Girdhar, R., Carreira, J., Doersch, C., Zisserman, A.: Video action transformer network. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 244–253 (2019) 11. Gu, F., Chung, M.H., Chignell, M., Valaee, S., Zhou, B., Liu, X.: A survey on deep learning for human activity recognition. ACM Comput. Surv. 54(8), 1–34 (2021). https://doi.org/10. 1145/3472290 12. Han, M., Zhang, D.J., Wang, Y., Yan, R., Yao, L., Chang, X., Qiao, Y.: Dual-ai: dual-path actor interaction learning for group activity recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2990–2999 (2022) 13. He, K., Gkioxari, G., Dollár, P., Girshick, R.: Mask r-cnn. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2961–2969 (2017) 14. Hu, G., Cui, B., He, Y., Yu, S.: Progressive relation learning for group activity recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 980–989 (2020) 15. Kong, Y., Jia, Y., Fu, Y.: Learning human interaction by interactive phrases. In: Fitzgibbon, A., Lazebnik, S., Perona, P., Sato, Y., Schmid, C. (eds.) ECCV 2012. LNCS, vol. 7572, pp. 300–313. Springer, Heidelberg (2012). https://doi.org/10.1007/978-3-642-33718-5_22 16. Li, S., Cao, Q., Liu, L., Yang, K., Liu, S., Hou, J., Yi, S.: Groupformer: Group activity recognition with clustered spatial-temporal transformer. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 13668–13677 (2021) 17. Pramono, R.R.A., Chen, Y.T., Fang, W.H.: Empowering relational network by self-attention augmented conditional random fields for group activity recognition. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12346, pp. 71–90. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58452-8_5 18. Qi, M., Wang, Y., Qin, J., Li, A., Luo, J., Van Gool, L.: Stagnet: An attentive semantic rnn for group activity and individual action recognition. IEEE Trans. Circuits Syst. Video Technol. 30(2), 549–565 (2019) 19. Qiu, Z., Yao, T., Mei, T.: Learning spatio-temporal representation with pseudo-3d residual networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 5533–5541 (2017) 20. Ryoo, M.S., Aggarwal, J.K.: UT-Interaction Dataset, ICPR contest on Semantic Description of Human Activities (SDHA). http://cvrc.ece.utexas.edu/SDHA2010/Human_Interaction.html (2010)
InterFormer: Human Interaction Understanding
203
21. Simonyan, K., Zisserman, A.: Two-stream convolutional networks for action recognition in videos. Advances in neural information processing systems, vol. 27 (2014) 22. Su, W., et al.: Vl-bert: Pre-training of generic visual-linguistic representations. arXiv preprint arXiv:1908.08530 (2019) 23. Sultani, W., Chen, C., Shah, M.: Real-world anomaly detection in surveillance videos. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 6479– 6488 (2018) 24. Sun, C., Myers, A., Vondrick, C., Murphy, K., Schmid, C.: Videobert: A joint model for video and language representation learning. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 7464–7473 (2019) 25. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 26. Touvron, H., Cord, M., Douze, M., Massa, F., Sablayrolles, A., Jégou, H.: Training dataefficient image transformers & distillation through attention. In: International Conference on Machine Learning, pp. 10347–10357. PMLR (2021) 27. Tran, D., Bourdev, L., Fergus, R., Torresani, L., Paluri, M.: Learning spatiotemporal features with 3d convolutional networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 4489–4497 (2015) 28. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł., Polosukhin, I.: Attention is all you need. Advances in neural information processing systems 30 (2017) 29. Wang, Z., Ge, J., Guo, D., Zhang, J., Lei, Y., Chen, S.: Human interaction understanding with joint graph decomposition and node labeling. IEEE Trans. Image Process. 30, 6240–6254 (2021). https://doi.org/10.1109/TIP.2021.3093383 30. Wang, Z., et al.: Understanding human activities in videos: A joint action and interaction learning approach. Neurocomputing 321, 216–226 (2018) 31. Wang, Z., Liu, S., Zhang, J., Chen, S., Guan, Q.: A spatio-temporal crf for human interaction understanding. IEEE Trans. Circuits Syst. Video Technol. 27(8), 1647–1660 (2017). https:// doi.org/10.1109/TCSVT.2016.2539699 32. Wang, Z., Meng, J., Guo, D., Zhang, J., Shi, J.Q., Chen, S.: Consistency-aware graph network for human interaction understanding. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 13369–13378 (2021) 33. Wu, J., Wang, L., Wang, L., Guo, J., Wu, G.: Learning actor relation graphs for group activity recognition. In: Proceedings of the IEEE/CVF Conference on computer vision and pattern recognition, pp. 9964–9974 (2019) 34. Xu, D., Fu, H., Wu, L., Jian, M., Wang, D., Liu, X.: Group activity recognition by using effective multiple modality relation representation with temporal-spatial attention. IEEE Access 8, 65689–65698 (2020) 35. Ye, H.J., Hu, H., Zhan, D.C., Sha, F.: Few-shot learning via embedding adaptation with set-toset functions. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8808–8817 (2020) 36. Zhou, B., Andonian, A., Oliva, A., Torralba, A.: Temporal relational reasoning in videos. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11205, pp. 831–846. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01246-5_49
UCLD-Net: Decoupling Network via Unsupervised Contrastive Learning for Image Dehazing Zhitao Liu, Tao Hong, and Jinwen Ma(B) School of Mathematical Sciences, Peking University, Beijing, China [email protected], [email protected], [email protected]
Abstract. From traditional algorithms based on handcrafted prior to learning algorithms based on neural networks, the image dehazing technique has gone through great development. The handcrafted prior-based methods need to first estimate the transmission map and atmosphere light in the atmospheric scattering model separately, and then calculate the final haze-free image, which often leads to a gradual accumulation of errors. In contrast, in the end-to-end neural networkbased methods, supervised learning with labels is a major element for the improvement of the dehazing effect. But in the physical situation, paired (hazy, haze-free) images are difficult to collect, which limits the application scope of supervised dehazing. To address this deficiency, we propose a Decoupling Network for image dehazing via Unsupervised Contrastive Learning mechanism which is widely used in self-supervised representation learning, named UCLD-Net. Specifically, we use the estimated transmission map and atmosphere light to design the structure of UCLD-Net and introduce prior knowledge to construct its loss function. It is demonstrated by the experiments that UCLD-Net achieves comparable results in the dehazing experiments on the benchmark RESIDE dataset, which sufficiently verifies its effectiveness. Keywords: Image Dehazing · Unsupervised Learning · Contrastive Loss
1 Introduction As a representative task with lots of application value in low-level computer vision, image dehazing has attracted the interest of many researchers in recent years. Like other similar tasks such as image denoising, image deraining, etc., image dehazing can be summarized as an image restoration problem. The atmosphere scattering model [1] is formulated as: I(x) = J(x)t(x) + A(1 − t(x)) Z. Liu, T. Hong—Equal contribution. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 204–215, 2023. https://doi.org/10.1007/978-981-99-4761-4_18
(1)
UCLD-Net: Decoupling Network via Unsupervised Contrastive
205
where I(x) and J(x) are the degraded hazy image and the target haze-free image respectively. A is the global atmosphere light, and t(x) is the medium transmission map. Moreover, we have t(x) = e−βd (x) with β and d (x) being the atmosphere scattering parameter and the scene depth, respectively. Since the transmission map t(x) and the atmosphere light A are often unknown in real scenarios, image dehazing is an ill-posed problem. Therefore the core challenge is to estimate t(x) and A properly, then we can restore the haze-free image as: J(x) =
I(x) − A +A t(x)
(2)
We can divide dehazing methods into two classes, traditional prior-based methods and learning-based methods. Classic traditional methods include DCP [2], BCCR [3], CAP [4], etc. With the rise of deep learning, many neural network methods are successively proposed, such as DehazeNet [5], MSCNN [6], AOD-Net [7], GCA-Net [8], FFA-Net [9], AECR-Net [10]. Driven by the supervised data, these networks are well designed to fulfill the dehazing task, behaving well on synthetic datasets. The above networks all belong to supervised learning, but their performance may decrease a lot when applied to a physically hazy situation, and the paired (hazy, haze-free) images are hard to collect on a large scale. To this end, some researchers try to adopt semisupervised learning or unsupervised learning to fulfill image dehazing so as to improve the model robustness and reduce the over-fitting risk. For example, inspired by the success of CycleGAN [11] in the unpaired image translation domain, CycleDehaze [12] and Dehaze-GLCGAN [13] are successively proposed to dehaze image via unsupervised GAN. However, existing unsupervised dehazing methods can not achieve competitive performance compared with supervised dehazing, with a not small decline in evaluation metrics such as PSNR.
Fig. 1. The framework of UCLD-Net.
In this paper, we are committed to making the most significance of unsupervised dehazing to improve its performance in the physical environment. Contrastive learning is a kind of self-supervised learning mechanism, which is widely used in representation learning nowadays, popular and effective. We propose a Decoupling Network framework
206
Z. Liu et al.
for image dehazing based on Unsupervised Contrastive Learning, named UCLD-Net and shown in Fig. 1, which does not demand paired (hazy, haze-free) images anymore. Moreover, with the help of prior knowledge, we can get finer dehazing performance. Specifically, we utilize the estimated transmission map and atmosphere light to design networks and construct extra loss terms based on dark channel prior (DCP), etc., to constrain the physical attributes of final dehazed images. In terms of experimental verification, we conduct experiments on the image dehazing benchmark RESIDE dataset and achieve comparable results compared with existing state-of-the-art methods, which sufficiently verify the effectiveness of UCLD-Net. Our contribution can be summarized as follows. • We propose an end-to-end image hazing framework based on unsupervised contrastive learning, and explore the function of contrastive mechanism in detail. • We are devoted to a finer dehazing effect by means of prior knowledge, from the perspective of network design and loss function, respectively. • We achieve comparable experimental performance of image dehazing in terms of quantity and quality.
2 Related Work As introduced in the previous section, image dehazing has evolved from traditional prior-based methods to learning-based methods. The dark channel prior (DCP) [2] is a brilliant discovery. Moreover, the boundary constraint and contextual regularization (BCCR) [3] and color attenuation prior (CAP) [4] are successively proposed. As for neural network methods, they usually adopt an encoder-decoder structure to learn restoration. AOD-Net [7] directly generates the clean image through a lightweight CNN, named All-in-One Dehazing Network. GCA-Net [8] means Gated Context Aggregation Network, which adopts the smoothed dilation convolution [14] to help remove the gridding artifacts, and leverages a gated sub-network to fuse the features from different levels. As for FFA-Net [9], i.e., Feature Fusion Attention Network, it combines Channel Attention with Pixel Attention mechanism. AECR-Net [10] proposes a contrastive regularization built upon contrastive learning to exploit both the information of hazy images and clear images as negative and positive samples, respectively. Apart from the above supervised methods, Chen et al. proposed a Principled Synthetic-to-real Dehazing (PSD) framework [15], i.e., a synthetic data pre-trained backbone, followed by unsupervised fine-tuning with real hazy images. In addition to the synthetic hazy image pairs, Yang et al. proposed a disentangled dehazing network to generate realistic haze-free images only using unpaired supervision [16], which leads to a new challenge. Zhu et al. [11] proposed CycleGAN to achieve a breakthrough in the image transformation problem, according to which many researchers construct unsupervised networks to achieve image dehazing and solve the domain migration problem between synthetic datasets and real hazy images. For example, Engin et al. [12] proposed Cycle-Dehaze, Anvari et al. [13] proposed Dehaze-GLCGAN, and Zheng et al. [17] proposed Dehaze-AGGAN.
UCLD-Net: Decoupling Network via Unsupervised Contrastive
207
3 Proposed Approach In this section, we describe the notations of image dehazing and then introduce the design details of our proposed approach: UCLD-Net. For the dehazing task, the hazy image and the haze-free image are usually denoted as I and J . Denoting the whole dehazing network as D, then for supervised dehazing, in general, it is optimized towards min L(D(I), J)
(3)
where L is the defined restoration loss function. As shown in Fig. 1, the framework of UCLD-Net is divided into two stages in the training process: Hazy2hazy stage followed by Clear2clear stage. In the first Hazy2hazy stage, a hazy image I(x) is input into the network to get the dehazing result fJ (I), the transmission map fT (I) and the global atmospheric light fA (I), where f∗ represent different sub-networks. Furthermore, fJ (I), fT (I) and fA (I) are fused into a hazy image I according to Eq. (1). The second stage is called Clear2clear, where a clear image J(x) is input and goes through a coupling and a decoupling process. Combined with the corresponding transmission map t(x) and the corresponding atmospheric light A generated by certain methods which are described in detail in Sec. 3.3, these three components are fused to generate a hazy image I(x). Subsequently, the generated I(x) is decoupled to dehazing result fJ (I), transmission map fT (I) and global atmospheric light fA (I), where the three sub-networks are shared with Hazy2hazy stage. In every training step, Hazy2hazy and Clear2clear stages are performed sequentially. It is important to emphasize that the hazy images used in Hazy2hazy and the clear images used in Clear2clear do not need to be paired, thereby we can achieve real unsupervised dehazing.
Fig. 2. The detailed structure of Hazy2hazy stage. C, ⊕ and ⊗ denote the operation of concatenation, element-wise sum and element-wise product, respectively. As for the illustration of MIB, TU and FFRB modules, please refer to Sec. 3.1.
3.1 Decoupling Sub-networks As shown in Fig. 2, the decoupling network contains three sub-networks: J-Net to calculate the dehazing result, T-Net to calculate the transmission map, and A-Net to calculate
208
Z. Liu et al.
the atmospheric light. These three sub-networks are kept consistent in Hazy2hazy and Clear2clear stages. What’s more, these sub-networks are not static, whose structure can be changed as desired, such as replaced with existing network modules. The design of J-Net and T-Net is inspired by the work EEA [19]. J-Net adopts three Multi-level Information Boost (MIB) modules to dehaze, and T-Net adopts three Transmission Units (TU), which have the same structure as MIB, to get the transmission image. After MIB or TU, the network feature flow is fed into a Feature Fusion Residual Block (FFRB). It is worth noting that there is information transfer between J-Net and T-Net. The outputs of the first two TUs of T-Net would be transferred to J-Net, which makes J-Net get more information about the depth and transmission. In previous studies, the atmospheric light of an image is often estimated as a single value, yet this is not very reasonable since there are often sky and shadow areas in the physical images. With a reasonable assumption that the atmospheric light A obeys a latent multi-dimensional Gaussian distribution, i.e., z ∼ N (μ, ), our UCLD-Net adopts the same A-Net structure as that of YOLY [20] to estimate A. A-Net utilizes a variational auto-encoder (VAE) [21] to fit the Gaussian distribution. 3.2 Hazy2hazy Stage As Fig. 2 shows, Hazy2hazy stage contains a discriminator module and a contrast learning module. Haze Discriminator The discriminator is adopted to determine whether the generated dehazing images are real or fake. J-Net could be considered as a generator and it cooperates with the discriminator to constitute a GAN, in which J-Net makes efforts to let the generated dehazing images fool the discriminator. In the training process of discriminator, not only the corresponding dehazing results fJ (I) but also additional images are used to promote the training effect. Unsupervised Contrastive Dehazing In Hazy2hazy stage, contrastive learning is also adopted to improve the effectiveness of training. In the feature space, we select the final synthetic hazy image I (x) as the anchor point. Then the corresponding input hazy image I(x) and other additional hazy images I i are set as positive samples to pull the anchor point, and the dehazing image fJ (I) and other additional clear images are set as negative samples to push the anchor point. There are two ways to choose the feature extractor to map 4 images into a feature space. One way is selecting a pre-trained network as the feature extractor such as a pre-trained VGG-16 [22], and the other choice is our trained sub-network such as J-Net. Optimized Loss Functions The loss functions used in this method are mainly referred to the YOLY network [20], which contains three loss functions (LI , LA , Lreg ). Besides, we add three additional loss items: (LDCP , Lcontrast and Ladv ) to improve the training process. LI loss evaluates the difference between the synthetic hazy image I (x) and the original hazy image I(x), and it’s usually taken as the L1 norm: LI = ||I(x) − I (x)||1
(4)
UCLD-Net: Decoupling Network via Unsupervised Contrastive
209
LA loss constrains the atmosphere light and it can be expressed as LA = ||fA (I) − Aother (x)||2 + KL(A)
(5)
The first term is the L2 norm of a gap between the atmosphere light fA (I) estimated by the sub-network and the atmosphere light Aother (x) estimated by other methods such as DCP [2]. And the second term utilizes the Kullback-Leibler (KL) divergence to evaluate the distance between the multi-dimensional Gaussian distribution of the atmosphere light A and the standard Gaussian distribution. Mathematically, KL(A) = KL(N (μ, )||N (0, I )) (6) 1 = ((μzi )2 + (σzi )2 − 1 − log(σzi )2 ) 2 i
where z denotes the latent variable of A, zi is the i-th dimension of z, μzi and σz2i denote the mean and variance of zi , respectively. Lreg loss is used as a regularization term to avoid overfitting. Formally, Lreg =
1 M 1 (xi − yi )2 i=1 yi ∈(xi ) 2M |(xi )|
(7)
where (xi ) denotes the neighborhood of xi , |(xi )| denotes the number of pixel points in the neighborhood (xi ), and M denotes the total number of pixel points of an image. DCP Loss. The dark channel prior [2] shows that there is a high probability that, in clear (haze-free) images, the darkest channel of three channels RGB of pixels has a smaller value, which can be used to check whether the image is clear or hazy. Formally, DCP is calculated as below: Jdark (x) = min ( min J c (x)) y∈(X ) c∈{R,B,G}
(8)
where (x) denotes the neighborhood of pixel x. To ensure the dehazed images satisfy DCP, we construct LDCP loss as LDCP = ||Jdark (fJ (I))||1
(9)
Note that the brightest 0.1% of the pixels in the dark channel are selected to estimate the global atmospheric light, abbreviated as ADCP . Then we can specify the Aother (x) term in Eq. (5) as ADCP (x). Contrastive Loss. When calculating the contrastive loss function Lcontrast , we set the reconstructed hazy image I (x) as the anchor, take the original hazy image I(x) as the positive sample, and take the clear image J(x) as the negative sample. in addition, we take N1 other hazy images I i as additional positive samples, and N2 other clear images J i as additional negative samples. Formally, Lcontrast =
D(anchor, I) + D(anchor, J) +
N1
i=1 D(anchor, I i )
N2
i=1 D(anchor, J i )
(10)
210
Z. Liu et al.
where D(x, y) donates the distance between two elements in the feature space. If the feature extractor E is used, then D(x, y) can be expressed as D(x, y) = ||E(x) − E(y)||P where || · ||P denotes the general P norm of a matrix or a vector. Adversarial Loss. The Haze discriminator is used to discriminate whether the dehazing result fJ (I) is real or fake, which could prevent image distortion. We refer to the discriminator of the least-square GAN [23] and define the loss item with the least-square function: (11) Ladv (fJ ) = EfJ (I) Pfake (Discriminator(fJ (I)) − 1)2 Ladv (D) = EfJ (I)∼Pfake (Discriminator(fJ (I)) − 0)2 +Ey∼Preal (Discrimator(y) − 1)2
(12)
where y denotes another true clear image. Ladv (fJ ) and Ladv (fJ ) are used for the training of J-Net and the discriminator respectively. Summarizing all the above loss items, the final loss function Ltotal of UCLD-Net is calculated as: Ltotal = λ1 LI + λ2 LA + λ3 Lreg +λ4 LDCP + λ5 Lcontrast + λ7 Ladv (fJ )
(13)
where λi , i = 1, 2, · · ·, 6 are regular coefficients (the same as below). 3.3 Clear2clear Stage As shown in Fig. 1, a clear image J(x) is input and goes through a coupling and decoupling process. It is worth noting that the clear images selected here don’t need to be paired, which is different from supervised methods. Given a clear image J(x), we can obtain its depth map d (x) according to a pre-trained depth-estimation network. Next, randomly selecting the values of atmospheric light A and atmospheric scattering coefficient β, we can synthesize a hazy image I(x) according to Eq. (1). Specifically, the value interval of A is taken as [0.6, 0.97] and the value interval of β is taken as [0.3, 1.0]. After getting the synthesized hazy image, we can fulfill its decoupling process, the same as Hazy2hazy stage. Performing decoupling operations on I(x) with different sub-networks, fJ (I), fT (I) and fA (I) are obtained, respectively. The loss function of Clear2clear stage consists of three main components. The first loss item is calculated between the input clear image J(x) and the dehazing result fJ (I) in the raw space and the feature space: LJ = ||J(x) − fJ (I)||1 + λJ D(J(x), fJ (I))
(14)
where D(x, y) has the same meaning of Eq. (10). The second loss item evaluates the difference in the transmission map between T(x) and fJ (I). Formally, LT = ||T(x) − fT (I)||1 + λT D(T(x), fT (I))
(15)
UCLD-Net: Decoupling Network via Unsupervised Contrastive
And the third loss item describes the gap between A and fJ (I): )||N (0, I )) LA = ||A(x) − fA (I)||1 + KL(N (μ,
211
(16)
where the second KL item is the same as Eq. (6). Finally, the total loss function is calculated as below: Ltotal = λ1 LJ + λ2 LT + λ3 LA
(17)
4 Experimental Results 4.1 Dataset and Implementation Details The image dehazing benchmark universally adopted nowadays is RESIDE [24], which contains synthetic hazy images in both indoor and outdoor scenarios. We adopt ITS (Indoor Training Set) and OTS (Outdoor Training Set) for training, SOTS (Synthetic Objective Test Set) and HSTS (Hybrid Subjective Test Set) for testing. As for the evaluation metrics, we adopt the common PSNR (Peak Signal to Noise Ratio) and SSIM (Structural Similarity). For our UCLD-Net, the optimizer is Adam with default parameters, the initial learning rate is set to α = 0.0001, and it decays by a cosine attenuation. During the training process, images are cropped randomly and resized as 240 × 240, and random rotation and flipping operations are performed for data augmentation. The number of positive and negative samples is set to 4. All experiments are conducted on NVIDIA TITAN Xp GPUs in PyTorch framework. 4.2 Quantitative and Qualitative Evaluation for Image Dehazing In this section, we will compare UCLD-Net with previous state-of-the-art image dehazing methods both quantitatively and qualitatively. These methods can be classified into four categories: 1) prior-based DCP [2] and BCCR [3]; 2) supervised-based, AOD-Net [7], GFN [25], EPDN [26], GCANet [8]; 3) semi-supervised, Semi-dehazing [18]; 4) unsupervised-based, CycleGAN [11], Cycle-Dehaze [12], LIGHT-Net [27], YOLY [20]. We train on ITS dataset and test on SOTS (including indoor and outdoor subsets) and HSTS datasets. Focusing on Table 1, we can find that UCLD-Net performs relatively well on three test benchmarks: SOTS (indoor), SOTS (outdoor) and HSTS. Under different learning modes, UCLD-Net can be ranked on the top two in terms of evaluation metrics: PSNR and SSIM, and it’s almost the best among unsupervised dehazing. Note that GCA-Net and Semi-dehazing methods have some metrics better than our ULCD-Net. Nevertheless, they both use paired (hazy, haze-free) images for training, which are not required for the unsupervised training of UCLD-Net. Moreover, we illustrate the qualitative comparisons of dehazing results in Fig. 3. By observation, we can find that DCP produces halo artifacts because of its underlying prior assumption. AOD-Net couldn’t sometimes dehaze entirely and the brightness of
212
Z. Liu et al.
Table 1. Quantitative results on RESIDE for different dehazing methods. The top-2 performances are highlighted in Bold and Italic, respectively. Method DCP
Method Type Prior
SOTS(indoor)
SOTS(outdoor)
HSTS
PSNR
SSIM
PSNR
SSIM
PSNR
SSIM
16.62
0.818
18.38
0.819
17.01
0.803
BCCR
Prior
16.88
0.791
15.71
0.769
15.21
0.747
AON-Net
Supervised
19.06
0.850
20.08
0.861
19.68
0.835
GFN
Supervised
22.31
0.880
21.49
0.838
22.94
0.894
EPDN
Supervised
21.55
0.907
22.57
0.763
20.37
0.877
GCA-Net
Supervised
30.23
0.980
21.66
0.867
21.37
0.874
Semi-dehazing
Semi-supervised
24.44
0.891
24.79
0.923
24.36
0.889
CycleGAN
Unsupervised
14.21
0.576
17.32
0.706
16.05
0.824
Cycle-Dehaze
Unsupervised
17.16
0.693
18.60
0.797
17.96
0.905
LIGHT-Net
Unsupervised
22.57
0.903
23.11
0.917
22.27
0.871
YOLY
Unsupervised
19.41
0.833
20.39
0.889
21.02
0.777
UCLD-Net
Unsupervised
25.13
0.926
23.55
0.919
23.82
0.933
dehazing images tends to be low. For GCA-Net, some loss of high-frequency information is manifested in the corresponding dehazed images. The power of Cycle-Dehaze is not worthy of conviction, because there is a lot of haze left. To some extent, YOLY could cause color distortion and loss of detailed information. And Semi-dehazing also does not remove haze completely. Compared to these methods, our UCLD-Net behaves much better, achieving a level of dehazing effect while maintaining finer details. 4.3 Ablation Study To further illustrate the effectiveness of our proposed several modules, we make ablation studies here. As shown in Table 2, we can find that Clear2clear stage and all the added loss items in Hazy2hazy stage bring positive effects, and the combination of them generates the best results. When we add Clear2clear stage, the performance improvement is the most significant, for which we speculate the reason is that a lot of distribution information about real clear images is provided by this module. Since our UCLD-Net is somewhere inspired by YOLY [20], we try to transplant the proposed modules of UCLD-Net to YOLY to improve its dehazing performance. The specific ablation results are shown in Table 3, which share the same pattern as our UCLD-Net and verify the necessity and effectiveness of our proposed modules again.
UCLD-Net: Decoupling Network via Unsupervised Contrastive
213
Fig. 3. Qualitative comparisons on SOTS (the top 2 rows for indoor and the middle 2 rows for outdoor) and real-world hazy images (the bottom 2 rows, without corresponding ground truth images) for different dehazing methods. Zoom in on the green rectangle area for more details.
Table 2. Ablation study with different modules for UCLD-Net. LJ , LA , Lreg √
Lcontrast
√
√
√
√
√
√
√
√
√
√
√
√
√
Clear2clear
LDCP
Ladv
√
PSNR
SSIM
10.35
0.3875
11.15
0.3971
21.68
0.8322
23.71
0.9048
25.13
0.9258
Table 3. Ablation study for transplanting the modules of our UCLD-Net to YOLY. LJ , LA , Lreg √
Lcontrast
√
√
√
√
√
√
√
√
√
√
√
√
√
Clear2clear
LDCP
Ladv
√
PSNR
SSIM
16.26
0.7455
17.26
0.7380
19.47
0.8117
19.51
0.8187
19.53
0.8147
Comparing Table 3 with Table 2, we can notice a phenomenon that, without Clear2clear stage, the performance of UCLD-Net is degraded a lot while the performance of YOLY is not much degraded. Analyzing the reasons for this, the network
214
Z. Liu et al.
structure of our UCLD-Net is much more complicated than YOLY, therefore UCLDNet is much easier to overfit, especially without much information for reference during training. However, after adding the two-stage training including Clear2clear, the overfitting dilemma is greatly alleviated since the information available for learning is much increased. Furthermore, with the constraint of DCP loss and adversarial loss, the dehazing effect is further refined and improved.
5 Conclusion In this paper, we propose an end-to-end decoupling network for image dehazing via unsupervised contrastive learning, name UCLD-Net. Combined with prior knowledge to design network structure and loss function, our UCLD-Net achieves competitive dehazing results while does not require paired (hazy, haze-free) images for training, which is very valuable in practical applications. The outperformance over existing methods is demonstrated both in quantity and quality. In future work, it is worth studying to further cooperate the power of unsupervised dehazing and contrastive learning, and explore their interpretability. Acknowledgements. This work is supported by the Natural Science Foundation of China under grant 62071171 and the high-performance computing platform of Peking University.
References 1. Narasimhan, S.G., Nayar, S.K.: Vision and the atmosphere. Int. J. Comput. Vis. 48(3), 233–254 (2002). https://doi.org/10.1023/A:1016328200723 2. He, K., Sun, J., Tang, X.: Single image haze removal using dark channel prior. IEEE Trans. Pattern Anal. Mach. Intell. 33(12), 2341–2353 (2010) 3. Meng, G., Wang, Y., Duan, J., Xiang, S., Pan, C.: Efficient image dehazing with boundary constraint and contextual regularization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 617–624 (2013) 4. Zhu, Q., Mai, J., Shao, L.: A fast single image haze removal algorithm using color attenuation prior. IEEE Trans. Image Process. 24(11), 3522–3533 (2015) 5. Cai, B., Xiangmin, X., Jia, K., Qing, C., Tao, D.: Dehazenet: An end-to-end system for single image haze removal. IEEE Trans. Image Process. 25(11), 5187–5198 (2016) 6. Ren, W., Liu, S., Zhang, H., Pan, J., Cao, X., Yang, M.-H.: Single image dehazing via multiscale convolutional neural networks. In: European Conference on Computer Vision. Springer, pp. 154–169 (2016) 7. Li, B., Peng, X., Wang, Z., Xu, J., Feng, D.: Aod-net: All-in-one dehazing network. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 4770–4778 (2017) 8. Chen, D., et al.: IEEE winter conference on applications of computer vision (WACV). IEEE 2019, 1375–1383 (2019) 9. Qin, X., Wang, Z., Bai, Y., Xie, X., Jia, H.: Ffa-net: Feature fusion attention network for single image dehazing. Proc. AAAI Conf. Artif. Intell. 34, 11908–11915 (2020) 10. Wu, H., et al.: Contrastive learning for compact single image dehazing. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10551–10560 (2021)
UCLD-Net: Decoupling Network via Unsupervised Contrastive
215
11. Zhu, J.-Y., Park, T., Isola, P., Efros, A.A.: Unpaired image-to-image translation using cycleconsistent adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2223–2232 (2017) 12. Engin, D., Gen, A., Ekenel, H.K.: Cycle-dehaze: Enhanced cyclegan for single image dehazing. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, pp. 825–833 (2018) 13. Anvari, Z., Athitsos, V.: Dehaze-glcgan: unpaired single image de-hazing via adversarial training. arXiv preprint arXiv:2008.06632 (2020) 14. Wang, Z., Ji, S.: Smoothed dilated convolutions for improved dense prediction. Data Min. Knowl. Disc. 35(4), 1470–1496 (2021) 15. Chen, Z., Wang, Y., Yang, Y., Liu, D.: Psd: Principled synthetic-to-real dehazing guided by physical priors. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7180–7189 (2021) 16. Yang, X., Xu, Z., Luo, J.: Towards perceptual image dehazing by physics-based disentanglement and adversarial training. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 17. Zheng, Y., Jia, S., Zhang, S., Tao, M., Wang, L.: Dehaze-aggan: Unpaired remote sensing image dehazing using enhanced attention-guide generative adversarial networks. IEEE Trans. Geosci. Remote Sens. 60, 1–13 (2022) 18. Li, L., et al.: Semi-supervised image dehazing. IEEE Trans. Image Process. 29, 2766–2779 (2019) 19. Wang, C., et al.: Eaa-net: A novel edge assisted attention network for single image dehazing. Knowl.-Based Syst. 228, 107279 (2021) 20. Li, B., Gou, Y., Gu, S., Liu, J.Z., Zhou, J.T., Peng, X.: You only look yourself: Unsupervised and untrained single image dehazing neural network. Int. J. Comput. Vis. 129(5), 1754–1767 (2021) 21. Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013) 22. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. In: International Conference on Learning Representations (2015) 23. Mao, X., Li, Q., Xie, H., Lau, R.Y.K., Wang, Z., Smolley, S.P.: Least squares generative adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2794–2802 (2017) 24. Li, B., et al.: Reside: A benchmark for single image dehazing, vol. 1. arXiv preprint arXiv: 1712.04143 (2017) 25. Ren, W., et al.: Gated fusion network for single image dehazing. In” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3253–3261 (2018) 26. Qu, Y., Chen, Y., Huang, J., Xie, Y.: Enhanced pix2pix dehazing network. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8160–8168 (2019) 27. Dudhane, A., Patil, P.W., Murala, S.: An end-to-end network for image de-hazing and beyond. IEEE Trans. Emerg. Top. Comput. Intell. 6(1), 159–170 (2022)
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model for Unsupervised Medical Image Registration Zhaoyang Liu1,2 , Xiuyang Zhao2(B) , Dongmei Niu2 , Bo Yang2 , and Caiming Zhang3 1 Shandong Provincial Key Laboratory of Network Based Intelligent Computing,
University of Jinan, Jinan, China 2 School of Information Science and Engineering, University of Jinan, Jinan, China
[email protected] 3 School of Software, Shandong University, Jinan, China
Abstract. Medical image registration is a crucial task for numerous medical applications, such as radiation therapy planning and surgical navigation. While deep learning methods have shown promise in improving registration accuracy, existing unsupervised registration methods based on convolutional neural networks struggle to capture long-range spatial relationships between voxels. Additionally, unsupervised registration methods based on Transformer are limited by their dependence on the induction bias of convolutional neural networks and the complexity of global attention. To address these limitations, we present LMConvMorph, a large kernel modern hierarchical convolutional model for unsupervised deformable medical image registration. LMConvMorph leverages larger receptive fields to identify spatial correspondence and employs a hierarchical design with smaller parameters to extract features at different scales, enabling effective feature extraction between moving and fixed images. Our approach yields significant improvements in registration performance. LMConvMorph is evaluated on a 3D human brain magnetic resonance image dataset, and the qualitative and quantitative results demonstrate its competitiveness with other baseline methods. Keywords: Image Registration · Medical Image · Convolutional Neural Network
1 Introduction Deformable medical image registration allows for the alignment of image pairs with complex deformations, such as those found within organs or tissues. This technology is crucial in medical imaging [1–3], as it establishes a correspondence between the content of two images and aligns corresponding content in position, providing more comprehensive information for diagnosis, treatment planning, and image-guided interventions. Traditional image registration algorithms typically iteratively compute parameter transformations, including feature matching, spatial transformation, and similarity calculation, to minimize a similarity metric between the input image pairs. Traditional image © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 216–226, 2023. https://doi.org/10.1007/978-981-99-4761-4_19
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model
217
registration methods such as standard symmetric normalization(SyN) [4] and Demons [5], have gained wide acceptance. These methods are known to be computationally expensive due to iterative optimization of large-scale parameters, which limit their clinical application and popularization. Recently, many researchers have also explored deep learning-based image registration [6–8]. Compared to traditional iterative methods, deep learning can transfer the iterative time to the model training time, resulting in significantly less actual inference time than traditional methods. The substantial reduction in registration time is one of the important advantages of deep learning-based image registration. In recent years, the focus of medical image registration task has shifted to the study of unsupervised methods that do not rely on labels. Balakrishnan et al. [9] proposed VoxelMorph, which uses a convolutional neural network (CNN) architecture and auxiliary segmentation to map input image pairs to deformation fields, maximizing a standard image matching objective function. De Vos et al. [10] achieved affine and deformable registration through unsupervised training by stacking several networks. With current trends in computer vision and medical imaging, transformer-based models such as ViT-VNet [11] have demonstrated good performance in image registration by simply bridging ViT [12] and V-Net [13]. Chen et al. [14] further extended ViT-V-Net with TransMorph, which captures semantic correspondences between the fixed and moving images by using a more advanced Swin-Transformer [15] architecture as its backbone. Due to the inherent limitations of convolutional neural networks, namely local receptive field, the existing CNN-based methods cannot obtain remote spatial relationships [12]. Although transformer has achieved remarkable success in deformable medical image registration, transformer-based approaches still rely on the prior knowledge of CNNs [16]. Therefore, the study of medical image registration based on CNNs is necessary. To address the above issues, we propose a new registration framework using modernized hierarchical convolution with large kernel sizes based on the fully convolutional architecture, ConvNeXt [17]. This framework, called LMConvMorph, is designed for volumetric medical image registration, leveraging large kernel hierarchical convolution to effectively capture spatial correspondences between moving and fixed images. The main contributions of this work are as follows: – We utilized depthwise convolution with large kernel sizes to capture spatial correspondences between moving and fixed images using a larger receptive field. – We introduced a hierarchical design to extract features of different scales with fewer parameters, which enhances the network’s ability to recognize spatial relationships between moving and fixed images. – We conducted a series of experiments to validate the effectiveness of the proposed deformable image registration model on the brain MRI registration dataset ADNI. By comparing with different types of registration algorithms, qualitative and quantitative experimental results proved the robustness and accuracy of our algorithm.
218
Z. Liu et al.
2 Method 2.1 Problem Formulation The registration problem is an energy function optimization problem, and the energy function is defined as: φˆ = arg minLsim (If , Im ◦ φ) + λR(φ) φ
(1)
where Im and If representthe moving and fixed images, respectively, and φ represents the deformation field. Lsim If , Im ◦ φ measures the alignment between the moved image Im ◦ φ and the fixed image If . The regularization term R(φ) enforces spatial smoothness on the deformation field, and λ is the regularization weight. The optimal deformation ˆ is obtained by minimizing the sum of the similarity metric Lsim If , Im ◦ φ and field, φ, the regularization term R(φ).
Fig. 1. Overview of the proposed method for deformable image registration.
2.2 Framework Overview The framework overview of our model is presented in Fig. 1. The network architecture is composed of an encoder-decoder with skip connections to better recover spatial details. Our method proposes a modern hierarchical convolutional network with large kernels for the encoder part to effectively obtain the spatial information between the input image pairs with larger receptive fields. The encoder takes both the fixed and moving images as input to obtain the spatial correspondence between them, and the decoder generates dense predictions, i.e., the deformation field, with the same size as the original input images. The spatial misalignment between the output moved image and the fixed image will lead to a change in the loss function, which is backpropagated to the encoder. Finally, we apply nonlinear transformation to moving images by using spatial transformation network.
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model
219
2.3 The Design of LMConvMorph Figure 2(a) illustrates the network architecture of the proposed LMConvMorph. The method utilizes a large kernel modern hierarchical convolutional neural network (LMConv) as the encoder to obtain the moving-fixed correspondence. The decoder, which is a 3D convolutional neural network (Conv3D), processes the obtained spatial correspondence into a dense displacement field. Skip connections are used to maintain local information flow between the encoder and decoder. As shown in Fig. 2(b), the LMConv block consists of a 7 × 7 × 7 depth convolution, two 1 × 1 × 1 layers, a layer normalization (LN), and a non-linear GELU activation (Gaussian Error Linear Unit, a smoother variant of ReLU).
Fig. 2. (a) The architecture of the proposed LMConvMorph registration network. (b) LMConv Block
Large Convolutional Kernels. The use of large convolutional kernels with their wider receptive fields can effectively capture long-range spatial relationships in deformed areas [18], leading to improved registration accuracy of these anatomical structures. Based on this, we can rethink the role of using large convolutional kernels in the registration process. Volumetric Depthwise Convolution. Similar to the weighted sum operation in selfattention, depthwise convolution operates on each channel separately, meaning that information is mixed only in the spatial dimensions [17]. We employ large kernel sizes to compress the medical image features through depthwise convolution, thus expanding the width of the network. This ensures that each convolutional kernel channel convolves with its corresponding input channel, allowing the output features to have the same channel dimension as the input. Inverted Bottleneck. The LMConv block utilizes the inverted bottleneck design, which is originally proposed in MobileNetV2 [19]. The intermediate expansion layer leverages
220
Z. Liu et al.
depthwise convolutions to filter features, acting as a source of non-linearity. This structure can partially reduce the parameter size of the model while improving the overall performance, with a increase in accuracy. Layer Normalization. Layer normalization (LN) [20] is applied before the 1 × 1 × 1 convolutional layer. Batch normalization (BN) [21] is a widely used strategy in existing deep convolutional neural network structures, which normalizes the convolutional representation to enhance convergence and reduce overfitting. However, the batch size for registration tasks is set to 1. Therefore, this method uses simple layer normalization instead of batch normalization. LMConv utilizes a large convolution kernel to map the original image to a lowresolution latent space feature representation. This approach not only reduces memory usage, parameter quantity, and computation, but also enlarges the network’s receptive field to process the input image information into the required format for the encoder input. After processing, the image is inputted into the encoder for feature extraction. The distinctive features of our method are its lightness and large receptive field. The LMConv encoder employs a multi-stage design, with each stage having a different feature map resolution. Learning information at multiple scales is crucial for registration. Following the design in the paper [17], the number of blocks per stage is set to (3, 3, 9, 3), with the number of channels for each stage being doubled (96, 192, 384, 768). LMConv employs a "patchify" strategy using a patchify layer implemented with 4 × 4 × 4 convolutions with a stride of four to adapt to the multi-stage architecture. However, this transformer-style design may not provide high-resolution feature maps and aggregate local information in the lower stages. To solve the issue, we use original Conv3D blocks to extract local information and obtain high-resolution feature maps by taking original and downsampled images as input. The decoder is composed of convolutional layers, activation functions, and normalization layers. The convolutional layers use a 3D convolution with a kernel size of 3 and padding of 1. The activation function used is Leaky ReLU and normalization is performed using batch normalization (BN). The decoder has channel numbers of (384, 192, 96, 48, 16). The multiscale outputs of each stage in the encoder are linked via long skip connections to the decoder based on a convolutional neural network, forming a U-shaped network for downstream registration tasks. 2.4 Loss Functions The loss function in our method is as follows: L(If , Im , φ) = Lsim (If , Im , φ) + λR(φ)
(2)
where Lsim (·) is the image similarity measure, and R(φ) is the deformation field regularization. Image Similarity Measure. In our unsupervised learning settings, we use the mean squared error (MSE) as the similarity metrics Lsim (·) to quantify the distance between If and Im , as follows: 1 MSE(If , Im , φ) = |If (x) − [Im ◦φ](x)|2 (3) x∈
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model
221
where x denotes the voxel location, [Im ◦ φ] denotes the image after the moving image is warped by the deformable field, and denotes the image domain. Deformation Field Regularization. In order to impose smoothness in the deformation field, a regularizer R(φ) is introduced in the loss function. R(φ) induces displacement value at one position to be similar to those of its neighboring positions, thereby promoting smoothness in the deformation field. The diffusion regularizer on the spatial displacement gradient is utilized to obtain a smooth displacement field φ, as follows: R(φ) = ||∇φ(x)||2 (4) x∈
3 Experiments and Results 3.1 Dataset and Preprocessing This experiment utilizes the ADNI [22] dataset comprising of 450 MRI scans of the brain. The experimental dataset is subject to standard preprocessing using FreeSurfer [23], which include skull stripping, resampling, and affine transformation. FreeSurfer is also used to obtain labeled maps of 30 anatomical structures to evaluate registration performance. The MRI scans is resampled to a resolution of 256 × 256 × 256 with an isotropic voxel size of 1mm × 1mm × 1mm, and then cropped to 160 × 192 × 224. The ADNI is divided into 300, 50, and 100 volumes for train, validation, and test. 3.2 Implementation details The presented method is performed by PyTorch [24] on an NVIDIA GeForce RTX 3090 GPU. For all experiments, we adopt the Adam optimizer to optimize the network parameters. The learning rate is set to 1 × 10−4 and the batch size is set to 1. In addition, the maximum training epochs for all methods are set to 500. 3.3 Baseline methods To demonstrate the superiority of this work, we conduct a comparison study with a traditional registration method (SyN) and three deep learning-based registration methods (VoxelMorph, Vit-V-Net, TransMorph). To ensure a fair comparison, we employ the same loss function for all deep learning-based methods, which comprises mean squared error (Eq. 3) and diffusion regularization (Eq. 4). Furthermore, all methods are trained separately on the same dataset splits. – SyN: We use the SyN implementation in the publicly available ANTs software package [25]. We use the mean squared difference (MSQ) as the objective function and set the iteration to (160, 80, 40). – VoxelMorph: We use VoxelMorph-2 from the VoxelMorph paper [9], the loss function consists of MSE and diffusion regularization, and the regularization hyperparameter λ is set to 0.02. – Vit-V-Net: A Vision Transformer based network structure proposed in [12]. We adopt the same hyperparameter settings as the proposed method [11]. – TransMorph: This is a registration framework based on Swin Transformer [15]. The hyperparameter of TransMorph is set as suggested in [14].
222
Z. Liu et al.
3.4 Evaluation metrics The Dice Similarity Coefficient (DSC) is utilized to evaluate the registration performance of each model by measuring the voxel overlap between corresponding anatomical structures. The DSC values range between 0 and 1, with higher values indicating greater similarity between corresponding regions of the two images and better registration performance. The average DSC is computed for all anatomical/organ structures in the test set. Additionally, we evaluate the smoothness of the deformation field by calculating the percentage of non-positive values in the Jacobian matrix determinant of the deformation field. 3.5 Results For quantitative analysis, an atlas-based registration method is used to evaluate the registration performance by calculating the deformation fields φ between the atlas and each image in the dataset. First, the test images are registered with the atlas image (i.e., the fixed image in the registration framework) to obtain a deformation field φ. The deformation field φ is then used to register specific anatomical structures, obtain the registered anatomical structures, and calculate the DSC and the Jacobian determinant of the deformation field φ. The quantitative evaluations are shown in Table 1. The Dice score and percentage of voxels with a non-positive Jacobian determinant (i.e., folded voxels) are evaluated for different baseline methods. The bold numbers denote the highest scores. Supplementary metrics, such as average running time required to register each pair of scans, the number of parameters, and multiply-accumulate operations in the ADNI dataset, are also provided. Table 1. Quantitative evaluations results. Model
DSC
% of Jφ ≤ 0
Times(s)
Affine
0.525 ± 0.062
-
-
Params(M) -
SyN
0.725 ± 0.014
< 0.0001
176
-
VoxelMorph
0.730 ± 0.016
0.343 ± 0.170
0.1529
0.301
MACs(G) 399.808
ViT-V-Net
0.734 ± 0.013
0.356 ± 0.181
0.2753
31.495
388.912
TransMorph
0.741 ± 0.013
0.316 ± 0.129
0.2759
46.660
712.066
Our
0.744 ± 0.013
0.316 ± 0.128
0.2431
52.491
733.494
The results of the proposed method are also shown in the Table 1, where LMConvMorph achieved the highest Dice score of 0.744. When we conduct hypothesis testing on the results using the paired t-test, the p-values for LMConvMorph over all other methods are p 0.05, with an exception of LMConvMorph over TransMorph, which is p = 0.054. In addition, we also compare the efficiency of different methods. The traditional registration method SyN is the slowest, taking 176 s to register a pair of
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model
223
images. Our proposed method has a shorter registration time compared to Transformerbased methods (including ViT-V-Net and TransMorph). It can achieve registration of a three-dimensional image case in 0.25 s, which is sufficient for real-time clinical diagnosis. VoxelMorph ViT-V-Net TransMorph
Ours
Moving
Fixed
Moving
Fixed
Moving
Fixed
SyN
Fig. 3. Qualitative comparison of various registration methods. We select 2D slices from the 3D MR images to better demonstrate the qualitative results of the experiment.
We provide a comparison of visualized results of our method with other contrasting medical image registration methods. To demonstrate the effectiveness of our work, we show the generated deformation fields and the registered images of different methods. Figure 3 displays the visualized results of all the methods. Moreover, to inspect the importance of the proposed framework over specific anatomical regions, we visualize the Dice score for each structure as boxplots in Fig. 4.
224
Z. Liu et al.
Fig. 4. Boxplot illustrates the Dice score for SyN, VoxelMorph, ViT-V-Net, TransMorph, and our proposed method over 9 anatomical regions.
3.6 Ablation Study We conduct ablation experiments to verify the effect of depth convolution of different sizes of convolution nuclei on registration. When the convolution kernel k is set to 0, all LMConv encoders in the network architecture in Fig. 2(a) are replaced with Conv3D. The results are shown in Table 2. When the large kernel modern hierarchical convolutional block designed in this chapter is not used, the registration effect is significantly reduced, and the DSC value is 0.731. The experiment also verifies that with the increase of the convolution kernel size, the receptive field increases correspondingly, and more structural information between image pairs can be extracted, and the folding rate of deformation field is also lower. After testing several convolution kernels of different sizes, including 3,5,7,9 and 11, it is found that when the kernels size is 7, DSC reaches the optimal value of 0.744. The results prove that the large kernel modern hierarchical convolutional block can effectively comprehend the spatial correspondence between the moving image and the fixed image. Table 2. Quantitative evaluations results of ablation study k
DSC
% of Jφ ≤ 0
0
0.731 ± 0.017
0.351 ± 0.170
3
0.737 ± 0.014
0.338 ± 0.132
5
0.738 ± 0.013
0.341 ± 0.144
7
0.744 ± 0.013
0.316 ± 0.128
9
0.742 ± 0.015
0.311 ± 0.116
11
0.739 ± 0.014
0.303 ± 0.149
LMConvMorph: Large Kernel Modern Hierarchical Convolutional Model
225
4 Conclusion In this study, we introduced a novel registration framework based on modern hierarchical convolutional neural networks. Specifically, we designed a large kernel modern hierarchical convolutional block to efficiently extract features between moving and fixed images, using a larger receptive field to enhance the network’s ability to identify spatial correspondence between image pairs. Additionally, we implemented a hierarchical structure to extract features at different scales, further improving the network’s capability to recognize the spatial relationship of medical images. We evaluated the role of LMConvMorph in inter-patient brain MR registration tasks. These results indicated that this registration model has great potential in the application of brain MRI image registration, and could provide strong support for clinical medicine and neuroscience research. Acknowledgements. This research was supported by Natural Science Foundation of Shandong province (Nos. ZR2019MF013, ZR2020KF015), Project of Jinan Scientific Research Leader’s Laboratory (No. 2018GXRC023).
References 1. Bharati, S., Mondal, M.R.H., Podder, P., Prasath, V.B.S.: Deep learning for medical image registration: a comprehensive review. arXiv preprint arXiv:2204.11341 (2022) 2. Boveiri, H.R., Khayami, R., Javidan, R., Mehdizadeh, A.: Medical image registration using deep neural networks: a comprehensive review. Comput. Electr. Eng. 87, 106767 (2020) 3. Haskins, G., Kruger, U., Yan, P.: Deep learning in medical image registration: a survey. Mach. Vis. Appl. 31(1–2), 1–18 (2020). https://doi.org/10.1007/s00138-020-01060-x 4. Avants, B., Epstein, C., Grossman, M., Gee, J.: Symmetric diffeomorphic image registration with cross-correlation: evaluating automated labeling of elderly and neurodegenerative brain. Med. Image Anal. 12, 26–41 (2008) 5. Vercauteren, T., Pennec, X., Perchant, A., Ayache, N.: Diffeomorphic demons: efficient nonparametric image registration. Neuroimage 45, S61–S72 (2009) 6. Miao, S., Wang, Z.J., Zheng, Y., Liao, R.: Real-time 2D/3D registration via CNN regression. In: 2016 IEEE 13th International Symposium on Biomedical Imaging (ISBI), pp. 1430–1434 (2016) 7. Uzunova, H., Wilms, M., Handels, H., Ehrhardt, J.: Training CNNs for image registration from few samples with model-based data augmentation. In: Descoteaux, M., Maier-Hein, L., Franz, A., Jannin, P., Collins, D.L., Duchesne, S. (eds.) MICCAI 2017. LNCS, vol. 10433, pp. 223–231. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-66182-7_26 8. Wang, J., Zhang, M.: DeepFLASH: An efficient network for learning-based medical image registration. In: 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4443–4451 (2020) 9. Balakrishnan, G., Zhao, A., Sabuncu, M.R., Guttag, J., Dalca, A.V.: An unsupervised learning model for deformable medical image registration. In: 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9252–9260 (2018) 10. de Vos, B.D., Berendsen, F.F., Viergever, M.A., Sokooti, H., Staring, M., Išgum, I.: A deep learning framework for unsupervised affine and deformable image registration. Med. Image Anal. 52, 128–143 (2019)
226
Z. Liu et al.
11. Chen, J., Frey, E.C., He, Y., Segars, W.P., Li, Y., Du, Y.: TransMorph: transformer for unsupervised medical image registration. Med. Image Anal. 82, 102615 (2022). https://doi.org/ 10.1016/j.media.2022.102615 12. Dosovitskiy, A., et al.: An Image is Worth 16 × 16 Words: Transformers for Image Recognition at Scale. Presented at the International Conference on Learning Representations, 26 Feb 2023 13. Milletari, F., Navab, N., Ahmadi, S.-A.: V-Net: fully convolutional neural networks for volumetric medical image segmentation. In: 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. IEEE (2016) 14. Chen, J., Frey, E.C., He, Y., Segars, W.P., Li, Y., Du, Y.: TransMorph: transformer for unsupervised medical image registration. Med. Image Anal. 82, 102615 (2022) 15. Liu, Z., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 9992–10002. IEEE, Montreal, QC, Canada (2021) 16. Park, N., Kim, S.: How do vision transformers work? arXiv preprint arXiv:2202.06709 (2022) 17. Liu, Z., Mao, H., Wu, C.-Y., Feichtenhofer, C., Darrell, T., Xie, S.: A ConvNet for the 2020s. In: 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 11966–11976. IEEE, New Orleans, LA, USA (2022) 18. Jia, X., Bartlett, J., Zhang, T., Lu, W., Qiu, Z., Duan, J.: U-Net vs transformer: is u-net outdated in medical image registration? In: Machine Learning in Medical Imaging: 13th International Workshop, MLMI 2022, Held in Conjunction with MICCAI 2022, Singapore, September 18, 2022, Proceedings, pp. 151–160. Springer (2022) https://doi.org/10.1007/978-3-031-210143_16 19. Sandler, M., Howard, A., Zhu, M., Zhmoginov, A., Chen, L.-C.: MobileNetV2: inverted residuals and linear bottlenecks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4510–4520 (2018) 20. Ba, J.L., Kiros, J.R., Hinton, G.E.: Layer normalization. arXiv preprint arXiv:1607.06450 (2016) 21. Ioffe, S., Szegedy, C.: Batch normalization: accelerating deep network training by reducing internal covariate shift. In: Proceedings of the 32nd International Conference on Machine Learning, pp. 448–456. PMLR (2015) 22. Mueller, S.G., et al.: Ways toward an early diagnosis in Alzheimer’s disease: the Alzheimer’s Disease Neuroimaging Initiative (ADNI). Alzheimer’s & Dementia. 1, 55–66 (2005) 23. Fischl, B.: FreeSurfer. NeuroImage. 62, 774–781 (2012) 24. Paszke, A., et al.: PyTorch: an imperative style, high-performance deep learning library. In: Advances in Neural Information Processing Systems. Curran Associates, Inc. (2019) 25. Avants, B.B., Tustison, N.J., Song, G., Cook, P.A., Klein, A., Gee, J.C.: A reproducible evaluation of ANTs similarity metric performance in brain image registration. Neuroimage 54, 2033–2044 (2011)
Joint Skeleton and Boundary Features Networks for Curvilinear Structure Segmentation Yubo Wang1,2 , Li Chen1,2(B) , Zhida Feng1,2 , and Yunxiang Cao1,2 1 School of Computer Science and Technology, Wuhan University of Science and Technology,
Wuhan, China [email protected] 2 Hubei Province Key Laboratory of Intelligent Information Processing and Real-Time Industrial System, Wuhan University of Science and Technology, Wuhan, China
Abstract. Curvilinear structure segmentation has wide-ranging practical applications across many fields. However, existing methods have low topological accuracy when segmenting curved structures, and face difficulties in maintaining complete topological connectivity in their segmentation results. To address these problems, we propose a joint skeleton and boundary feature Encoder-Decoder segmentation network for curved structures. Our method incorporates three decoding branches that extract semantic, skeleton, and boundary features, respectively. Additionally, each decoder output undergoes feature fusion via a joint unit after every layer. Furthermore, adaptive connection units are added between the encoder and decoder to selectively capture information from the encoder. Finally, we perform evaluations on three public datasets for curvilinear structure segmentation tasks, including retinal images for clinical diagnosis, coronary angiography images, and road crack images. Experimental results show that the method outperforms other existing state-of-the-art methods in terms of pixel-level accuracy and topological connectivity. Keywords: Curvilinear Structure Segmentation · Skeleton Features · Boundary Features
1 Introduction A curvilinear structure is a set of connecting lines or curves with a certain width, and its structured features are particularly important. This type of object is widely found in medical, remote sensing, microscopic and other images. Curvilinear structure segmentation is the binary mask segmentation of curved objects in such curved mesh images. For example, segmentation of blood vessels in medical images can help doctors to determine and treat certain lesions [1, 2], automatically detect cracks in road images captured by drones [3], extract road networks from remote sensing images [4, 5], and so on. There are usually two main difficulties in segmenting curved mesh images, one of which is that these images are often complex in composition, with no obvious difference between the target curve and the background, and the curves themselves vary in thickness © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 227–238, 2023. https://doi.org/10.1007/978-981-99-4761-4_20
228
Y. Wang et al.
and structure, which is usually very challenging for the extraction and segmentation of subtle curves, and it is difficult to preserve the complete topological connectivity. And considering the limitations of imaging equipment, shooting environment, and other factors, this challenge will be even more daunting. Moreover, the true labels of curved mesh image datasets such as medical image datasets are known as the gold standard, and they are often manually labeled by experienced human experts who can still accurately segment subtle curves with their rich a priori knowledge when they are indistinguishable or even lost in the images. Computers can also have this rich a priori knowledge if sufficient gold standard data is fed into the deep learning network, unfortunately, the publicly available curve segmentation dataset is far from adequate for training. Existing advanced segmentation methods [6–9] tend to achieve substantial performance in terms of pixel-level segmentation accuracy metrics, but the overall topological connectivity of curves is broken due to under-segmentation. This is because these methods do not address the loss of the fine curves mentioned earlier, and the number of fine curve pixels is much smaller than that of coarse curves that are easy to segment, which has a limited impact on overall pixel-level accuracy and can be easily ignored. However, the fine curve is very important for measuring the topological connectivity of the segmentation structure. To solve the above problems, we propose a coded segmentation network with joint skeleton and boundary features to extract semantic features, skeleton features, and boundary features of the image simultaneously, and use skeleton features and boundary features to constrain and supervise segmentation results, improve topological connectivity while ensuring pixel-level accuracy of segmentation results.
2 Related Work Segmentation of curvilinear structure is one of the semantic segmentation tasks, which usually involves making a prediction for each pixel given an image as input and determining whether the pixel belongs to the background or to the target to be segmented. Earlier traditional methods such as threshold segmentation, region segmentation, graph theory, and clustering-based methods have not achieved satisfactory results. With the rise of deep learning, some deep learning-based sub-semantic segmentation methods have achieved impressive segmentation results, such as Full Convolutional Network (FCN) [10], SegNet [11], U-Net [12], and its various variants [13, 14]. U-Net is a deep neural network with encoder-decoder architecture, in which the encoder part uses convolution and pooling to obtain image context information, and the decoder part gradually reduces to the size of the original image by up-sampling and cropping, and then the features of different layers can be combined by hopping connections. The decoder part gradually reduces to the original image size by up sampling and cropping and then enables the features of different layers to be combined by jumping connections. To enhance pixel-level segmentation accuracy, the Attention U-Net [13] incorporates an attention mechanism introduced on top of the U-Net architecture. An Attention Gate is integrated into different layers of the decoder to mitigate redundant information from skip connections. Recurrent U-Net [15] leverages the concept of recurrent neural networks by first introducing gated recurrent units between the U-shaped network encoder and
Joint Skeleton and Boundary Features Networks
229
decoder. Further, it treats the entire U-shaped network as a recurrent unit, refining features through multiple iterations. DA-Net [17] utilizes a two-branch transformer module to achieve high accuracy. The module combines local information at the patch level and global context at the image level. To preserve the topological connectivity of segmentation results, SCR-Net [16] leverages the structural consistency of synthesized data to construct a structurally consistent recovery network. This process extracts high-frequency components from the data while maintaining consistency in structure. JTFN [29], on the other hand, replaces each layer of the decoder with FIM for obtaining boundary features of images while preserving topological connectivity in results through multiple iterations of feedback. The aforementioned approach has demonstrated remarkably impressive results in semantic segmentation of curve structures. However, it lacks sensitivity to the structural characteristics of curved objects, resulting in issues like partial or even complete discontinuity and omission in segmented curvilinear structures. Combining existing research, we propose a joint skeleton and boundary feature curvilinear structure segmentation network (JSBF) to additionally capture the skeleton and boundary features of the curve structure, constrain the curve structure, and improve its topological connectivity without losing pixel accuracy. Compared with previous methods, the innovations in this paper are: (1) Two additional branches of the auxiliary decoder are designed to capture skeleton features and boundary features. (2) A feature joint unit is introduced after each decoder layer so that the primary decoder can acquire additional features extracted by the secondary decoder. (3) Adaptive connection units are added between the encoder and the decoder to help the decoder.
3 Method 3.1 Generation of Skeleton Labels and Boundary Labels A skeleton is a structure-based target descriptor that carries feature information such as the position and topology of a curve and can provide an effective characterization of the topology of a curve. The effectiveness of using skeleton features to constrain segmentation has been demonstrated experimentally in the literature [18, 19]. The literature [20, 21] supports that topology and connectivity of boundaries are related and demonstrates that connectivity of topology can help identify boundaries. Inspired by it, we use a joint skeleton feature and boundary feature approach to enhance the topological connectivity of the segmentation results. The skeleton labels are generated using the Zhang-Suen algorithm [22], which gradually removes the target contours from the binary image through iterations, retaining only the pixels on the target skeleton. In each iteration, the pixel values depend on the values of itself and 8 neighboring pixels at the previous iteration. Each iteration of the algorithm consists of two operations, the first step is to remove the southeast boundary point and the northwest corner point from the pixel’s 8-neighborhood, and the second step is to remove the northwest boundary point and the southeast corner point. For the acquisition of boundary labels, we use Canny operator to process the labels. The image is smoothed by Gaussian filtering, the gradient amplitude and direction are calculated, non-maximum suppression is applied to find local maxima, strong and weak
230
Y. Wang et al.
Skeleton
Ground Truth
Boundary
Fig. 1. Skeleton labels and boundary labels.
edges are determined by double thresholding, and they are connected using an edge tracking algorithm to extract the boundary lines in the labeled image. The above image skeleton labels and boundary labels are shown schematically in Fig. 1. 3.2 JSBF Architecture As shown in the Fig. 2, the curvilinear structured segmentation network with joint skeleton and boundary features proposed in this paper consists of a backbone encoder, a three-branch decoder with feature joint units, and four adaptive connection units.
Fig. 2. The overall structure of JSBF consists of an encoder and three decoders. Each layer of the decoder has an FJU, and the encoder and decoder are connected by the ACU.
Specifically, the image T is input into a VGG-like encoder with five convolutional blocks, and the high-dimensional features F(i) of the image are obtained by multiple downsampling, followed by inputting F(i) into three different decoder branches for upsampling and decoding to extract its semantic features, skeleton features, and boundary features respectively, and then these three features are then communicated through the feature joint unit. Where the upsampling module in the decoder is composed of a deconvolution layer and two convolution layers. The final layer of the coder is the output layer and contains a 1 × 1 convolution and a Sigmoid function. In addition, the input
Joint Skeleton and Boundary Features Networks
231
information received by the decoder when performing upsampling decoding includes, in addition to the feature information from the corresponding feature information of the previous layer, feature information from the corresponding layer of the encoder, and after the adaptive connection unit. 3.3 Feature Joint Units The feature joint unit proposed in this paper serves as a channel for communication between the original structural features, skeleton features, and boundary features. This module is only introduced in the decoder, as the decoding layer places more emphasis on semantic features than the coding layer. The Fig. 3 (a) shows the structure of the FJU. The input d s on the left side is the skeleton feature of the output of the upsampling layer, the input d t in the middle is the feature used for semantic segmentation and the corresponding d b on the right side is the boundary feature. Thus, the whole FJU can be represented as:
d s , d t , d b = FJU (d s , d t , d b )
(1)
where d s , d t , d b are the skeleton features, semantic features, and boundary features respectively after feature union. d t branch first fuses d s and d b to obtain the joint semantic features, and then the remaining two branches fuse the joint semantic features to obtain the joint skeleton features and joint boundary features respectively. The design of FJU makes the semantic features of the encoder output fused to the skeleton features and boundary features, thus improving the topological connectivity of the semantic segmentation results. 3.4 Adaptive Connection Unit U-Net uses a skip connection to combine encoder features and decoder features directly when performing upsampling decoding to achieve information integration. This approach enables better recovery of spatial information of the curve structure at the decoding stage and improves the accuracy of segmentation. However, this does not eliminate the interference of noise in the image. Therefore, we introduce an adaptive information linking unit in the jump connection section between the coder and decoder to enable the decoder to adaptively and selectively combine the decoder output features to reduce the interference of background noise in the image. The structure of the adaptive information connection unit is shown in Fig. 3(b). It serves as input of the encoder features et of the corresponding layer and the decoding features dt of the previous layer, and then adaptively filters the information of both to obtain the final jump connection information: ct = SA(et + d t ) × CA(et + d t )
(2)
ot = et × ct + ϕ 2 et + d t × ct
(3)
where SA denotes spatial attention, CA denotes channel attention and ϕ 2 denotes two convolution layers.
232
Y. Wang et al.
(a) FJU
(b)ACU
Fig. 3. (a) Feature Joint Units. (b) Adaptive Connection Unit
3.5 Joint Loss Function Due to the presence of two additional auxiliary decoder branches, the whole network has a total of three outputs, using the training loss function for the skeleton extraction task and the boundary extraction task as the canonical terms of the training loss function for the segmentation task. The binary cross-entropy loss function is used for all three loss functions, and the joint loss function is defined as: L = Lbce (y, g) + λs Lbce (ys , gs ) + λb Lbce (yb , gb )
(4)
where y is the segmentation prediction result and g is the ground truth; ys is the skeleton prediction result and gs is the skeleton label; yb is the boundary prediction result and gb is the boundary label; λs and λb are the weighting factors respectively.
4 Experiments 4.1 Datasets To validate that the proposed network model can be applied to the segmentation task of a wide range of curvilinear structure, the following three publicly available datasets are used to validate the model’s segmentation effectiveness: the retinal image DRIVE [23], the x-ray angiography coronary artery disease dataset XCAD [24], and the road crack dataset CrackTree [25]. The DRIVE dataset contains 40 retinal images for clinical diagnosis, of which 33 are healthy retinal images and 7 show signs of mild early diabetic retinopathy. These images are divided into a training group and a test group, with 20 images in each group. The latter 20 are used for training and the first 20 for testing. The XCAD dataset comprises coronary angiographic images acquired using the General Electric Innova IGS 520 system during stent placement. Each image is a 512 × 512 pixel, single-channel representation. The dataset features 126 separate coronary angiograms, which an experienced radiologist has annotated for vessel segmentation.
Joint Skeleton and Boundary Features Networks
233
Of these angiograms, 84 are allocated for model training, and 42 serve as test images, also with expert annotations for vessel segmentation. CrackTree has 206 images of the road surface, labeled with a mask of only one-pixel width. Therefore, CrackTree is sensitive to topology-aware evaluation. The dataset is not divided in the literature and the first 150 are used for training and the rest for testing. 4.2 Evaluation Metrics To evaluate the pixel accuracy and topological connectivity of the proposed network model for segmenting curvilinear structure corpora, two types of evaluation metrics are used for this experiment. For the pixel-level evaluation, we chose the F1 score, precision and recall that are widely used in existing semantic segmentation methods. The topologybased metric is relatively robust compared to the pixel-level metric, which is sensitive to small variations. To assess mask topological consistency, a similar approach to [26, 27] was used. A randomly selected connection path from the label was selected and the equivalent path in the binary prediction mask was analyzed. If such a path does not exist, the predicted path is classified as infeasible. If it exists and the length error between the two paths is within 10%, it is classified as the correct path. One thousand paths are samples for each test image. The percentage of infeasible paths and the percentage of correct paths are defined as: k ai (5) Infea. = ki=1 i=1 Ni k bi Correct. = ki=1 (6) i=1 Ni where ai and bi are the number of infeasible and correct paths for each test image, respectively, and Ni is the sum of the number of paths. 4.3 Comparison To verify the validity of the method, JSBF and other advanced curve segmentation methods, including U-Net [12], IterNet [28], JTFN [29], FR-Unet [30], are compared on three datasets, DRIVE, XCAD, and CrackTree200. To ensure the fairness and reliability of the experiments. We only compare the results with those of replicated models under the same settings. The results are shown in Table 1. The proposed method achieves optimal values for 13 out of a total of 15 metrics. Among them, the pixel-level accuracy metric F1 achieved the highest scores of 78.77, 74.73 and 83.05 on the three datasets, respectively, and the topological connectivity metric Infeasible achieved 31.91, 94.87 and 36.75, respectively, also ahead of alternative methods.
234
Y. Wang et al.
Table 1. Comparison with other segmentation methods. Some of the results in the table are not given in the original paper, and the results we give are reproduced. Datasets
Method
Precision/%↑
Recall/%↑
F1/%↑
Infeasible/%↓
Correct/%↑
XCAD
U-Net
77.95
74.32
75.75
45.24
48.15
IterNet
70.70
79.30
74.11
36.38
56.05
CrackTree
DRIVE
JTFN
75.37
78.33
76.81
35.21
55.17
FR-Unet
73.35
78.88
74.27
33.41
58.71
Ours
77.01
81.41
78.77
31.91
60.13
U-Net
77.64
54.74
65.08
98.61
1.35
IterNet
78.57
63.35
71.17
98.45
1.45
JTFN
78.13
67.07
71.90
97.99
1.92
FR-Unet
69.32
68.48
68.94
96.63
3.23
Ours
81.15
69.13
74.37
94.87
4.78
U-Net
82.74
80.59
81.41
52.52
39.83
IterNet
83.07
81.00
81.97
42.76
49.32
JTFN
83.41
82.76
82.81
40.19
51.17
FR-Unet
83.29
82.89
82.92
47.62
43.46
Ours
83.01
83.10
83.05
36.75
52.27
This demonstrates that our method can improve the topological connectivity while ensuring the pixel-level accuracy of the segmentation results. Meanwhile, Recall values are 81.41, 69.13 and 83.10 on the three datasets, respectively, which are on average 2.87 higher than FR-Unet. Figure 4 shows examples of segmentation of JSBF and alternative scenarios. Rows 1 to 6 correspond to the coronary angiography image, the road crack image for clinical diagnosis and the retinal image. From the experimental results, when the segmented images are compared with the labelled images, the method is able to segment the fine curve structures accurately, and the extracted curve structures are rarely broken or missing when the curve structures are relatively fine and the contrast with the background is not too different. Some background noise has also been suppressed. 4.4 Ablation Study Ablation experiments are performed on the XCAD for skeletal feature branches, boundary feature branches, and adaptive information connectivity units in terms of segmentation accuracy and topological connectivity metrics. The baseline model M0 is a Ushaped network without skeletal and boundary feature branches and adaptive information connectivity units. M1 and M2 are baseline models with skeletal and boundary feature branches, respectively, and M3 is a baseline model with introduced adaptive connectivity units. The quantification results are shown in the Table 2.
Joint Skeleton and Boundary Features Networks
Image
Ground Truth
U-Net
FR-Unet
235
Ours
Fig. 4. Examples of segmentation by JSBF and other methods. It can be inferred that: the skeleton feature branch and the boundary feature branch help the network compensate for the broken parts, and the Adaptive Connection Unit suppresses the noise in the background.
Adding the skeleton feature extraction branch to the segmentation network increases the Correct value by 9.51 on the XCAD. This suggests that the branch improves the network’s ability to obtain topological structure information and enhances topological connectivity in the segmentation results. Adding the boundary feature extraction branch increased both the F1 and Correct values. The F1 reached 76.79 and the Correct reached 54.77, showing improved pixel-level accuracy and topological connectivity. Introducing the adaptive connectivity unit significantly increased the Precision value from 75.88 to 76.98, indicating suppressed background noise interference. To demonstrate the relationship between the improvement of model performance and the number of parameters, the model was further subjected to an ablation experiment. Comparison was made between our model and two other models: U-Net model M4, which has a similar number of parameters, and model M5, in which two auxiliary branches proposed in our method were replaced with semantic feature extraction branches identical to the main branch. The quantitative results are shown in Table 3.
236
Y. Wang et al. Table 2. Ablations of Skeleton and Boundary feature, ACU
Method
Precision/%↑
Recall/%↑
F1/%↑
Infeasible/%↓
Correct/%↑
M0
75.88
74.32
75.46
46.37
47.33
M1
76.27
76.77
76.54
36.32
56.84
M2
76.34
77.25
76.79
38.29
54.77
M3
76.98
75.32
75.92
40.89
52.51
Ours
77.01
81.41
78.77
31.91
60.13
Table 3. Ablations of model parameters on XCAD. Method
Param
Precision/%↑
Recall/%↑
F1/%↑
Infeasible/%↓
Correct/%↑
M4
13M
77.95
74.32
75.75
45.24
48.15
M5
16M
76.35
75.24
76.08
45.43
46.99
Ours
16M
77.01
81.41
78.77
31.91
60.13
The F1 and Correct of the method are 3.02 and 11.98 higher than those of the U-Net model with the same number of parameters, and 2.49 and 13.14 higher than those of the three-branch model with the same features, respectively. It can be inferred that the improvement in the model effect is not due to the increase in the number of parameters, but rather the additional skeletal and boundary features help the network to achieve better segmentation results.
5 Conclusion We propose a novel network to solve the issue of sub-segmentation in curved surfaces. Our approach combines skeleton and boundary feature coding with decoded segmentation, and employs a three-branch decoder to extract semantic, skeleton, and boundary features. Our method improves topological connectivity and reduces broken and missing segmentation results, while maintaining high pixel-level accuracy. Comparative and ablation experiments confirm the effectiveness of our skeleton and boundary feature extraction branches. This study has paved the way for numerous potential avenues of future research. Next, we will explore the extension of the method to 3D curved structural images, which have more complex spatial structural features and are more challenging. Acknowledgment. This work was supported by National Natural Science Foundation of China (62271359).
Joint Skeleton and Boundary Features Networks
237
References 1. Sau, P.C., Bansal, A.: A novel diabetic retinopathy grading using modified deep neu-ral network with segmentation of blood vessels and retinal abnormalities. Multimedia Tools and Applications 81(27), 39605–39633 (2022) 2. Kar, M.K., Neog, D.R., Nath, M.K.: Retinal vessel segmentation using multi-scale residual convolutional neural network (msr-net) combined with generative adversarial networks. Circuits Systems Signal Process. 42(2), 1206–1235 (2023) 3. Yang, F., Zhang, L., Yu, S., Prokhorov, D., Mei, X., Ling, H.: Feature pyramid and hierarchical boosting network for pavement crack-detection. IEEE Trans. Intell. Transp. Syst. 21(4), 1525– 1535 (2019) 4. Batra, A., Singh, S., Pang, G., Basu, S., Jawahar, C., Paluri, M.: Improved road connectivity by joint learning of orientation and segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10385–10393 (2019) 5. Ventura, C., Pont-Tuset, J., Caelles, S., Maninis, K.K., Van Gool, L.: Iterative deep learning for road topology extraction. arXiv pre-print arXiv:1808.09814 (2018) 6. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: Deeplab: Semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Trans. Pattern Anal. Mach. Intell. 40(4), 834–848 (2017) 7. Chen, L.-C., Zhu, Y., Papandreou, G., Schroff, F., Adam, H.: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 833–851. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_49 8. Li, K., Qi, X., Luo, Y., Yao, Z., Zhou, X., Sun, M.: Accurate retinal vessel segmentation in color fundus images via fully attention-based networks. IEEE J. Biomed. Health Inform. 25(6), 2071–2081 (2020) 9. Zhu, Z., Xu, M., Bai, S., Huang, T., Bai, X.: Asymmetric non-local neural networks for semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 593–602 (2019) 10. Long, J., Shelhamer, E., Darrell, T.: Fully convolutional networks for semantic segmentation. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 3431– 3440 (2015) 11. Badrinarayanan, V., Kendall, A., Cipolla, R.: Segnet: A deep convolu-tional encoder-decoder architecture for image segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(12), 2481– 2495 (2017) 12. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional net-works for biomedical image segmentation. In: Medical Image Com-puting and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5–9, 2015, Proceedings, Part III 18, pp. 234–241. Springer (2015) 13. Oktay, O., et al.: Attention u-net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999 (2018) 14. Alom, M.Z., Yakopcic, C., Hasan, M., Taha, T.M., Asari, V.K.: Recurrent residual u-net for medical image segmentation. Journal of Medical Imaging 6(1), 014006 (2019) 15. Wang, W., Yu, K., Hugonot, J., Fua, P., Salzmann, M.: Recurrent u-net for resource-constrained segmentation. In: Proceedings of the IEEE/CVF international conference on computer vision, pp. 2142–2151 (2019) 16. Li, H., et al.: Structure-consistent restoration network for cataract fundus image enhancement. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part II. pp, 487– 496. Springer(2022)
238
Y. Wang et al.
17. Wang, C., Xu, R., Xu, S., Meng, W., Zhang, X.: Da-net: Dual branch transformer and adaptive strip upsampling for retinal vessels segmentation. In: Medical Image Computing and Computer Assist-ed Intervention–MICCAI 2022: 25th International Conference, Sin-gapore, September 18–22, 2022, Proceedings, Part II, pp. 528–538. Springer (2022) 18. Jerripothula, K.R., Cai, J., Lu, J., Yuan, J.: Object co-skeletonization with co-segmentation. In: 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 3881–3889. IEEE (2017) 19. Wshah, S., Shi, Z., Govindaraju, V.: Segmentation of arabic handwriting based on both contour and skeleton segmentation. In: 2009 10th International Conference on Document Analysis and Recognition, pp. 793–797. IEEE (2009) 20. Jain, V., et al.: Boundary learning by optimization with topological constraints. In: 2010 IEEE Computer Society Conference on Computer Vision and Pattern Recognition, pp. 2488–2495. IEEE (2010) 21. Xu, Z., Sun, Y., Liu, M.: Topo-boundary: A benchmark dataset on topological road-boundary detection using aerial images for autonomous driving. IEEE Robotics and Automation Letters 6(4), 7248–7255 (2021) 22. Zhang, T.Y., Suen, C.Y.: A fast parallel algorithm for thinning digital patterns. Communications of the ACM 27(3), 236–239 (1984) 23. Staal, J., Abr‘amoff, M.D., Niemeijer, M., Viergever, M.A., Van Ginneken, B.: Ridge-based vessel segmentation in color images of the retina. IEEE transactions on medical imaging 23(4), 501–509 (2004) 24. Ma, Y., et al.: Self-supervised vessel segmentation via adversarial learning. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 7536–7545 (2021) 25. Zou, Q., Cao, Y., Li, Q., Mao, Q., Wang, S.: Cracktree: Automatic crack detection from pavement images. Pattern Recogn. Lett. 33(3), 227–238 (2012) 26. Ara´ujo, R.J., Cardoso, J.S., Oliveira, H.P.: A deep learning design for improving topology coherence in blood vessel segmentation. In: Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part I 22, pp. 93–101. Springer (2019) 27. Wegner, J.D., Montoya-Zegarra, J.A., Schindler, K.: A higher-order crf model for road network extraction. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 1698–1705 (2013) 28. Li, L., Verma, M., Nakashima, Y., Nagahara, H., Kawasaki, R.: Iternet: retinal image segmentation utilizing structural redundancy in vessel networks. In: Proceedings of the IEEE/CVF winter conference on applications of computer vision, pp. 3656–3665 (2020) 29. Cheng, M., Zhao, K., Guo, X., Xu, Y., Guo, J.: Joint topology-preserving and featurerefinement network for curvilinear structure segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 7147–7156 (2021) 30. Liu, W., et al.: Full-resolution network and dual-threshold iteration for retinal vessel and coronary angiograph seg-mentation. IEEE J. Biomed. Health Inform. 26(9), 4623–4634 (2022)
A Driver Abnormal Behavior Detection Method Based on Improved YOLOv7 and OpenPose Xingquan Cai, Shun Zhou, Jiali Yao, Pengyan Cheng, and Yan Hu(B) School of Information Science and Technology, North China University of Technology, Beijing 100144, China [email protected]
Abstract. The current driver’s abnormal behavior detection is interfered with by complex backgrounds, body self-obscuring, and other factors leading to low accuracy of small object detection and large error of human node detection. In this paper, we propose a driver abnormal behavior detection method based on improved YOLOv7 and OpenPose. Firstly, we add an attention mechanism to the backbone network of YOLOv7, increase the feature information contained in the shallow feature map, introduce Wise-IoU loss function to improve the detection accuracy of small and long objects (e.g., cigarettes, water glasses). Secondly, we calculate the IoU values between the driver’s hand and the confidence frame of small objects to detect the abnormal behavior of driver-object interaction (e.g., smoking, drinking). Thirdly, we improve the OpenPose network model by replacing convolutional kernels and adjusting the connection of convolutional kernels to achieve twodimensional node detection of the driver’s driving posture. Finally, we use the FastDTW algorithm, and calculate the similarity of the driver’s two-dimensional nodal information to detect the abnormal posture of the driver. The experimental results show that the method in this paper has high accuracy of small object detection and low nodal error, and can effectively detect abnormal behaviors such as smoking, drinking and abnormal posture of drivers during driving. Keywords: Abnormal Behavior · Attentional Mechanisms · Driving Posture
1 Introduction With the rapid development of the country’s economy, the transportation industry has also gained a booming development. Cars have become an important part of people’s daily life and work. At present, the number of cars in demand has skyrocketed, and the problems caused by traffic have become more and more prominent, and it is not uncommon to see situations that endanger people’s lives and property due to traffic accidents [1]. Recent advances in machine vision technology, computational performance, and public datasets have rapidly improved abnormal driving behavior recognition [2]. This paper uses video-based driver abnormal behavior detection, which can prevent traffic accidents and benefit transportation centers, large fleets, and bus groups by regulating and disciplining drivers’ behavior. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 239–250, 2023. https://doi.org/10.1007/978-981-99-4761-4_21
240
X. Cai et al.
2 Related Work At present, the research on abnormal driver behavior is divided into two main types of research directions, one is the physiological characteristics of abnormal behavior, mainly in fatigue driving, and the other is the psychological characteristics of abnormal behavior, mainly in distracted driving. Distracted driving is further divided into cognitive distractions, manual distractions, and visual distractions. For cognitive distraction currently it is not possible to identify it by good technical means. While manual distraction and visual distraction can be unified as distracting behavior with behavioral characteristics recognizable to the naked eye, with the development of machine vision technology, image based abnormal behavior detection can be achieved. The driver’s head posture is also used as an important basis for discerning distracted driving. Tawari [3] proposed a distributed camera framework called Co HMEt, where each camera detects and tracks the face features and builds a geometric model of the head pose to output the head pose in real time. Diaz-Chito [4] established a head model in 2016 by locating the coordinates of the key point of the center of both eyes and the key point of the tip of the nose for driver’s head deflection angle estimation, and in 2018 estimated the driver’s head deflection and pitch angles using directional gradient histogram, streamwise embedding projection and continuous regression [5]. Borghi proposed the POSEidon+ framework for estimating driver head and shoulder pose based on depth images [6]. Inspired by the Point Net++ framework, Hu [7] proposed a point cloud-based driver head pose estimation method to create 6D vectors for head pose estimation. Liu [8] proposed a lightweight framework called Recurrent Multitasking Thin Net to estimate the driver pose, which consists of nine body nodes, five facial keypoints and three head Euler angles. Driving is done through hand movements, so it is equally important to monitor driver hand movements when detecting driver distractions. Deng [9] used a Faster RCNN based system for hand detection and rotation estimation using a rotation layer after ROI pooling. Wang [10] combined hand features with cascaded convolutional neural network classification to achieve multi-feature based hand detection. Xia used an AdaBoost detector with aggregated channel features for driver hand detection based on the Fast R-CNN framework [11, 12]. Yuen [13] estimated the affinity heat maps of the driver and passenger wrist and elbow joints by a modified OpenPose model, and used Part Affinity Fields to learn the positions of wrist and elbow joints and match them according to their correlations to achieve two-handed positioning of the driver and passenger in the vehicle. Chen [14] proposed a deep learning algorithm for abnormal driving behavior monitoring based on the improved YOLOv3 algorithm to provide a theoretical method reference for pilot driving behavior risk identification and control. After the above analysis, we choose a combination of YOLOv7 [15] based target detection and OpenPose [16] based pose estimation to implement the abnormal behaviors that occur during driver driving.
A Driver Abnormal Behavior Detection Method
241
3 Proposed Method At present, the method of driver abnormal behavior detection is mainly based on target detection, and target detection can only determine the abnormal driving behavior in the driving process based on the actual objects such as cigarettes, water cups, cell phones, etc. Held by the driver during the driving process, and it is difficult to achieve some driving behavior with abnormal driving posture just by target detection algorithm, so in order to realize the driver without physical objects at the same time such as hands off the steering wheel, one-handed driving, head deviation, etc. To identify abnormal driving behavior. The main steps of this paper based on the improved YOLOv7 and OpenPose driver abnormal behavior detection method include. Firstly, we enhance the network’s ability to sense small targets and improve the detection accuracy of the offending objects by adding an attention mechanism to the backbone network of YOLOv7 and introducing the Wise-IoU loss function. Secondly, we calculate the category of the violating objects detected by YOLOv7 and the driver’s hand position information to determine whether there is abnormal behavior. Then, we use the improved OpenPose method to detect the driver’s joint point information. Finally, we use the FastDTW algorithm to calculate the similarity of the driver’s normal driving posture and determine whether there is any abnormal posture in the driver’s driving process. 3.1 Add Attention Module Due to the feature extraction process of YOLOv7’s backbone network, for small target detection and its important shallow and middle texture and contour information are not fully extracted, which has a certain degree of impact on small target detection and easily causes missed detection. Therefore, this paper enhances the attention of the network to small targets by introducing the ACmix attention module. As shown in Fig. 1, the ACmix attention module consists of two modules, convolutional attention and self-attention, combined in parallel. The relationship between input and input and between input and output is taken into account to reduce the occurrence of missed detection. ACmix attention module is a fusion module based on channel and spatial attention, whose main purpose is to improve the performance of the network by enhancing the correlation between different channels in the feature map and between spatial locations. calculated as shown in Eq. (1): W = sigmoid (V )
(1)
where each element of W represents the attention weight at the corresponding position, V denotes the fused feature vector. The original feature map and the weight vector are weighted and summed to obtain the final ACmix attentional feature map. Assuming that the original feature map is X and the weight vector is W , the ACmix attention feature map Y is calculated as shown in Eq. (2): Y =W X where denotes element-by-element multiplication.
(2)
242
X. Cai et al.
Fig. 1. ACmix schematic diagram.
The core idea of ACmix attention module is to obtain global and local feature information through multiple pooling methods and feature maps of different sizes, and to use attention mechanism to weight and fuse these information to improve the expressiveness and generalization ability of the model. 3.2 Replacement Loss Function The Wise-IoU loss function is a loss function used in target detection that combines the IoU loss function and the localization loss function. In the Wise-IoU loss function, the localization loss function uses a smoothed L1 loss function (also called Huber loss function), which is calculated as shown in Eq. (3): smoothL1 (pi − gi ) Lloc (p, g) = (3) i∈Pos
where p denotes the coordinates of the prediction frame output by the network, g denotes the coordinates of the actual labeled frame, and Pos denotes the set of positive samples. smoothL1 Denotes the smoothed L1 loss function, which is calculated as shown in Eq. (4): if |x| < 1 0.5x2 , (4) smoothL1 (x) = |x| − 0.5, otherwise For the IoU loss function, it is calculated as shown in Eq. (5): LIoU (p, g) = −log
IoU (p, g) IoU max
(5)
where IoU (p, g) denotes the intersection ratio between the prediction frame p and the actual labeled frame g. IoU max denotes a threshold value to normalize the value domain of the IoU loss function. Ultimately, the Wise-IoU loss function is calculated as shown in Eq. (6): LWIoU (p, g) = λloc Lloc (p, g) + λIoU LIoU (p, g)
(6)
where λloc and λIoU are two hyperparameters used to balance the contributions of the localization loss function and the IoU loss function.
A Driver Abnormal Behavior Detection Method
243
Compared to YOLOv7’s loss function, Wise-IoU offers several advantages. More accurate: Wise-IoU calculates IoU more accurately, which leads to improved detection accuracy by better measuring the overlap between prediction and real frames. More stable: YOLOv7’s loss function can become unstable when a frame has no objects, leading to problems in model training. Wise-IoU introduces a penalty factor that makes model training more stable. More flexible: Wise-IoU can be adjusted for different application scenarios by modifying the penalty coefficient to balance localization and classification based on practical needs. 3.3 2D Human Pose Estimation The abnormal behavior of the driver’s posture during driving needs to be detected quickly and accurately, so the computational complexity of the model and the detection accuracy are both required, i.e., a posture estimation network with low computational complexity and high detection accuracy needs to be constructed. In this paper, we improve the model based on OpenPose network to realize the 2D coordinate detection of key points of human body during driver’s driving. The specific implementation process is as follows. Construction of OpenPose Network. The OpenPose model, as a classical model for two-dimensional pose estimation, is widely used in various practical scenarios, such as pedestrian detection [17] and fall detection [18], because of its advantages of high detection accuracy and robustness. First, the VGG19 network extracts driver action features to obtain feature map F, which is then used as input for a two-branch, multi-stage convolutional network. The upper branch predicts the partial affinity field for keypoint location and orientation, while the lower branch predicts the partial confidence map for keypoint location. The network uses a multi-stage structure with large convolution kernels to capture semantic information between key points within each stage. The network structure is illustrated in Fig. 2.
Fig. 2. OpenPose network structure diagram.
Replacement Convolution Kernel. First, the 7 × 7 convolution kernel is replaced with three consecutive 3 × 3 small convolution kernels; then, the 3 × 3 ordinary convolution kernel is replaced with a depth-separable convolution kernel, i.e., the ordinary convolution is decomposed into a 3 × 3 Depthwise convolution and a one Pointwise convolution of 1 × 1. The former is responsible for the convolution operation of each convolution
244
X. Cai et al.
kernel with each input channel, the latter is responsible for the linear combination of multiple channel feature maps outputted from the upper layer. Figure 3(a) shows the ordinary 3 × 3 convolution kernel structure, while Fig. 3(b) and Fig. 3(c) show the deep convolution kernel and dotted convolution kernel structures, respectively. The improved convolution kernel structure using depth-separable convolution is shown in Fig. 4.
Fig. 3. Structure of ordinary convolution kernel and depth-separable convolution kernel.
Fig. 4. Improving the normal convolution kernel using depth-separable convolution.
The computational consumption of an ordinary convolutional layer is shown in Eq. (7): conv = DK × DK × M × N × DF × DF
(7)
For the deep convolution kernel structure, the size of the convolution kernel K is DK × DK × M . The number of channels of each convolution kernel is 1. The number of convolution kernels is equal to the number of channels of the input feature map, and each convolution kernel only performs convolution operations with a single channel of the feature map. For the dotted convolution kernel structure, the convolution kernel size is 1 × 1 × M , and M is the number of channels in the output of the previous layer. Therefore, the total computational consumption of a depth convolution kernel and a dotted convolution kernel is shown in Eq. (8), then the ratio of the computational volume of the ordinary convolution operation and the depth separable convolution is shown in Eq. (9). Since the size of the replaced ordinary convolutional kernel is 3 × 3, the convolutional operation of the whole improved two-branch multi-stage network can be reduced to nearly 1/9 of the original one. Depthwise + Pointwise = DK × DK × M × DF × DF + M × N × DF × DF Depthwise + Pointwise 1 1 = + 2 conv N DK
(8) (9)
A Driver Abnormal Behavior Detection Method
245
Changing Connection of Convolution. The replacement of ordinary convolutional kernels deepens the depth of the network, which easily leads to the problem of gradient disappearance in the model and affects the final detection accuracy. In order to reduce the computational effort while minimizing the impact on the model accuracy, this paper adopts a dense connection to connect each adjacent convolutional layer. Through dense connectivity, each convolutional layer in the network can directly obtain multi-scale feature information and gradient information from previous inputs, allowing features to be connected on the channel to achieve feature reuse, improving the network’s ability to learn features, while solving the problem of difficulty in training due to deepening of the network depth and improving the efficiency of the model. 3.4 Implement YOLOv7 and OpenPose Joint Anomalous Behavior Detection The anomalous behaviors detected in this paper include human-object interaction anomalies and human posture anomalies. For the anomalous behavior of drivers using specific physical objects in violation during driving, this paper uses the YOLOv7 target detection method to identify which kind of driving anomalies drivers belong to. Since there are limitations in the definition of target detection labels, this paper classifies the possible target classification labels into four categories: cell phone, cigarette, water cup, hand and face. The abnormal behaviors are identified based on the tags detected by YOLOv7. If both cell phone and hand tags are detected and the bounding boxes of the two tags overlap, the behavior is identified as cell phone playing. Similarly, the presence of labels and the overlap of bounding boxes are used to identify the presence of abnormal drinking and smoking behaviors. The overlapping area of the bounding boxes is calculated using Intersection over Union (IoU), which is defined as the ratio of the overlapping area of two bounding boxes to their concatenated area, as shown in Eq. (10): IoU =
SA∩B SA∪B
(10)
where, SA∩B denotes the intersection area of two bounding boxes and SA∪B denotes the concurrent area of two bounding boxes. The IoU value of the two bounding boxes is obtained by dividing the intersection area by the concurrent area. If the IoU value is greater than the set threshold, it can be determined that the two bounding boxes overlap. To address the abnormal human posture behavior of the driver during driving, this paper uses the improved OpenPose posture estimation method to estimate the driver’s two-dimensional joint points, and then calculates the similarity of the driver’s twodimensional joint point information by the FastDTW algorithm to interpret whether the driver has abnormal behavior during driving.
4 Experiments and Results Analysis 4.1 Constructing Dataset To validate the method of this paper, when constructing the data driver abnormal driving behavior dataset, we used the dataset provided by the Kaggle competition and the DMS driver behavior dataset open-sourced by Magic Data. Considering that the shooting
246
X. Cai et al.
angles and categories of this dataset cannot fully meet the data requirements of the study, this paper adopts manual shooting and other means. The final dataset contains 11038 image samples and sets all images to 640 × 640 pixels. The training set is constructed with 7727 images and the test set with 3311 images. 4.2 Algorithm Feasibility Experiments Ablation Experiment. In order to verify the effectiveness of the proposed add-attention mechanism and replacement loss function in this paper, different modules are added to the YOLOv7 model under the same experimental conditions to evaluate the impact of each module on the detection performance. Table 1. Comparison results of ablation experiments ACmix
Wise-IoU √
√ √
√
GFLOPs
mAP50
FPS
Params
105.2
69.41
121
71.3
105.2
70.24
118
71.3
105.5
71.47
87
72.6
105.6
72.30
83
72.7
The results are shown in Table 1. The detection accuracy is improved when replacing only the loss function and adding only the attention mechanism, but both modules are added to the YOLOv7 network model at the same time to ensure the greatest improvement in detection accuracy with a smaller increase in the number of parameters, while also ensuring better real-time performance.
Fig. 5. The detection effect of the basic YOLOv7 network model and the method in this paper.
Figure 5 shows detection results of the proposed method compared with the original YOLOv7 network model on images of drivers engaging in different types of offending
A Driver Abnormal Behavior Detection Method
247
behavior during driving. The proposed method correctly identifies a cell phone in the first column, has higher confidence for water cup detection in the second column, and accurately detects smoking behavior in the third column, while the original model makes mistakes or misses the detection. Compare Experiments of Target Detection Algorithms. The paper proposes an improved YOLOv7 network model and compares it with other network models under the same configuration environment and initial training parameters. The results in Table 2 show that the proposed model outperforms other models in terms of detection accuracy while maintaining real-time detection, with [email protected] improved to varying degrees compared to SSD, Faster RCNN, YOLOv5, and YOLOv7. Table 2. Comparison of experimental results of different network models Network Model
Input size/pixel
[email protected]
[email protected]:0.95
SSD [19]
640 × 640
0.683
0.289
Faster RCNN [20]
640 × 640
0.436
0.157
YOLOv5
640 × 640
0.525
0.231
YOLOv7
640 × 640
0.691
0.291
Proposed Method
640 × 640
0.723
0.306
Driver Posture Visualization Experiment. To verify the feasibility of our improved OpenPose network for extracting driver driving posture, we designed experiments for visualizing driver driving posture, including normal driving posture and abnormal driving posture in this paper.
(a) Normal
(b) Abnormal
(c) Abnormal
(d) Abnormal
(e) Normal
Fig. 6. Abnormal driving posture.
Figure 6 shows the key point detection for drivers of different genders and body sizes in normal driving posture and abnormal driving posture. The improved OpenPose algorithm used in this study can accurately detect the driver’s upper key point, providing essential data for subsequent driver posture similarity calculation. Comparison Experiments with the Original OpenPose Method. This paper conducted comparative experiments to verify the effectiveness of the OpenPose improvement. Computational complexity experiments were performed on the original method,
248
X. Cai et al.
followed by detection accuracy experiments on improved inter-convolutional kernel connection. These experiments were conducted on the driver dataset created in this paper, with results shown in Table 3. [email protected], [email protected], and [email protected] represent detection accuracy of keypoints at similarity thresholds of 25%, 50%, and 75%, respectively. Table 3. Experimental comparison results of the method in this paper and the original method. Method
Precision evaluation
Computational Volume Evaluation
AP
[email protected]
[email protected]
[email protected]
GFLOPs
OpenPose
62.5
91.3
85.6
68.4
171.1
OpenPose+Improved Convolution kernel
59.3
88.0
82.6
65.2
20.3
Proposed Method
61.5
90.3
84.7
67.1
15.6
Table 3 shows the results of comparative experiments on model computational complexity and detection accuracy. The second row indicates that the OpenPose-based convolutional kernel improvement method reduces computation by about 1/9 compared to the original method. The third row shows that at threshold values of 0.25, 0.5 and 0.75, the corresponding AP values all decreased by approximately 4%. The optimization method significantly reduces model computation but also leads to a loss of model detection accuracy. The table shows that after improving the convolutional kernel and using the dense connection method, the model’s detection accuracy improved with a corresponding increase in AP values at different threshold values, and there was also a reduction in computational effort. The dense connection method improves the model’s ability to learn features and detect human key points accurately. Therefore, the improved method based on OpenPose in this paper can significantly reduce the model computation while ensuring a small loss of model detection accuracy, which proves the feasibility and effectiveness of the method. Driver Driving Posture Similarity Experiment. To verify the feasibility of FastDTW algorithm for driver driving posture similarity analysis, this paper uses the algorithm to calculate the 2D action sequence similarity between the driving posture template video and the actual driving posture video. The upper body posture of the driver is analyzed for similarity since the lower body key points cannot be accurately predicted during driving. Figure 6(e) shows the normal driving posture. To accurately identify abnormal driving behavior, this paper estimates the similarity of the driver’s head, left hand, and right hand postures, and uses the average value as the final evaluation criteria. Table 4 shows the similarity scores of some abnormal and normal driving postures.
A Driver Abnormal Behavior Detection Method
249
Table 4. Partial driving posture similarity score table.
driving posture Head similarity value
0.26
0.48
0.86
0.94 0.89
Left hand similarity value
0.58
0.85
0.93
Right hand similarity value
0.18
0.91
0.24
0.91
similarity value Average
0.34
0.74
0.67
0.91
Whether it is abnormal behavior
yes
yes
yes
no
In this paper, we set the similarity values of head, left hand, and right hand to a range of 0–1, where higher values indicate higher similarity and lower values indicate lower similarity. We took the average of these three similarity values and determined that a threshold of 0.8 produced relatively accurate identification of abnormal behavior. Through multiple rounds of experiments, we found that when the average similarity value was below the threshold, the driver exhibited abnormal behavior during driving, whereas when the average similarity value was above the threshold, the driver did not exhibit obvious abnormal behavior during driving.
5 Conclusions and Future Work This paper proposes a driver abnormal behavior detection method based on improved YOLOv7 and pose estimation, aimed at addressing the safety hazards caused by common irregular driving behaviors and the low accuracy of small target detection. An attention mechanism and Wise-IoU loss function are added to enhance the network sensing capability of the YOLOv7 backbone network and improve detection accuracy. Improved OpenPose method is used for 2D nodal detection of driver’s driving posture, and FastDTW algorithm is applied for similarity calculations with the standard driving posture to detect abnormal postures of drivers. The paper also implements an improved driver abnormal behavior detection method and develops a detection system based on this method. Next, we will further study the identification method for abnormal driving posture behavior and define various types of abnormal driving behavior, including specific categories of posture abnormal behavior in drivers, and apply the method to passenger transportation centers, large fleets or bus groups, etc. to standardize the management and restraint of drivers’ driving behavior [21]. Acknowledgments. This work was supported by the Funding Project of Beijing Social Science Foundation (No. 19YTC043).
250
X. Cai et al.
References 1. Zhang, H., Zhuang, X., Zheng, J.: Optimization of YOLO network for human anomalous behavior detection. Comput. Eng. Appl. 59(7), 242–249 (2023) 2. Bao, G., Xi, X., Zhang, H.: A review of human abnormal behavior detection based on deep learning. Ind. Control Comput. 35(5), 102–103+106 (2022), 3. Tawari, A., Martin, S., Trivedi, M.: Continuous head movement estimator for driver assistance: issues, algorithms, and on-road evaluations. IEEE Trans. Intell. Transp. Syst. 15(2), 818–830 (2014) 4. Diaz-Chito, K., Hernández-Sabaté, A., López, A.: A reduced feature set for driver head pose estimation. Appl. Soft Comput. 45(3), 98–107 (2016) 5. Diaz-Chito, K., Del Rincon, M., Hernández-Sabaté, A.: Continuous head pose estimation using manifold subspace embedding and multivariate regression. IEEE Access 6(18), 325–334 (2018) 6. Borghi, G., Fabbri, M., Vezzani, R.: Face-from-depth for head pose estimation on depth images. IEEE Trans. Pattern Anal. Mach. Intell. 42(3), 596–609 (2018) 7. Hu, T., Jha, S., Busso, C.: Robust driver head pose estimation in naturalistic conditions from point-cloud data. In: 2020 IEEE Intelligent Vehicles Symposium (IV), pp. 1176–1182. IEEE Press (2020) 8. Liu, Y., Lasang, P., Pranata, S.: Driver pose estimation using recurrent lightweight network and virtual data augmented transfer learning. IEEE Trans. Intell. Transp. Syst. 20(10), 3818–3831 (2019) 9. Deng, X., Zhang, Y., Yang, S.: Joint hand detection and rotation estimation using CNN. IEEE Trans. Image Process. 27(4), 1888–1900 (2017) 10. Wang, Q., Zhang, G., Yu, S.: 2D hand detection using multi-feature skin model supervised cascaded CNN. J. Signal Process. Syst. 91(10), 1105–1113 (2019) 11. Xia, Y., Yan, S., Zhang, B.: Combination of ACF detector and multi-task CNN for hand detection. In: 2016 IEEE 13th International Conference on Signal Processing (ICSP), pp. 601– 606. IEEE Press (2016) 12. Yan, S., Xia, Y., Smith, S.: Multiscale convolutional neural networks for hand detection. Appl. Comput. Intell. Soft Comput., 9830641–9830654 (2017) 13. Yuen, K., Trivedi, M.: Looking at hands in autonomous vehicles: a convnet approach using part affinity fields. IEEE Trans. Intell. Veh. 5(3), 361–371 (2019) 14. Chen, N., Man, Y., Ning, W.: A deep learning-based approach for monitoring abnormal pilot driving behavior. J. Saf. Environ. 22(1), 249–255 (2022) 15. Wang, C., Bochkovskiy, A., Liao, H.: YOLOv7: trainable bag-of-freebies sets new state-ofthe-art for real-time object detectors. arXiv preprint arXiv, 2207, 02696 (2022) 16. Cao, Z., Simon, T., Wei, S.: Realtime multi-person 2D pose estimation using part affinity fields. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7291–7299 (2017) 17. Wang, L., Zhao, T., Wang, W.: Pedestrian detection and tracking algorithm with pose change robustness. Comput. Eng. Des. 43(10), 2877–2881 (2022) 18. Xiong, M., Li, J., Xiong, J.: Pedestrian fall detection method based on optical flow reconstruction and deep pose features. Sci. Technol. Eng. 22(35), 15688–15696 (2022) 19. Liu, W., et al.: SSD: single shot multibox detector. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 21–37. Springer, Cham (2016). https://doi.org/ 10.1007/978-3-319-46448-0_2 20. Girshick, R.: Fast R-CnN. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1440–1448 (2015) 21. Dotse, J., Nicolson, R., Rowe, R.: Behavioral influences on driver crash risks in Ghana: a qualitative study of commercial passenger drivers. Traffic Inj. Prev. 20(2), 134–139 (2019)
Semi-supervised Semantic Segmentation Algorithm for Video Frame Corruption Jingyan Ye1,2 , Li Chen1,2(B) , and Jun Li3 1 School of Computer Science and Technology, Wuhan University of Science and Technology,
Wuhan 430065, Hubei, China {yejy,chenli}@wust.edu.cn 2 Hubei Province Key Laboratory of Intelligent Information Processing and Real-Time Industrial System, Wuhan 430065, Hubei, China 3 Wuhan Dongzhi Technology Co., Ltd., Wuhan 430062, Hubei, China
Abstract. To address the problems of lack of labeled data and inaccurate segmentation in semantic segmentation of corrupted frame in surveillance video, a semi-supervised semantic segmentation method based on pseudo label filter with weak-strong perturbation and horizontal continuity enhancement module are proposed. The weak-strong perturbation-based pseudo label filter method performs selective re-training via prioritizing reliable unlabeled images based on holistic image-level stability. Concretely, weak-strong perturbations are applied on unlabeled images, and the discrepancy of their predictions serves as a measurement for stability of pseudo label. In addition, the horizontal continuity enhancement module is designed to make the model learn clearer inter-class boundaries of corrupted frame data. To validate the proposed method, the corrupted frame data from the security surveillance system are collected to produce a dataset for validation experiments. The experimental results show that our work outperforms other semi-supervised methods in terms of mean intersection over union (MIoU), demonstrating the effectiveness of the proposed method and the horizontal continuity enhancement module on the semantic segmentation task for corrupted frames. Keywords: Semi-Supervised Learning · Semantic Segmentation · Video Frame Corruption
1 Introduction Video frame corruption refers to the phenomenon of block distortion or fault in the screen during video playback, commonly with vertical or horizontal bars, or irregular figures displayed on the screen (Fig. 1). In the field of security surveillance, video frame corruption can occur due to adverse network conditions. High-definition surveillance video generates substantial data that can saturate network resources during real-time transmission, causing data packet loss. Incomplete video information received out of packet loss, leading to frame corruption. Frame corruption will seriously affect the reliability of video data, and make the monitoring data unable to play its regulatory © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 251–262, 2023. https://doi.org/10.1007/978-981-99-4761-4_22
252
J. Ye et al.
role. Moreover, nowadays, image analysis technology based on artificial intelligence is widely used in the field of monitoring, and frame corruption will interfere with intelligent image analysis and affect its accuracy. However, processing frame corruption manually for large amounts of surveillance video data is inefficient and expensive.
Fig. 1. Samples of video frame corruption.
To address the issue of corrupted frames in security surveillance systems, this paper proposes the use of deep learning-based visual semantic segmentation. Semantic segmentation [1, 2] can classify image pixels into predefined categories and assign semantic categories to different regions of the image. By employing semantic segmentation methods to segment corrupted frames in security surveillance video images, the negative impact of frame corruption on security surveillance systems can be effectively eliminated. Notwithstanding its usefulness, the application of semantic segmentation to frame corruption encounters the following challenges: One of the challenges faced in applying semantic segmentation to address frame corruption in security surveillance systems is the lack of a dataset for semantic segmentation tasks related to frame corruption. In the field of semantic segmentation, supervised learning-based segmentation methods have produced notable results. However, pixellevel labels required for training are limited, which hinders the development of semantic segmentation. Pixel-level labeling is a laborious task, and annotating large amounts of data implies a high cost for manual labeling. To obtain high-performance semantic segmentation and reduce sample annotation costs, researchers have explored various semi-supervised semantic segmentation strategies. Secondly, the target region of corrupted image data has unique characteristics that distinguish it from public datasets. The texture of the target region exhibits strong randomness, making it straightforward for the segmentation model to segment the target region into fragments, which is not suitable for the intended purpose. Additionally, the high similarity of texture between the target region and the background region presents challenges in segmenting an accurate inter-class boundary. These characteristics of the target region increase the complexity of the segmentation task relative to public datasets.
Semi-supervised Semantic Segmentation Algorithm
253
To address the above problems, this paper proposes an innovative semantic segmentation method that includes stability estimation by weak and strong perturbation, as well as a semantic segmentation and horizontal continuity enhancement module. The proposed method exhibits the following innovative features: • The reliability assessment strategy based on weak-strong perturbation performs selective re-training through prioritizing reliable images based on holistic prediction-level stability throughout the training process. • The horizontal continuity enhancement module, developed for the specificity of frame corruption, extracts the horizontal continuity weights from the feature map, weighs the feature map, and fuses it with the feature map via residual linking. This approach enables the model to learn a clearer inter-class demarcation, which leads to an improvement in the segmentation model’s performance.
2 Related Work Semi-supervised Learning. In semi-supervised learning, two main methodological branches are proposed, namely, consistency regularization [3–5, 7–10] and entropy minimization [11–15]. Consistency regularization forces the currently optimized model to produce stable and consistent predictions on the same unlabeled data under various perturbations [16], such as shape and color. Earlier work maintained a teacher parameterized by updating the student’s exponential moving average [10] to produce more reliable manual labels for the student model. Entropy minimization utilizes unlabeled data in an explicit bootstrap fashion, where unlabeled data are assigned pseudo-labels to be trained jointly with manually labeled data. Among them, FixMatch proposes to inject strong perturbations to unlabeled images and supervise the training process with predictions from weakly perturbed ones to subsume the merits of both methodologies. ST++ [17] designs a framework to progressively leverage unlabeled images. Semi-supervised Semantic Segmentation. Earlier works [18, 19] in semi-supervised semantic segmentation have employed Generative Adversarial Networks [20] as an auxiliary supervision signal for unlabeled data. Nonetheless, the optimization of GANs is challenging and can lead to mode collapse issues [21]. Subsequently, numerous methods [22, 24–27] have proposed straightforward mechanisms to solve this task, including the enforcement of similar predictions under multiple perturbed embeddings, two different contextual crops, and dual differently initialized models. Building upon FixMatch, PseudoSeg adopted the weak-to-strong consistency approach to segmentation scenarios and then included a calibration module to refine pseudo masks. During this trend, French et al. identified that Cutout [28] and CutMix [29] are critical to the success of consistency regularization in segmentation. AEL then designed an adaptive CutMix and sampling strategy to enhance learning in underperforming classes. Inspired by contrastive learning, Lai et al. [30] proposed the enforcement of predictions of shared patches under different contextual crops to be the same. U2PL [31] treated uncertain pixels as reliable negative samples and contrasted them against corresponding positive samples. Similar to
254
J. Ye et al.
the core spirit of co-training and mutual learning [32], CPS introduced dual independent models to supervise each other. Stability Estimation. Previous methods estimate model stability using Bayesian analysis. However, limited by the computational burden of Bayesian inference, some other methods use Dropout [33]. In the semi-supervised setting, FixMatch simply sets a confidence threshold to filter uncertain samples, DMT [34] maintains two networks with different initializations to highlight inconsistent regions, and ST++ estimates image-level stability by measuring the overall predictive stability of image masks.
3 Method 3.1 Problem Definition In semi-supervised semantic segmentation [15], the training dataset set is a combination of pixel-wise labeled images and unlabeled images. Let Dl = x1l , y1 , . . . , xnl , yn u } represent m unlabeled samples, represent n labeled samples and Du = {x1u , . . . , xm u l where xi is the i-th unlabeled input sample and xi is the i-th labeled input sample with a spatial dimension of H × W and its corresponding pixel-level label yi . In most works, the overall optimization target is formalized as: L = Ls + λLu
(1)
where λ acts as a tradeoff between labeled and unlabeled data. The Ls is supervised loss and the Lu is the unsupervised loss. In our work, we follow the pseudo-label branch of semi-supervised semantic segmentation. The pseudo-labeling method consists of the following steps: 1) Pretrain: The initial model is trained using the available labeled data. 2) Generate pseudo label: Obtain predictions from unlabeled data by the trained model and generate pseudo labels based on the predictions. 3) Retrain: The model is tuned using a new dataset consisting of real labeled and pseudolabeled data. 3.2 Weak-Strong Perturbation Based Pseudo Label Filter In the pseudo-labeling, an issue that needs to be addressed is the potential accumulation of incorrect knowledge. This can occur due to the formation of inaccurate labels resulting from the generation rules of pseudo labels. Incorrect labels can mislead the model and negatively impact its performance. Several techniques have been explored to tackle this problem, including estimating the reliability of images or pixels using methods like confidence distribution based on the model’s final prediction output and filtering lowconfidence pixels by applying defined thresholds [5]. Another strategy involves training two models with different initializations to predict the same unlabeled samples and using the stability loss to reduce discrepancies in prediction outcomes [34]. These approaches are commonly adopted by computer vision researchers to overcome the challenges of pseudo-labeling and improve model accuracy.
Semi-supervised Semantic Segmentation Algorithm
255
Fig. 2. Weak-strong perturbation stability estimation.
Furthermore, studies have suggested a positive correlation between a model’s segmentation performance and the stability of its pseudo-labels. Therefore, it is crucial to select unlabeled images with improved prediction results based on the reliability of the given image during the generation of pseudo labels. ST++ [17] saves checkpoints during model training to generate pseudo masks for unlabeled images. These masks are used to evaluate the reliability of the original image data and to determine the mean Intersection over Union (MIoU) between each earlier pseudo mask and the final mask. This measurement serves as an indicator of the stability and reliability of the image. However, the selection of checkpoints is critical, as it can
256
J. Ye et al.
significantly impact the performance of the model, especially when applied to new tasks. We hope to measure the reliability of unlabeled images more succinctly but not select the checkpoints manually. In our work, a single model is employed for computing an image’s reliability without requiring the manual selection of a confidence threshold. To ensure a more stable assessment of reliability, this paper’s method of filtering unreliable samples is based on image-level rather than widely adopted pixel-level information (Fig. 2). Specifically, for unlabeled images xiu ∈ Du , different kinds of perturbations are applied to obtain image xiuw and image set Dius . Where xiuw denotes the image obtained by weak perturbation and Dius denotes the images set obtained by strong perturbation. we collect the prediction results piuw and Pius by the pretrained model, using the mean square error to measure the stability of xiu : si =
pijus ∈Pius
MSE(piuw , pijus )
(2)
where si is the stability score of the image, which reflects the stability of xiu . We compute stability scores for all unlabeled images and rank the entire set based on these scores, selecting the top R images with the lowest scores for the retraining phase. This procedure optimizes the model, and the remaining unreliable images are relabeled and retrained a second time using the full combination of manually labeled and pseudo-labeled data. 3.3 Horizontal Continuity Enhancement For simple corrupted frame images, the existing models can achieve accurate segmentation, but for complex corrupted frame images, fragmentation of segmentation results and unclear segmentation boundaries often occur. In this paper, a horizontal continuity enhancement module is proposed. In the observation of security video frame corruption data, most frame corruption area in the data distributes in the horizontal direction with continuity. Specifically, if the frame corruption occurs at a certain height of the image, the pixels at this height are all frame-corrupted pixels with high probability (see Fig. 3). In order to enhance the learning ability of the model for this feature, this paper designs the Horizontal Continuity Enhancement module, which inductively obtains the attention map of the feature map in the horizontal direction by pooling the feature maps in the model in the horizontal direction. The attention map is then fused with the feature map in the form of residual links to obtain a feature map that is more sensitive to the horizontal continuum (Fig. 4). The calculation formula is as follows: Mh (F) = σ (W × (HAvgPool(F, d ) + HMaxPool(F, d ))) × F + F
(3)
where σ denotes the sigmoid function, HAvgPool denotes the average pool function on horizontal direction with step d , HMaxPool denotes the max pool function on horizontal direction with step d .
Semi-supervised Semantic Segmentation Algorithm
257
Fig. 3. Samples with horizontal continuity. (a), (b) show the image with horizontal continuity feature, (c), (d) show the label of frame corruption image in the same row.
Fig. 4. Semantic segmentation model with horizontal continuity enhancement module.
258
J. Ye et al.
4 Experiment 4.1 Dataset We collected Frame Corruption dataset from the security monitoring system as experiment dataset, including corrupted frame images and normal images. The Frame Corruption dataset includes 800 labeled training data and 3000 unlabeled training data. The labeled data are divided into training set and validation set, where the training set and validation set each contain 400 images (Fig. 5).
Fig. 5. Samples of Frame Corruption dataset. The first line is the image and the second line is the corresponding label.
4.2 Evaluation Metrics Following previous papers [9, 17], we report the mean Intersection-over-Union (MIoU) for validation set for both datasets. All the results are based on the single scale inference. 4.3 Implementation Details The experiments are conducted on NVIDIA RTX 3070 GPU and the PyTorch deep learning framework. During model training, the input image size was resized to 512 × 512 pixels. Weak data augmentation techniques like random cropping and resizing, and strong data augmentation techniques such as random color enhancement, random Gaussian blurring, and random grayscale were incorporated. The loss function used during training was the cross-entropy function, and network parameters were optimized using the Stochastic Gradient Descent (SGD) optimizer. The initial learning rate was 0.0001, momentum was set to 0.9, the weight decay value was 0.0001, and the model convergence was accelerated using the cosine learning rate strategy. These values were chosen based on empirical analysis and were used as default values throughout the experiments. The semantic segmentation model used in all methods was DeeplabV3+ deep learning network based on ResNet18. This experimental setup is commonly adopted in computer vision research to train and evaluate deep learning models for various tasks.
Semi-supervised Semantic Segmentation Algorithm
259
4.4 Comparison The performance of this method, baseline method, and other advanced semi-supervised semantic segmentation methods are compared on the Frame Correction dataset, and the experimental results are shown in Table 1. We evaluate the model performance by mean intersection over union (MIoU). It can be observed from the table that our work is superior to the baseline method and other semi-supervised methods in the scenario of a small amount of label data. Such as only 50 and 100 labeled images are available, our method is superior to the previous best results by 1.7% and 1.2% respectively. Table 1. Compare our method with other semi-supervised method. The fraction (e.g., 1/16) and number (e.g., 25) denote the proportion and number of labeled images. The best results are marked in bold. Method
1/16(25)
1/8(50)
1/4(100)
1/2(200)
1/1(400)
Baseline
0.928
0.941
0.961
0.973
0.976
PS-MT [8]
0.882
0.916
0.925
0.943
0.939
UniMatch [9]
0.918
0.920
0.934
0.953
0.952
ST++ [17]
0.917
0.938
0.969
0.974
0.978
MT [10]
0.869
0.903
0.910
0.937
0.948
CPS [23]
0.869
0.904
0.919
0.944
0.951
Ours
0.933
0.958
0.973
0.977
0.978
Figure 6 shows the visualization results of some methods on the Frame Corruption dataset. The findings of this study indicate that the proposed method achieves higher accuracy in the segmentation of simple frame corruption. Moreover, the method also demonstrates superior segmentation performance for complex scenes and images with high segmentation difficulty. Table 2. Ablation experiment result of pseudo label filter (PLF) and horizontal continuity enhancement module (HCE). The best results are marked in bold.
To verify the effectiveness of the horizontal continuity enhancement module, ablation experiments are designed in this paper for validation. The experiments are conducted on the dataset Frame Corruption, and the results are shown in Table 2. From the table, we can see that the MIoU indexes of the model on the Frame Corruption dataset are improved after adding the horizontal continuity enhancement module in both fully
260
J. Ye et al.
supervised learning and semi-supervised learning, which proves that the horizontal continuity enhancement module can enhance the semantic segmentation performance of the model on splash screen images.
Image
Ground Truth
Ours
UniMatch
Fig. 6. Qualitative result on Frame Corruption dataset.
5 Conclusion In this work, we propose a semi-supervised semantic segmentation method based on weak-strong perturbation stability estimation to alleviate the potential performance degradation incurred by incorrect pseudo labels. Meanwhile, for the complex frame corruption samples, a horizontal continuity enhancement module is constructed to generate more accurate segmentation. With extensive experiments conducted across a variety of benchmarks and settings, our work outperforms previous methods. Moreover, we further examine the effectiveness of each component. In further work, we aim to design a more universal semantic segmentation method and study whether there is another way to filter stable pseudo label. Acknowledgments. This work was supported by National Natural Science Foundation of China (62271359).
References 1. Wang, R., Lei, T., Cui, R., Zhang, B., Meng, H., Nandi, A.K.: Medical image segmentation using deep learning: a survey. IET Image Process. 16(5), 1243–1267 (2022)
Semi-supervised Semantic Segmentation Algorithm
261
2. Peláez-Vegas, A., Mesejo, P., Luengo, J.: A Survey on Semi-Supervised Semantic Segmentation (2023) 3. French, G., Laine, S., Aila, T., Mackiewicz, M., Finlayson, G.: Semi-supervised semantic segmentation needs strong, varied perturbations (2020) 4. Kim, J., Jang, J., Park, H., Jeong, S.: Structured consistency loss for semi-supervised semantic segmentation (2021) 5. Sohn, K., et al.: FixMatch: Simplifying semi-supervised learning with consistency and confidence. Adv. Neural Inform. Process. Syst. 33, 596–608 (2020) 6. Xie, Q., Dai, Z., Hovy, E., Luong, M.T., Le, Q.V.: Unsupervised data augmentation for consistency training. Adv. Neural Inform. Process. Syst. 33, 6256–6268 (2020) 7. Chen, Z., Zhang, R., Zhang, G., Ma, Z., Lei, T.: Digging into pseudo label: a low-budget approach for semi-supervised semantic segmentation. IEEE Access 8, 41830–41837 (2020) 8. Liu, Y., Tian, Y., Chen, Y., Liu, F., Belagiannis, V., Carneiro, G.: Perturbed and strict mean teachers for semi-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4258–4267 (2022) 9. Yang, L., Qi, L., Feng, L., Zhang, W., Shi, Y.: Revisiting weak-to-strong consistency in semi-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7236–7246 (2022) 10. Tarvainen, A., Valpola, H.: Mean teachers are better role models: weight-averaged consistency targets improve semi-supervised deep learning results. Adv. Neural Inform. Process. Syst. (2018) 11. Li, H., Zheng, H.: A residual correction approach for semi-supervised semantic segmentation. In: Ma, H., et al. (eds.) PRCV 2021. LNCS, vol. 13022, pp. 90–102. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-88013-2_8 12. Pham, H., Dai, Z., Xie, Q., Le, Q.V.: Meta pseudo labels. In: IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2021, virtual, June 19–25, 2021, pp. 11557–11568. Computer Vision Foundation/IEEE (2021) 13. Zou, Y., et al.: PseudoSeg: designing pseudo labels for semantic segmentation (2021) 14. Xie, Q., Luong, M.T., Hovy, E., Le, Q.V.: Self-training with noisy student improves ImageNet classification. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10687–10698 (2020) 15. Liu, S., Zhi, S., Johns, E., Davison, A.J.: Bootstrapping semantic segmentation with regional contrast (2022) 16. Sajjadi, M., Javanmardi, M., Tasdizen, T.: Regularization with stochastic transformations and perturbations for deep semi-supervised learning. Adv. Neural Inform. Process. Syst. 29 (2016) 17. Yang, L., Zhuo, W., Qi, L., Shi, Y., Gao, Y.: ST++: make self-training work better for semi-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4268–4277 (2022) 18. Hung, W.C., Tsai, Y.H., Liou, Y.T., Lin, Y.Y., Yang, M.H.: Adversarial learning for semisupervised semantic segmentation (2018) 19. Souly, N., Spampinato, C., Shah, M.: Semi supervised semantic segmentation using generative adversarial network. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 5688–5696 (2017) 20. Creswell, A., White, T., Dumoulin, V., Arulkumaran, K., Sengupta, B., Bharath, A.A.: Generative adversarial networks: an overview. IEEE Sign. Process. Mag. 35(1), 53–65 (2018) 21. Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., Chen, X.: Improved techniques for training gans. Adv. Neural Inform. Process. Syst. 29 (2016) 22. Alonso, I., Sabater, A., Ferstl, D., Montesano, L., Murillo, A.C.: Semi-supervised semantic segmentation with pixel-level contrastive learning from a class-wise memory bank. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8219–8228 (2021)
262
J. Ye et al.
23. Chen, X., Yuan, Y., Zeng, G., Wang, J.: Semi-supervised semantic segmentation with cross pseudo supervision. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2613–2622 (2021) 24. He, R., Yang, J., Qi, X.: Re-distributing biased pseudo labels for semi-supervised semantic segmentation: a baseline investigation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6930–6940 (2021) 25. Hu, H., Wei, F., Hu, H., Ye, Q., Cui, J., Wang, L.: Semi-supervised semantic segmentation via adaptive equalization learning. Adv. Neural. Inf. Process. Syst. 34, 22106–22118 (2021) 26. Lai, X., et al.: Semi-supervised semantic segmentation with directional context-aware consistency. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1205–1214 (2021) 27. Ouali, Y., Hudelot, C., Tami, M.: Semi-supervised semantic segmentation with crossconsistency training. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12674–12684 (2020) 28. DeVries, T., Taylor, G.W.: Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552 (2017) 29. Yun, S., Han, D., Oh, S.J., Chun, S., Choe, J., Yoo, Y.: Cutmix: regularization strategy to train strong classifiers with localizable features. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 6023–6032 (2019) 30. Wang, Y., et al.: Semi-supervised semantic segmentation using unreliable pseudo-labels. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4248–4257 (2022) 31. Qiao, S., Shen, W., Zhang, Z., Wang, B., Yuille, A.: Deep co-training for semi-supervised image recognition. In: Proceedings of the European conference on computer vision (ECCV), pp. 135–152 (2018) 32. Zhang, Y., Xiang, T., Hospedales, T.M., Lu, H.: Deep mutual learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4320–4328 (2018) 33. Liang, X., et al.: R-Drop: regularized dropout for neural networks. Adv. Neural Inform. Process. Syst. 34, 10890–10905 (2021) 34. Feng, Z., et al.: DMT: dynamic mutual training for semi-supervised learning. Pattern Recogn. 130, 108777 (2022)
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model Lei Shi1,2,4 , Hongqiu Xue1 , Caixia Meng2,3(B) , Yufei Gao1,4 , and Lin Wei1 1 School of Cyber Science and Engineering, Zhengzhou University, Zhengzhou 450002, China 2 School of Computer and Artificial Intelligence, Zhengzhou University, Zhengzhou 450001,
China [email protected] 3 Department of Image and Network Investigation Technology, Railway Police College, Zhengzhou 450053, China 4 SongShan Laboratory, Zhengzhou 450046, China
Abstract. The fall events in crowded places are prone to public safety problems, where real-time monitoring and early warning of falls can reduce the safety risks. Aiming at the problems of large scale and poor timeliness of existing fall detection methods based on pose estimation, an OpenPose human fall detection algorithm called DSC-OpenPose is proposed, which incorporates an attention mechanism. Using DenseNet dense connection idea as reference, each layer is directly connected to all previous layers in the channel dimension to achieve feature reuse and reduce the size of model parameters. In order to get the spatial direction dependency and precise location information of the feature map and to increase the pose estimation accuracy, the coordinate attention method is introduced between various stages. The method is proposed to identify fall behavior based on human outer ellipse parameters, head height and lower limb height together to achieve fall detection of human targets. It is showed that the algorithm achieves a good balance between model size and accuracy on the COCO dataset. The fall detection approach simultaneously achieves 98% accuracy and 96.5% precision on the RF dataset, reaching a detection speed of 20.1 frames/s. The model is small enough to support the real-time inference requirements of embedded devices. Keywords: Human Pose Estimation · Dense Connection · Attention Mechanism · Fall Detection
1 Introduction According to the World Health Organization, falling is the second most common cause of unintentional death and injury. A stampede that results from falling in a crowded public area might easily cause a significant safety hazard. Both the 2014 Shanghai Bund and the 2022 Korea Ri Taewon mass stampede were caused by falls, as an abnormal event that threatens life safety, the detection and warning of falls has become a focus of research today. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 263–276, 2023. https://doi.org/10.1007/978-981-99-4761-4_23
264
L. Shi et al.
There are three main fall detection methods: wearable device-based fall detection methods, scene sensor-based fall detection methods, and computer vision-based fall detection methods. The wearable device-based technology mainly uses a gyroscope and accelerometer to detect fall behavior, however, its detection accuracy may be impacted due to the requirement of wearing the device and ensuring stable connectivity. Scene sensor-based technology recognizes fall behavior by processing and analyzing sensor data from depth, infrared, acoustic, and vibration sensors, but the sensors are prone to interference and cannot be utilized in outdoor environments. Computer vision-based fall detection is a better alternative, which primarily uses pose estimation and human skeletal sequence extraction to detect fall behavior. However, there is a need to improve the realtime performance and accuracy of detection in crowded public areas. Additionally, the pose estimation model has a large size, and the accuracy of extracted key points and skeleton information is insufficient, leading to errors in fall detection. In summary, the following are our contributions: 1) The OpenPose detection approach is utilized alongside the DenseNet dense connection concept for exchanging parameters and features across each dense block, resulting in a considerable reduction in the number of model parameters; 2) Coordinate attention mechanisms have been incorporated between each DenseBlock in this paper to enhance the accuracy of heatmaps and keypoints identification; 3) A CNN-based fall detection method is developed which utilizes an upgraded pose estimation network along with human posture outer ellipses, and the experiments indicate that the existing approach is outperformed by the suggested method.
2 Related Work Most researchers have used deep learning-based pose estimation methods for fall detection, Kang et al. [1] introduced a fall detection system using pose estimation and GRU (Gated Recurrent Unit), using Pose Net for pose estimation and inputting the extracted pose information into GRU for fall detection, but the real-time performance of the detection needs to be improved. Chen et al. [2] and Nogas et al. [3] proposed a fall detection method based on OpenPose, Chen [2] combined OpenPose to detect fall behavior using “fall velocity of hip center”, “body centerline angle” and “aspect ratio of the bounding box”, Nogas [3] determined the occurrence of a fall by using an LSTM network incorporating OpenPose and correlating it with spatio-temporal information. However, the lack of accuracy of keypoints and skeleton information extracted by the pose estimation in the above methods leads to errors in fall detection. Single-person posture estimation and multi-person posture estimation are the two categories of posture estimation. Regression approaches and body component detection methods are the primary techniques used for single-person posture estimation, which provide high detection accuracy but are not suitable for fall detection in complex contexts. Top-down and bottom-up approaches are the two basic strategies for multi-person posture estimation. Two-stage detection algorithms, such as RSN [4] and ViTPose [5], are principal components of the top-down approach. In the first stage, the target human body is identified using a target detection method, and a rectangular box containing the human body area is marked. In the second stage, the key points of the marked area are
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
265
detected. To overcome the drawback of a high computational cost, Pishchulin et al. [6] proposed a DeepCut detector based on the R-CNN framework that, like the OpenPose algorithm, first finds and labels all candidate keypoint sites in the image before combining the candidate sites into the final human pose using an integer linear programming method; Chen et al. [9] proposed HigherHRNet, a model with higher resolution detection performance based on HRNet. However, the model has the drawback of being complex and having a large number of parameters. Bottom-up approaches like OpenPose, which may ensure detection accuracy with quick detection speed, are suited for multi-person pose estimation in complicated situations. Nonetheless, they still have issues like big models, redundant parameters, and long detection times.
3 DSC-OpenPose An improved OpenPose human pose estimation model is proposed in this paper, which is based on the OpenPose [10] network structure and incorporates the dense connection idea of DenseNet. The aim is to address the problems caused by too many layers in the traditional OpenPose structure, including gradient disappearance, feature transfer, and parameter redundancy. Firstly, all the convolutional layers are directly connected at each stage to generate the DenseBlock, which improves the parameter sharing and reduces the parameter redundancy and model size, while ensuring the maximum information transfer between network layers. Additionally, each layer combines the inputs of all preceding levels and transmits the output features to all succeeding layers to ensure feedforwardness. A transition layer is inserted between consecutive DenseBlocks in order to minimize the size of the feature map, lower the number of channels, increase model density, and further compress the model. By adding the Coordinate Attention mechanism between several DenseBlocks, it is possible to address the issue of the classic OpenPose network’s low detection accuracy in complicated environments without adding an excessive number of parameters. Finally, the model uses five recurrent fine-tuning steps to generate the Confidence Map (Fig. 1).
Fig. 1. Improved posture estimation network DSC-OpenPose
266
L. Shi et al.
3.1 DenseBlock Assume that the convolutional network’s input is X0 , and that it contains n levels with each layer implementing a nonlinear transformation Hn (·). These nonlinear transformations can be composite functions of various processes, such as ReLU, convolution, pooling, and so on. To address the issue of gradient disappearance, the classic OpenPose model creates a jump link between the (n − 1)th and nth layer. Feature transfer works as follows: Xn = Hn (Xn−1 ) + Xn−1
(1)
The input of each layer is created from the output of all preceding layers using Cocatnate parallel operation, and the learnt feature maps of each layer are immediately transferred to all following layers. Any two convolutional blocks can be linked to one another inside the proposed DenseBlock. The feature transfer appears as: Xn = Hn ([X0 , X1 , · · · , Xn−1 ])
(2)
where, [X0 , X1 , · · · , Xn−1 ] represents the feature maps created at each layer, from layer 0 to layer n − 1. Each layer’s feature maps in DenseBlock [11] have the same size, i.e., C, H, and W are the same, so they can be linked in the channel dimension. The internal structure of DenseBlock is shown in Fig. 2.
Fig. 2. DenseBlock internal structure
3.2 Transition Layer Due to the dense connection feature in DenseBlock, there are more splicing operations, resulting in more channels in the final output of DenseBlock. Therefore, the Transition layer is utilized for dimensionality reduction. It directly connects two adjacent dense blocks and downsamples them with 1 × 1 convolutional blocks to reduce the dimensionality in the channel dimension, which is BN + ReLU + 1 × 1 Conv. The Transition layer creates an output feature map θ m that indicates the compression factor θ if a DenseBlock has m feature mappings.
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
267
3.3 Coordinate Attention The keypoint position information is more crucial to the detection outcome throughout the pose estimation procedure. The focus on location information is achieved without introducing extra parameters by efficiently integrating the position information into the channel attention mechanism. As a result, the model can more effectively locate and discern the desired target. The coordinate attention mechanism module [15] employed in this work is shown structurally in Fig. 3.
Fig. 3. Coordinate Attention Module
The coordinate attention mechanism module [14] includes the coordinate information embedding module and the coordinate attention generation module. The global pooling method used by the coordinate information embedding module (Fig. 3(1)) splits the X and Y coordinates into two one-dimensional features z k and z w , , and for the feature maps produced independently by the two branches, the sum is produced by encoding each channel with an average pooling layer of kernel size (H * 1) and (1 * W), respectively. The coordinate attention generating module (Fig. 3(2)) converts the two feature maps z k z w , then it uses a 1 × 1 convolutional pair F1 to transform them using the transformation equation shown below: f = δ(F1 ([z h , z w ])) f ∈ C/r × (H + W )
(3)
f refers to the intermediate feature map, which is partitioned into two tensors of f h and f w along the spatial dimension, of the spatial information in the horizontal and vertical directions. Finally, the final attention weights are produced by using the sigmoid activation function, and the resulting weights are multiplied with the input feature map to achieve the attention weighting operation. The channel dimensions of the two tensors are converted into the channel dimensions of the input information by using the 1 × 1 convolution F h and F w , respectively. The exact implementation formula is as follows: g h = σ (F h (f h )), g w = σ (F w (f w )).
(4)
The attention weighting operation not only acquires the significance of each channel, but also collects position information through the aggregation of spatial directional attributes to accurately capture the major locations of interest. The output of the coordinate attention mechanism module is input to the next DenseBlock as: y(i, j) = x(i, j) × g h (i) × g w (j)
(5)
268
L. Shi et al.
4 Fall Detection Algorithm Based on Pose Estimation 4.1 Fall Indicators 4.1.1 Human Posture External Ellipse There are certain disparities between the rectangle characteristics and the human body shape, making it difficult to more properly define the body shape features in the fall detection phase of the studies that employ external rectangles to produce the human body border delineation. Therefore, the position qualities of the human body are represented by an external ellipse, which is more similar to the form of the human body. The ellipse can be represented by four parameters: the coordinates of the center of the ellipse, O(x, y); the angle between the long axis of the ellipse and the horizontal direction, θ ; the length of the long semi-axis of the ellipse, el ; and the length of the short semi-axis of the ellipse, es . Based on the characteristics of the human body’s structure, the key points of the left eye, right eye, left shoulder, right shoulder, left elbow, right elbow, left ankle, and right ankle are selected, and the ellipse is fitted using the least squares method. The ellipse’s equation is known as: ax2 + bxy + cy2 + dx + ey = 1
(6)
Let α = [a, b, c, d , e]T , X = [x2 , xy, y2 , x, y]T , so the equation can be expressed as αX = 1, the fitting ellipse optimization issue may therefore be stated as: min Dα2
s.t. α T Cα = 1
(7)
where D represents a set of data samples n × 6, 6 indicates the dimension, n indicates the sample size, α denotes the parameters of the elliptic equation. According to the Lagrange multiplier approach, the following two equations are derived by inserting the Lagrange factor λ. 2DT Dα − 2λCα = 0 α T Cα = 1
(8)
Let S = DT D, the equations above may therefore be rewritten as: Sα = λCα α T Cα = 1
(9)
Solve the eigenvalue and eigenvector (λi , ui ) of equation Sα = λCα, According to equation α T Ca = 1, u can be obtained such that u2 μTi Cui = 1, i.e.: 1 λi 2 = 2 T (10) μi = T μi Cui μi Sui Let αi = uμi , the eigenvector μi corresponding to λi > 0 is taken as the solution of the ellipse fit. The four parameters of an ellipse can be determined based on its geometry as O(x, y), θ , el , es .
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
269
Figure 4 depicts the fitted ellipse for the standing and falling scenarios.
Fig. 4. Fitted ellipse and its parameters
4.1.2 Head Coordinates and Lower Limb Height During the process of the human body from standing to falling, the human head coordinates H (xhead , yhead ) and lower limb heights hlimbs will be changed, as defined in the following Fig. 5:
Fig. 5. Definition of head coordinates and lower limb heights
In the formula, yrknee , yrankle , ylknee , ylankle are the vertical coordinates of right knee, right ankle, left knee and left ankle, respectively. 4.2 Fall Classifier Based on the above fall metrics, for each video sequence to detect a moving object, eight unique features of the feature vector are used to identify the fall behavior of the object. The feature vectors are shown in Eq: F = [x, y, θ, el , es , xhead , yhead , hlimbs ]
(11)
To classify the fall and non-fall behaviors, the feature vectors are fed into a classifier designed based on a CNN (Convolutional Neural Network), whose overall structure and specific inputs and outputs are shown in Fig. 6. The CNN consists of two 2D-convolution layers, one max pooling layer, and two dense layers arranged in a linear stack. The final dense layer is responsible for classifying fall behavior. Batch normalization is added between the two convolution layers to normalize data as the mean and variance change over time during model training.
270
L. Shi et al.
Fig. 6. Fall detection network overall structure
5 Experiments and Result Analysis 5.1 Fall Classifier The pose estimation model experiments use the COCO 2017 dataset and the Al Challenger dataset. The tag data is accessed in JSON format. The self-built RF (Real Fall) behavior dataset, which was employed in the fall detection model, is derived from real photographs from road and railroad surveillance as well as Internet images. The majority of the scenarios take place in public areas like street crossings, train station entrances and exits, etc. The dataset concludes more than 1400 photos and movies, an 8:2 ratio between fall behavior and normal behavior, and a 9:1 ratio between the training set and test set. The experiments were carried out using the deep learning framework PyTorch, the Ubuntu 22.04 LTS operating system, and Python as the primary development language. The GPU used is NVIDIA TESLA T4 (16.0 GB). 5.2 Evaluation Indicators It is not able to compare the projected results with the actual values one at a time during the human keypoint detection algorithm measurement procedure, making it impossible to tell whether the present predicted results have been missed or incorrectly detected. Thus, the similarity between predicted and actual object keypoints is calculated using OKS (Object Keypoint Similarity) metric. The OKS values of 0.50, 0.75, and 0.90 (AP 50 , AP 75 , AP 90 ) are utilized to represent the required matching accuracies for the key points, while the mAP (mean Average Precision) is employed to indicate the average accuracy of all keypoints. Additionally, FPS (Frame Per Second), which refers to the number of frames detected per second, is used as a temporal metric to gauge the algorithm’s efficiency. Four primary metrics commonly used in fall detection, namely accuracy, recall, precision and F1-score, which are calculated as follows: TP TP , Recall = , TP + FP TP + FN Precision · Recall TP + TN , F1 − score = 2 ∗ Accuracy = TP + FN + FP + TN Precision + Recall
Precision =
(12)
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
271
Specifically, TP (True Positive) denotes the number of fall categories correctly detected as fall categories; FP (False Positive) represents the number of non-fall categories incorrectly classified as fall categories; FN (False Negative) refers to the count of fall categories erroneously detected as non-fall categories, and TN (True Negative) indicates the number of non-fall categories properly identified as non-fall categories. 5.3 Model Comparison Experiment The proposed model was experimentally validated using a training procedure consisting of 400 iterations and an initial learning rate of 0.01 [16]. The step decay approach was employed to adjust the learning rate: at iteration 100, the rate decreased to 1/10 of the current value, and at iteration 250, it decreased again to 1/10 of the current rate. The back propagation procedure was implemented using the gradient descent method with a batch size of 64 and a momentum parameter of 0.9. The DSC-OpenPose model was compared to the bottom-up OpenPose, Lightweight OpenPose, Hourglass, and HigherHRNet-w32 models on the COCO dataset and AI Challenger dataset. The comparison experiments were conducted with the same experimental environment and configuration, and the results are presented in Table 1. Table 1. Comparison of results from different models on COCO and AI Challenger datasets Dataset
Model
mAP (%)
AP50 (%)
AP75 (%)
AP90 (%)
COCO 2017
OpenPose [10]
62.1
69.9
59.5
48.6
Lightweight OpenPose [12]
59.1
62.2
50.1
42.2
Hourglass [7]
62.5
70.4
59.8
49.1
HigherHRNet-w32 [9]
63.6
72.9
60.9
47.9
AI Challeger
Ours
63.9
72.5
61.5
47.9
OpenPose [10]
63.8
70.7
61.3
48.3
Lightweight OpenPose [12]
58.9
63.7
50.5
43
Hourglass [7]
64.3
71.6
61.7
49.2
HigherHRNet-w32 [9]
65.1
71.9
61.4
48.9
Ours
64.8
72.3
62.4
49.1
The model was compared with the aforementioned method using metrics such as parameters, operation speed, and FPS in order to evaluate the impact of the dense link mechanism. The results of this comparison were presented in Table 2.
272
L. Shi et al. Table 2. Comparison of different model sizes
Model
Params/107
GFLOPS
FPS/(frame/s)
OpenPose [10]
34.2
33
12.1
Lightweight OpenPose [12]
14.1
8.3
23.2
Hourglass [7]
21.5
19.5
17.4
HigherHRNet-w32 [9]
28.6
23.4
13.6
Ours
19.4
22.3
20.6
The experimental results on the COCO dataset demonstrated that the proposed model improves mAP without increasing the detection period. Furthermore, the detection time was faster than both OpenPose and HigherHRNet-w32 models, while achieving an average accuracy improvement of 4 percentage points compared to Lightweight OpenPose, a similarly fast network. Notably, the proposed model outperformed all other models across all threshold conditions; On the AI Challenger dataset, the proposed model showedan improved mAP compared to both OpenPose and Lightweight OpenPose models. In comparison to HigherHRNet-w32, which has a quite detection accuracy, the proposed model achieved faster detection speed with fewer parameters. Table 3. Comparison of results from attention module ablation experiments DataSet
Model
mAP
Nose
Left Elbow
Right Elbow
Left Knee
Right Knee
Left Ankle
Right Ankle
COCO 2017
Ours (no attention)
62.3
82
78.4
78.3
75.4
72.3
64.5
65.9
Ours (with attention)
63.9
82.4
79.6
77.9
77
74.5
65.5
66.5
Ours (no attention)
63.7
82.1
78.4
78
75.6
72.9
65.6
64.3
Ours (with attention)
64.8
81.7
80.6
79.8
74.5
75.1
66
63.5
AI Challenger
In Table 3, the DSC-OpenPose model had a 43% reduction in parameters compared to the OpenPose model. This was mainly due to the use of Dense Connections, which enhance parameter sharing and greatly reduce parameter storage requirements, resulting in significant savings in parameter space. In terms of computing speed, the proposed model stitched the output of the previous layer together with the output of the current layer through repeated stitching process and passes it to the next layer, which requires opening a new memory for each stitching process to save the stitched features. So, it increases memory usage and slows down computation due to the large feature maps and the substantial amount of computation in the convolution process. Consequently, the computational speed of it does not show significant improvement.
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
273
To evaluate the impact of the added coordinate attention module on detection performance, the attention module was ablated and analyzed on both COCO and AI Challenger datasets. The results were presented in Table 3. The experimental results indicated that the proposed model exhibits improved mAP and keypoint accuracy when the attention module was added, compared to the model without coordinate attention module. Specifically, adding the coordinate attention module resulted in an increase of 1.6 percentage points in mAP on the COCO dataset and 1.1 percentage points on the AI Challenger dataset. Additionally, the model with the attention module showed better performance in some keypoint accuracy metrics [17]. In order to more intuitively reflect the effectiveness of the attention module, pose estimation models with and without coordinate attention mechanism were taken for human skeleton extraction, respectively, the results were shown in Fig. 7:
Fig. 7. Comparison chart of pose estimation. Where the left figure depicts a skeleton map extracted by an algorithm without an attention mechanism, and the right figure shows a skeleton map extracted by an algorithm that utilizes an attention mechanism.
The comparison effect figure showed that after adding the coordinate attention model, the algorithm extracted more accurate pose information and generated a more accurate skeleton map, indicating that the proposed model focused more on identifying key channels and locations of key points, assigning them more weight, ultimately enhancing key point detection accuracy. The effectiveness of the attention module was verified by comparing the pose estimation with and without attention module. In the fall detection part, keypoints were first extracted and skeleton maps were generated for the RF dataset using DSC-OpenPose, Next, the keypoint and skeleton information were utilized to generate fall detection features, which were then used for experiments with the proposed fall classifier. To evaluate the performance of the proposed model, a ten-fold cross-validation method was employed, where the input data was divided into ten groups. Nine groups were used for training and one group was used for validation in each fold, and this process was repeated for all ten possible groups. The Binary Crossentropy loss function was selected as the measure of difference between predicted and true results, and only this error function was utilized to determine the performance metric for training. The Nadam optimizer was used in the training process, which can effectively handle sparse gradients and has the feature of adaptive adjustment of learning rate. The changes in accuracy and loss during model training and testing were shown in Fig. 8.
274
L. Shi et al.
Fig. 8. The graph on the left displays the accuracy changes during both model training and testing, while the graph on the right shows the loss changes during the same processes.
In order to verify its effectiveness, comparison experiments were conducted with other fall detection methods based on posture estimation models and optical flowbased detection methods in the same experimental environment. The comparison took into account metrics such as accuracy, precision, recall, and FPS. The results of the comparison experiments were shown in Table 4. Table 4. Comparison of results from different models on the RF dataset Model
Accuracy (%)
Precision (%)
Recall (%)
F1-score (%)
FPS (frame/s)
OpenPose+Conv+LSTM [2]
96
94.5
97.4
95.9
18.8
PoseNet + GRU [1]
97.2
96.1
97.6
96.8
17.6
3D LPN+CNN [13]
97.6
96.2
97.7
96.9
15.4
VGG+Optical Flow [14]
97.4
96
97.4
96.6
19.1
Ours
98
96.5
96.1
97.2
20.1
The experimental results showed that the proposed fall detection model in this paper achieves an accuracy of 98%, precision of 96.5%, and recall of 96.1%, meaning that most of the fall behaviors will be accurately identified. In comparison with other detection algorithms, the proposed algorithm exhibited superior accuracy and precision while maintaining a faster detection speed than some algorithms. This made it highly suitable for real-time detection demands. The detection effect of the proposed method was shown in Fig. 9. The yellow line indicated the extracted skeleton map, and the red box indicated the abnormal fall state. The experimental results showed that this method can identify the fall behavior more accurately.
DSC-OpenPose: A Fall Detection Algorithm Based on Posture Estimation Model
275
Fig. 9. Display of fall detection results
6 Conclusions The DSC-OpenPose human fall detection algorithm with fused attention mechanism is introduced in this paper, which has three main aspects: 1) It is proposed to use the tandem dense connection block instead of the original CNN architecture to reduce the number of parameters of the model, and the number of parameters is reduced by 43% compared with OpenPose algorithm to achieve a lightweight model; 2) The coordinate attention mechanism is introduced to capture the channel and position information for precisely locating the keypoints and obtaining more accurate pose estimation results; 3) The resulting pose estimation results are used to design a CNN fall classifier to recognize fall behavior, achieving high detection accuracy and fast detection speed. In pose estimation part, the proposed algorithm has good detection accuracy on COCO dataset; in fall detection part, the algorithm achieves an accuracy rate of 98%, precision of 96.5% on RF dataset. To reduce memory occupation is an area for further research. Achieving accurate estimation in occlusion environments is another avenue for future exploration. Acknowledgement. It is supported by National Key R&D Program of China (grant no. 2018YFB1701401, 2020YFB1712401-1), National Natural Science Foundation of China (grant no. 62006210, 62001284), Key Project of Public Benefit in Henan Province of China (grant no. 201300210500), Science and technology public relations project of Henan Province (grant no. 212102210098, 202102210373) and the Research Foundation for Advanced Talents of Zhengzhou University (grant no. 32340306).
References 1. Kang, Y., Kang, H., Kim, J.: Fall detection method based on pose estimation using GRU. In: Lee, R., Kim, J.B. (eds.) SNPD 2021. SCI, vol. 951, pp. 169–179. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-67008-5_14 2. Chen, W., Jiang, Z., Guo, H., et al.: Fall detection based on key points of human-skeleton using openpose. Symmetry 12(5), 744 (2020) 3. Nogas, J., Khan, S.S., Mihailidis, A.: DeepFall: non-invasive fall detection with deep spatiotemporal convolutional autoencoders. J. Healthc. Inform. Res. 4(1), 50–70 (2020) 4. Cai, Y., et al.: Learning delicate local representations for multi-person pose estimation. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) ECCV 2020. LNCS, vol. 12348, pp. 455– 472. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58580-8_27
276
L. Shi et al.
5. Xu, Y., Zhang, J., Zhang, Q., et al.: ViTPose: simple vision transformer baselines for human pose estimation. arXiv preprint arXiv:2204.12484 (2022) 6. Pishchulin, L., Insafutdinov, E., Tang, S., et al.: Deepcut: joint subset partition and labeling for multi person pose estimation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4929–4937 (2016) 7. Newell, A., Yang, K., Deng, J.: Stacked hourglass networks for human pose estimation. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9912, Part VIII, pp. 483–499. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46484-8_29 8. Kreiss, S., Bertoni, L., Alahi, A.: PifPaf: composite fields for human pose estimation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11977–11986 (2019) 9. Cheng, B., Xiao, B., Wang, J., et al.: HigherHRNet: scale-aware representation learning for bottom-up human pose estimation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5386–5395 (2020) 10. Cao, Z., Hidalgo, G., Simon, T., et al.: OpenPose: realtime multi-person 2D pose estimation using part affinity fields. IEEE Trans. Pattern Anal. Mach. Intell. 43(1), 172–186 (2021) 11. Huang, G., Liu, Z., Van Der Maaten, L., et al.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4700– 4708 (2017) 12. Osokin, D.: Real-time 2D multi-person pose estimation on CPU: lightweight openpose. arXiv preprint arXiv:1811.12004 (2019) 13. Chen, Z., Wang, Y., Yang, W.: Video based fall detection using human poses. In: Liao, X., et al. (eds.) BigData 2022. CCIS, vol. 1496, pp. 283–296. Springer, Singapore (2022). https://doi. org/10.1007/978-981-16-9709-8_19 14. Menacho, C., Ordoñez, J.: Fall detection based on CNN models implemented on a mobile robot. In: 2020 17th International Conference on Ubiquitous Robots (UR), pp. 284–289. IEEE (2020) 15. Hou, Q., Zhou, D., Feng, J.: Coordinate attention for efficient mobile network design. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13713–13722 (2021) 16. Liu, Q., Kortylewski, A., Yuille, A.: PoseExaminer: automated testing of out-of-distribution robustness in human pose and shape estimation. arXiv preprint arXiv:2303.07337 (2023) 17. Yang, J., Zeng, A., Liu, S., et al.: Explicit box detection unifies end-to-end multi-person pose estimation. arXiv preprint arXiv:2302.01593 (2023)
Improved YOLOv5s Method for Nut Detection on Ultra High Voltage Power Towers Lang Xu1 , Yi Xia1 , Jun Zhang1 , Bing Wang2 , and Peng Chen1(B) 1 National Engineering Research Center for Agro-Ecological Big Data Analysis and Application, Information Materials and Intelligent Sensing Laboratory of Anhui Province, School of Internet and Institutes of Physical Science and Information Technology, Anhui University, Hefei 230601, Anhui, China [email protected] 2 School of Electrical and Information Engineering, Anhui University of Technology, Ma’anshan, Anhui 243032, China
Abstract. As an important power transmission facility, the operational stability and safety of ultra high voltage (UHV) power transmission towers are crucial to energy supply and social stability. With the continuous development of science and technology and artificial intelligence technology, the research and development of UHV tower maintenance robots has become an inevitable trend. However, the existing vision-based maintenance robot research is not yet mature, and there are problems such as poor real-time performance, low positioning accuracy, large distance measurement error, and poor performance of embedded devices. To solve these problems, this paper proposes a lightweight nut object detection algorithm based on YOLOv5s and MobileNetV3. In addition, experiments were conducted on the generated nut dataset. Compared with YOLOv5s, the proposed method can reduce the model size by 77.78% and increase the detection speed by 4.17%, and improved accuracy by 0.73%. The experimental results show that the improved algorithm greatly reduces the model size and improves the detection speed while maintaining the original accuracy, and effectively solves the problems of poor realtime detection and poor performance of embedded devices in existing methods. Keywords: YOLOv5s · Nut Detection · MobileNetV3 · CBAM
1 Introduction UHV power transmission technology is one of the most advanced power transmission technologies in the world. High-voltage transmission has the advantages of long-distance transmission, low loss and high transmission capacity, so it has become an important part of the power system of various countries. At present, UHV power transmission technology has been widely used in my country’s power industry, and high-voltage towers are spread all over the country [1]. The high-voltage tower is a whole that is connected layer by layer, and damage to any part can cause a chain reaction and lead to catastrophic accidents. However, since high-voltage towers are usually exposed outdoors © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 277–287, 2023. https://doi.org/10.1007/978-981-99-4761-4_24
278
L. Xu et al.
and are affected by the natural environment for a long time, equipment damage will occur. Therefore, a large amount of manpower, material resources and financial resources are needed to carry out the inspection and maintenance work of the electric power tower. At present, manual inspection is still the most important power inspection method, that is, to inspect and record electrical equipment by visual and auditory means of perception. However, the labor cost of this inspection method is high, and the inspection efficiency is greatly affected by natural conditions such as weather. The harsher the natural environment and the more complex the climatic conditions, the more necessary it is to increase the inspection frequency. At the same time, the safety of the inspectors cannot be guaranteed. As shown in Fig. 1, the tower maintenance personnel perform equipment maintenance on the tower. The maintenance workers need to climb to the high tower for maintenance work, which is extremely dangerous. Therefore, finding a cost-effective and safe inspection method and maintenance method is the most important task in the power industry.
Fig. 1. Manually performing maintenance of high-voltage power towers
In 2014, Schneider [2] proposed an object detection method that applies binocular vision technology to embedded mobile devices. This method can efficiently implement binocular vision detection on low-power mobile devices. This type of method has the advantages of fast response, low power consumption, high accuracy, and real-time performance, and has an important impact on the subsequent research on embedded devices. In 2021, Wang et al. from the University of Hong Kong proposed an efficient binocular stereo 3D detection algorithm [3] that uses 2D image information for detection. The algorithm inputs binocular images into two independent YOLOv3 neural networks, and then uses a special binocular synchronization mechanism to integrate the detection results obtained by the two networks to achieve 3D detection results. Compared with the traditional stereoscopic 3D detection algorithm, this method is easy to compute, has better robustness, and can detect and localize 3D objects in real time. In this paper, aiming at the shortcomings of the nut target detection algorithm based on YOLOv5 [4], such as large model size and large amount of detection computation,
Improved YOLOv5s Method for Nut Detection
279
a lightweight nut target detection algorithm based on YOLOv5 is proposed based on the MobileNetV3 network architecture. MobileNetV3 introduces the depthwise separable convolution and the inverse residual structure with a linear bottleneck, which are excellent in MobileNetV1 and MobileNetV2, and is based on the lightweight attention model of the squeeze and excitation structure (SE) [5] in backbone, using h-swish as a new activation function. In addition, due to the loss of accuracy caused by replacing the backbone network with MobileNetV3 in this article, it is considered to replace the C3 module in the neck structure with Convolutional Block Attention Module (CBAM) [6] to improve the detection accuracy and effectiveness of the model. Finally, the improved algorithm was tested on the nut data set, and a good detection effect was obtained.
2 YOLOv5s Network Structure The YOLOv5 algorithm was proposed by Ultralytics in June 2020 on the basis of YOLOv4 [7]. Although the recognition accuracy has not been greatly improved compared to the previous YOLO series algorithms, the recognition speed far exceeds the previous versions. YOLOv5 has a strong advantage in model deployment. The size of the YOLOv4 model using darknet is as large as 244 MB, while the size of the YOLOv5 model is only 27 MB, which has a smaller model size. According to the depth and width of the network, YOLOv5 is divided into four versions: YOLOv5s, YOLOv5m, YOLOv5l, and YOLOv5x. Among them, the YOLOv5s algorithm model has the smallest volume and the shortest single-frame processing time, which is more suitable for lightweight electric tower nut detection. Therefore, this article focuses on making improvements YOLOv5s. Its network structure is shown in Fig. 2.
Fig. 2. YOLOv5s network architecture
280
L. Xu et al.
3 YOLOv5s Improvements 3.1 MobileNetV3 The network structure of MobileNetV3 [8] is a combination of Depthwise Separable Convolution of MobileNetV1 [9] and Inverted Residuals and Linear Bottleneck of MobileNetV2 [10]; in MobileNetV1 and based on MobileNetV2, MobileNetV3 innovatively added the SE attention mechanism and used the Neural Architecture Search (NAS) [11] to search for network configuration and parameters, and proposed to use the h-Swish activation function. Figure 3 shows the network structure of MobileNetV3.
Fig. 3. MobileNetV3 network architecture
1) Depthwise Separable Convolution There are two commonly used convolutional operations, one is depth convolution, and the other is point convolution. Depth separable convolution decomposes depth convolution and point convolution to achieve more efficient computation and reduce computation cost. Figure 4 shows the network of depth separable convolution, which can be implemented in two steps: depth convolution and pointwise convolution. In depth convolution, an independent convolution kernel is used for each channel of the input, and each channel is decomposed into independent hierarchies. In pointwise convolution, a 1 × 1 convolution kernel is applied to the output of the depth convolution. This process combines all depth layers into one channel and applies a convolution kernel to that channel, greatly reducing the computational cost of the model. 2) Inverted residual structure The inverted residual structure is an improvement of the residual structure [12]. The residual structure and the inverted residual structure are shown in Fig. 5. The residual structure uses 1 × 1 convolution to reduce the dimensionality of the data, then uses 3 × 3 convolution to extract the features of the data, and finally uses 1 × 1 convolution to increase the dimensionality of the data. It is characterized by a larger number of channels at both ends and fewer channels in the middle. In contrast, the inverted residual structure first uses 1 × 1 convolution to increase the dimension, then uses 3 × 3 depth separable convolution to extract features, and finally uses 1 × 1 convolution to reduce the dimensionality and reduce the number of channels. Unlike regular residual convolutions, inverted residuals use depthwise separable convolutions for feature extraction. In addition, in the inverted residual
Improved YOLOv5s Method for Nut Detection
281
Fig. 4. Depthwise separable convolution
structure, the first two ReLU activation functions are replaced by ReLU6, and the last ReLU activation function is replaced by a linear activation function. The advantage of the inverted residual structure is that it can improve the computational efficiency of the model and reduce the number of model parameters while maintaining the accuracy of the model.
Fig. 5. Residual structure and inverted residual structure
3) SeNet SeNet is a relatively common attention mechanism that effectively improves the performance and stability of deep neural networks. The SE attention mechanism adaptively adjusts the feature importance of each channel by learning a channel attention weight, so as to better capture the relevant information of the features and improve the expressiveness and robustness of the model. Figure 6 is a schematic diagram of the SE module network.
282
L. Xu et al.
Fig. 6. SE module
The main idea of the SE attention mechanism is to weight the feature channels by two operations. The first operation is the squeeze operation (Squeeze), which compresses the two-dimensional features (HxW) in each channel by performing a global average pooling (Global average pooling) operation [13] on each feature channel to obtain a 1 × 1 × C channel vector. The second operation is the excitation operation (Excitation), this part uses a multi-layer perceptron (MLP) to generate a channel attention vector. First, through a fully connected layer operation, an output of 1 × 1 × c/r is obtained, where the parameter r is used to reduce the number of channels, thereby reducing the computational complexity. Next, apply the ReLU activation function to activate the output. Second, perform a full-connection layer operation to increase the dimension, and apply the Sigmoid activation function, and the final output feature is 1 × 1 × C. 3.2 Improved Nut Detection Network Architecture To address the shortcomings of existing object detection algorithms, such as high model computation complexity, low detection accuracy, and long training time, this paper conducts lightweight processing of the MobileNetV3 model based on the YOLOV5s version. While retaining the excellent depthwise separable convolution and linear bottleneck inverse residual structure in MobileNetV1 and MobileNetV2, the SE channel attention mechanism and h-swich activation function are added to improve the feature extraction capability. Based on the lightweight processing of MobileNetV3, the C3 module in the neck layer is replaced by the CBAM module, and the improved nut detection network structure is shown in Fig. 7. The size of the input image is 3 × 640 × 640, and it is first processed by the CBH module, including convolution operation, batch normalization processing and LeakyReLU activation function processing on the input image. Finally, it is processed by different numbers of Bneck blocks, and the processed data of different sizes are input to Neck for processing. The Upsample layer is a layer used to upsample the input tensor, changing its size by increasing the spatial dimension of the input tensor. The Concat layer is a layer used to connect two input tensors together in a certain dimension, which can merge the two while keeping the shape of the tensor unchanged. The Head layer is used to output the result of target detection, including information such as the category, position and confidence of the target, and is the final layer of the target detection model.
Improved YOLOv5s Method for Nut Detection
283
Fig. 7. Improved nut detection network architecture
4 Experimental Results and Analysis The operating system of the device used in the experiments in this chapter is Windows, the code compiler used is Pycharm, the graphics acceleration module is Geforce GTX 1080Ti, the memory size is 11G, and the training is accelerated with CUDA10.2 and cudnn7.6.5. The training framework used in the experiment is Pytorch, the corresponding Torch version is 1.8.1, and the programming language is Python 3.8.13. 4.1 DataSet This article focuses on nut detection for maintenance robots on power transmission towers, aiming to identify the nuts on the tower. The distribution and number of nuts vary at different locations on the high-voltage tower vary. Single nuts are present on the tower cross arm, while a large number of nuts are randomly distributed on the connection plate. The nuts are classified into hexagonal nuts, high-strength nuts, and locking nuts according to their different purposes, and this study focuses on the most common hexagonal nuts. Based on the analysis of these nuts, the experimental data set in this article was constructed by using a binocular camera to capture images on outdoor high-voltage transmission towers. Considering the actual working environment and various influencing factors, such as lighting, weather, and background, the nut data collection process was comprehensively considered. Only the left view of the binocular camera was retained, and a total of 1970 images were collected. The data set includes individual nuts, densely packed nuts on the tower connection plate, and holes. To improve the fit of the model, the hole data was also added as noise. For labeling, the LabelImg tool was used for manual labeling, and labels were divided into “nut” and “hole”. Figure 8 shows the nut data set. The annotated data was saved in XML format for subsequent network training and then converted to a TXT format for easy training using a script. In addition, data augmentation techniques were used to
284
L. Xu et al.
Fig. 8. Nut dataset.
expand the dataset to 5910 images for better training results. For model training, the nut dataset was divided into three parts: a training set of 4728 images, a test set of 591 images, and a validation set of 591 images. 4.2 Evaluation Metrics The main evaluation metrics for object detection algorithms in the field of deep learning are as follows: model size, number of parameters, computational complexity, training time, precision, recall, mean average precision (mAP), and so on. In this article, we choose model size, precision, mean average precision, and detection speed as the evaluation metrics of the algorithm. The formulas for calculating Precision, Recall, and Mean Average Precision are as follows, where TP represents the number of true positive samples predicted by the model, FP represents the number of false positive samples predicted by the model, and FN represents the number of false negative samples predicted by the model. Precision = Recall =
TP TP + FP
TP TP + FN
AP = ∫10 p(r)dr
(1) (2) (3)
4.3 Experimental Results For better detection results, this paper deals with epochs and batch_size did a comparative test. First batch_size is set to 16, and based on this, epochs are set to 100, 200, 300 respectively. The result is that 96.1% mAP can be obtained when epochs are set to 300. When batch_size is set to 32, epochs are set to 100, 200, 300, and 92.3% of mAP can be obtained when epochs are 300. The experimental results are shown in Fig. 9. Some parameters of the training part of this experiment are set, i.e., batch_size is 16, the initial
Improved YOLOv5s Method for Nut Detection
285
95 90
mAP_0.5(%)
85 80 75
100
200
300
epochs batch_size=16
batch_size=32
Fig. 9. Experimental results with different batch_size and epochs.
learning rate is 0.01, epoch is set to 300, and the pre_training model is yolov5s pt,. And the nut in the recognition image can be obtained more accurately. In order to verify the performance of the improved model, this experiment compared YOLOv5m,YOLOv5s, YOLOv5s-CBAM, YOLOv5s-MobileNet, and Faster-RCNN, in a comparative experiment. YOLOv5s-CBAM replaces the C3 module in the Neck network part of the YOLOv5s network with CBAM. In this experiment, the same equipment was used for model training and nut detection, and both were run on the Windows 10 operating system. The GTX 1080ti graphics card was used.The prepared nut dataset was used for training. To verify the detection speed of the trained model, ten unlabeled nut images were randomly selected as the detection targets. To speed up the processing, the batch size of the images was set to 8. The experimental results are shown in Table 1. Table 1. Results of a comparative experiment. Algorithms
Precision (%)
[email protected]: 0.95
Model Size(MB)
Detection Speed(ms)
YOLOv5m
94.3
0.504
40.6
17.9
YOLOv5s
95.7
0.547
14.4
14.4
YOLOv5s-CBAM
93.8
0.495
14.0
13.6
YOLOv5s-MobileNet
95.1
0.534
3.1
12.7
YOLOv5s-MC
96.1
0.552
3.3
13.4
Faster-RCNN
92.8
0.501
84.7
25.2
Based on the analysis of the experimental results in Table 1, it can be seen that all six algorithms have good performance in nut detection accuracy, with YOLOv5s and two lightweight algorithms being the most outstanding. Although YOLOv5s-MobileNetV3 has a significant advantage in model size, it suffers from a loss in detection accuracy.
286
L. Xu et al.
After improvement, the model size of YOLOv5s-MobileNetV3 is reduced by 78.47% compared to YOLOv5s, while the speed is improved by 6.62%. YOLOv5s-MC not only improves in model size and detection speed, but also has a slight improvement in accuracy. The model sizes and corresponding recognition speeds of the six algorithms are shown in Table 1. According to the data, compared with the other five algorithms, YOLOv5s-MC not only has faster detection speed and optimal model size, but also has the best accuracy. Therefore, the improvement of YOLOv5s is more in line with the original purpose of this paper, which is to optimize the network structure to reduce the model size and make the algorithm more suitable for use in embedded devices. According to Table 1, YOLOv5s, YOLOv5s-MobileNet, and YOLOv5s-MC have good performance in terms of detection accuracy, detection speed, and model size, so these three models are used for nut detection. The same nut image with nut and bolt hole labels is selected for detection, and the detection results of the three models are shown in Fig. 10. It can be seen from the figure that the algorithm can accurately detect nuts and bolt holes, with both YOLOv5s and the improved network having higher confidence values. However, the improved network has a slightly better confidence in detecting bolt holes.
Fig. 10. Three models’ detection performance.
5 Conclusion This article proposes a lightweight SAR ship target detection algorithm to address the issues of model size, accuracy, and detection speed in nut target detection models. The following results were obtained: (1) Based on the basic architecture of the YOLOv5s algorithm, this article incorporates Mobilenetv3 into the backbone network for lightweight processing. The model size was reduced by 78.47% and the speed was improved by 6.62%, showing significant lightweighting effects. (2) The CABM attention mechanism was added to the Neck layer to increase the average precision to 96.1%. Currently, this research remains at the algorithmic level, and in the future, it will be transplanted to hardware devices while maintaining good detection performance. Acknowledgement. This work was supported by the National Natural Science Foundation of China (Nos. 62072002 and 62172004), and Special Fund for Anhui Agriculture Research System.
Improved YOLOv5s Method for Nut Detection
287
References 1. Shu, Y., Chen, W.: Research and application of UHV power transmission in China. High Volt. 3(1), 1–13 (2018) 2. Liu, M., Ding, X., Du, W.: Continuous, real-time object detection on mobile devices with-out offloading. In: 2020 IEEE 40th International Conference on Distributed Computing Systems (ICDCS), pp. 976–986. IEEE (2020) 3. Liu, Y., Wang, L., Liu, M.: Yolostereo3d: a step back to 2D for efficient stereo 3D Detection. In: 2021 IEEE International Conference on Robotics and Automation (ICRA), pp. 13018–13024. IEEE (2021) 4. Diwan, T., Anirudh, G., Tembhurne, J.V.: Object detection using YOLO: challenges, architectural successors, datasets and applications. Multimedia Tools Appl. 82(6), 9243–9275 (2023) 5. Yan, B., Fan, P., Lei, X., et al.: A real-time apple targets detection method for picking robot based on improved YOLOv5. Remote Sens. 13(9), 1619–1632 (2021) 6. Woo, S., Park, J., Lee, J.Y., et al.: CBAM: Convolutional Block Attention Module. Springer, Cham (2018) 7. Bochkovskiy, A., Wang, C.Y., Liao, H.Y.M.: Yolov4: optimal speed and accuracy of object detection. arXiv Preprint arXiv: 2004.10934 (2020) 8. Howard, A., Sandler, M., Chu, G., et al.: Searching for MobileNetV3. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1314–1324 (2019) 9. Howard, A.G., Zhu, M., Chen, B., et al.: Mobilenets: efficient convolutional neural net-works for mobile vision applications. arXiv Preprint arXiv: 1704.04861 (2017) 10. Sandler, M., Howard, A., Zhu, M., et al.: Mobilenetv2: inverted residuals and linear bottlenecks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4510–4520 (2018) 11. Zoph, B., Vasudevan, V., Shlens, J., et al.: Learning transferable architectures for scalable image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8697–8710 (2018) 12. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 13. Lin, M., Chen, Q., Yan S.: Network in Network. arXiv Preprint arXiv: 1312.4400 (2013)
Improved Deep Learning-Based Efficientpose Algorithm for Egocentric Marker-Less Tool and Hand Pose Estimation in Manual Assembly Zihan Niu1 , Yi Xia1 , Jun Zhang1 , Bing Wang2 , and Peng Chen1(B) 1 National Engineering Research Center for Agro-Ecological Big Data Analysis and
Application, Information Materials and Intelligent Sensing Laboratory of Anhui Province, School of Internet and Institutes of Physical Science and Information Technology, Anhui University, Hefei 230601, Anhui, China [email protected] 2 School of Electrical and Information Engineering, Anhui University of Technology, Maanshan 243032, Anhui, China
Abstract. Different manual assembly orientations have a significant impact on assembly accuracy. The success or confidence of posture estimation depends on the accurate six degree-of-freedom (6DoF) position and orientation (pose) estimation of the tracked objects. In this paper, we present an improved Efficient Pose algorithm, which is a single-shot learning-based approach to hand and object pose estimation. Based on the original Efficient Pose algorithm, we added a subnetwork for hand prediction, replaced some MBConv modules with Fused-MBConv modules, modified the number of network layers, and used different training strategies. Experimental results show that on the public dataset for monocular red-green-blue (RGB) 6DoF marker-less hand and surgical instrument pose tracking, it improves performance and shortens training time compared to other methods. Keywords: Single-shot pose estimation · marker-less · deep learning · pose estimation
1 Introduction Assembly operations can be categorized into standard parts assembly and non-standard parts assembly. For example, most standard parts assembly in the automotive industry has been achieved through mechanical automated assembly. However, some assembly stages require manual assembly due to structural complexity. Although manual assembly remains indispensable in industrial digitalization, it faces problems such as complex processes, low efficiency, high error rates, and hindered effective experience formation. Pose estimation of unmarked objects remains challenging yet important. Researchers have proposed solutions for 6DoF object pose recovery, including template-based methods, point-to-point methods, the traditional machine learning, and deep learning techniques, which are currently the best performing. Deep learning uses large annotated datasets to learn discriminative feature representations. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 288–298, 2023. https://doi.org/10.1007/978-981-99-4761-4_25
Improved Deep Learning-Based Efficientpose Algorithm
289
We can categorize deep learning 6DoF object pose estimation methods into onestep methods that directly predict poses, and iterative refinement methods. Iterative refinement methods first predict 2D bounding boxes and then estimate 6DoF poses using Perspective-n-Point or RANSAC algorithms. While slower and more computationally intensive, iterative refinement can achieve a higher accuracy. One-step methods are faster and more accurate, but may struggle in complex environments. Bukschat et al. proposed the Efficient Pose, a one-step 6DoF object pose estimation method. Combined with AR, Efficient Pose can improve task positioning and accuracy to support manual assembly. Our work improves Efficient Pose to combine one-step speed and iterative refinement accuracy for real-time manual assembly post estimation.
2 Materials and Methods In order to jointly model the interaction between assembly tools and the user’s hand, based on EfficientNetV2 proposed by Google in 2021, two subnetworks are added in the head part of the network to predict the rotation vector R and translation vector t of the target object, while inheriting the efficient image classification capability of EfficientNetV2. At the same time, a subnetwork for hand prediction is added on this basis to better adapt to cable assembly operations, considering the necessity of manual assembly (Fig. 1.).
Fig. 1. The improved EfficientPose algorithm is a monocular deep learning framework for predicting hand and assembly tool related information from a single monocular red-green-blue (RGB) input image frame.
2.1 Multiscale Feature Fusion To address the common scaling problems of CNNs, Tan and Le proposed the EfficientNet framework, which is generated by uniformly scaling the network width, depth and resolution based on a fixed set of scaling coefficients [1]. Based on the EfficientNet backbone, Tan et al. proposed EfficientDet, a network that incorporates a weighted bidirectional feature pyramid network (BiFPN) for multi-scale feature fusion to effectively represent and process multi-scale features for object detection [2].
290
Z. Niu et al.
We introduce a set of subnetworks based on the EfficientPose architecture to estimate the 3D vertex positions of the user’s hands and the 6DoF transformation (rotation and translation) describing the pose of the object in the image. Similar to object detection frameworks such as YOLO [3], EfficientNetV2 does not directly predict 2D bounding boxes, but instead predicts the probabilities corresponding to the anchor boxes tiled on the image, with a unique set of predictions for each anchor box. In the network, the features contained in each anchor box serve as input to the subnetworks described below. 1) Rotation network For prediction, Efficient Pose’s rotation subnetwork uses an axis-angle representation instead of quaternions because fewer parameters are required. Mahendran et al. [4] found that it performed slightly better in their experiments in their experimental team, and directly regressed a rotation vector (r ∈ SO(3)) from the convolutional features extracted from each prior box. To train the rotation subnetwork, the square of the distance between the calibrated point position r and the predicted point position r is used [5]. The loss function of the rotation subnetwork is based on PoseLoss and ShapeMatch-Loss in PoseCNN [5], and the loss function equation is defined as 2 1 r‘x − rx , RLoss r , r = x∈M 2m
(1)
where M is the set of 3D model points, m is the number of points in the set, r is the axis-angle rotation for the ground truth pose, and r is the axis-angle rotation for the predicted pose. 2) Translation network Influenced by PoseCNN, our translation subnetwork separately regressed a 2D object center point c = (cx , cy )T in pixel coordinates and a translation distance component t z , separately, from convolutional features extracted within each anchor box. A translation vector (t ∈ R3 ) for the object was then constructed using the object center point c, the distance component t z , and knowledge of the camera’s intrinsic parameters. The missing translation components, t x and t y , were computed as tx =
(cx − px ) · tz fx
(2)
ty =
(cy − py ) · tz , fy
(3)
where the principal point p = (px , py ) and the focal lengths f x , f y were derived from the camera intrinsic parameters. We modeled our translation loss using the smoothed L1 loss between the ground truth translation (t) and the predicted translation (t ) as 1 smoothL1 t x − tx , TLoss t , t = x∈M 2m where
smoothL1 =
0.5x2 , |x| < 1, , |x| − 0.5, other
(4)
(5)
Improved Deep Learning-Based Efficientpose Algorithm
291
and where M is the set of 3D model points, m is the number of points in the set, t denotes the translation for the ground truth pose, and t denotes the translation for the predicted pose. The final 6DoF transformation, T, for the predicted object pose was composed of the rotation matrix R ∈ R3×3 , computed from the axis-angle rotation representation r, and the translation components as T = [R|t] ∈ R4×4 . 3) Hand network Our hand subnetwork used a vector representation to model a 3D skeleton described by a total of 21 3D joint poses (h ∈ R3×21 ). The ground truth hand data used for training is in an identical vector format. Unlike other parametric approaches such as MANO [6], which generates a representative 3D hand mesh model from the 3D skeleton data, we instead directly regress the 3D hand skeleton vector from convolutional features extracted from within each anchor box. To train our hand subnetwork to predict the 3D skeleton vertices of the hand, we used the average squared distance between the points of the correct hand skeleton pose (h) and their corresponding predictions (h‘). We defined our hand vertex loss as 1 smoothL1 h x − hx , HLoss h , h = x∈M 2m
(6)
where M is the set of 3D model points, m is the number of points in the set, h denotes the ground truth hand vertex vector, and h‘ denotes the predicted hand vertex vector. 2.2 Improved Backbone Network 1) Fused-MBConv Module The Fused-MBConv module consists of MBConv blocks, SE attention modules, and scaling gate units. The MBConv block is designed based on the inverse residual bottleneck structure of the mobile network. It contains a 1 × 1 convolution for channelwise feature projection, a 3 × 3 depth-separable convolution for spatial feature extraction, and a 1 × 1 convolution for channel-wise feature expansion. The SE attention module aggregates contextual information in the channel dimension through global average pooling and reweighted feature maps. Finally, we introduce a learnable scaling factor to provide a data-driven mechanism to adjust the scale of the feature maps (Fig. 2). The MBConv module is relatively simple, consisting mainly of a 1 × 1 projection convolution layer, a 3 × 3 depth separable convolution layer, and a 1 × 1 expansion convolution layer. Fused-MBConv integrates the SE attention mechanism and the scaling gate mechanism bssed on the basis of MBConv, which can extract richer feature expressions and achieve higher accuracy without increasing the computational complexity too much (Fig. 3). 2) Network Structure The initialization module is located at the beginning of the EfficientNetV2 network and contains 1 or more Fused-MBConv modules to extract low-level features. The channel expansion factor expansion of the Fused-MBConv module is relatively small to
292
Z. Niu et al.
Fig. 2. Fused-MBConv module, the value of Expansion controls the complexity of the FusedMBConv module. When the value of Expansion is not 1, Fused-MBConv is more lightweight. When the value of Expansion is 1 Fused-MBConv integrates more mechanisms to extract richer feature expressions.
Fig. 3. The structural differences between the MBConv and Fused-MBConv modules.
ensure a moderate amount of computation. The feature map resolution output by the initialization module is relatively high and contains richer spatial information, which provides a good basis for the processing by subsequent modules. The extraction module contains several Fused-MBConv modules with the same basic structure as the initialization module, but the channel expansion factor Expansion is slightly larger than the initialization module to extract more abstract semantic features. The extraction module gradually reduces the resolution of the feature map to capture contextual information over a larger area, while also reducing the number of model parameters and computational complexity. The aggregation module contains 1 or more Fused-MBConv modules with fewer output channels than the extraction module to aggregate feature expressions at different scales. Reducing in the number of channels can avoid overly redundant features and reduce the model size to some extent. The feature map resolution of the aggregation
Improved Deep Learning-Based Efficientpose Algorithm
293
module remains unchanged or is slightly reduced, allowing it to effectively fuse multiscale features output from the initialization module and extraction modules. The classification module contains 1 MBConv module with the same number of output channels as the number of object categories for the final classification prediction. The classification module no longer changes the feature map resolution and directly uses the features output from the aggregation module for classification judgment. As shown in Table 1, the network structure of EfficientNetV2 effectively aggregates features at different semantic levels through different modules, and gradually reducing the feature map resolution and model computational complexity in the process. Table 1. Backbone network structure. Stage
Operator
Stride
Channels
Layers
0
Conv3 × 3
2
24
1
1
Fused-MBConv1, k3 × 3
1
24
2
2
Fused-MBConv4, k3 × 3
2
48
4
3
Fused-MBConv4, k3 × 3
2
64
4
4
MBConv4, k3 × 3, SE0.25
2
128
6
5
MBConv6, k3 × 3, SE0.25
1
160
9
6
MBConv6, k3 × 3, SE0.25
2
256
15
7
Conv1 × 1 & Pooling & FC
–
1280
1
2.3 Improved Training Strategy Large images take up a lot of memory space during training, while the amount of memory space that can be used for training on graphics cards is limited. Therefore, using smaller batches, Hugo [7] demonstrated using smaller batches that one should use smaller image sizes rather than large image sizes if one wants to increase the accuracy and training speed. The experimental results of Tank’s [8] support this theory. In addition, Howard [9] and Tank suggested that when using different image sizes for training, the regularization parameters should be dynamically adjusted rather than fixed as before. The loss of accuracy results from inappropriate regularization. Using dynamic image sizes to speed up network training without proper tuning can significantly slow down training and reduce accuracy. The improved learning method adopts the progressive learning training strategy proposed by Tan et al. In the initial stage of network training, smaller images and weaker regularization are used to train the network so that it can quickly learn simple representations. Then, the image size is gradually increased and stronger regularization is added to make learning more difficult. The method of gradually changing the image size is based on the work of Howard et al. [10] and requires adaptive adjustment of the regularization. Assuming that the total training has N steps, the target image size is S e ,
294
Z. Niu et al.
and the regularization level is Φ e = {φ0k }, where k represents a type of regularization such as exit rate or mix rate. The training is divided into M steps, each step 1 ≤ i ≤ M trains the model with image size S i and regularization size Φ i = {φik }, and the last step M will uses the target image size S e and regularization Φ e . There are three typical regularization types: Dropout [11] Network-level regularization that reduces adaptive dropout by randomly dropping channels. Implemented by adjusting the exit rate γ; RandAuguys [12] implemented by adjusting the amplitude ε of the data augmentation; Mixup [13] cross-image data augmentation and we choose Dropout. 2.4 Evaluating Indicator Our experimental data was evaluated using the average three-dimensional error. The evaluation method calculates the average three-dimensional error between the predicted pose of the assembly tool and the actual pose, as well as the position error (mm) and rotation error (degrees) at the tip of the three-dimensional model of the assembly tool. The tool error equation is defined as TooLLoss =
1 (Rx + t) − (R x + t ), x∈M m
(7)
where m is the number of points in the 3D model set, R and R , and t and t are the ground truth and predicted rotation matrices and translation vectors respectively. The 3D model points of the rigid surgical drill model set are represented by x. in addition, we evaluated the performance of our 3D hand vertex predictions by computing the mean end-point error across the 21 predicted joints [26, 27]. We described our hand vertex error measure as HandLoss =
1 hx − h , x∈M m
(8)
where m is the number of points in the 3D set, and h and h˜ are the ground truth and predicted hand vertex vectors respectively. To evaluate the runtime performance of our network, we estimated the total number of parameters in each model, the model size, the FLOPS of input images of size (1, 3, 256, 256), the network training time, and the inference time (latency) measured with a batch size of 1 using a single NVIDIA RTX 3090 GPU and 1000 test data samples.
3 Experiment and Result 3.1 Experimental Environment The operating system of this experiment is Ubuntu 20.04. The CPU is Intel(R) Core(TM) i9-10900X CPU @ 3.70 GHz with 64G running memory. The GPU is NVIDIA RTX 3090 with 24 GB GDDR6X memory. The deep learning framework is Pytorch 1.8.0, the experimental language is Python 3.8.0, and the CUDA version is 10.2.
Improved Deep Learning-Based Efficientpose Algorithm
295
3.2 Datasets We used the synthetic and real data sets presented by Hein, et al. [13] to evaluate the performance of the improved Efficient Pose algorithm. Both datasets used handheld tools for rigid tool tracking and fused hand gesture information during tool grasping. The images were annotated with the six degrees-of-freedom pose of the tool in the tool coordinate frame along with the 3D hand joint positions of the user while grasping the tool. 1) Real image dataset This dataset contains 3,746 images with a resolution of 256 × 256 pixels, all of which are real shots of hand-held tools at different angles and positions, as shown in Fig. 4.(1). On this dataset, we use the network weights that performed best on the synthetic dataset, with a batch size of 32 and 500 epochs, an initial learning rate of 0.001, and an SGD optimizer with a momentum of 0.9. For comparison, a different optimizer was used on the synthetic dataset. On the real dataset, the SGD optimizer with a lower learning rate can provide more stable convergence results. 2) Synthetic image dataset This synthetic dataset contains a total of 10,500 images, each with a resolution of 256 × 256 pixels, as shown in Fig. 4.(2). The data was synthesized using a pipeline to construct a virtual image dataset. For this dataset, the batch size was 32, the number of epochs was 500, the Adam optimizer was used, and the initial learning rate was 0.001.
Fig. 4. The figure shows: (1) Real image dataset; (2) Synthetic image dataset
3.3 Results and Analysis Hein et al. proposed several strategies to estimate surgical tool and hand gesture poses from RGB data of a monocular camera. For synthetic drilling dataset based on PVNet [15] and HandObjectNet algorithm.
296
Z. Niu et al.
The PVNet algorithm focuses only on 6DoF estimation without considering the joint interaction between the user’s hand and the object. It uses a set of 2D keypoints corresponding to the center of the 3D bounding box and selected locations on the 3D model to indirectly estimate the pose of the object. It also uses a network architecture similar to U-Net [16] to estimate the segmentation mask and 2D vector field of each keypoint. PVNet uses a RANSAC voting method based on the 2D vector field to recover the 2D keypoints and minimizes the Mahalanobis distance based on the predicted keypoints and the mean and covariance of the reference keypoints to recover the 6DoF pose of the surgical tool using the PnP method. The HandObjectNet algorithm not only focuses on the 6DoF of the target object, but also considers the articulation interaction between the hand and the tool. In the HandObjectNet network, the hand and the tool share the same ResNet-18 [17] encoder and decoder. The decoder of the hand network needs to estimate 18 pose parameters, including 15 coefficients to describe the construction of the hand and 3 parameters to represent the hand’s axis-angle, as well as 10 parameters to describe the MANO hand model. The tool network must regress an axis-angle rotation vector, a 2D translation vector, and a normalized focal depth offset, and finally compute the 3D transformation data using the camera intrinsics. Table 2 shows that compared with the HandObjectNet algorithm and the PVNet algorithm, our algorithm performs better in terms of TooLLoss and tool orientation error, that is, tool angle error, on real image datasets. In addition, in terms of the number of network parameters, compared with HandObjectNet’s 12.49 and PVNet’s 12.96, our algorithm reduces the number of network parameters by about 3 times, and the network size is 16.3MB, compared with HandObjectNet’s 53.1 MB and PVNet’s 51.9 MB, reducing by about 3 times. In addition, the inference latency has also increased by about 3 times. Table 2. Performance comparison on the real image dataset data. Ours
HandObjectNet
PVNet
TooLLoss (mm)
11.85 ± 8.75
16.73 ± 16.97
20.59 ± 52.14
Drill Tip Error (mm)
33.41 ± 35.64
44.45 ± 59.72
31.10 ± 67.18
Drill Bit Direction Error (deg)
5.45 ± 6.52
6.59 ± 10.18
7.11 ± 21.78
HandLoss (mm)
17.56 ± 6.83
17.15 ± 10.58
–
Network Paramters (M)
3.95
12.49
12.96
Network Size (MB)
16.3
53.1
51.9
FLOPS (B)
1.35
4.76
30.96
Latency (ms)
16.85 ± 3.2
21.5 ± 3.3
52.4 ± 8.2
Table 3 shows that compared to the HandObjectNet algorithm and PVNet algorithm, on the synthetic image dataset, our algorithm reduces the TooLLoss by 3.46 mm and 29.40 mm compared to 13.78 mm and 39.72 mm of HandObjectNet and PVNet, respectively. There is a significant improvement in tool tip error and tool orientation error. The
Improved Deep Learning-Based Efficientpose Algorithm
297
tool tip error is reduced by 39.25 mm and 45.94 mm respectively. The tool orientation error is reduced by 5.10 mm and 9.80 mm respectively. Table 3. Performance comparison on the synthetic image dataset data. Ours
HandObjectNet
PVNet
TooLLoss (mm)
10.32 ± 8.32
13.78 ± 5.28
39.72 ± 66.49
Drill Tip Error (mm)
26.86 ± 20.65
66.11 ± 26.91
72.80 ± 105.66
Drill Bit Direction Error (deg)
3.61 ± 2.88
8.71 ± 3.98
13.41 ± 33.78
HandLoss (mm)
16.98 ± 9.57
9.78 ± 4.54
–
Table 4 shows that our algorithm uses a progressive learning training strategy. On the RTX3090 GPU, the real dataset took about 14.5 h to train, and the synthetic dataset took about 47.7 h to train. Compared to non-progressive learning, the real dataset improved by about 10% and the synthetic dataset improved by about 12%. Table 4. Training strategies for progressive learning. Dataset
Non-progressive learning/h
Progressive learning/h
Real image dataset
16.2
14.5
Synthetic image dataset
54.3
47.7
4 Conclusion The focus of this work is to obtain the 6DoF of the target hand and assembly tool from monocular RGB data by studying the interaction between the user’s hand and the assembly tool. A hand gesture estimation subnetwork is incorporated into the Efficient Pose algorithm and the backbone network structure is modified to show the latest performance of pose estimation of hands and assembly tools and the network inference speed. Through modifications and additions, the algorithm can better solve the problem of 6DoF pose estimation for hands and tools during manual assembly. Our method can be widely used for artificial assembly guidance, assembly posture specification, etc. Our key innovation lies in incorporating hand gesture estimation into the Efficient Pose algorithm and modifying its backbone structure to achieve the stateof-the-art performance in 6DoF pose estimation for hands and assembly tools during manual assembly.
298
Z. Niu et al.
Acknowledgement. This work was supported by the National Natural Science Foundation of China (Nos. 62072002 and 62172004), and Special Fund for Anhui Agriculture Research System.
References 1. Tan, M., Le, Q.: EfficientNet: rethinking model scaling for convolutional neural networks. In: International Conference on Machine Learning, pp. 6105–6114. PMLR (2019) 2. Tan, M., Pang, R., Le, Q.V.: EfficientDet: scalable and efficient object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10781–10790 (2020) 3. Redmon, J., Divvala, S., Girshick, R., Farhadi, A.: You only look once: unified, real-time object detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 779–788 (2016) 4. Mahendran, S., Ali, H., Vidal, R.: 3D pose regression using convolutional neural networks. In: Proceedings of the IEEE International Conference on Computer Vision Workshops, pp. 2174– 2182 (2017) 5. Xiang, Y., Schmidt, T., Narayanan, V., et al.: PoseCNN: a convolutional neural network for 6D object pose estimation in cluttered scenes. arXiv preprint arXiv:1711.00199 (2017) 6. Romero, J., Tzionas, D., Black, M.J.: Embodied hands: modeling and capturing hands and bodies together. ACM Trans. Graph. (2017) 7. Touvron, H., Vedaldi, A., Douze, M., et al.: Fixing the train-test resolution discrepancy. Adv. Neural. Inf. Process. Syst. 356, 32 (2019) 8. Tan, M., Le, Q.: EfficientNetV2: smaller models and faster training. In: International Conference on Machine Learning, pp. 10096–10106. PMLR (2021) 9. Hoffer, E., Weinstein, B., Hubara, I., et al.: Mix & match: training convnets with mixed image sizes for improved accuracy, speed and scale resiliency. arXiv preprint arXiv:1908. 08986 (2019) 10. You, Y., Zhang, Z., Hsieh, C.J., et al.: ImageNet training in minutes. In: Proceedings of the 47th International Conference on Parallel Processing, pp. 1–10 (2018) 11. Srivastava, N., Hinton, G., Krizhevsky, A., et al.: Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15(1), 1929–1958 (2014) 12. Cubuk, E.D., Zoph, B., Shlens, J., et al.: RandAugment: practical automated data augmentation with a reduced search space. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, pp. 702–703 (2020) 13. Zhang, H., Cisse, M., Dauphin, Y.N., et al.: Mixup: beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 (2017) 14. Hein, J., et al.: Towards markerless surgical tool and hand pose estimation. Int. J. Comput. Assist. Radiol. Surg. 16(5), 799–808 (2021) 15. Peng, S., Liu, Y., Huang, Q., et al.: PvNet: pixel-wise voting network for 6D of pose estimation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4561–4570 (2019) 16. Weng, W., Zhu, X.: INet: convolutional networks for biomedical image segmentation. IEEE Access 9, 16591–16603 (2021) 17. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016)
Intelligent Computing in Communication Networks
Adaptive Probabilistic Broadcast in Ad Hoc Networks Shuai Xiaoying1(B) , Yin Yuxia2 , and Zhang Bin1 1 College of Information Engineering, Taizhou University, Taizhou 225300, Jiangsu, China
[email protected] 2 Taizhou University, Taizhou, Jiangsu 225300, China
Abstract. Broadcasting in wireless networks is an essential information dissemination method. However, flooding may cause broadcast storms that produce severe contention, collisions, and delays. In this study, an adaptive probabilistic broadcasting (APB) algorithm is proposed based on neighbor discovery (ND); this approach reduces collision and ensures higher reliability at lower retransmissions. The adaptive algorithm rebroadcasts the broadcast from an INode (a node that has received the broadcast packet) to its 1-hop UNodes (nodes that have not received the broadcast). The node first counts the number of 2-hop INodes that are connected through its 1-hop UNode and the size of UNodes in the 1-hop neighbors; then, it computes the probability of rebroadcasting according to those statistics. The APB uses the rebroadcast delay to determine the forwarding priority of INodes. The APB adaptively adjusts the probability and delay to reduce the number of replications while simultaneously covering more UNodes. The adaptive algorithm effectively reduces the number of retransmissions while ensuring efficiency and reliability. The simulation results show that the number of conflicts, the number of rebroadcasts and average end-to-end delay of the APB are significantly lower than those of the other algorithms, and the reliability is higher than that of the other algorithms. Keywords: Ad hoc · Neighbor Discovery · Adaptive Probabilistic Broadcast
1 Introduction With the rapid development of wireless networks, research on ad hoc networks, such as wireless sensor networks (WSNs), mobile ad hoc networks (MANETs) and vehicular ad hoc networks (VANETs) [1], has gained widespread attention. MANETs consist of a set of wireless nodes that communicate with other nodes without relying on any network infrastructure. Owing to their resource-constrained, distributed, and self-organizing characteristics, MANETs have many attractive applications in rescue operations, traffic accidents, and sudden disasters. Broadcasting is a simple and effective propagation method. Shahid Latif et al. [2] elaborated the existing broadcast schemes in VANETs. Flooding is a basic information dissemination method for MANETs. Too many broadcast packets increase the load on the network and increase competition, collision, and delay. Collisions may induce retransmission and increase the end-to-end delay. Reducing © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 301–313, 2023. https://doi.org/10.1007/978-981-99-4761-4_26
302
S. Xiaoying et al.
the number of duplications may reduce the accessibility of the messages, particularly in sparse networks. Therefore, it is important for intermediate nodes to decide whether to forward the broadcast. A novel algorithm called adaptive probability broadcast is proposed to dynamically adjust the broadcasting probability and delay based on neighbours. The main innovations of APB are as follows: 1) A method for adaptive rebroadcasting probability is proposed. INode determines the forwarding probability based on the number of UNodes for its 1-hop neighbours and the number of 2-hop neighbours of INode that are connected through its 1-hop UNode. One rebroadcast enables as many UNodes as possible to receive the broadcast and reduces the collision caused by the simultaneous transmission of 2-hop INodes. This algorithm effectively reduces the number of rebroadcasts and end-to-end delays. 2) The adaptive probability decision method improves the reliability. APB is applicable to any connected network and can broadcast to all connected nodes in the network. 3) A method for rebroadcasting delay is designed. An INode with more 1-hop UNodes has a higher priority to rebroadcast the broadcast.
2 Related Research Blind flooding is the simplest type of broadcast pattern [3]. Each node forwards the packet to all of its 1-hop neighbours when it receives a packet. Evidently, blind flooding is expensive. Redundant broadcasting results in packet collisions and serious contention. Probabilistic broadcast algorithms can be classified as fixed probabilistic and adaptive broadcasts [4]. The fixed probabilistic broadcast algorithm forwards received broadcast packets with the same probability p to either all its 1-hop neighbours or none of them with probability 1-p [5, 6]. The adaptive probabilistic broadcast algorithm optimizes the forwarding probability according to density, distance, number of neighbours, etc., to reduce the number of broadcast packet retransmissions [7–9]. Some schemes [10] require global information, such as the total number of nodes, average degree, and minimum degree. Therefore, these methods that require global information are difficult to apply in real dynamic MANETs. Several broadcasting mechanisms based on neighbour knowledge have been developed in [11] and [12]. Many researchers have designed probabilistic broadcasting algorithms based on neighbour knowledge. Neighbour knowledge probabilistic broadcast exchanges 1-hop or 2-hop neighbour information to calculate the forwarding probability [13]. Lim et al. [14] proposed a broadcast scheme that is based on self-pruning. After a random assessment delay (RAD), the received packet is rebroadcast or discarded. Dominant pruning (DP) extends self-pruning by using 2-hop neighbour information to reduce the number of rebroadcasts [15]. To overcome the deficiencies of the existing dominant pruning algorithm, the TDP (total dominant pruning) algorithm and PDP (partial dominant pruning) algorithm were proposed [16]. Peng et al. [17] proposed a scalable broadcasting algorithm (SBA). NCPR (neighbour coverage-based probabilistic rebroadcast) uses the additional coverage ratio and connective factor to compute the rebroadcast probability [18]. The connective factor is based on the number of all nodes and the 1-hop neighbours. The NCPR employs a rebroadcast delay and probability to improve the broadcasting performance in wireless networks. DCFR [19] uses a dynamic connectivity formula and dynamic
Adaptive Probabilistic Broadcast in Ad Hoc Networks
303
connectivity factor to optimize NCPR, which improves the efficiency of packet delivery. NPB [20] adopts the opposite strategy of NCPR and DCFR to determine the rebroadcast delay. Three dynamic probabilistic rebroadcast schemes based on 1-hop and 2-hop neighbours have been proposed [21]. Dingzhu Lu [22] proposed a broadcast based on neighbour knowledge (NKB) for MANET to reduce delay and extended NKB based on velocity to NKVB for higher mobility nodes. These schemes require global information, such as the total number of nodes, average degree, and minimum degree. Therefore, these methods are difficult to implement in real applications. The published literature does not consider the number of INodes in 1-hop neighbours of UNodes. If the number of INodes around a UNode is large, the probability of retransmission for each INode is reduced. If multiple INodes send packets, duplicate packets and conflicts occur.
3 Adaptive Probabilistic Broadcast 3.1 Network Model In Fig. 1, the wireless network is depicted as an undirected graph. A node that has received a broadcast which is called an INode; otherwise, it is called a UNode. The solid circles denote the INodes. The unfilled circles represent UNodes. If two nodes are located within the communication range, they are called 1-hop neighbours and are connected in one line.
Fig. 1. Network Topology.
Node S is used as the source. Node S broadcasts a packet to its 1-hop nodes. The broadcast covers nodes A, B, and C. Nodes A, B, and C become INodes after receiving the broadcast. Nodes A and C are rebroadcast once, and the packet can be transmitted to all remaining nodes of the network. If node B rebroadcasts the packet, the number of duplicates increases and may conflict with node C. Node A must be rebroadcast; otherwise, nodes E and D cannot receive the packet. Each INode decides whether to forward it based on its neighbours. Each node obtains neighbour information using a discovery program [23]. When a UNode enters the communication range of an INode, it establishes an adjacency relationship and exchanges information via neighbour discovery (ND) [24]. ND is the foundation for initializing networks, route determination and data dissemination in MANETs [25]. It was developed to discover the 1-hop neighbours of a node [26]. Nodes exchange beacons to maintain network topology information [27]. Each node maintains
304
S. Xiaoying et al.
and manages its own neighbour list. Upon the exchange of neighbour information, each node updates its neighbour list. Let N1 (Y ) represent the set of 1-hop nodes of node Y; the notation is shown in Table 1. Table 1. Notation Notation
Description
N1 (Y )
the set of 1-hop nodes of node Y
N2 (Y )
the set of 2-hop nodes of node Y
N2 (Y, K)
the set of 2-hop nodes of node Y that are connected through its 1-hop node K
N1i (Y )
the set of 1-hop nodes of node Y that have received the broadcast packet
N1u (Y ) N2i (Y , K)
the set of 1-hop nodes of node Y that have not received the broadcast the set of nodes in N2 (Y, K) that have received the broadcast packet
px
broadcast probability of node x
INode UNode Fb
a node that has received(covered) the broadcast packet a node that has not received(covered) the broadcast packet the broadcast packet
tx
rebroadcast delay of node x
d
constant delay
|.|
the size of a set
3.2 Goal of APB Owing to the characteristics of the wireless node, each 1-hop neighbour INode of a UNode may broadcast to it. If the number of INodes is high, the possibility of retransmission increases, resulting in more redundant packets. This may cause conflict or broadcast storms. An INode with more 1-hop UNodes and fewer 2-hop INodes can rebroadcast preferentially. One rebroadcast can cover more UNodes. The possibility of simultaneous rebroadcast by INodes in 2-hop neighbours is reduced. The number of collisions is reduced by reducing the number of 2-hop INodes transmitting at the same time. Therefore, the number of rebroadcast nodes can be reduced to mitigate conflicts and redundancies. Each INode decides whether to forward by using a selection strategy. The goals of the adaptive probability broadcast algorithm are to send messages to as many nodes as possible through a single flood while ensuring that the packet is transmitted to each node with a very high probability and while reducing collisions. The node broadcast rules are as follows. 1) For an INode Y, the larger |N1u (Y )| is, the more likely node Y rebroadcasts the broadcast packet. 2) For a UNode K, the smaller |Ni1 (K)| is, the higher the probability that each node in the Ni1 (K) forwards the broadcast packet.
Adaptive Probabilistic Broadcast in Ad Hoc Networks
305
3) For an INode whose one 1-hop UNode is connected to only it, the INode needs to forward the received broadcast packet. Rule 1 ensures that an INode with more UNodes has a higher rebroadcasting probability and shorter delay. The INode has more 1-hop UNodes, and the transmission probability is higher for INodes with more 1-hop UNodes. The INode rebroadcasts packets to additional nodes. Rule 2 reduces the number of duplicate packets and avoids conflicts caused by the simultaneous forwarding of multiple 1-hop INodes of the UNode. If a UNode has more INode neighbours, the probability of each INode transmitting the data packet is lower. Rule 3 is used to improve the reachability and to ensure that any UNode can receive the broadcast, even if only one INode is connected. 3.3 Broadcast Probability An adaptive probability method was designed for the broadcast algorithm to achieve the objectives, as shown in Eq. (1). When calculating the transmission probability, INode X must consider not only the number of UNodes in its 1-hop neighbours but also the number of INodes around the UNode. u N (X) 1 (1) pX = u N (X ) + N i (k) 1
2
As shown in Fig. 2, the broadcast can be forwarded by any node in the INodes to UNode K. Thus, the probability of rebroadcasting INodes in Fig. 2 (I) should be lower than that of INodes in Fig. 2(II). In Fig. 2(I), N1u (X) = {K}, N2i (X, K) = {V, J}, N1u (V) = {K}, N2i (V,K) = {X, J}, N1u (J) = {K}, and N2i (J,K) = {V, X}. The probabilities calculated using (1) are as follows: px = pv = pJ =
1 = 0.33 1+2
(2)
Fig. 2. UNode with different INodes
As shown in Fig. 2 (II), only two INodes exist around UNode K. According to (1), the probabilities are pX = 0.5 and pJ = 0.5. The retransmission probability of each INode in Fig. 2(II) is higher than that of each INode in Fig. 2(I). This result is consistent with rule 2. An adaptive probability broadcast reduces the possibility of conflict caused by multiple INodes around the UNode.
306
S. Xiaoying et al.
The network topology is shown in Fig. 3. Nodes D, E, and X are INodes, while the other nodes do not receive the broadcast. According to Rule 1, the probability of rebroadcasting node X is higher than those of nodes D and E. The APB uses (1) to calculate the transmission probability of each node, as follows: N1u (X) = {A, B, C}, N1u (D) = {A}, N1u (E) = {B, C}. N2i (X, A) = {D}, N2i (X, B) = {E}, N2i (X, C) = {E}. N2i (E, B) = {X}, N2i (E, C) = {X}, N2i (D, A) = {X}. The probabilities are pX = 0.75, pD = 0.5 and pE = 0.67. The results are consistent with rule 1. UNode G is connected only to node C, as shown in Fig. 1. Node C must retransmit the broadcast so that node G receives it. N1u (C) = {F, G}, N2i (C, G) = φ. The probability of INode C retransmission is 1. This result is consistent with rule 3.
Fig. 3. INode with different UNodes
3.4 Reliability Analysis It is important for the broadcast to propagate from the source to other nodes that have not received it. Reliability refers to the possibility of one UNode receiving a broadcast.
Fig. 4. Network topology
In Fig. 4, the nodes covered by F b are marked as I 1 , I 2 …, I n , and the nodes not covered are marked as U 1 , U 2 …, U m . Where n is the size of the Ni1 ( U j ). A UNode may receive F b from any INode in its 1-hop neighbours. For example, U 2 may receive F b s that are retransmitted by I 1 , I 2 …, I n . The probability of U 2 receiving F b is given by (3). n x=1
px =
n
N1u (Ix )
x=1
N1u (Ix ) + N2i (Ix , U2 )
(3)
Adaptive Probabilistic Broadcast in Ad Hoc Networks
=
307
n N u (Ix ) + N i (Ix , U2 ) − N i (Ix , U2 ) 1
2
x=1
=
2
N1u (Ix ) + N2i (Ix , U2 ) n
1−
x=1
N2i (Ix , U2 ) N1u (Ix ) + N2i (Ix , U2 )
For ∀x, Ix ∈ N1i (Uj ), ∃j = 1, . . . , m, all N2i (Ix , Uj ) are equal, as in (4). N2i (I1 , U2 ) = N2i (I1 , U2 ) = ... = N2i (In , U2 )
(4)
The probability increases with an increase in N1u (Ix ). When N1u (I x ) = 0, the probability is zero, and the INode does not send F b . When N1u (I x ) = 1 and ∃j, ∀x ∈ N1i (Uj ), the probability of U j receiving F b is given by (5). n x=1
1 i (U ) , n = N j 1 1 + N2i (Ix , U2 ) =
n x=1
(5)
1 =1 1+n-1
Therefore, the probability of each UNode receiving F b is greater than or equal to 1. 3.5 Rebroadcast Delay To fully discover neighbours and to improve coverage, some algorithms adopt a rebroadcast delay to determine the forwarding priority of INodes. In the NCPR and DCFR algorithms, if a receiving node and transmitting node have more common contiguous nodes, the delay for the receiving node to rebroadcast F b is shorter. The NPB adopts the opposite strategy of NCPR and DCFR to determine the rebroadcast delay. A receiving node that has more common neighbours with the transmitting node has a greater rebroadcast delay. The transmission delay of INode X is affected by the ratio of the number of common contiguous nodes to the average number of neighbours of INode X. Thus, this may lead to a situation where the number of uncovered nodes of one node is greater than that of another node, but the delay is not less than that of the other node. Rebroadcast delay is also used in this study to determine the forwarding priority of the nodes. The delay algorithm is based on NPB and NCPR. The delay is given by (6). ⎧ N u (x) ⎪ ⎨ d • (1 − N11 (s) ), 1 ≤ N1u (x) < |N1 (s)| N u (x) ≥ |N1 (s)| (6) tx (s) = 0, 1u ⎪ ⎩ N (x) = 0 cancel, 1 The algorithm sets a smaller delay to node X when the size of N1u (X) is larger and vice versa. If |N1u (X)| is the largest among all the nodes receiving the broadcast transmitted by INode S, its transmission delay is the smallest.
308
S. Xiaoying et al.
Due to dynamic network topology and broadcast characteristics, node X may receive duplicate F b from different broadcast nodes before rebroadcasting. Thus, t x (s) is dynamic. Therefore, the delay is computed as follows (7): tx = αtx + (1 − α)tx (s)
(7)
where A ∈ [0,1]. 3.6 Algorithm Each node runs a neighbour discovery program. It obtains neighbour information through a neighbour discovery program and stores the information in its neighbour list. If node X receives packet F b sent by node S, node X updates its neighbour list and calculates N1u (X), Ni1 (X), N2i (X, K), and N1 (S), respectively, according to the list. If the number of uncovered neighbours of node X is not zero, the delay time is calculated; otherwise, the packet is discarded. If node X receives a duplicate packet in the delay time, it discards the packet after updating the neighbour list. When t x expires, INode X calculates the probability of rebroadcasting and then sends packets with probability px .
Algorithm 1 APB 1) if (node X receives Fb from node S) then 2) update the neighbour list by discover neighbour 3) 1L (X), 1 X (X), 1 L (X, K), and N1(S) are counted. 4) if (| 1 X (X)|==0) then return 5) endif 6) compute tx(s) by (6) 7) tx = tx(s) 8) while (!expired(tx) 9) if (receives duplicate Fb) then 10) update the neighbour list 11) update 1L (X), 1 X (X), 1 L (X, K), and N1(S) 12) compute tx by (7) 13) discard duplicate Fb 14) endif 15) endwhile 16) compute the probability px by (1) 17) if (random(0,1) ≤ px) then 18) broadcast Fb 19) else 20) discard Fb 21) endif 22) endif
Adaptive Probabilistic Broadcast in Ad Hoc Networks
309
4 Simulation In the simulation, different connection probabilities are selected to generate various connected networks. All nodes have the same transmission radius and data transfer rates. The simulation times for each algorithm are set to 1000. First, a network was randomly generated using OMNET + + with a random connection probability in each simulation. Subsequently, a node was randomly selected as the source mode. A message propagates from the source node to its neighbouring nodes until all nodes receive the message or a specified time is reached. Figure 5 shows that as the number of nodes in the network increases, the number of APB broadcasts does not exceed that of the other algorithms, and the number of packets rebroadcast by the APB does not change significantly. End-to-end delay is the amount of time the broadcast packet is successfully transmitted from the source to the destination node. The delay of the broadcast packets to each node is recorded and the sum of the delay of each node at the end of the simulation is calculated. Each network was simulated several times to calculate the average delay. Figure 6 shows the average delay of the different algorithms. The results show that the average delay of APB is lower than that of the other three algorithms.
Fig. 5. Comparison of broadcast times in different networks
Fig. 6. Average delay of different algorithms
310
S. Xiaoying et al.
Figure 7 shows the average number of collisions for each algorithm according to the number of nodes. The simulation results show that the number of conflicts for the APB is lower than that of other algorithms. With the increase in nodes in the network, the conflicts caused by APB increase relatively slowly. Reliability describes the number of nodes that receive the broadcast in a wireless network. A connected network of 20 nodes was randomly generated each time. A node is randomly selected as the source and the simulation time is set to 1 s, 2 s, 3 s and 4 s. After several simulations, the average number of nodes that could be arrived at by the broadcast packets was calculated. The APB can transmit broadcast packets to 16.5 nodes in 1 s and to 19 nodes in 2 s. The results show that the speed and number of broadcast packets covered by the APB are higher than those of the other algorithms. Figure 8 shows the average rate of nodes that receive the broadcast for different algorithms in different networks. The results show that the rate of broadcast packets covered by the APB is higher than those of the other algorithms. In previous simulations, the networks were randomly generated. Some INodes of p-Flood, NPB, and NCPR occasionally do not rebroadcast, resulting in some nodes being unable to receive the broadcast. A spindle network was designed to check the reliability, as shown in Fig. 9. The broadcast probabilities of p-Floods, NCPR, and NPB may be less than 1. Node 0 is used as the source. It broadcasts a packet. The probability of rebroadcasting the p-Flood algorithm for node 1 increases with an increase in p. The network adopts NCPR or NPB to broadcast packets. When the broadcast is transmitted to node 8, NCPR or NPB calculates the forwarding probability p8 . The p8 is less than 1. Therefore, NCPR and NPB may not be rebroadcast. Node 9 does not receive a broadcast. Through simulations in different randomly generated networks and in the network shown in Fig. 9, the APB transmits broadcast packets to all nodes.
Fig. 7. Average number of collisions
Adaptive Probabilistic Broadcast in Ad Hoc Networks
311
Fig. 8. Average coverage rate of different algorithms
Fig. 9. Spindle network
5 Conclusion To reduce the number of broadcast packets and achieve high reachability in ad hoc networks, an adaptive probability broadcast algorithm is proposed. This algorithm dynamically discovers the number of nodes and nodes of the neighbours and adaptively adjusts the probability of forwarding the packets. If there are many INodes in the UNode’s neighbours, each INode reduces the transmission probability to reduce conflict and redundant duplicate information. If there are more UNodes in the INode neighbours, the INode is more likely to replicate the packet. Although the number of replications is reduced, a message can be received by several UNodes. This algorithm does not utilize the global topology of the network and exchanges only the information of 1-hop neighbours. The algorithm dynamically adapts to changes in network topology and adaptively adjusts the transmission probability. This algorithm has the advantages of simple calculations and a small amount of data exchange. It can not only reduce the number of broadcasts and delays but also ensure data accessibility.
312
S. Xiaoying et al.
References 1. Faping, W.: Review on research and applications of V2X key technologies. Chinese J. Automot. Eng. 1(1), 1–12 (2020) 2. Juliet, A., Vijayakumar, S., Joan, P.R.: A comparative study of broadcasting protocols in VANET. Veh. Commun. 13(22), 1–22 (2018) 3. Tseng, Y.C., Ni, S.Y., Chen, Y.S., et al.: The broadcast storm problem in a mobile ad hoc network. Wireless Netw. 8(2), 153–167 (2002) 4. Reina, D.G.: A survey on probabilistic broadcast schemes for wireless ad hoc networks. Ad Hoc Netw. 25, 263–292 (2015) 5. Pallai, G.K., Sankaran, M., Rath, A.K.: Self-pruning based probabilistic approach to minimize redundancy overhead for performance improvement in MANET. Int. J. Comput. Netw. Commun. 13(2), 15–36 (2021) 6. Haas, Z.J., Halpern, J.Y., Li, L.: Gossip-based ad hoc routing. IEEE/ACM Trans. Netw. 14(3), 479–491 (2006) 7. Shuai, X.Y., Zhang, B., Yin, Y.X.: Adaptive probabilistic broadcasting for floating content. J. Phys: Conf. Ser. 1944(1), 1–6 (2021) 8. Ennaciri, A., Erritali, M., Cherkaoui, B., Sailhan, F.: Optimal broadcasting algorithm for VANET system. In: Bouzefrane, S., Laurent, M., Boumerdassi, S., Renault, E. (eds.) MSPN 2020. LNCS, vol. 12605, pp. 209–222. Springer, Cham (2021). https://doi.org/10.1007/9783-030-67550-9_14 9. Cartigny, J., Simplot, D.: Border node retransmission based probabilistic broadcast protocols in ad-hoc networks. Telecommun. Syst. 22(1), 189–204 (2003) 10. Lysiuk, I.S., Haas, Z.J.: Controlled gossiping in Ad hoc networks. In: Proceedings of Wireless Communications and Networking Conference WCNC (2010) 11. Kwon, T.J., Gerla, M.: Efficient flooding with passive clustering (PC) in ad hoc networks. ACM SIGCOMM Comput. Commun. Rev. 32(1), 44–56 (2002) 12. Borgonovo, F., Capone, A., Cesana, M., et al.: ADHOC MAC: New MAC architecture for Ad Hoc networks providing efficient and reliable point-to-point and broadcast services. Wireless Netw. 10(4), 359–366 (2004) 13. Ruiz, P., Bouvry, P.: Survey on broadcast algorithms for mobile ad hoc networks. ACM Comput. Surv. 48(1), 1–35 (2015) 14. Lim, H.: Multicast tree construction and flooding in wireless ad hoc networks. In: Proceeding ACM International Workshop on Modeling, Analysis and Simulation of Wireless and Mobile Systems, Boston, MA (2000) 15. Lim, H., Kim, C.: Flooding in wireless ad hoc networks. Comput. Commun. 24(3), 353–363 (2001) 16. Lou, W., Wu, J.: On reducing broadcast redundancy in ad hoc wireless networks. IEEE Trans. Mob. Comput. 1(2), 111–122 (2002) 17. Wei, P., Lu, X.C.: On the reduction of broadcast redundancy in mobile ad hoc networks. In: Proceedings of Workshop on Mobile & Ad Hoc Networking & Computing IEEE (2000) 18. Zhang, X., Wang, E., Xia, J., et al.: a neighbor coverage-based probabilistic rebroadcast for reducing routing overhead in mobile ad hoc networks. IEEE Trans. Mob. Comput. 12(3), 424–433 (2013) 19. Ejmaa, A., Subramaniam, S., Zukarnain, Z.A., et al.: Neighbor-based dynamic connectivity factor routing protocol for mobile Ad hoc network. IEEE Access 4, 8053–8064 (2017) 20. Liu, W.: A neighbor-based probabilistic broadcast protocol for data dissemination in mobile IoT networks. IEEE Access 6, 12260–12268 (2018) 21. Ryu., Jung Pil.:An adaptive probabilistic broadcast scheme for ad-hoc networks. In proc High Speed Networks & Multimedia Communications, IEEE International Conference, Hsnmc, Toulouse, France, June 30-july IEEE(2004)
Adaptive Probabilistic Broadcast in Ad Hoc Networks
313
22. Lu, D., Dong, S.: A neighbor knowledge and velocity based broadcast scheme for wireless ad hoc networks. Int. J. Distrib. Sens. Netw. 13(11), 1–14 (2017) 23. Yamamoto, R., Kashima, A., Yamazaki, T, et al.: Adaptive contents dissemination method for floating contents. In: Proceedings of 90th IEEE Vehicular Technology Conference. Piscataway, NJ: IEEE Press (2019) 24. Sun, W., Yang, Z., Zhang, X.: Energy-efficient neighbor discovery in mobile ad hoc and wireless sensor networks: a survey. IEEE Commun. Surv. Tutorials 16(3), 1448–1459 (2014) 25. Morillo, R.: More the merrier: neighbor discovery on duty-cycled mobile devices in group settings. IEEE Trans. Wireless Commun. 21(7), 4754–4768 (2022) 26. Sun, G., Wu, F., Gao, X., Chen, G., Wang, W.: Time-efficient protocols for neighbor discovery in wireless Ad hoc networks. IEEE Trans. Veh. Technol. 62(6), 2780–2791 (2013) 27. Turcanu, I., Kim, M., Klingler, F.: Towards 2-hop neighbor management for heterogeneous vehicular networks. In Proceedings of12th IEEE Vehicular Networking Conference, pp. 1–2 NJ: IEEE Press (2020)
A Light-Weighted Model of GRU + CNN Hybrid for Network Intrusion Detection Dong Yang(B) , Can Zhou, and Songjie Wei School of Computer Science and Engineering, Nanjing University of Science and Technology, Nanjing 210094, China {12013010,120106222759,swei}@njust.edu.cn
Abstract. Typical network traffic is characterized by high-dimensional, polymorphic and massive amounts of data, which is a consistent challenge for pattern-based intrusion detection. Most detection models suffer from low efficiency and poor consideration of portability. We propose a light-weighted network intrusion detection model incorporating GRU and CNN as an equilibration of model complexity and performance. Firstly, we prune off redundant features from the dataset using extremely randomized trees. Then feature extraction is performed using GRU, taking into account the long-term and short-term dependencies in the data, and all hidden layer outputs are treated as the sequences of feature information for the next step. We construct a CNN model with structures including inversed residual, depthwise separable convolution and dilated convolution for spatial feature extraction. The model convergence is accelerated with a channel attention mechanism. We conduct experiments on the CIC-IDS2017 dataset and have verified the proposed with excellent detection performance, as well as the advantages of simplicity, such as fewer model parameters, smaller model size, less training time and faster detection. Keywords: Network Intrusion Detection · Gated Recurrent Unit · Convolutional Neural Network · Light-weighted Detection Model · Extremely Randomized Tree
1 Introduction While the rapid development of the global Internet has brought convenient services to people, Internet also encounters a variety of security threats, especially through the highvolume malicious traffic. For example, a mega DDoS attack crippled the U.S. Internet in 2016. The frequency of malware attacks worldwide has brought cybersecurity to the forefront of attention. In order to mitigate or avoid the damage caused by network attacks on devices, various intrusion detection systems (IDS) are developed and deployed to surveil network data flows. In recent years, deep learning techniques have been applied in a variety of fields, such as speech recognition, image recognition, text translation, etc. Deep neural networks have certain advantages for the analysis and processing of high-dimension large-volume © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 314–326, 2023. https://doi.org/10.1007/978-981-99-4761-4_27
A Light-Weighted Model of GRU + CNN Hybrid
315
data. It can extract feature information from the raw data through various non-linear transformations, which is exactly suitable for the classification of malicious traffic in intrusion detection. Therefore, more and more deep learning methods had been designed and applied in network traffic intrusion detection [1]. It can be seen that researchers had invested a lot of effort in the detection performance of intrusion detection systems, but considered less on applicability and portability of these models, and never concerned with the number of parameters, model size and training time of the models. Due to resource limitations in network devices, complex intrusion detection models may not be successfully deployed and executed on network edge devices, which are vulnerable to attack damage [2]. Therefore, how to balance detection performance and model lightweight is an important problem to solve in intrusion detection research. This paper embeds Gated Recurrent Unit (GRU) into a light-weighted Convolutional Neural Network (CNN) to achieve an integrated and hybrid detection model named IGRU-LiCNN. It achieves the lightweight of the model in terms of both feature reduction and model structure, and the model maintains the excellent detection performance. Advantages can be summarized as the follows: 1. The IGRU-LiCNN takes into account the long-term and short-term dependencies in the sequence data, and splices the hidden layers of each time step of the GRU as the input to the CNN. 2. For spatial feature extraction, inversed residual, depthwise separable convolution, dilated convolution and channel shuffling are used for multi-scale feature extraction. We reduce the number of model parameters while allowing for adequate extraction of network traffic data features. The approach also improves the data representation capability and accelerates the convergence of the model by assigning different weights to each feature channel through the channel attention mechanism. 3. We apply the extremely randomized trees and lightweight CNN algorithms to minimize the attack detection time and improve the model’s detection performance.
2 Related Work Researchers originally used various traditional machine learning algorithms to solve classification problems. For example, Deng Z et al. [3] used the KNN algorithm to perform classification experiments on big data and medical imaging data. Garg S et al. [4] relied on SVM algorithm for anomaly detection work on internet of vehicles data. Recently, there has been an increasing interest in the application of deep learning techniques in the field of intrusion detection. Kasongo S M et al. [5] used Deep Gated Recurrent Unit for wireless network intrusion detection. Azizjon M et al. [6] designed a 1D-CNN model serializing packets of Internet protocols as training data. Similarly, developers used some hybrid models for network traffic intrusion detection. Kunhare N et al. [7] used a genetic algorithm for feature selection, and then applied a hybrid classifier for network traffic classification. Sun P et al. [8] designed a hybrid model of CNN and LSTM to extract features. The concept of class weights was used for classification to solve the problem of unbalanced data classes. These methods outperform traditional machine learning methods and have led to some progress in intrusion detection, but do not take into account the complexity of the
316
D. Yang et al.
model and its size. Deploying neural network models in devices without GPUs is still a problem due to storage space and computational power limitations. Recently, a number of researchers have taken notice of this problem and have carried out some research. Ren K et al. [9] used recursive feature elimination to reduce the dimensionality of the data features before feature extraction. Popoola S I et al. [10] used the encoding phase of a long-term and short-term memory autoencoder to reduce the feature dimensionality. The lightweight models mentioned above all reduce the complexity of the model by reducing the dimensionality of the input features. However, in the feature extraction step, they still use some complex models.
Fig. 1. The overall structure of the model
Algorithm 1 Pseudocode of the Proposed IGRU-LiCNN method for IoT Input: Training dataset Testing dataset Total epoch times Output: IGRU-LiCNN model for intrusion detection 1: procedure Data_Preprocessing 2: Deleting outliers from dataset 3: Transform the category label into a numerical representation 4: Compute the Z-score standardization result of the dataset Reduce the training dataset features by ERT 5: 6: Reduce the testing dataset features by ERT 7: return the new k-dimensional feature space 8: procedure IGRU-LiCNN 9: while do 10: Load lightweight network 11: Input into the lightweight network for training 12: Use cross-entropy as loss function 13: Save model 14: end while 15: Save model as IGRU-LiCNN model 16: Test and save the performance of IGRU-LiCNN model on 17: return IGRU-LiCNN model
A Light-Weighted Model of GRU + CNN Hybrid
317
3 Proposed Method The overall architecture of the proposed lightweight intrusion detection model IGRULiCNN is shown in Fig. 1, which mainly consists of three parts: data pre-processing, feature extraction and classification. The IGRU-LiCNN method is implemented as in Algorithm 1. 3.1 Detection Principle The characteristic distribution pattern of traffic generated when the network is under attack is somewhat different from normal traffic. Based on this characteristic, we can analyze the features of the captured packets to determine if the network is under attack. For example, in a DDoS attack, attackers use multiple computer systems as tools to simultaneously send a large number of requests to the target server, causing the target server to malfunction. We can detect DDoS attacks by analyzing the characteristics of network traffic, including port numbers, protocol types, packet sizes, flow rates, and packet response times. These features involve complex patterns and behaviors in malicious traffic. Traditional detection methods based on statistical thresholds are challenged to achieve good detection results because the threshold needs be manually determined and can only passively identify traffic that exceeds the threshold as abnormal. For complex traffic patterns, threshold statistical methods have high false positive and false negative rates. However, deep learning models have the ability to deal with non-linear relationships and can capture complex patterns and behaviors in malicious traffic by combining all data features, thus enabling more accurate detection of malicious traffic. Additionally, network traffic generated by the same attack has similarities, while traffic generated by different attacks has differences. Deep learning models can identify and extract feature information of malicious traffic and ultimately output results for classifying different types of network traffic. The IGRU-LiCNN model extracts features from both sequence and spatial perspectives, allowing it to more fully learn feature information in various types of traffic data. 3.2 Data Preprocessing The samples of network traffic are represented as T = [t1 , t2 , . . . , tn , C], where ti denotes the i-th traffic feature, and C denotes the label information corresponding to the traffic samples. Therefore, the whole network traffic dataset can be represented as: 1 1 t1 t2 · · · tn1 2 2 t t · · · t2 n 1 2 (1) TA = . . . . .. .. . . .. tm tm · · · tm 1
2
n
where n and m represent the number of features and the number of samples of network traffic data. Data preprocessing involves four steps. 1. Data cleansing. Deleting rows in the dataset with null, missing, or infinite values.
318
D. Yang et al.
2. Data numerization. For some non-numerical features in the dataset, such as the label information of the data, the classification values need be mapped to integer values using unique thermal coding. 3. Data standardization. Due to the different value ranges of various features, some features differ by several orders of magnitude, which can affect the classification results. Therefore, it is necessary to normalize the feature values. The Z-score standardization method is used in this case. By standardizing the data with the mean and standard deviation, the processed data conforms to a standard normal distribution, with a mean of 0 and a standard deviation of 1. The corresponding formula is as follows: z=
x−μ σ
(2)
4. Feature selection. The feature selection technique used in this paper is an extension of Extremely Randomized Trees (ERT) [11] as an ensemble learning technique. It builds multiple decision trees by randomly sampling and partitioning the data, and these decision trees are then ranked by feature importance indicators to perform feature selection. Compared to other feature selection methods, ERT has advantages such as efficiency and robustness, making it particularly suitable for feature selection in high-dimensional data. Through ERT, we can halve the feature dimensionality of the CIC-IDS2017 dataset. 3.3 IGRU-LiCNN Intrusion Detection Model When deploying intrusion detection models in network devices, the time and computational costs of the model, as well as the detection performance are inevitable concerns. Considering the sequence and spatial feature information in network traffic data, this paper proposes a model that integrates the GRU and CNN structures. The combination can extract sequential and spatial feature information from the data, respectively. GRU brings in simplicity but keeps focus, while CNN guarantees the detection granularity on the focus. Furthermore, by improving the conventional CNN structure, the proposed model has lightweight characteristics and can accurately detect network intrusion behavior. The intrusion detection component in the model can be divided into the sequence feature extraction module, the spatial feature extraction module, and the classification output module. Sequence Feature Extraction Module We apply GRUs to extract the dependencies between sequential data, preventing the problems of gradient vanishing and gradient explosion. Compared with LSTM, GRU has fewer parameters while ensuring performance in most circumstance. Input data is divided into multiple windows of width W, which include S consecutive network traffic sample data, and can be represented as S = [xt−W +1 , xt−W +2 , · · · , xt ]. S is then input into the GRUs, generating a vector representation hi for the hidden state at each time step. Considering the long-term and short-term dependencies in sequential data, all hidden layer information is output, not just the last hidden layer information. The corresponding calculation is as follows: h1 , h2 , · · · , hn = GRU (x1 , x2 , · · · xn )
(3)
A Light-Weighted Model of GRU + CNN Hybrid
319
Spatial Feature Extraction Module We incorporate dilated convolutions into the convolutional layer to achieve downsampling features. By changing the dilation coefficient of the dilated convolution, we can extract features from multiple receptive fields without the computational cost of multiple pooling operations. Additionally, we use the inversed residual structures to prevent gradient vanishing and network degradation. If ordinary residual structures are used to compress the feature map first, only limited feature information can be extracted. In contrast, inversed residual structures first map the feature map from a low-dimensional space to a high-dimensional space before performing feature extraction and then compress the feature map. The depthwise separable convolution structure is split into two independent parts in this paper. The depthwise convolution is used to extract features, while the pointwise convolution is used to expand and compress the feature map. Finally, these lightweight structures are combined to form the lightweight LiCNN unit as shown in Fig. 2. Conv 1*1 BN Relu Input
DWConv 3*3 BN Relu
Conv 1*1 BN Relu Conv Concat 1*1 BN Channel Shuffle Relu
Conv 1*1 BN Relu
Fig. 2. Structure diagram of LiCNN
As shown in Fig. 1, the spatial feature extraction module is as follows: first, concatenate all hidden layer information output from the previous module as the input of this module. Then, a regular convolutional layer with a stride of 2 is used to implement downsampling and feature map size adjustment. Next, the processed feature map is input into three layers of lightweight units (LiCNN) for feature extraction. Finally, the channel attention mechanism is used to assign different weight information to different feature channels according to their contribution to classification. The main implementation process of LiCNN is as follows: 1. The input feature map is divided into two equal parts, xf 1 and xf 2 , according to the number of channels. The feature map xf 1 is mapped equally, and the feature map xf 2 is used for feature extraction. 2. Mapping the feature map xf 2 from low-dimensional tensor to high-dimensional tensor using 1*1 pointwise convolution to obtain the feature map xf 2_h . 3. Feature extraction of feature map xf 2_h using depthwise convolution with dilated convolution structure, resulting in feature map xf 2_h2 . 4. The feature map xf 2_h2 is compressed using 1 * 1 pointwise convolution so that it has the same number of layers as xf 2 to obtain the feature map xf 2_out . 5. Tensor stitching is performed on the feature maps xf 1 and xf 2_out . The channel shuffling technique helps achieve the information interaction between feature maps, so as to eliminate the boundary effect and obtain the feature map xc_s . 6. We use 1*1 pointwise convolution to obtain the final output xLi_out . In the LiCNN structure, a BN layer is used for normalization after each convolution step. This ensures that the output can meet or approximately follow a normal distribution,
320
D. Yang et al.
thereby accelerating the convergence speed of the model and preventing the occurrence of gradient disappearance. Then, the ReLU activation function is used to amplify the differences between features and obtain the final output. For the three-layer LiCNN structure, the deep convolution is augmented with a Hybrid Dilated Convolution (HDC) structure, where the dilation coefficients are set to [1–3]. The calculation formulas for each are as follows: xout_c = CNN (Concatenation(h1 , h2 , · · · , hn ))
(4)
xout_gf = LiCNN xout_c
(5)
In the channel attention mechanism, in order to avoid the large number of parameters in the fully connected layer, global average pooling is used instead to compress features. The formula calculation is: zc = Fsq (uc ) =
H W 1 uc (i, j) i=1 j=1 H ×W
(6)
where uc is the feature points in the feature layer, and H and W the feature layer sizes. Two fully connected layers are used to fuse the information of each channel as follow: s = Fex (z, W ) = σ (g(z, W )) = σ (W2 δ(W1 z))
(7)
In the channel attention mechanism, the first fully connected layer compresses the feature channels to C/r, where r is the scaling parameter. Then, the Relu activation function is used. Next, the second fully connected layer restores the number of feature channels to C and uses the sigmoid function to obtain s. Finally, the weight and corresponding channel features are multiplied, and the formula for this is: x˜ c = Fscale (uc , sc ) = sc · uc
(8)
Classification Output Module This module consists mainly of an AdaptiveAvgpool layer and an output layer. In our design, the fully connected layer is replaced by an AdaptiveAvgpool layer, which greatly reduces the number of parameters and calculation complexity of the model, and improves the model’s generalization performance. Finally, the softmax function is used for network traffic classification.
4 Experimental Validation 4.1 Experimental Environment and Hyperparameter Settings We conduct the model training and testing on Windows platform with 2.4GHz CPU, NVIDIA GeForce MX450 GPU, 16.0GB RAM, torch version 1.2.0, and Python version 3.6.12. The proposed model uses the SGD optimizer with a learning rate of 0.001. The GRU hidden layer has 32 nodes, the batch size is set to 8, and the epoch is set to 15.
A Light-Weighted Model of GRU + CNN Hybrid
321
4.2 Dataset Selection We use the CIC-IDS2017 dataset [12], including 8 types of attacks. It contains 78 dimensional features. The data distribution is shown in Table 1. Due to the imbalance in the dataset, some categories have too few samples. In this experiment, only the four categories with a relatively large number of samples were selected, including BENIGN, DoS, DDoS, and Port Scan samples. Therefore, only the dataset files related to these attack samples were selected from the dataset. The ratio of training set to test set is 7:3. After data preprocessing, the distribution of each category is shown in Table 2. Table 1. Distribution of raw data in CIC-IDS2017 Category
Instances
Category
Instances
BENIGN
2273097
FTP-Patator
7938
DoS
252661
SSH-Patator
5897
DDoS
128027
Heartbleed
11
Port Scan
158930
Infiltration
36
Bot
1966
Table 2. Sample counts in CIC-IDS2017 dataset after extraction Category
Train Instances
Test Instances
BENIGN
466430
198856
DoS
176782
75879
DDoS
89563
38464
Port Scan
111173
47757
4.3 Performance Evaluation Metrics Following the tradition, we use accuracy, recall, F1-score, and number of parameters to evaluate the performances of IGRU-LiCNN and other methods. Accuracy =
TP + TN TP + TN + FP + FN TP TP + FN
(10)
precision · recall precision + recall
(11)
Recall = F1 = 2 ·
(9)
where TP, TN, FP and FN denote the numbers of true positives, true negatives, false positives and false negative, respectively.
322
D. Yang et al.
4.4 Experimental Results Figure 3 on the left shows the loss convergence and accuracy curves of the model training. We use warm-up and cosine annealing algorithm to dynamically adjust the learning rate, which can avoid the model getting stuck in local optima and greatly reduce the risk of model oscillation and overfitting. Figure 3 on the right shows the impact of window length and ERT on the accuracy of the model. It can be seen from the figure that when the window length is 10, the model achieves the best performance on the CIC-IDS2017 dataset. In addition, as the window length increases, there is a decreasing trend in the detection performance of the dataset. For larger windows, the detection performance of the model does not improve significantly, and it becomes difficult to train due to the increase in the number of parameters.
Fig. 3. Loss of convergence and accuracy curve for model training (left) Effect of window length and ERT on model detection performance (right)
We analyze the impact of feature selection techniques on experimental results. As shown in Table 3, the confusion matrix of the dataset before and after feature reduction is displayed. In the first experiment, the proposed model was trained and evaluated using all feature data (G1). In the second experiment, the model was trained and evaluated using the optimal 30 features (G2). It can be seen that the accuracy of the intrusion detection model based on G2 is higher than the accuracy trained on G1. This indicates that the feature selection algorithm ERT has removed irrelevant features for classification, allowing the model to better extract feature information. Table 3. CIC-IDS2017 Confusion matrix
G1
G2
Predicted Actual BENIGN DDoS DoS Port Scan BENIGN DDoS DoS Port Scan
BENIGN
DDoS
DoS
Port Scan
0.9958 0.0002 0.0176 0.0061 0.9976 0.0002 0.0124 0.0053
0.0003 0.9990 0 0 0.0001 0.9994 0 0
0.0031 0.0008 0.9824 0.0005 0.0021 0.0004 0.9876 0.0005
0.0008 0 0 0.9934 0.0002 0 0 0.9941
A Light-Weighted Model of GRU + CNN Hybrid
323
4.5 Comparative Analysis We benchmark the proposed with four other models, including IGWO-SVM [13], ANNCFS [14], KNN-PCA [15], and OCNN-HMLSTM [16]. The results are shown in Fig. 4. It can be seen that traditional machine learning algorithms are no longer suitable for intrusion detection models when faced with high-dimensional and complex data. In comparison, the ANN-CFS model has shown improvement in various indicators. But it did not consider the multi-dimensional feature information in the data. OCNN-HMLSTM solves this problem by using CNN and LSTM for feature extraction, but its model structure is single and does not consider network optimization. The model proposed in this paper considers the long-term and short-term dependencies in the data and uses the LiCNN structure to extract network traffic features, which has a richer feature representation ability. This solves the problem of insufficient feature extraction in traditional CNN structures and the excessively large number of parameters in deep models. In addition, the introduction of BN layer and attention mechanism accelerates the convergence speed of the model. Therefore, the proposed model has a shorter training time of 824.7 s. Similarly, the proposed model has an average feedback time of 18s for classifying the entire test set, which indicates that the detection time for a small batch of data is about 3.9 ms.
Fig. 4. Accuracy, recall, and F1-score for each model
For the ablation experimental analysis, the GRU model in Fig. 4 uses the output of the last hidden layer as the basis for classification, while the CNN model uses ordinary convolution for feature extraction. It can be seen that the combination of GRU and CNN improves the detection performance of the model. The addition of the LiCNN structure further improves the performance of the model, which is due to the improvement of the LiCNN compared to the original CNN structure. Among them, the inversed residual structure, depthwise convolution, and pointwise convolution allow the model to fully extract data feature information while reducing the number of parameters in the model, and also solve the problem of network degradation. The channel shuffle structure eliminates the boundary effect of the feature map. The addition of the attention mechanism allows the model’s detection performance to be further improved.
324
D. Yang et al.
4.6 Lightweight Performance of the Model Figure 5 shows the specific structural parameters of the two models, while Fig. 6 shows the parameter quantity of the model when using the CIC-IDS2017 dataset. When compared to the GRU + CNN model without using feature selection algorithms, our proposed model reduced the parameter quantity by 48.17%. Therefore, the model proposed in this paper can better alleviate the problem of limited resources of network devices. In addition, the proposed model reduces the number of parameters by 24.05% compared to the model without ERT. Thus, both lightweight feature extraction models and feature selection algorithms can contribute to the lightweight performance of the model. Through experiments, we found that the training time of our model was 34.8% less than that of the GRU + CNN model, which took 1265.3 s to train. Furthermore, the detection results show that our proposed method improved the accuracy by 2.35%. Therefore, our proposed method provides a solution for lightweight network intrusion detection models.
Fig. 5. IGRU-LiCNN structure diagram (top) GRU + CNN structure diagram (bottom)
Fig. 6. Comparison of the number of parameters in different models
5 Discussion and Conclusion In this paper, a lightweight intrusion detection model is proposed. Extremely randomized trees are first used to reduce the dimensionality of high-dimensional data features. GRU and lightweight CNN structures are used for feature extraction. The lightweight CNN model is the main focus of this paper, which reduces the computational cost by dilated convolution as well as expanding and compressing the feature map. The inversed residual and channel shuffle structures are also used to extract feature information more efficiently. The addition of the channel attention mechanism accelerates the convergence of the model. Experimental results verify that the proposed model maintains excellent performance in intrusion detection while reduces the model’s parameter count, volume,
A Light-Weighted Model of GRU + CNN Hybrid
325
and training time with low detection delay. Furthermore, the model is shown to be applicable for intrusion detection on a CPU-based computing platform. Therefore, for the problem of deploying intrusion detection models on resource-limited network devices, the proposed method provides a feasible solution for practical deployment. There are several future works that can be explored: First, the classification problem in imbalanced data can be further studied and evaluated. Second, the deployment of intrusion detection models in real network environments can be realistically tested on devices. While most researchers are approaching almost perfect accuracy in anomaly detection, we believe portability deployment and simplicity of data are the promising direction for IDS optimization both in theory and practice.
References 1. Yang, Z., Liu, X., Li, T., Wu, D., Wang, J., Zhao, Y., et al.: A systematic literature review of methods and datasets for anomaly-based network intrusion detection. Comput. Secur. 116 102675 (2022) 2. Kan, X., Fan, Y., Fang, Z., Cao, L., Xiong, N.N., Yang, D., et al.: A novel IoT network intrusion detection approach based on adaptive particle swarm optimization convolutional neural network. Inf. Sci. 568, 147–162 (2021) 3. Deng, Z., Zhu, X., Cheng, D., Zong, M., Zhang, S.: Efficient KNN classification algorithm for big data. Neurocomputing 195, 143–148 (2016) 4. Garg, S., Kaur, K., Kaddoum, G., Gagnon, F., Kumar, N., Han, Z.: Sec-IoV: a multi-stage anomaly detection scheme for Internet of vehicles. In: Proceedings of the ACM MobiHoc Workshop on Pervasive Systems in the IoT Era, pp. 37–42 (2019) 5. Kasongo, S.M., Sun, Y.: A deep gated recurrent unit-based model for wireless intrusion detection system. ICT Express. 7(1), 81–87 (2021) 6. Azizjon, M., Jumabek, A., Kim, W.: 1D CNN based network intrusion detection with normalization on imbalanced data. In: 2020 International Conference on Artificial Intelligence in Information and Communication, pp. 218–224 (2020) 7. Kunhare, N., Tiwari, R., Dhar, J.: Intrusion detection system using hybrid classifiers with meta-heuristic algorithms for the optimization and feature selection by genetic algorithm. Comput. Electr. Eng. 103, 108383 (2022) 8. Sun, P., Liu, P., Li, Q., Liu, C., Lu, X., Hao, R., et al.: DL-IDS: extracting features using CNN-LSTM hybrid network for intrusion detection system. Secur. Commun. Netw. 2020, 1–11 (2020) 9. Mohammadi, S., Mirvaziri, H., Ghazizadeh-Ahsaee, M., Karimipour, H.: Cyber intrusion detection by combined feature selection algorithm. J. Inf. Secur. Appl. 44, 80–88 (2019) 10. Popoola, S.I., Adebisi, B., Hammoudeh, M., Gui, G., Gacanin, H.: Hybrid deep learning for botnet attack detection in the Internet-of-Things networks. IEEE Internet Things J. 8(6), 4944–4956 (2020) 11. Shams, E.A., Rizaner, A., Ulusoy, A.H.: A novel context-aware feature extraction method for convolutional neural network-based intrusion detection systems. Neural Comput. Appl. 33(20), 13647–13665 (2021). https://doi.org/10.1007/s00521-021-05994-9 12. Sharafaldin, I., Lashkari, A.H., Ghorbani, A.A.: Toward generating a new intrusion detection dataset and intrusion traffic characterization. ICISSp. 1, 108–116 (2018) 13. Safaldin, M., Otair, M., Abualigah, L.: Improved binary gray wolf optimizer and SVM for intrusion detection system in wireless sensor networks. J. Ambient. Intell. Humaniz. Comput. 12, 1559–1576 (2021)
326
D. Yang et al.
14. Sumaiya-Thaseen, I., Saira Banu, J., Lavanya, K., Rukunuddin-Ghalib, M., Abhishek, K.: An integrated intrusion detection system using correlation-based attribute selection and artificial neural network. Trans. Emerg. Telecommun. Technol. 32(2), e4014 (2021) 15. Benaddi, H., Ibrahimi, K., Benslimane, A.: Improving the intrusion detection system for NSLKDD dataset based on PCA-fuzzy clustering-KNN. In: 2018 6th International Conference on Wireless Networks and Mobile Communications, pp. 1–6 (2018) 16. Kanna, P.R., Santhi, P.: Unified deep learning approach for efficient intrusion detection system using integrated spatial–temporal features. Knowl.-Based Syst. 226, 107132 (2021)
Reinforcement-Learning Based Preload Strategy for Short Video Zhicheng Ren1 , Yongxin Shan1 , Wanchun Jiang1(B) , Yijing Shan1 , Danfeng Shan2 , and Jianxin Wang1 1 School of Computer Science and Engineering, Central South University, Changsha, China
[email protected]
2 School of Computer Science and Technology, Xi’an Jiaotong University, Xi’an, China
https://faculty.csu.edu.cn/wanchun
Abstract. Now, short video application users have reached 1.02 billion and accounted for 94.8% of the total Internet users. The preload strategy for short video is the key to guarantee the Quality of Experience (QoE) of users. However, the design of preload strategy is challenging because the performance is influenced by factors including network bandwidth, video types, and user behavior. Existing preload strategies suffer from two issues. First, the impact of current decision on the future decision is ignored and each decision is evaluated independently, leading to local optimal decision. Second, the learning-based preload strategies predict the QoE of decisions as the rewards, which may deviate from the actual rewards of the decisions. To address these issues, we design the Reinforcement Learning based Preload Strategy (RLPS) for short video to improve QoE in this work. Specifically, RLPS constructs a delayed feedback mechanism to obtain the actual reward of each decision. In this way, the impacts of current decision on the future decision are also involved in the reward function. Simulation results confirm the advantages of RLPS under different scenarios. Specifically, compared with the state-of-the-art strategy PDAS, RLPS improves the combination score of QoE and bandwidth usage by more than 17.3%. Keywords: Short video · preload strategy · reinforcement-learning · delayed feedback mechanism
1 Introduction Nowadays, short video applications such as KuaiShou, DouYin, and TikTok develop rapidly. According to the 51th statistical report on the development of Internet in China released by the China Internet Network Information Centre (CNNIC), the number of short video users has reached 1.02 billion in December 2022, accounting for 94.8% of the total internet users [1]. This work is supported by the Nation Natural Science Foundation of China (No. 61972421), the Key R&D Plan of Hunan Province (No. 2022SK2107), the Excellent Youth Foundation of Hunan Province (No. 2022JJ20078), and the Fundamental Research Funds for the Central Universities of Central South University in China (No. 2022ZZTS0705). This work uses the computing resources at the High Performance Computing Center of Central South University. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 327–339, 2023. https://doi.org/10.1007/978-981-99-4761-4_28
328
Z. Ren et al.
The Quality of Experience (QoE) influences the preference to short video applications. Generally, QoE consists of the bitrates of played video chunks, the bitrate switching between video chunks, and the rebuffering time. In addition, with the increasing cost of bandwidth, how to reduce bandwidth usage while ensuring QoE has become another important issue. The preload strategy for short video is the key to improving the QoE of users as well as reducing bandwidth usage. The preload strategy will try to download high bitrate video chunks as much as possible, under the premise of reducing rebuffer and bitrate switching. Moreover, in short video applications, users can slide away the boring video and watch the next recommended video at any time. Users get rebuffering if the buffer of the next video is empty, which reduces QoE. Meanwhile, the chunks of swiped video are wasted, which increases the cost of bandwidth. Therefore, besides selecting the appropriate bitrate for the video chunks, the preload strategy also needs to determine which video chunk to preload. The performance of preload strategy is affected by many factors, including variable network bandwidth, video chunk size, and diverse user behaviors. Firstly, the network bandwidth and video chunk size determine the download time of the video chunk. if the network bandwidth is poor or the video chunk size is large, the download time of the video chunk is too long, and the user is easy to get rebuffering. To avoid rebuffering, the preload strategy downloads low bitrate video chunks, which also reduces QoE because of low bitrate and possible bitrate switching. On the contrary, if the network bandwidth is good and the video chunk size is small, users can enjoy high-definition and smooth video playback. Moreover, user behavior determines when users swipe away from the currently played video and choose another video in the recommended queue. If the buffer of played video cache is empty, the user will get rebuffering. Meanwhile, the chunks of swiped video are wasted, which increases unnecessary bandwidth costs. The current video and videos in the recommendation queue compete for resources in limited network bandwidth. When preloading videos, the preload strategy should conform to user behavior as much as possible. Generally, network bandwidth, video chunk size, and user behavior have significant effects on the performance of preload strategy. Adaptive Bitrate (ABR) algorithms [6–10] select the bitrate of video chunk according to network bandwidth and video buffer, to ensure the smooth playback of video. Because the ABR algorithm does not consider frequent video swiping in short video applications, it cannot be directly adopted as preload strategy for short video. To improve the QoE, TikTok proposed preload strategy for short video, Fixed [3]. Fixed algorithm preloads at most fixed video chunks for each video. If the buffer of the currently played video exceeds the threshold, Fixed will preload videos in the recommended queue in the order. However, Fixed ignores the preference of users for the videos. If the user is not interested in the currently played video and swipes it away soon, preloading videos in a fixed order may cause frequent rebuffering for the users. Therefore, works [4, 13] proposed the adaptive preload strategy, which adopts the user retention probability to predict the played probability of video chunks to download, and appropriately adjusts the priority of the decisions based on the played probability. The user retention probability is obtained by counting the specific played time of each video. However, works [4, 13] evaluate each decision independently. The current decision affects the future state of the client, and the future state of the client determines the next decision, i.e., decisions
Reinforcement-Learning Based Preload Strategy
329
are interrelated. For preload strategy, evaluating each decision independently may only achieve local optima instead of global optima. DAM [2] is a preload strategy based on reinforcement learning. Although DAM achieves global optima and adapts to various environments through deep reinforcement learning networks, the QoE portion of the rewards of DAM is based on prediction. The deviation between predicted rewards and actual rewards of decisions may lead the learning algorithm to the wrong direction. This paper makes the following contributions. A) This paper reveals the above issues and proposes the Reinforcement Learning based Preload Strategy (RLPS) for short video. Specifically, RLPS develops a delayed feedback mechanism to train reinforcement learning model for these issues. The reward function consists of QoE and bandwidth usage of the decision. Moreover, the environment returns the reward of decision only when the downloaded video chunk of the decision is played or swiped away. In this way, the agent can obtain the actual reward of each decision, instead of a predicted reward. After receiving the reward, the network parameter will be updated with the stored data, i.e., the corresponding past state and action. Furthermore, because preload strategy decisions are interrelated, the impact of current decision on future decisions is added into the reward function of RLPS. B) To evaluate the performance of RLPS, we implement it on the simulation platform used in the short video grand challenge of ACM Multimedia 2022 and conduct experiments under different scenarios. The results confirm that RLPS performs well under various network conditions with various video types and user behaviors. Specifically, compared with the state-of-the- art strategy PDAS, RLPS improves the combination score of the QoE and the bandwidth usage by at least 17.3%. The design of rest of this paper is organized as follows. Section 2 is the background and related work, and Sect. 3 presents the design of RLPS. Moreover, Sect. 4 is the evaluation and Sect. 5 concludes this paper.
2 Background and Related Work 2.1 Background In recent years, with the popularity of mobile terminals, more vivid and convenient short video applications are becoming the mainstream of people’s lives. In short video applications, users can slide away the boring video and watch the next recommended video at any time. For users, in addition to the quality of short videos, the QoE also affects the selection of short video applications. In this context, how to improve the quality of experience (QoE) for users has become an important issue to solve, which typically consists of the bitrates of played video chunks, the bitrate switching between video chunks, and the rebuffering time. The preload strategy is the key to improving QoE for users. Specifically, when the downloaded video chunk is played, the preload strategy will determine the next video chunk to download and the corresponding bitrate. Based on the decision provided by the preload strategy, the client will download the video chunk and place it in the video buffer, enabling the video to be played smoothly and avoiding rebuffering. The bitrate
330
Z. Ren et al.
of downloaded video chunks should be as high as possible. However, video chunks with higher bitrate represent longer download time. Because of limited bandwidth, it may occur that the next video chunk has not yet been downloaded but the video buffer has become empty. In this context, the client may experience severe rebuffer, and the user’s QoE is rapidly reduced. When the preload strategy is determining the current decision, the preload strategy will try to avoid rebuffering based on the video buffer, and the bitrate switching between the downloaded video chunks. Because the current decision affects the video buffer and the bitrate of video chunks in the buffer, the current decision has a significant impact on subsequent decisions. Decisions are interrelated, and this fact brings more challenges for the preload strategy to make optimal decisions. The performance of the preload strategy is mainly affected by factors such as network bandwidth, video chunk size, and user behavior. Bandwidth and video chunk size determine the download time of a video chunk, and if the download time of a video chunk exceeds the length of the video buffer, the client may experience frequent rebuffer. Moreover, there are significant fluctuations in network bandwidth, and current bandwidth estimation algorithms fail to accurately predict the short-term bandwidth. In addition, due to the uncertainty of user behavior, users can slip away from the currently playing video at any time, when downloaded but unplayed video chunks will be wasted, and the client may also get rebuffer due to empty buffers of videos in the recommended queue. Therefore, to maximize bandwidth utilization, the preload strategy needs to consider user behavior and avoid downloading video chunks that may not be played. 2.2 Related Work In live [17] and on-demand [12] video, ABR algorithm just selects the appropriate bitrate for downloaded video chunks. However, in short video applications, the sequence of video chunks played is determined by the user behavior. The ABR algorithm is not suitable for short video applications because it ignores frequent video switching caused by user behavior, which may lead to frequent rebuffering on the client. To reduce the rebuffer in short video applications, traditional preload strategies for short videos are proposed, such as Fixed [3]. Fixed preloads the currently played video and videos in the recommendation queue in a fixed order. The Fixed will first preload the currently played video. When the currently played video has been preloaded, the Fixed algorithm will preload videos in the recommended queue in order, with a maximum of four video chunks preloaded for each video. The Fixed uses the MPC [8] to determine the appropriate bitrate for the preloaded video chunks. However, due to the diversity of user behavior, the Fixed fail to satisfy all users. To improve the QoE and adapt to diverse user behavior, the adaptive preload strategy [2, 4, 13] is proposed. PDAS [4] is the most representative algorithm of adaptive preload strategies. PDAS sets a threshold for each video, based on the download time, and played probability of the video chunks. When the video buffer exceeds the threshold, PDAS will not preload this video. The PDAS evaluates each decision based on the bandwidth usage of the downloaded video chunks and the expected QoE of the downloaded video chunks being played in the future.
Reinforcement-Learning Based Preload Strategy
331
DAM is a short video preload strategy based on deep reinforcement learning with action masking [2]. It constructs Actor-Critic (AC) networks [5] and generates mappings from a set of states to optimal decisions. Specifically, the Actor network generates a probability distribution of optional actions based on the state of the environment. The Actor network selects the action with the highest probability and interacts with the environment to obtain rewards. Rewards received include bandwidth usage of downloading video chunks, rebuffering during downloading, and expected QoE of the downloaded video chunks being played in the future. DAM will update the Actor network and Critic network with the awards received. To accelerate training, the DAM also sets up abundant action masking. For traditional preload strategies, each decision is made independently. However, in short video applications, the current decision has a significant impact on the next decision, which is ignored by traditional preload strategies. Therefore, the decision made by traditional preload strategies may only be locally optimal, rather than the best decision during the entire video playback process. Moreover, for traditional preload strategies, they may fail to adapt to a variety of environmental conditions, because they rely entirely on evaluation methods designed based on prior knowledge. Although DAM achieves global optima and adapts to various environments through deep reinforcement learning networks, the QoE portion of the rewards of DAM is based on prediction. There may be errors between the rewards returned by the environment and the actual rewards of the decision. As a result, the reinforcement network may fail to learn in the correct direction, and the algorithm may face convergence difficulties. Moreover, for DAM, the abundant action masking also may make them fail to adapt to various environmental conditions.
3 Design of RLPS To address the above issues, RLPS is designed for short video in this section. 3.1 Framework To address the above issues, RLPS is designed in this section. The framework of RLPS is shown in Fig. 1. Briefly, the RLPS algorithm uses a deep reinforcement learning model to generate a set of mappings from states to actions. RLPS solves the local optima problem by evaluating the impact of current decisions on future decisions in the reward function. Moreover, the Actor and Critic network needs to update its parameters θ and ω according to the state, action and reward obtained each time. For the update of parameters θ and ω, RLPS designs a delayed feedback mechanism to postpone the parameter updates. The actual reward of the decision in the environment can be calculated by the delayed feedback mechanism. Thus, the problem of the unreliable reward function is solved. RLPS uses a combination of pre-training and online training to learn a more practical model.
332
Z. Ren et al.
Fig. 1. Framework of RLPS
3.2 Detail of RLPS RLPS is constructed by employing the classic AC model. The details are as follows. State. The state is defined as Sk = {Tk , Bk , Rk , Fk , Wk }. Here Tk is the average throughput of the past j downloaded video chunks. Bk is the buffer of N videos. Rk is the bitrate of the last downloaded video chunk for N videos. Fk represents the next video chunk sizes for each video, with three bitrate levels. Wk represents the expected played probability of the next video chunk for N videos, which is obtained from user retention probability for each video. Action. When the Actor network receives the state Sk , it will output the action Ak . Ak is defined as Ak = {V , I , D}. There are up to five videos and three bitrate levels for the agent to choose from. V represents the chooses video and I represent the corresponding bitrate. D is the time of pausing the download. Because the agent can repeatedly select the pause actions to extend the pause time, we set a small value of 250 ms for the pause action. Reward. After the Actor network outputs the current action Ak , the environment executes Ak and returns the reward Rk to the agent. Reward Rk is defined as: Rk = μ ∗ bitrate(Ak ) − v ∗ smooth(Ak ) − a ∗ rebuf (Ak ) _change(Ak ) −β ∗ band _use(Ak ) + γ ∗ bufmax(e f −1,ζ)
(1)
where bitrate(Ak ) is the bitrate of the downloaded chunk and μ is the actual played time of the downloaded chunk. smooth(Ak ) is the bitrate switch when the downloaded chunk is played, and v represents whether the downloaded chunk is played. The value of v is 0 or 1. Two parts are calculated by the environment when the downloaded chunk is played. rebuf (Ak ) is the rebuffering during the execution of Ak . band _usage(Ak ) is the bandwidth usage of the action Ak . buf _change(Ak ) represents the buffer change of the video j after the action Ak is executed. f represents the buffer of the video j at the state Sk . j represents the video that has not been preloaded over, from the currently played video to _change(Ak ) represents the impact the videos in the recommended queue in the order. bufmax(e f −1,ζ) of the current decision on the future decision. The current decision affects the buffer of the video, and the client may get frequent rebuffering when the buffer of the video is small. Therefore, if the current action the video when the buffer is small, it reduces
Reinforcement-Learning Based Preload Strategy
333
the risk of rebuffering for future action and should be rewarded. On the contrary, if the current action Ak decrease the buffer of the video when the buffer is small, Ak should be punished. Moreover, when the buffer is sufficient, increasing the buffer of videos will not necessarily bring better QoE to the user, because the downloaded video chunk may not be played. Therefore, we use ef − 1 to reflect the degree of the effect of Ak on future actions. We set ζ = 0.25 here and it is used to avoid the value of ef − 1 becoming _change(Ak ) very small. When the buffer is smaller, the value of bufmax(e f −1,ζ) will be larger, i.e., the impact of current action on future actions will be more significant.
Fig. 2. Delayed Feedback Mechanism
Action Masking. To accelerate the learning of Actor-Critic model, similar to [4], RLPS sets up a threshold for each video. Once the video buffer exceeds its threshold, the Actor network won’t export the decision that preloads this video. Specifically, for each video i, the threshold is set as follows: i , e−λ1 C−λ2 Bk ) Hi = ∗ max(Zmax i
(2)
i represents the expected download time of highest bitrate chunk of video i. where Zmax i Bk is buffer of video i. C is the predicted network bandwidth. , λ1 and λ2 is set 3.5, 0.3 and 0.15.
3.3 Delayed Feedback Mechanism The traditional Actor-Critic algorithm uses a single-step update model to update the network parameters θ and ω. Specifically, after the Actor network outputs the action Ak at the state Sk , the environment executes Ak and returns reward Rk as well as new state Sk+1 to the agent. The Actor network and Critic network will update the parameters θ and ω based on Sk , Ak , Rk and Sk+1 . Different from the traditional single-step update strategy, we set up a delayed feedback mechanism. Specifically, for each action output by the Actor network, the reward of action Ak is not immediately returned after Ak is executed in the environment, but the state Sk+1 is immediately returned for the Actor network continues to generate the next action Ak+1 . Rk will be returned to the agent only when the downloaded video chunk of Ak is played over or swiped away. After receiving Rk , the agent will update the parameters of networks based on Sk , Ak , Rk and Sk+1 . As
334
Z. Ren et al.
shown in Fig. 2, for action Ak , only when the downloaded video chunk is played over at the state Sk+2 , the environment returns Rk to the agent. For the action of pausing the download, the environment returns the reward immediately. The reason for adopting the delayed feedback mechanism is that for the preloading strategy, it is difficult to figure out the actual played time of the currently downloaded video chunk, even the video chunk may not be played. However, only the played video chunks bring actual QoE to the user. Adding the expected QoE to the reward of the action may lead the agent to learn in the wrong direction Therefore, to truly reflect the impact of current decisions, we adopt the delayed feedback mechanism. To sum up, RLPS makes the preloading strategy more accurate by feeding back real decision rewards through delayed feedback mechanism and takes into account the impact of current decisions on future decisions in reward function, so that the preloading model achieves global optimality. The specific process of the RLPS algorithm is as follows. Firstly, RLPS initializes the AC network based on the preset values. Secondly, a simulator is used to simulate the environment for AC network learning, which includes factors such as network bandwidth, videos, and user behavior. The environment executes action Ak and returns a new state Sk+1 , and rewards R based on delayed feedback mechanism. Thirdly, the AC network update parameters θ and ω based on the reward R returned each time. Finally, at the end of each video playback, the simulator updates the environment given the new network bandwidth, video, and user behavior. 3.4 Training RLPS uses a combination of pre-training and online learning to update the AC model. In the pre-training process, a training set consisting of different videos, network bandwidths, and user behaviors is selected for model pre-training. When the pre-training process is complete, the model will be trained online in the actual test scenario, using pre-training models. In the online-training process, RLPS will continuously update the AC model. The reason for pre-training the model is that the performance of short video preload strategy is affected by many factors, including variable network bandwidth, different video chunk sizes and diverse user behaviors. The combination of these factors is various, so the preloading model needs to learn different scenarios in advance to meet the actual needs of users. In addition, because RLPS contains a variety of actions and states, to make the model accurate, it must undergo extensive training. Moreover, because the combinations of these factors are endless in actual application scenarios, training online enables the RLPS to perform better in various application scenarios.
4 Evaluation 4.1 Setup Similar to [2, 4, 13], RLPS is implemented on the short video simulator provided by [11] to evaluate its performance. In this simulator, the recommendation queue comprises four videos, besides the currently played video. All videos are encoded into three bitrates at [750, 1200, 1850] kbps and cropped into chunks with one second. The simulator
Reinforcement-Learning Based Preload Strategy
335
generates the user behavior by sampling video retention rate tables, which are based on the user retention probability of each video. Different types of network bandwidth are also provided, with 20 groups each for low bandwidth, medium bandwidth, high bandwidth, and mixed bandwidth. The simulator only provides seven videos to evaluate the performance of the algorithm, but this is far from meeting the training requirements of RLPS. Therefore, we downloaded 93 additional videos from TikTok, and generated the random user retention probability for these videos. The Actor network employs the simple NN [14] layer to export the probability of each action. Specifically, the state is delivered to three hidden layers with 128, 64, and 64 neurons, respectively, and the last output layer has 16 neurons. The SoftMax function [15] is adopted as the activation function. The Critic network employs a similar network structure, with only one neuron in the last layer.
(a) CDF of scores under low bandwidth
(b) CDF of scores under medium bandwidth
(c) CDF of scores under high bandwidth
Fig. 3. Comparison of preload strategies under different network bandwidths. The more the algorithm lines go to the right, the better the performance.
The performance of the algorithm is evaluated after all videos have been played by the simulator. The evaluation method is as follows: Score = ξ ∗ Bandy (σ ∗ Bitratel − η ∗ Smoothl ) − (μ ∗ Rebufd ) − l
d
y
(3) σ = 1, η = 1, μ = 1.85, and ξ = 0.5. Bitratel and Smoothl represents the bitrate of video chunk l and bitrate switching when the video chunk l is played. Rebuf d represents the rebuffering during the video playback process. Band y represents the bandwidth usage of download video chunk y [8, 16]. During the pre-training process, RLPS continuously learns in various environments provided by simulators, which consist of different user behaviors, videos, and network bandwidth. After pre-training is completed, this paper compares Fixed, PDAS, and RLPS under different network environments, user behaviors, and videos. There is no comparison with DAM because the network parameters and training process of DAM are unknown.
336
Z. Ren et al.
4.2 Results Performance Comparison Under Different Network Bandwidths. To demonstrate that RLPS is suitable for a variety of bandwidths, we compare RLPS with Fixed and PDAS algorithms in low, medium, and high bandwidths respectively. Figure 3 shows the Cumulative Distribution Function (CDF) scores of different preload strategies with different network bandwidth, the simulation results show that compared with Fixed and PDAS, RLPS improves the score at different bandwidths, confirming the excellent performance of RLPS at different bandwidths. In addition, to precisely measure the gap between different strategies, this paper calculates the average scores of different strategies under three bandwidths in Table 1. We use the score of Fixed as a baseline, and compared to PDAS, RLPS improves the score by 33.7% under low bandwidth, 23.2% under medium bandwidth, and 17.3% under high bandwidth. And from the table, it can be observed that when comparing Fixed and RLPS, the QoE and Band of RLPS varies under different bandwidths, and in some cases it even decreases. Under low and medium bandwidth conditions, RLPS struggles to download high bitrate video chunks to improve QoE due to bandwidth limitations. The improvement in scores mostly stems from the reduction in bandwidth usage, it relies on delayed feedback mechanism to accurately identify video chunks that not playing by user, thereby reducing download these chunks. Under high bandwidth conditions, RLPS frequently downloads high bitrate video chunks to improve user QoE, with the reward function set in RLPS as a key factor in ensuring smooth playback for users. However, RLPS has a larger bandwidth usage compared to PDAS, which is due to the action masking setting during the pre-training process. Specifically, in high bandwidth, the download time for video chunks is relatively short, and it doesn’t require keeping many video chunks in the buffer to avoid rebuffered in playback. The action masking mechanism in RLPS keeps the buffer at a higher value, thereby increasing bandwidth usage. Table 1. Average scores, QoE, and bandwidth usage of different preload strategies under different bandwidth
Low bandwidth
Medium bandwidth
High bandwidth
Preload Strategy
Average Scores
QoE
Bandwidth usage
Fixed
−428.849
245.046
1347.784
PDAS
−259.854
404.796
1329.301
RLPS
−202.961
410.051
1226.024
Fixed
−167.092
635.820
1605.760
PDAS
1.229
648.287
1299.656
RLPS
40.923
658.341
1234.832
Fixed
110.333
1385.856
2551.046
PDAS
583.009
1686.975
2207.932
RLPS
665.987
1824.943
2317.912
Reinforcement-Learning Based Preload Strategy
(a) CDF of scores under low (b) CDF of scores under high retention probability retention probability
337
(c) CDF of scores under mixed retention probability
Fig. 4. Comparison of preload strategies with different user behaviors under mixed network bandwidth. The more the algorithm lines go to the right, the better the performance.
Performance Comparison Under Different user Behaviors. User behavior as one of the factors influencing preloading strategy. To demonstrate that RLPS is suitable for all kinds of user behaviors, the scores of different preload strategies with different user behaviors are compared in this paper. Experiments are conducted with user behaviors of low retention probability, high retention probability, and mixed retention probability, under mixed network bandwidth. As shown in Fig. 4, the simulation results show that compared with Fixed and PDAS, RLPS improves the score at different user behavior. RLPS controls the video buffer size in the right range at different bandwidths by pretraining, so that it achieves a stable improving under different user behaviors.
(a) CDF of scores under first group of videos
(b)
CDF of scores under second group of videos
Fig. 5. Comparison of preload strategies in different groups of videos under mixed network bandwidth. The more the algorithm lines go to the right, the better the performance.
Performance Comparison Under Different Videos. The preloading strategy is affected by the video chunk size, and different types of videos will affect the performance of the strategy. To demonstrate that RLPS is suitable for a variety of videos, we compare RLPS with Fixed and PDAS algorithms in two groups of videos under mixed network bandwidth. Each group contains 50 videos. As shown in Fig. 5, the simulation results show that compared with Fixed and PDAS, RLPS improves the score at different videos. The impact of the size of video chunks on RLPS is small because the effect of video chunk size has been taken into account in the reward function, and RLPS can make globally optimal decisions to avoid lag when larger video blocks are encountered.
338
Z. Ren et al.
5 Conclusion In this paper, we reveal that existing preload strategies suffer from the problem of the local optima and unreliable evaluation function. To address these problems, we propose the RLPS strategy, which develops a delayed feedback mechanism for reinforcement learning. The delayed feedback mechanism helps RLPS to obtain the true reward of each decision, and the effect of the current decision on the future decision is considered in the reward function. The simulation results show that RLPS improves the comprehensive scores of QoE and bandwidth usage by at least 17.3% compared with the state-of-art strategy PDAS.
References 1. CNNIC. the 51th Statistical Report on the Development of Internet in China. https://cnnic. cn/n4/2023/0302/c199-10755.htmll. Accessed 2 Mar 2023 2. Qian, S.Z., Xie, Y., Pan, Z., et al.: DAM: deep reinforcement learning based preload algorithm with action masking for short video streaming. In: Proceedings of the 30th ACM International Conference on Multimedia, pp. 7030–7034 (2022) 3. Guo, J., Zhang, G.A.: video-quality driven strategy in short video streaming. In: Proceedings of the 24th International ACM Conference on Modeling, Analysis and Simulation of Wireless and Mobile Systems, pp. 221–228 (2021) 4. Zhou, C., Ban, Y., Zhao, Y., et al.: PDAS: probability-driven adaptive streaming for short video. In: Proceedings of the 30th ACM International Conference on Multimedia, pp. 7021– 7025 (2022) 5. Konda, V., Tsitsiklis, J.: Actor-critic algorithms. Adv. Neural Inf. Process. Syst. 12 (1999) 6. Huang, T., Zhou, C., Zhang, R.X., et al.: Stick: a harmonious fusion of buffer-based and learning-based approach for adaptive streaming. In: IEEE INFOCOM 2020-IEEE Conference on Computer Communications, pp. 1967–1976. IEEE (2020) 7. Lv, G., Wu, Q., Wang, W., et al.: Lumos: towards better video streaming QOE through accurate throughput prediction. In: IEEE INFOCOM 2022-IEEE Conference on Computer Communications, pp. 650–659. IEEE (2022) 8. Yin, X., Jindal, A., Sekar, V., et al.: A control-theoretic approach for dynamic adaptive video streaming over HTTP. In: Proceedings of the 2015 ACM Conference on Special Interest Group on Data Communication, pp. 325−338 (2015) 9. Zhang, X., Ou, Y., Sen, S., et al.: SENSEI: aligning video streaming quality with dynamic user sensitivity. In: 18th USENIX Symposium on Networked Systems Design and Implementation (NSDI 21), pp. 303–320 (2021) 10. Yan, F.Y., Ayers, H., Zhu, C., et al.: Learning in situ: a randomized experiment in video streaming. In: 17th USENIX Symposium on Networked Systems Design and Implementation (NSDI 20), pp. 495–511 (2020) 11. Zuo, X., Li, Y., Xu, M., et al.: Bandwidth-efficient multi-video prefetching for short video streaming. arXiv preprint arXiv:2206.09839 (2022) 12. Timmerer, C.: HTTP Streaming of MPEG Media [EB/OL]. [2022–04–15]. https://multimedi acommunication.blogspot.co.at/2010/05/http-streaming-of-mpeg-media.html 13. Wu, X., Zhang, L., Cui, L.: QoE-aware download control and bitrate adaptation for short video streaming. In: Proceedings of the 30th ACM International Conference on Multimedia, pp. 7115–7119 (2022) 14. He, K., Gkioxari, G., Dollár, P., et al. Mask r-cnn. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2961–2969 (2017)
Reinforcement-Learning Based Preload Strategy
339
15. Bishop, C.M. (ed.): Pattern Recognition and Machine Learning. ISS, Springer, New York (2006). https://doi.org/10.1007/978-0-387-45528-0 16. Mao, H., Netravali, R., Alizadeh, M.: Neural adaptive video streaming with pensieve. In: Proceedings of the Conference of the ACM Special Interest Group on Data Communication, pp. 197–210 (2017) 17. Kim, J., Jung, Y., Yeo, H., et al.: Neural-enhanced live streaming: Improving live video ingest via online learning. In: Proceedings of the Annual Conference of the ACM Special Interest Group on Data Communication on the Applications, Technologies, Architectures, and Protocols for Computer Communication, pp. 107–125 (2020)
Particle Swarm Optimization with Genetic Evolution for Task Offloading in Device-Edge-Cloud Collaborative Computing Bo Wang(B)
and Jiangpo Wei
Software Engineering College, Zhengzhou University of Light Industry, Zhengzhou, China [email protected] Abstract. There have been some works proposing meta-heuristic-based algorithms for the task offloading problem in Device-Edge-Cloud Collaborative Computing (DE3C) systems, due to their good performance than heuristic-based approaches. But these works don’t fully exploit the complementarity of multiple meta-heuristic algorithms. In this paper, we combine the benefits of both swarm intelligence and evolutionary algorithm, for designing a high-efficient task offloading strategy. To be specific, our proposed algorithm uses the iterative optimization framework of Particle Swarm Optimization (PSO) to exploit the cognitions of swarm intelligence, and applies the evolutionary strategy of Genetic Algorithm (GA) to preserve the diversity. Extensive experiment results show that our proposed algorithm has better acceptance ratio and resource utilization than nine of classical and up-to-date methods. Keywords: Genetic Algorithm · Particle Swarm Optimization · Task Offloading · Edge Computing · Cloud Computing
1 Introduction Meta-heuristics, including but not limited to swarm intelligence and evolutionary algorithms, use global search strategies inspired by some natural rules or social behaviors. Meta-heuristics can achieve good or even global best solutions for decision-making problems, and thus have been widely used in various fields. In this paper, we focus on the application of meta-heuristics to the task offloading for device-edge-cloud collaborative computing (DE3C). DE3C is an effect way to address the insufficient local resources for user satisfactions, due to limited spaces of user devices, e.g., smartphones and IoT devices. DE3C puts some edge resources near devices for low latency services, and rents resources from clouds for resource-intensive applications. The task offloading is a challenge problem need to be addressed for DE3C, to optimize the performance and resource efficiency. The task offloading is to decide which and how many resources for processing every user request (task assignment and resource allocation), and the processing order of multiple requests on each resource. Unfortunately, the task offloading is NP-hard generally [6]. Supported by the key scientific and technological projects of Henan Province (Grant No. 232102211084), and the Natural Science Foundation of Henan (Grant No. 222300420582). © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 340–350, 2023. https://doi.org/10.1007/978-981-99-4761-4_29
Particle Swarm Optimization with Genetic Evolution
341
There are mainly two kinds of approaches for solving the task offloading problem, which are heuristics and meta-heuristics. Heuristics exploit problem-specific local search strategies for approximate solutions. Heuristics usually consume only a little time but have limited performance. Thus, some other works exploiting the global search abilities of meta-heuristics for the task offloading to pursue better performance. But these meta-heuristic-based algorithms have some issues, which limits their effectiveness. Some works employed only one kind of meta-heuristics, which provides limited performance improvement, as each algorithm has its own strengths and weaknesses. Some other works considered to combine two or more meta-heuristics by sequentially performing these algorithms, which did not fully exploit the complementation between different meta-heuristics. In addition, most of existing works more or less simplified the task offloading problem, such as assuming homologous hybrid resources, ignoring the resource capacities of user devices, only concerning the task assignment sub-problem. Therefore, in this paper, we propose a hybrid heuristic task offloading algorithm for DE3C systems, by exploiting the complementation between Particle Swarm Optimization (PSO), one of the most representative swarm intelligence algorithms, and Genetic Algorithm (GA), a representative evolutionary algorithm. PSO has fast convergence rate but easily trapping into local optimal solutions. On the contrary, GA has powerful global search ability but slow convergence rate. Thus, PSO and GA are highly complementary. Specifically, our proposed hybrid heuristic algorithm uses the iterative process framework of PSO, and performs the crossover and mutation operators of GA with selfand social cognitions on each particle. In brief, the contributions of this paper are as followings. – We formulate the task offloading problem for heterogeneous DE3C into a binary non-linear programming (BNLP), optimizing the acceptance ratio and the resource utilization. – We propose a hybrid heuristic algorithm to solve the task offloading problem with polynomial time, by combining both advantages of GA and PSO. – We conduct extensive simulated experiments to evaluate our proposed algorithm, and the results show that our algorithm has better performance than several of classical and latest offloading algorithms. The rest of this paper are as followings. Section 2 illustrates existing related works. Section 3 formulates the task offloading problem, and Sect. 4 presents our proposed hybrid algorithm. In Sect. 5, we describe our conducted experiments and discuss the experiment results. At last, Sect. 6 concludes our work.
2 Related Works Recently, as the popularity of smart devices increases, more and more researches focused on the task offloading problem based on various heuristics and meta-heuristics.Li et al. [10] proposed an offloading method that scheduled the task with heaviest workload to the edge server or the local device that provides the minimum completion time with the energy budget of the device. Hao et al. [8] made the offloading decision for mobile cloud computing based on grey number, for improving energy consumption. Song et al.
342
B. Wang and J. Wei
[14] exploited ant colony optimization (ACO)-based offloading method to optimize the average delay by the collaboration of multiple edge computing. These above works only concerned the offloading of tasks to either edge or cloud computing, instead of exploiting their collaboration for better performance. Two heuristic offloading methods were proposed by Wang et al. [19], aiming at increasing response speed and load balancing, respectively, by cloud-edge collaboration. The first one was greedily offloading a task to the edge server providing the shortest service delay. The second one selected the edge server with most connected devices for each offloaded tasks. In the work of Almutairi et al. [1], the task offloading was formulated as a binary linear programming, and was solved by branch and bound methods after relaxing binary variables into real ones. References [15, 20] exploited GA to decide each task’s offloading for optimizing the delay or resource cost. These works only concerned the collaboration of edge and cloud computing for offloaded tasks, which wastes the resources of devices generally equipped with a certain amount of hardware. Therefore, for solving the task offloading problem with the collaborative computing of devices, edges, and the cloud, You & Tang [21] applied PSO for energy efficiency and low delay. Wang et al. [17, 18] respectively used PSO and GA to improve user satisfaction and resource efficiency. Chakraborty and Mazumdar [5] employed GA to optimize the device energy consumption with delay constraints. Alqarni et al. [2] made use of PSO for the delay minimization. All of above works used only one heuristic or meta-heuristic method, without combining advantages of multiple algorithms. Mahenge et al. [11] combined the position update strategies of PSO and Grey Wolf Optimizer (GWO), to optimize the device energy cost for DE3C. This work didn’t exploit complementarity of different kinds of meta-heuristic algorithm in its proposed hybrid algorithm, and thus may have a limited improved performance. To exploit both benefits of GA and PSO, Hafsi et al. [7] proposed a hybrid offloading algorithm, which performed the population update strategies of GA and PSO in the first and last half evolutionary phases, respectively. Nwogbaga et al. [12] added a mutation operator of GA on each individual at the end of each iteration of PSO. These hybrid algorithms are only simply performing each individual algorithm separately, leading to poorly exploiting the complementarity of different algorithms. In addition, all of these algorithms are ignoring the resource heterogeneity in DE3C, which can result at resource inefficiency. Thus, in this paper, we study on a new integration strategy exploiting complementarity of GA and PSO, to addressing the task offloading problem concerning the resource heterogeneity.
3 Problem Formulation In this paper, we consider the three-tier DE3C architecture. The device tier is consisting of various user devices launching request tasks based on user requirements. The edge tier provides several edge servers (ESs) for processing tasks requested by corresponding user devices. The cloud tier provides abundant computing and storage resources to make up for the limited service capacity of devices and ESs, in the form of cloud servers (CSs). In a DE3C system, the edge tier provides communication connections for data transfer between ESs and devices over various communication technologies. For each ES,
Particle Swarm Optimization with Genetic Evolution
343
there is a coverage area, and only devices in this area can have connections with the ES, as the connections between ESs and devices are wireless generally. When a user device has insufficient capacity for satisfying is requests, some of these requests can be offloaded to an ES covering the device and processed by the ES’s resources. In this paper, we focus on the task offloading problem, to decide the computing node (local device, ES or CS) for each task’s processing. A DE3C system consists of D devices, E ESs, and V CSs, which are respectively si (i = 1, ..., D), si (i = D + 1, ..., D + E) and si (i = D + E + 1, ..., D + E + V ). The computing node si has gi computing capacity. The network transfer rate of ES/CS si is ri for the data transfer between a device and the ES/CS. Each ES provides LAN network connections for devices covered by it. All devices can have connections with CSs over Internet. Thus, in general, the network transfer rates of CSs are poorer than ESs. We use a binary constant ai,j (D+1 ≤ i ≤ D+E, 1 ≤ j ≤ D) to represent the connectivity between a ES and a device, where ai,j = 1 if sj is covered by si , and otherwise, ai,j = 0. There are T tasks lunched by devices, represented by tk (k = 1, ..., T ). We use binary constants bi,k (1 ≤ i ≤ D, 1 ≤ k ≤ T ) to denote the relationships between tasks and devices. bi,k = 1 means tk is lunched by the device si , and bi,k = 0 means not. A task can be processed locally by its device. When the device has insufficient resources, the task can be offloaded to the ES that has connection with its device for its processing. Each task can be also offloaded to a CS rented from the cloud, when it is not sensitive to the latency. Task tk requires ck computing resources for processing its input data. The input data amount of tk is mk . In this paper, we ignore the time consumed by result return for each task, as done in many related works, because the result data is generally much less than the input data for a task in DE3C. The deadline of tk is dk , i.e., tk must be finished before dk . Without loss of generality, we assume that d1 ≤ d2 ≤ ... ≤ dT . If a task’s deadline can be satisfied, DE3C accepts and processes the task. Otherwise, the task is rejected as there is no profit for its processing. For the formulation of the task offloading problem in DE3C, we define binary variables xi,k to represent the scheduling of tasks to computing nodes, as Eq. (1). If tk is scheduled to si for its processing, xi,k = 1. Otherwise, xi,k = 0. 1, if tk is scheduled to si for its processing , 1 ≤ i ≤ D + E + V, 1 ≤ k ≤ T. xi,k = 0, else (1) For each task, it cannot be processed by devices other than its device or ESs without a connection with its device. Thus, the constraints (2) and (3) hold. xi,k ≤ bi,k , 1 ≤ i ≤ D. xi,k ≤
D
(ai,j · bj,k ), D + 1 ≤ i ≤ D + E.
(2)
(3)
j=1
For a task, when it is processed locally, the processing time is consumed only by computing, as the input data is stored in the device. Then, the finish time is the start time plus
344
B. Wang and J. Wei
the computing time for the task. As the earliest deadline first (EDF) provides an optimal solution for processing tasks in a computing node, in maximizing the acceptance ratio [4], we can establish the acceptance ratio optimization model based on EDF. Thus, in each device, for tasks processed locally, each task’s computing can be started after the tasks with earlier deadlines are finished. Therefore, the finish time of tasks processed locally can be calculated by Eq. (4). ck /gi is the computing time consumed by tk . kk−1 =1 (xi,k · ck /gi ) is the accumulated computing time consumed by tasks with earlier deadlines than tk and processed by si (xi,k = 1, 1 ≤ i ≤ D). fi,k =
k
(xi,k · ck /gi ), 1 ≤ i ≤ D.
(4)
k =1
When a task is offloaded to an ES or a CS, there are time consumed by transferring the input data from the device. In this paper, we assume that each ES/CS is equipped with one network interface card (NIC), and the input data of multiple offloaded tasks are transferred sequentially in each ES or CS to avoid the mutual interference. When tk is offloaded to si (D+1 ≤ i ≤ D+E+V ), the time consumed by its input data transfer is mk /ri , and the consumed computing time is ck /gi . The data transfer of tk can be started only when si finishes that of tasks that have earlier deadlines and are offloaded to si . The accumulated time of data transfer of these tasks is kk−1 =1 (xi,k · mk /ri ), and offloaded to s can be achieved, which is thus, the finish time of the data transfer of t i k k k =1 (xi,k · mk /ri ). There are two situations for a task’s computing, where the bottlenecks are the data transfer and the computing, respectively, when the task is processed by an ES or a CS. In the first situation, the task’s computing can be started after its data transfer is finished. In this case, the finish time of the task can be calculated by Eq. (5). R = fi,k
k k =1
(xi,k ·
mk ck ) + ,D + 1 ≤ i ≤ D + E + V. ri gi
(5)
When the bottleneck is the computing, the task’s computing can be started when the computing of tasks processed before it in the same ES or CS, which is later than the finish time of its data transfer. In such case, the finish time can be got by Eq. (6). ck C = max (xi,k · fi,k ) + , D + 1 ≤ i ≤ D + E + V . (6) fi,k gi 1≤k ≤k−1 Combining both cases, we can achieve the finish time of tasks offloaded to an ES or a CS by Eq. (7). R C , fi,k }, D + 1 ≤ i ≤ D + E + V . fi,k = max{fi,k
(7)
As every task can be processed by only one computing node, then inequalities (8) xi,k = 0. And thus, the finish time for hold. If task tk is rejected in the DE3C, D+E+V i=1 each task can be given by Eq. (9). And if tk is rejected, its finish time fk gets 0. Then, the deadline requirements of tasks can be formulated as Eq. (10). D+E+V i=1
xi,k ≤ 1, 1 ≤ k ≤ T .
(8)
Particle Swarm Optimization with Genetic Evolution
fk =
D+E+V
(xi,k · fi,k ), 1 ≤ k ≤ T .
345
(9)
i=1
fk ≤ dk ,
1 ≤ k ≤ T.
(10)
For each computing node, the occupied time is the latest finish time of tasks scheduled to it, which is maxk {xi,k · fk } for si . And thus, the occupied computing resource of si for task processing is maxk {xi,k · fk } · gi . While, the computing resource consumed for task processing is Tk=1 (xi,k · ck ) in si . Then, the computing resource utilization of si is Tk=1 (xi,k · ck )/ maxk {xi,k · fk } · gi . And the overall computing resource utilization of the DE3C system can be calculated by Eq. (11). D+E+V T k=1 (xi,k · ck ) i=1 . U = D+E+V (maxk {xi,k · fk } · gi ) i=1
(11)
Based on above formulations, the task offloading problem can be modelled as following optimization problem, with constraints (1)–(11). Where N is the number of accepted T tasks, D+E+V k=1 xi,k . i=1 Maximizing
N +U
(12)
The optimization objective is maximizing the number of accepted tasks (N ) plus the overall computing resource utilization (U ). When the number of total tasks is fixed, the maximization of the accepted task number is identical to optimize the accept ratio, which is one of the most used metrics for quantifying the user satisfaction. U is no more than 1, and thus, in the offloading problem, the accept task number is the major optimization objective, and the overall resource utilization is the minor one. The decision variables include xi,k , 1 ≤ i ≤ D + E + V , 1 ≤ k ≤ T , which are binary. Therefore, the offloading problem belongs binary non-linear programming (BNLP). This problem can be solved by existing tools, e.g., CPLEX. But these tools have time complexities exponentially increased with the problem scale, and thus aren’t applicable for medium- to large-scale DE3C systems. Therefore, in the next section, we propose a hybrid heuristic algorithm to achieve efficient solutions with polynomial time.
4 Hybrid Heuristic Task Offloading Algorithm In this paper, we consider to combine the benefits of PSO and GA, and design a hybrid heuristic offloading algorithm (PSOGA), as outlined in Algorithm 1. PSOGA exploits the iterative optimization framework of PSO and the evolutionary strategy of GA. In PSOGA, each particle position corresponds to a task offloading solution that provides the node for each task, which will be illustrated followings. For evaluating the goodness of each particle, PSOGA uses the optimization objective (Eq. (12)) as the fitness function. The fitness of a particle is $N + U$ when applying corresponding task offloading solution in the DE3C and using EDF for task executions on each computing node.
346
B. Wang and J. Wei
At first (in the initialization phase), PSOGA initializes a population consisting of multiple particles, by randomly setting the position of every particle in each dimension (line 1), and evaluates the fitness of each particle (line 2). Meantime, for each particle, PSOGA records its personal best position as the initialized one (line 3). In addition, PSOGA finds the best position with the best fitness of all particles, and assigns it to the global best position (gb) (line 4). After the initialization, PSOGA iteratively evolves the population to converge to the optimal solution (lines 5–12). In each iteration, PSOGA performs three crossover operators and one mutation operator on each particle (lines 6–10). The first crossover operator is performed on a particle with another particle randomly selected. Other two crossover operators are performed on the particle with its pb and, respectively. For each crossover operator, two offspring are produced, and their fitness are evaluated. Then, PSOGA updates the position of the particle as its offspring with better fitness in produced two offspring. And if the better offspring has better fitness than the epb, the pb is also updated as the better offspring. In this paper, we use the uniform crossover operator to improve the population diversity.
Algorithm 1: PSOGA: Task offloading with hybrid PSO and GA Input: information of tasks, devices, ESs, and CSs; Output: a task offloading solution; 1
Initialize positions of particles randomly; //The initialization phase: lines 1-4
2
Evaluate the fitness of each particle;
3
Set the personal best position (pb) as the initialized one for each particle;
4
Set the global best position (gb) as the best one of all particles;
5
while not reach termination condition do //The evolution phase: lines 5-12
6
for each particle p do
7
Randomly select a particle other than p, p';
8
Cross p with p'; Cross p with pb; Cross p with gb;
9
Evaluate fitness of all offspring;
10
Replace p by the best offspring p, and update pb;
11
Mutate p, evaluate its fitness, and update pb;
12 13
Update gb as the best pb if the best pb has better fitness than gb; Return the decoded offloading solution decoded from gb;
For each particle, after finishing the perform of these three crossover operators, PSOGA mutates the particle to produce an offspring, and replaces the particle with the offspring. Meanwhile, if the offspring has better fitness than the particle’s pb, the pb is updated as the offspring (lines 11). Similar to the crossover, we use the uniform mutation operator in this paper. At the end of every evolutionary iteration, PSOGA updates gb as the best pb if the latter has better fitness (line 12). After the evolution finishes, PSOGA decodes gb into a task offloading solution that is the optimal solution produced by the algorithm (line 13).
Particle Swarm Optimization with Genetic Evolution
347
In PSOGA, there is a one-to-one correspondence between particle positions and task offloading solutions. Dimensions of a position are corresponding to tasks, and the value in a dimension is the No. of the computing node that the corresponding task is scheduled to, where candidate computing nodes are numbered sequentially from devices to ESs and the cloud. As the cloud provides “infinite”‘ CSs, we cannot numbered all CSs, and thus, we use the cloud as a special computing node. In real word, the cloud provides several types of CSs with different resource configurations. When a task is decided to be scheduled to the cloud, we use first fit (FF) to decide the CS for the task, in this paper. When there are available rented CSs, FF schedules a task to the first CS that satisfies its requirements. When all rented CSs cannot finish the task, FF rents a new CS with the configuration meeting its deadline.
5 Performance Evaluation In this section, we establish a simulated experimental platform to evaluate the performance of PSOGA, referring to related works [3, 13, 17]. In the simulated DE3C, there are 1000 tasks randomly associated with 10 devices. Each task requires 0.5–1.2 GHz, with 1.5–6 MB input data. The deadline of a task is set as 1–5 s. Every device has 2–8 computing cores each with 1.8–2.5 GHz computing capacity. In our experiment, the computing unit is considered as the computing core instead of the computing node, as fine grained resource allocation helps to improve the resource utilization [16]. There are 5 ESs and 10 CS types. Every ES has 4–32 computing cores, and Each CS type is configured with 1–8 cores. The computing capacity of each ES or CS core is 1.8–3.0 GHz. An ES and the network configuration of a CS type are set as 80–120 and 8–12 Mbps. Each device has a network connection with one random ES for its task offloading to the ES. To confirm the performance of PSOGA, We compare PSOGA with FF, FFD, EDF, random method (RAND), GA [18], PSO [17], GA with replacement operator (GA_RPL) [9], and GA_PSO [7]. The performance metrics used for the performance comparison of different offloading algorithms include the acceptance ratio (N /T ) and the resource utilization (U ), which are respectively two optimization objectives of our work. We repeat the experiment more than 100 times, and presents the metric value scaled by that of FF for each algorithm. Figure 1 shows the boxplot graph of the acceptance ratios achieved by various task offloading algorithm, in the whole DE3C system. From this figure, we can see that PSOGA achieves 4.87%-28.2% greater acceptance ratio than other algorithms, on average, in overall. And it is statistically significant at the 99% level that PSOGA has better acceptance ratio than others. This is mainly because PSOGA is able to more fully utilize the abundance cloud resources, as it concerns the relative capacities of different resources when making the offloading decisions, as implemented in the end of Sect. 4. PSOGA processes more tasks than others except GA_RPL, in overall. The reason that GA_RPL has overall best acceptance ratio in the cloud is that it finishes fewest tasks in devices and ESs. Thus, more tasks are remained to be offloaded to the cloud, and it is more likely that there are more tasks can be satisfied by the cloud. In addition, PSOM has better performance than PSO, and GA_PSO and PSOGA perform better than GA and PSO, in optimizing the acceptance ratio, as shown in Fig. 1.
348
B. Wang and J. Wei
Thus, the combination of multiple meta-heuristics can have an improvement on each of these meta-heuristics, by exploiting their complementarity. By comparing PSOM, GA_PSO, and PSOGA, which all aim at combining benefits of both GA and PSO, PSOGA has above 21% better average acceptance ratio in overall. This experimental phenomenon proves the high efficiency of the integration strategy exploited by PSOGA. PSOM add only a mutation operator into PSO, to overcome the easily trapping into local optima. GA_PSO is sequentially using the evolutionary strategy of GA and PSO. Both of them have limited improvements on PSO and GA. Our proposed algorithm, PSOGA, takes full advantage of the self- and social cognitions of swarm intelligence and the evolutionary diversity, and thus provides a good combination of GA and PSO. Thus, we must properly design the combination strategy, otherwise can get little performance improvement or even deterioration.
Fig. 1. The acceptance ratios achieved by various offloading algorithms.
As shown in Fig. 2, PSOGA has above 10% greater utilization than other metaheuristic-based algorithms, on average. But, PSOGA has low utilization than heuristicbased algorithms, FF, FFD, and EDF. The main reason is that PSOGA offloads more tasks to ESs and the cloud, as shown in Fig. 1, but processes fewer tasks locally, for improving the overall acceptance ratio. This can reduce the utilization by increasing the percentage of data transfer time, as there is no time consumption in data transfer for task processing in local devices. Thus, compared with heuristic-based algorithms, our offloading algorithm improves the acceptance ratio at the cost of resource utilization, and it is worth it. This is because the acceptance ratio not only determines the income but also affects the reputation for service providers in various cloud computing systems including DE3C. In all of meta-heuristic-based algorithms, PSOGA achieves the greatest utilization with almost the best acceptance ratio, which verifies the high resource efficiency of our algorithm for task offloading in DE3C systems.
Particle Swarm Optimization with Genetic Evolution
349
Fig. 2. The computing resource utilizations achieved by various offloading algorithms.
6 Conclusion In this paper, we focus on the task offloading problem for DE3C. First, we formulate the problem into a BNIP. Then, to solve the problem with a polynomial time, we propose the hybrid PSO and GA algorithm, which exploits the evolutionary framework of GA and the evolutionary strategy of GA, to overcome the easily trapping into local optima of PSO and the slowly convergence of GA. Extensive experiments are conducted, and the results verify the effectiveness and efficiency of our proposed algorithm. In this work, as done by many works, we focus on independent tasks because they are prevalent in network computing. As the development of the application complexity, such as the structure of neural network model, there are also many tasks with interdependences. Therefore, in the future, we will extend our model and algorithm for the offloading concerning the dependencies, to increase the application scope of our approach.
References 1. Almutairi, J., Aldossary, M., Alharbi, H.A., Yosuf, B.A., Elmirghani, J.M.H.: Delay-optimal task offloading for UAV-enabled edge-cloud computing systems. IEEE Access 10, 51575– 51586 (2022) 2. Alqarni, M.A., Mousa, M.H., Hussein, M.K.: Task offloading using GPU-based particle swarm optimization for high-performance vehicular edge computing. J. King Saud Univ. – Comput. Inf. Sci. 34(10, Part B), 10356–10364 (2022) 3. Amazon Web Services, Inc.: Cloud Computing Services - Amazon Web Services (AWS) (2023). https://aws.amazon.com/ 4. Baker, T.: An analysis of EDF schedulability on a multiprocessor. IEEE Trans. Parallel Distrib. Syst. 16(8), 760–768 (2005) 5. Chakraborty, S., Mazumdar, K.: Sustainable task offloading decision using genetic algorithm in sensor mobile edge computing. J. King Saud Uni. – Comput. Inf. Sci. 34(4), 1552–1568 (2022)
350
B. Wang and J. Wei
6. Du, J., Leung, J.Y.T.: Complexity of scheduling parallel task systems. SIAM J. Discret. Math. 2(4), 473–487 (1989) 7. Hafsi, H., Gharsellaoui, H., Bouamama, S.: Genetically-modified multi-objective particle swarm optimization approach for high-performance computing workflow scheduling. Appl. Soft Comput. 122 (2022) 8. Hao, Y., Wang, Q., Cao, J., Ma, T., Du, J., Zhang, X.: Interval grey number of energy consumption helps task offloading in the mobile environment. ICT Express 9, 1–6 (2022) 9. Hussain, A.A., Al-Turjman, F.: Hybrid genetic algorithm for IOMT-cloud task scheduling. Wirel. Commun. Mob. Comput. 2022 (2022) 10. Li, Y., Zeng, D., Gu, L., Zhu, A., Chen, Q., Yu, S.: PASTO: enabling secure and efficient task offloading in trustZone-enabled edge clouds. IEEE Trans. Veh. Technol., 1–5 (2023) 11. Mahenge, M.P.J., Li, C., Sanga, C.A.: Energy-efficient task offloading strategy in mobile edge computing for resource-intensive mobile applications. Digit. Commun. Netw. 8(6), 1048– 1058 (2022) 12. Nwogbaga, N.E., Latip, R., Affendey, L.S., Rahiman, A.R.A.: Attribute reduction based scheduling algorithm with enhanced hybrid genetic algorithm and particle swarm optimization for optimal device selection. J. Cloud Comput. 11, 15 (2022) 13. Sang, Y., Cheng, J., Wang, B., Chen, M.: A three-stage heuristic task scheduling for optimizing the service level agreement satisfaction in device-edge-cloud cooperative computing. PeerJ Comput. Sci. 8(e851), 1–24 (2022) 14. Song, S., Ma, S., Yang, L., Zhao, J., Yang, F., Zhai, L.: Delay-sensitive tasks offloading in multi-access edge computing. Expert Syst. Appl. 198, 116730 (2022) 15. Song, S., Ma, S., Zhao, J., Yang, F., Zhai, L.: Cost-efficient multi-service task offloading scheduling for mobile edge computing. Appl. Intell. 52(4), 4028–4040 (2021). https://doi. org/10.1007/s10489-021-02549-2 16. Tirmazi, M., et al.: Borg: the next generation. In: Proceedings of the Fifteenth European Conference on Computer Systems, EuroSys 2020, Association for Computing Machinery, New York (2020) 17. Wang, B., Cheng, J., Cao, J., Wang, C., Huang, W.: Integer particle swarm optimization based task scheduling for device-edge-cloud cooperative computing to improve SLA satisfaction. PeerJ Comput. Sci. 8(e893), 1–22 (2022) 18. Wang, B., Lv, B., Song, Y.: A hybrid genetic algorithm with integer coding for task offloading in edge-cloud cooperative computing. IAENG Int. J. Comput. Sci. 49(2), 503–510 (2022) 19. Wang, C., Guo, R., Yu, H., Hu, Y., Liu, C., Deng, C.: Task offloading in cloud-edge collaboration-based cyber physical machine tool. Rob. Comput.-Integr. Manuf. 79, 102439 (2023) 20. Wang, H.: Collaborative task offloading strategy of UAV cluster using improved genetic algorithm in mobile edge computing. J. Rob. 2021 (2021) 21. You, Q., Tang, B.: Efficient task offloading using particle swarm optimization algorithm in edge computing for industrial internet of things. J. Cloud Comput. 10(1), 1–11 (2021). https:// doi.org/10.1186/s13677-021-00256-4
DBCS-SMJF: Designing a BLDCM Control System for Small Machine Joints Using FOC Leyi Zhang1 , Yingjie Long2 , Yingbiao Hu3 , and Huinian Li4(B) 1 School of Automation, Guangdong University of Technology, Guangzhou 510006, China 2 School of Computer Science, South China Normal University, Guangzhou 510631, China 3 Faculty of Applied Sciences, Macao Polytechnic University, Macao 999078, China 4 School of Computer Science and Engineering, Macao University of Science and Technology,
Macao 999078, China [email protected]
Abstract. This paper proposes a new motor field-oriented control scheme for robot joint control that utilizes linear Hall position sensors. The proposed position sensor can obtain rotor position information of the motor with high precision and low cost, overcoming the limitations of traditional motor control systems that use expensive and bulky encoders. The proposed plan includes a Hall sensor module, an electromagnetic detection module, a motor drive module, and a motor. It has been tested and verified using Ansys magnetic field simulation. The proposed system can improve the accuracy and stability of robot joint control while reducing costs compared to traditional motor Field-Oriented Control (FOC) systems. Using linear Hall position sensors can make the system more compact and lightweight. Future work may focus on further optimizing and refining the design for real-world applications. Keywords: SVPWM · FOC · Linear Hall Sensor · Motor · PI Controller · MCU
1 Introduction The Brushless Direct Current Motor (BLDCM) is widely used in various industrial and consumer products, including electronic devices, home appliances, automobiles, aerospace, medical equipment, construction, and other fields. In China, annual motor power consumption accounts for over 60% of the total power generation, indicating the motor’s essential role in our daily lives. The BLDCM is crucial in many devices, such as drones, household robots, and lightweight linkage manipulators. Therefore, the vector control problem of BLDCM plays a vital role in resolving joint machine design issues. One of the primary challenges in controlling brushless DC motors is maintaining accuracy and stability. Torque fluctuation, resulting from the motor’s internal structure and characteristics, is the primary issue during operation [1]. Brushless motors can be categorized into trapezoidal and sine wave control according to their control modes. Sine wave control requires a high-resolution position sensor, and the controller circuit is relatively complex but highly precise [2]. Therefore, most design schemes adopt sine wave control for better control accuracy in vector control. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 351–359, 2023. https://doi.org/10.1007/978-981-99-4761-4_30
352
L. Zhang et al.
1.1 Research Status of the Motor Position Sensor The motor position sensor is primarily utilized for measuring the rotor angle. Currently, the most widely adopted method employs an encoder to accomplish the measurement of the motor angle. While this approach is exact and mature in industrial applications, it is disadvantaged by its high cost and large volume. To address these issues, this design proposes a low-cost angle sensor based on a linear Hall sensor that can fulfill the requirements of small BLDCM rotor angle measurement (Fig. 1).
Fig. 1. Framework of the Motor control circuit.
Motor position sensor research has focused on large servo systems and sensorless FOC designs. Advances in encoder-based systems have led to the development of more specific sensor types, such as planar magnetic induction angle sensors [3] and shaft angle encoders [4]. MEMS sensor solutions can mitigate bulk, but magnetic field measurement based on the gyroscope principle suffers from hysteresis. Linear Hall sensors were chosen as the position sensor solution after careful analysis. Adaptive virtual current sensors [5] can estimate rotor position without sensors, but the indirect calculation of rotor position by detecting phase voltage and DC [6] is unsuitable for joint motor control. 1.2 Using Advanced Hall Effect for New FOC System Design This paper proposes a Motor Field-Oriented Control (MFOC) scheme using a linear Hall position sensor for robot joint control. It introduces a new position sensor based on the linear Hall effect and a field-oriented control system for the BLDCM motor. The report provides insights into achieving accurate position control of the motor drive in collaborative design. Section 1 covers the background and current status of the FOC vector control drive algorithm and hardware design, position sensors, and the advantages and disadvantages of linear Hall sensors. Section 2 discusses the latest advances in position sensors and
DBCS-SMJF: Designing a BLDCM Control System
353
FOC technology. Section 3 focuses on the BLDCM position sensor’s relevant technologies, such as the linear Hall sensor’s measurement principle and mathematical model, and presents the motor rotor position sensor and ADC amplification circuit design. Section 4 analyzes the BLDCM motor control system’s design principle, vector control system function, hardware circuit design principles, PCB drawing considerations, and communication instruction set. Finally, Sect. 5 provides the conclusion of the article.
2 Related Work Since the 1970s, domestic research on robotic arms has undergone several decades of development. During this time, significant progress has been made in the design and control of robotic arms, marking a continuous evolution of this technology. From the development of the first robotic arm in Shanghai in 1972 to the latest “Diao Da Bai” robotic arm designed for the Beijing Winter Olympics in 2022, the technology has undergone several technical iterations and has been refined continuously. 2.1 Motor Rotor Position Sensor This paper proposes a low-cost, high-precision Hall effect-based angle measurement solution for small brushless motors. Previous works utilizing linear Hall sensors for angle measurement have limitations. To address this issue, the team used machine learning and neural network algorithms to achieve a measurement angle error control of 1.7° from static to 850rpm. Similarly, Han et al. [7] proposed a two-step Kalman filter method based on phase compensation, significantly improving angle estimation accuracy. 2.2 The Driving Mode of BLDCM Nicolae et al. [8] compared SSM and SVM in PWM control signals and found SVM superior. Ciprian et al. [9] analyzed BLDCM performance under a scalar control strategy for automotive braking and steering systems, finding it to be better. Wang et al. [10] proposed a FOC-based BLDCM control system using a switched Hall sensor. Yang et al. [11] improved the speed loop PID algorithm with a single-neuron neural network. Wang et al. [12] and Yin et al. [13] optimized the controller using RBF neural networks. Zhang et al. [14] used an MCU to realize hardware FOC chip control. Neural networks effectively tackle nonlinear BLDCM management problems, with specialized activation functions enhancing control performance and reducing operational requirements.
3 Proposed Approach 3.1 Design of BLDCM Rotor Position Detection Scheme Linear Hall Sensor Mathematical Model, The basic principle of linear Hall sensors, is the Hall effect, as shown in Fig. 2. The magnetic field strength B passes through the thin plate of the conductive medium in the direction perpendicular to the plate surface
354
L. Zhang et al.
and the current flows through the conductor in the direction I. The charged particles in the plate are acted upon by the Lorentz force, and the direction of the force can be determined using the left-hand rule. The direction of movement of the charged particles in the conductor and the magnitude of the force is:
Fig. 2. Hall Effect Schematic.
Under the action of the Lorentz force, free electrons gather on the surface S1. On the contrary, an equal amount of positive charges accumulate on the surface S2, forming an electric field in the conductor. With the accumulation of costs, the electric field strength increases to a certain extent and reaches equilibrium. Formula (1) is the potential difference at both ends of the conductor can be obtained as: IB (1) VH = end Analysis of Measuring Principle of Linear Hall Position Sensor. In this design, the structure of the angle sensor is illustrated in Fig. 3, as an example, depicts two magnets, F1 and F2, forming a pole pair and placed in the rotor. The Hall sensor is installed in the stator, and its potential difference is measured using an external ADC module. A uniform magnetic field B passes through Hall sensors D1 and D2, generating induced electromotive force. As the rotor rotates, the potential difference changes due to the angle between the magnetic field B and the Hall sensor D1. The magnitude of the potential difference can be used to determine the angle. The simulation results indicate that the magnetic field at the measurement center is close to a uniform magnetic field when the area is small. In Fig. 4, the blue-green arrows in the center of the two permanent magnets are the main ones, suggesting that the magnetic flux B is about 0.080506T, which corresponds to a magnetic field of about 800gs, within the measurement range of the Hall sensor. 3.2 BLDCM Vector Control Scheme Design Analysis of BLDCM Motor Control System. The motor vector control system involves obtaining the current magnitudes of each phase and simplifying the calculation using the Clark transformation to obtain and Iβ . The Park transformation is then used to transform the two-dimensional coordinates into one-dimensional using the rotation angle, and the PI controller controls the output to complete the FOC control process. Real-time control may require adding a phase-locked loop to reduce control delay. As shown in Fig. 6, it is a block diagram of the sampling circuit design, which includes a WIFI module, serial port communication, MCU central control, inverter circuit, isolation amplifier circuit, and ADC signal acquisition circuit.
DBCS-SMJF: Designing a BLDCM Control System
355
F1
Hall1
D1 D2 F2
Fig. 3. Schematic diagram of the measurement principle of the linear Hall position sensor.
Fig. 4. Simulation results of the permanent magnet structure.
4 Experiments This paper proposes a simulation-based method for designing BLDCM control algorithms. Chapters on position detection, current control, and position control using the Hall position sensor are included. A virtual system platform is used to construct a simulation model, and simulation analysis is carried out to verify the control system’s effectiveness. The design of a linear Hall-based BLDCM drive control system is also presented, along with implementation methods (Fig. 5).
Fig. 5. Block diagram of linear hall position sensor.
356
L. Zhang et al.
Fig. 6. Current sampling circuit design block diagram.
4.1 Linear Hall Position Sensor Design Scheme Verification After theoretical verification by Matlab and Ansys simulations, the feasibility of this design has been confirmed. To complete the experiment, existing equipment is used to design the experimental device using PCB. To reduce the size, this design installs magnets on the transmission device of the machine joints by welding to the pads. NdFeB is used because it maintains magnetism at high temperatures up to 500 °C despite the small pad size.
Fig. 7. The physical picture of the machine joint installed on the motor.
After welding the magnet, the machine joint and the motor can be connected using M3 screws and four mounting holes. As shown in Fig. 7, the machine joint is installed onto the 2804 motor via the screw hole, and the permanent magnet is soldered onto the gray pad.
DBCS-SMJF: Designing a BLDCM Control System
357
Position Sensor Measurement Experiment. ADC data is obtained from the Hall sensor, and non-inductive FOC controls motor rotation. The resulting waveform on the oscilloscope resembles a sine wave due to the magnet’s position changes. The angle is calculated using inverse trigonometric functions, and the resulting graph matches the Matlab simulation. Error is calculated using the inverse trigonometric function. The Hall sensor generates a signal with an amplitude of 130–200 mV and a noise range of 25 mV. To improve accuracy, an average filter is used to preprocess the movement, and a lookup table method is used to calculate the angle. Experimental results show that the actual measurement error is within 3°, indicating the high accuracy of the linear Hall sensor.
4.2 Experimental Verification of BLDCM Vector Control System FOC Control Accuracy Measurement. MCU output duty cycle can be adjusted to simulate sine wave output for smooth motor rotation. Test results of angle measurement and error are shown in Fig. 8 and Fig. 9. To test the system response, a set of test commands was established for precise time synchronization. The speed response time of the FOC control system designed in this study is within 0.5 s, compared with the research result of Yu et al. [15]. In this chapter, the relevant parameters of the experimental platform were measured first, and then the measurement data was wirelessly transmitted to the PC using ESP32. After obtaining the operating data of the platform, the speed and torque response of the FOC control system in controlling the pan/tilt motor were measured using Matlab, and the results are shown in Fig. 10. Before completing the FOC experimental verification, the measurement accuracy of the linear Hall position sensor was first determined. The position obtained from the linear Hall position sensor was used as the rotor angle input of the FOC controller, resulting in precise control.
Fig. 8. Actual measurement error output of hall sensor.
358
L. Zhang et al.
Fig. 9. Measurement of rotational angle accuracy with 5° offset each time.
5 Conclusion This paper presents new methods for designing brushless DC motors (BLDCM) controllers in collaborative machine design. The article discusses the pros and cons of various position sensors and proposes solutions for small machine joint design. The study introduces a design approach for a BLDCM vector controller utilizing a linear Hall sensor and a standardized PCB mechanical structure. The paper demonstrates experimental results showing a 3° measurement error on a 26 mm PCB board. A simple field-oriented control (FOC) drive scheme with a 3.5° error range is proposed based on the machine joint’s motor control accuracy requirements. The PCB machine joint structure is validated using only a few PCBs and two permanent magnets. The paper introduces two innovative contributions: a position sensor design that uses the linear Hall effect and two permanent magnets and a new experimental platform for designing the entire machine joint using only a few PCBs. The paper provides a highly repeatable BLDCM-driven design approach.
Fig. 10. FOC system motor control response.
DBCS-SMJF: Designing a BLDCM Control System
359
Acknowledgement. We thank our team members for their valuable support and enthusiasm in this energy-intensive experiment. Our team is dedicated to exploring new technologies, and we take joy in our daily work. Special thanks to our supervisor for his expert guidance and constructive feedback, which was instrumental in shaping this paper. We look forward to applying our learnings to future projects and making even more significant strides in the rapidly evolving technology landscape.
References 1. Younesi, A., Tohidi, S., Feyzi, M.: Improved optimization process for nonlinear model predictive control of PMSM. Iranian J. Electr. Electr. Eng. 14(3), 278 (2018) 2. Chen, S., Liu, G., Zheng, S.: Sensorless control of BLDCM drive for a high-speed maglev blower using low-pass filter. IEEE Trans. Power Electron. 32(11), 8845–8856 (2016) 3. Tang, Q., Wu, L., Chen, X., Peng, D.: An inductive linear displacement sensor based on planar coils. IEEE Sens. J. 18(13), 5256–5264 (2018) 4. Kim, J.H., Hahn, Y.K., Chun, H.: Multiplexed detection of pathogens using magnetic microparticles encoded by magnetic axes. Sens. Actuat. B Chem. 285, 11–16 (2019) 5. Adamczyk, M., Orlowska-Kowalska, T.: Postfault direct field-oriented control of induction motor drive using adaptive virtual current sensor. IEEE Trans. Industr. Electron. 69(4), 3418– 3427 (2021) 6. Jia, Z., Zhang, Q., Wang, D.: A sensorless control algorithm for the circular winding brushless DC motor based on phase voltages and DC current detection. IEEE Trans. Industr. Electron. 68(10), 9174–9184 (2020) 7. Han, B., Shi, Y., Li, H.: Position estimation for ultra-low speed gimbal servo system of SGMSCMG based on linear hall sensors. IEEE Sens. J. 20(20), 12174–12183 (2020) 8. Irimia, N.D., Lazar, F.I., Luchian, M.: Comparison between sinusoidal and space vector modulation techniques on the resulting electromagnetic torque ripple produced by a threephase BLDC motor under field-oriented control. In: 2019 6th International Conference on Control, Decision and Information Technologies (CoDIT), pp. 640–645. IEEE (2019) 9. Bejenar, C., Irimia, N.D., Luchian, M., Lazar, F.I.: Dynamic behavior analysis of a threephase BLDC motor under scalar control strategy for automotive actuation systems. In 2020 International Conference on Development and Application Systems (DAS), pp. 7–15. IEEE (2020) 10. Lu, W., et al.: A kind of PWM DC motor speed regulation system based on STM32 with fuzzyPID dual closed-loop control. In: Intelligent Computing Methodologies: 18th International Conference, ICIC 2022, Xi’an, China, August 7–11, 2022, Proceedings, Part III, pp. 106–113. Springer (2022) 11. Yang, X., Deng, W., Yao, J.: Neural network based output feedback control for DC motors with asymptotic stability. Mech. Syst. Signal Process. 164, 108288 (2022) 12. Wang, W., Pang, H., Li, X., Wu, Y., Song, X.: Research on speed control of permanent magnet synchronous motor based on RBF neural network tuning PID. J. Phys. Confer. Ser. 2264(1), 012018 (2022) 13. Yin, Z., Zhao, H.: Implementation of various neural-network-based adaptive speed PI controllers for dual-three-phase PMSM. In: IECON 2022–48th Annual Conference of the IEEE Industrial Electronics Society, pp. 1–6. IEEE (2022) 14. Huiling, Z.: The design of a dual servo drive and control integrated platform based on Vnet neural network. Secur. Commun. Netw. (2022) 15. Yu, Z., Qin, M., Chen, X., Meng, L., Huang, Q., Fu, C.: Computationally efficient coordinate transformation for field-oriented control using phase shift of linear hall-effect sensor signals. IEEE Trans. Industr. Electron. 67(5), 3442–3451 (2019)
Intelligent Data Analysis and Prediction
A Hybrid Tourism Recommendation System Based on Multi-objective Evolutionary Algorithm and Re-ranking Ruifen Cao1 , Zijue Li1 , Pijing Wei3 , Ye Tian3 , and Chunhou Zheng2(B) 1 School of Computer Science and Technology, Anhui University, Hefei, China 2 School of Artificial Intelligence, Anhui University, Hefei, China
[email protected] 3 Institutes of Physical Science and Information Technology, Anhui University, Hefei, China
Abstract. The tourism recommendation system (TRS) can intelligently recommend the next attractions for users according to their historical visit records. Most tourism recommendation methods only consider accuracy as an evaluation index, which can no longer meet the needs of users who are increasingly seeking for diversified attractions. In order to solve the problem, a hybrid recommendation model based on multi-objective optimization and re-ranking (MOEA/D-HRR) is proposed in this study. To overcome the limited predictive power of a single recommendation algorithm, MOEA/D-HRR mixes three different classical techniques by weighted summation of their recommendation results, and then improve the diversity of the recommendation list by using the re-ranking algorithm based on Maximal Marginal Relevance (MMR). In addition, we construct a multi-objective optimization problem in which accuracy and diversity are formulated as two objectives, and use a multi-objective evolutionary algorithm (MOEA) to optimize the weights and parameter of the re-ranking algorithm. A lot of experiments on two public tourism datasets and a new tourism dataset were carried out. Experimental results showed that the model could provide users with diversified attractions while ensuring the accuracy of recommendation attraction. Keywords: Tourism Recommendation System · Multi-objective Optimization · Hybrid Algorithm · Re-ranking Algorithm
1 Introduction With the improvement of residents’ living standards, tourism has become an important part of people’s leisure and entertainment. However, in face of huge and complex tourism data, it is difficult for people to quickly find the attractions they need. Therefore, in order to meet people’s increasingly urgent demand for efficient and intelligent tourism information, an effective tourism recommendation system (TRS) is necessary to recommend more desirable attractions to users. TRS can analyze users’ historical behavior through relevant recommendation methods to provide recommendation list. Traditional recommendation methods can be divided © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 363–372, 2023. https://doi.org/10.1007/978-981-99-4761-4_31
364
R. Cao et al.
into user-based collaborative filtering recommendation methods(User_CF) [1], itembased collaborative filtering recommendation methods(Item_CF) [2], content-based recommendation methods [3], and knowledge-based recommendation methods [4]. Moreover, deep learning is also developing vigorously, and a large number of recommendation methods based on deep learning are constantly proposed [5]. Compared with traditional methods, the deep learning methods can better learn the complex relationships between users and attractions. However, most of these recommendation technologies mainly focus on improving the accuracy of the recommendation. Accuracy based methods always recommend attractions with great similarity to the historical attractions visited by users, which will lead to aesthetic fatigue of users. Recommending more diverse attractions can discover users’ potential interests. It is important to note that accuracy and diversity are in conflict. So how to balance accuracy and diversity is an important and challenging problem for researchers. Although each algorithm has its advantages, it also has its disadvantages, such as content-based recommendation methods has the long tail problem [6] and so on. Therefore, hybrid recommendation system that integrating multiple algorithms are studied and developed to overcome the problems existed in a single algorithm [7]. The hybrid recommendation system can be divided into three types [8]: The monolithic hybrid paradigm, the parallel hybrid paradigm and the pipelined hybrid paradigm. To solve the above problem in recommendation, we propose a hybrid recommendation model based on re-ranking and multi-objective evolutionary algorithm (MOEA/DHRR). We normalize the ranking list of attractions obtained by three basic recommendation algorithms. Then the three ranking lists are weighted with appropriate proportions to obtain the hybrid attractions ranking list. Finally, the re-ranking algorithm is used to rearrange the attractions rank list, which greatly improving the diversity of the recommendation list. In order to find the best weights in our algorithm, we use a multiobjective evolutionary algorithm(MOEA) to obtain optimal parameters in the model. To test the performance of the proposed algorithm, MOEA/D-HRR is compared with some of advanced recommendation algorithms on two public tourism datasets and one dataset constructed by ourselves. The results show that MOEA/D-HRR can greatly improve the diversity of recommendations while ensuring accuracy.
2 Related Work In recent years, many traditional recommendation methods still widly used. Zhang [9] made full use of contextual information such as trust relationships between friends, user preferences and geography, and combines this information with collaborative filtering algorithms to recommend attraction of interest to visitors. Kang [10] proposed a sequential recommendation model called SASRec based on the attentional mechanism in transformer, which can better obtain the historical access sequence information of users. Zhang [5] proposed a model which can learns the user’s short-term interest by self-attention mechanism and obtain the long-term preferences by metric learning. Since each algorithm has its own advantages and disadvantages, many researches make up for the shortcomings of each algorithm by mixing different algorithms. Guo [11]
A Hybrid Tourism Recommendation System
365
proposed a hybrid recommendation model for label recommendation, which can extract text features and visual features simultaneously at the same time. Zhou [12] mixed two neural network models in order to take advantage of their advantages which can extract semantic features as well as sequence features. Although these hybrid recommendation methods can combine multiple algorithms from different dimensions to improve the effect, the attractions in the recommendation results will be similar to each other and lack of diversity which greatly weakens the user’s experience. Most recommendation algorithms mainly focus on improving the accuracy of the recommendation, but often ignore other indicators such as diversity that are also very important to users. Zuo [13] proposed a multi-objective recommendation model, which used MOEA to optimize the accuracy and diversity of recommendations. Cai [15] proposed a hybrid recommendation algorithm suitable for high dimensional and multiobjective in the framework of MOEA. However, these multi-objective recommendation methods have not been applied in the field of tourism recommendation. Therefore, we propose a hybrid re-raking recommendation model based on multiobjective evolutionary algorithm, which consider both accuracy and diversity of recommendations. We adopt a re-ranking algorithm to dynamically adjust the relationship between the two objectives, which means that our final list of recommendations will have good accuracy and diversity, greatly improving the user’s tourism experience.
3 The Proposed MOEA/D-HRR In this part, MOEA/D-HRR will be elaborated on. First, Sect. 3.1 describes the framwork of the MOEA/D-HRR. Then, in Sect. 3.2, two objective functions are introduced. Finally, the re-ranking algorithm is then introduced in Sect. 3.3 (Fig. 1).
Fig. 1. The framework of MOEA/D–HRR.
366
R. Cao et al.
3.1 The Framework of MOEA/D-HRR As the proposed recommendation method is based on the user’s rating ranking of the attractions, the user’s historical access sequence is firstly input into the three baseline recommendation algorithms to obtain the scoring lists of the three algorithms. Then, three coefficients α, β and γ are used to perform linear weighting on the scoring lists obtained by the three baseline recommendation algorithms, as shown in Eq. 1: P(u, i) = αPRUcf (u, i) + βPRIcf (u, i) + γ PRSASRec (u, i)
(1)
RUcf , RIcf and RSASRec are the prediction scores of attractions obtained based on User_CF [1], Item_CF [2] and SASRec [10] respectively. P(u, i) is the score of user u on attraction i after combining three algorithms. Different the combination of α, β and γ , the hybrid recommendation results will be different, so how to determine the best value of the three parameters is very important. MOEA is used to get optimal α, β and γ . Specially, the relationship between α, β and γ must satisfy the following condition: α+β +γ =1
(2)
A new ranking list of attractions was obtained by combining the results obtained by the three recommendation algorithms, and then the list is input into the re-ranking algorithm based on Maximal Marginal Relevance (MMR) [16] to further improve the diversity. There is a weight parameter λ in the re-ranking algorithm to best balance accuracy and diversity of recommendation results. In order to determine the best parameter λ in the re-ranking algorithm, we also adopt the multi-objective evolutionary algorithm to find the best λ value. In other words, λ with α, β and γ are the decision variables of multi-objective problem that is solved by using the multi-objective optimization algorithm. MOEA/D [17] is used to optimize the two objective functions that are accuracy and diversity by adjusting the above four parameters. As MOEA/D converges faster than other MOEAs, we select MOEA/D as our multi-objective optimization algorithm. After obtaining the best value of four parameters, the recommendation lists of a user can be output by our recommendation system. 3.2 The Tow Optimization Objectives In order to recommend the best attractions to users, we use the two objective functions in terms of accuracy and diversity to obtain the optimal the parameters in our model. Accuracy is a direct indicator to determine the availability of a TRS, which also is a standard for TRS to predict the user’s behavior ability. The accuracy objective function is shown as follows: Accuracy =
u∈User {Ru }∩{click}
|User|
(3)
where Ru is the recommendation list for user u. click is the last visit attraction in the user’s historical visit sequence. {Ru } ∩ {click} indicates whether the attraction actually visited by the user is in the recommendation list. If yes, it is 1; otherwise, it is 0. |User| indicates the number of users of the whole data set.
A Hybrid Tourism Recommendation System
367
The second optimization objective of the proposed model is diversity. A TRS with high diversity can give users different experience, discover users’ potential interest and avoid monotonous recommendation results. The diversity objective function is as follows: m,n∈Ru ,m=nSim(m,n) 1 1 − (4) Diversity = |User| 1 u∈User 2 |Ru |(|Ru |−1)
Sim(m, n) =
√|Um ∩Un | |Um ||Un |
(5)
|Ru | represents the length of user’s recommendation list, Sim(m, n) represents the similarity between two attractions m and n. Um and Un represent the set of all users who visited m and n. An accuracy function with high value indicates that the recommendation list will be similar to the user’s historical visit list. However, high diversity indicates that there are many attractions on the recommendation list that are not similar to users’ historical preferences. Obviously, these two objective functions are in conflict with each other. The objective function is constructed as follows: max Accuracy (6) max Diversity
3.3 The Re-ranking Algorithm Based on MMR To further improve the balance between diversity and accuracy of the hybrid rankling list, the re-ranking algorithm based on MMR is introduced to the proposed model. The re-ranking algorithm selects an attraction from the list which are not recommended attractions and calculates the user’s rating for it, as well as its similarity to the final recommendation attractions. We hope the attraction to be as similar as possible to the user visited and as different as possible from the final recommendation attractions. The re-ranking algorithm based MMR can be summarized as the following formula: argmax max (7) MMR = λP(u, i) − (1 − λ) Sim(i, j) i ∈ R\S j∈S In the formula, R represents the ranking list of recommendation scores input into the algorithm, S represents the set of attractions selected in R to recommend to users, R\S represents the set in which the attractions are not selected and λ is the weight coefficient to adjust the accuracy and diversity of recommendation results. P(u, i) is the user’s rating of attractions after hybrid algorithm. When the user scores the attraction higher, we can say that the attraction is more similar to the user’s historical preference. After the user’s attractions rating list are obtained by weighted summation of the three basic recommendation algorithms, the re-ranking algorithm is used to get the final recommendation list. The final recommendation list would be more diverse, rather than attractions that are similar to the user’s visit history will receive higher ratings.
368
R. Cao et al.
4 Dataset To verify the universality of the proposed method, we carried out many experiments on two datasets NYC and TKY in the public dataset Foursquare [18]. The NYC and TKY datasets are commonly used in tourism recommendation, which recorded the check-in data of tourists in New York and Tokyo from April 2012 to September 2013. These two datasets include visitor ID, location ID, visitor visit time, location, and so on. Referring to the previous study [19], we filtered users with less than 10 access records and locations with less than 10 visitors. To further prove effectiveness of the MOEA/D-HRR, we use the tourism information of Huangshan City on trip.com. The tourism information includes the tourist ID, the attraction ID, time, ratings, and the coordinates of attractions. We sort the samples according to tourist ID and tourist visit time, and get the visit sequence of each tourist. We combine each tourist’s visit sequence into a tourism dataset named HSC for the experiments. For the HSC dataset, we filter out the locations visited by less than four different users and the users that visitedless than four different locations. The details of the three data sets after processing is shown in Table 1. Table 1. Three data sets after processing. Datasets
User
Attractions
Interaction
HSC
967
100
6406
NYC
829
1089
40479
TKY
2214
2852
329115
5 Experiments 5.1 Performance Metrics In the experiment, we used two metrics, Hit@N and Diversity, which are widely used in the evaluation of recommendation methods. These two evaluation metrics are consistent with the accuracy and diversity of the two objective functions respectively. Therefore, the calculation formulas of the two evaluation metrics is respectively Eq. 3 and Eq. 4. In the experiments, the value of N is set to 10. 5.2 Selection of Baseline Algorithm In order to select effective recommendation algorithms as the benchmark algorithm for hybrid recommendation, we first evaluate the effects of five recommendation algorithms on three datasets. Item_CF, User_CF and BPR [20] can obtain the user’s preferences through the user’s rating matrix, but they all do not combine the time series information. ATTRec [5] and SASRec can learn short-term and long-term information about users
A Hybrid Tourism Recommendation System
369
from time series. These algorithms can be combined to learn the interaction between users and projects from multiple perspectives and improve the quality of recommendations. In the experiment, we used the last attraction the user visited as the test set. For the deep learning model, the penultimate visited attraction is used as the verification set to determine the parameters, and the rest is used as the training set. For the machine learning model, we used the data except last visitd attraction as the training set. The experimental results on the three datasets are shown in Fig. 2. From these figures, we can see that Item_CF, User_CF and SASRec are the three best performers in terms of accuracy of the three datasets, and they also have good results in terms of diversity. Therefore, User_CF, Item_CF and SASRec were selected to construct the proposed MOEA/D-HRR according to the comprehensive performance of these five single recommendation methods on the three data sets.
Fig. 2. Performance of different baseline recommendation algorithms on three data sets
5.3 Comparison with State of the Art Methods We compared MOEA/D-HRR with three single methods (Item_CF, User_CF, and SASRec) and two hybrid methods (Vote-HRM [20] and RVEA-HRM [14]. As the three single methods focus on accuracy, we select the solution with the highest accuracy from Pareto-solution set obtained by RVEA-HRM and MOEA/D-HRR to compare with these algorithms. Table 2 show the performance of these algorithms on three datasets respectively. From the experimental results listed in the Table 2, we can conclude that the proposed MOEA/D-HRR is superior to all other comparison methods on all data sets. Our algorithm shows great advantages when compared with baseline methods. There was a great improvement in two indicators of the three data sets. At the same time, compared with the two hybrid recommendation algorithms, MOEA/D-HRR also has a great
370
R. Cao et al.
improvement. In the HSC data set, MOEA/D-HRR improves 3.74% in Accuracy, 2.42% in Diversity, 3.51% in Accuracy, and 1.31% in Diversity compared with Vote-HRM and RVEA-HRM, respectively. Other data sets and indicators have come to similar conclusions. MOEA/D-HRR is superior to the other two hybrid algorithms mainly because there are redundant attractions in the hybrid list. So, MOEA/D-HRR colud select the user’s favorite attractions from the list that is not selected by users but as different as possible from the selected attractions. So it can do more fine sorting, reduce redundant items and improve recommendation performance. Table 2. Comparison methods with state of the art methods on three data sets. Datasets
Algorithm
Hit@10
Diversity
HSC
Vote-HRM
0.6316
0.7588
RVEA-HRM
0.6339
0.7699
MOEA/D-HRR
0.6690
0.7830
Vote-HRM
0.3493
0.9162
RVEA-HRM
0.3534
0.9109
MOEA/D-HRR
0.3839
0.9238
Vote-HRM
0.5058
0.8799
RVEA-HRM
0.5140
0.8986
MOEA/D-HRR
0.5198
0.9044
NYC
TKY
5.4 The Effectiveness of the Re-ranking Algorithm In order to verify the effectiveness of the re-ranking algorithm, the model without the re-ranking algorithm called MOEA/D-HR was compared with MOEA/D-HRR. Figure 3 show the Pareto fronts obtained by two models respectively. It can be seen that our model basically dominate the Pareto fronts generated by MOEA/D-HR. Since the re-ranking algorithm reorders the ranking of the hybrid attractions, it maximizes the utility of users and thus achieves the better effect of attractions diversification. So, the re-ranking algorithm is very effective as the important part of MOEA/D-HRR.
A Hybrid Tourism Recommendation System
371
Fig. 3. Pareto fronts of MOEA/D-HR and MOEA/D-HRR on three data sets
6 Conclusion In this paper, we propose a hybrid recommendation algorithm based on multi-objective optimization and re-ranking, which can provide users with more diverse recommendations while ensuring the accuracy of recommendations. We obtain the recommendation score ranking of each user by combining the three baseline algorithms, and reorder through the re-ranking algorithm to further refine recommendation lists, thereby improving the accuracy and diversity of recommendations. The experimental results show that the recommendation list generated by proposed method not only meets the user’s preferences but also brings a variety of recommendations to the user, which greatly improves the user’s experience. Acknowledgments. This work was supported in part by the University Synergy Innovation Program of Anhui Province under Grant (GXXT 2021-030), in part by the National Key Research and Development Program of China under Grant (2020YFA0908700) and in part by the National Natural Science Foundation of China under Grants (61873001 and U19A2064).
References 1. Zhao, Z.-D., Shang, M.-S.: User-based collaborative-filtering recommendation algorithms on Hadoop. In: 2010 Third International Conference on Knowledge Discovery and Data Mining. IEEE (2010) 2. Sarwar, B., et al.: Item-based collaborative filtering recommendation algorithms. In: Proceedings of the 10th International Conference on World Wide Web (2001) 3. Pazzani, M.J., Billsus, D.: Content-based recommendation systems. In: Brusilovsky, P., Kobsa, A., Nejdl, W. (eds.) The Adaptive Web. LNCS, vol. 4321, pp. 325–341. Springer, Heidelberg (2007). https://doi.org/10.1007/978-3-540-72079-9_10 4. Burke, R.: Knowledge-based recommender systems. Encyclopedia Libr. Inf. Syst. 69(Suppl. 32), 175–186 (2000) 5. Zhang, S., et al.: Next item recommendation with self-attentive metric learning. In: ThirtyThird AAAI Conference on Artificial Intelligence, vol. 9 (2019) 6. Park, Y.-J., Tuzhilin, A.: The long tail of recommender systems and how to leverage it. In: Proceedings of the 2008 ACM Conference on Recommender Systems (2008)
372
R. Cao et al.
7. Burke, R.: Hybrid recommender systems: survey and experiments. User Model. User-Adap. Inter. 12, 331–370 (2002) 8. Çano, E., Morisio, M.: Hybrid recommender systems: a systematic literature review. Intell. Data Anal. 21(6), 1487–1524 (2017) 9. Zhang, Z., et al.: A context-awareness personalized tourist attraction recommendation algorithm. Cybern. Inf. Technol. 16(6), 146–159 (2016) 10. Kang, W.C., Mcauley. J.: Self-attentive sequential recommendation. In: 2018 IEEE International Conference on Data Mining (ICDM) (2018). https://doi.org/10.1109/ICDM.2018. 00035 11. Guo, H., et al.: DeepFM: a factorization-machine based neural network for CTR prediction. arXiv preprint arXiv:1703.04247 (2017) 12. Zhou, X., Li, Y., Liang, W.: CNN-RNN based intelligent recommendation for online medical pre-diagnosis support. IEEE/ACM Trans. Comput. Biol. Bioinform. 18(3), 912–921 (2020) 13. Zuo, Y., et al.: Personalized recommendation based on evolutionary multi-objective optimization [research frontier]. IEEE Comput. Intell. Mag. 10(1), 52–62 (2015) 14. Cai, X., et al.: A hybrid recommendation system with many-objective evolutionary algorithm. Expert Syst. Appl. 159, 113648 (2020) 15. Vargas, S., Castells, P.: Rank and relevance in novelty and diversity metrics for recommender systems. In: Proceedings of the Fifth ACM Conference on Recommender Systems (2011) 16. Zhang, Q., Li, H.: MOEA/D: a multiobjective evolutionary algorithm based on decomposition. IEEE Trans. Evol. Comput. 11(6), 712–731 (2007) 17. Yang, D., et al.: Modeling user activity preference by leveraging user spatial temporal characteristics in LBSNs. IEEE Trans. Syst. Man Cybern. Syst. 45(1), 129–142 (2014) 18. Zhao, K., et al.: Discovering subsequence patterns for next POI recommendation. In: IJCAI (2020) 19. Rendle, S., et al.: BPR: Bayesian personalized ranking from implicit feedback. arXiv preprint arXiv:1205.2618 (2012) 20. Mukherjee, R., Sajja, N., Sen, S.: A movie recommendation system–an application of voting theory in user modeling. User Model. User-Adap. Inter. 13, 5–33 (2003)
Intelligence Evaluation of Music Composition Based on Music Knowledge Shuo Wang1(B) , Yun Tie1 , Xiaobing Li2 , Xiaoqi Wang3 , and Lin Qi1 1 School of Electrical and Information Engineering, Zhengzhou University, Zhengzhou, China
[email protected]
2 Central Conservatory of Music, Beijing, China 3 CSSC Systems Engineering Research Institute, Beijing 100094, China
Abstract. In recent years, due to the application of neural networks in music generation, a lot of music has appeared on the Internet, which poses more challenges for screening music and selecting songs. Although subjective evaluation is currently the ultimate choice for evaluating the beauty of songs, this method has limitations due to human resources and efficiency issues. Therefore, this paper combines information science and music science to propose an aesthetic computing framework for music composition on the Internet. Firstly, more accurate basic and advanced features of music are extracted by transcribing the separated musical accompaniment. Next, we will match them to the music rules. Then we use appropriate merging rules to determine the weight of elements, so as to achieve the purpose of calculating the aesthetic feeling of music composition. This paper presents for the first time a musical aesthetic evaluation framework combining music information dynamics, audio transcription and Zipf’wlaw. The experimental results prove that the objective beauty calculation method proposed is feasible and effective. Keywords: Computing aesthetic · Objective evaluation · Musical dynamics · Comentropy · RNN
1 Introduction In the “2021 Chinese Digital Music Annual White Paper” released by Tencent Music Data Research Institute, it is pointed out that in 2021, new songs will spring up like mushrooms, and the total number of new songs will reach 1.145 million, an increase of 53.1% year-on-year in 2020, which is equivalent to a new song coming in 27 s in 2021. A new song will be born in seconds [1]. In daily life, we don’t have so much time to listen to every song carefully. How to use time more effectively to find nice and aesthetic music has become a widespread demand in people’s daily life. More importantly, studies have proved that music has a great influence on people’s emotional changes [2], and music without aesthetic feeling is easy to make people feel depressed. Therefore, it is of good practical significance to perform music aesthetic calculations. As one of the most contagious arts of human beings, music has attracted a large number of researchers to study it in detail. The quality of these generated music is uneven, © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 373–384, 2023. https://doi.org/10.1007/978-981-99-4761-4_32
374
S. Wang et al.
and many of them cannot even be called music. Therefore, having a set of evaluation systems to screen generated music also has very good practical significance for the field of music generation. It can quickly screen out relatively high-quality music and save people’s time. Unfortunately, there is no unified standard for generative work in the field of art. This result is not surprising. Since the development of the arts such as music, film and television, and literature, their evaluation systems have been accompanied by subjectivity. It is difficult to use a certain a strict standard to measure them [3]. Because music is the product of human creativity, if humans do not have a common definition and consensus on creativity, it is difficult to quantify and calculate the evaluation system. Therefore, it is worth considering to develop an objective evaluation system [4]. The ratio between order and complexity proposed by American mathematician George David Birkhoff in 1928 as an aesthetic measure [5], and has continued to develop on this basis. Music dynamics also draws on his ideas. However, computational aesthetics is rarely used in other fields other than visual art, graphics, and design. In particular, computational aesthetics is rarely used in systems that can automatically generate music [6]. Therefore, in order to study the application of computational aesthetics in music art, this paper proposes a music computational aesthetics framework combined with music rules to evaluate the aesthetics of music.
2 Related Work 2.1 Music Computational Aesthetic Indicators Not include Knowledge of Music Metrics that usually do not include domain knowledge about music are to compare the difference between the generated music and the original data, and calculate their mathematical statistical properties to calculate whether the music is good or bad. Evaluation indicators of probabilistic subject knowledge such as density estimation and probability have been well applied in the image field [7], and have also been applied to the music field [8]. Huang et al. proposed a frame-level evaluation to validate the loglikelihood between model output frames in terms of independent Bernoulli variables [9]. Johnson et al. also verified the log-likelihood performance of their prediction task [10]. This method is also considered to be a meaningful quantitative indicator. For generative models, the evaluation indicators usually include loss, BLEU score, etc. [11], where the loss only represents the difference between the predicted value and the real value, and does not represent the quality of the music, and evaluates the beauty of the music. The BLEU score is also used to illustrate the similarity between the test set and the model-generated music, which cannot be applied to the aesthetic evaluation system. The evaluation indicators of model-generated music are all for the purpose of prediction. Include Knowledge of Music Metrics that contain music domain knowledge are usually derived from some musical concepts and can be computed on specific musical features. Such as pitch, rhythm, chords, etc. In the paper, Chuan et al. verified the feasibility of musical indicators such
Intelligence Evaluation of Music Composition
375
as tonic tension and interval frequency, and verified their impact on the performance of the model [12]. In the paper, Wangtao et al. verified the feasibility of harmony degree and chord accuracy as evaluation indicators [13]. Kristy Choi et al. mentioned Note Density, Pitch Range, Mean Pitch (MP)/Variation of Pitch (VP), Mean Velocity (MV)/Variation of Velocity (VV), Mean Duration (MD)/Various factors such as Variation of Duration (VD) used as the possibility of evaluation [14], and these speed performance evaluations also provide some additional possibilities. Baiyong et al. used the method of minimum distance classifier to detect whether the generated music conforms to the characteristics of classical music [15]. 2.2 Music Transcription Music transcription refers to the extraction of human-readable recordings describing musical performances from music, which lists the performance of notes with pitch levels and corresponding time stamps [16]. Although the problem of automatic pitch estimation of monophonic signals has been solved, creating an automatic system capable of transcribing polyphonic music without restrictions on the degree of polyphony or the type of instrument remains difficult [17]. The application of signal processing methods to music transcription was detailed by Klapuri et al. [18].
3 Method 3.1 Overview In this section, we will describe in detail the entire detailed process of music aesthetics calculation, which is used to filter music from a large number of music and give the calculation the beauty of music. Figure 1 depicts an overview of our framework. It is composed of three parts. The first section is audio transcription, which is in box. Then the second part is feature extraction from midi. And the last section is aesthetic calculation.
Fig. 1. Overall diagram of aesthetic computing
376
S. Wang et al.
3.2 Audio Transcription In this part, we will describe our transcription process. First we use Librosa to extract the Mel frequency cepstral coefficient (MFCC) in the music. The linear transformation of the logarithmic energy spectrum of the nonlinear Mel scale is a feature widely used in acoustics. In this process, we set the FFT window to 2048, and the sampling frequency to 16 kHz. We can define the logarithmic mel spectrum input by the neural network as M ∈ RK∗F , where K is the number of frames, and F is the frequency of the mel spectrum. We will get the M as the input of the neural network. At the same time, we define a pitch recognition variable Pr ∈ (0, 1)K∗H , where K has the same concept as the frame number K just mentioned, and H is the number of pitches. Here, the number of piano keys is 88, which includes 52 white keys and 36 black keys. We take the number 0 or 1 of Pr to represent the pitch within the target frame. The output and input have the same dimension K * F, and the purpose of the function is to match the input with the pitch recognition. The function model parameters are learned by the neural network. At the same time, we set the loss function as follows: K H lb (Pr (k, h), Po (k, f)) (1) lf = k=1
h=1
Among them lf is the LSTM loss function, and the binary cross-entropy function. Since the recognition result Pr in this paper is only 0 or 1, its essence can be regarded as a binary classification problem. Using the binary cross-entropy function can effectively train the network. The binary cross-entropy function is defined as: 1 N yi log(p(yi )) + (1 − yi )log(1 − p(yi )) (2) lb = − i=1 N For the case where the label yi is 1, if the predicted value p(yi ) is close to 1, then the value of the loss function should be close to 0. On the contrary, if the predicted value is close to 0 at this time, the value of the loss function should be very large, which is very consistent with the nature of the log function. The playing speed of the note has a certain relationship with the loudness of the note, as shown in the formula, where v is the note speed. rdB = 20 log(m ∗ v + b)2
(3)
Our work determines the speed of the transcribed note by measuring the note velocity. MIDI files use integers 0–127 to represent the loudness of a note, and there is a positive correlation between the value and the loudness. Before starting the calculation, we normalize [0, 127] to [0, 1], and define the ground truth and prediction speed as Gvel , Pvel . And we define the loss function: K H Lvel = lb (Gvel (k, h), Pvel (k, h)) (4) k=1
h=1
Among them, lb is the same as the binary cross entropy function mentioned above, and finally restore the prediction speed from [0, 1] to [0, 127]. The method we use to extract features is jsymbolic, which is able to extract the feature matrix we need, such as pitch, length, chord, note density, beat histogram, mean tempo and so on. Of course, the most basic elements is pitch, and many other elements are calculated on this basis.
Intelligence Evaluation of Music Composition
377
3.3 Aesthetic Calculation In the feature extraction step, we are able to get the pitch, duration characteristics of the accompaniment. Next we describe several metric calculations that are included in the calculations in this work. Zipf’law: The law proposed by the Harvard University linguist Zipf is an experimental law in the natural language discipline. Experiments have proved that Zipf’law is still applicable in the field of music [19]. In this work, we sort the extracted pitch features p ∈ {p1 , p2 , p3 .....pn } according to their frequency of occurrence. And we define the sorted and normalized pitch sequence as p∗ , and the calculations is: T(pi ) = −log p∗i
(5)
According to the three characteristics: the smooth distribution of the note sequence, the smoothness of the melody curve and the entropy maximization of the melody change, the sequence T extracted from a beautiful piece of music should be fit to a straight line. At this point, we only need to use the optimize linear fitting method in the spicy library to process the final sequence: f = aTi + b
(6)
Here, the slope a has no practical physical meaning, but it has been demonstrated through a large number of experiments that its range is between −0.1 and −0.5 [19]. At the same time, the fitted line may have the same slope, but the original data fluctuates more than the fitted line, so we use the fitting coefficient R2 to measure the degree of music fitting when the slope is the same. In this work, as long as the slope is within a reasonable range, there is no need to do more considerations, and only need to compare the fitting coefficients. The fit coefficients are defined as: R2 = 1 − SStot = SSres =
SSres SStot
(7)
(yobserve − y)2
(8)
(yobserve − ymatching )2
(9)
Pitch Entropy: Thinking from the perspective of Musical Information Dynamics, music is the result of motivated creation, and the creation process needs to follow certain arrangement rules, which limit the change of notes, but this does not mean note changes are constant. Music arrangement needs to have appropriate tension to express the author’s emotion and resonate with the audience, so enough changes are needed. Linking the information entropy proposed by Shannon to the pitch entropy proposed in music can be a good way to evaluate the information content of a song, that is, the tension of music or the expressive power. It is defined as: pi log pi (10) Pe = − i
378
S. Wang et al.
We take the result of pitch entropy Pe to measure musical tension. Simultaneous Consonance: The consonance interval is derived from the relationship between the notes of the song, such as pure one, fourth, octave, big and small thirds, augmented fifths, and so on. For centuries, there have been many explanations for consonant intervals, and the more generally accepted theories are the coordination of sound amplitude fluctuations and the masking effect between sounds. Beating refers to: cos(2πf1 t) cos(2πf2 t) = 2cos(2πft) cos(2π δt)
(11)
When two tones are of similar frequency, the listener will hear the latter, slow amplitude fluctuations (0.1–5 Hz) are considered a not unpleasant loudness oscillation, but fast amplitude fluctuations (20–30 Hz) Having rough quality descriptions, this roughness is thought to be responsible for cognitive dissonance. Masking is due to interference between two tones. But one thing is clear, that is: harmonious interval combinations make people sound pleasant and make music aesthetically pleasing, while dissonant intervals make music appear unstable and make people feel bad. We calculate the proportion of consonant intervals to the total interval by looping through the intervals between every two notes, which is defined as: ps (12) Sr = pi These ps include the perfect 1st, 4th, 5th, and octave. We measure the degree of beauty between adjacent tones by calculating the results. Sr by measuring the musical aesthetics from a microscopic perspective, which belongs to the calculation between notes, combined with the macroscopic Zipf’law, it can better complete the aesthetic calculation work. Chord Entropy: A chord refers to the combination of at least two to seven tones. These tones can be sounded successively or at the same time. Most chords composed of multiple tones sound harmonious and pleasant, making the music rich in changes and the content of the composition flexible. In this paper, we extract the chord types for subsequent calculations, which are divided into ten types: 1) only contains two pitch categories, 2) minor triads, 3) major triads, 4) diminished triads, 5) augmented triads Triads, 6) other triads, 7) minor seventh chords, 8) dominant seventh chords, 9) major seventh chords, 10) other four-pitch chords, 11) complex chords with more than four pitch levels. In this work, vibrato and portamento are not considered. ci log ci (13) Ce = − i
In the formula, chord entropy Ce is used as our output to calculate the amplitude and regularity of beating changes between multiple notes. Among them, the count of chords is calculated by calculating the proportion of the number of chords in each beat in the continuous singing time: ci =
T c(i,t) t=1 t
(14)
Intelligence Evaluation of Music Composition
379
Tone Span: TS measures the range of pitches in the piece of music, which will affect human evaluation of music to a certain extent. For example, a piece of music with a wide range and spanning multiple ranges will make people sound more shocking and make the music more attractive. In this work, the TS index only appears as a bonus item, because a small TS does not mean that the song is not good. When the TS spans three octaves, we choose to add points to the final music season calculation. Ts = pmax − pmin
(15)
Pitch Variation: PV measures the number of different pitches in a piece of music. Songs with more pitches in the same time period can provide a more dynamic sense of hearing. Of course, this indicator only appears as a bonus item, and will not affect the final score too much. When the number of different pitches reaches 60, the final calculation result will add points. So far we have introduced several aesthetic calculation indicators that need to be used in this work: Zipf’law, pitch entropy, the ratio of Simultaneous Consonance, chord entropy, Mean Tempo, Tone Span, Pitch Variation and their respective uses.
3.4 Merger Rules After the description in the above sections, the results of various music evaluation indicators can be clearly calculated from the music. In this subsection we describe how we aggregate and combine the various computed metrics into a final aesthetic score. The first is the fitting coefficient R2 , since the fitting coefficient floats between [0–1] and gathers at [0.7–1]. Since the index value is small, we need to reasonably distinguish the judgment results between different music. We use the sinh function to expand the distinction. Second step is Pe optimization. After a lot of experimental verification, we can conclude that its value is floating between [4, 6], but it is not absolute, so we use the method P8e to quantify the pitch entropy, so that it can be added to the final within the scoring rules. The third step is Sr . We evaluate because its value often appears between [0.3, 0.6]. In order to increase the discrimination, we use it for processing eSr − 1. The last step, we deal with chord entropy Ce . Since the calculation of chord ratio is used in this work, this will cause us to be unable to distinguish the chord entropy of different music when calculating music chord entropy. So we take out the deciles and percentiles of the calculation results and multiply by ten to calculate the result, that is 10 * modf(Ce ). In this way, we can get the final calculation formula: Score = A ∗ [sinh(R2 ),
Pe , (eSr − 1), 10 ∗ modf(Ce ).] + Extra score 8
(16)
Among them, A is the coefficient matrix. In Table 1 we give the relative importance required by AHP, which depends on the average of the judgments of multiple experts. At the same time in Table 2 we give the analysis results. So far we have completed the merging of calculation indicators and the calculation of aesthetic scores.
380
S. Wang et al. Table 1. AHP hierarchy analysis data sinh(R2 )
Pe 8
e Sr − 1
10 ∗ modf(Ce )
sinh(R2 )
1
1.25
1.111
1.429
Pe 8 e Sr − 1
0.8
1
1.429
1.667
0.9
0.7
1
1
10 ∗ modf(Ce )
0.7
0.6
1
1
Table 2. AHP analysis results Item
Feature vector
Weights
Largest eigenvalue
CI value
sinh(R2 )
1.171
29.284%
4.028
0.009
Pe 8 e Sr − 1
1.159
28.969%
0.878
21.961%
10 ∗ modf(Ce )
0.791
19.786%
4 Experiment This section will describe the experimental results from the following two aspects: 1) An example of the results of the aesthetic calculation indicators of the sample song. 2) Expert + volunteer human ear evaluation experiment. There are three main sources of data: one is the self-made 200-piece MIDI music collection, including Jazz, Bule, Bach piano pieces, Beethoven piano pieces, Mozart piano pieces and other world famous pieces. The second is to randomly select songs from the popular music column in NetEase Cloud Music. The third is the piano data set GiantMIDI-Piano provided by ByteDance [20]. 4.1 Sample Song Example For this sample song, we use Jay Chou’s “Sunny Day” and Zhong Lifeng’s “Blue Traveler”. These two songs have also been evaluated by experts. Among them, Sunny Day scored higher and the other song Blue Traveler scored relatively low. Next, the aesthetic indicators and final evaluation results will be calculated respectively according to the method proposed in this paper. It needs to be clearly pointed out that the pitch result here is the normalized probability situation, that is, the frequency of a certain pitch appearing in this piece of music. As shown in Fig. 2, Sunny Day with higher scores has a better fitting effect, and the fitting coefficient R2 is also higher in comparison. In the following Table 3 we give the calculation results of each sample. In Table 4, we give the scores of two songs, the first category is calculated by the method proposed in this paper, and the second category is scored by three experts trained in professional music. Three experts, from the Central Conservatory of Music, score the
Intelligence Evaluation of Music Composition
381
Fig. 2. Zipf’ law index experiment of SunnyDay and Blue Traveler
Table 3. Sunny Day and Blue Traveler Index Calculation Music
R2
Pe
Sr
Ce
Ts
Pv
Sunny day
0.97
5.04
0.47
4.48
83
45
Blue Traveler
0.82
4.20
0.49
4.92
40
33
two songs based on their educational experience. The full score is 100. We can clearly see the results from the scoring. The calculation method we use is similar to that of experts, which also proves the feasibility of our aesthetic calculation method. Table 4. Sunny Day and Blue Traveler Scores Song
Our proposed method
Expert
Sunny Day
81.85
87.5
Blue Traveler
73.15
76.3
4.2 Human Ear Evaluation Experiment In this subsection, we describe two other experimental results. First of all, we selected some music from GiantMIDI-Piano and mixed them with self-made midi music collection and provided them to 50 volunteers. The 50 volunteers, respectively from Zhengzhou University and the Central Conservatory of Music, evaluated their songs through the WeChat mini-program. In this experiment, we rated the music with more than 75 points as aesthetic music. We compare it with the result of volunteer feedback, that is whether the song is good or not. As shown in the Table 5, the favorable rate using the calculation method in this paper is 65.3%, which is lower than the 76.8% positive rate of the volunteers, but it is helpful for judging whether the music is aesthetically pleasing.
382
S. Wang et al. Table 5. Human Ear Evaluation Favorable rate
our proposed method
65.3%
Human ear assessment
76.8%
In addition, we verified the feasibility of our proposed method by randomly selecting 10 pieces of music from the popular column in NetEase Cloud, and comparing the three modes of system calculation aesthetic score, expert scoring, and volunteer evaluation. The results are shown in Table 6. Table 6. Scores for each method Music
Our proposed method
Expert
Volunteer average score
A Tavernkeeper
80.85
83.2
91.5
Hangzhou
77.25
85.4
87.5
Jinan Jinan
78.15
81.1
83.7
Travel Youth
77.35
78.5
81.6
story Telling boy
73.15
75.6
80.2
Lost
76.35
79.3
79.3
Safflower Valley
68.65
76.7
78.4
Letter
71.75
74.2
77.5
Cheap Grocery Store
70.45
75.9
74.6
Letter to Time
69.15
72.3
70.4
In Fig. 3 a), we show the comparison of the three. The order of the songs in the figure is the same as that in the table. Through analysis, it can be clearly known that the evaluation method proposed in this paper is roughly the same as the trend of the expert scoring and the average score of the volunteers. Only a small number of songs have different trends, which also proves the feasibility and effectiveness of the method proposed in this paper. As shown in Fig. 3 b), in order to better show the results, we define 20 songs as a group, and use the violin diagram to measure the difference between the method proposed in this paper and the experts’ scoring under the two groups.
Intelligence Evaluation of Music Composition
383
Fig. 3. Comparison of scoring methods
5 Conclusion Music evaluation has lagged behind music generation, in large part because evaluating aesthetic compositions often requires human subjective testing. Considering the required resources, this paper proposes a framework for music composition aesthetics calculation. The proposed method combines audio transcription, music information dynamics and calculates the extracted features rich in music domain knowledge based on multiple music knowledge indicators, and finally an aesthetic score is given. The feasibility of the method proposed in this paper is verified by comparing with the human ear experiment. The analytical approach presented in this paper offers the possibility to screen music as well as assess the quality of generative models.
References 1. 2021 Annual White Paper of Chinese Digital Music. Tencent Music Data Research Institute (2022) 2. Ji, S., Luo, J., Yang, X.: A comprehensive survey on deep music generation: multi-level representations, algorithms, evaluations, and future directions (2020) 3. Tikhonov, A., Yamshchikov, I.P.: Music generation with variational recurrent autoencoder supported by history. arXiv, abs/1705.05458 (2020) 4. Yang, L.-C., Lerch, A.: On the evaluation of generative models in music. Neural Comput. Appl. 32(9), 4773–4784 (2018). https://doi.org/10.1007/s00521-018-3849-7 5. Iqbal, A.: Computational aesthetics. Encyclopedia Britannica (2015) 6. Wu, Y.-T., Chen, B., Su, L.: Multi-instrument automatic music transcription with selfattention-based instance segmentation. IEEE/ACM Trans. Audio Speech Lang. Process. 28, 2796–2809 (2020). https://doi.org/10.1109/TASLP.2020.3030482 7. Theis, L., et al.: A note on the evaluation of generative models. CoRR abs/1511.01844 (2016). n. pag 8. Dong, H.-W., et al.: MuseGAN: symbolic-domain music generation and accompaniment with multi-track sequential generative adversarial networks. arXiv abs/1709.06298 (2017). n. pag 9. Huang, C.-Z.A., et al.: Counterpoint by convolution. In: ISMIR (2017) 10. Johnson, D.D.: Generating polyphonic music using tied parallel networks. In: Correia, J., Ciesielski, V., Liapis, A. (eds.) EvoMUSART 2017. LNCS, vol. 10198, pp. 128–143. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-55750-2_9 11. Gillick, J., et al.: Learning to groove with inverse sequence transformations. In: ICML (2019)
384
S. Wang et al.
12. Chuan, C., Herremans, D.: Modeling temporal tonal relations in polyphonic music through deep networks with a novel image-based representation. In: AAAI (2018) 13. Jin, C., et al.: A transformer generative adversarial network for multi-track music generation. CAAI Trans. Intell. Technol. 7(3), 369–380 (2022) 14. Hadjeres, G., Pachet, F., Nielsen, F.: Deepbach: a steerable model for bach chorales generation. In: International Conference on Machine Learning. PMLR (2017) 15. Jin, C., et al.: A style-specific music composition neural network. Neural Process. Lett. 52(3), 1893–1912 (2020) 16. Cemgil, A.T.: Bayesian music transcription (2004) 17. Benetos, E., Dixon, S., Giannoulis, D., Kirchhoff, H., Klapuri, A.: Automatic music transcription: challenges and future directions. J. Intell. Inf. Syst. 41(3), 407–434 (2013). https:// doi.org/10.1007/s10844-013-0258-3 18. Klapuri, A., Davy, M., (eds.): Signal processing methods for music transcription (2007) 19. Nan, N., et al.: Common quantitative characteristics of music melodies - pursuing the constrained entropy maximization casually in composition. Sci. China Inf. Sci. 65, 1–3 (2022) 20. Kong, Q., et al.: GiantMIDI-piano: a large-scale MIDI dataset for classical piano music. Trans. Int. Soc. Music. Inf. Retr. 5, 87–98 (2022)
StockRanker: A Novelty Three-Stage Ranking Model Based on Deep Learning for Stock Selection Rui Ding, Xinyu Ke, and Shuangyuan Yang(B) School of Informatics, Xiamen University, Xiamen 361100, China [email protected]
Abstract. The use of deep learning to identify stocks that will yield higher returns in the future and purchase them to achieve returns greater than the market average is an attractive proposition. However, in recent years, many studies have revealed two major challenges facing this task, including how to effectively extract features from historical stock price data that can be used for stock prediction and how to rank future stock returns using these features. To address these challenges, we propose StockRanker, an innovative three-stage ranking model for stock selection. In the first stage, we use autoencoder to extract features embedded in the historical stock price data through unsupervised learning. In the second stage, we construct a hypergraph that describes the relationships between stocks based on industry and market capitalization data and use hypergraph neural networks (HGNN) to enhance the features obtained in the first stage. In the third stage, we use a listwise ranking method to rank future stock returns based on the stock features obtained earlier. We conducted extensive experiments on real Chinese stock data, and the results showed that our model significantly outperformed baseline models in terms of investment returns and ranking performance. Keywords: Deep learning · Stock prediction · Learning to rank
1 Introduction Investments in stocks can be classified into two types: passive and active. Passive investments aim to achieve the average market return, while active investments seek to outperform the market by selecting stocks that are likely to perform better in the future and thus obtain excess returns [1]. Conventional approaches for analyzing stocks encompass fundamental analysis and technical analysis. The former evaluates the potential for a stock to generate higher returns by scrutinizing both the industry’s outlook in which the stock operates and the financial position of the stock itself [2]. The latter, on the other hand, leverages historical stock price data and technical indicators derived from it to identify lucrative investment opportunities [3]. Given the remarkable achievements of deep learning techniques in the domains of computer vision and natural language processing, considerable efforts have been directed towards leveraging deep learning algorithms to predict future stock returns [4]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 385–396, 2023. https://doi.org/10.1007/978-981-99-4761-4_33
386
R. Ding et al.
Many of the current techniques employed for stock prediction involve transforming the stock forecasting issue into either a classification [5–7] or regression problem [8, 9]. The classification approaches generally categorize forthcoming stock returns as up or down, with models then predicting the likelihood of a stock price increase. On the other hand, regression methods directly estimate future stock prices or future stock returns. The most significant issue with these methods is that they do not concentrate directly on investment returns as the optimization objective. As such, the creation of a higher performing model does not necessarily correspond to higher investment returns, ultimately resulting in the models being unsuitable for real investments (Table 1). In light of this, the stock ranking models have been proposed [10]. The primary aim of these models is to rank future stock returns so that the model can identify stocks with higher future returns to achieve excess returns. Thus, the stock ranking model can guide actual investment directly. Table 1. The classification method measures model performance by accuracy and the regression method by MSE. C1 and R1 outperform C2 and R2, but have lower return on investment. Stock
S1
S2
S3
S4
S5
Return (Ground Truth)
+20
+15
-5
-10
-20
Model Performance
Top-1 Profit
0.6 0.9 18 60
0.7 0.4 20 12
0.8 0.6 -6 -12
0.2 0.3 -9 -15
0.4 0.6 -22 -50
80% 40% 7 516.6
-5 +20 +15 +20
Classification Regression
C1 C2 R1 R2
The presented stock ranking model offers a novel framework for utilizing deep learning methods in the domain of stock prediction. Nevertheless, existing stock ranking models encounter two pressing challenges that demand attention. The primary challenge that arises in the prediction of stocks pertains to the extraction of effective features from historical price data that can be utilized for forecasting. Historical price data contains a wealth of features that can be leveraged for stock prediction (e.g., 5-day percentage change, whose stock-picking power is shown in Fig. 1), with technical analysis providing numerous technical indicators that are derived from historical price data. Contemporary deep learning methodologies incorporate feature extraction modules, with ultimate optimization goals geared towards returns. However, owing to the opaque nature of deep learning models and the dynamic and complex nature of the stock market, it is challenging to determine whether the features extracted by the feature extraction module are useful for stock prediction. The second major challenge in this domain relates to devising models that can accurately predict the order of stock returns through the utilization of existing features. Prior research endeavors have focused on incorporating a pairwise ranking error term into the loss function, which essentially involves predicting stock returns with partial optimization of ranking performance. Nevertheless, this approach does not entail a direct prediction of the order of stock returns.
StockRanker: A Novelty Three-Stage Ranking Model
387
Fig. 1. Each trading day, stocks are ranked from high to low based on 5-day percentage change and divided into two groups, high and low, and bought with equal weight and held until the next day, i.e. daily rebalance. Curve market is the average return of the whole market.
To mitigate the aforementioned issue, we propose a novel stock selection methodology termed StockRanker, which employs a three-stage ranking model utilizing both unsupervised and supervised learning approaches. Firstly, we utilize the unsupervised autoencoder to extract historical stock price data features. During training, the autoencoder employs the same input and output time series, which can be interpreted as a reduction in the dimensionality of the historical stock price data. The extracted features capture solely the variations in the historical price data, and are not directly associated with future returns. In the second stage, we construct hypergraphs dynamically, incorporating the stock relationships based on their industry and market capitalization information. The features derived from the first stage are augmented using the HGNN. These relationships are helpful in stock prediction. In the third stage, we utilize the listwise ranking approach to rank the stocks and design a novel loss function specifically optimized for the overall ranking error of the model. We conduct extensive experiments on real Chinese stock data and our findings reveal that StockRanker outperforms the baseline model in both ranking performance and investment return. The contributions of our work can be summarized as: • We use unsupervised learning to extract features from historical stock price data such that the extracted features reflect only the changes in historical stock prices. • We propose an innovative three-stage ranking model, StockRanker, for stock ranking, using the listwise ranking method, and design a novel listwise rank loss function specifically for stock ranking. • We conducted extensive experiments on real Chinese stock data, and the results demonstrate the effectiveness of our model in investment.
388
R. Ding et al.
2 Related Work 2.1 Stock Prediction Based on Classification and Regression Various deep learning techniques have been employed for stock prediction, typically by transforming the problem into either a classification or regression task. Wang et al. utilized ARWIA analyses to extract features from historical stock prices and employed a recurrent neural network to predict medium-term upward and downward trends [6]. Feng et al. combined adversarial training with LSTM to mitigate overfitting when predicting stock price movements [5, 11]. Kim et al. introduced Hats, a hierarchical attention network that selectively aggregates different types of relationship information to create stock features [7]. Inspired by the Discrete Fourier Transform, Zhang et al. proposed the SFM recurrent network to predict stock future prices, which decomposes the hidden state in the network into components of different frequencies so that it can capture the trading patterns of different frequencies in the stock history [8]. Qin et al. proposed the DA-RNN, which incorporates two stages of attention to better exploit long-term dependencies in historical stock price information for stock price prediction [9]. 2.2 Stock Prediction Based on Ranking Models Feng et al. pointed out the transformation of stock prediction into a classification problem and a regression problem without directly taking investment return as the optimization objective, and proposed RSR, which is the first stock ranking model [10]. RSR utilizes historical stock price information and stock relation information for stock ranking, and employs a loss function with a pairwise ranking error term. RSR outperforms classification and regression models in terms of investment returns, proving the superiority of the ranking model for the first time. Inspired by RSR, Sawhney et al. introduced Hawkes attention into the feature extraction module and used hypergraph and HGNN to model stock relations, thereby outperforming RSR in both ranking performance and investment returns [12].
3 Methodology 3.1 Problem Formulation Our approach transforms the stock selection task into a listwise stock ranking task. Let S = s1 , s2 , . . . , sN be a set of N stocks, where si represents one of the stocks. At trading day t, si corresponds to a vector xit ∈ RD consisting of original price features, including opening price, closing price, 5-day average price, etc. (D denotes the number of original price features), a closing price pit , and a daily return rate rit =
pit −pit−1 . pit−1
If we use L days
to a of historical price data to predict stock returns, then on trading day t, si corresponds t−L+1 t−L+2 t t D×L , xi , . . . , xi ∈ R , and sequence of original price feature vectors Xi = xi t t t t N ×D×L P = X1 , X2 , . . . , XN ∈ R represents the list of sequences of original price feature vectors for all stocks. StockRanker takes Pt as input and outputs the predicted return
StockRanker: A Novelty Three-Stage Ranking Model
389
list yt = y1t , y2t , . . . , yNt , and the list of real returns of the stock is rt = r1t , r2t , . . . , rNt . If yit > yjt . , it means that the model predicts rit > rjt . . Thus, sorting yt in descending order yields the predicted relative order of future returns for the set of stocks S on trading day t. 3.2 Framework Overview
Fig. 2. The overall architecture of StockRanker.
Figure 2 illustrates the overall architecture of our proposed StockRanker, which consists of three modules. Feature Extraction Module. The present study employs an autoencoder model to derive the latent features from the historical price data of stocks. The autoencoder model is an unsupervised learning technique predicated on the Encoder-Decoder architecture. In the present model, both the Encoder and the Decoder are founded on Multi-Layer Perceptron (MLP) and constructed utilizing linear and leakyrelu layers. It is worth noting that the Decoder cponent of the autoencoder solely performs its function during the training phase. Feature Enhancement Module. StockRanker constructs stock relationship hypergraphs based on industry and market capitalization data of stocks, and augments the features extracted in the first step using HGNN. Listwise Ranking Module. The listwise ranker employed in the present study is a Multi-Layer Perceptron (MLP) that is composed of linear layers and leakyrelu layers. The ranker is designed to accept the features obtained through the feature enhancement module as input and subsequently generate predictions regarding future stock returns. Furthermore, the ranker is optimized utilizing the listwise ranking loss function to achieve listwise ranking of future stock returns.
390
R. Ding et al.
3.3 Feature Extraction Historical price data of stocks contains features that can be used to predict future returns [13]. 5-day percentage change is a typical case study, and Fig. 1 demonstrates its stockpicking capabilities. Stock prediction differs from tasks such as image recognition in that it has a higher complexity and stochasticity. With supervised learning methods, features are extracted directly using the future returns of a stock as the target, which has a high degree of uninterpretability. Therefore, the first step of our approach is to train an unsupervised learning-based model to extract features that objectively reflect the historical price changes of a stock. The autoencoder and its variants were originally designed for dimensionality reduction and feature extraction [14]. It consists of an encoder and a decoder. The encoder receives an input vector x and compresses it into a low-dimensional feature vector z = fθ (x). The decoder takes z as input and outputs a reconstructed vector x = gθ (z). Here, θ and θ are the learnable parameters of the encoder and decoder, respectively. The optimization objective of the autoencoder is to minimize the average reconstruction error by making x and x as close as possible. This is achieved by minimizing the mean squared error between the input and output vectors: θ ∗, θ
∗
= argmin θ,θ
n 1 (i) (i) L x ,x n i=1
1 L(x(i) , gθ fθ x(i) n n
= argmin θ,θ
i=1
where L x, x = x − x 2 . In StockRanker, we employ a conventional standard autoencoder architecture, where both the encoder and decoder are Multi-Layer Perceptrons (MLPs) utilizing leakyrelu activation functions. For a stock si , we flatten its original price feature vector sequence Xit at trading day t into a vector uit , which is input into the encoder to extract the feature vector zit . Next, zit is passed to the decoder to obtain vit . During training, we minimize the reconstruction error to enable zit to effectively capture the historical price change information of si . The training of this module is independent of other modules. The trained encoder is used as the initial feature extractor, with its output passed as input to the next module. 3.4 Feature Enhancement The fluctuations in stock prices are not mutually independent; stocks that are related to each other often rise and fall together, such as stocks in the same industry or of similar market capitalization [13, 15]. Therefore, enhancing the features generated in the first stage based on the interrelationships among stocks would be meaningful. In our model, we leverage HGNN to achieve this objective. A hypergraph is defined as G = (V, E, W), where V is the set of all vertices, E is the set of hyperedges, and W is the weight of each hyperedge. The hypergraph G is typically represented using the incidence matrix H ∈ R|V |×|E | , where |V| is the number
StockRanker: A Novelty Three-Stage Ranking Model
391
of vertices and |E| is the number of hyperedges. The matrix W is a diagonal matrix, and so are De and Dv , which respectively represent the degree of edges and the degree of vertices. In our case, each stock represents a vertex, and we establish hyperedges based on daily stock industry and market capitalization data, with W set as a unit diagonal matrix. HGNN consists of hypergraph convolution layers, and HConv(·) is the hypergraph convolution. We use the hypergraph Laplacian defined by Feng et al. [16], −1/2 T −1/2 , then the update rule of the hypergraph convolution = I − Dv HWD−1 e H Dv is as follows: Z(l+1) = HConv Z(l) , H, P 1 1 − T − 2 (l) H D Z P = ELU Dv 2 HWD−1 v e where Z(l) and Z(l+1) are the input and output of l th layer. Z(1) is the output of the first stage, i.e. the features extracted by the feature extraction module. P is the learnable parameter matrix, ELU is exponential linear unit activation. The final output of HGNN will be the input of the next module. 3.5 Listwise Ranking Method and Listwise Ranking Loss The third component of the model is the ranking module, which utilizes a listwise approach [17]. Both the input and output of the module are in list format. Let eit denote the final features of stock si on trading day t, which are the outputs of the feature enhancement module. Et = e1t , e2t , . . . , eNt represents the final features of all stocks S, which are treated as a feature vector list and fed into the ranking module. The ranking module is a multilayer perceptron (MLP) with leakyrelu function. It takes E t t activation t t t as input and outputs the predicted ranking list y = y1 , y2 , . . . , yN for stock returns. If yit > yjt , it means that the model predicts rit > rjt . The key to the listwise ranking method is to design a listwise loss function that directly optimizes the model’s ranking performance on the entire list. We have designed a specialized ranking loss function for the scenario of ranking future stock returns: N t t t t
rit − μtr 2 1 yi − μy L y ,r = ( − ) N σyt σrt i=0
The loss function measures the ranking error between predicted and true stock returns. This is achieved by cross-sectionally standardizing the predicted and true returns and then comparing their differences. Specifically, μty and σyt denote the mean and standard deviation of predicted returns, respectively, while μtr and σrt denote the mean and standard deviation of true returns, respectively. Experimental results demonstrate that the proposed ranking loss function performs well.
392
R. Ding et al.
4 Experimental Setup 4.1 Dataset The effectiveness of StockRanker is evaluated using real Chinese market stock data. The data encompasses two distinct parts: price data used for extracting historical price features, including daily opening price, highest price, lowest price, closing price, 5day moving average, 10-day moving average, 20-day moving average, 60-day moving average, 120-day moving average, and 250-day moving average and stock industry classification data and market capitalization data used for constructing a stock relation hypergraph. The data spans from 2016 to 2022 and is sourced from the Wind database. All data was divided into training, validation, and testing sets in a ratio of 6:2:2. 4.2 Data Preprocessing Price Data Standardization. Given the significant disparities in the prices of distinct stocks, it is imperative to standardize the price data before it is fed into the model. Let hhv denote the maximum value of the original price feature vector sequence Xit of stock si on t
t = Xi . trading day t. Consequently, the standardized price feature vector sequence is X i
hhv
Stock Hypergraph Construction. The methods that use industry classification data to aid in stock prediction typically categorize stocks into fixed industries. However, in reality, listed companies may change their primary business, leading to changes in the industry. Therefore, we dynamically update the industry data and market capitalization data of stocks on a daily basis, and use this to construct a stock hypergraph. Candidate Stock Pool Construction. In order to better reflect real-world scenarios and ensure experimental validity, we filter out certain stocks from our daily candidate pool. These stocks include those that have been listed for less than one year, those that have been suspended or subjected to special treatment in the past 60 trading days, those that rank in the bottom 5% by trading volume, stocks that have hit their daily price limit, and stocks with missing data values exceeding 10%.
4.3 Evaluation Metrics Excess Return. The purpose of the stock selection task is to identify stocks with excess returns in order to outperform the market. Each trading day, we rank the stocks from high to low based on the predictive results of our model, and then divide them into two groups: high and low. Subsequently, we buy the stocks in high with equal weight and hold them until the following day, rebalancing each trading day. Our performance evaluation metrics are the information ratio(IR) and the annualized excess return(AER), both of which are classic measures used to evaluate the performance of active management strategies relative to market benchmarks. IR =
rp σp
StockRanker: A Novelty Three-Stage Ranking Model
393
AER = 250 · rp 1 t t rp − rm rp = n t=1 n 2 1 t rpt − rm σp = n−1 n
t=1
t are the model portfolio’s return and market benchmark return on trading where rpt and rm day t. rp and σp are the mean and standard deviation of the excess returns of the model portfolio, respectively.
Ranking Performance. We employ the most widely used ranking metric in the finance domain, Rank Information Coefficient (RankIC), to assess the model’s ability to rank future stock returns.
T 1 cov R yt , R rt RankIC = T σR(yt ) σR(rt ) t=1
where R yt is the ranking of yt , σR(yt ) is the standard deviation of yt ., and
t t
cov R y , R r is the covariance of R yt and R rt .
5 Results and Analysis 5.1 Comparison with Baselines Table 2 presents a comparative analysis of the performance of StockRanker and several baseline models, classified as CLF (Classification), REG (Regression), and RAN (Rank), in terms of ranking and return on investment. Our findings suggest that ranking models outperform classification and regression models in terms of ranking performance and investment returns, with classification models showing slightly better performance than regression models. These results validate the transformation of stock prediction into a ranking learning problem. Additionally, we found that models that utilize the relationship between stocks outperform other models, indicating that information about the relationship between stocks is useful for predicting the order of future stock returns, a finding consistent with real-world observations. Our proposed StockRanker achieves a RankIC of 0.0544 and an IR of 0.287, which are 13.1% and 21.6% better than the secondbest model, respectively. This demonstrates the effectiveness of the unsupervised feature extraction approach and the listwise ranking method employed in our model. Figure 3 depicts the cumulative excess returns of all models on the test set. Notably, the excess returns of the ranking models exhibit smoother patterns compared to other models. Our proposed StockRanker outperforms all other models, capturing the highest excess return.
394
R. Ding et al. Table 2. Results of StockRanker versus baselines.
CLF
REG
RAN
Methods
Description
A-LSTM [5] HATS [7]
RankIC
AER
IR
LSTM + Adversarial training 0.0252
0.083
0.124
Construct stock multigraphs by hierarchical graph attention
0.0283
0.092
0.116
SFM [8]
LSTM + DFT-based hidden state decomposition
0.0078
0.031
0.053
DA-RNN [9]
RNN + Dual-stage attentions 0.0178
0.061
0.095
RankLSTM [10]
LSTM + Loss function with pairwise ranking error term
0.0302
0.105
0.168
RSR-E [10]
Temporal GCN + Relation strength based on feature similarity
0.0407
0.118
0.181
RSR-I [10]
Temporal GCN + Relation strength based on neural net
0.0426
0.124
0.207
STHAN-SR [12]
Hawkes Attention + HGNN
0.0481
0.135
0.256
StockRanker(Ours)
AutoEncoder + HGNN + Listwise Ranker
0.0544
0.167
0.257
Fig. 3. Comparison of cumulative excess returns for all models
5.2 Ablation Study We conducted ablation experiments to demonstrate the effectiveness of each component of StockRanker, and Table 3 shows the results of the ablation experiments. From the results, it can be observed that if the unsupervised learning-based feature extraction module is replaced with a simple MLP, the model effect decreases significantly,
StockRanker: A Novelty Three-Stage Ranking Model
395
Table 3. Results of ablation experiments performed on StockRanker Description
RankIC
AER
IR
StockRanker-AE
Replace AutoEncoder based feature extraction module with MLP
0.0414
0.112
0.173
StockRanker-SR
Remove feature enhancement module based on stock relationships
0.0493
0.147
0.256
StockRanker-PRL
Replace listwise ranking loss with pairwise ranking loss
0.0392
0.125
0.163
StockRanker
Complete StockRanker
0.0544
0.167
0.287
which demonstrates the effectiveness of our unsupervised learning approach to feature extraction using autoencoder. On the other hand, if the loss function is changed to a loss function with a pairwise ranking error term, the model is no longer a listwise ranking model at this point, and the effectiveness of the model is greatly reduced. This shows the effectiveness of using the listwise ranking method for stock ranking. Also, it can be seen that the feature enhancement module based on the stock relationship contributes to the model effectiveness.
6 Conclusion and Future Work This paper presents a novel approach for stock selection by transforming stock prediction into a ranking task, using the proposed three-stage stock ranking model, named StockRanker. The unsupervised learning model autoencoder is utilized to extract features embedded in historical stock prices, which can be leveraged for stock prediction. The model incorporates a listwise ranking method and a designed listwise ranking loss function to facilitate stock selection scenarios. Furthermore, the inter-stock relationship information is also incorporated to enhance the stock selection capability of the model.The efficacy of the proposed StockRanker is validated using actual Chinese stock market data by assessing its ability to rank the future returns of stocks and its investment returns. Ablation experiments have also been conducted to confirm the necessity of each component in the model. The proposed StockRanker is an innovative framework for stock selection consisting of three components, each of which may employ a distinct neural network.In future work, we intend to try to replace the existing components with different neural networks, for example, by applying a sequence model represented by LSTM to the feature extraction module, so as to examine whether the stock selection effectiveness of the model can be further improved. Acknowledgements. This work was supported by the Natural Science Foundation of Fujian Province of China (No. 2022J01003).
396
R. Ding et al.
References 1. Bodie, Z., Kane, A.: Investments (2020) 2. Nti, I.K., Adekoya, A.F., Weyori, B.A.: A systematic review of fundamental and technical analysis of stock market predictions. Artif. Intell. Rev. 53(4), 3007–3057 (2019). https://doi. org/10.1007/s10462-019-09754-z 3. Nazário, R.T.F., Silva, J.L., Sobreiro, V.A., Kimura, H.: A literature review of technical analysis on stock markets. Q. Rev. Econ. Finan. 66, 115–126 (2017) 4. Jiang, W.: Applications of deep learning in stock market prediction: recent progress. Expert Syst. Appl. 184, 115537 (2021) 5. Feng, F., Chen, H., He, X., Ding, J., Sun, M., Chua, T.S.: Enhancing stock movement prediction with adversarial training. IJCAI (2019) 6. Wang, J.H., Leu, J.Y.: Stock market trend prediction using arima-based neural networks. In: Proceedings of International Conference on Neural Networks (ICNN 1996), vol. 4, pp. 2160– 2165 (1996) 7. Kim, R., So, C.H., Jeong, M., Lee, S., Kim, J., Kang, J.: Hats: a hierarchical graph attention network for stock movement prediction. ArXiv abs/1908.07999 (2019) 8. Zhang, L., Aggarwal, C.C., Qi, G.J.: Stock price prediction via discovering multi-frequency trading patterns. In: Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (2017) 9. Qin, Y., Song, D., Chen, H., Cheng, W., Jiang, G., Cottrell, G.: A dual-stage attention-based recurrent neural network for time series prediction. ArXiv abs/1704.02971 (2017) 10. Feng, F., He, X., Wang, X., Luo, C., Liu, Y., Chua, T.S.: Temporal relational ranking for stock prediction. ACM Trans. Inf. Syst. (TOIS) 37, 1–30 (2018) 11. Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural Comput. 9, 1735–1780 (1997) 12. Sawhney, R., Agarwal, S., Wadhwa, A., Derr, T., Shah, R.R.: Stock selection via spatiotemporal hypergraph attention network: a learning to rank approach. In: AAAI Conference on Artificial Intelligence (2021) 13. Jeanblanc, M., Yor, M., Chesney, M.: Mathematical methods for financial markets (2009) 14. Larochelle, P.V., Bengio, Y., Manzagol, P.A.: Extracting and composing robust features with denoising autoencoders. In: International Conference on Machine Learning (2008) 15. Hou, K., Dijk, M.A.V.: Resurrecting the size effect: firm size, profitability shocks, and expected stock returns. Econ. Appl. Econ. Model. J. (2018) 16. Feng, Y., You, H., Zhang, Z., Ji, R., Gao, Y.: Hypergraph neural networks. In AAAI Conference on Artificial Intelligence (2018) 17. Xia, F., Liu, T.Y., Wang, J., Zhang, W., Li, H.: Listwise approach to learning to rank: theory and algorithm. In: International Conference on Machine Learning (2008)
Design and Application of Mapping Model for Font Recommendation System Based on Contents Emotion Analysis Young Seo Ji1(B) and Soon bum Lim1,2 1 Sookmyung Women’s University, Seoul 04310, Korea
{jyseo0102,sblim}@sookmyung.ac.ker 2 Research Institute of IT Convergence, Seoul 04310, Korea
Abstract. Fonts are an important tool that can compensate for the absence of nonverbal and paralinguistic means that are reflected in real-world situations. However, selecting an appropriate font is a process that heavily relies on aesthetic sense and experiential judgment, making it difficult for the general public who are not proficient in using fonts. Therefore, in this study, we intend to implement a service that automatically recommends fonts that match the message when content such as facial expressions and sentences are entered. To this end, we designed an experiment to interpret the emotions associated with different fonts and a model to map the actual content and fonts. In the process of identifying the emotion of the font, We selected emotion keywords to verify the relevance of fonts and quantified their emotional impressions. Since the emotional criteria for content extracted using a deep learning emotional analysis model differed from those for fonts, we devised a new mapping method. We created a mapping model that calculates the correlation between each emotional criterion and determines similarity. We applied this model to confirm the relationship between the emotions of the content and the fonts and developed a system that recommends fonts. Keywords: Font Recommendation · Emotion Analysis · User Experiment
1 Introduction Using a font in a text environment not only enhances design value, but it also affects the delivery of message information. Unlike in face-to-face interactions, textualized messages cannot reflect non-verbal expressions, and using fonts can compensate for the absence of non-verbal communication means in such an environment. Therefore, graphic designers use fonts to correctly convey and reinforce the messages they wish to convey and to establish an overall emotional tone [1]. Studies in various fields are actively investigating this topic. According to a study on the user awareness of the application of fonts to mobile text messengers, more than 70% of respondents stated that the effectiveness of emotional communication could be increased by applying appropriate fonts and that it is better to use different fonts depending on the content [2]. Using fonts that represent brand identity in marketing © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 397–408, 2023. https://doi.org/10.1007/978-981-99-4761-4_34
398
Y. S. Ji and S. Lim
doubled the rate of choosing a brand compared to when such fonts were not being used, and 75% of consumers tended to choose products with more appropriate fonts in the presence of inappropriate fonts [3]. In the e-book field, there was a difference in the fonts preferred by users depending on the types of books [4]. Thus, font selection is an essential factor that affects the conveyance of the meaning of the text. However, selecting an appropriate font suitable for each situation is a difficult task for ordinary people who do not have expertise in design and lack the formative sense. This is because it is very complicated to identify impressions of fonts based on one’s experience and relate them to one’s purpose. Users are aware of their purposes but do not know the criteria for mapping which fonts are appropriate for which purpose, so there are cases where font selection takes a considerable amount of time, and inappropriate fonts are thus selected. The existing font search services must check all fonts without considering other factors to select the desired font. Google font, a representative font search service, sorts of fonts by popularity and trend, and it is categorized only by the basic properties of fonts: Serif, Sans-Serif, Display, Handwriting, and Monospace. Users must spend a long time in these font searches, and it is not possible to confirm whether the mood that the user thinks will be induced and the mood that is commonly induced by the font are the same, which can give a different impression than expected. Another font search site, My Font, can be searched according to tags, but this is not helpful for users who do not know the intended purpose. To compensate for this, a prototype of a system that recommends fonts was designed. Since it is a system that recommends fonts based on emotions, the process of primarily checking the emotions of fonts was conducted. After selecting 19 keywords that can express fonts and 26 representative fonts to be used in the experiment, several users were asked to select fonts that had high, medium, or low relevance for certain keywords to conduct an experiment to confirm the relationship between fonts and keywords, and the emotional impressions of fonts were confirmed through keywords. Although an evaluation was conducted by designing an interface that allows the user to select keywords according to context through keyword value, the response was positive compared to existing search services, but there were many users who had difficulty choosing keywords that fit the situation. To solve this problem, a mapping model was designed to connect the emotions extracted from the content with the font, so that the font could be automatically recommended according to the emotion of the content. In this study, a new mapping method between fonts and content is designed for recommendation, and a font recommendation system is designed by applying the calculation method.
2 Related Works Various studies have been conducted to understand the emotions of fonts. For example, a study on English fonts was conducted to digitize fonts by associating them with vocabulary keywords [5]. In that study, keywords suitable for expressing English fonts were selected, and the keyword values for fonts were calculated. The keywords were collected and vocabulary that expressed fonts in previous font studies was selected for those keywords, and then the relationship between the fonts and the keywords was checked using crowd-sourcing. The keywords included both concrete characteristics such as “thin” and
Design and Application of Mapping Model
399
“angular” and nebulous characteristics such as “friendly” and “sloppy”. There were 31 keywords selected, along with six common properties of English fonts such as capitals, cursive, display, Italic, Monospace, and Serif to express the properties of the font, thus comprising a total of 37 keywords. The font and two keywords are presented to the operator. The operator selects a keyword that is the most relevant to the font. The task was structured in the Two Alternative Forced Choice (2AFC) method, in which one keyword was selected out of two keywords. Data accumulated using crowd sourcing was calculated using the maximum likelihood method and expressed as a single value for keywords for each font. Research was also conducted to understand the emotions of fonts in Hangeul (Korean alphabet) [6]. Emotions were confirmed through the PADO model, including Organized (O) in the pleasure, arousal, and dominance (PAD) model [7]. This was done to understand the font using four axes called PADO, and property value survey was conducted from the user. According to each axis, “positive-negative”, “active-passive”, “soft-strong”, and “organized-free” were presented as lexical discrimination scales, and the range of values between 0 and 5 was specified for the lexical discrimination scale. Since the user may feel difficulties using the proposed scale, 75 emotional adjectives that are highly related to the scale and can express fonts well are provided as additional information to reduce the individual’s perspective. A study that changed fonts in text-based messengers according to the emotions of fonts used a process of filtering fonts into two categories: positive and negative, and ranking them within each category. Using a crowd-sourcing technique, two fonts were presented to the workers, and the emotions associated with the font were checked by asking which font best represented the emotion of the category, and by applying this to the messenger, the effect of the application of the font on the message delivery was also confirmed [8]. In a study intended to recommend fonts according to context [9], fonts were recommended using a deep learning method rather than a method of checking the emotion of a font. Sentences with various meanings were collected, and the operator was asked to select a font that fits the sentence. A data set was created in such a way that two fonts were presented, and a more suitable font was selected. By learning the font according to the sentence, a font was recommended when a new sentence was entered. However, this method faces limitations in that it is very difficult to add new fonts.
3 Font Recommendation System Based on Emotional Keyword Selection 3.1 Selection of Emotional Keywords and Fonts To implement a system that expresses the impression of fonts based on keywords and recommends fonts based on this impression, the process of selecting keywords that can represent fonts was conducted. Referring to a study [10] that expressed the impression of Hangeul fonts in various vocabularies, this study used the 37 keywords used in Peter O’Donovan’s study. Among the 37 keywords, keywords that are difficult to apply to Hangeul fonts were excluded with the help of font design experts, while keywords were
400
Y. S. Ji and S. Lim
selected to balance the number of keywords expressing the appearance of the font and keywords expressing emotion. After the first selection of keywords in the above method, a survey was conducted to remove keywords that were difficult for actual users to understand. In subsequent keyword-font-related experiments, this survey has been a process conducted to efficiently construct the experiment and leverage the ability of the user to select keywords along with emotion-based automatic recommendation. The survey evaluators were instructed to present various fonts and select all keywords that they thought represented the images of the presented fonts well. According to the survey results, keywords selected by two or less out of 10 evaluators were removed. The keywords ultimately selected through the survey were “angular”, “technical”, “formal”, “modern”, “harmonious”, “disorderly”, “gentle”, “stiff”, “attention-grabbing”, “boring”, “calm”, “delicate”, “friendly”, “warm”, “dramatic”, “graceful”, “strong”, “determined”, and “playful”. It is advantageous to proceed with fonts while including as many fonts as possible. However, conducting an experiment using too many fonts causes the problem of increasing tiredness among the experiment participants, and there is a risk that the reliability of the experiment results will decrease toward the second half. Therefore, in this experiment, a representative font was chosen to be applied to the experiment by selecting a minimum font with representation by design. More than 200 fonts with high user preference and free use were collected, and the most common design attributes of Hangeul fonts were classified into Serif, Sans-serif, and Handwriting (other), and fonts were selected by classifying them into styles with basic styles and variant designs within each classification system. As shown in Fig. 1, the selected fonts were composed of four basic serif fonts, four basic sans-serif fonts, six variant serif fonts, five variant sans-serif fonts, and seven Handwriting (etc.), and the thickness values of the fonts were also varied.
Fig. 1. Selected fonts by basic design category
3.2 Implementing Interface for Examining Keyword Attributes Value by Font Since the commonly used survey method is to receive a keyword value for one font and to repeat the survey for the next font, the objectivity of the survey may be lacking, as
Design and Application of Mapping Model
401
fonts are not evaluated relative to other fonts. It is therefore difficult to relate various fonts and keywords. In addition, requiring a large amount of work can cause tiredness among participants. Therefore, a new experimental interface was implemented for this study. This interface was implemented using the Django framework (Fig. 2).
Fig. 2. Interface for associating keyword attributes with font
To investigate keyword attributes by font, keywords were presented to the evaluator, and the relationship between fonts and keywords was compared and classified step by step. This is a task that involves classifying relevance with keywords into three levels: high, medium, and low, and participants were allowed to classify fonts themselves in a drag-and-drop manner, and all fonts were displayed on one screen so that they could be visually checked and compared for all fonts. Therefore, when there was an incorrect classification in the middle of the task, it could be corrected and arranged. This was done to improve the convenience and accuracy of the experiment by making it easy for the user to check and modify simultaneously. 3.3 Calculating Keyword Property Values by Font As 61 people in total participated in the experiment, and since the responses to the survey were divided into three levels of fonts based on keywords, a process of converting the keyword selection result for each font was carried out to obtain the keyword value for each font. To convert the keyword value by font into a category value between 0–1 using the corresponding results, the keyword was calculated as the ratio of the total number of respondents from the number of respondents selected for each font, and the weights of 1, 0.5, and 0 were multiplied for the high, middle, and low ratios, respectively, then finally calculated as one value. If f is the font and the keyword is k, then fkh is when the number of responses to the keyword is high according to the font, fkm is when the number of responses to the keyword is middle according to the font, and fkl is when the number of responses to the keyword is low according to the font, and the keyword value fk according to the font is shown in Eq. (1). fk =
(fkh × 1) + (fkm × 0.5) + (fkl × 0) n
(1)
Here, n means the total number of respondents, and 61 is the number of participants in this experiment. The keyword values for all fonts were quantified through the above equation.
402
Y. S. Ji and S. Lim
3.4 Implementation and Evaluation of Keyword Selection Based Font Recommendation Service To evaluate the effectiveness of the keyword-based font service compared to the basic font search service, an evaluation was conducted by implementing a keyword selectionbased font recommendation service. An interface was designed that allows users to select the keywords they want, and fonts were recommended according to the selected keywords. In this case, the user was configured to select as many keywords as they desired. This was applied to a one-hot encoding method in which the keyword selected by the user was 1 and the keyword not used was 0, and the similarity with the keyword numerical value of each font was determined. Cosine similarity was used to enable a comparison of similarity, including the direction of keyword selection. nk means the number of keywords, which is 19, and the font similarity FS was calculated by dividing the dot product of the vector, which is the sum of product of the user-selected keyword Uk constructed by the one-hot encoding method and the keyword fk of the font, by the size of the user-selected keyword vector and the dot product of the font keyword value vector, as shown in Eq. (2). nk k=1 (Uk × fk ) (2) FS = nk nk 2 2 × (U ) (f ) k k k=1 k=1 A user evaluation was conducted to evaluate the keyword-based font recommendation service. The evaluators consisted of seven users who were proficient in using fonts and who had experience using various fonts along with seven users who were not proficient in using fonts and who used limited fonts. The evaluation scale consisted of three aspects: how convenient it is compared to existing services, the accuracy of how much the recommended font matches one’s intention, and satisfaction with the recommended font (Fig. 3).
Fig. 3. Font Recommendation Service based on keyword value
Before the evaluation process, it was explained how the keyword-based font recommendation interface was used, and four situations of font use were presented. (i) Choosing a font to use for presentation, (ii) Choosing a font to use for subtitles of interesting videos, (iii) Choosing a font to use when writing a report, and (iv) Choosing a font to use when expressing negative opinions on social media, and the evaluation is conducted after the font recommendation process using this service. The evaluation was
Design and Application of Mapping Model
403
graded used a 7-level Likert scale between 0 and 6. In addition, the keywords selected by the evaluators for each situation were also confirmed during the evaluation process. When choosing fonts to use for presentation materials, 13 out of 14 evaluators selected the keyword “modern,” followed by “formal,” selected by nine evaluators and “determined,” selected by seven evaluators, and “boring,” “friendly,” and “warm.”, “dramatic” and “disorderly” keywords were not selected by any evaluators. When selecting fonts for captions of interesting videos, 13 evaluators chose the keyword “attentiongrabbing”, while 11 evaluators chose the keyword “playful”. The keywords “angular”, “formal”, “stiff”, “boring”, “calm”, “graceful”, and “detected” were not selected. All 15 evaluators chose the keyword “formal” for the font to be used when writing the report, and nine chose the keyword “angular”. Finally, in the case of expressing negative opinions on social networking services (SNS), there were no keywords selected by more than seven people—or half of the 14 people—when using various keywords for each evaluator. Evaluators tended to select similar keywords when there was a slightly more explicit purpose, and when the ambiguous criterion of emotion was included, the keywords selected by evaluators varied. As a result of the evaluation, the convenience compared to the existing font search system was 4.5 points on average, thus showing a positive evaluation of this prototype system. Comparing each situation, the case of choosing a font to use for presentation materials and the case of selecting a font to use for subtitles of interesting videos, convenience, satisfaction, and accuracy were all above 4.5 points on average, and when choosing a font to use when writing a report, convenience and satisfaction averaged over 4.5 points and accuracy was 4.1, respectively. In the case of expressing negative opinions on SNS, there was a significant difference in the evaluations by users, and when looking at the evaluation results, the keywords selected by the users were more diverse than those in other cases. The evaluators suggested that they felt that it was difficult to relate keywords to situations expressing negative opinions (Fig. 4).
Fig. 4. Usability test result of font recommendation service on keyword value
In other words, it was confirmed that the keyword emotion-based font recommendation is more efficient than the existing search service, but the user may find it difficult to even select a keyword that fits the situation. Therefore, research was conducted to create a matching model that automatically recommends fonts when content is entered.
404
Y. S. Ji and S. Lim
4 Automatic Font Recommendation Based on Emotion Mapping The composition of the overall system of this study is shown in Fig. 5 When a user inputs various types of content such as facial expressions and text, the emotions of the content are analyzed using a deep-learning model. Emotions extracted from content are used to recommend fonts that match the values of the pre-stored font-specific keyword databases according to their similarity through experiments and similarity calculations.
Fig. 5. Full system flow diagram of font recommendation system
Since the content emotion classification criteria and font keywords do not match, it is necessary to check the association between the content emotion classification criteria and the previously selected font keywords to recommend fonts from the emotion of the content (Fig. 6).
Fig. 6. Utilizing emotion space to identify associations between different criteria
To confirm the association of each vocabulary, the PAD (pleasure, arousal, and dominant) emotion model proposed by Mehrabian was used. By using the PAD model, all emotions can be expressed on a plane consisting of three axes. Two computational models using the corresponding emotion model were designed to implement font recommendation. 4.1 Design Mapping Model A model was designed to calculate the Euclidean distance in the PAD space by converting the keywords of the font and the emotion of the content into one PAD representative value, and a study [11] was used to present PAD values for about 20,000 English words to convert keyword figures in fonts to PAD figures. For each font, the average value of the sum of the values multiplied by the PAD value corresponding to each keyword was calculated. However, repetitive calculation of the decimal point value in the above
Design and Application of Mapping Model
405
calculation process caused a problem in that the font was distributed at a single point due to skewness wherein the total sum of the calculation results was distributed between the existing median value and 0.5 (Fig. 7).
Fig. 7. Coordinate of emotion on PAD model
Therefore, a method of determining similarity using the correlation coefficient between the emotion classification reference price font keywords was devised. This newly devised model is a method of analyzing and recommending the correlation between keywords and emotion classification criteria. The classification criteria used in content emotion analysis vary from researcher to researcher, but most of them use similar criteria. For example, Ekman classified six emotions (happiness, sadness, disgust, fear, surprise, and anger) as basic human emotions [12], while Plutchik classified eight emotions (happiness, sadness, disgust, fear, surprise, anger, fear, and trust) as basic emotions [13]. Although there are some differences, most deep-learning multiple emotion analysis methods use these two classification criteria. In addition, the classifications used in emotional classification studies were summarized. Using the PAD value of the font keyword, the Pearson correlation coefficient with the content emotion classification criteria was obtained, and a mapping model using the correlation coefficient was created (Fig. 8). The extracted emotion value of the content is converted into a numerical value of the font keyword using a correlation coefficient. In Eq. (3), Ck is the value converted from the content to the keyword used in the font, and it is the same as the sum of the results obtained by multiplying the emotion Ce of the content by the correlation coefficient Corr ke with the corresponding keyword. Meanwhile, ne is the number of emotions extracted from the content. Ck =
n
(Ce × Corr ke )
(3)
1
In other words, the content also has a value for the font keyword. The Pearson correlation coefficient has a positive linear relationship that is closer to 1 and a negative linear relationship that is closer to −1, and when converting to a keyword according
406
Y. S. Ji and S. Lim
Fig. 8. Pearson correlation coefficient heat map between emotion criteria and font keywords
to emotion, if the correlation coefficient is multiplied by the correlation coefficient, the keyword with a correlation coefficient close to 1 is conserved and remains similar to the original data. Through this process, since the content and the font have values for the same criteria, it is possible to compare the similarity between the font and the content. A model for determining the similarity between fonts and content was constructed using the cosine similarity scheme for comparison while including the overall orientation of emotions. Recommendation Result and Evaluation Using Mapping Model. The current system implements a font recommendation model for facial expressions and sentences. Facial expression emotion extraction used the DeepFace API [14] and sentence analysis used the IBM Watson API [15]. It can also be recommended to apply the emotion classification criteria of the added model so that it can be expanded to different contents (Fig. 9).
Fig. 9. Fonts recommended by applying face and text to the correlation model
Based on this model, the recommendation results were evaluated. The users were presented with a recommended font from facial expressions, sentence content, and content, and answered on a 0–6 Likert scale for the recommendation accuracy and the overall efficiency and satisfaction of the system. Here, accuracy is a question of how well the font matches the presented content. Fifteen people participated in the evaluation, and they answered that they had average accuracies of 3.85 for facial expressions and 4.77
Design and Application of Mapping Model
407
for sentences. Overall, the accuracy was 4.24. In addition, the efficiency of the service that automatically recommends fonts through content was high, with an average of 5.09, and the overall satisfaction was 4.3 (Fig. 10).
Fig. 10. Usability test result of recommended font
5 Conclusion In this study, keywords that can express the characteristics of fonts and can be applied to Hangeul fonts were selected, and fonts of various designs were selected to conduct experiments with the aim of connecting keywords represented by fonts. Based on the experimental results, keywords for each font were calculated as numerical values between 0 and 1, and an evaluation was conducted on a keyword-based font recommendation service prototype that recommends fonts based on keywords selected by a user. The evaluators agreed that fonts can be found more conveniently than existing services and expressed positive opinions on this service. Through the evaluation, the effectiveness of the keyword-based font recommendation service was confirmed, but it was found that it was difficult for real users to relate to keywords that fit their desired situation. We designed a service that automatically recommends fonts according to content to solve this problem. Through correlation analysis, the mapping model with the correlation coefficient applied allows users to easily and quickly select fonts that harmonize with the content, thus reducing the cost of font selection. In addition, since the mapping model can use various emotion classification criteria, it has the advantage that it can be immediately applied to the computational model, even if a model other than the currently applied deep learning emotion analysis model is used. However, there is a limitation in that few fonts are currently used for recommendations. Therefore, in future studies, we will design a deep learning model that can automatically extract keywords from fonts through image-based learning so that various fonts can be applied to recommendations and to implement an interface that can be conveniently used by actual users.
References 1. Koch, B.E..: Human emotion response to typographic design. Dissertation, University of Minnesota (2011)
408
Y. S. Ji and S. Lim
2. Koh, Y.-W., Sohn, E.-M., Lee, H.-J.: Users’ perception on fonts as a tool of communication and SMS. Arch. Des. Res. 20(1), 133–142 (2007) 3. Doyle, J.R., Bottomley, P.A.: Font appropriateness and brand choice. J. Bus. Res. 57(8), 873–880 (2004) 4. An, M.-N., Kim, N.-Y., Chung, J.-H., Lee, Y.-K., Guo, J.-Y., Yoon, J.-Y.: A preference study on serif and sans serif typefaces of e-book. J. HCI Soc. Korea 15(3), 5–12 (2020) 5. O’Donovan, P., L¯ıbeks, J., Agarwala, A., Hertzmann, A.: Exploratory font selection using crowdsourced attributes. ACM Trans. Graph. (TOG) 33(4), 1–9 (2014) 6. Kim, H.-Y., Lim, S.-B.: Application and analysis of emotional attributes using crowdsourced method for Hangul font recommendation system. J. Korea Multimedia Soc. 20(4), 704–712 (2017) 7. Mehrabian, A.: Pleasure-arousal-dominance: a general framework for describing and measuring individual differences in temperament. Curr. Psychol. 14(4), 261–292 (1996) 8. Choi, S., Aizawa, K.: Emotype: expressing emotions by changing typeface in mobile messenger texting. Multimedia Tools Appl. 78(11), 14155–14172 (2018) 9. Shirani, A., Dernoncourt, F., Echevarria, J., Asente, P., Lipka, N., Solorio, T.: Let me choose: from verbal context to font selection. In: Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 8607–8613 (2020) 10. Lee, Y.: Impression of Hangeul typeface. J. Korean Soc. Typogr. 9(2), 28–55 (2017) 11. Mohammad, S.: Obtaining reliable human ratings of valence, arousal, and dominance for 20,000 English words, In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics, pp. 174–184 (2018) 12. Ekman, P.: Universals and cultural differences in facial expressions of emotion. In: Nebraska Symposium on Motivation, University of Nebraska Press (1971) 13. Plutchik, R.: A psychoevolutionary theory of emotions (1982) 14. https://github.com/serengil/deepface. Accessed 14 Oct 2022 15. https://www.ibm.com/cloud/watson-natural-language-understanding. Accessed 14 Oct 2022
Time Series Prediction of 5G Network Data Based on Improved EEMD-BiLSTM Prediction Model Jianrong Li1(B) , Zheng Li1 , Jie Li2 , Gongcheng Shi1 , Chuanlei Zhang1 , and Hui Ma3 1 Tianjin University of Science and Technology, Tianjin 300000, China
[email protected]
2 China Unicom Research Institute, Beijing 100000, China 3 Yunsheng Intelligent Technology Co., Ltd., Tianjin 300000, China
Abstract. 5G networks are designed to support various emerging applications, such as the Internet of Things (IoT), autonomous vehicles, and smart cities, which require low latency, high throughput, and reliable connectivity. Accurately predicting the traffic load and network performance in 5G networks is critical for meeting the demands of these applications. In this paper, a hybrid model composed of integrated empirical mode decomposition (EEMD) method and BiLSTM network is used to predict 5G network data. And on this basis, aiming at the information leakage problem of EEMD, an improved EEMD decomposition method is proposed. The experimental results demonstrated that the improved EEMD-BiLSTM model is more accurate than the EMD-BiLSTM model. Keywords: Time series prediction · 5G network data · EEMD-BiLSTM · Optimized sliding-window
1 Introduction Accurate prediction of traffic load and network performance in 5G networks is crucial for future 5G development needs, and time-series prediction of 5G network data is an important task because it can help improve the performance and reliability of 5G networks. Time-series forecasting models are often employed to analyze historical data and identify trends and patterns to gain insight into the future behavior of 5G networks. These predictions can help network operators optimize resource allocation, improve network efficiency, and enhance user experience. For example, if a predictive model predicts a traffic spike during a certain time period, network operators can allocate more resources during that time period to ensure the network can handle the increased demand [1, 2]. Modelling of 5G network data using deep learning techniques. Mining the connection with the predicted object overcomes the shortcomings of traditional time series prediction methods and can greatly improve prediction accuracy [3]. The commonly used deep neural network algorithms in time series forecasting problems include recurrent neural network (RNN), long short-term memory network (LSTM) and convolutional neural © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 409–420, 2023. https://doi.org/10.1007/978-981-99-4761-4_35
410
J. Li et al.
network (CNN). Li Xiaoxue et al.[4] introduced the Attention mechanism on the basis of the LSTM model, and assigned different weights to the hidden state of the LSTM by mapping weights and learning parameter matrices, reducing the loss of historical information and strengthening the role of important information in influencing information; Z Gao [5] used the seasonal time difference method to stabilize the time series characteristics, and constructed a smooth long short-term memory (SLSTM) traffic prediction model to accurately predict 5G networks. At present, there is a large research space for the hybrid model of 5G network traffic forecasting, and the hybrid forecasting model combining signal decomposition technology and neural network has a better effect [6]. Commonly used sequence decomposition methods are Empirical Mode Decomposition (EMD), Ensemble Empirical Mode Decomposition (EEMD), Variational Mode Decomposition (VMD) and Wavelet Transform (WT). EEMD and VMD can overcome the modal confusion of EMD and are a better choice. According to the literature [7–9], it can be known that the BiLSTM neural network has a good performance in predicting time series data. However, when the EEMD decomposition method is used in the field of time series forecasting, there is an information leakage problem in the process of decomposing the original data [10]. Because the EEMD method is a global one-time decomposition, the training set and the test set will be decomposed together, which will cause the training set component to contain the information in the test set. Normally, when the training set is input into the model for training, the model will not The information in the test set should be known. However, the eigenpattern component (IMF) of the training set generated by the EEMD decomposition will contain information from the test set. This will cause the model to learn information from the test set in advance, resulting in Information leakage. Based on the EEMD-BiLSTM model for predicting 5G network traffic, this paper proposes an improved EEMD decomposition method for the EEMD information leakage problem, and completes the experiment. It is found that the improved EEMD-BiLSTM model is more accurate than the EEMD-BiLSTM hybrid model.
2 Related Work Currently, predictive models that combine sequence decomposition techniques with neural network algorithms are very popular. Sequence decomposition technology can smooth and denoise the original data and improve the accuracy of neural network prediction. This paper proposes an improved EEMD-BiLSTM network prediction model for the problem of EEMD information leakage. The following is a detailed description. 2.1 Ensemble Empirical Mode Decomposition Ensemble empirical mode decomposition (EEMD) takes advantage of the statistical property that white noise has a uniform frequency distribution, and on the basis of EMD, Gaussian white noise is added to the signal so that the signal has continuity at different scales. This changes the characteristics of the signal polarisation points and promotes anti-mixing decomposition, which can effectively avoid the phenomenon of pattern overlap [11]. When EEMD studies time series prediction, the algorithmic process
Time Series Prediction of 5G Network Data
411
is to decompose the original data into several MF components by EEMD. The training and test sets are then divided, and the training set is fed into the prediction model to adjust the parameters and obtain the optimal prediction model. The test set is then fed into the optimal model for prediction. This decomposition method is a global decomposition [12]. In other words, the data to be divided into the test set is already taken into account in the EEMD decomposition, which is equivalent to knowing the data to be predicted in the future. For the division of the data set Said after the first decomposition, this gives rise to the problem of information leakage. If one considers dividing the training and test sets first and then performing the EEMD decomposition. This situation is also unreasonable and cannot be performed because the training and test sets are generally divided in a ratio of 8:2. So the training set and the test set are very different in number. The number of intrinsic mode functions (IMFs) generated by different numbers is also different [13]. Therefore, the IMF components decomposed from the training set and the test set are inconsistent, and the input data sets with the same structure cannot be constructed, and the test set data cannot be input into the model according to the sample data structure of the training set [14]. Therefore, it is not feasible to divide the data set first and then decompose it. 2.2 BiLSTM Neural Network Model The LSTM model is an extension of the RNN model. The LSTM model consists of the input word Xt , the cell state Ct , Temporary cell state C˜ t , hidden layer state ht , forgetting gate ft , memory gate it , output gate ot [15]. The Bi-LSTM model is a combination of two LSTM models, one for training the forward input data and one for training the reverse input data, which can better capture the bi-directional data features, better capture the data features in both directions [16]. 2.3 EEMD-BiLSTM Model Since the network data are non-smooth non-linear complex data and the network data have a large range of values. if directly input into the BiLSTM model for training and prediction, it is conceivable that the results are often unsatisfactory [17]. Therefore, the original data set is first decomposed into several smooth eigenmodal (IMFs) components using the integrated empirical modal decomposition (EEMD) method, The IMF component diagram is shown in Fig. 1, and each IMF component is normalized and divided into a training set and a The training set is then input to the bidirectional long and short-term memory (BiLSTM) model to learn [18], to obtain the learned model, and then the test data are input to the learned model for prediction, and the predicted values of each IMF component are summed to obtain the final prediction value Y. Flow chart of EEMD-BiLSTM model is shown in Fig. 2, and evaluation criteria such as RMSE and MAE are calculated to judge the model prediction effect. 2.4 Chapter Summary This chapter introduces that EEMD can avoid modal aliasing, but there are still some problems that need to be solved. EEMD-BiLSTM can fit 5G network traffic data very
412
J. Li et al.
Fig. 1. EEMD decomposition IMF diagram
Fig. 2. Flow chart of EEMD-BiLSTM model
well [19]. If the EEMD information leakage problem is solved, the hybrid prediction model, effectiveness and feasibility will be further improved.
Time Series Prediction of 5G Network Data
413
3 Improved EEMD-BiLSTM Based Algorithm Model 3.1 Data Pre-processing The quality of data involves many factors, including accuracy, completeness, consistency, timeliness, credibility, and interpretability, which directly determine the prediction and generalization ability of the model [20]. This dataset is a 5G video streaming dynamics tracking dataset collected by major mobile operators in Ireland and consists of network key performance indicators (KPIs), including channel-related metrics, signal-to-noise ratio, and download and upload rates. The data in this data set fluctuates greatly, there is no obvious linear relationship between the data, and there are missing values and some outliers, so we must clean the data set, denoise the data set, so that the prediction model can be trained more good. Its preprocessing steps are as follows: (1) According to the Pearson correlation coefficient, calculate the correlation between each characteristic variable and the target field, and remove the characteristic variable with a correlation degree R < 0.3; (2) Use the average value to fill the features with few missing values, and remove the feature variables with more than 80% missing values; (3) Use abnormal value detection on the data set, find abnormal data “2147483647” and ‘-’ in it, and delete the entire data containing the abnormal value. Finally, a 25,700 pieces of 5G network data with 15 characteristic variables are obtained, The explanation of dataset field names is shown in Table 1. Table 1. Introduction field names in the dataset Timestamp
Longitude
Latitude
Speed
CellID
RSRP
RSRQ
SNR
CQI
RSSI
State
NRxRSRP
NRxRSRQ
DL_bitrate
UL_bitrate
3.2 Sliding Window EEMD-BiLSTM Prediction Model For 5G Network Speed In this experiment, the Ireland Mobile 5G network dataset has a total of 25,700 pieces of data. The EEMD-BiLSTM model construction process is to decompose the original EEMD data, and then input all the decomposed IMF components into the BiLSTM model for training and prediction. The sliding window EEMD-BiLSTM prediction model is divided into several windows, each window is decomposed, and the window IMF component is input into the BiLSTM model for training and prediction. When EEMD is decomposed, add white noise to the original data within the sliding window size W1 to obtain new data: Yi(t) = X(t) + ri(t) (0 < i < W 1)
414
J. Li et al.
Perform EMD decomposition on the new signal to get the IMF component: M −1 (i) Yit = IMF (i) m (t) + rM (t)(0 < i < W 1, M = length(IMF)) m=1
Perform EEMD decomposition in a window, and then slide the window to decompose the data of the next window. The entire algorithm flow can be divided into an outer loop and an inner loop. The following is a detailed description in two parts: (1) Outer circulation. This algorithm uses two sliding windows, one is that the EEMD decomposition uses a sliding window, and the other is that a sliding window is used when the IMF combined data set is used to construct the (X, Y) data set of the input model. The outer loop refers to the decomposition process of the sliding window EEMD. Set a sliding window with a window size of 20,000 and a step size of 50. The Irish network speed data is decomposed into EEMD segments by means of a sliding window, and only one window size is decomposed at a time. The amount of data, starting from the decomposition of the data in the range of [1:20000] to generate 9 IMF components and 1 residual component. The components are internally looped. When the internal loop is completed, the window slides for 50 steps, and then the window data [51:20050] is decomposed by EEMD to generate new IMF components. Repeat the above steps until all the data is decomposed. (2) Inner circulation. The 9 IMF combined data sets generated by the outer loop use a sliding window of length 500 with a step size of 50. The sliding window constructs (X, Y) data sets input to the neural network for each IMF combined data set, for example The first X is the window data of length 500 from [1:500], Y is the data of length 50 from [501:550] in the IMF component, and the second X corresponds to the length of [51:550], Y For [551:600], 390 input data (X, Y) are constructed in this way, and the first 389 X, Y are input into the BiLSTM model for training and learning, and the last X[19501:20000] corresponds to Y[20001:20050] The value is unknown relative to the current window, so the last X is used for prediction, and the predicted value Y[20001:20050] is obtained and saved, and the above operation is repeated for each IMF combined data set, and finally all the predicted values Add Yimfi[20001:20050] to get the predicted value Y˜[20001:20050] of the original data, save the data, so far the inner loop is completed once; the outer loop window is moved once, that is, slide 50 steps, and then do In the inner loop, save the predicted value, and so on until the last outer loop, input the first 389 X into the deep neural network BiLSTM for training, do not make predictions, and finally get the predicted value Y[20001:25700] . For the BiLSTM model, the input data is an IMF component training set within a window: it = σ (Wimfi ht−1 , ximfi,t + bimfi ) Data output by the BiLSTM model Ot, also for the predicted value in this window, which is in the next window: Ot = σ (Wimfi ht−1 , ximfi,t + bimfo
Time Series Prediction of 5G Network Data
415
Finally, the predicted value Y˜ [20001:25700] of the data can be obtained, the flow chart of sliding window decomposition is shown in the Fig. 3: 25700 Y˜ imf (t) Y˜ = n=20001
Fig. 3. Flow chart of sliding window decomposition
Calculate the four evaluation standard values RMSE, MAE, MAPE and R2 using the actual value and predicted value. The algorithm flow chart adds two sliding windows on the basis of Fig. 2, as shown in Fig. 4, and the structure diagram of the model is shown in Fig. 5.
Fig. 4. Flow chart of sliding window-based EEMD-BiLSTM algorithm
416
J. Li et al.
Fig. 5. Structure of sliding window EEMD-BiLSTM model
4 Experiment Evaluation and Discussion Using the EMD and EEMD functions in the PyEMD package, the target data is decomposed into several IMF components, and the IMF components are input into the BiLSTM model. The BiLSTM neural network structure is shown in Fig. 6, BiLSTM is an input layer with 64 neurons, and a hidden layer with 32 neurons. Layer, a fully connected output layer. The activation function of the input layer and the hidden layer is relu, the activation function of the output layer is linear, and the dropout is set to 0.3. Select control experimental models EMD-BiLSTM and EEMD-BiLSTM. Figure 7 shows the comparison of the last 200 real values and predicted values for the three models, and Fig. 8 shows the predicted values of the three models. For the comparison chart, the four evaluation standard values are shown in Table 2, and Fig. 9 is a line chart of the evaluation values of the three models. The sliding window EEMD-BiLST model not
Time Series Prediction of 5G Network Data
417
Fig. 6. BiLSTM neural network structure
only avoids the problem of information leakage, but compared with the EMD-BiLSTM prediction accuracy, it not only does not decrease, but it has improved. The root mean square error RMSE has decreased by 5.53%, and the average absolute error MAE has decreased. 14.18%, the average absolute percentage error MAPE decreased by 5%, and R2 increased from 0.981 to 0.994.
Fig. 7. Real value versus predicted value
In general, the use of sliding windows improves the prediction accuracy of the model, and the fitting effect is getting better and better, which further improves the effectiveness and accuracy of the model. And compared with the newly proposed 5G network traffic forecasting model (SLSTM) [5], its RMSE error is smaller and the R2 value is higher, which means that the data fitting effect of the model in this paper is better.
418
J. Li et al.
Fig. 8. Comparison of the prediction results of the three models
Table 2. Comparison of evaluation criteria values
EMD-BiLSTM
RMSE
MAE
MAPE (%)
723.53
282.1
40.42
EEMD-BiLSTM
689.15
249.56
39.58
Sliding Window EEMD-BiLSTM
683.76
242.47
38.39
Evaluate Result
1500 1000 500 0 RMSE EMD-BiLSTM
R2 MAE MAPE % EEMD-BiLSTM SW-EEMD-BiLSTM
Fig. 9. Folding line chart of evaluation criteria for the three models
Time Series Prediction of 5G Network Data
419
5 Conclusion This paper presents a hybrid model for predicting 5G network data that combines the ensemble empirical mode decomposition (EEMD) method and the bidirectional long short-term memory (BiLSTM) network. An improved EEMD decomposition method is proposed to address the issue of information leakage in EEMD. Additionally, Dropout is added to each layer of the BiLSTM network to prevent overfitting and reduce complexity among neurons. The experimental results show that the proposed method provides a better solution to the EEMD information leakage problem and makes the prediction model more rigorous and feasible. There are still some details worth noting in this paper. Although the EEMD method is better than the EMD decomposition effect, it is more time-consuming than the EMD. In the future, we can think about this problem and propose an optimization scheme. This belongs to the optimization of the internal calculation of the EMD algorithm and requires a strong Knowledge of mathematical theory.
References 1. Sheng, H., Zhang, Y.: Research on network traffic modeling and forecasting based on ARIMA. Commun. Technol. 52(4), 903–907 (2019) 2. Lim, B., Zohren, S.: Time-series forecasting with deep learning: a survey. Phil. Trans. R. Soc. A 379(2194), 20200209 (2021) 3. Zheng, J.D., Cheng, J.S., Yang, Y.: Improved EEMD algorithm and its application. Vibr. Shock 32(21), 21–26 (2013) 4. Li, X., Jiang, C., Chi, M., et al.: Research on Traffic prediction based on attention mechanism and long short-term memory neural network. Comput. Inf. Technol. 30, 14–16 (2022) 5. Gao, Z.: 5G traffic prediction based on deep learning. Comput. Intell. Neurosci. 2022, 1–5 (2022) 6. Dai, S., Chen, Q., Liu, Z., et al.: Time series forecasting method based on EMD-LSTM. J. Shenzhen Univ. Sci. Technol. 37(3), 265–270 (2020) 7. Zheng, J., Zhang, B., Ma, J., et al.: A new model for remaining useful life prediction based on NICE and TCN-BiLSTM under missing data. Machines 10(11), 974 (2022) 8. Lin, Y., Chen, K., Zhang, X., et al.: Forecasting crude oil futures prices using BiLSTMAttention-CNN model with Wavelet transform. Appl. Soft Comput. 130, 109723 (2022) 9. Yao, H.-R., Li, C.-X., Zheng, X.-J., et al.: Short-term load combination prediction model integrating adaptive chirp modal decomposition and BiLSTM. Power Syst. Prot. Control 50(19), 58–66 (2022) 10. Kong, D.-T., Liu, Q.-C., Lei, Y.-G., et al.: An improved EEMD method and its application. J. Vibr. Eng. 28(6), 1015–1021 (2015) 11. Xu, X., Xu, G., Wang, X., et al.: Empirical mode decomposition (EMD) and its applications. J. Electron. 37(3), 581 (2009) 12. Liu, S.F., Qin, S.R., Berlin: Problems and solutions in EMD. In: Proceedings of the 9th National Conference on Vibration Theory and Applications. National Conference on Vibration Theory and Applications, Hangzhou (2007) 13. Liu, H., Zhang, M., Cheng, J.: Processing of EMD endpoint problems based on polynomial fitting algorithm. Comput. Eng. Appl. 40(16), 84–86 (2004) 14. Qiu, J., Zheng, H., Cheng, Y.H.: Research on multi-scale LSTM-based prediction models. J. Syst. Simul. 34(7), 1593 (2022)
420
J. Li et al.
15. Liu, J.Y., Zhang, C.R., Qi, J.J.: PM2.5 concentration prediction model based on sliding window and LSTM. J. Qiqihar Univ. (Nat. Sci. Ed.) 38, 87–94 (2022) 16. Xu, Y., Wang, Z., Wu, Z.: A stock trend prediction model based on CNN-BiLSTM multifeature fusion. Data Anal. Knowl. Disc. 5(7), 126–138 (2021) 17. Wang, J.D., Du, C.: Short-term load forecasting model based on Attention-BiLSTM neural network and meteorological data correction. Power Autom. Equip. (2022) 18. Abduljabbar, R.L., Dia, H., Tsai, P.-W.: Unidirectional and bidirectional LSTM models for short-term traffic prediction. J. Adv. Transp. 2021, 5589075 (2021). 16 pages 19. Zhan, Y., Sun, S., Li, X., Wang, F.: Combined remaining life prediction of multiple bearings based on EEMD-BILSTM. Symmetry 14(2), 251 (2022) 20. Aouedi, O., Piamrat, K., Parrein, B.: Intelligent traffic management in next-generation networks. Future Internet 14(2), 44 (2022)
CWA-LSTM: A Stock Price Prediction Model Based on Causal Weight Adjustment Qihang Zhang, Zhaoguo Liu, Zhuoer Wen, Da Huang(B) , and Weixia Xu Institute for Quantum Information and State Key Laboratory of High Performance Computing, College of Computer Science and Technology, National University of Defense Technology, Changsha 410073, China [email protected]
Abstract. With the advent of the era of big data, various types of data prediction models have been widely studied by scholars. Time series data refers to relevant data where the same features are recorded over consistent intervals of time, reflecting the changing trend of random variables. Due to its wide applications in fields such as weather changes, financial market fluctuations, and battlefield situational awareness, research on predicting time series data is vital. Traditional prediction models are based on mining the correlation of factors in the data. However, correlations do not always reflect the interactions between factors. In contrast, causality reveals the essential relationship between factors in time series data. In this paper, we propose a novel method called Causal Weight Adjustment-LSTM (CWA-LSTM), using a causal discovery algorithm to predict time series data. The experimental results demonstrate the effectiveness of our method in stock price prediction. Keywords: Causality Discovery · Time Series · Stock Price
1 Introduction Time series data is very common in daily life, widely existing in weather and climate, urban transportation, financial economy and other fields. The prediction and analysis of time series data is considered one of the most challenging tasks in the field of artificial intelligence and machine learning. Stock prices are a common time series data. In recent years, the prediction of stock prices has received widespread attention from investors and researchers due to the potential for high returns. However, stock price data is characterized by nonlinearity, high noise, and instability in response to emergency events, making accurate prediction particularly challenging. In this paper, we focus on predicting stock price data, attempting to reduce prediction errors and improve accuracy with a method based on causal discovery. In previous research for stock price prediction, many mature methods have been proposed, such as autoregressive moving average model (ARMA) [1], autoregressive integrated moving average model (ARIMA) [2], and generalized autoregressive conditional heteroskedasticity model (GARCH) [3]. These representative methods use statistical models for prediction and have some assumptions about the data distribution © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 421–432, 2023. https://doi.org/10.1007/978-981-99-4761-4_36
422
Q. Zhang et al.
such as Gaussian noise, therefore sometimes fail to achieve good results in practical predictions. The rapid development of deep learning makes it possible to avoid these flaws [4]. Compared with traditional methods, deep learning algorithm not only does not need to assume data distribution, but also can model nonlinear relationship. Therefore, more and more researchers are using deep learning methods for stock price prediction. Researchers feed large amounts of historical stock price data into the model, and then train the model to predict future stock movements. However, current methods have poor stability in long-term prediction tasks. We noticed that when using deep neural networks, people select factors that are correlated with the prediction objective as input data. As is well known, correlation does not imply causation [5]. How to mine causality to improve prediction accuracy is an important question for stock price prediction. To achieve this goal, an effective method was devised. First, we selected a suitable set of stock factors as features and used a causal discovery algorithm to determine the causal relationship between these features and the objective. Next, we attempted to adjust the weights of the features that had a causal relationship with the objective, and used the adjusted weights to predict stock prices for better performance. The main contributions of this paper can be summarized as follows. 1) To the best of our knowledge, we applied causal weighting adjustments to stock price forecasts for the first time. 2) Based on the PC algorithm and LSTM, we propose CWA-LSTM. This method is capable of discovering causality and adjusting feature weights to predict time series data. 3) Our model achieved astonishing performance when tested on a real stock dataset. The results demonstrate the benefits of incorporating causality and causal weight adjustments into prediction models. The remainder of this paper is organized as follows. In Sect. 2, we review related work. Section 3 provides a detailed description of the method we devised. In Sect. 4, we conduct a series of comparison experiments to evaluate the effectiveness of the proposed method. Finally, we summarize our findings and conclude the paper in Sect. 5.
2 Related Work Over the past few decades, fundamental and technical analysis have been used to predict stock prices, producing high returns. However, with the explosion of data in recent years, machine learning techniques have gained significant traction in stock price prediction. Various methods, including support vector machines [6], perceptual machines [7], artificial neural networks [8] and decision trees [9], have been applied to this domain, leading to considerable improvements in prediction accuracy. Qiu [10] combined global search techniques (GA/SA) with artificial neural networks. Despite the considerable progress made, investors have faced significant challenges such as complex feature engineering and poor adaptability. With the rapid development of technology, deep learning has emerged as the primary technique for solving most AI problems. Stoean [11] developed stock prediction models based on LSTM and CNN respectively, and used their prediction results to construct buying and selling strategies. Kim [12] combined LSTM and CNN to predict stock prices from both time series data and stock trend charts. Siami-Namini [13] compared ARIMA and LSTM with RMSE criteria, and demonstrated that LSTM performed better
CWA-LSTM: A Stock Price Prediction Model
423
than ARIMA. Lu W [14] proposes a CNN-BiLSTM-AM method to predict the stock closing price of the next day. Most prediction models can only find the potential correlation between variables, but are not capable of revealing the causal relationship between factors and stock prices. Researchers attempted to introduce causality into the research field with Granger causality test. The classical Granger causality test was proposed by C.W.J. Granger in 1969 [15], which stimulated the development of econometrics. Hiemstra and Jones [16] tested the data with nonlinear causality and confirmed the validity of Granger causality test. Zhuo [17] found a causal relationship between consumer attitudes surveyed by the University of Michigan and consumption trends in the U.S. through linear and non-linear Granger causality tests. However, there is no solid causal theory foundation for Granger causality test. It’s been recognized that proving causality requires Counterfactual reasoning [18]. The causality is more stable and doesn’t change over time, which is of more interest to equity investors. Therefore, it is urgent to study the real causal relationship between stock factors. The existing literature reveals that identifying the true causal relationships among stock factors is not getting enough attention. Hu [19] proposed an additive noise model with conditional probability to address the problem of high-dimensional dynamic stock market many-to-one causality discovery and effectively uncover the relationship between portfolio selection factors and stock returns. Zhang [20] developed a causal feature selection (CFS) algorithm based on the causality theory to select the optimal set of factors for stock market analysis. Different factors have varying degrees of impact on predicting stock prices. We believe that factors with causal relationships are more important for predicting stock prices, and therefore, should be given more weight in the prediction process. To address this, we propose a method that combines causal discovery, multi-factor feature weight adjustment, and deep neural networks, resulting in a significant improvement in stock price prediction accuracy. Our approach differs significantly from existing research on stock price prediction.
3 CWA-LSTM Method Despite the potential benefits of incorporating causal relationships in stock price prediction, there is limited research on this topic. Additionally, feature weight adjustments in existing methods are mostly based on correlation, which may not accurately reflect the real causal relationships. In this context, we propose a novel forecasting method, CWA-LSTM (Causal Weight Adjustment-LSTM). This method first utilizes a causal relationship discovery algorithm to identify causal relationships among multiple stock factors, and then adjusts the weights of the causal factors that have causal relationship with the stock price to increase their importance. Finally, the factors with adjusted weights are fed into LSTM neural network as a feature set for prediction. 3.1 Multi Factor Forecasting In stock price forecasting, a great attention is given to the selection of factors and model design. Traditional stock forecasting methods typically rely on single factors, such as
424
Q. Zhang et al.
opening and closing prices, for stock analysis. However, due to the low dimensionality of the data, it is challenging to capture the changing patterns of stock prices in detail, which reduces the accuracy of stock price prediction. Multi-factor analysis, in contrast, is used to mine multiple useful stock factors for prediction. There are several considerations in factor selection. Firstly, the chosen factors need to be representative and capable of reflecting the actual trading situation and future trend of stock changes. Secondly, the selected factors should cover different aspects. These factors can be broadly divided into the following categories: size factors, valuation factors, transaction factors, and price factors. In this paper, a total of 33 factors, including price factors and transaction factors, are selected as initial factors. In multi-factor prediction methods, factors are typically assigned initial weights and sent to the neural network for back propagation to update their weights. However, not all factors have equal importance to the prediction results. We believe that factors with causal relationships to the predicted values should have higher weights. Therefore, we propose using a causality discovery algorithm to identify causal relationships between factors and adjust their initial weights accordingly before feeding them into the neural network. 3.2 PC Algorithm and Causal Weighting Adjustment Traditional stock price prediction methods mainly focus on the correlations between multiple factors of a stock. However, causality provides a strict distinction between cause and effect variables and can reveal the mechanisms behind the occurrence of stock price changes, guiding people on what to do in a way that correlations cannot. Therefore, we used a causality discovery method based on inter-temporal data to map out causal relationships and identify the underlying factors that drive stock price changes. The current mainstream causal discovery methods include constraint-based methods, causal function model-based methods, and hybrid methods. Constraint-based methods have important applications in learning high-dimensional causal structures. These methods learn causal graph structures in a heuristic way by determining conditional independence. Causal function model-based methods start from the causal mechanism of data generation and use the causal function model to determine causal relationships and identify causal directions. Hybrid approaches combine constraint-based and causal function model-based methods in an attempt to effectively improve upon the shortcomings of causal function models while overcoming the challenge of controlling false discovery rates on high-dimensional data. The current mainstream causal discovery methods include constraint-based methods, causal function model-based methods, and hybrid methods. Constraint-based methods have important applications in learning high-dimensional causal structures. These methods learn causal graph structures in a heuristic way by determining conditional independence. Causal function model-based methods start from the causal mechanism of data generation and use the causal function model to determine causal relationships and identify causal directions. Hybrid approaches combine constraint-based and causal function model-based methods in an attempt to effectively improve upon the shortcomings of causal function models while overcoming the challenge of controlling false discovery rates on high-dimensional data.
CWA-LSTM: A Stock Price Prediction Model
425
The PC algorithm is a classical constraint-based method and the pioneering algorithm of the PC family, which is based on the assumptions of causal sufficiency and causal faithfulness. The algorithm can be divided into three steps. The first step begins with a completely undirected graph and uses a conditional independence test to learn the causal skeleton and obtain the corresponding set of causal separations. In the second step, causal orientation is performed on the V-structures. In the third step, the remaining edges are oriented based on the acyclic constraints of the causal graph. The causal skeleton learning method in the PC algorithm is shown in Table 1, let adj(G, Xi ) denote the set of nodes adjacent to node Xi in graph G, and let sepset(Xi , Xj ) be the set of causal separations between node Xi and node Xj . The PC algorithm starts from a completely undirected graph G and removes the redundant undirected edges in G by a partial loop, let n be the conditional the order of the set of variables i.e. the number of conditional variables, when n = 0, judge the edge independence of all variables, if Xi ⊥ Xj , then delete the edge Xi -Xj and save sepset(Xi , Xj ) and sepset(Xj , Xi ) as the empty set, when all variables are tested, let n = 1. Then select the pairs of variables(Xi , Xj ) that still have edges connected in adj(G, Xi ) {Xj } select the conditional variable set S with n = 1 for the conditional independence test Xi ⊥Xj |S. If the hypothesis test is satisfied, delete the edges Xi -Xj and save S to sepset(Xi , Xj ) and sepset(Xj , Xi ). If all pairs of variables satisfying the condition are selected and their corresponding conditional variable sets are tested, then continue to increase n until the number of elements in the adjacency set of all pairs of adjacency variables is less than n. At this point the causal skeleton learning phase ends and we obtain the causal skeleton G as well as the causal separation set. Then there is the triangular structure Xi - Xk - Xj that is not closed in G according to the V structure, and when Xk does not belong to sepset(Xi , Xj ), the structure is judged to be V structure and oriented as Xi → X k ← Xj . Finally, the remaining undirected edges are oriented as far as possible by applying the default set of rules. In causal structure discovery, conditional independence tests are commonly used to determine whether variables are conditionally independent. The basic principle is that given a set of random variables or a triplet of random variables (X, Y, Z), the conditional independence test is based on independently and identically distributed observations of the sampled joint probability distribution P(X, Y, Z), where the null hypothesis H0: X and Y are conditionally independent of Z and the alternative hypothesis H1: X and Y are not conditionally independent of Z are tested. The partial correlation technique test and the nuclear independence condition test are classical methods, but the former is less stable in higher order conditional independence judgments. Therefore, in this paper, we adopt the method of kernel conditional independence (KCI) test. In our experiments, we feed the current day’s data of 32 feature factors and the next day’s closing price data into the PC algorithm to observe which factors in the causal graph have a causal relationship with the next day’s closing price. Then, we add the time series data of the causal factors to the original multi-factor feature set by repeating the column. This increases the weight of the causal feature factors in the feature set, which is then fed into the neural network to observe its effect on the predicted values.
426
Q. Zhang et al. Table 1. Skeleton learning methods for PC methods. Input: set of variables V Output: Causal skeleton 1: Start with a completely undirected graph G on the set of variables V 2: n = -1; 3: repeat 4: n = n + 1; 5: repeat 6: Select a pair of variables ( , ) that have not been selected in G and have edges connected to them, and the pair needs to satisfy |adj(G, )\{ }| n; 7: repeat 8: Choose the set of conditional variables S belonging to adj(G, )\{ } ,|S| = n. 9: if given S, and are conditionally independent then 10: Delete the undirected edge between and in G; 11: Let sepset( , )=sepset( , ) =S 12: end if 13: until Xi and are no longer connected by edges or all S belong to adj(G, )\{ } , |S|=n have been chosen 14: until all pairs of adjoint variables in G satisfying |adj(G, )\{ }| n have been selected 15: until all pairs of adjoint variables ( , ) in G satisfy |adj(G, )\{ }| n 16:return G,sepset
3.3 LSTM Algorithm Our approach utilizes deep learning algorithms for stock price prediction. While Recurrent Neural Networks (RNNs) have internal memory and feedback and feedforward connections, allowing them to handle input sequences of arbitrary time series, their accuracy tends to decrease with longer time series. It is believed that gradients disappear and explosions make RNNS difficult to train in long time series. LSTM (Long Short Term Memory) is developed on the basis of RNN. On the basis of RNN, LSTM adds three control units: forget gate, input gate and output gate to solve the problem of long sequence dependence. In addition to being able to approximate complex non-linear relationships, LSTM also has the advantages such as high performance, strong learning ability, strong robustness and strong fault tolerance, etc. As LSTM is considered to be one of the most effective models to deal with time series, we choose the LSTM algorithm for stock price prediction. The LSTM is calculated as shown below: ft = sigm(wfx xt + wth ht−1 + bf )
(1)
it = sigm(wix xt + wth ht−1 + bi )
(2)
gt = tanh(wgx xt + wgh ht−1 + bg )
(3)
CWA-LSTM: A Stock Price Prediction Model
427
ct = ft × ct−1 + it × gt
(4)
ot = sigm(wox xt + woh ht−1 + bo )
(5)
where ft denotes the forget gate, which determines how much information from the previous layer will be recorded. it denotes an input gate, which determines how much of the input information will be used. gt denotes a source that generates alternative information that is used to update the status of cell ct at any time. The current cell state ct and the intermediate output ot together determine the final output ht . The LSTM achieves long time transfer of information by constructing input, forget, and output gates, ensuring that useful previous information is available for network training.
4 Experiments In this section, we evaluate the effectiveness of our method in predicting stock prices in the real world and demonstrate the performance improvement resulting from adjusting the weight of the causal features. 4.1 Data Sources and Pre-processing All stock data for the experiment was retrieved from the tushare database in python. A total of 33 factors were selected including the opening price, next day closing price, volume, turnover, compounding factor and dissimilarity moving average, while the next day closing price was used as the forecast objective. Ping An Bank (000001.SZ) is the first joint-stock commercial bank in China to publicly issue shares and is one of the leading retail banks. We have chosen Ping An Bank on the Shenzhen Stock Exchange as one of the stocks to study because of its good representation. Table 1 shows a partial stock dataset for Ping An Bank. We choose Bank of Jiangyin (002807.SZ) and Bank of Qingdao (002948.SZ) as our data samples, as they are also banking stocks in the Shenzhen Stock Exchange. This selection is based on the fact that stocks in the same industry tend to be consistent in terms of policy changes and industry momentum development, thus avoiding forecast errors caused by different policies (Table 2). After obtaining the stock data, we perform a normalization operation on the underlying data using the fit_transform method to eliminate the large differences in order of magnitude of the stock features. This helps to improve the prediction accuracy and convergence speed of the model. The normalization operation ensures that the variance of the feature data is 1 and the mean value is 0. 4.2 Causality Discovery Model All selected multi factorial features are presented in Table 3. To establish a causal relationship between stock features, we apply the PC algorithm to conduct causal discovery of the factors. The expected significance level is set to 0.05, and the independence test method we use is KCI.
428
Q. Zhang et al. Table 2. Partial factor data of Ping An Bank.
trade_date close
open
high
low
pre_close . . . boll_mid boll_lower cci
20190520
12.38 12.35 12.54 12.25 12.44
. . . 12.622
11.054
– 75.583
20190521
12.56 12.4
12.73 12.36 12.38
. . . 12.538
11.007
– 45.013
20190522
12.4
12.57 12.57 12.32 12.56
. . . 12.427
11.01
– 60.711
20190523
12.29 12.24 12.42 12.14 12.4
. . . 12.339
10.96
– 99.072
.. .
.. .
.. .
.. .
.. .
20220523
14.83 15.07 15.07 14.76 15.02
. . . 14.708
13.763
35.003
20220524
14.4
14.83
. . . 14.639
13.744
– 50.466
20220525
14.39 14.43 14.49 14.3
14.4
. . . 14.557
13.804
– 110.45
.. .
.. .
.. .
14.87 14.87 14.4
.. .
.. .
Table 3. Complete set of multi factorial features Number
Factor
Number
Factor
Number
Factor
X1
next_close
X12
open_qfq
X23
macd
X2
open
X13
close_hfq
X24
kdj_k
X3
high
X14
close_qfq
X25
kdj_d
X4
low
X15
high_hfq
X26
kdj_j
X5
pre_close
X16
high_qfq
X27
rsi_6
X6
change
X17
low_hfq
X28
rsi_12
X7
pct_change
X18
low_qfq
X29
rsi_24
X8
vol
X19
pre_close_hfq
X30
boll_upper
X9
amount
X20
pre_close_qfq
X31
boll_mid
X10
adj_factor
X21
macd_dif
X32
boll_lower
X11
open_hfq
X22
macd_dea
X33
cci
The causal relationship between the three stocks are illustrated in Fig. 1(a), 1(b), and 1(c), respectively. A one-way arrow represents a causal relationship between two variables, with the tail of the arrow indicating the cause and the arrow indicating the effect. A two-way arrow indicates a two-way causal relationship or a common causal variable. The causal graph accurately represents the two-way relationship between open (X2) and pre_close (X5), which is consistent with the fact that the opening price of the day is mostly influenced by or fluctuates near the previous day’s closing price. Similarly, the two-way relationship between various pre-weighted and post-weighted values (X11– X12, X13–X14, X15–X16, X17–X18, X19–X20) is also accurately represented in the
CWA-LSTM: A Stock Price Prediction Model
(a) 000001.SZ
(b) 002807.SZ
429
(c) 002948.SZ
Fig. 1. Cause effect diagram of three stocks
causal graph, indicating that the algorithm is effective in discovering the directional causal relationships among multiple factors. Due to the large number of factors, the focus of our experiment is on the factors with a causal arrow connected to the next day’s closing price (X1), including Open (X2), High (X3), Low (X4), and the previous closing price (X5). We added these causal factors to the full set of multifactorial features of the corresponding stocks as a replicated column, attempted to weight the features, and tested the prediction model to improve the accuracy of the prediction. 4.3 Model Construction and Parameter Setting In this paper, we constructed the experimental model using the TensorFlow framework in Python 3.10. Specifically, we employed the sequential model in Tf.Keras, which includes two LSTM layers and one dense layer, to predict stock prices. The first LSTM layer contained 80 neurons, while the second layer contained 100 neurons. We optimized the parameters over 200 epochs using the Adam optimizer with a learning rate of 0.001 and a batch size of 128. 4.4 Evaluation Metrics As we aim to forecast stock prices, we can use evaluation metrics for regression problems to assess the accuracy of our forecasts. In this paper, we choose three evaluation metrics, namely mean squared error (MSE), root mean squared error (RMSE), and mean absolute error (MAE). The formulas for the three evaluation metrics are presented below: 1 n (yi − yi )2 i=1 n 1 n RMSE = (yi − yi )2 i=1 n 1 n yi − yi MAE = i=1 n MSE =
(6)
(7)
(8)
where n is the number of samples and yi is the true data, and yi is the fitted data. All three evaluation indicators are used to measure the deviation between the true value and
430
Q. Zhang et al.
the predicted value, with smaller values indicating that the predicted value is closer to the true value, that the model has better performance, and that the prediction is more successful. 4.5 Experimental Results and Analysis The table below displays the comparison results between the proposed method and the vanilla method with LSTM model. As we can see, our method leads to a decrease in prediction error for all three stocks, regardless of the evaluation metric used. Specifically, the Bank of Zhengzhou achieved a better prediction result with a forecast error reduction of 67.64%. Table 4. MSE of predicted values under different causal weights. Stocks
LSTM
CWA-LSTM
Double CWA-LSTM
000001.SZ
0.196097
0.172315
0.257173
002807.SZ
0.007361
0.006939
0.007196
002948.SZ
0.076136
0.024638
0.027733
Table 5. RMSE of predicted values under different causal weights. Stocks
LSTM
CWA-LSTM
Double CWA-LSTM
000001.SZ
0.442828
0.415108
0.507122
002807.SZ
0.085796
0.083299
0.084828
002948.SZ
0.275928
0.156966
0.166531
Table 6. MAE of predicted values under different causal weights. Stocks
LSTM
CWA-LSTM
Double CWA-LSTM
000001.SZ
0.348909
0.326530
0.397873
002807.SZ
0.060039
0.053066
0.056839
002948.SZ
0.159213
0.114280
0.127671
In order to assess whether increasing the weights of the causal feature factors could consistently improve the prediction, we tested the effect when increasing the causal feature factors by multiple times. As can be seen in Tables 4, 5, and 6, the best prediction results are obtained when a list of causal features is added. However, it can be observed that adding two columns of the same causal feature to the neural network leads to a decrease in the expected effect of Ping An Bank, although the prediction effect of
CWA-LSTM: A Stock Price Prediction Model
431
Jiangyin Bank and Bank of Qingdao is better than the baseline effect, but worse than that of CWA-LSTM. This bias may be due to increasing the weight of causal features also increasing the noise in the factor.
5 Conclusion and Future Work In this paper, we propose a novel method called CWA-LSTM, introducing causality to predict stock prices by adjusting causal factor weights. This approach aims to assign greater weight to causal factors among multi-factor features. Our proposed method integrates causality discovery, multi-factor prediction, and LSTM algorithm. We validate our method on real stock data and show that it outperforms the original LSTM model, effectively reducing the prediction errors of the stock price. In the future, mining the causal relationship between data is a promising research direction. However, there are challenges that need to be addressed. Based on our research, future directions include: 1) Incorporating more feature values and conducting causal discovery to identify stable and reliable causal relationships and minimize the directional errors brought about by Markov equivalence classes. 2) Combining our model with advanced models for time series data such as GRU and Transformer. 3) Investigating the specific laws of causal stock weight adjustment and determining the specific parameters for adjusting weights to enhance the generalizability of the model for different kinds of data in the future.
References 1. McLeod, A.I., Li, W.K.: Diagnostic checking ARMA time series models using squaredresidual autocorrelations. J. Time Ser. Anal. 4(4), 269–273 (1983) 2. Zhang, G.P.: Time series forecasting using a hybrid ARIMA and neural network model. Neurocomputing 50, 159–175 (2003) 3. Duan, J.C.: The GARCH option pricing model. Math. Financ. 5(1), 13–32 (1995) 4. Hossain, A., Nasser, M.: Recurrent support and relevance vector machines based model with application to forecasting volatility of financial returns. J. Intell. Learn. Syst. Appl. 3(04), 230 (2011) 5. Pearl, J.: Models, reasoning and inference, vol.19, no. 2. Cambridge University Press, Cambridge (2000) 6. Joachims, T.: Making large-scale SVM learning practical. Technical report (1998) 7. Collins, M.: Discriminative training methods for hidden Markov models: theory and experiments with perceptron algorithms. In: Proceedings of the 2002 Conference on Empirical Methods in Natural Language Processing (EMNLP 2002), pp. 1–8 (2002) 8. Mills, T.C., Mills, T.C.: Time Series Techniques for Economists. Cambridge University Press, Cambridge (1990) 9. Sorensen, E.H., Miller, K.L., Ooi, C.K.: The decision tree approach to stock selection. J. Portfolio Manag. 27(1), 42–52 (2000) 10. Qiu, M., Song, Y., Akagi, F.: Application of artificial neural network for the prediction of stock market returns: the case of the Japanese stock market. Chaos Solitons Fractals 85, 1–7 (2016) 11. Stoean, C., Paja, W., Stoean, R., et al.: Deep architectures for long-term stock price prediction with a heuristic-based strategy for trading simulations. PLoS ONE 14(10), e0223593 (2019)
432
Q. Zhang et al.
12. Kim, T., Kim, H.Y.: Forecasting stock prices with a feature fusion LSTM-CNN model using different representations of the same data. PLoS ONE 14(2), e0212320 (2019) 13. Siami-Namini, S., Namin, A.S.: Forecasting economics and financial time series: ARIMA vs. LSTM. arXiv preprint, arXiv:1803.06386 (2018) 14. Lu, W., Li, J., Wang, J., Qin, L.: A CNN-BiLSTM-AM method for stock price prediction. Neural Comput. Appl. 33(10), 4741–4753 (2020). https://doi.org/10.1007/s00521-020-055 32-z 15. Granger, C.W.J.: Investigating causal relations by econometric models and cross-spectral methods. Econom. J. Econ. Soc. 37, 424–438 (1969) 16. Hiemstra, C., Jones, J.D.: Testing for linear and nonlinear Granger causality in the stock price-volume relation. J. Financ. 49(5), 1639–1664 (1994) 17. Zhuo, Q., Michael, M., Wing-Keung, W.: Linear and nonlinear causality between changes in consumption and consumer attitudes. Econ. Lett. 102(3), 161–164 (2008) 18. Kahneman, D., Tversky, A.: The simulation heuristic. Department of Psychology, Stanford University, CA (1981) 19. Yong, H., Kang, L., Xiang-Zhou, Z., et al.: Concept drift mining of portfolio selection factors in stock market. Electron. Commer. Res. Appl. 14(6), 444–455 (2015) 20. Zhang, X.-Z., Hu, Y., Kang, X., et al.: A causal feature selection algorithm for stock prediction modeling. Neurocomputing 142(1), 48–59 (2014)
StPrformer: A Stock Price Prediction Model Based on Convolutional Attention Mechanism Zhaoguo Liu, Qihang Zhang, Da Huang(B) , and Dan Wu Institute for Quantum Information, State Key Laboratory of High Performance Computing, College of Computer Science and Technology, National University of Defense Technology, Changsha 410073, China [email protected]
Abstract. Stock price prediction is a crucial task in quantitative trading. The recent advancements in deep learning have sparked interest in using neural networks to identify stock market patterns. However, existing deep learning models have limitations in exploring long dependencies in time-series data and capturing local features, making it challenging to reflect the impact of feature factors on stock prices. To address this, we propose a convolutional attention mechanismbased stock price prediction model, StPrformer. The model utilizes a convolutional attention mechanism to mine temporal dependencies between stock prices and feature factors. Additionally, the convolutional layer in the encoder provides direct a priori information of input features for prediction. Our experiments demonstrate that StPrformer outperforms existing deep learning models in terms of prediction accuracy. Compared to the classical Transformer prediction model, StPrformer reduces the average absolute error and mean square error by 33.3% and 26.1%, respectively. These results confirm the universality and superiority of StPrformer. Keywords: StPrformer · Convolutional Attention · Stock Price Prediction
1 Introduction Time-series prediction is a crucial area of research with significant applications in finance, healthcare, meteorology, transportation, power scheduling, and other important domains. Historically, since the 1990s, prediction methods based on statistical theory have dominated the field of time-series prediction. However, in the early 21st century, time-series prediction methods based on traditional machine learning theory began to show promise. Recently, with the advancement of deep learning theory and the increase in computing power, deep learning-based methods have outperformed traditional methods in many challenging machine learning tasks. As a result, using deep learning theory to solve temporal prediction problems has become a new research frontier, attracting increasing attention from the research community. The advancement of artificial intelligence has resulted in the gradual integration of neural networks into a broad range of tasks, including intelligent recognition, image detection, text translation, and more. Neural networks possess powerful learning capabilities and can acquire many features that are not discernible by humans. The application © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 433–444, 2023. https://doi.org/10.1007/978-981-99-4761-4_37
434
Z. Liu et al.
of neural networks has facilitated the resolution of numerous challenging problems. In light of the exceptional performance of neural networks in other fields, some scholars have employed neural networks for predicting stock prices to unravel the underlying logic of the stock market. Compared to traditional time series methods, the use of neural networks has significantly enhanced prediction accuracy. In particular, Transformer-based models exhibit a significant advantage in modeling long-term dependencies in sequential data due to their self-attention mechanism, which makes them more powerful in terms of learning capability. Nevertheless, existing Transformer-based models for predicting stock prices do not account for the characteristics of stock prices in the network structure adaptation design, leaving considerable room for improvement in terms of generalization capability and prediction accuracy. Specifically, Vanilla Transformer was developed to model long time series, but the volatility of stock prices necessitates short input and output series for this forecasting problem. Moreover, the dot product computation used in the selfattention mechanism of Vanilla Transformer can only learn the dependencies between time points in the series and is insensitive to the local features of the series, rendering the model incapable of accurately capturing the impact of feature factors on stock prices. Conversely, convolutional layers are inclined to exhibit superior generalization performance and faster convergence due to their strong inductive bias and can effectively learn the localized features of time series data. Based on the aforementioned considerations, we propose a stock price prediction model, StPrformer, based on a convolutional attention mechanism. StPrformer adopts the residual join and encoder-decoder structure but replaces the self-attention mechanism with a convolutional attention mechanism, which enables the model to capture both the local relationships of feature factors through the convolutional layer and the long dependencies of time-series data. The contributions of this paper are summarized as follows: • We propose a novel convolutional attention mechanism that seamlessly integrates convolutional layers and self-attention mechanisms. Based on this mechanism, we introduce a new stock price prediction model called StPrformer. Our experiments demonstrate that StPrformer achieves superior prediction results compared to existing models. • Through experiments conducted on the CSI 300 index and individual stock data selected from various industries such as finance, pharmaceuticals, and technology, we compare and analyze StPrformer against three baseline models, namely RNN, LSTM, and Vanilla Transformer. Our results demonstrate significant improvements in MSE and MAE indicators, affirming the effectiveness and generalizability of StPrformer.
2 Related Work Historically, early stock forecasting methods relied on statistical and economic knowledge to analyze stock prices and select relatively simple stock characteristic data for forecasting. The most commonly used models included the autoregressive model (AR) [1], autoregressive moving average model (ARMA) [2], and differential autoregressive integrated moving average model (ARIMA) [3]. While statistically based autoregressive
StPrformer: A Stock Price Prediction Model
435
models have some advantages in computational efficiency due to their single characteristics, they are too simplistic and have low input dimensions, which limits their ability to model non-linear and non-stationary financial time series and makes them resistant to anomalous data. With the rapid development of big data, deep learning technology has made significant breakthroughs due to its powerful learning ability and feature extraction capability. The primary deep learning techniques include convolutional neural networks, recurrent neural networks, and long short-term memory networks. Krauss [4] employed LSTM models to build stock prediction models and demonstrated that these models outperformed random forests and deep neural networks. Selvin [5] proposed three stock prediction models based on CNN, RNN, and LSTM, respectively, and compared their performance by predicting the stock prices of listed companies. The study concluded that LSTM has better memory and is most suitable for stock time series prediction. Zhou [6] proposed a general framework for predicting high-frequency stock markets using LSTM and CNN for adversarial training. The experimental results demonstrated that this method can effectively improve the accuracy of stock price prediction and reduce prediction errors. Zhang [7] combined the discrete optimization algorithm with the LSTM model to construct a novel stock forecasting model for predicting future stock returns, and the model achieved good forecasting results. The Transformer model is an encoder-decoder framework that efficiently extracts features from time-series data. It is specifically designed to capture dependencies between elements in long sequences, enabling parallel training and providing powerful feature extraction, multimodal fusion, and interpretability. The success of the Transformer approach in time-series tasks has led to its increasing use in various fields, including graphical image recognition, natural language processing, and economic securities forecasting [8, 9]. For example, Ding [10] proposed a hierarchical multiscale Gaussian-improved Transformer to capture long-term and short-term dependencies of stock time series. Li [11] introduced an attention mechanism based on a gated cyclic cell structure to focus on stock characteristic information at important time points, accurately reflecting the stock price movement pattern. Yang [12] utilized the Transformer as a self-encoder for user data to model user behavior in depth.
3 StPrformer The StPrformer model is based on the Transformer architecture, and its numerical embedding and temporal encoding modules allow for the extraction of relational time-series features from stock data. The encoder and decoder sections use a convolutional attention mechanism to establish long-range dependencies between time series. To extract temporal dependencies between stock prices and feature factors, the StPrformer model includes a dilated convolutional layer and a multi-head attention mechanism, with a convolutional attention module incorporated into both the encoder and decoder. The final prediction output is achieved through linear layer regression. Figure 1 illustrates the specific structure of the model.
436
Z. Liu et al.
Fig. 1. StPrformer model structure.
3.1 Time Coding In contrast to recurrent neural networks that process data sequentially, the self-attention mechanism in the Transformer model allows for parallel computation by eliminating sequential operations. However, this approach results in the absence of an inherent temporal position relationship between input data, which is critical for capturing dependencies between elements in a time series. To address this issue, temporal position coding must be embedded into the input series. One approach is to use a timestamp to indicate the position of elements within a given time series. This module is known as time encoding and can be divided into two parts: relative time encoding and absolute time encoding. Relative Time Coding. Relative time coding refers to encoding information about the relative positions of time points within a time series of length L. The relative time encoding of the input sequence can be described as follows: ti − t1 (1) RTE(ti , 2j) = sin 2j 10000 dmodel ti − t1 RTE(ti , 2j + 1) = cos (2) 2j 10000 dmodel where RTE(•) denotes the relative time encoding; t i denotes the timestamp at position i in the sequence; t 1 denotes the first time point of the input sequence; j denotes the dimension; i ∈ {1, 2, …, L}, j ∈ {1, 2, …, d model /2}. Absolute Time Coding. Because timestamps in a time series contain valuable information, it is not possible to encode them unilaterally based on the relative positions of points in the sequence. Therefore, it is necessary to convert absolute time encoding into relative time encoding by calculating the time difference between the timestamps of all points in the input sequence and a fixed timestamp. This method of calculating time encoding is more concise and clear than other existing methods, with an equivalent effect. Absolute time encoding can be described as follows: t i − tc ATE(ti , 2j) = sin (3) 2j d model 10000
StPrformer: A Stock Price Prediction Model
ATE(ti , 2j + 1) = cos
ti − tc 10000
437
2j dmodel
(4)
The time encoding of the input sequence of length L and dimension d is obtained by relative time encoding and absolute time encoding and can be described as: TimeEnc(ti , j) = RTE(ti , j) + ATE(ti , j)
(5)
where t i denotes the timestamp at position i in the sequence; j denotes the dimension; i ∈ {1, 2, . . . , L}, j ∈ {1, 2, . . . , dmodel }.
3.2 Encoder-Decoder StPrformer’s encoder is composed of N identical layers, each consisting of a void attention layer and a location-based feedforward network layer. Both sub-layers employ residual connection and layer normalization. The original input sequence X is summed with Value Embedding and Time Encoding to form the input X to the encoder. The Value Embedding operation projects the data from the d-dimensional input space to the d-dimensional model space to align the data dimensions of the model, facilitating subsequent residual connection and model stacking. The Time Encoding provides information on the timestamp of the model embedding sequence. The decoder is made up of M identical layers, each containing three sub-layers. Compared to the encoder, the decoder uses a new set of multi-head attention modules that assign different weights to the decoder’s intermediate result Q based on the encoder’s outputs K and V. This process calculates the degree of correlation between K and Q. 3.3 Convolutional Attention Mechanism Although the Transformer is a general-purpose neural network model designed to capture long-term dependencies, it does not make any assumptions about the input data and features. Therefore, it requires a large amount of data to train effectively and avoid overfitting. When provided with sufficient data, experiments in various fields such as computer vision and natural language processing have demonstrated that Transformer models can outperform convolutional and recursive models. However, for a given amount of data and computation, the Transformer may not outperform other state-of-the-art models, indicating that it lacks inductive bias properties. Convolutional neural network models benefit from inductive bias, enabling them to converge faster and generalize better. In contrast, self-attention models have stronger learning capabilities and perform better on large datasets. However, the sample size for training time-series data is typically much smaller than that of computer vision or natural language processing datasets. As a result, combining the convolutional layer and attention mechanism can leverage the advantages of both model types, leading to improved generalization and accuracy of the time-series prediction model. Figure 2 illustrates the convolutional attention mechanism model.
438
Z. Liu et al.
Fig. 2. Convolutional attention module.
Dilated Convolution. The Vanilla Transformer model employs dot product to compute attention, which captures the dependencies among time points in a sequence but lacks sensitivity to local features of the sequence. Consequently, the self-attention layer may learn incorrect temporal dependencies due to noise or fluctuations in some points of the feature sequence, such as stock prices, resulting in biased prediction outcomes. To overcome this limitation, a convolution layer is required to learn the local features of the time sequence. Moreover, the convolution layer must prevent future information from leaking into the past. One advantage of causal convolution over conventional convolution is its ability to prevent the leakage of historical information. Specifically, the output at time t is obtained by convolving only the elements at time t and earlier times, thus preserving the causality of the time-series data. However, a simple causal convolution can recall only a linearlysized historical sequence at the depth of the network, which limits its application to time-series tasks, particularly those requiring long historical sequences. To fully utilize the local features of the sequences and building upon the property of convolutional neural networks to capture local relevance, dilated convolution is proposed as a means of extracting local features in the sequences for the self-attention module. Dilated convolution is a technique that expands the field of perception without resorting to pooling, thus preserving information. This technique allows each convolutional output to contain a larger range of information. In contrast to standard convolution, where each output is obtained by convolving a kernel with adjacent inputs, dilated convolution involves inserting zeros between the input values and then convolving the kernel with the resulting sequence. This effectively increases the receptive field of the convolution operation. The output of the dilated convolution at moment t is obtained by convolving only the elements at moment t and earlier moments. The output y of a convolution kernel i of size k at moment t can be described as: (6) yti = Activation W (i) xt−k+1 , xt−k+2 , . . . , xt where W (i) denotes the convolution operation and Activation denotes the activation function. The size of k is the range of perception of sequence locality. Masked Multi-Head Attention. The self-attention mechanism utilized by the Vanilla Transformer model enables it to identify and preserve all correlation information between elements in a sequence. This is achieved by learning the relevance of each input element in the sequence to the target element and assigning different weights to the features based on their relative importance. To enhance prediction accuracy, StPrformer proposes the use of masks to prevent information leakage in the temporal dimension. By restricting
StPrformer: A Stock Price Prediction Model
439
correlation calculations to the historical direction, the model focuses its final decision on the feature dimensions that positively contribute to the prediction target. The attention mechanism employed by StPrformer is an enhanced version of the self-attention mechanism, and it can be represented in a compact matrix form as: QK T V (7) Attention(Q, K, V ) = softmax mask √ Dk The attention mechanism used in StPrformer involves a request matrix Q ∈ RL×D , a primary key matrix K ∈ RL×D , and a numerical matrix V ∈ RL×D . Weight scores are calculated for each element in the sequence using these matrices, and the Softmax activation function is then applied to normalize these scores and set their sum to 1. This yields the importance of a point on both the time and space scales. The resulting Softmax weights are then multiplied element-wise with the numerical matrix V to obtain the attention matrix. This effectively prevents historical information leakage and enhances the spatial features with strong and weak correlations, enabling the model to better focus on essential information. When Q, K, and V are all derived from the same sequence, this is referred to as self-attention, and the neural network learns the parameters for Q, K, and V. In practice, multiple sets of Q, K and V are used to project the inputs into h different subspaces, enhancing the learning capability of the network by splitting and calculating the attention weight scores multiple times to capture the correlations of different categories in the sequence. This approach is referred to as the multi-head attention mechanism. The output in each subspace is calculated using the same formula, and after several parallel calculations, the output of all subspaces is concatenated and projected back to the d model space of the model. Since the attention distribution in different subspaces is different and multiple independent heads are able to correlate different information in the sequence, multi-head attention gives StPrformer a more powerful feature extraction capability. Feedforward Network, Residual Connection and Layer Normalization. The feedforward network layer of StPrformer comprises two convolutional layers with a kernel size of 1. This design choice enhances the model’s capability to fit complex temporal relationships. The improved performance of the model due to this feature can be expressed as follows:
F(X ) = Conv (RELU(Conv(X )))
(8)
where Conv(·) and Conv (·) denote two convolutional layers, each of which is internally connected in dimension d. Residual connection is a technique used to address the challenge of training deep neural networks. By reducing model complexity, it can effectively prevent overfitting and gradient disappearance. Additionally, layer normalization is employed to facilitate the smooth integration of input and output from different layers, which accelerates the model’s training process and enables it to converge more rapidly.
440
Z. Liu et al.
4 Experiments 4.1 Experimental Description Features and Datasets. Data and features are key determinants of the upper limit of machine learning, and feature engineering plays a crucial role in stock price prediction experiments, with a significant impact on prediction accuracy. The stock market has accumulated a vast amount of historical data reflecting the fluctuations in its price movements. However, the number of features and the prediction accuracy are not directly proportional. In other words, the more features used in training, the better the prediction accuracy is not guaranteed. Therefore, it is essential to select the most valuable input feature indicators carefully to enhance the model’s prediction performance. To this end, in addition to the historical stock price information, two technical indicators, the Price Earnings Ratio (P/E) and the Moving Average Convergence/Divergence (MACD), were chosen as feature factors for the experimental base data. The P/E ratio represents the market price per share divided by the earnings per share and is often used as an indicator of whether a stock is undervalued or overvalued. The MACD indicator reflects the trend of the stock market based on the relationship between fast and slow average positions. Given the large number of listed companies in China, it is essential to enhance the objectivity and generalizability of experimental analyses. Therefore, stock index and individual stock prices are predicted separately. The CSI 300 index is selected for the stock index data, covering the time range from 1 January 2013 to 1 January 2023. For individual stock data, three typical stocks from three different popular sectors, namely finance (ICBC), pharmaceuticals (Pharmaceuticals of China), and technology (China Mobile), are selected, and the data cover the period from 1 January 2013 to 1 January 2023. The data are divided into a training set, validation set, and test set in a time order of 7:1:2, and positive samples are generated from each dataset as the training sample set, validation sample set, and test sample set, respectively. All the data sources are retrieved through the Tushare Python interface. Model Parameters and Training Methods. The StPrformer model is trained using the Adam optimizer with a learning rate of 0.001. The loss function is set to Mean Squared Error (MSE), and the batch size is 64. The training process comprises 30 iterations on the training set, and at the end of each round, the Mean Absolute Error (MAE) of the current model is evaluated on the validation set. If it is less than the current optimal MAE, the model is saved, and the optimal MAE is updated. If the model has a larger MAE than the optimal MAE in five consecutive training rounds, the model is considered to have reached the learning limit, and training is stopped using the Early Stopping strategy. This strategy is an effective way to prevent overfitting and improve training efficiency. The best-performing model on the validation set is used as the final model, and its performance is evaluated on the test set. To further validate the effectiveness of the StPrformer model, we selected three existing deep learning models, namely RNN, LSTM, and Vanilla Transformer, for comparison experiments. Model Evaluation Metrics. To assess the prediction accuracy of the model quantitatively, we employed two widely-used evaluation metrics: Mean Absolute Error (MAE) and Mean Squared Error (MSE). The MAE measures the absolute difference between
StPrformer: A Stock Price Prediction Model
441
the predicted and true values, while the MSE enhances the sensitivity of the evaluation metric to very large or very small errors. A smaller value for both MAE and MSE indicates a closer predicted value to the true value and, thus, a higher level of accuracy in prediction. Therefore, the evaluation metrics can be defined as follows:
n
i=1 yi − yˆ i (9) MAE y, yˆ i = n
1
yi − yˆ i MSE y, yˆ i = n
n
(10)
i=1
where n is the total sample size, yi and yˆ i denote the true and predicted values respectively. 4.2 Results and Analysis CSI 300 Index Forecasting. To validate the efficacy of the StPrformer forecasts, we conducted a side-by-side comparison between the StPrformer and three baseline models for the CSI 300 Index. We took multiple tests to eliminate the possibility of a single test skewing the outcome and obtained results for forecasting through the different models. The outcomes obtained through multiple tests are presented in Table 1, which shows the average forecasting results of the different models. Table 1. Prediction results of CSI 300 index by different models. Models
MAE
MSE
RNN
2.473
9.947
LSTM
1.749
3.924
Vanilla Transformer
1.308
3.182
StPrformer
0.872
2.352
Table 1 shows that the classical time-series model, RNN, has the highest MAE and MSE values, likely due to its limited capacity to fit time series data with strong volatility, such as stock prices. LSTM, an improved neural network model of RNN, effectively overcomes the issues of gradient disappearance and explosion during training, leading to better prediction performance. However, its performance is still unsatisfactory. In contrast, the Vanilla Transformer model, which relies on self-attention mechanisms, outperforms RNN due to its ability to quickly extract temporal information and perform parallel computation, thereby reducing prediction error. However, its performance can still be improved, as it may not effectively extract local features of the factors. The StPrformer model, with its powerful feature extraction and network learning capabilities, demonstrates the smallest MAE and MSE values. Specifically, its MAE and MSE values are 33.3% and 26.1% lower, respectively, than those of the Vanilla Transformer. The results are significantly better than those of the other three baseline models.
442
Z. Liu et al.
(b)LSTM
(a)RNN
(c)Vanilla Transformer
(d)StPrformer
Fig. 3. CSI 300 index forecast results.
Figure 3 presents the prediction outcomes of the four deep learning models for a specific period of time for the CSI 300 index. The stock index was normalized to aid in the observation of differences in prediction effects among the models. As illustrated in Fig. 3, RNN’s prediction results exhibit negative situations such as chaotic order and high volatility. LSTM presents a slight advantage over RNN, but the curve fitting degree remains suboptimal, leaving ample room for improvement. Vanilla Transformer, based on the attention mechanism, offers higher accuracy in predicting future trends, but the predicted values exhibit high fluctuations. In summary, StPrformer accurately captures inflection points with large trend changes, and its prediction performance is superior to that of Vanilla Transformer. The two curves almost overlap, minimizing the error between the predicted and true values and optimizing the regression fit. These observations visually validate the fitting ability and superiority of StPrformer. Individual Stock Price Forecasting. To confirm the robustness and generalizability of StPrformer, we selected three representative stocks from three different industries: ICBC for finance, China Pharmaceuticals for pharmaceuticals, and China Mobile for technology. Individual stock price prediction experiments were conducted using StPrformer and baseline models. The results are presented in Table 2.As shown in Table 2, StPrformer outperformed the other models for different stock samples, indicating its superior performance. These results provide evidence for the robustness and generalizability of StPrformer in the context of stock price prediction across various industries.
StPrformer: A Stock Price Prediction Model
443
Table 2. Prediction results of typical individual stock prices by different models. Models
Stock
RNN
ICBC
0.296
0.352
Pharmaceuticals of China
0.278
0.343
China Mobile
0.281
0.349
LSTM
Vanilla Transformer
StPrformer
MAE
MSE
ICBC
0.212
0.254
Pharmaceuticals of China
0.224
0.268
China Mobile
0.243
0.281
ICBC
0.187
0.224
Pharmaceuticals of China
0.154
0.207
China Mobile
0.177
0.218
ICBC
0.132
0.184
Pharmaceuticals of China
0.117
0.163
China Mobile
0.124
0.176
5 Conclusions To enhance the accuracy of stock price time series prediction, we propose a novel model called StPrformer, which is based on a convolutional attention mechanism. This model addresses the limitation of existing deep learning models in capturing the impact of feature factors on stock prices. StPrformer leverages a masked multi-head attention layer to extract the time-series dependence between historical stock prices and feature factors. Additionally, it incorporates a dilated convolution layer that prcovides direct a priori information on the features. Our experimental results demonstrate that StPrformer significantly outperforms classical deep learning models, such as RNN, LSTM, and Vanilla Transformer. Moreover, StPrformer exhibits superior feature extraction, generalization, and prediction abilities.
References 1. Gurland, J., Whittle, P.: Hypothesis testing in time series analysis. J. Am. Stat. Assoc. 49(1), 197–201 (1954) 2. Box, G., Jenkins, G.: Time series analysis forecasting and control. J. Time 31(4), 238–242 (1976) 3. Adebiyi, A., Adewumi, A., Ayo, K.: Stock price prediction using the ARIMA model. In: UKSim-AMSS 16th International Conference on Computer Modelling and Simulation, pp. 106–112. IEEE (2014) 4. Fischer, T., Krauss, C.: Deep learning with long short-term memory networks for financial market predictions. Eur. J. Oper. Res. 2(3), 45–48 (2017) 5. Selvin, S., Sreelekshmy, A.: Stock price prediction using LSTM, RNN and CNN-sliding window model. In: International Conference on Advances in Computing, vol. 13, no. 21, pp. 453-456. IEEE (2017)
444
Z. Liu et al.
6. Zhou, X., Pan, Z., Hu, G., et al.: Stock market prediction on high-frequency data using generative adversarial nets. Math. Probl. Eng. 11(3), 20–24 (2018) 7. Zhang, X., Ying, T.: deep stock ranker: a LSTM neural network model for stock selection. In: DMBD, pp. 654–657 (2018) 8. Zhang, Q.Y., Qin, C., Zhang, F.Y., et al.: Transformer-based attention network for stock movement prediction. Expert Syst. Appl. 202, 117239 (2022) 9. Wang, C.J., Chen, Y.Y., Zhang, S.Q., et al.: Stock market index prediction using deep Transformer model. Expert Syst. Appl. 208, 118128 (2022) 10. Ding, Q.G., Wu, S.F., Sun, H., et al.: Hierarchical multi-scale Gaussian transformer for stock movement prediction. In: International Joint Conference on Artificial Intelligence, pp. 4640– 4646 (2022) 11. Gu, L.Q., Wu, Y.J., Pang, J.H.: GRU based on attention mechanism stock forecast model. Syst. Eng. 38(5), 134–140 (2020) 12. Yang, L., Yao, R.J.: Research on credit card default prediction model based on transformer. Comput. Simul. 38(8), 440–444 (2021)
Diagnosis of Lung Cancer Subtypes by Combining Multi-graph Embedding and Graph Fusion Network Siyu Peng, Jiawei Luo(B) , Cong Shen, and Bo Wang The College of Computer Science and Electronic Engineering, Hunan University, Changsha 410083, Hunan, China [email protected]
Abstract. Cancer is a highly heterogeneous disease, and cancer subtype diagnosis is a crucial stage in the whole process of cancer treatment. In recent years, the accumulation of multi-omics data has provided more complete data support for the diagnosis of cancer subtypes. However, the significant differences between different omics and their incorrect connections after composition remain significant obstacles to data fusion. The method based on multi view learning can effectively alleviate this problem. In this paper, we propose a novel Diagnosis of lung cancer subtypes by combining Multi-graph Embedding and Graph Fusion network (DMEGF). A graph fusion network is used to generate consensus graph of multiomics data, and we use the multi-graph autoencoder with attention mechanism to learn the common representation. In addition, the similarity of adjacent features in common features is maintained according to the concept of mutual information. Then classify and recognize the comprehensive representations to diagnose cancer subtypes. The experiments on TCGA lung cancer datasets show that DMEGF has good diagnostic performance and good interpretability. Keywords: Cancer subtyping · Deep learning · Graph embedding · Graph convolution · Mutual information
1 Introduction Lung cancer is one of the most dangerous malignant tumors to people’s health and life. Its incidence and mortality rate are among the top of all types of cancer [13]. According to pathological classification, lung cancer can be divided into two main types: small cell lung cancer and non-small cell lung cancer (NSCLC). NSCLC is the most common type of lung cancer, accounting for the other 85% [23]. NSCLC can be further categorized into three main histopathological subtypes: 35–40% adenocarcinoma (LUAD), 30–35% squamous cell carcinoma (LUSC), and less than 10% large cell carcinoma. Different subtypes of cancer have different sensitivity to treatment, thus requiring diverse treatment methods. Therefore, it is meaningful and necessary to accurately predict cancer subtypes, which can not only help to understand the cancer evolution and genomic mechanism, but also improve the diagnostic accuracy, and help doctors design effective treatments to improve clinical treatment [21]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 445–456, 2023. https://doi.org/10.1007/978-981-99-4761-4_38
446
S. Peng et al.
Many methods have been used to diagnose lung cancer subtypes. Medical diagnosis mainly relies on pathology and imaging. In recent years, various methods based on omics research have been developed and applied to the field of cancer subtype diagnosis [10, 12]. For example, using the automatic deep learning framework of image analysis, [19] was feasible to capture broad histopathological features from the overall images of H&E tissue slice to predict the four molecular subtypes of gastric adenocarcinoma. However, they lack the ability to reveal pathology [16], and multi omics data can provide more comprehensive biological information and better describe organisms. Therefore, it is a challenging problem to find the molecular subtypes of cancer by using the multi-omics data. Recent research results have improved the diagnostic performance by fusing multiomics data through multi-view learning (MVL). Compared with single-view data, more comprehensive features and richer information can be obtained from multi-view data [4, 7]. The existing MVL-based methods are mainly on account of matrix decomposition, collaborative training, canonical correlation analysis and multi-core learning to integrate heterogeneous organizational data [8, 11]. Nevertheless, they can only find the subtle relationship between each omics, and the information obtained is limited, which also limits the improvement of diagnostic performance. Meanwhile, there are still many problems to be solved in the field of multi-view learning. The significant difference between the data of different views can result in inconsistent views, potentially distorting the similarity matrix used to describe similar samples [22]. In addition, the size difference of different features can make feature fusion difficult [24]. In order overcome challenge above, in this study, we propose a method called Diagnosis of lung cancer subtypes by combining Multi-graph Embedding and Graph Fusion network (DMEGF), as shown in Fig. 1. We first obtain the predefined graph of each group data and generate a consensus graph, then use the multi-head attention layer to convolve the original features to generate a comprehensive latent representation, and use the graph-specific decoder for feedback in combination with the consensus graph. In addition, mutual information maximization is used to promote the fusion of the characteristics of each group. Finally, DMEGF classifies and diagnoses the synthesized representation. We will introduce our proposed method in more detail in the following content. Our experiments on TCGA [6] lung cancer data demonstrate that DMEGF has better diagnosis performance and interpretability compared to other methods.
2 Methods Our method is a supplement and improvement to the multi-view clustering method CMGEC [18]. Based on this, the following is the specific implementation of the DMEGF model. Given a multi-omics dataset X = {xv }Vv=1 , including N samples and V types of omics data, xv ∈ RN ×dv represents the v-th omics data with dv -dimensional features. Y ∈ {1, . . . , t}N denotes the subtypes of N samples.
Diagnosis of Lung Cancer Subtypes
447
Fig. 1. The illustration of DMEGF model. The model is mainly divided into four modules: Graph Fusion Network (GFN), Multiple Graph AutoEncoder (M-GAE), Multi-graph Mutual Information Maximization (MMIM) and Multiple Layer Perceptron (MLP). First of all, the corresponding predefined graphs are constructed for the original omics data after preprocessing, and the consensus graph is obtained by using GFN to fuse the predefined graphs. Then input the consensus graph and the original features and graphs of each omics data into M-GAE to learn the representation of comprehensive latent features, and introduce multi-graph attention fusion network to better integrate data. In addition, MMIM is designed to maintain the similarity of adjacent features in the learned comprehensive representation. Finally, the final representation is classified by MLP to obtain the final subtype diagnosis.
2.1 Data Preprocessing Due to the high deletion rate of single omics data, we cannot guarantee to obtain all the omics data we need from the same patient sample, so we need to supplement the missing part of each omics. Based on each omics data xv (v = 1, . . . , V ), we obtain the union of its patient samples. Then, we use zero to represent the expression values of missing patient samples in omics, and then apply a weight to the current complete multi-omics dataset X , and the processing process is as follows Xiv = wv xiv .
(1)
where wv ∈ Rdv is the weight processed by softmax to select information features, and is the element multiplication operator. Then, we construct the predefined graph and adjacency matrix according to the obtained Xiv . The number of nodes in the predefined graph is the number of patients, and the number of edges is the numerical value that minimizes the error caused by the features of omics data and their corresponding labels after k-NN composition.
448
S. Peng et al.
2.2 Graph Fusion Network In the graph fusion network, we use the full connection layer for learning, and the consensus graph learned by l-th layer can be expressed as: Gl = δl WGFN (l) Gl−1 + bGFN (l) . (2) where WGFN (l) is the weight matrix of the l-th layer in the network, and bGFN (l) is the corresponding bias. And δl (·) represents the activation function of the l-th layer, we assume the total number of layers is len, then the expression of δ is relu(x), while l is not equal to len − 1 δl (x) = . (3) x, else Because the predefined graph generated by the data of each omics contains different adjacencies, we combine the fully connected layer to generate a consistent graph for the entire fusion network as follows: V v v v (WGFN A + b ) (4) A∗ = δ (l) GFN (l) . v=1
For the reason of avoiding information omission, the consensus graph A∗ needs to integrate different graph data as comprehensively as possible. So we calculate the loss between the consensus graph and the graphs of each omics. Here we use MSE as the loss function. Therefore, we get the loss of graph fusion network as follows: V (5) loss A∗ , Av . Lge = v=1
In fact, the consensus graph currently obtained can be used for classification diagnosis. However, the graph fusion process only uses structural information and does not take into account node characteristics, we need to deal with it further. 2.3 Multiple Graph AutoEncoder Here, we design different convolution layers for different omics, and perform the following operations on the predefined graphs and original features of each omics: 1 −2 − 21 v (6) zl = σl T AT zl−1 WGCN (l) . where σl (·) is the activation function, whose expression is similar to δl (·), and A = A+I , T ii = j Aij . I is the identity diagonal matrix. WGCN (l) represents the parameter matrix learned by GCN. After obtaining the corresponding feature representation of each omics, we splice them into a comprehensive feature representation Z. In order to make full use of the information in the consensus graph and make the learned feature representation more suitable for subtype diagnosis, we continue to use the GCN layer to convolution the previously obtained comprehensive representation Z and consensus graph A∗ , then obtain a common representation Z : 1 1 ∗ −2 ∗ ∗ −2 Z =σ T (7) A T ZWl .
Diagnosis of Lung Cancer Subtypes
449
Next, operate Z with the specific graph decoder to obtain the reconstructed multi
1
v
omics graph data A , . . . , A . Then we minimize the difference between the original v adjacency graph Av and the reconstruction graph A as follows:
V v Lrec = loss Av , A . (8)
v=1
2.4 Multi-graph Mutual Information Maximization According to the relevant properties of mutual information, the larger the mutual information is, the more similar the samples are [2, 20]. Then the optimization goal of the current module can be expressed as: max{I (X , Z− )}.
(9)
where Z− represents the nearest neighbor representation of sample X and k-NN algorithm is used to find the nearest neighbor here. In our model, we use JS divergence in mutual information, so the loss function of MMIM is as follows: Lmim = −JS(p(z− |x)p(x)p(z− )p(x)). According to [9], JS divergence can be expressed as 2p(x) JS(p(x)q(x)) = Ex∼p(x) logρ log p(x) + q(x) . 2p(x) + Ex∼q(x) log 1 − ρ log p(x) + q(x)
(10)
(11)
Then we can calculate Lmim by combining Eq. (10) and Eq. (11). In summary, the total loss of M-GAE module is obtained as follows: Lgae = Lrec + Lmim .
(12)
2.5 Multiple Layer Perceptron Now the final representation Z is put into the classification diagnosis network to get the final subtype result. We use multiple layer perceptron to classify and diagnose lung cancer subtypes Y , and cross entropy loss is used to calculate the classification loss: Lmlp =
1 N CE(f (zi ), yi ). i=1 N
(13)
where zi ∈ Z and yi ∈ Y represent the feature representation and corresponding subtype category of the cancer data of i-th patient, respectively. In the model training phase, each module is trained sequentially in each epoch based on its respective loss function, and then enters the next training epoch. DMEGF can then capture information from different histological data and effectively fuse them to provide more accurate cancer subtype diagnosis results.
450
S. Peng et al.
3 Results 3.1 Experimental Settings We use the TCGA lung cancer datasets processed by [16] to classify and diagnose LUSC and LUAD to test our method, which are the two main subtypes of NSCLC, and perform ablation experiments on the components of the DMEGF model. And we randomly divide the samples into two groups, and use 1:1 ratio as training set and verification set for model input. In order to obtain more comprehensive performance, four indicators are used to evaluate the diagnostic performance of the model: Accuracy (ACC), AUROC, F1-score and AUPRC. 3.2 Comparison of Diagnosis Results We download tissue whole slide image (WSI) [15], DNA methylation and miRNA [5] to evaluate the performance of our model. The information of these omics data is summarized in Table 1. The total number of data is the union of the three kinds of omics data in the table, and the missing omics data is filled with zeros to align. Table 1. Statistical information of TCGA Lung cancer datasets. LUAD
LUSC
Features sizes
WSI
506
488
3 × 256 × 256
DNA methylation
458
370
9816
miRNA
513
478
1881
total
522
504
We use four baseline methods for comparative experiments, namely PAN-19 [3], MDNNMD [14], GIMPP [1] and LungDIG [17]. And we select the optimal parameters within the recommended range of parameters of the comparison method. For each experiment, we execute the algorithm five times per round, take its optimal result, and repeat the calculation for ten rounds, taking the average of the ten optimal results as the final result. The experimental results are shown in Table 2. And the best results are shown in bold. According to the above results, comparing LungDIG and MDNNMD, it can be seen that the introduction of attention mechanism makes the data generated by multidimensional fusion more helpful for subtype diagnosis. And compared with LungDIG, which also applies attention mechanism, DMEGF improves ACC by 4.3%, AUROC by 3.5%, F1-score by 4.9% and AUPRC by 4.3%. This indicates that it is not just the attention mechanism that plays a role. And our method improves ACC by 1.1%, AUROC by 0.2%, F1-score by 1.3%, and AUPRC by 1.2% compared to GIMPP with sub-optimal performance. This indicates that the representation learned by our method is effective. 3.3 Parameter Analysis Since our lung cancer data needs to rebuild the k-NN graph, the nearest neighbor number kG is an important parameter, so we design the experiment according to the value change
Diagnosis of Lung Cancer Subtypes
451
Table 2. Comparative results of subtype diagnosis on TCGA lung cancer datasets. ACC
AUROC
F1-score
AUPRC
MDNNMD
0.8493
0.8936
0.8609
0.8992
LungDIG
0.8901
0.8983
0.8837
0.9072
PAN-19
0.9039
0.9212
0.9019
0.9207
GIMPP
0.9223
0.9317
0.9195
0.9384
DMEGF
0.9336
0.9339
0.9328
0.9506
of kG . As the number of edges in a graph is closely related to kG , larger kG may increase the time of graph operation and convolution and reduce the running speed of the model, we have carried out experiments in the range of kG ∈ [3, 15]. The results of Fig. 2 show that our method is insensitive to kG , the diagnostic ACC values can remain above 91%. This also proves that DMEGF can fully learn the structure information in graph data when the number of edges is small. When kG = 6, the model effect is slightly better than other situations, so we set kG to 6 in our experiments.
Fig. 2. Diagnostic results of DMEGF when kM = 6.
In MMIM, kM nearest neighbors need to be selected for each graph node. Therefore, we also carry out experiments to test the effect of kM value changing from 1 to 10 on the diagnosis results. And the results are shown in Fig. 3. We can find that the model performs best when kM = 3, and the evaluation index value changes slightly with the increase of kM value, the overall ACC value remains above 91.5%. Finally, we use 3 as the unified nearest neighbor number.
452
S. Peng et al.
Fig. 3. Diagnostic results of DMEGF when kG = 3.
3.4 Ablation Study So as to verify the impact of each module in the model on the subtype diagnosis performance, we conduct ablation experiments. It is mainly aimed at three parts, DMEGFGFN, DMEGF-MMIM and DMEGF-MREC. They ignore the loss in GFN, the reconstruction loss in MMIM module and M-GAE module respectively. Table 3 shows the optimal results of DMEGF and its variants in five runs. In summary, the various modules of our model jointly ensure the overall results of the model, achieving excellent subtype diagnostic performance. We perform the algorithm five times using different omics data, and report the optimal results in Table 4. We can observe that the effect of experiments on single-omics data is very poor, for the three kinds of omics data, the diagnostic accuracy is basically about 0.5, and the results of AUROC and F1-scores are also low. When using two or more kinds of omics data, it has achieved very significant performance improvement. This shows that our model is effective for the fusion of different omics data. Comparing the experimental performance using three kinds of omics data and only two kinds of omics indicates that the data integration in our model makes increase patient’s characteristics diverse and the data information obtained more comprehensive, which makes up for the lack of single-omics data and improves the diagnostic performance. We also plot a PR curve based on the AUPRC value results when using different omics as input data, as shown in Fig. 4, so that we can more intuitively see the comparison of results. Table 3. Performance comparison of DMEGF and its three variants. ACC
AUROC
F1-score
AUPRC
DMEGF-GFN
0.9220
0.9220
0.9206
0.9401
DMEGF-MMIM
0.9240
0.9242
0.9237
0.9394
DMEGF-MREC
0.9201
0.9198
0.9175
0.9383
DMEGF
0.9376
0.9377
0.9368
0.9514
Diagnosis of Lung Cancer Subtypes
453
Table 4. Comparison of group diagnosis results. Input the single-omics data, two omics data and all three omics data into the model for experiment. ACC
AUROC
F1-score
AUPRC
WSI
0.5091
0.5
0.6586
0.7455
Methy
0.5531
0.5
0.6177
0.7234
miRNA
0.5172
0.5
0.6512
0.7414
WSI+Methy
0.9045
0.9046
0.9037
0.9264
WSI+miRNA
0.9064
0.9063
0.9040
0.9330
Methy+miRNA
0.9130
0.9123
0.9084
0.9353
WSI+Methy+miRNA
0.9376
0.9377
0.9368
0.9514
Fig. 4. PR curve when using different omics data.
3.5 Clinical Analysis In this section, we study the effects of different subtypes of NSCLC patients on prognosis, treatment and survival. As shown in Fig. 5(A), there are obvious differences between the two curves in the figure. P-value = 0.042 also shows that our model diagnosis are effective. In order to further understand the characteristics of each subtype we diagnosed, we also download the gene expression datasets of cancer patients corresponding to the model validation set from TCGA, remove the duplicate genes and retain the maximum expression of each gene to obtain the COUNT matrix, and determine the 2-fold differential expression of each subtype of tumor samples and normal samples, and the significant differential expression genes with corrected P-value < 0.05. Then, R package DESeq2 is used to draw the volcanic plot corresponding to each subtype, as shown in Fig. 5(B), showing the differences and significant results of genes in different subgroups of the corresponding subtype.
454
S. Peng et al.
Fig. 5. (A) Kaplan-Meier survival curve of two subtypes diagnosed by our model. (B) Volcanic plot obtained by differential expression analysis of tumor samples and normal samples. The left is LUAD, and the right is LUSC. (C) GO functional enrichment analysis. The above is LUAD, and the below is LUSC.
We further carry out gene enrichment analysis on the lung cancer subtypes diagnosed by DMEGF. We select Top10 for visualization (Fig. 5(C)). We find that the rich GO items of the two subtypes are basically similar, but there are significant differences in the MF part of the ontology. This may explain the survival difference between LUAD and LUSC subtypes, and also prove that our model can effectively distinguish different subtypes.
Diagnosis of Lung Cancer Subtypes
455
4 Conclusion In this paper, we propose an approach called Diagnosis of lung cancer subtypes by combining Multi-graph Embedding and Graph Fusion network (DMEGF). DMEGF processes the data using the multi-graph attention fusion layer, and uses the Graph Fusion Network (GFN) to learn the consensus graph of the multi-omics data, and further uses Multiple Graph AutoEncoder (M-GAE) to learn the latent representation of the integrated data, then applies the graph specific decoder to reconstruct the multi-view. Finally, Multiple Layer Perceptron (MLP) is used for subtype diagnosis of fused feature representation. The experiments on the TCGA lung cancer datasets have proved that our DMEGF can diagnose the cancer subtypes more accurately than the most advanced methods. Funding. This work has been supported by the National Natural Science Foundation of China (Grant No. 62032007).
References 1. Arya, N., Saha, S.: Generative incomplete multi-view prognosis predictor for breast cancer: GIMPP. IEEE/ACM Trans. Comput. Biol. Bioinf. 19(4), 2252–2263 (2021) 2. Bachman, P., Hjelm, R.D., Buchwalter, W.: Learning representations by maximizing mutual information across views. In: Advances in Neural Information Processing Systems 32 (2019) 3. Cheerla, A., Gevaert, O.: Deep learning with multimodal representation for pancancer prognosis prediction. Bioinformatics 35(14), i446–i454 (2019) 4. Chen, Y., Xiao, X., Zhou, Y.: Jointly learning kernel representation tensor and affinity matrix for multi-view clustering. IEEE Trans. Multimedia 22(8), 1985–1997 (2019) 5. Chu, A., et al.: Large-scale profiling of microRNAs for the cancer genome atlas. Nucleic Acids Res. 44(1), e3–e3 (2016) 6. Clark, K., et al.: The Cancer Imaging Archive (TCIA): maintaining and operating a public information repository. J. Digit. Imaging 26, 1045–1057 (2013) 7. Gao, X., Mu, T., Goulermas, J.Y., Wang, M.: Topic driven multimodal similarity learning with multi-view voted convolutional features. Pattern Recogn. 75, 223–234 (2018) 8. Gligorijevi´c, V., Pržulj, N.: Methods for biological data integration: perspectives and challenges. J. R. Soc. Interface 12(112), 20150571 (2015) 9. Jiang, W., Liu, W., Chung, F.-I.: Knowledge transfer for spectral clustering. Pattern Recogn. 81, 484–496 (2018) 10. Lehman, C.D., Wu, S.: Stargazing through the lens of AI in clinical oncology. Nat. Cancer 2(12), 1265–1267 (2021) 11. Li, Y., Wu, F.-X., Ngom, A.: A review on machine learning principles for multi-view biological data integration. Brief. Bioinform. 19(2), 325–340 (2018) 12. Menyhárt, O., Gy˝orffy, B.: Multi-omics approaches in cancer research with applications in tumor subtyping, prognosis, and diagnosis. Comput. Struct. Biotechnol. J. 19, 949–960 (2021) 13. Siegel, R.L., Miller, K.D., Fuchs, H.E., Jemal, A.: Cancer statistics, 2021. CA Cancer J. Clin. 71(1), 7–33 (2021) 14. Sun, D., Wang, M., Li, A.: A multimodal deep neural network for human breast cancer prognosis prediction by integrating multi-dimensional data. IEEE/ACM Trans. Comput. Biol. Bioinf. 16(3), 841–850 (2018)
456
S. Peng et al.
15. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 16. Wang, X., Yu, G., Wang, J., Zain, A.M., Guo, W.: Lung cancer subtype diagnosis using weakly-paired multi-omics data. Bioinformatics 38(22), 5092–5099 (2022) 17. Wang, X., Yu, G., Yan, Z., Wan, L., Wang, W., Lizhen, L.C.C.: Lung cancer subtype diagnosis by fusing image-genomics data and hybrid deep networks. IEEE/ACM Trans. Comput. Biol. Bioinform. (2021) 18. Wang, Y., Chang, D., Fu, Z., Zhao, Y.: Consistent multiple graph embedding for multi-view clustering. IEEE Trans. Multimedia (2021) 19. Wang, Y., et al.: DEMoS: a deep learning-based ensemble approach for predicting the molecular subtypes of gastric adenocarcinomas from histopathological images. Bioinformatics 38(17), 4206–4213 (2022) 20. Wen, J., Han, N., Fang, X., Fei, L., Yan, K., Zhan, S.: Low-rank preserving projection via graph regularized reconstruction. IEEE Trans. Cybern. 49(4), 1279–1291 (2018) 21. Yang, Y., Tian, S., Qiu, Y., Zhao, P., Zou, Q.: MDICC: novel method for multi-omics data integration and cancer subtype identification. Brief. Bioinform. 23(3), bbac132 (2022) 22. Yin, M., Gao, J., Xie, S., Guo, Y.: Multiview subspace clustering via tensorial t-product representation. IEEE Trans. Neural Netw. Learn. Syst. 30(3), 851–864 (2018) 23. Zappa, C., Mousa, S.A.: Non-small cell lung cancer: current treatment and future advances. Transl. Lung Cancer Res. 5(3), 288 (2016) 24. Zheng, Y.: Methodologies for cross-domain data fusion: an overview. IEEE Trans. Big Data 1(1), 16–34 (2015)
Detformer: Detect the Reliable Attention Index for Ultra-long Time Series Forecasting Xiangxu Meng1 , Wei Li1,2(B)
, Zheng Zhao1 , Zhihan Liu1 , Guangsheng Feng1 , and Huiqiang Wang1
1 College of Computer Science and Technology, Harbin Engineering University,
Harbin 150001, China {mxx,wei.li,zhaozheng,lzhlzh,fengguangsheng, wanghuiqiang}@hrbeu.edu.cn 2 Modeling and Emulation in E-Government National Engineering Laboratory, Harbin Engineering University, Harbin 150001, China
Abstract. Long sequence time-series forecasting is a challenging task that involves all aspects of production and life. This requires establishing a model to efficiently predict the future by using temporal dependencies from the past. Although Transformer-based solutions deliver state-of-the-art forecasting performance, there are still two issues when focusing on the ultra-long time series: First, existing solutions take heuristic approaches for black-box sampling to reduce the quadratic time complexity of canonical self-attention that leads to numerical instability and loss of accuracy. Furthermore, attention-based models cannot be applied directly due to the lack of temporal modelling capability. To tackle these issues, we propose a stable and accurate model, named Detformer, which can achieve O(L · logL) time complexity. Specially, we design a dual-feedback sparse attention mechanism to eliminate the poor numerical stability in heuristic sparse attention, and then propose a temporal dependency extraction mechanism that enables Detformer to carry out temporal modelling from the perspective of the attention index. We further propose a noise-eliminating algorithm that identifies reliable attention to improve the temporal modelling. Extensive experiments on four benchmark datasets demonstrate the effectiveness of our Detformer model and the efficiency of our dual-feedback attention mechanism. Keywords: Time series · Forecasting · Transformer · Dual-feedback
1 Introduction Time-series forecasting plays an important role in various applications, such as financial markets, climate change, and disease propagation. Generally, we need to model time dependence from past time-series data and then predict unknown conditions over a long period in the future. Current research suffers from the following two points: 1) length of sequences; 2) accuracy and stability of prediction [10]. Figure 1 shows the empirical examples of LSTM and Informer [10], among them when the prediction length is long, the MSE and MAE scores of LSTM become suboptimal and the results of Informer have obvious volatility in Fig. 1(b). © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 457–468, 2023. https://doi.org/10.1007/978-981-99-4761-4_39
458
X. Meng et al.
(a) LSTM
(b) Informer
Fig. 1. Results of LSTM and Informer. (a) The red line represents the inference speed of predictions and the blue line represents the MSE of predictions. (b) The five points represent five MSE errors with large differences.
Considering long-term time-series sequences, attention-based approaches (e.g., see [8, 10]) achieve acceptable results in most cases. However, they still have the following limitations: (1) the quadratic computation of self-attention; (2) insufficient temporal dependency. Regarding the well-known attention-based study, the computational overhead of Transformer [8] mainly comes from canonical self-attention, which reaches O(L2 ) due to a large number of matrix multiplication. Some work has been proposed to derive a more efficient self-attention (e.g., [2, 4, 10]) via taking an one-size-fits-all approach and making no difference in the treatment of all attentions. However, they ignored the relationship between each attention point and the model ability in detail, for example, random sampling Q or K. Even Informer did a measurement to select the “active” Q, it started from sampling K so the “active” Q has limitations. Due to this inherent limitation of indiscriminate sampling, some high-quality attention will “be killed”, which results in prediction accuracy and stability that can hardly be improved further. It is worth noting that most of these work focused on solving the limitation (1) by taking some tricks to improve the prediction performance, while (2) is often ignored. Although some recent work tried to improve the prediction ability of models by exploring temporal dependency (e.g., [3, 7, 9, 10]), they either started from a data perspective or were difficult to be applied to attention-based models. Based on the above motivations, in this paper, we propose an efficient model, named Detformer, for long-term time series forecasting. Specifically, by embedding the proposed dual-feedback attention as the inner operators, Detformer can accurately identify the reliable attention instead of sampling, which allows our model to refine valid attention index position in a dynamic way during the process of training with O(L · logL) complexity. Next, we design a temporal dependency extraction mechanism, which enables Detformer to obtain the long-term dependency and improve the accuracy and stability of prediction by combining certain weighted short-term dependencies. Inspired by [5], we further propose a lightweight algorithm to eliminate negative attention. To the best of our knowledge, this is the first work to further improve the timing modelling capability of attention-based models by exploring the indexing position of attention in detail. Our primary contributions are summarized as follows: 1. We propose a novel and efficient dual-feedback attention mechanism, which extracts reliable attention from the perspective of the attention index.
Detformer: Detect the Reliable Attention Index
459
2. We propose an explicit attention extraction algorithm to find reliable attention by a local density-based method thus enhancing accuracy. 3. We propose a light-weight algorithm to efficiently eliminate the accuracy loss caused by “noise attention”. 4. Extensive experiments manifest that Detformer outperforms state-of-the-art approaches on four benchmarks.
2 Methodology As aforementioned, the time-series forecasting problem has two prominent problems: the bottleneck of canonical self-attention computing efficiency and the intricate temporal patterns. To tackle these two issues, we propose Detformer for identifying reliable attention index and learning temporal dependency in time-series. Figure 2 provides an overview of our Detformer time-series forecasting framework, which is composed of three-fold: (1) A dual-feedback attention with O(L · logL) complexity as the basic block; (2) A temporal dependency extraction mechanism to extract long-term and short-term dependences for time-series forecasting; (3) A light-weight algorithm based on local attention density to find the negative attention index.
Fig. 2. An overall of Detformer architecture.
2.1 Dual-Feedback Attention Informer formulated the canonical self-attention to probsparse attention as follows. QKT V (1) A(Q, K, V) = Softmax √ d where Q is a sparse matrix of Q. Informer achieved good results by this ProbSparse Self-attention, but through our experiments, we find that the results of Informer approach fluctuate greatly. For example,
460
X. Meng et al.
(a) Sampling
(b) Taking all
Fig. 3. The active Q obtained by sampling K and taking all K are different.
in ETTh2 dataset, when both the input and output is 24, the highest MSE value is 0.56, and the lowest MSE value is only 0.29. The detailed 1-shot results can be seen in Fig. 1(b). It can be explained that sampling is undifferentiated, and the K value sampled may be the lazy K (Here, we call the K as lazy K if its softmax score is low). The results largely depend on the quality of K and Q. Although Q is selected by K, since K starts from sampling, the selected Q has limitations. To identify reliable attention while contending with high complexity, we propose dual-feedback attention with O(L · logL) complexity. We sample K in a different way from [10] and achieve top-k active Q by a measurement to recompute self-attention in Eq. 2. M (qi , K) = ln
LK j=1
e
qi kjT √ d
−
T 1 LK qi kj √ j=1 LK d
(2)
However, simply implementing this approach results in nearly half of reliable attention loss in each layer. Figure 3 shows the index position of active Q calculated by sampling K (Fig. 3(a)) and all K (Fig. 3(b)) respectively, the difference between them is obvious. To reduce the loss from the perspective of the single attention index, we introduce a dual-feedback attention mechanism to circumvent this (see Fig. 4). Specifically, we compute the relatively reliable Q first, and calculate the frequency of adjacent samples Q, which is the same size as the input batchsize, will be set to 32. We record the indices with high frequency and use these reliable indices to replace the K as the starting point to calculate Q. As observed by our experimental studies, the index of the active Q and the active K is very close in the same training batch at the same layer, and when the trend of parameter fitting becomes stable, the position of active attention index will be fixed. Therefore, we can infer the index of active Q by the index of active K. However, due to the attention index is highly correlated with model parameters, active Q from a single process is not reliable, and it makes our approach suffers from a suboptimal situation. To overcome the above problem, we explore the universal law from adjacent batches and apply it to 1-shot. Finally, we can obtain the reliable active Q and achieve the time complexity of O(L · logL). Formally, we summarize the proposed measurement method in Eq. 3 and
Detformer: Detect the Reliable Attention Index
461
Eq. 4 as follows, F(Q) = Q_set_frequency =
Lq i=1
qi yi
(3)
Q_frequencyj
(4)
BQ j=1
where yi is a 0–1 variable indicating if the qi is active or not, and Lq denotes the length of the sequence q. Q− frequency denotes the frequency of Q in 1-shot, and BQ denotes the number of adjacent batch we adopt. By summing over Q− frequency, we obtain a stable set of frequencies, named Q− set− frequency. At last, we can select top-k active Q from the above set.
(a) Full attention
(b) Dual-feedback attention
Fig. 4. Full attention and dual-feedback attention mechanism.
2.2 Dependency Extraction We observe that unlike traditional parameter updates, the past index disappears with iterations rather than being updated based For instance, parameters are on the past. ∂f ∂f , ..., updated iteratively along the direction of ∂x ∂xj , which is often called gradient i
method. In this way, parameters can be updated asymptotically, e.g., from 0.315 × 10−5 to 0.311 × 10−5 . Unlike this method, the update process of an active index needs a big leap, e.g., from 0 to 1 or from 1 to 0. In the face of unknown situations, the active index in this 1-shot is not enough to maintain stable prediction results. Therefore, for each process, we hope that it can not only obtain prediction results from this reliable index, but also obtain predictions from previous experience, which is in line with the basic idea of deep learning, so we call this “iterative updating of attention index”. As observed by our empirical studies shown in Fig. 5, we draw a distinct conclusion, “in one layer, the index of active attention is relatively fixed, and the quality of attention from single process is volatile”. Based on this, we propose a mechanism to extract reliable dependencies from the past. Specifically, for each layer, we allocate a separate
462
X. Meng et al.
(a) Example of one attention layer
(b) Example of another attention layer
Fig. 5. Index position of active attention in different attention layers during the training phase.
specific memory to store their past index dependencies. When the result is obtained, it will extract reliable dependencies from the memory storage and weight them. We define for inference by the weighting scheme as Eq. (5). the index θnow θnow = (1 − α)θpast + αθnow
(5)
where α is a smoothing coefficient hyperparameter. θpast indicates the past indices stored in memory, θnow denotes the 1-shot index. 2.3 Noise Elimination Noise Elimination (NE) means “eliminating the negative attention”, in this section, we hope to obtain more robust results through NE by taking the active attention index as a measurement of attention quality. Intuitively, we can compare the defined measurements of each attention index unit. However, in practice, considering the feature dimensions in hidden layer, the index may reach dozens or hundreds. In addition, during the training process, the number of attention index may be thousands. Therefore, we consider formulating effective noise elimination methods from two aspects: (1) reduce the number of samples; (2) reduce the measurement dimension. Based on our observations shown in Fig. 5, the attention distribution of different samples in the same batch is quite similar. Thus, we can use a batch of samples, which makes the amount of attention we need to deal with will be reduced to (total)/(batch size). Secondly, to reduce the measurement dimension of samples, we develop a measurement method that fully balances accuracy and feasibility. In particular, we perform feature mapping on the index location of the attention sample. The mapping interval depends on the frequency characteristics of each layer. For example, in the first layer, we can use indices 0–21, 22–42, and 43–95 as x, y, and z, respectively. In the second layer, we change the indices to 0–16, 17–32, and 33–46. The mapping feature we obtained as a measurement to calculate local density, is similar to x, y and z in 3D clustering, and we take two attention layers as an example (see Fig. 6(a) and (b)). A straightforward method is to use clustering algorithms (e.g., K-means or K-medoids) to classify attention into categories according to their similarity. But these methods will ignore a lot of “good
Detformer: Detect the Reliable Attention Index
(a) Example attention layer1's index
463
(b) Example attention layer2's index
Fig. 6. Interval mapping.
(a) Local-density
(b)
-means
Fig. 7. Difference between local-density and K-means.
attention”. For example, Fig. 7 shows the difference between local density and K-means. It can be seen that some relatively small clusters will be “killed” if we use the K-means. Inspired by [5], we design an algorithm based on local density of data points, which can easily detect clusters of arbitrary shape and identify negative attention from nonspherical clustering. By setting a density threshold and discarding the points in the area whose density is lower than the threshold, it can achieve more reliable clustering than K-means and K-medoids. Specifically, we first calculate the local density of each sample and select the reliable center candidate of clusters with higher local density. The local density can be calculated by Eq. (6). (6) Y dij − dc ρi = j
where Y (x) = 1 if x < 0, otherwise Y (x) = 0. ρi is equal to the number of points closer to point i than dc . For the sample attention with the highest local density, we regard them as relatively reliable attention candidates. Then we calculate the Euclidean distance between them, and select the one furthest from other candidates as the final reliable attention. The distance can be calculated according to Eq. (7). Distancei = mean( dij ) (7) j
464
X. Meng et al.
For lower density attention, we call it “unreliable point”. It can be observed that lowdensity attention focuses on the position on the halo. These unreliable points will be replaced by reliable centers rather than be deleted directly, which can ensure the total number of samples is constant and the calculation of attention become more stable.
3 Experiments We conduct comprehensive empirical study on four benchmarks to evaluate the efficiency of the proposed Detformer approach. 3.1 Experimental Setup Datasets. Following [10], we validate our framework on four benchmark datasets: ETTh1, ETTh2, ETTm1, ETTm2. Each dataset is collected from electricity transformers, and recorded every 15 min. The train/validation/test data is 12/4/4 months as well. Hyper-Parameter Tuning. Detformer is consist of a 2-layer encoder and 1-layer decoder structure. we use Adam optimizer with the dynamic learning rate starting from 1e−4 , decaying 10 times smaller every 2 epochs and total epochs is 10 with proper early stopping. The weight of long and short temporal dependency is 0.7 and 0.3 respectively. The number of noise attention unit that is replaced with reliable attention unit is 5. We use the same position embedding as Informer, and consider 3 random training/validation setups. The final result is the average of the three runs. All experiments are implemented on a server with a GeForce RTX 3090 24GB GPU. Metrics. We used the two evaluation metrics, including MSE = 1n ni=1 (y − y)2 and MAE = 1n ni=1 |y−y| on each prediction window (averaging for multivariate prediction), and rolling the whole set with stride = 1 [10].
3.2 Quantitative Analysis We compare state-of-the-art sparse-attention-based method, self-attention-based and LSTM-based method, including Informer [12], EDGruAtt [1] and EDLstm [6]. The major difference between EDGruAtt and EDLstm is that EDGruAtt imposes the selfattention mechanism to the RNN model. And the Informer sparse the self-attention mechanism to improve efficiency. These three researches represent the latest results in three directions to verify the validity of our proposed Detformer. We fix the prediction length and evaluate models with a wide range of input lengths: 24, 48, 168, 336 and 720. Here are results on both the multivariate and univariate settings. Eval-I: Multivariate/Univariate Time-series Forecasting. As for this setting, Detformer achieves the consistent state-of-the-art performance in most benchmarks and prediction length settings (see Table 1). Note that, the best results are highlighted in boldface, the “–” indicates failure for out-of-memory. Especially, compared to state-ofthe-art sparse-attention based method Informer, Detformer gives 8.0% (0.974 → 0.896) MSE and 5.2% (0.749 → 0.710) MAE reduction in ETTh1 (at 336), and makes 20.3%
Detformer: Detect the Reliable Attention Index
465
Table 1. Multivariate long sequence time-series forecasting results. Methods
Metric ETTh1 24
ETTh2 48
168
336
720
24
48
168
336
720
Detformer MSE
0.520 0.503 0.568 0.896 1.244 0.382 0.524 3.795 3.449 4.704
MAE
0.523 0.509 0.550 0.710 0.852 0.471 0.562 1.721 1.617 1.897
MSE
0.547 0.554 0.557 0.974 1.247 0.356 0.531 4.319 3.869 4.888
MAE
0.535 0.542 0.553 0.749 0.847 0.450 0.571 1.768 1.704 1.908
MSE
1.223 1.690 2.647 2.637 2.938 1.147 1.756 3.030 2.043 4.070
Informer EDLstm
MAE
0.862 1.039 1.294 1.238 1.330 0.850 1.042 1.356 1.092 1.581
EDGruAtt MSE
0.787 0.813 0.815 –
–
0.693 0.829 2.171 –
–
MAE
0.644 0.668 1.188 –
–
0.644 0.668 1.188 –
–
Methods
Metric ETTm1 24
48
ETTm2 168
336
720
24
48
168
336
720
Detformer MSE
0.433 0.458 0.333 0.365 0.354 0.176 0.169 0.194 0.407 0.810
MAE
0.429 0.438 0.382 0.407 0.397 0.306 0.296 0.323 0.505 0.746
MSE
0.473 0.434 0.338 0.345 0.368 0.183 0.181 0.197 0.398 1.017
MAE
0.459 0.429 0.392 0.383 0.405 0.314 0.313 0.321 0.496 0.821
MSE
0.366 0.455 0.483 1.354 1.520 0.986 1.431 0.529 1.342 1.633
MAE
0.537 0.580 0.585 1.112 1.271 0.723 0.881 0.638 0.953 1.067
Informer EDLstm
EDGruAtt MSE
0.693 0.829 2.171 –
–
0.645 0.782 0.661 –
–
MAE
0.644 0.668 1.188 –
–
0.547 0.583 0.553 –
–
(1.017 → 0.810) MSE and 8.9% (0.821 → 0.748) MAE reduction in ETTm2(at 720). Considering LSTM-based method EDLstm, the Detformer model shows significantly better results than them. Our method has a MSE decrease of 78.5% (at 168), 66.0% (at 336) and 57.7% (at 720) in ETTh1. Although EDLstm has a performance lead in ETTh2 in several cases, we attribute this to a specific example, in which the effectiveness of LSTM is reflected and can be further studied. Our proposed method achieves better results than self-attention-based method EDGruAtt on MSE by decreasing 38.1% (at 48), 30.3% (at 168) and 3.5% (at 336) in ETTh1. Besides, Detformer shows longterm robustness and low computing consumption. We list the univariate results in four datasets in Table 2. Under the comparison with Informer, our Detformer still achieves state-of-the-art performance.
466
X. Meng et al. Table 2. Univariate long sequence time-series forecasting results.
Methods
Metric ETTh1 24
ETTh2 48
168
336
720
24
48
168
336
720
Detformer MSE
0.081 0.111 0.132 0.138 0.110 0.078 0.078 0.107 0.198 0.244
MAE
0.225 0.273 0.300 0.311 0.250 0.213 0.214 0.256 0.355 0.396
MSE
0.100 0.128 0.180 0.132 0.106 0.083 0.083 0.119 0.192 0.285
MAE
0.255 0.293 0.359 0.301 0.259 0.220 0.220 0.272 0.350 0.428
Informer Methods
Metric ETTm1 24
ETTm2
48
168
336
720
24
48
168
336
720
Detformer MSE
0.018 0.019 0.038 0.059 0.069 0.026 0.024 0.039 0.031 0.035
MAE
0.105 0.104 0.159 0.202 0.220 0.110 0.109 0.149 0.129 0.160
MSE
0.020 0.023 0.042 0.041 0.079 0.029 0.025 0.026 0.035 0.024
MAE
0.106 0.115 0.162 0.154 0.233 0.116 0.114 0.127 0.140 0.110
Informer
3.3 Ablation Study We conduct an ablation study on ETTh1 and the results are shown in Table 3. Note that the positive values represent increased error. Without-DF means Detformer without Dualfeedback Attention, and Without-TD means Detformer without Temporal Dependency and Noise Elimination. Eval-II: The Performance of Dual-Feedback Attention. In this study, we use Detformer as the benchmark to eliminate additional effects of dual-feedback attention. The other experimental setup is aligned with the settings of Eval-I. From the Table 3, Detformer achieves worse performance in all cases after taking down dual-feedback attention, which indicates the effectiveness of this technique. Eval-III: The Performance of Temporal Dependency and Noise Elimination. In this study, we testify the effects of Temporal Dependency and Noise Elimination together. From Table 3, we can see that the prediction performance of Detformer will increase after taking Temporal Dependency and Noise Elimination techniques. Table 3. Ablation study. Test Groups
24
48
96
144
168
336
720
Without-DF
0.042
0.017
0.049
0.032
0.073
0.075
0.034
0.031
0.003
0.035
0.018
0.048
0.064
0.042
0.060
0.034
0.017
0.031
0.027
0.023
0.020
0.036
0.024
0.015
0.007
0.011
0.015
0.026
Without-TD
Detformer: Detect the Reliable Attention Index
467
3.4 Hyper-Parameter Sensitivity Analysis Eval-IV: Impact of α. Here, α means the bandwidth of Detformer information. In partice, the best α is 5 and we also compare the different value of α in range {3, 7, 8, 10}.As shown in Table 4, a positive value indicates an increase in error, while a negative value indicates a decrease in error. When α = 8 and α = 10, the case has improved. But in all cases, 5 is the best. Eval-V: Impact of Sampling Factor c. The value of c determines the knowledge that Detformer has learned from the past. As shown in Fig. 8, we choose the value of c from 0.1 to 0.9, and set the c = 0.7 in practice. It is worth noting that even if we set c = 0, this means that the run does not involve attention, but we can still get acceptable results. Table 4. Hyper-parameter sensitivity analysis with α. Test Groups
24
α=3
0.040
0.007
0.003
0.024
0.042
0.071
0.055
0.026
0.004
0.007
0.011
0.024
0.044
0.043
0.049
0.030
0.067
0.014
0.062
0.034
0.017
0.031
0.018
0.042
0.006
0.036
0.017
0.020
α=8
0.042
−0.006
0.040
0.003
0.071
0.015
0.117
0.023
−0.001
0.026
−0.010
0.044
−0.003
0.072
α = 10
0.035
0.024
0.067
−0.002
0.055
0.046
0.034
0.022
0.015
0.044
−0.010
0.022
0.035
0.040
α=7
48
96
144
168
336
720
Fig. 8. Hyper-parameter sensitivity analysis with c.
4 Conclusion We propose a novel model Detformer, which could detect the attention can provide reliable prediction results and maintain O(L · logL) time complexity. Furthermore, we propose an effective approach to extract temporal dependencies based on the attention
468
X. Meng et al.
index and a light-weight method to measure attention quality using local density to improve prediction robustness. Extensive experiments on four datasets demonstrate that our approach outperforms other state-of-the-art methods. Acknowledgements. This research was sponsored by National Natural Science Foundation of China, 62272126, and the Fundamental Research Funds for the Central Universities, 3072022TS0605.
References 1. Bahdanau, D., Cho, K, Bengio, Y.: Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473 (2014) 2. Cirstea, R.G., Guo, C., Yang, B., Kieu, T., Dong, X., Pan, S.: Triformer: triangular, variablespecific attentions for long sequence multivariate time series forecasting full version. arXiv preprint arXiv:2204.13767 (2022) 3. Du, Y., et al.: AdaRNN: adaptive learning and forecasting of time series. In: Proceedings of the 30th ACM International Conference on Information and Knowledge Management, pp. 402–411 (2021) 4. Kitaev, N., Kaiser, Ł., Levskaya, A.: Reformer: the e-cient transformer. arXiv preprint arXiv: 2001.04451 (2020) 5. Rodriguez, A., Laio, A.: Clustering by fast search and find of density peaks. Science 344(6191), 1492–1496 (2014) 6. Sutskever, I., Vinyals, O., Le, Q.V.: Sequence to sequence learning with neural networks. In: Advances in Neural Information Processing Systems 27 (2014) 7. Tzeng, E., Hoffman, J., Saenko, K., Darrell, T.: Adversarial discriminative domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7167–7176 (2017) 8. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems 30 (2017) 9. Wang, J., Chen, Y., Feng, W., Yu, H., Huang, M., Yang, Q.: Transfer learning with dynamic distribution adaptation. ACM Trans. Intell. Syst. Technol. (TIST) 11(1), 1–25 (2020) 10. Zhou, H., et al.: Informer: beyond efficient transformer for long sequence time-series forecasting. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 11106–11115 (2021)
An Ultra-short-Term Wind Speed Prediction Method Based on Spatio-Temporal Feature Decomposition and Multi Feature Fusion Network Xuewei Li1,2,3,4 , Guanrong He1,2,3 , Jian Yu1,2,3 , Zhiqiang Liu1,2,3 , Mei Yu1,2,3,4 , Weiping Ding5 , and Wei Xiong6(B) 1 College of Intelligence and Computing, Tianjin University, Tianjin 300350, China 2 Tianjin Key Laboratory of Cognitive Computing and Application, Tianjin 300350, China 3 Tianjin Key Laboratory of Advanced Networking, Tianjin 300350, China 4 Tianjin International Engineering Institute, Tianjin University, Tianjin 300350, China 5 School of Information Science and Technology, Nantong University, Nantong 226019, China 6 TCU School of Civil Engineering, Tianjin Chengjian University, Tianjin 300384, China
[email protected]
Abstract. Wind energy plays an important role in alleviating global warming. To improve the efficiency of wind energy utilization, an ultra-short-term wind speed forecasting method based on spatio-temporal feature decomposition and multi feature fusion network is proposed. The method is divided into two stages: data preprocessing and prediction. The data preprocessing stage includes the construction and decomposition of the spatio-temporal feature, which can reduce the impact of wind speed fluctuations on the model while ensuring the integrity of spatiotemporal features. In the prediction stage, a parallel prediction network consisting of multiple multi-feature fusion networks (MFFNets) is proposed. Considering the high semantic and high information density characteristics of spatio-temporal features, MFFNet integrates shallow, intermediate, and deep features to combine local detail information with global feature information, which reduces the impact of local wind speed fluctuations on the accuracy of predictions. The proposed method was validated on a wind farm located in the Midwestern United States. Compared with current advanced methods, MFFNet achieved improvements of more than 6.8% in all indicators. The results demonstrate that the proposed method has promising applications in large-scale wind farm forecasting. Keywords: Wind Speed Prediction · Spatio-temporal feature decomposition · Convolutional Neural Network · Variational Mode Decomposition
1 Introduction With the increasingly serious issues of global climate change and environmental pollution, clean energy including wind energy has undergone rapid and extensive development worldwide. Despite the vast potential for wind power generation, the intermittency and © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 469–481, 2023. https://doi.org/10.1007/978-981-99-4761-4_40
470
X. Li et al.
variability of wind power can lead to grid instability issues when integrated into the power grid. Wind speed forecasting is one of the mainstream solutions to address wind power grid integration problems. By predicting future wind speed changes, preparation time can be provided for power grid regulation and other operations, increasing wind power utilization [1]. Currently, wind speed forecasting methods include three categories: physical methods, traditional statistical methods, and machine learning methods [2]. Physical methods use meteorological information such as humidity, temperature, wind direction, atmospheric pressure, as well as geographical information as inputs to the model to predict future wind speed [3]. However, numerical weather forecasting is difficult to model because it requires a large amount of complex computation, and it is not suitable for short-term wind speed forecasting due to the difficulty of generating meteorological data in a short period [4]. The traditional statistical methods determine the relationship function between meteorological factors and wind speed based on historical data and then use the relationship function to predict wind speed [5]. However, statistical models have difficulty capturing the nonlinear characteristics of wind speed, leading to a bottleneck in improving their prediction accuracy [6]. Machine learning methods have been widely applied in wind speed prediction due to their ability to capture complex nonlinear features in wind speed data. For example, Yu et al. [7] proposed an improved Long ShortTerm Memory with Enhanced Forgetting Gate (LSTM-EFG) for wind power prediction, which enhances the effectiveness of the forgetting gate and improves the convergence speed of the model. Compared with traditional statistical methods, machine learning methods can better extract the complex nonlinear features in wind speed sequences, but the impact of wind fluctuations on model prediction accuracy still exists. In recent years, the data preprocessing method based on time series decomposition has become a focus of research to reduce the impact of wind fluctuations on the accuracy of model predictions. In wind speed prediction, commonly used time series decomposition methods include wavelet transform (WT), empirical mode decomposition (EMD), ensemble empirical mode decomposition (EEMD), complementary ensemble empirical mode decomposition (CEEMD), variable mode decomposition (VMD), etc. Although WT has good time-frequency localization properties, it has disadvantages such as nonadaptivity and a lack of selection criteria for wavelet bases. Compared with WT, the EMD method can adaptively decompose time series into multiple intrinsic mode functions (IMFs) with different center frequencies and residues, and has better decomposition effects. For example, Shang et al. [8] propose a wind speed prediction method that combines EEMD, CNN, and attention mechanism. However, the EMD method and its derived methods EEMD and CEEMD have the problem of mode mixing, which reduces the accuracy of model predictions [9]. Compared with WT, EMD, and their derivative algorithms, the VMD method determines the frequency centers and bandwidths of each component by iteratively searching for the optimal solution of the variable model, while adaptively decomposing the time series and mitigating the impact of mode mixing on model predictions. Therefore, the VMD method is widely used in wind speed prediction. For example, Zhang et al. [10] proposed an adaptive wind speed prediction model based on VMD, fruit fly optimization algorithm, ARIMA, and deep belief network.
An Ultra-short-Term Wind Speed Prediction Method
471
Incorporating spatial features among turbines into the wind speed prediction process is a promising research direction. There is a correlation between the wind speeds of upstream and downstream turbines in the wind direction. Therefore, the wind speed sequence of neighboring turbines can be used to enhance the wind speed feature of the target turbine. For example, Jiang et al. [11] computed the Pearson correlation coefficient between the target wind turbine and its neighboring turbines, and selected the highly correlated ones to assist in the prediction process. However, in large-scale wind farms, the single turbine-based spatio-temporal feature methods require modeling of each turbine separately, resulting in significant time overhead and thus not suitable for short-term wind speed prediction. To improve the prediction efficiency of large-scale wind farms, many scholars have started to study how to predict the wind farm as a whole. Khodayar et al. [12] proposed a graph-based spatio-temporal feature prediction method. Liu et al. [13] proposed a spatio-temporal wind speed prediction method based on probability and spatio-temporal features. However, there are two problems with current spatio-temporal feature prediction methods. Firstly, existing methods do not consider the impact of nonlinearity and nonstationarity of wind speed on the prediction accuracy. Secondly, current end-to-end spatiotemporal feature prediction methods do not consider the impact of local wind speed changes on the prediction accuracy. To address these issues, this paper proposes an ultra short-term wind speed prediction method based on spatio-temporal feature decomposition and multi-feature fusion. The main contributions and innovations of this research are as follows: (1) A spatio-temporal feature decomposition method based on time series decomposition is proposed, which can reduce the nonlinearity and nonstationarity characteristics of the data while ensuring the integrity of the spatio-temporal features. (2) A network structure called Multi-Feature Fusion Network (MFFNet) is proposed, which can integrate global features and local features to enhance the model’s attention to local wind speed mutations. (3) A parallel prediction network is constructed, consisting of multiple MFFNets, allowing the network to better focus on the inherent differences between subspatiotemporal features.
2 Materials and Methods 2.1 Method Process In this section, the entire process of the ultra-short-term wind speed prediction method based on spatio-temporal feature decomposition and MFFNet is described in detail. In the data pre-processing stage, a spatio-temporal feature decomposition method based on the VMD is proposed. In the prediction stage, a parallel prediction network composed of multiple MFFNets is constructed. The parallel network has two advantages: first, the parallel network structure allows each sub-network to focus on the intrinsic features of specific sub-spatiotemporal characteristics, thus avoiding the challenge of a single network in learning the feature differences among different sub-spatiotemporal characteristics. Second, all sub-networks can be trained simultaneously, effectively improving the fitting ability without increasing the training time cost. The overall wind speed prediction process is shown in Fig. 1.
472
X. Li et al.
Fig. 1. Flow chart of the hybrid prediction method for ultra-short term wind speed.
2.2 Spatio-Temporal Features Decomposition Method The construction of spatio-temporal features is crucial for the accuracy of the prediction model. This paper uses matrix mapping to construct spatio-temporal features. First, the latitude and longitude of all wind turbines in the wind farm are recorded, and duplicate values are removed. Let the number of longitudes and latitudes obtained be m and n, respectively. Then, a two-dimensional matrix of size m × n is constructed, and the wind turbines are mapped to the matrix according to their real geographical coordinates. Elements in the matrix without corresponding wind turbines are set to 0. The resulting two-dimensional matrix is shown in Fig. 2(a). The spatio-temporal feature diagram of the generated wind speed is shown in Fig. 2(b).
(a) Spatial matrix
(b) Visualization image
Fig. 2. (a) Geographic coordinate mapping matrix. Blue pixel points represent the presence of wind turbines, and white pixel points represent the absence of wind turbines. (b) The generated wind speed spatio-temporal features schematic. It is enlarged from the matrix to show more details, with warmer colors representing higher wind speed values.
In wind speed prediction literature, time series decomposition methods are often used to reduce the nonlinearity and nonstationarity of wind speed. However, there is currently a lack of research on spatio-temporal feature decomposition. The focus of spatio-temporal feature decomposition is on how to express the spatio-temporal features of wind farms while reducing the nonlinearity and nonstationarity of wind speed sequences. For the original wind speed sequence Si (i = 1, ..., n), where n is the number
An Ultra-short-Term Wind Speed Prediction Method
473
of wind turbines, time series decomposition methods are used to decompose it into m sub-sequences IMFsij (j = 1, ..., m) and a residual sequence. The sub-sequences of each wind turbine are arranged in order of increasing central frequency, and then the j-th sub-sequence of each wind turbine is selected and mapped to the matrix according to the true geographical coordinates of the wind turbine, forming STFj (j = 1, ..., m + 1), ensuring the completeness of spatial information in the process of feature decomposition. In wind speed prediction literature, the commonly used sequence decomposition method is EMD and its derivatives. However, it determines the number of sub-sequences obtained by decomposition based on the characteristics of the time series, leading to different numbers of sub-sequences obtained by different sequence decompositions. VMD can determine the number of sub-sequences to be decomposed in advance, which can ensure the completeness of the spatio-temporal features of the sub-sequences. The spatio-temporal feature decomposition process based on VMD is shown in Fig. 3.
Fig. 3. The spatio-temporal feature decomposition process based on VMD.
2.3 MFFNet Model The pixels in the spatio-temporal feature map correspond to the wind turbines in a oneto-one manner, exhibiting high information density and semantic density. In contrast, natural images are natural signals with highly spatial redundancy. When some regions of an image are randomly occluded, the image can still be reconstructed with deep features to generate a similar image to the original one [14]. Therefore, the prediction networks in natural images are not suitable for directly applying to wind speed prediction methods based on spatio-temporal features. During the feature extraction process of MFFNet using residual networks, the receptive field of the feature map increases with the number of convolutions, allowing shallow, middle, and deep features to capture wind speed trends over different ranges. As shown in Fig. 4, wind speed can exhibit sudden changes in local regions, and the local wind speed change trend has a significant impact on the model’s prediction accuracy. Therefore, directly reconstructing deep features would lose a significant amount of local information, leading to a decrease in the prediction accuracy of the network. To address this issue, MFFNet concatenates shallow, middle, and deep features in the channel direction. In the process of multi-feature fusion, considering that adjacent wind turbines may have
474
X. Li et al.
different wind speed change trends due to wind speed fluctuations, MFFNet adopts two 1 × 1 convolutions to aggregate shallow, middle, and deep features of each wind turbine, avoiding potential negative impacts from neighboring wind turbines and achieving global prediction for the wind farm. The main structure of the MFFNet model is illustrated in Fig. 5.
Fig. 4. Illustration of Wind Speed Variation in Wind Farm at Different Time Points.
Fig. 5. Main structure of the MFFNet network.
3 Case Studies 3.1 Dataset In this work, we validate the performance of the proposed method using wind speed data generated by the National Renewable Energy Laboratory based on the Weather Research and Forecasting model in the United States [15]. The selected wind farm is located in the Midwestern United States at longitude 105.00°W to 105.34°W and latitude 41.40°N to 41.90°N (17.3 km × 38.7 km) and contains a total of 592 wind turbines, as shown in Fig. 6. As shown in Fig. 2, the wind turbines in this range are mapped by geographical coordinates to form a matrix of size 30 × 20, where 8 positions do not exist corresponding to wind turbines and are filled using zeros. 52,704 wind speed data from January 2004 to December 2004 were selected for training and 52,560 wind speed data from January to December 2005 were selected for testing, with wind speeds sampled at 10 min intervals.
An Ultra-short-Term Wind Speed Prediction Method
475
Fig. 6. Located area of the dense wind turbines in map.
3.2 Evaluation Metrics In this paper, Mean Absolute Error (MAE), Mean Square Error (MSE) and Symmetric Mean Absolute Percentage Error (SMAPE) are selected as performance evaluation indicators for wind speed prediction models. The calculation formulas of MSE, MAE and SMAPE are as follows: 1 n (yi − yi )2 (1) MSE = i=1 n 1 n |yi − yi | (2) MAE = i=1 n |yi − yi | 1 n (3) SMAPE = i=1 (yi − yi )/2 n where yi and yi are respectively the actual values and the predicted values and n is the total number of predicted samples. 3.3 Results and Analysis of Ablation Experiments In order to validate the rationality of the MFFNet model, two variants of the MFFNet model, MFFNet1 and MFFNet2, were designed for validation. Variant MFFNet1 uses only deep features. Variant MFFNet2 selects intermediate features and deep features for fusion. Table 1 shows the prediction results. Compared to MFFNet1, MFFNet2 shows an improvement of over 24% in all performance metrics. Compared to MFFNet2, MFFNet demonstrates an improvement of over 22% in all performance metrics. This indicates that the introduction of intermediate features and shallow features can compensate for the loss of local information during the feature extraction process, effectively enhancing the expressive capability of the network’s features. 3.4 Results and Analysis of Comparative Experiments To verify the performance of the proposed network, MFFNet was compared with the benchmark model Persistence Method (PM) and six advanced methods including CNNLSTM [16], Transformer+ [17], ImResNet [18], WSFNet [19], HRNet [20], and UNet [21]. HRNet and UNet are deep learning methods for natural images, and were using
476
X. Li et al.
Table 1. The prediction results of the MFFNet model and its variants when the prediction interval is 10 min, 20 min and 30 min. Forecast Horizon
Metrics
MFFNet1
MFFNet2
MFFNet
10-min
MAE
0.422
0.306
0.222
MSE
0.365
0.199
0.128
20-min
30-min
SMAPE
0.062
0.046
0.035
MAE
0.425
0.304
0.226
MSE
0.372
0.200
0.136
SMAPE
0.062
0.047
0.036
MAE
0.459
0.346
0.246
MSE
0.441
0.273
0.157
SMAPE
0.068
0.051
0.039
the same data preprocessing method as in this paper. CNNLSTM and Transformer+ are single wind turbine prediction methods, while ImResNet and WSFNet are single wind turbine prediction methods based on time series decomposition. The prediction time intervals were set to 10 min, 20 min, and 30 min to evaluate the performance of the proposed model in ultra-short-term wind speed prediction. To present the experimental results more intuitively, the optimal value of each metric in all tables is highlighted in bold. During network training, momentum, batch size, initial learning rate, and training epochs were set to 0.9, 16, 0.05, and 60, respectively. Table 2 shows the prediction results of all methods in 2005. Figure 7–9 show the prediction curves of all methods on randomly selected wind turbines. Table 2. Prediction results on the 2005 dataset with all methods when the prediction interval is 10 min, 20 min and 30 min. Methods
10 min MAE MSE
20 min SMAPE MAE MSE
PM
0.339 0.367 0.047
CNNLSTM
30 min SMAPE MAE MSE
0.591 0.936 0.081
SMAPE
0.791 1.513 0.107
0.278 0.195 0.045
0.444 0.451 0.069
0.562 0.723 0.079
Transformer+ 0.254 0.165 0.036
0.398 0.399 0.056
0.552 0.703 0.076
ImResNet
0.306 0.177 0.049
0.338 0.212 0.051
0.363 0.249 0.053
WSFNet
0.303 0.179 0.070
0.370 0.214 0.061
0.410 0.249 0.055
HRNet
0.292 0.197 0.045
0.314 0.222 0.048
0.319 0.227 0.052
UNet
0.248 0.138 0.040
0.269 0.159 0.043
0.285 0.182 0.046
MFFNet
0.222 0.128 0.035
0.226 0.136 0.036
0.246 0.157 0.039
An Ultra-short-Term Wind Speed Prediction Method
477
MFFNet achieved the best prediction performance when the prediction interval was set to 10 min. Among the four single-wind-turbine prediction methods, the advantages of ImResNet and WSFNet based on time-series decomposition were not significant compared to CNNLSTM and Transformer+. Among the methods based on the timespace feature decomposition approach, HRNet was inferior to the four single-windturbine prediction methods, while UNet and MFFNet were clearly superior to other methods. This indicates that models based on natural images may not be suitable for wind speed prediction based on time-space features. From Fig. 7, it can be observed that the prediction curve of HRNet deviates greatly from the true curve, while the deviations of other methods from the true curve are relatively small.
Fig. 7. Prediction curves on a random wind turbine at a prediction interval of 10 min.
Fig. 8. Prediction curves on a random wind turbine at a prediction interval of 20 min.
When the prediction time intervals are 20 min and 30 min, MFFNet still achieves the best prediction results. Among the four individual wind turbine prediction methods, the prediction results of ImResNet and WSFNet become more pronounced, indicating that the influence of wind speed nonlinearity and non-stationarity on the prediction accuracy of the models cannot be ignored. Additionally, the performance gap between the HRNet method and ImResNet and WSFNet gradually narrows, suggesting that spatio-temporal features can enhance data characteristics and improve model prediction accuracy. From Figs. 8 and 9, it can be observed that, except for UNet and MFFNet, the other methods exhibit significant deviations from the true curve.
478
X. Li et al.
Fig. 9. Prediction curves on a random wind turbine at a prediction interval of 30 min.
To further validate the performance of the MFFNet model, the predictive results of all models for each season in 2005 were compared, as shown in Tables 3, 4, 5 and 6. The evaluation metrics of all methods were smaller in spring and winter, but larger in summer and autumn, indicating that the wind speed changes were more stable in spring and winter and more turbulent in summer and autumn. Table 3. The prediction results of all methods in the spring of 2005. Methods
10 min MAE MSE
20 min SMAPE MAE MSE
30 min SMAPE MAE MSE
SMAPE
PM
0.241 0.140 0.031
0.444 0.441 0.056
0.621 0.826 0.078
CNNLSTM
0.231 0.119 0.038
0.373 0.288 0.058
0.486 0.530 0.065
Transformer+ 0.213 0.087 0.028
0.322 0.221 0.042
0.461 0.437 0.059
ImResNet
0.290 0.139 0.047
0.319 0.172 0.048
0.331 0.186 0.048
WSFNet
0.301 0.169 0.120
0.367 0.193 0.061
0.413 0.229 0.053
HRNet
0.278 0.159 0.046
0.295 0.182 0.045
0.308 0.189 0.052
UNet
0.220 0.094 0.036
0.237 0.107 0.039
0.260 0.131 0.042
MFFNet
0.177 0.062 0.028
0.180 0.063 0.030
0.213 0.089 0.033
When the prediction interval was 10 min, the ImResNet and WSFNet methods had poor performance in predicting the wind speed fluctuations in spring and winter compared to CNNLSTM and Transformer+, but their predictive performance was similar in summer and autumn. When the prediction interval was 20 and 30 min, ImResNet and WSFNet methods had a more significant advantage over CNNLSTM and Transformer+ in summer and autumn. Therefore, time series decomposition methods could effectively reduce the impact of nonlinearity and nonstationarity of wind speed on model prediction accuracy. The proposed MFFNet achieved the best predictive performance in both spring and autumn. Although some indicators of MFFNet were not as good as those of UNet in summer and autumn, the predictive gap between MFFNet and UNet gradually
An Ultra-short-Term Wind Speed Prediction Method
479
Table 4. The prediction results of all methods in the summer of 2005. Methods
10 min MAE MSE
20 min SMAPE MAE MSE
30 min SMAPE MAE MSE
SMAPE
PM
0.390 0.483 0.057
0.666 1.186 0.096
0.874 1.868 0.125
CNNLSTM
0.300 0.233 0.051
0.483 0.536 0.078
0.604 0.834 0.091
Transformer+ 0.274 0.200 0.042
0.438 0.486 0.067
0.603 0.874 0.089
ImResNet
0.298 0.164 0.051
0.338 0.203 0.056
0.376 0.281 0.060
WSFNet
0.307 0.177 0.058
0.364 0.209 0.065
0.384 0.233 0.057
HRNet
0.284 0.191 0.047
0.308 0.214 0.050
0.330 0.266 0.054
UNet
0.256 0.154 0.043
0.282 0.183 0.047
0.310 0.233 0.051
MFFNet
0.250 0.172 0.040
0.256 0.185 0.042
0.278 0.225 0.044
Table 5. The prediction results of all methods in the autumn of 2005. Methods
10 min MAE MSE
20 min SMAPE MAE MSE
30 min SMAPE MAE MSE
SMAPE
PM
0.448 0.649 0.070
0.751 1.522 0.117
0.968 2.272 0.149
CNNLSTM
0.307 0.250 0.051
0.494 0.556 0.081
0.626 0.895 0.097
Transformer+ 0.283 0.232 0.045
0.461 0.547 0.073
0.628 0.917 0.096
ImResNet
0.303 0.175 0.052
0.342 0.222 0.057
0.373 0.275 0.059
WSFNet
0.300 0.179 0.053
0.357 0.211 0.063
0.373 0.236 0.058
HRNet
0.296 0.218 0.051
0.325 0.248 0.056
0.348 0.307 0.060
UNet
0.269 0.177 0.047
0.296 0.209 0.052
0.325 0.267 0.056
MFFNet
0.269 0.203 0.045
0.276 0.220 0.047
0.301 0.269 0.050
narrowed as the prediction interval increased, and MFFNet achieved the best predictive results in both MAE and SMAPE indicators. This indicates that the reliability of the MFFNet model increases with the increase in prediction interval.
480
X. Li et al. Table 6. The prediction results of all methods in the winter of 2005.
Methods
10 min MAE MSE
20 min SMAPE MAE MSE
30 min SMAPE MAE MSE
SMAPE
PM
0.273 0.191 0.030
0.502 0.587 0.054
0.700 1.082 0.075
CNNLSTM
0.271 0.176 0.038
0.424 0.422 0.058
0.532 0.630 0.063
Transformer+ 0.245 0.142 0.028
0.370 0.338 0.044
0.515 0.581 0.058
ImResNet
0.332 0.229 0.045
0.354 0.250 0.045
0.372 0.252 0.046
WSFNet
0.303 0.194 0.056
0.391 0.246 0.057
0.469 0.295 0.049
HRNet
0.311 0.222 0.039
0.328 0.241 0.041
0.339 0.246 0.044
UNet
0.248 0.124 0.032
0.261 0.137 0.034
0.292 0.169 0.037
MFFNet
0.190 0.075 0.025
0.191 0.076 0.026
0.247 0.118 0.030
4 Conclusion This paper proposes an ultra-short-term wind speed prediction method based on spatiotemporal feature decomposition and multi-feature fusion Network. Compared with existing advanced methods, the proposed method can effectively capture the local details and exhibits highly reliable performance in cases of wind speed fluctuations. In recent years, inspired by human visual attention mechanisms, attention-based networks have received much attention from researchers. Given the promising performance improvement of attention mechanisms, in future work, how to utilize wind direction information to construct an attention model will also be a focus of research. Acknowledgements. This work is supported by the National Natural Science Foundation of China (Grant No. 61976155).
References 1. Yuan, S., et al.: A novel multi-objective robust optimization model for unit commitment considering peak load regulation ability and temporal correlation of wind powers. Electr. Power Syst. Res. 169, 115–123 (2019) 2. Sari, A.P., et al.: Short-term wind speed and direction forecasting by 3DCNN and deep convolutional LSTM. IEEJ Trans. Electr. Electron. Eng. 17(11), 1620–1628 (2022) 3. Jacondino, W.D., et al.: Hourly day-ahead wind power forecasting at two wind farms in northeast Brazil using WRF model. Energy 230, 120841 (2021) 4. Liu, Z., et al.: Hybrid forecasting system based on data area division and deep learning neural network for short-term wind speed forecasting. Energy Convers. Manag. 238, 114136 (2021) 5. Ding, W., et al.: Point and interval forecasting for wind speed based on linear component extraction. Appl. Soft Comput. 93, 106350 (2020) 6. Jiang, P., et al.: A combined forecasting system based on statistical method, artificial neural networks, and deep learning methods for short-term wind speed forecasting. Energy 217, 119361 (2021)
An Ultra-short-Term Wind Speed Prediction Method
481
7. Yu, R., et al.: LSTM-EFG for wind power forecasting based on sequential correlation features. Futur. Gener. Comput. Syst. 93, 33–42 (2019) 8. Shang, Z., et al.: Decomposition-based wind speed forecasting model using causal convolutional network and attention mechanism. Expert Syst. Appl. 223, 119878 (2023) 9. Wang, X., et al.: Adaptive support segment based short-term wind speed forecasting. Energy 249, 123644 (2022) 10. Zhang, J., et al.: An adaptive hybrid model for short term wind speed forecasting. Energy 190, 115615 (2020) 11. Jiang, Z., et al.: Ultra-short-term wind speed forecasting based on EMD-VAR model and spatial correlation. Energy Convers. Manag. 250, 114919 (2021) 12. Khodayar, M., et al.: Spatio-temporal graph deep neural network for short-term wind speed forecasting. IEEE Trans. Sustain. Energy 10(2), 670–681 (2018) 13. Liu, Y., et al.: Probabilistic spatiotemporal wind speed forecasting based on a variational Bayesian deep learning model. Appl. Energy 260, 114259 (2020) 14. He, K., et al.: Masked autoencoders are scalable vision learners. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 15979–15988. IEEE (2022) 15. King, J., et al.: Validation of power output for the WIND Toolkit, No. NREL/TP-5D00-61714. NREL, Golden, CO, USA (2014) 16. Shen, Z., et al.: Wind speed prediction of unmanned sailboat based on CNN and LSTM hybrid neural network. Ocean Eng. 254, 111352 (2022) 17. Qu, K., et al.: Short-term forecasting for multiple wind farms based on transformer model. Energy Rep. 8, 483–490 (2022) 18. Yildiz, C., et al.: An improved residual-based convolutional neural network for very short-term wind power forecasting. Energy Convers. Manag. 228, 113731 (2021) 19. Acikgoz, H., et al.: WSFNet: an efficient wind speed forecasting model using channel attention-based densely connected convolutional neural network. Energy 233, 121121 (2021) 20. Wang, J., et al.: Deep high-resolution representation learning for visual recognition. IEEE Trans. Pattern Anal. Mach. Intell. 43(10), 3349–3364 (2020) 21. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-31924574-4_28
A Risk Model for Assessing Exposure Factors Influence Oil Price Fluctuations Raghad Alshabandar1(B) , Ali Jaddoa2 , and Abir Hussain3 1 Datactics, Belfast, Northern Ireland, UK [email protected] 2 School of Engineering, Technology and Design, Canterbury Christ Church University, Canterbury, UK 3 University of Sharjah, Sharjah City, UAE
Abstract. The impact of oil price volatility on the global economy is considerable. However, the uncertainty of crude oil prices is affected by many risk factors. Several prior studies have examined the factors that impact oil price fluctuations, but these methods are unable to indicate their dynamic non-fundamental factors. To address this issue, we propose a risk model inspired by the Mean-Variance Portfolio theory. The model can automatically construct optimal portfolios that seek to maximize returns with the lowest level of risk without needing human intervention. The results demonstrate a significant asymmetric cointegrating correlation between oil price volatility and non-fundamental factors. Keywords: Modern Portfolio Theory · Conditional Value at Risk · Consumer Price Index
1 Introduction Oil plays a crucial role in economic steadiness and stock markets; however, the volatility of oil prices is a major challenge in stock marks. According to the International Energy Agency (IEA), oil will provide 30% of the world’s energy, and the uncertainties in oil prices might impact future cash flow and investors’ portfolios. Since the investors could feel more cautious in spending their money, it might lead to a higher risk of economic slowdown [20]. The researchers have found that the Consumer Price Index (CPI) is raised by 0.4 when crude oil barrels increase by 10$, resulting in economic inflation. This increases fears of uncertainty among investors which in turn adversely impacts the stock market. Conversely, the rising oil price could potentially boost the economic growth of countries that are in high demand for oil as the main resource of their income [10]. Due to the globalization of the economy worldwide, volatilities of oil prices are likely to exist across the world. Although empirical studies have highlighted the impact of many risk factors affecting oil prices, their dynamic association is quite difficult to investigate because of the constantly fluctuating of oil prices [12]. Risk analysis can be utilized to address this issue; this involves the process of identifying potential criteria that outline the dynamic risk factors. The risk analysis helps © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 482–492, 2023. https://doi.org/10.1007/978-981-99-4761-4_41
A Risk Model for Assessing Exposure Factors Influence
483
decision-makers to determine the potential fundamental factors and non-fundamental factors that affect oil price fluctuations. Therefore, they would be able to control the financial threats and develop preventative plans [9]. Although the Modern Portfolio theory was widely applied to measure the risk in the financial industry, the evaluation of oil fluctuations factors only takes into consideration the risk values of fundamental factors such as supply and demand [5, 19]. However dynamic non-fundamental factors such as climate, war, economic events, and disasters require further concern. These non-fundamental factors significantly impact the supplydemand, leading to a change in the oil price [5, 19]. We propose a risk model framework inspired by the Modern Portfolio theory in this paper. The model can dynamically indicate the factors that impact the oil price fluctuation by assessing expected return, volatility, and weight. The model will help the decision maker in the oil and gas industry to define optimal portfolios that maximize expected return while minimizing risks. The remainder of this paper is organized as follows; section two Background provides an overview of the Modern Portfolio Theory along with a literature review of related work. Section three Methodology presents the detailed methodology along with results and discussion. The Conclusion and future works are described in Section four.
2 Background and Literature Review 2.1 Modern Portfolio Theory The Modern Portfolio Theory (MPT) is an economical mathematical model that uses an average contrast analysis in finding expected returns, it is maximized for a given level of risk. The decision-maker uses the MPT to make investment decisions by weighing the percentage of risk factors. The analysis will determine the factor that gives the least risk at a given level of return. It measures varanine, which volatility of returns produced by an asset against the expected returns of that asset [3, 6].In volatile stock markets, Investors are becoming more concerned about the risk of their investments, particularly after the global financial crisis. MPT is one of the most popular investment strategies that has been used to overcome this challenge. It assists the investors in achieving better financial outcomes and keeps the downside risk at the lowest level [6]. There are two metrics used to measure the risk exposure in a portfolio. These are Conditional Value at Risk (CVaR) and Value at Risk (VaR). The CVaR is a statistical technique used to measure portfolio financial risk over a specific time horizon. It is calculated by taking the weighted average of the losses in the tail of the distribution of possible returns VaR is a risk management metric used to measure the potential amount of loss that could happen in an investment portfolio over a specified period. To compute the VaR, loss, and occurrence probability of loss within a period should be calculated. Investment companies usually use the VaR to assess and control financial risk, which could help them to optimize their portfolio [13]. Shell employs VaR in its risk management process with the aim of monitoring the risk and maintaining a strong capital position. According to Shell’s annual report for 2020, VaR plays a vital role in determining the losses and gains of market movement over a specified period horizon.
484
R. Alshabandar et al.
2.2 Literature Review There are several studies that have examined factors that impact oil price volatility. The authors in reference [17] show the uncertainty of economic circumstance highly impacts the fluctuation of oil prices. The U.S macroeconomic uncertainty and market volatility were investigated by reference [2], and the finding shows a nonlinear relationship between the stock market and oil prices. The historical simulation model was employed to measure the risk of financial historical time series data, result concludes the model gives the highest weight to data points that are close to present [7]. The VaR is also employed by [1] to measure the energy portfolio, the result suggests VaR measure could aid the investor to highlight risk associations with each energy source in their plan. Thus, it could minimize the risk of a single source in future. The non-negative matrix factorization model, in conjunction with the historical model, is employed to measure the VaR of oil prices [20]. The non-negative matrix factorization model captured the risk factors from an online news dataset and assign them as a weight to the historical simulation model. The result shows risk factors from online news could enhance the accuracy of VaR measurement [15]. The non-negative matrix factorization suffers from several limitations such as data sparsity, overfitting, and cold start recommendations [14]. As consequence, the VaR of oil price could be incorrectly determined risk factors. The author in reference [18] used the Generalized Autoregressive Conditional Heteroscedasticity (GARTH) based on covariance methods to evaluate crude oil prices. The finding suggests fractional integration to the model is required to provide better VaR results. The author in [4] also employs GARTH to forecast the stock market index, the forecast volatility is used to update the weight of observed historical data. The research demonstrates that models with adjusted weight perform better than other models. The Autoregressive Distributed Lag (ADL) is applied to evaluate the oil price fluctuation during covid-19 [15]. The finding illustrates Covid-19, stock market volatility, and Russia–Saudi Arabia oil price war are the main factors contributing to the collapse of the oil price. The main drawback of ADL is giving reliable results only when variables are stationary so that their values are not changed over time. To the best of our knowledge, numerous studies in the literature have been conducted using various methods to investigate factors that impact oil price volatility. Despite these methods are incapable of precisely examining the dynamic non-fundamental factors.
3 Methodology 3.1 Data Description The database was captured from the U.S. Energy Information Administration (EIA) website and Information Energy Agency website [8, 16]. The data were collected from 1/1/1998- till 5/7/2022. The crude oil price was calculated in dollars per barrel. A brief description of the data can be seen below (Table 1).
A Risk Model for Assessing Exposure Factors Influence
485
Table 1. Dataset description Input
Description
Crude Oil Price
It represents the crude oil price measured in (Dollars per Barrel)
Supply Number_Of_Days
It is corresponding to the average amount of time in days on a weekly basis that the warehouse holds prior to selling
Field_Production_ Crude Oil
It corresponds to the number of barrels produced from the oil field each day. It is measured (Thousand Barrels per Day)
Crude Oil Imports
It is corresponding to the number of barrels that crude oil is imported. It is measured (Thousand Barrels per Day)
3.2 Exploratory Data Analysis Exploratory Data Analysis (EDA) is the graphical representation of information and data; The EDA tool is used in this study with the aim of providing an accessible way to see and understand trends, outliers, and patterns in oil price data. Figure 1 displays oil price data over a period of thirty years, highlighting fluctuation in the price of oil. However, the price of oil remained slightly stable between 1991–1996.Notably, the figure reveals the oil price was rising between the period 2000–2008 while the figure showed a decline in energy prices at beginning of 2009, which can be attributed to the financial crisis. As can be seen in the figure, there is a random fluctuation in the oil price was observed over time, particularly in the period between 2010 and 2015. Furthermore, the figure illustrates a drop in oil prices between 2019 and 2020 due to the COVID-19 pandemic and associated restrictions.
Fig. 1. Time series plot for oil price data
486
R. Alshabandar et al.
3.3 Data Pre-processing Data pre-processing is an important step that can significantly enhance the performance of the risk model. The data pre-processing phase is divided into two steps, namely data cleansing and data transformation. The data cleansing involves removing the missing values and removing the outliers data points. The data transformations are implemented by considering two techniques. The first technique is normalizing the data to follow a normal distribution. This is done by utilizing Min-Max scaling which transfers the values to a range between 0,1. The second method involves transferring data to stationary property. The purpose of transforming data into stationary properties is to detect patterns and relationships between variables more easily. To determine if the observed data does exhibit a trend or seasonal effect, the stationary test is employed. The stationary test is visualized in Fig. 2 by plotting Mean (µ) and variance (SD) over time verse the original dataset. The figure reveals that data is not stationary as it has trend effects. Augmented DickeyFuller test (ADF) is utilized to ascertain whether the data has stationary properties after transformation. The ADF is a statistical test that is used to determine the stationary of given time series data by evaluating null and alternative hypotheses [7].The data has been transferred from non-stationary to stationary by using differencing. The difference between the original observation at a particular month and with observation at the previous month. Table 9 compares the statistical results pre and post transferring, the p-value was recorded “0.306304” before transformation while the p-value is “2.231302e-16” after transformation, the p values give an indication that data has been successfully transferred to stationary property.
Fig. 2. Stationary test for oil price data
A Risk Model for Assessing Exposure Factors Influence
487
3.4 Features Engineering Features’ engineering procedure is considered in this study, and oil price data is used to extract time series features. There are a set of features that are extracted. A summary of the features engineering is highlighted below. The previous oil price is important for making predication; the price value at time t is affected by the value at time t-1. The past values are known as lags, with t-1 referring to lag 1, t-2 as lag 2, and so on. The lag value was chosen dependent on PACF (Partial Autocorrelation Function) test. We set different lag values (1,2, 7,30), and the PACF test shows strong coloration between the first lag and with the seventh and twelfth lags while a low correlation between the first lag and with second lag. The rolling window features were extracted based on the mean values of the past month. The meaning is computed by shifting to 1 month. Therefore, we mainly extract three time series features: lag_feature_month, lag_feature_week and Rolling mean. 3.5 Risk Model Construction We developed a risk model based on Modern Portfolio Theory (MPT) theory. We considered each feature as a stock, and each pair as a portfolio of stock components. The risk model has three input parameters. More details can be found below. The risk model only takes into account the data for the last four years. • Expected Returns: The expected returns(µ) represent the predicted profit or loss based on historical rates of return. The expected returns in the risk model are anticipated stock market returns associated with relevant features that can be predicted based on the historical return [6]. • Covariance Matrix: The covariance matrix represents the volatility of the whole portfolio, which can be used by the decision-maker to reap the benefits of diversification and increase the return per unit risk. The covariance matrix in the risk model is employed to explore the relationship between features and measure the total risk of the portfolio. In this study, the Ledoit Wolf shrinkage covariance matrix is used. The reason of utilized Wolf shrunk covariance is to reduce the impact of noise and outliers in the dataset[11]. With this approach, the estimator tends to pull the most extreme coefficients towards more central values through transformation. Thus, more stable estimates for features can be gained. • Weight: The weight refers financially to the percentage of total values of investment in particular assets. The weight in the risk model denotes the weight of features. The weight assists the decision-maker to gain deeper insight into the diversification of their portfolios. The weight is computed by finding the optimal portfolio that maximizes Sharpe Ratio. The optimal portfolio can be derived by using the efficient frontier. The efficient frontier is the combination of portfolio expected returns with portfolio volatility that provides the highest expected returns(µ) at the lowest risk level. The optimal portfolio will assist the stakeholder and stock market advisor to find the best combinations of features that will decrease the risk of oil price fluctuation in the long term. The sharp ratio measures the risk performance of the portfolio.
488
R. Alshabandar et al.
4 Result and Discussion We analyzed the risk model result from 1/1/2019–30/9/2022. The result of expected returns can be seen in Fig. 3. “Field_Production_Crude Oil” and “Crude Oil Imports” features have the highest expected return in 2019. The finding shows that expected returns of “Field_Production_Crude Oil" failed by approximately 10% in 2020. This was mainly due to the significant impact of the COVID-19 pandemic on the oil industry, particularly when prices sharply dropped in early 2020. The expected return for “lag_feature_month” and “Rolling mean” was increased since the price of oil raised post COVID-19 in 2021, and 2022. The outcome indicates a surge in the stock market, particularly for countries reliant on revenue from oil sales to drive their GDP. The covariance matrix is displayed in Fig. 4. The matrix shows the “lag_feature_month” and “rolling_mean” are positively correlated with each other in 2019 and 2020. This can be attributed to the strong dependencies between the past values and current values of oil prices, particularly for long patterns and trends. “Crude_ Oil_ Import” had a good relationship with “Field_Production_ Crude Oil” and “supply number_of_days” in 2021, 2022 respectively. While a very weak and negative correlation is noticed between “lag_feature_month” with these features between 2020 and 2022.Due to the OPEC announcement cut production in May 2020. This can lead to a raise the oil prices and a drop in the stock market. Table 2 compares the Sharpe ratio, expected_returen, and volatility across four years. It is notable that the highest expected return of approximately 7260.0%–7280.0% can be gained in the period between April and September 2022, followed by 2021. However, there was a significant drop in expected_returen in 2020 with yields of only 19.6% and 20.2% respectively. Due to the Russian war in Ukraine, the oil price was rising from 76$ per baller to 110$ on March 2022, leading to increase profits for countries that rely heavily on Oil in their GDP. The highest volatility can be seen in the first quarter of 2020 and the third quarter of 2022. The highest sharp ratio was given in the period between September and December 2020 with a value of 55.62, whereas the lowest was recorded in the first quarter of 2020, with a value of 0.28. The optimal risk portfolio that achieves the expected return with less level of risk, was shown in the third quarter of 2022, acquiring a value of 55.62.
A Risk Model for Assessing Exposure Factors Influence
489
Table 3 lists the weight result for each feature per quarter across four years. As can be seen, the percentage of weight between −1,1. Negative weight was assigned to features when there were fluctuations in price and number of produced oils. In 2019 and 2020 the weight values for “lag_feature_month” features approximately between −0.33 and 0.98 whereas the weight is slightly increased in 2022 to yield the highest value in the third quarter of 2022. It can interpret that the feature “lag_feature_month” has the highest risk on the stock market mostly in the second quarter of 2020 when the price of oil drops less than 20$. The weight acquires for “supply number_of_days” is roughly between 0.29 and 0.50 in 2019 while the weight of this feature increased to a record 0.99 in 2020. On another hand, the weight dramatically dropped between 2021 and 2022 to reach the lowest value of −0.8479 in the period between July and September 2022. The reason for the high weight of “supply number_of_days” is the lockdown restrictions lead to an increase in the number of days to supply oil, and the oversupply of oil is another factor that encourages the oil companies to cut their production. In 2019, the weight of features “Field_Production_Crude Oi” achieved the highest weight in the first, third, and fourth quarters reaching a value of 0.99. Despite this, weight was drop sharply in 2020 and 2021 to give the lowest value of −0.0703 in the first quarter of 2021. The low weight of “Field_Production_Crude Oil” due to there was a sharp decline in oil prices and oversupply, which negatively influences the profitability of the oil industry, in addition, the stock market is also impacted adversely. Table 2. Risk model portfolio performance Date
Expected Return
Volatility
Sharp Ratio
Jan–March-19
68.60%
38.10%
1.75
Apri–June-19
37.00%
25.90%
1.35
July–Sep-19
69.20%
51%
1.32
Sep–Dec-19
64.20%
48.60%
1.28
Jan–March-20
42.00%
145.00%
0.28
April–June-20
19.60%
39.00%
0.5
July–Sep-20
20.20%
39.40%
0.51
Sep–Dec-20
41.20%
121.90%
0.32
Jan–March-21
2012.60%
49.30%
40.76
April–June-21
1812.20%
55.50%
32.62
July–Sep-21
1330.50%
43.70%
30.41
Sep–Dec-21
1374.20%
40.40%
33.96
Jan–March-22
1814.50%
35.90%
50.52
April–June-22
7620.60%
47.20%
16.1
July–Sep-22
7280.00%
132.30%
55.62
490
R. Alshabandar et al. Table 3. Risk model features weight
Date
Lag feature month
Rolling mean
Supply Days
Oil Import
Jan–March-19
−0.1192
−0.1876
0.348
−0.0311
April–June-19
−0.1214
0.1109
0.2886
−0.0302
July–Sep-19
−0.2502
−0.2333
0.5006
−0.007
Oct–Dec-19
−0.3276
−0.0111
0.4755
−0.1267
Jan–March-20
-0.015
0.149
0.99
−0.5213
April–June-20
0.9899
−0.6571
0.99
−0.5328
July–Sep-20
0.89
−0.6005
0.99
−0.4711
Oct–Dec-20
0.3102
−0.2135
0.98
−0.1737
Jan–March-21
0.5385
0.5591
−0.0622
0.0348
April–June-21
0.5061
0.4715
−0.0323
−0.0954
July–Sep-21
0.6261
0.1546
0.1171
0.0013
Oct–Dec-21
0.6962
0.1589
0.1184
0.0553
Jan–March-22
0.1822
0.8466
−0.0239
−0.0142
April–June-22
0.0961
0.99
−0.2539
0.0215
July–Sep-22
0.99
0.3655
−0.8479
−0.0398
5 Conclusion In this paper, we proposed the risk model based on modern portfolio theory to prioritize the underlying fundamental non-fundamental factors that contribute to fluctuation in the crude oil price and their subsequent impact on the growth of stock markets. Our result demonstrates that non-fundamental factors significantly influence oil price volatility. The finding also reveals that there is a nonlinear association between oil prices and the stock market. In the future machine learning will be used to forecast the critical date in the future when the price highly fluctuated.
A Risk Model for Assessing Exposure Factors Influence
2019
2020
2021
2022
491
Fig. 3. Risk model expected returns plot.
2019
2021
2020
2022
Fig. 4. Risk model covariance matrix plot.
References 1. Ahmed Ghorbel, A.T.: Energy portfolio risk management using time-varying extreme value copula methods. Econ Model. 38, 470–485 (2014) 2. Bakas, D., Triantafyllou, A.: Volatility forecasting in commodity markets using macro uncertainty. Energy Econ. 81, 79–94 (2019)
492
R. Alshabandar et al.
3. Chiu, M.C., Wong, H.Y.: Mean-variance portfolio selection with correlation risk. J. Comput. Appl. Math. 263, 432–444 (2014). https://doi.org/10.1016/j.cam.2013.12.050 4. Fries, C.P., Nigbur, T., Seeger, N.: Displaced relative changes in historical simulation: application to risk measures of interest rates with phases of negative rates. J. Empir. Finance 42, 175–198 (2017) 5. Deng, S., Sakurai, A.: Crude oil spot price forecasting based on multiple crude oil markets and timeframes. Energies 7, 2761–2779 (2014) 6. Fabozzi, F.J., Gupta, F., Markowitz, H.M.: The legacy of modern portfolio theory. J. Investing 11(3), 7–22 (2002) 7. Darryll, H.: Evaluation of value-at-risk models using historical data. Econ. Policy Rev. 2, 1 (1996) 8. IEA: Data and statistics 9. Jagoda, K., Wojcik, P.: Implementation of risk management and corporate sustainability in the Canadian oil and gas industry: an evolutionary perspective. Acc. Res. J. 32, 381–398 (2019) 10. Kilian, L., Zhou, X.: The impact of rising oil prices on US inflation and inflation expectations in 2020–23. Energy Econ. 113, 106228 (2022) 11. Ledoit, O., Wolf, M.: I shrunk the sample covariance matrix. 30(4), 110–119 (2003) 12. Lyu, Y., et al.: Good volatility, bad volatility and economic uncertainty: Evidence from the crude oil futures market. Energy 222, 119924 (2021). https://doi.org/10.1016/j.energy.2021. 119924 13. Sarykalin, S., et al.: Value-at-risk vs. conditional value-at-risk in risk management and optimization. In: State-of-the-Art Decision-Making Tools in the Information-Intensive Age, pp. 270–294. INFORMS (2008). https://doi.org/10.1287/educ.1080.0052 14. Sherman, T.D., Gao, T., Fertig, E.J.: CoGAPS 3: Bayesian non-negative matrix factorization for single-cell analysis with asynchronous updates and sparse data structures. BMC Bioinform. 21, 6–11 (2020) 15. Le, T.-H., Le, A.T., Le, H.-C.: The historic oil price fluctuation during the Covid-19 pandemic: what are the causes? J. Empir. Finance 58, 101489 (2021) 16. U.S. Energy Information Administration: PETROLEUM & OTHER LIQUIDS 17. Watugala, S.W.: Economic uncertainty, trading activity, and commodity futures volatility. J. Futur. Mark. 39(8), 921–945 (2019) 18. Youssef, M., Belkacem, L., Mokni, K.: Value-at-Risk estimation of energy commodities: a long-memory GARCH–EVT approach. Energy Econ. 51, 99–110 (2015) 19. Zhang, Y.-J., Yao, T.: Interpreting the movement of oil prices: driven by fundamentals or bubbles? Econ. Model. 55, 226–240 (2016) 20. Zhao, L.T., et al.: Forecasting oil price volatility in the era of big data a text mining for VaR approach. Sustainability (Switz.) 11, 14 (2019). https://doi.org/10.3390/su11143892
A Dynamic Graph Convolutional Network for Anti-money Laundering Tianpeng Wei1 , Biyang Zeng1 , Wenqi Guo1 , Zhenyu Guo2 , Shikui Tu1(B) , and Lei Xu1(B) 1 Department of Computer Science and Engineering, Shanghai Jiao Tong University,
Shanghai 200240, China {tushikui,leixu}@sjtu.edu.cn 2 Big Data & Artificial Intelligence Laboratory, Industrial and Commercial Bank of China, Shanghai 200120, China
Abstract. Anti-money laundering (AML) is essential for safeguarding financial systems. One critical way is to monitor the tremendous daily transaction records to filter out suspicious transactions or accounts, which is time consuming and requires rich experience and expert knowledge to construct filtering rules. Deep learning methods have been used to model the transaction data by graph neural networks, and achieved promising performance in AML. However, the existing methods lack efficient modeling of the transaction time stamps, which provide important discriminative features to efficiently recognize the accounts participating money laundering. In this paper, we propose a dynamic graph attention (DynGAT) network for detecting suspicious accounts or users, which are involved in illicit transactions. The daily transaction records are naturally constructed as graphs, by considering the accounts as nodes and the transaction relationship as edges. To take the time stamps into account, we construct one transaction graph from the records within every time interval, and obtain a temporal sequence of transaction graphs. For every graph in the sequence, we not only compute the node embeddings by a vanilla graph attention network, but also explicitly develop a time embedding via a position-encoding block. Our method further captures the dynamics of the graph sequence through a multi-head self attention block on the sequence of concatenations of node embeddings and time embeddings. Moreover, we train the model by a weighted cross entropy loss function to tackle the sample imbalance problem. Experiments demonstrate that our method outperforms the existing ones in AML task. Keywords: Anti-money laundering · Graph attention network · Dynamic graph
1 Introduction Money laundering aims to legalize illegal profits, and it is a severe financial crime. It may cause direct financial losses, or provide concealed funds to criminals for illegal activities. For example, according to [8], the amount of money laundering between Nigeria and USA is up to one trillion US dollars per year through financial crimes, and the money © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 493–502, 2023. https://doi.org/10.1007/978-981-99-4761-4_42
494
T. Wei et al.
further diverts to illegitimate applications like drug trafficking. However, the detection of money laundering crimes is complex because the illegal clients and transactions are in extreme considering the number of total clients and relevant transactions. Also, illegal users try to cover up himself with normal transactions and changes money laundering techniques rapidly. So, it is hard to conclude a recognition pattern for long term. Besides, there is few public data set released by companies or banks as the real transaction records are confidential, which further limits the research on anti-money laundering problem. Considering the harmful consequences of money laundering, many efforts are put to research on how to identify money laundering behaviors and locate relevant suspects in order to prevent them. The identification methods evolve with the advance of computer science technology, from manual hard-code rules to automatic detection with machine learning strategies. Traditional machine learning methods, e.g., logistic regression, support vector machine, multilayer perceptron (MLP), have been used in AML to reduce the human burden [3, 5]. One limitation of these works is that they still heavily rely on the manual computation of certain features, which demands expert knowledge. For example, to obtain the characteristics of the trading counterparties, one needs to compute the number of different individuals one has transactions with. Such important features do not exist in original data, and it is difficult to calculate an appropriate one especially when considering high-order counterparties. With the recent success of deep learning, deep learning methods have been developed for the AML tasks [7]. For example, a deep autoencoder was developed in [10] to identify anomalies in import and export data of Brazil and determine the fraud possibilities of exporters. Recently, graph neural network models (GNNs) [4] became popular for the AML tasks, because the financial transactions are naturally network data. Specifically, the construction from transaction records to graph is quite simple. The entity (i.e., the account or its holder) is taken as a node, and a flow of currency between two entities defines an edge between the corresponding nodes. All the information of the transaction or the entities can be constructed as initial features for edges or nodes. The overall idea of GNN is to aggregate the information of neighbor vertices in graph, and update the representation of the current vertex. Thus, it is possible to have an end-to-end pipeline for AML, and the money laundering patterns may be automatically detected and learned to improve the efficiency and performance of AML. In an attempt made by [12], the authors suggested that, before running traditional machine learning methods like random forest, it is potentially beneficial to augment the node features with the embeddings computed from graph convolutional networks (GCNs). It is noted that financial transactions naturally come with time stamps. Actually, the time stamps are very important when taking into account to construct the transaction graphs. See Fig. 1 for example. The records of scenario (a) may be normal daily transactions, while the records of scenario (b) are suspicious with money laundering patterns. If ignoring the time stamps, the transaction graphs for the two scenarios are the same in topology. If using one day as the time interval, the constructed graph sequences become discriminative in representing money laundering patterns. The work in [12] demonstrated that temporal GCN models, e.g., EvolveGCN [9], consistently outperform the GCN on static graphs, because the temporal evolution of transactions provides additional
A Dynamic Graph Convolutional Network
495
information for the evaluation of the risk of the money laundering, which is lost in static graphs. Specifically, EvolveGCN [9] cut transaction graphs to different snapshots by time, and trained a recurrent neural network (RNN) to generate the weight matrix for the GCN in every time step. There are also other dynamic graph models. DyGCN [2] proposed a new aggregation method for GCN to update the node features by time. By defining information changes as the variation of neighbors embedding, the node feature in next time step is aggregated as a weighted sum of that in current time step and its information changes. Lian Yu [15] uses a three-layer graph attention network (GAT) to generate node embeddings for each subgraph, and random forest is used for classification with the input of previous embeddings. Temporal-GCN [1] seeks to learn temporal information with an additional long-short term memory (LSTM) layer. The output of LSTM layer is then passed to two layers of TAGCN, which is a variant of GCN, to obtain temporal node embeddings. A function is defined in [13] to assign weights to each graph in the sequence, with the most recent graph has more influence on the final decision. TGAT [14] adopted the idea of position encoding in natural language processing, and proposed an encoding function to generate time feature for networks to learn sequential information.
Fig. 1. The transaction time stamps may result in different constructed graphs for a certain time interval. Here, one day has been used as the time interval in both (a) and (b). Notice that if ignoring the time stamps or using one month as the time interval, the topology of the two constructed graphs are the same.
In this paper, we propose a novel dynamic graph attention (DynGAT) network for the AML task. We transform the transaction data into a graph sequence which consist of snapshots from different time clip. Then, the temporal patterns are explicitly exposed in the graph sequence, and can be modeled by a multi-head attention block on the node embeddings of the transaction graphs at each time interval. The node embeddings are computed by a graph attention network [11] which introduces attention mechanism in aggregation process, and learns to assign different weights to the neighbors so as to focus on certain influential entities. Moreover, the dynamic order of the graph sequence is computed as a time embedding via an encoding block, and is concatenated into the node embeddings. It strengthens the discriminative learning of deep representations between the normal transactions and the illicit transactions.
496
T. Wei et al.
The contribution of the paper can be summarized as two points: 1. The paper proposes a new method to the extraction and application of temporal sequential information from dynamic transaction graphs with the purpose of identification of illegal clients. Particularly, the experiment shows the proposed method improves the performance in AML tasks compared with the static and reference dynamic methods. 2. Instead of traditional RNN-based method, the paper models temporal information explicitly through time encoder, and generates node embeddings for each subgraph. The analysis of which time period contributes the most to the recognition of illegal clients becomes possible by studying the heatmap of attention mechanism. Therefore, the proposed model is able to provide certain explainability for the regulation review of financial institutions.
2 Method 2.1 An Overview of DynGAT
Fig. 2. An overview of the proposed DynGAT. For graphs at different time points, we employ a time encoder (TE), a continuous function d (k) that maps time to a vector space, to extract explicitly the temporal order information of the transaction. This time embedding is then concatenated with node embeddings H (k) = {h1 (k) , . . . , hi (k) , . . . } obtained from a GAT encoder, where hi (k) $ is the GAT embedding of the node vi for the graph Gk in the sequence. Finally, we use multi-head attention to capture temporal information within the graph sequences.
An overview of the proposed DynGAT is given in Fig. 2. Our method mainly contains five parts, i.e., construction of graph sequence from the raw transaction records, node embeddings via GAT, time embeddings via a time encoder, multi-head self attention to capture the temporal patterns evolving in the sequence of concatenated embeddings of the node and the time, and finally a softmax classifier to predict whether a is suspicious or not. The key contribution of the model lies in the explicit modeling of the temporal patterns on the transaction graph sequence. The learned representations between the normal nodes and the suspicious nodes are more discriminative than the existing methods, and thus the prediction performance on the AML task is improved.
A Dynamic Graph Convolutional Network
497
2.2 Constructing the Graph Sequence from the Transaction Records Given transaction records and client information records, we first define a static transaction graph G = (V, E) where a vertex vi ∈ V is constructed from a client (either the transferor or the payee of the transactions) and a directed edge ek = (vi , vj ) ∈ E represents a flow of funds from the client vi to the client vj . Each transaction has a timestamp, and we use tk to denote the timestamp of the transaction edge ek . We record the biggest timestamp value as Tmax and the smallest one as Tmin . To convert static graph into dynamic graph sequence, we use the following formula (i) (i) to determine start time tstart and end time tend for the i-th graph in dynamic sequence: (i)
tstart =
i · (Tmax − Tmin + 1) (i) (i) , tend = tstart − 1, n
(1)
where n refers to the number of time clips. Then we define our dynamic graph sequence as G = {G0 , G1 , . . . Gn−1 , Gn }. Each graph Gi in the sequence is expressed as the follows: (i)
(i)
Gi = (V, Ei ), Ei = {ek |tstart ≤ tk ≤ tend },
(2)
Generally, each graph Gi has a different subset of the transaction edges which happens in the specified time period, but it uses the same set of all nodes and attributions. Usually, the bank accounts are relatively stable, as the events of registering new accounts or closing the old ones are rare. It is reasonable to assume that the accounts and their attributes are unchanged at a certain period. 2.3 Computing the Node Embedding and the Time Embedding We employ a GAT as a feature extractor from the graph sequence. A node embedding is computed for each graph in the sequence by the GAT. Like most NN variants, GAT is implemented by a message passing process and an aggregation function. The network details of GAT are referred to the original paper [10]. We adapt the standard GAT in this paper to the AML task by integrating he features of the transaction edges. Specifically, for an account vi after the graph construction, we will have a set of accounts as its neighbors, denoted as N i , in a certain graph Gk . To calculate the node embedding of vi , GAT first calculates the similarity sij between the current account vi and its transaction counterparties vj in the graph. Then, the similarity is normalized to obtain the attention coefficient αij for each neighbor. In this paper for the AML task, the edges of a transaction graph are often attached with additional features, e.g., the amount of the transaction. Hence, the original equation in GAT is generalized to integrate the edge features f ij [6] as follows: sij = softmax(LeakyReLU (aT [Whi ||Wf ij ||Whj ]))
(3)
where a and W are learnable parameters, the operator “||” denotes the concatenation of the two vectors, and hi is the embedding of the node vi . The initialization of hi is the raw feature vectors of the nodes (e.g., the age of the account).
498
T. Wei et al.
The node embedding hi is updated as the weighted sum of the neighbors’ representations hj , where the weights are given by the attention coefficient αij which is computed from sij in Eq. (3). The above computed node embeddings hi is able to encode the topological patterns of the money launderers, but it fails to incorporate the information of the time stamps. To remedy this issue, one may use the time stamps directly as features, but this manner could harm the generalization performance of the model which may overfit the time stamps. Here, we develop a time encoder module to convert the one-dimension time sequence into time embeddings instead of directly using the time itself. This is a very similar trick as the position encoding in the field of natural language processing. Our time encoding function d (t) is adopted from [14]. The input of d (t) is a time sequence {0, 1, . . . , n}. The time encoder maps each time to a d dimensional vector, and outputs {d (0), d (1), …, d (n)}. Finally, we fuse the graph structure information and the transaction time information together by concatenating the node embeddings with the time embedding, i.e., hi
(k)
= [hi (k) ||d (n)]
(4)
(k)
where hi is the final output of the deep representations of the node vi for the graph Gk in the sequence. It contains not only the topological patterns of the laundering behaviors, but also the timing features of the transactions which are important to detect the money launderer. 2.4 Modeling the Temporal Information by the Multi-head Self Attention Block The final step is to integrate i’s from different timestamps. The embedding embedding (k) set of a node i is marked as hi k ∈ (0, 1, . . . , n)}. A common method in dynamic sequence is to train a classifier to give a risk score for every hidden representation, and apply an average or max pooling to integrate the scores. The drawback of the traditional method is that the sequential information across time steps is lost. For example, an account appears to be normal if it receives money from multiple persons, or transfers money to another account. But if it receives multiple transactions and transfers money a to certain one, it becomes very suspicious. Self-attention is invented to identify the relevant tokens within a sentence, and shows great ability in modeling the positional information between token with position encoding. Under the time sequential scenario, we can also apply self-attention operation to learn the sequential information of temporal hidden representations across time steps. The inputs of self-attention are three matrices Q, K, V. To calculate our own Q, K, V, we reorganize the embedding set to be an embedding matrix Z where the kth column (k) T
of Z is (hi ) . Then, multiply Z with project matrix W q to acquire the input Query matrix which can be expressed as Q = W q Z. We then have Key matrix K and Value matrix V following the same manner. Let d be the length of hi
(k)
. Then, the integrated
A Dynamic Graph Convolutional Network
499
representation matrix Fi of node i can be calculated with the follow equation. QK T Fi = Attention(Q, K, V) = softmax( √ V) d
(5)
To stabilize the output, multi-head scheme is usually applied. Generally, multi-head operation is to calculate self attention function for multiple times individually and concatenate different outputs. Then, a projection matrix is used to restore the shape of output matrix. 2.5 Training the Model The loss function is binary cross entropy function. In AML task, we care more about the minor illegal class. To overcome the imbalance class problem, we assign additional weights during the calculation of loss function. With a higher weight attached to the illegal class, we force the model to focus on the illegal samples.
3 Experiment 3.1 Experiment Setting The AML task can be treated as a binary classification problem. Each node should be classified as either legal or illegal. In our experiment, a 3-layer MLP will be used as classifier. The dynamic graph sequence will be organized according to the definition of Sect. 2.1, and we set the number of clips to be 5. We set the number of GAT layers to be 2. To solve the class imbalance problem, we assign [0.7, 0.3] as weight parameter of our loss function. Our models are trained on Tesla V100 and the cpu is Intel(R) Xeon(R) Gold 6130 CPU @ 2.10 GHz. 3.2 Dataset • Bitcoin-Elliptic Bitcoin-Elliptic dataset records bitcoin transaction events in 49 time steps. Each node in dataset represents a transaction and each edge is a flow of payment between transactions. 2% of all nodes are labeled as illegal, 21% are labeled as legal and the rest are unclassified. In real banking system, the clients that have not been identified by experts should be treated as normal ones. We follow the routine and view unclassified nodes as legal in our experiment resulting in 4545 illegal nodes and 226224 legal nodes. • AMLSim-10K AMLSim is a synthetic data generator developed by IBM company. It can simulate real banking environment and provides synthetic transaction records and clients’ KYC (Know Your Customer) information according to revealed money laundering patterns. We use the default setting to generate this dataset with 12043 clients in total of which 737 are labeled as suspects, and name it as AMLSim-10K (Table 1).
500
T. Wei et al. Table 1. Overview of data sets
Dataset
#Nodes
#Node features
#Edges
#Edges features
Bitcoin-Elliptic
203769
166
234355
0
AMLSim-10K
12043
5
197905
2
3.3 Baseline Method We first choose static GAT to test the performance without any temporal information. Two implicit temporal model evolveGCN and GRU-GCN are also selected as baseline. EvolveGCN uses GCN to generate node embeddings on each subgraph, but does not update the weight matrix of GCN during loss backward stage. Instead, a RNN is trained to update the weight matrix in the next time step with the current weight matrix as input. GCN-GRU is a common method for dynamic graph sequence. It uses a GRU cell to connect the outputs of each time step and uses the output of the final step as the final output. To analyze the effect of time encoder, we create DynGAT-T by removing time encoder module from DynGAT. 3.4 Performance The result from Table 2 shows our method outstands the baseline methods which proves the effectiveness of our method. In Elliptic data set, the performance of our method in all four metrics is best among all methods. In AMLSim-10K data set, our method achieves the best precision and ROC-AUC. The static GAT has an extremely high recall, because it tends to classify everyone as illegal accounts, which means it does not successfully learn a pattern to divide the two classes. Our method, however, manages to maintain a good balance on them. Table 2. Experiment result on Elliptic and AMLSim-10K data sets Elliptic
AMLSim-10K
Method
precision recall
Static GAT
0.603
0.494 0.543 0.952
F1
ROC-AUC precision recall 0.056
0.985 0.105 0.379
F1
ROC-AUC
GCN-GRN
0.229
0.413 0.294 0.777
0.137
0.311 0.190 0.560
evolveGCN 0.235
0.161 0.191 0.769
0.097
0.322 0.149 0.518
DynGAT-T
0.516
0.547 0.531 0.923
0.284
0.470 0.354 0.653
DynGAT
0.605
0.654 0.628 0.962
0.421
0.242 0.308 0.695
The PR curves in Fig. 3 show that generally GAT methods perform better than GCN based methods in Elliptic data set. Static GAT even performs better than other temporal baseline methods. We conclude that the graph constructed from Elliptic is spare, the
A Dynamic Graph Convolutional Network
501
average degree of each node is 1.1. When cutting into different graphs, each subgraph becomes more sparser, and the other GCN-based methods fails to extract useful patterns in each subgraph. Because DynGAT-T cannot use time encoding to learn sequential pattern and it sees less global information than static GAT, static GAT is even able to perform better than DynGAT-T.
Fig. 3. PR Curve of Elliptic data set (left) and AMLSim-10K (right)
AMLSim-10K is a more connected graph than Elliptic. All methods suffer a loss of performance in AMLSim-10 while the DynGAT and DynGAT-T stand out. Static GAT is not able to learn patterns from such a complicated graph, as information from different time period is mixed up. For example, if an account receives from 1 account and 11 accounts in two days continuously, it seems suspicious for abnormal increase of transaction times. On average, it has 5 transactions per day which seems to be normal. In this case, the temporal models perform better.
4 Conclusion The paper presents a new dynamic graph learning model called DynGAT for the antimoney laundering task. We expose the temporal information via constructing a sequence of dynamic graphs from the transaction records. With the employment of a time encoder to calculate the temporal hidden representation for nodes in dynamic graphs, and fuse with the node embeddings computed by GAT for the structural behavior of money laundering activities. We devise a multi-head attention to further capture the sequential information of the temporal hidden representations from different time steps. The learned representations of the users by our method are more discriminative between the normal users and the money launderers. The proposed method explicitly encodes the time variable as part of the node feature and applies attention mechanism in both graph encoder and the fusion process, which helps to explore the importance and contribution of time in AML identification by analyzing heat maps of attention mechanism. Experiments show the effectiveness of our method and outperforms the existing methods in detecting the suspicious users or accounts.
502
T. Wei et al.
Acknowledgement. This work was supported by Shanghai Municipal Science and Technology Major Project (2021SHZDZX0102), and ICBC (grant no. 001010000520220106). Shikui Tu and Lei Xu are co-corresponding authors.
References 1. Alarab, I., Prakoonwit, S.: Graph-based LSTM for anti-money laundering: experimenting temporal graph convolutional network with bitcoin data. Neural Process. Lett. 55(1), 689–707 (2023) 2. Cui, Z., Li, Z., et al.: Dygcn: Efficient dynamic graph embedding with graph convolutional network. IEEE Trans. Neural Netw. Learning Syst. (2022) 3. Feng, Y., Li, C., et al.: Anti-money laundering (AML) research: a system for identification and multi-classification. In: Ni, W., Wang, X., Song, W., Li, Y. (eds.) Web Information Systems and Applications, pp. 169–175. Springer International Publishing, Cham (2019) 4. Gori, M., Monfardini, G., Scarselli, F.: A new model for learning in graph domains. In: Proceedings of the 2005 IEEE International Joint Conference on Neural Networks, vol. 2, pp. 729–734. IEEE (2005) 5. Jullum, M., Løland, A., et al.: Detecting money laundering transactions with machine learning. J. Money Laundering Control 23(1), 173–186 (2020) 6. Kami´nski, K., Ludwiczak, J., et al.: Rossmann-toolbox: a deep learning-based protocol for the prediction and design of cofactor specificity in rossmann fold proteins. Briefings in Bioinform. 23(1), bbab371 (2022) 7. Kute, D.V., et al.: Deep learning and explainable artificial intelligence techniques applied for detecting money launderinga critical review. IEEE Access 9, 82300–82317 (2021) 8. Olujobi, O.J., Yebisi, E.T.: Combating the crimes of money laundering and terrorism financing in Nigeria: a legal approach for combating the menace. J. Money Laundering Control 26(2), 268–289 (2023) 9. Pareja, A., Domeniconi, G.,et al.: EvolveGCN: evolving graph convolutional networks for dynamic graphs. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 5363–5370 (2020) 10. Paula, E.L., Ladeira, M., et al.: Deep learning anomaly detection as support fraud investigation in Brazilian exports and anti-money laundering. In: Proceedings of the 2016 15th IEEE International Conference on Machine Learning and Applications (ICMLA), pp. 954–960. IEEE (2016) 11. Veliˇckovi´c, P., et al.: Graph attention networks. arXiv preprint arXiv:1710.10903 (2017) 12. Weber, M., Domeniconi, G., Chen, J., Weidele, D.K.I., Bellei, C., Robinson, T., Leiserson, C.E.: Anti-money laundering in bitcoin: experimenting with graph convolutional networks for financial forensics. arXiv preprint arXiv:1908.02591 (2019) 13. Wu, B., Liang, X., et al.: Improving dynamic graph convolutional network with fine-grained attention mechanism. In: ICASSP 2022–2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 3938–3942. IEEE (2022) 14. Xu, D., Ruan, C., Korpeoglu, E., Kumar, S., Achan, K.: Inductive representation learning on temporal graphs. arXiv preprint arXiv:2002.07962 (2020) 15. Yu, L., Zhang, N., Wen, W.: Abnormal transaction detection based on graph networks. In: Proceedings of the 2021 IEEE 45th Annual Computers, Software, and Applications Conference (COMPSAC), pp. 312–317 (2021)
Bearing Fault Detection Based on Graph Cyclostationary Signal Analysis and Convolutional Neural Network Cong Chen1 and Hui Li1,2(B) 1 School of Mechanical Engineering, Tianjin University of Technology and Education,
Tianjin 300222, China [email protected] 2 Tianjin Key Laboratory of Intelligent Robot Technology and Application, Tianjin 300222, China
Abstract. Aiming at the traditional cyclostationary signal method, which ignores the relationship between data individuals and is difficult to effectively process irregular spatial structure data, such as graph signal, a graph cyclostationary signal preprocessing technique is proposed for bearing fault diagnosis. This method combines the advantages of graph signal processing that can effectively process graph signal and cyclostationary signal analysis that can effectively suppress Gaussian noise. Firstly, the vibration signal is converted into a graph signal, and then the graph cycle autocorrelation function (GCAF) and the graph spectral correlation density (GSCD) are calculated using the proposed method. GCAF and GSCD can effectively demodulate the simulative signal. Secondly, the experimental verification of the proposed method combined with convolutional neural network on the rolling bearing dataset of Case Western Reserve University is realized. Finally, comparative experiments show that the Convolutional neural network has a higher accuracy when graph cyclostationary signal analysis is used as a data preprocessing tool. Keywords: Bearing · Fault Diagnosis · Cyclostationary Signal Analysis · Graph Signal Processing · Convolutional Neural Network
1 Introduction With the rapid development of modern industrialization, production equipment is developing towards intelligence, large-scale, and complexity, while increasing the uncertainty of safe operation of equipment. Nowadays, rotating machinery, especially induction motors, plays an important role in industrial systems and is one of the key equipment in many industrial fields, such as petrochemical, aerospace, rail transit, and other industries. These rotating machines consist of components such as stators, rotors, shafts, and bearings. Bearing is the most important mechanical component in rotating machinery, used to guide and support the shaft in rotating machinery. On the other hand, research has shown that many mechanical failures in rotating machinery are caused by bearing failures. Therefore, accurate bearing fault diagnosis is extremely important [1]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 503–513, 2023. https://doi.org/10.1007/978-981-99-4761-4_43
504
C. Chen and H. Li
An effective bearing fault diagnosis program can improve the operating efficiency of the machine, reduce maintenance costs, and extend the service life of the machine [2]. Vibration signal processing is the most common used method for bearing fault diagnosis. Without changing the operating state of the machine, the transformation of machine vibration signals can indicate changes in the health state. Traditional fault diagnosis methods based on frequency domain assume that the vibration signal is linear and stationary, such as spectral kurtosis, cepstrum, and other analysis methods [3]. However, when a rolling bearing fails, its signal is mostly non-stationary. At this time, timefrequency analysis methods such as Wigner-Ville distribution, wavelet transform [4, 5] and cyclostationary signal methods can help to achieve the analysis of non-stationary signals. However, Wigner-Ville distribution has defects such as cross interference terms, and wavelet transform has problems such as optimal wavelet selection. Cyclostationary signals are a special class of non-stationary signals, whose statistical characteristics exhibit a certain periodicity with time. After being proposed in 1950, cyclostationary theory was widely used in the early 21st century [6, 7]. However, the above methods are based on regular structures data in Euclidean space, ignoring the relationships between data individuals, and the actual signals are often irregular, which has certain limitations when processing such data signals. In recent years, graph signal processing methods have gradually attracted the attention of scholars at home and abroad [8]. The main research object of graph signal theory is graph structure, which is composed of vertices and edges, and is used to describe specific relationships between individuals. Graph signal processing mainly analyzes the representation matrix of a graph, using the eigenvalues and eigenvectors of Laplacian matrix to achieve spectral analysis of the graph [9]. Currently, graph signal processing has been widely used in many fields such as image recognition, smart grids, social networks, and so on [10, 11]. However, research on extending traditional cyclostationary signal processing methods to the field of graph signal processing has not yet been carried out. In response to the above problems, this paper extends the traditional cyclostationary signal analysis methods to the field of graph signal processing, deeply studies the basic theories and methods of graph cyclostationary signal processing, proposes a bearing fault diagnosis method based on graph cyclostationary signal analysis, and combines the graph cyclostationary signal method with convolutional neural networks to achieve automatic identification and classification of fault types, in order to enrich and expand the theory and methods of fault diagnosis.
Bearing Fault Detection Based on Graph Cyclostationary Signal
505
2 Basic Theory 2.1 Cyclostationary Signal Analysis Cyclostationary signals are mainly divided into first order cyclostationary signals, second order cyclostationary signals, and higher order cyclostationary signals based on the characteristics of periodic statistics. This article is mainly based on the secondorder cyclostationary theory, including the cyclic autocorrelation function (CAF) and the spectral correlation density function (SCDF). Cyclic Autocorrelation Function Assuming the signal x(t) has a second-order timevarying statistical characteristic, its quadratic time-varying statistics is defined as Rx (t, τ ) = E{x(t)x∗ (t − τ )}
(1)
where * is the complex conjugate, τ is the time lag. The cyclic autocorrelation function is defined as the coefficient of its quadratic timevarying statistics to the time t Fourier transform Rαx (τ ) = x(t)x∗ (t − τ )e−j2π αt (2) t
where •t is the time average, generally, the frequency when α = 0 is referred to as the cyclic frequency. Spectral Correlation Density Function According to Wiener-Simpson’s theorem, the spectral correlation density function and the cyclic autocorrelation function form a pair of Fourier transform pairs, as shown in Eq. (3) +∞ Rx (α, τ )e−j2π ftτ d τ (3) Sx (α, f ) = −∞
where f is the spectral frequency. When α = 0, CAF degenerates into a traditional autocorrelation function, and SCDF degenerates into a power spectral density function, which can be used to process stationary signals. When α = 0, it is used to describe the cyclostationary characteristics of a signal, and all satisfying conditions are called cyclostationary sets.
2.2 Graph Signal Processing A graph is an abstract data structure that describes the geometric relationships between data. In practical applications, such as social networks, transportation networks, and the internet, they can be represented as graph structures to facilitate simplified analysis. Graphs play an important role in improving the analysis of high-dimensional data. The graph can be represented by a matrix, such as adjacency matrix, Laplacian matrix and so on. Therefore, the rolling bearing fault diagnosis method based on graph signal
506
C. Chen and H. Li
processing theory mainly analyzes the graph structure matrix representing the vibration data relationship, rather than the data and graph itself. Representation Matrix of Graph Structure Considering the undirected and unweighted path map G = (V , E, A), assume that the time series signal with a length of N is shown in Eq. (4) x(t) = [x(1), x(2) · · · x(N )]T
(4)
where V is a finite set of vertices (|V | = N ), E is the set of edges, A is the adjacency matrix with dimensions N ∗ N , if there is an edge e = (i, j) connection between vertices i and j, then Aij = 1, otherwise Aij = 0. The graph Laplacian matrix is defined as L=D−A
(5)
where D is a diagonal matrix, note that d is the degree of the vertex, and the degree of the vertex n is dn = m=n Amn . Graph Fourier Transform (GFT) and Graph Translation The GFT of a graph signal takes the characteristic vector of L as a basis function and performs a Graph Fourier transform on the graph signal x. Perform eigenvalues decomposition on L to obtain Eq. (6) L = U U −1
(6)
where = diag[λ1 , λ2 , λ3 , · · · λN ], λi is the ith eigenvalue of L.Characteristic matrix U = diag[u1 , u2 , u3 · · · uN ], ui is the eigenvector corresponding to the ith eigenvalue of L. The GFT of graph signal x represented by X is defined as X = U −1 x
(7)
The inverse graph Fourier transformation (IGFT) is defined as x = UX
(8)
Assuming the graph signal x is translated to all vertices, a translation matrix is defined as SM =
√
N U XM ∗ U T
(9)
where ∗ represents matrix element-wise multiplication, XM represents graph Fourier transform matrix, which is defined as ⎡ ⎤ X (λ1 ) X (λ1 ) · · · X (λ1 ) ⎢ ⎥ . ⎢ X (λ2 ) X (λ2 ) .. X (λ2 ) ⎥ ⎢ ⎥ XM = ⎢ . (10) .. .. ⎥ .. ⎣ .. . . . ⎦ X (λN ) X (λN ) · · · X (λN ) where X (λi ), i = 1, 2 · · · N is the element corresponding to X .
Bearing Fault Detection Based on Graph Cyclostationary Signal
507
2.3 Graph Cyclostationary Signal Analysis Extending traditional cyclostationary theory to the field of graph signal processing, a graph cyclostationary signal analysis method is proposed. The graph correlation matrix of a time-varying graph is defined as FX = SM ∗ Flip(SM )
(11)
where * is matrix element-wise multiplication that represents the multiplication of corresponding elements of a matrix. Flip( • ) represents the order of elements in each column of the inverted matrix. Analogy to traditional cyclostationary theory, the cyclic autocorrelation function is the coefficient of the Fourier transform of the signal’s quadratic time-varying statistics. In the graph domain, the GCAF is defined as the transpose of the time-varying graph correlation matrix after GFT T (12) GCAF = U −1 FX The GSCD is a graph Fourier transform of the graph cyclic autocorrelation function, defined as GSCD = U −1 GCAF
(13)
3 Simulative Signal Analysis Assume that a simulation signal is given in Eq. (14) y(t) = ((1 + cos(2π 30t))) ∗ cos(2π 250t)
(14)
where the carrier frequency fc = 250 Hz, the modulation frequency f0 = 30 Hz, the sampling frequency fs = 2000 Hz, the sampling duration 0.3s, and the total sampling length N = 600. It can be seen from Fig. 1(a) that the high-frequency signal is modulated by the lowfrequency signal. It can be seen from Fig. 1(b) that the spectrum of the simulation signal y(t) is centered on the carrier frequency fc = 250 Hz, and the modulation frequency f0 = 30 Hz is the sideband spectrum structure.
508
C. Chen and H. Li
(a) Simulated signal
(b) FFT of simulated signal
Fig. 1. Time domain waveform of simulative signal and its FFT
It can be seen from the three-dimensional map of GCAF in Fig. 2(a) that GCAF divides the frequency of the carrier signal and the modulated signal into two different frequency bands. It can be seen from the outline diagram of the GCAF cycle frequency direction in Fig. 2(b) the peak value of the low frequency portion of the cyclic frequency is the modulation frequency and its second harmonic. The high frequency portion of the cyclic frequency simultaneously contains information about the 2-fold frequency and modulation frequency of the carrier signal.
(a) Three-dimensional plot of GCAF
(b) Outline of GCAF
Fig. 2. Three-dimensional diagram of GCAF and axial direction profile of cyclic frequency
It can be seen from Fig. 3(a) that GCSD also has peaks in different frequency bands. Figure 3(b) shows the contour of the cyclic frequency axis direction of GSCD, indicating that GSCD also well realizes signal modulation and demodulation.
Bearing Fault Detection Based on Graph Cyclostationary Signal
(a) Three-dimensional graph of GSCD
509
(b) Outline of GSCD
Fig. 3. Three-dimensional diagram of GSCD and axial direction profile of cyclic frequency
4 Experimental Validation of Bearing Datasets In order to further verify the reliability of the proposed method, the graph cyclostationary signal method is combined with a convolutional neural network to achieve fault diagnosis and automatic classification of the rolling bearing dataset of Case Western Reserve University. The experimental flow chart is shown in Fig. 4.
Fig. 4. Experimental flow chart
4.1 Experimental Setup The rolling bearing data set of Case Western Reserve University includes vibration signal data under four operating conditions: 0hp, 1hp, 2hp, and 3hp (horsepower), each of which corresponds to 10 different data types. This experiment uses vibration signal data with a sampling frequency of 12k and a load of 2. The number of samples for each type is 1024. The training set, verification set and test set are divided in 7:2:1 ratio. The data collection method is overlapping sampling, with a step size of 100. As shown in Table 1.
510
C. Chen and H. Li Table 1. Experimental data set
Location
Ball
Inner C3
C1
Diameter
0.007 0.014 0.021 0.007 0.014 0.021 0.007 0.014 0.021 0 700
700
C4
C5
700
700
C6
C7
Normal
Markers
Data Set Train 700
C2
Outer C8
700
700
700
C9 700
C0 700
Val
200
200
200
200
200
200
200
200
200
200
Test
100
100
100
100
100
100
100
100
100
100
The experiment uses the Python 3.8 programming environment and the Pytorch 1.12.1 deep learning framework to build the convolutional neural network. The input of the neural network is the calculated GCAF, because the GCAF has a clearer timefrequency graph. The network structure is composed of one input layer, one convolutional layer, one pooling layer, one full connection layer, and one output layer. The convolution kernel_size is 3*3 with a step size of 1, and the pooling layer kernel size is 2*2 with a step size of 2. Full connectivity layer dimension is 1024. The hyperparameter settings are given in Table 2. (The hyperparameters are obtained by selecting the optimal model parameters through repeated experiments). Table 2. Hyperparameters Number
Name
Parameter settings
1
epoch
200
2
Batch size
128
3
Learning rate
0.01
4
Weight decay
0.0005
5
gamma
0.9
4.2 Analysis of Experimental Results The experimental results can be analyzed for the accuracy, stability, and classification effect of the model through Acc curve, Loss curve, and t-SNE visualization. It can be seen from Fig. 5(a) that the loss function curve shows that the model has a faster learning speed, and Fig. 5(b) shows that the accuracy rate has approached 100% after 200 training sessions.
Bearing Fault Detection Based on Graph Cyclostationary Signal
(a) Loss function curve
511
(b) Accuracy curve
Fig. 5. Loss function curve and accuracy curve
(a) Training set confusion matrix
(b) Test set confusion matrix
Fig. 6. Confusion matrix for training and testing sets
Figure 6 shows the confusion matrix for the training set and the test set. It can be seen from Fig. 6(a) that the accuracy rate of the training set has reached 100%, and it can be seen from Fig. 6(b) that only one sample of the training set did not achieve correct classification. Figure 7 shows the output layer t-SNE visualization and learning rate change curve. It can be seen from Fig. 7(a) that the fault data is divided into 10 different colors, and only one magenta sample (C5) is divided into clusters classified as orange (C6), which is consistent with the confusion matrix of the test set, indicating that the model has distinguished different fault types. Figure 7(b) shows the update curve of the learning rate. The model with a larger initial learning rate has a faster learning rate, and the final learning rate is close to 0, which shows a more stable convergence effect for the model.
512
C. Chen and H. Li
(a) Visualization of t-SNE
(b) Learning rate curve
Fig. 7. Output layer t-SNE visualization and learning rate curve
4.3 Comparative Experiment
Accuracy
Set up a group of comparative experiments to verify that the proposed method can effectively extract bearing fault features and provide better preprocessing data for neural networks. The GCAF and GSCD calculated by the proposed method and the SCD and CAF calculated by the traditional cyclostationary signal method are respectively input into the neural network. The experimental data are respectively 0hp, 1hp and 2hp vibration signals under three kinds of loads. The accuracy of the test set is shown in Fig. 8. When the proposed method is used for preprocessing, the model has a higher accuracy. 100.00% 95.00% 90.00% 85.00% GSCD SCD GCAF CAF
0HP 99.60% 99.30% 99.30% 99.10%
1HP 97.70% 97.60% 98.90% 98.85%
2HP 99.80% 99.55% 99.90% 99.65%
Fig. 8. Test Set Accuracy Comparison Histogram
5 Conclusion Based on cyclostationary signal analysis theory and graph signal processing methods, a graph cyclostationary signal analysis method is proposed. Simulative experiments show that this method can effectively achieve signal demodulation. Moreover, it has clearer time-frequency and bispectral images than the traditional cyclostationary signal method. Finally, a fault classification experiment on rolling bearing data sets is conducted using neural networks. Compared with the traditional cyclostationary signal method, the model has a higher accuracy when using the proposed method to pretreat the vibration signals of rolling bearings.
Bearing Fault Detection Based on Graph Cyclostationary Signal
513
Acknowledgement. This research is a part of the research that is sponsored by the Science and Technology Planning Project of Tianjin (Grant No. 22YDTPJC00740).
References 1. Boudiaf, A., Moussaoui, A., Dahane, A., et al.: A comparative study of various methods of bearing faults diagnosis using the Case Western Reserve University data. J. Fail. Anal. Prev. 16(2), 271–284 (2016) 2. Tian, R. L., Yu, K., Tan, J.: Condition monitoring and fault diagnosis of roller element bearing. Bearing Technology (2017) 3. Wei, Z., Wang, Y., He, S., et al.: A novel intelligent method for bearing fault diagnosis based on affinity propagation clustering and adaptive feature selection. Knowl.-Based Syst. 116(C), 1–12 (2017) 4. Liu, G., Qu, L.: Feature extraction of mechanical faults based on the continuous wavelet transform. J. Xi’an Jiaotong Univ. 11, 74–77 (2000) 5. Feng, L., Kang, J., Meng, Y., et al.: Research on feature extraction of rolling bearing incipient fault based on Morlet wavelet transform. Chin. J. Sci. Instrum. 34(04), 920–926 (2013) 6. Li, H.: Study on high efficient algorithm for cyclic correntropy spectral analysis. J. Electron. Inf. Technol. 43(02), 310–318 (2021) 7. He, J., Chen, J., Bi, G., et al.: Frequency demodulation analysis of degree of cylcostationary and its application to gear defect detection. J. Shanghai Jiaotong Univ. (Chin. Ed.) 41(11), 1862–1866 (2007) 8. Xu, K., Hu, W., Leskovec, J., et al.: How powerful are graph neural networks? In: International Conference on Learning Representations (2018) 9. Luxburg, U.V.: A tutorial on spectral clustering. Stat. Comput. 17(4), 395–416 (2014) 10. Zhou, J., Cui, G., Hu, S., et al.: Graph neural networks: a review of methods and applications. AI Open 1, 57–81 (2020) 11. Xu, B., Shen, H., Cao, Q., et al.: Graph wavelet neural network. In: International Conference on Learning Representations (2019)
Rolling Bearing Fault Diagnosis Based on GWVD and Convolutional Neural Network Xiaoxuan Lv1 and Hui Li1,2(B) 1 School of Mechanical Engineering, Tianjin University of Technology and Education,
Tianjin 300222, China [email protected] 2 Tianjin Key Laboratory of Intelligent Robot Technology and Application, Tianjin 300222, China
Abstract. In order to solve the problems that the traditional signal processing method cannot deal with non-Euclidian data, a graph Wigner-Ville distribution (GWVD) and convolutional neural network (CNN) based rolling bearing fault diagnosis method is proposed. GWVD combines graph signal processing with Wigner-Ville distribution and can effectively suppress the cross interference terms. GWVD based CNN has the advantages of extremely high energy aggregation and resolution, which improves the pattern recognition accuracy of convolutional neural network in noisy environment. The principle of GWVD and CNN based rolling bearing fault diagnosis are introduced. Firstly, one-dimensional bearing fault vibration signals are converted into graph signals according to the similarity between sampling points and then the completed data are divided proportionally. Secondly, the GWVD based pre-processing is used as the input layer of the convolutional neural network, and the bearing fault feature extraction and fault classification are realized by two-dimensional convolutional neural network. Finally, the Case Western Reserve University standard bearing data set is used for experimental verification and compared with the traditional Wigner-Ville distribution signal preprocessing method. The experimental results show that GWVD based CNN has higher energy aggregation and resolution, and is an effective intelligent bearing fault diagnosis method with better performance than that of traditional WVD based CNN. Keywords: Graph Wigner-Ville Distribution · Graph Signal Processing · Convolutional Neural Network · Deep Learning · Fault Diagnosis
1 Introduction Deep learning is a huge breakthrough in artificial intelligence. Its deep network structure can directly learn the most essential features from vibration signals to realize fault diagnosis, avoiding the process of manual feature extraction by various signal processing methods [1–3]. At present, deep learning methods such as convolutional neural network (CNN), deep belief network (DBN), recurrent neural network (RNN) and stacked autoencoder (SAE) have been initially applied in the field of bearing fault diagnosis [4]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 514–523, 2023. https://doi.org/10.1007/978-981-99-4761-4_44
Rolling Bearing Fault Diagnosis Based on GWVD and CNN
515
In recent years, graph signal processing has opened up a new field of signal processing. Yu et al. [5] proposed the EMD and AR model of the rolling bearing fault diagnosis method, which the rolling bearing vibration signal is decomposed into several intrinsic mode functions (IMF). The AR model was established for each IMF component, and the Mahalanobis distance based discriminant function was established with the autoregressive parameters of the model and the variance of the residual as the feature vector, so as to judge the working state and fault type of the rolling bearing. Aiming at the problem that it is difficult to extract fault signals of rolling bearings under strong background noise and random pulse interference. Zhang et al. [6] proposed an improved kurtograph method for fault diagnosis of rolling bearings, compared and analyzed the effect of improved kurtograph method and fast kurtograph method for fault diagnosis of rolling bearings. The improved kurtograph method can more accurately determine the resonance band. Aiming at the problem that the traditional fault feature extraction method of rolling bearings is susceptible to external noise interference and contains a large amount of redundant information. Zhou et al. [7] proposed a fault diagnosis method that can extract the feature of view graph signal, which has an 8.34% higher accuracy than the traditional fault feature extraction method and effectively realizes bearing fault diagnosis. Kumara et al. [8] proposed a fault diagnosis model based on particle swarm optimization of deep belief network structure, and realized adaptive adjustment of structural parameters of deep belief network through particle swarm optimization algorithm, so as to effectively realize bearing fault diagnosis. Yiakopoulos et al. [9] realized rolling bearing fault diagnosis through recurrent neural network, and realized fault detection through gated recurrent unit based on stack auto-encoder in the model, which improved the generalization ability of the model. Although existing researches on graph Fourier transform, graph wavelet transform and other methods can identify faults by establishing the relationship between time and frequency, local analysis is usually more effective than global analysis when large graphs are used as signal domains. In order to characterize the vertex local characteristics and narrow-band spectral characteristics of signals, the classical time-frequency analysis method is extended to the vertex frequency analysis of signals defined on the graph, the traditional vibration signal is converted into the form of a graph, and the study of vertex domain is extended. By studying the specific frequency of a vertex and the relationship between vertexes, a new signal processing which combined with graph signal processing technique and Wigner-Ville distribution is put forward, and the signal is input into the one dimensional convolutional neural network for rolling bearing fault intelligent diagnosis. By modulating the simulation signal, the unweighted undirected ring graph are established, and the Fourier transform of the fast windowed graph and the fractional graph is calculated. The value of GWVD is obtained, and the visual image is constructed. The cross-term interference problem is verified by this method. In this paper, the traditional WVD method was compared with the GWVD method by setting the analog signal, and the GWVD method was verified by using the Case Western Reserve University (CWRU) motor bearing standard data. The method is an efficient intelligent fault identification technology with high energy aggregation and resolution.
516
X. Lv and H. Li
2 The Definition of GWVD and Fault Diagnosis Steps 2.1 Graph Signal Processing Graph Signal Processing (GSP) is an application of Discrete Signal Processing (DSP) theory in the graph signal field. Through the transfer of Fourier transform, filtering and other basic concepts of signal processing, it studies the basic tasks of graph signal processing, such as compression, transformation and reconstruction. Generally, the sampling point is taken as the Vertex of the graph, the similarity of the amplitude of the time-domain signal is taken as the edge of the graph, and the amplitude of the time-domain signal is taken as the signal on the graph to construct the graph and the graph signal. At present, the relationship between graph signal processing and traditional time-frequency signal processing can be established by using unweighted undirected graph and adjacency matrix, and unweighted undirected path graph and Laplacian matrix. 2.2 The Definition of Wigner-Ville Distribution Let the Fourier transform of the continuous signal x(t) be x(j), then the Wigner-Ville distribution is defined as ∞ τ τ x(t + )x ∗ (t − )e−jt d τ (1) Wx (t, ) = 2 2 −∞ The Wigner-Ville distribution can also be expressed in the frequency domain form of x(t) ∞ 1 θ θ X ( + )X ∗ ( − )ejtθ d θ (2) Wx (t, ) = 2π −∞ 2 2 Wigner-Ville distribution (τ, v) = 1 kernel function [10], so that it does not need to choose kernel function, the analysis process is simple, but also has the shift frequency shift invariance edges conditions, frequency, time, condition, unique properties such as it has been widely used. 2.3 The Definition of Graph Wigner-Ville Distribution With graph signal x, its Laplacian eigenvalue decomposition can be defined as L = U U T
(3)
where L is the laplacian matrix of the graph signal x, U is the eigenvector matrix, is the eigenvalue matrix. The graph Fourier transform of graph signal x can be defined as X = UTx
(4)
Rolling Bearing Fault Diagnosis Based on GWVD and CNN
517
The graph signal matrix xM can be defined as ⎡
⎤ x1 · · · x1 x2 · · · x2 ⎥ ⎥ .. .. ⎥ . ··· . ⎦ xN xN · · · xN
x1 ⎢ x2 ⎢ xM = ⎢ . ⎣ ..
(5)
The Fourier transform matrix of graph signal x can be defined as ⎡
X (λ1 ) X (λ1 ) · · · X (λ1 ) ⎢ X (λ2 ) X (λ2 ) · · · X (λ2 ) ⎢ XM = ⎢ .. ⎣ .
⎤ ⎥ ⎥ ⎦
(6)
X (λN ) X (λN ) · · · X (λN )
Shift graph signal x to all vertices i, i = 1, 2, 3…N, then we get the translation matrix SM SM √ SM = N U (XM ◦ U T ) (7) where the symbol “◦” denotes element-wise multiplication, the resulting GWVD formula is GWVD = U T (xM ◦ SM )T
(8)
GWVD can not only overcome the shortcoming of traditional WVD method to produce cross-terms, but also analyze local vertices, with stronger energy aggregation and higher resolution. The main steps of GWVD and CNN rolling bearing fault feature extraction are given as follows: 1) 2) 3) 4) 5)
Convert the time vibration signal x(t) into graph data. Calculate the Laplacian matrix L and perform eigenvalue decomposition. Take the graph Fourier transform of the graph signal. Divide training set, test set and verification set. Input the data preprocessing into the input layer of the two-dimensional convolutional neural network, and train the data and carry out feature extraction after operations such as convolutional pooling. 6) Rolling bearing fault classification visualization.
3 Rolling Bearing Fault Simulation Signal When a single damage point fault occurs in a rolling bearing, the frequency spectrum diagram of vibration signals generally takes the fault characteristic frequency of the component where the damage point is located as the central frequency and the rotation frequency of the bearing as the modulation side frequency band of the modulation frequency. The modulation information is extracted from the signal and the bearing fault location and degree can be obtained by analyzing its strength and frequency. The
518
X. Lv and H. Li
sampling frequency is 2000 Hz, carrier frequency is 250 Hz, modulation frequency is 30 Hz, sampling number is 600, and sampling time is 0.3 s. When the bearing is damaged, the bearing rotates once, and the damage point and other components in contact with it collide once, the amplitude of vibration signal will produce amplitude modulation phenomenon. The simulative signal model after modulation can be expressed as x(t) = ((1 + cos(2π 30t))) ∗ cos(2π 250t)
(9)
The time domain waveform was drawn according to the simulative signal, and the fast Fourier transform was carried out, as shown in Fig. 1. According to the time domain waveform diagram of the simulated signal, the characteristics of the simulative signal can be observed from the image after fast Fourier transform that peak values are generated at 220 Hz, 250 Hz and 280 Hz, which correspond to the carrier frequency and modulation frequency of the original signal.
(a)Time domain waveform
(b)Fast Fourier transform graph
Fig. 1. Simulation signal diagram
According to the formula derived above, the images of WVD method and GWVD are directly calculated. According to the two-dimensional and three-dimensional spectral graphs under WVD method, it can be seen that there are frequency components between 0 Hz–100 Hz and 200 Hz–400 Hz. There are serious cross terms in the images, and clear and useful information cannot be obtained from the blurred vision. The details are shown in Fig. 2. The GWVD two-dimensional and three-dimensional spectrum diagram of the signal is calculated, and the characteristic frequency displayed in the image is compared with the traditional WVD method. From Fig. 3, it can be clearly seen that the GWVD method concentrates the modulation and carrier signal in the intermediate frequency stage, eliminating the cross term, and there are frequency components at 220 Hz, 250 Hz and 280 Hz. The results are consistent with those obtained by WVD method. It shows that GWVD method overcomes the cross interference term problem of traditional WVD method, has high energy aggregation and resolution, is very suitable for non-stationary and nonlinear signal processing, and can identify faults effectively.
Rolling Bearing Fault Diagnosis Based on GWVD and CNN
(a)2D WVD
519
(b)3D WVD
Fig. 2. Wigner-Ville distribution of simulated signals
(a)2D GWVD
(b)3D GWVD
Fig. 3. Graphical Wigner-Ville distribution of simulated signals
4 GWVD-CNN Based Rolling Bearing Fault Diagnosis 4.1 Experimental Data of CWRU The GWVD method was applied to the motor rolling bearing data set of Case Western Reserve University (CWRU) in the United States [11, 12]. The vibration signal acquisition object was the motor drive end (Drive End, DE) deep groove ball bearing, the model was SKF6205, and the vibration signal sampling frequency was 12 kHz. According to the bearing damage diameter and defect location, the fault types are divided into 10 categories, and the bearing defect location is divided into inner ring, outer ring and rolling body fault. The damage diameter is 0.007 in, 0.014 in and 0.021 in respectively, and the vibration signal in normal state is the first category. In the training process, the label of each type of fault is coded. The experiment adopted a data set with a load of 0hp (1797 rpm), each sample length was 1024 sampling points, and the step size was 100 as the offset for overlapping sampling.1000 samples were selected for each type of fault, a total of 10,000 sample data sets were established, sample labels were coded, and the training set, verification set and
520
X. Lv and H. Li
test set were divided in a ratio of 7:2:1, as shown in Table 1. The training set by GWVD method for data preprocessing, and then into the CNN model for training, in the training process, the CNN model with 128 size as a Batch (Batch), the maximum iteration times of the data set epoch set to 500, the initial learning rate is 0.001, loss function using cross entropy loss function. Similarly, the verification set and test set were preprocessed by GWVD and fed into the CNN model after training. The verification set was used to adjust the model parameters, and the test set was used to verify the model recognition accuracy and generalization ability after training. Table 1. Experiment dataset. Fault Position
Ball
Label
1
2
3
4
5
6
7
8
9
10
Fault Diameter
0.007
0.014
0.021
0.007
0.014
0.021
0.007
0.014
0.021
0
Training
700
700
700
700
700
700
700
700
700
700
Validation
200
200
200
200
200
200
200
200
200
200
Testing
100
100
100
100
100
100
100
100
100
100
dataset
Inner Race
Outer Race
Normal
4.2 GWVD-CNN Bearing Fault Diagnosis The GWVD-CNN bearing fault diagnosis process proposed in this paper carries out detailed experimental analysis and verification on the above experimental data set from three aspects: data preprocessing, CNN feature extraction and recognition, and result visualization analysis. Data preprocessing mainly converts one-dimensional vibration signals into two-dimensional images by GWVD method, and carries out noise reduction and dimension reduction of signals at the same time, which is conducive to input them into CNN for training. 2D CNN is mainly used to realize automatic fault feature extraction and fault mode classification of bearing fault types. The visualization of fault classification is used to visually demonstrate the effectiveness of GWVD-CNN model in effectively classifying various faults. The GWVD-CNN bearing fault diagnosis flow chart is shown in Fig. 4. 4.3 Data Preprocessing When observing the differences of different fault types, bearing vibration signals of different fault types are obviously different on the whole, but some vibration signals of different types are similar in visual observation. In practice, when the number of sampling points is small, their differences are more difficult to distinguish, especially when the working conditions are complex and the background noise is large. It is more difficult to identify fault type directly according to one-dimensional vibration signal. The data set was divided according to Table 1, and the GWVD method was used for data preprocessing of training samples under 0hp load.
Rolling Bearing Fault Diagnosis Based on GWVD and CNN
521
Fig. 4. GWVD-CNN Flow chart of bearing fault diagnosis
Fig. 5. 2D-CNN network model structure
4.4 2D CNN Two-dimensional convolutional neural network (2D-CNN) is generally composed of input layer, convolutional layer, pooling layer, full connection layer and output layer. The whole network consists of one convolution layer, one pooling layer, two full connection layers and one output layer. The input is 32 × 32 two-dimensional images, and 5 × 5 convolution kernels with step size of 5 are used. Activated by Relu function, the convolution is pooled, the size of pooled kernels is 2 × 2, and the output size becomes 16 × 16 × 8. The first full connection paved the pooled results into a one-dimensional vector form, that is, the input size was 1024, the output size was 1024 × 1, and the output size of the second full connection was 10 × 1, that is, the probability values of 10 categories were predicted respectively. Figure 5 shows the CNN network structure.
522
X. Lv and H. Li
4.5 Visual Analysis of GWVD-CNN Fault Diagnosis The rolling bearing fault diagnosis classification experiment was carried out according to the CNN model, and the Accuracy curve and Loss curve of the rolling bearing fault diagnosis method based on GWVD-CNN were drawn, as shown in Fig. 6. It can be clearly seen that the accuracy of the training set and the test set tends to be stable after 100 iterations, and the accuracy of the training set reaches 99% and maintains this trend for stable training all the time. Although the test set fluctuated slightly, it tended to be stable as a whole, and the accuracy rate reached 94% after 10 iterations, and remained stable after 50 iterations, and the accuracy rate was always higher than 95%. The target error loss value begins to converge after 50 iterations and then remains stable. The visualized results verify the stability of the fault diagnosis model, and the accuracy of the final training set and test set are 99.4% and 98.4%, respectively.
(a)Loss function curve
(b)Accuracy curves
Fig. 6. Accuracy and loss curves
(a)Training Set
(b)Testing Set
Fig. 7. Confusion matrix
The confusion matrix is used to visualize the classification results after 500 iterations of the training set, as shown in Fig. 7. The horizontal and vertical axes represent the types of predicted and actual fault types. The results show that the GWVD-2DCNN model has well separated samples with different fault degrees. Acknowledgement. This research is a part of the research that is sponsored by the Science and Technology Planning Project of Tianjin (Grant No. 22YDTPJC00740).
Rolling Bearing Fault Diagnosis Based on GWVD and CNN
523
References 1. Zhang, Y., Wang, F., Wang, W.: Research of motor fault diagnosis method based on noise analysis. Micromotors 45(08), 83–87 (2012) 2. Zheng, Y., Li, G., Li, Y.: Survey of application of deep learning in image recognition. Comput. Eng. Appl. 55(12), 20–36 (2019) 3. Ren, H., Qu, J., Chai, Y., et al.: Deep learning for fault diagnosis: the state of the art and challenge. Control Decis. 32(8), 1345–1358 (2017) 4. Zhang, Z., Xiao, N., Wang, C., et al.: Fault diagnosis analysis of rolling bearing based on vibration signal. Railway Qual. Control 50(06), 21–24 (2022) 5. Yu, D., Cheng, J., Yang, Y.: Rolling bearing fault diagnosis method based on EMD and AR model. J. Vib. Eng. 03, 84–87 (2004) 6. Zhang, H., Chen, B., Song, D.: Bogie bearing fault diagnosis based on improved Kurtogram. Urban Rail Transit Res. 22(02), 41–47 (2019) 7. Zhou, J., Huang, X., Xiong, W., et al.: Fault diagnosis of rolling bearing based on visual spectrum signal feature extraction. Manuf. Technol. Mach. Tools 723(9), 1005–2402 (2022) 8. Kumar, A., Kumar, R.: Role of signal processing, modeling and decision making in the diagnosis of rolling element bearing defect: a review. J. Nondestr. Eval. 38(1), 1–29 (2019) 9. Yiakopoulos, C.T., Gryllias, K.C., Antoniadis, I.A.: Rolling element bearing fault detection in industrial environments based on a K-means clustering approach. Expert Syst. Appl. 38(3), 2888–2911 (2011) 10. Chen, G., Ma, S., Liu, M., et al.: Wigner-Ville distribution and cross Wigner-Ville distribution of noisy signals. J. Syst. Eng. Electron. 05, 1053–1057 (2008) 11. Smith, W.A., Randall, R.B.: Rolling element bearing diagnostics using the case western reserve university data: a benchmark study. Mech. Syst. Sign. Proces., 64–65 (2015) 12. Pan, L., Gong, Y., Yan, S.: Research on bearing fault diagnosis method based on improved one-dimensional convolutional neural network. Softw. Guide, 1–5 (2023)
Expert Systems
Aggregation of S-generalized Distances Lijun Sun, Chen Zhao, and Gang Li(B) School of Mathematics and Statistics, Qilu University of Technology (Shandong Academy of Sciences), Jinan 250353, China [email protected]
Abstract. Aggregation function is a technique of combining a collection of data from several sources into a representative one value. In recent years, fuzzy binary relations and S-generalized distances have become the objects of aggregation functions. Applications of S-generalized distances are common in the domains of computer science and management of databases. In this paper, we deal with the S-generalized distance aggregation, which merge a family of Si -generalized distances into a new S-generalized distance. As a result, we characterize these aggregation functions by meaning of extended dominance and S-triangle triple. Furthermore, we tackle the case: Si and S are strict t-conorms. Keywords: Computer science · Dominance · Aggregation function · S-generalized distance · S-triangle triple
1 Introduction In information fusion, we often deal with the process of merging information from different sources into a single one. The mathematical model describing the process is often based on aggregation functions [11, 14]. Aggregation functions have a wide range of applications in probability, statistics, decision theory and computer science. For instance, in database management or MongoDB, aggregation functions are frequently used in to group, manipulate, and analyze data [3, 8, 10]. Moreover, aggregation functions can also be used to analyze information and solve equations in mathematical fields. Many techniques of aggregation impose a restriction when selecting the most suitable aggregation function for the problem under consideration. In general, this restriction demands that the output results of the aggregation have the identical properties as the input data. An example of this kind of situation matches the case in which a collection of metrics (distances) must be amalgamated into a new one. Since the concept of metric plays a central role as measurement tool in applied research, many authors have studied in depth how a collection of metrics can be combined into a single one by means of aggregation function. In fact, in 1981 Borsik and Doboš profoundly studied the general problem of merging a collection of distances (not necessarily finite) into a single one [1]. Recently, Pradera et al. have provided, in the spirit of Borsik and Doboš, a general solution to the problem of merging data represented by means of a finite family of generalized distances and pseudo distances [17, 18]. Several general techniques for merging a finite © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 527–536, 2023. https://doi.org/10.1007/978-981-99-4761-4_45
528
L. Sun et al.
number of distances into another one have been studied by Casasnovas and Roselló in [6]. A prototypical example where aggregation function theory has been successfully applied and, in addition, S-generalized distances have demonstrated to be particularly very useful is supplied by the so-called fuzzy databases, that is databases where uncertain information can be managed [2]. Besides, S-generalized distances such as ultrametric are widely applied in computer science. The idea of bisimulation [15, 16, 20] is one of concurrency theory’s most significant contributions to computer science. It describes when two systems can act in a similar manner, such as when one system can emulate another and vice versa. Intuitively, two systems are bisimilar if they match each other’s moves. In addition to testing behavioral equivalence, bisimulation allows one to reduce the state space of a system by aggregating bisimilar states. Reference [5, 7, 9, 21] utilizes an ultrametric to measure the similarity of states in a (nondeterministic) fuzzy transition system (FTS). So it is very interesting to research the aggregation of S-generalized distances. We focus on the S-generalized distance aggregation in this paper. As a result, we define extended dominance and S-triangle triple to describe these aggregation functions. We also address the situation when Si and S are strict t-conorms. The structure of the paper is as follows. Firstly, we recall some basic definitions of t-conorms which will be used in the sequel and the notion of dominance. In Sect. 3, we characterize S-generalized distance aggregation function by means of triangular triple and extended dominance relation. Finally, we will close the contribution with a short summary.
2 Preliminaries Before start of this section, let’s provide some of notations that will be used in this paper. • a = (a1 , ..., an ). Without causing confusion, we sometimes shorten a as (ai ). • FSS1 ,S2 ,...,Sn represents a function that aggregates a family of Si -generalized distances into a S-generalized distance, i = 1, ..., n. i = 1, ..., n. Now, we give the definitions of the aggregation function, the t-conorm, dominance, and other terms in this section. Definition 1. [4] An aggregation function in In is a function An : In → I that (i) is non-decreasing (in each variable). (ii) fulfills the boundary condition: −
inf A(n) x
−
x∈In
−
= inf I and sup A(n) x −
= sup I, where
−
x
= (x1 , ..., xn ),I is a
x∈In
nonempty real interval of the extended real number system. A specific case of aggregation functions is presented below. Definition 2. [12]. A two-place function S : [0, 1]2 → [0, 1] is called t-conorm if it is associative, commutative, with 0 as the neutral element and non-decreasing in each
Aggregation of S-generalized Distances
529
place. moreover, t-conorm S is called strict if it is continuous and strictly increasing on [0, 1)2 . Remark 1. The binary function T-conorm can be easily extended to the n-ary case due to the associativity. The four basic t-conorms SM , SP , SL , and SD are usually discussed in literature. They are defined by, respectively: SM (x, y) = max(x, y), SP (x, y) = x + y − x · y, SL (x, y) = min(x + y, 1), 1 (x, y) ∈ (0, 1]2 , SD (x, y) = max(x, y) otherwise. Remark 2. For any t-conorm S, we have that SM ≤ S ≤ SD , that is SM (x, y) ≤ S(x, y) ≤ SD (x, y), for all (x, y) ∈ [0, 1]2 . Definition 3. [19] A continuous t-conorm S is an Archimedean t-conorm if and only if there exists an increasing and continuous function s : [0, 1] → [0, +∞] with s(0) = 0 such that for all a, b ∈ [0, 1], S(a, b) = s−1 (min(s(1), s(a) + s(b))). Function s is called an additive generator of s, and is uniquely determined up to a positive multiplicative constant. We recall the definition of S-generalized distance based on the above t-conorm. Definition 4. [19] Let X be a set, S be a t-conorm. A function d: is said to be a Sgeneralized distance on X if it verifies the following properties for any x, y, z in X: (i) d (x, x) = 0. (ii) d (x, y) = d (y, x). (iii) d (x, z) ≤ S(d (x, y), d (y, z)). n is called to be amenable provided Definition − 5. [13] A function F : [0, ∞) → [0, ∞) − that F a = 0 ⇔ a1 = a2 = ... = an = 0, for all a = (ai ), i = 1, ..., n.
Definition 6. [13] Let A be a non-decreasing binary function defined in X 2 and B n a non-decreasingn-ary function − defined in X . We say that A dominates B if for all − − − a, b ∈ [0, 1]n , A B a , B b ≥ B(A(a1 , b1 ), ..., A(an , bn )). In the next section, we will study the aggregation of S-generalized distances based on the dominance relationship, the triangular triple and its extended form.
3 Aggregation of S-generalized Distances In this section, we focus on the aggregation of S-generalized distances. First, we recall the concept of S-generalized distance aggregation function as follows. Definition 7. A function F : [0, 1]n → [0, 1] is said to aggregate a family of Si -generalized distances (i = 1, ..., n) into a S-generalized distance if F(d1 , · · · dn )
530
L. Sun et al.
is S-generalized distance on X for any set X and any collection of Si -generalized distances d1 , · · · , dn on X , where F(d1 , · · · dn ) is given by F(d1 , · · · , dn )(x, y) = F(d1 (x, y), · · · dn (x, y)),x, y ∈ X . We denote the set of all such functions F by FSS1 ,··· ,Sn . Next, we extend the concept of dominance relation and characterize the S-generalized distance aggregation function. Definition 8. Let A and A1 , · · · , An be non-decreasing binary functions defined in X 2 , B a non-decreasing n-ary function defined in X n . We say that A extended dominates B − −
with respect to (A1 , · · · , An ), if for all a, b ∈ [0, 1]n , − − ≥ B(A1 (a1 , b1 ), · · · , An (an , bn )). A B a ,B b
(1)
Remark 3. (i) If A = A1 = · · · = An , then the extended dominance relation is reduced to the dominance relation in Definition 8. (ii) If A dominates B and satisfies A ≥ Ai , i = 1, · · · , n, then A extended dominates − −
B with to (A1 , ..., An ). Indeed, if A dominates B then for all a, b ∈ [0, 1]n , respect − − A B a ,B b ≥ B(A(a1 , b1 ), ..., A(an , bn )). Due to the non-decreasing properties of A, Ai and B, i = 1, · · · , n, we have B(A(a1 , b1 ), · · · , A(an , bn )) ≥ B(A1 (a1 , b1 ), · · · , An (an , bn )). − − ≥ B(A1 (a1 , b1 ), · · · , An (an , bn )). So, A B a , B b Theorem 1. Let S be a t-conorm and di , i = 1, ..., n, be a collection of S i -generalized distances. If a function F : [0, 1]n → [0, 1] is non-decreasing, amenable and S extended dominates F with respect to (S1 , ..., Sn ), that is S(F(a), F(b)) ≥ F(S1 (a1 , b1 ), ..., Sn (an , bn )),
(2)
− −
for all a, b ∈ [0, 1]n , then F(d1 , . . . , dn ) is a S-generalized distance. Proof. Since the symmetry follows directly from the symmetry of each di . Therefore, to prove that F(d1 , . . . , dn ) is a S-generalized distance, we only need to prove that F satisfies (i) and (iii) in Definition 4. – Firstly, we verify (i) in Definition 4. As each di is a Si -generalized distance, we have that di (x, x) = 0 for all i = 1, . . . , n, and F(d1 , · · · , dn )(x, x) = F(d1 (x, x), · · · , dn (x, x)) = F(0, · · · , 0).
(3)
Since F is amenable, we have F(0, · · · , 0) = 0 and F(d1 , . . . , dn )(x, x) = 0.
Aggregation of S-generalized Distances
531
– Finally, we verify (iii) in Definition 4. As each di is a Si -generalized distance, we have that di (x, z) ≤ Si (di (x, y), di (y, z)) for all i = 1, ..., n. Due to monotonicity, we get F(d1 (x, z), · · · , dn (x, z)) ≤ F(S1 (d1 (x, y), d1 (y, z)), · · · , Sn (dn (x, y), dn (y, z))). (4) −
−
−
Taking a=(ai ) = (di (x, y)), b=(bi ) = (di (y, z)), c=(ci ) = (di (x, z)), we obtain that S(F(d1 , · · · , dn )(x, y), F(d1 , · · · , dn )(y, z)) = S(F(d1 (x, y), · · · , dn (x, y)), F(d1 (y, z), · · · , dn (y, z))) = S(F(a), F(b)) ≥ F(S1 (a1 , b1 ), · · · , Sn (an , bn )) = F(S1 (d1 (x, y), d1 (y, z)), · · · , Sn (dn (x, y), dn (y, z))) ≥ F(d1 (x, z), · · · , dn (x, z)). So, F(d1 , · · · , dn ) is a S-generalized distance. Example 1. Let S = S1 = S2 , di be a collection of Si -generalized distances on X , i = 1, 2. The function F = SM . Obviously, F is non-decreasing, amenable and S extended dominates F with respect to (S1 , S2 ). Since the symmetry follows directly from the symmetry of each di . Therefore, to prove that F(d1 , d2 ) is a S-generalized distance, we only need to prove that F satisfies (i) and (iii) in Definition 4. Since the function F is amenable, F satisfies (i). And since S extended dominates F with respect to (S1 , S2 ), we have. S(SM (d1 (x, y), d2 (x, y)), SM (d1 (y, z), d2 (y, z))) ≥ SM (S(d1 (x, y), d1 (y, z)), S(d2 (x, y), d2 (y, z))) = SM (S1 (d1 (x, y), d1 (y, z)), S2 (d2 (x, y), d2 (y, z))) ≥ SM (d1 (x, z), d2 (x, z)) for any x, y, z ∈ X . So, F(d1 , , d2 ) is a S-generalized distance. Next, let’s consider the necessary conditions for Theorem 1. Remark 4. If F is a S-generalized distance aggregation function, then S may be extended dominates F with respect to (S1 , ..., Sn ) in some cases. Example 2. Let S = SP and di be a collection of Si -generalized distances, i = 1, 2. Consider the function F : [0, 1]2 → [0, 1] defined by. ⎧ − ⎨ 0 if a= (0, 0), − ⎪ F a = 2/3 − (5) a= (1/3, 1/3), ⎪ ⎩ 1/2 otherwise. −
for all a ∈ [0, 1]2 .
532
L. Sun et al.
Firstly, we verify F ∈ FSS1D,S2 , for any x, y, z ∈ X . There are three different cases to be considered: – If x = y = z, we have SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) = SP (F(0, 0), F(d1 (y, z), d2 (y, z))) = F(d1 (y, z), d2 (y, z)) ≥ F(d1 (x, z), d2 (x, z)). – If x = y = z, we have SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) = SP (F(0, 0), F(0, 0)) = F(d1 (x, z), d2 (x, z)). – If x = y = z, we have SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) = F(d1 (x, y), d2 (x, y)) + F(d1 (y, z), d2 (y, z)) −F(d1 (x, y), d2 (x, y)) · F(d1 (y, z), d2 (y, z)) ≥ F(d1 (x, z), d2 (x, z)). Other situations are similar. It is simple to verify that F ∈ FSS1P,S2 . In this case, S extended dominates F, that is SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) ≥ F(S1 (d1 (x, y), d1 (y, z)), S2 (d2 (x, y), d2 (y, z))).
For any x, y, z ∈ X . There are three different cases to be considered: – If x = y = z, we have SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) = SP (F(0, 0), F(d1 (y, z), d2 (y, z))) = F(d1 (y, z), d2 (y, z)) = F(S1 (0, d1 (y, z)), S2 (0, d2 (y, z))). – If x = y = z, we have SP (F(d1 (x, y), d2 (x, y)), F(d1 (y, z), d2 (y, z))) = SP (F(0, 0), F(0, 0)) = F(S1 (0, 0), S2 (0, 0)) = F(S1 (d1 (x, y), d1 (y, z)), S2 (d2 (x, y), d2 (y, z))). – If x = y = z, we have min(SP (F(d1 (x, y), (d2 (x, y)), F(d1 (y, z), F(d2 (y, z)))) = 3/4 − ≥ 2/3 = max F( a) ≥ F(S1 (d1 (x, y), d1 (y, z)), S2 (d2 (x, y), d2 (y, z))). We conclude that S extended dominates F in this case.
Aggregation of S-generalized Distances
533
Next, we introduce the concept of S-triangular triplet which is an extended form of the triangular triplet and can be used to characterize the S-generalized distance aggregation function. Definition 9. A triplet (a, b, c) ∈ [0, 1]3 is said to be (1-dimensional) S-triangular if − − −
a ≤ S(b,c), b ≤ S(a, c) and c ≤ S(a, b). Let a, b, c ∈ [0, 1)n , n ≥ 1, we say that − − − a, b, c is a (n-dimensional) S-triangular triplet if (ai , bi , ci ) is a S-triangular triplet −
−
−
for all i = 1, ..., n, where a = (a1 , ..., an ), b = (b1 , ..., bn ) and c = (c1 , ..., cn ). Theorem 2. Consider a function F : [0, 1]n → [0, 1]. The following assertions areequivalent. (i) F = FSS1 ,...,Sn . (ii) F satisfies the following properties: i) F(0,..., 0,. − −0) = − − − − ii) If a, b, c is an n-dimensional Si -triangle triplet, then F a , F b , F c is a 1-dimensional S-triangle triplet, i = 1, 2, ..., n. Proposition 1. Let S be a t-conorm, S the set of all t-conorms, we denote by S∗ the set of all S-triangular triplets, then partially ordered sets (S, ≤) and (S∗ , ⊆) are ordered isomorphisms. Remark 5. According to the Remark 2, we have SM ∗ ⊆ S∗ ⊆ SD∗ for any t-conorm *. Next, we provide a characterization of those functions that are able to merge a collection of S-generalized distances into a new one whenever all involved t-conorms are strict. Theorem 3. Let S, S1 , ..., Sn be strict t-conorms, function F : [0, 1]n → [0, 1]. The following assertions are equivalent: (i) F aggregates a collection of Si -generalized distances di into a S-generalized distance, i = 1...., n. (ii) There exists a function G : [0, +∞]n → [0, +∞] which transforms n-dimensional triangular triplets into 1-dimensional triangular triplets and satisfies G = s ◦ F ◦ (s1−1 × ... × sn−1 ),
(6)
where s and si are additive generators of t-conorms S and Si , respectively. Proof. (i) ⇒ (ii) Let us suppose that F aggregates a collection of Si -generalized distances into a s-generalized distance. we prove that G transforms n-dimensional triangu− − −
lar triplets into 1-dimensional triangular triplets. to this end, assume that ( a, b, c) ∈ [0, ∞]n is a n-dimensional triplet. define the fuzzy relations di , i = 1, ..., n, on a non-empty set X = { x, y, z} (of different elements) by di (x, y) = di (y, x) = si−1 (bi ), di (x, z) = di (z, x) = si−1 (ai ), di (z, y) = di (y, z) = si−1 (ci ),di (x, x) =
534
L. Sun et al.
di (y, y) = di (z, z) = 0. hard to check that each di is a Si -generalized distance. it is not − − − next, let us prove that G a , G b , G c is a 1-dimensional triangular triplet. Since F aggregates the collection of Si -generalized distances into a S-generalized distance we have that S(F(d1 , · · · , dn )(x, y), F(d1 , · · · , dn )(y, z)) ≥ F(d1 , · · · , dn )(x, z). It follows that s−1 (s ◦ F(d1 , · · · , dn )(x, y) + s ◦ F(d1 , · · · , dn )(y, z)) ≥ F(d1 , · · · , dn )(x, z). So, s ◦ F(d1 , · · · , dn )(x, y) + s ◦ F(d1 , · · · , dn )(y, z) ≥ s(F(d1 , · · · , dn )(x, z)), ≥ s ◦ F(d1 (x, y), ..., dn (x, y)) + s ◦ F(d1 (y, z), ..., dn (y, z)) s ◦ F(d1 (x, z), ..., dn (x, z)). The fact that di (x, y) = si−1 (bi ), di (x, z) = si−1 (ai ) and di (y, z) = si−1 (ci ) provides that s ◦ F s1−1 (b1 ), ..., sn−1 (bn ) + s ◦ F s1−1 (c1 ), ..., sn−1 (cn ) ≥ s ◦ F s1−1 (a1 ), ..., sn−1 (an ) . − − − Hence, we conclude that G a ≤ G b +G c and G transforms a n-dimensional triangular triplet into a 1-dimensional triangular triplet. (ii) ⇒ (i) Assume that G transforms n-dimensional triangular triplets into 1dimensional triangular triplets. We prove that F(d1 , · · · , dn ) is a S-generalized distance for all collection di of Si -generalized distances,i = 1, ..., n. To this end, consider a collection di of Si -generalized distances on a non-empty set X ,i = 1, ..., n. Then, for each x, y, z ∈ X , we set di (x, y) = di (y, x) = ai ,di (z, y) = di (y, z) = bi ,di (x, z) = di (z, x) = ci .Since Si (di (x, y), di (y, z)) ≥ di (x, z), Si (di (x, z), di (z, y)). ≥ di (x, y) and Si (di (y, x), di (x, z)) ≥ di (y, z). We have si (ai ) + si (bi ) ≥ si (ci ), si (ci ) + si (bi ) ≥ si (ai ) and si (ci ) + si (ai ) ≥ si (bi ). Thus, (si (ai ), si (bi ), si (ci )) is − − − − − − a 1-dimensional triangle triplet. Hence, s a , s b , s c is a n-dimensional − − − − triangle triplet, where s a = (s1 (a1 ), ..., sn (an )), s b = (s1 (b1 ), ..., sn (bn )) − − − − − − − − and s c = (s1 (c1 ), ..., sn (cn )). Therefore, G s a , G s b , G s c is a 1-dimensional triangle triplet, we have − − − −− − G s a ≤G s b +G s c , − − − −− − + G s a . Since G = s ◦ F ◦ (s1−1 × ... × sn−1 ), we and G s c ≤ G s b − − − − −− − − − +G s c =s F b +s F c . have s F a = G s a ≤ G s b
Aggregation of S-generalized Distances
535
From the preceding inequality we deduce that − − − − − +s F c = S F b ,F c . F a ≤ s−1 s F b Therefore, we obtain that S(F(d1 , · · · , dn )(z, y), F(d1 , · · · , dn )(x, z)) ≥ F(d1 , · · · , dn )(x, y). Similarly, we can show that S(F(d1 , · · · , dn )(x, y), F(d1 , · · · , dn )(y, z)) ≥ F(d1 , · · · , dn )(x, z) and S(F(d1 , · · · , dn )(y, x), F(d1 , · · · , dn )(x, z)) ≥ F(d1 , · · · , dn )(y, z). Moreover,F(d1 , · · · , dn )(x, y) = F(d1 , · · · , dn )(y, x). Since every di is a Si generalized distances. Consequently, we conclude that F aggregates Si -generalized distances into a S-generalized distance. Example 3. Aggregation of two S-generalized distances (d1 , d2 ) defined with respect to the strict t-conorm SP into a generalized distance of the same class (S1 = S2 = S = SP ). The additive generator of SP is s(x) = −In(1 − x), its inverse is s−1 (x) = 1 − e−x , when G = Max, we have F(d1 , d2 ) = s−1 ◦ G ◦ (s1 × s2 )(d1 , d2 ) = s−1 ◦ Max ◦ (s1 (d1 ), s2 (d2 )) = s−1 ◦ Max ◦ (s1 (d1 ), s2 (d2 )) = Max(d1 , d2 ). That is F = Max.
4 Conclusion We have addressed the aggregation problem of a collection of Si -generalized distances into a S-generalized distances with respect to a different t-conorms. We characterize S-generalized distances aggregation function by means of triangular triple and extended dominance relation. Moreover, we have characterized S-generalized distance aggregation function when S and Si are strict t-conorms. Namely, the characterization is provided in terms of the additive generators of the all involved t-conorms and another functions that transform n-dimensional triangular triplet into a 1-dimensional triplet. As future work we will discuss whether some concrete families of aggregation functions could be applied as S-generalized distances, additionally, consider how S-generalized distance is defined when the triangular t-conorm S is expressed in ordinal sums form. Acknowledgements. This work is supported by National Natural Science Foundation of China under Grant 61977040 and Natural Science Foundation of Shandong Province under Grant ZR2019MF055.
536
L. Sun et al.
References 1. Borsík, J., Doboš, J.: On a product of metric spaces. Math. Slovaca 31(2), 193–205 (1981) 2. Bosc, P., Buckles, B.B., Petry, F.E., Pivert, O.: Fuzzy databases. In: Fuzzy Sets in Approximate Reasoning and Information Systems, pp. 403–468 (1999). https://doi.org/10.1007/9781-4615-5243-7_8 3. Bradshaw, S., Brazil, E., Chodorow, K.: MongoDB: the definitive guide: powerful and scalable data storage. O’Reilly Media (2019) 4. Calvo, T., Mayor, G., Mesiar, R.: Aggregation Operators: New Trends and Applications, vol. 97. Springer Science & Business Media, Heidelberg (2002) 5. Cao, Y., Sun, S.X., Wang, H., Chen, G.: A behavioral distance for fuzzy-transition systems. IEEE Trans. Fuzzy Syst. 21(4), 735–747 (2012) 6. Casasnovas, J., Rosselló, F.: Averaging fuzzy biopolymers. Fuzzy Sets Syst. 152(1), 139–158 (2005) 7. Desharnais, J., Jagadeesan, R., Gupta, V., Panangaden, P.: The metric analogue of weak bisimulation for probabilistic processes. In: Proceedings 17th Annual IEEE Symposium on Logic in Computer Science, pp. 413–422. IEEE (2002) 8. Dipina Damodaran, B., Salim, S., Vargese, S.M.: Performance evaluation of mysql and mongodb databases. Int. J. Cybern. Inform. (IJCI) 5, 387–394 (2016) 9. Ferns, N., Panangaden, P., Precup, D.: Metrics for finite Markov decision processes. In: UAI. vol. 4, pp. 162–169 (2004) 10. Gy˝orödi, C., Gy˝orödi, R., Pecherle, G., Olah, A.: A comparative study: MongoDB vs. MySQL. In: Proceedings of the 2015 13th International Conference on Engineering of Modern Electric Systems (EMES), pp. 1–6. IEEE (2015) 11. Hitzler, P., Seda, A.: Mathematical aspects of logic programming semantics. Taylor & Francis (2011) 12. Klement, E.P., Mesiar, R., Pap, E.: Triangular Norms, vol. 8. Springer Science & Business Media, Dordrecht (2013) 13. Mayor, G., Valero, O.: Metric aggregation functions revisited. Eur. J. Comb. 80, 390–400 (2019) 14. Mesiar, R., Kolesarova, A., Stupˇnanová, A.: Quo vadis aggregation? Int. J. Gen. Syst. 47(2), 97–117 (2018) 15. Milner, R.: Communication and Concurrency, vol. 84. Prentice hall Englewood Cliffs (1989) 16. Park, D.: Concurrency and automata on infinite sequences. In: Deussen, P. (ed.) GI-TCS 1981. LNCS, vol. 104, pp. 167–183. Springer, Heidelberg (1981). https://doi.org/10.1007/BFb001 7309 17. Pradera, A., Trillas, E.: A note on pseudometrics aggregation. Int. J. Gen. Syst. 31(1), 41–51 (2002) 18. Pradera, A., Trillas, E., Castiñeira, E.: On distances aggregation. In: Proceedings of the Information Processing and Management of Uncertainty in Knowledge-Based Systems International Conference, vol. 2, pp. 693–700. Universidad Politécnica de Madrid Press (2000) 19. Pradera, A., Trillas, E., Castiñeira, E.: On the aggregation of some classes of fuzzy relations. In: Technologies for Constructing Intelligent Systems 2: Tools, pp. 125–136 (2002) 20. Sangiorgi, D.: On the origins of bisimulation and coinduction. ACM Trans. Program. Lang. Syst. (TOPLAS) 31(4), 1–41 (2009) 21. Van Breugel, F., Worrell, J.: A behavioural pseudometric for probabilistic transition systems. Theoret. Comput. Sci. 331(1), 115–142 (2005)
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy for Green Multi-depot Heterogeneous-Fleet Vehicle Routing Problem Linhao Xu, Bin Qian(B) , Rong Hu, Naikang Yu, and Huaiping Jin Faculty of Information Engineering and Automation, Kunming University of Science and Technology, Kunming 650500, China [email protected] Abstract. Aiming at the green multi-depot heterogeneous-fleet vehicle routing problem (GMDHFVRP), a mixed integer programming model (MIP) was established to minimize the total cost, and a Lagrange heuristic algorithm incorporated with decomposition strategy (LHA_DS) was proposed to solve it. First, the original problem is decomposed into a number of vehicle sub-problems by relaxing the complex constraints. At the same time, the distance-based K-means clustering method is used to allocate service customer sets for each parking lot, and the surrogate subgradient method is used to solve the sub-problems to update the lagrange multiplier; Secondly, in order to obtain a better initial feasible solution, a repair strategy based on greedy idea is designed to repair the problem relaxation solution to obtain a better solution, and a neighborhood search method to determine the operation order is designed to optimize the feasible solution. Finally, the effectiveness of LHA_DS in solving GMDHFVRP is verified by simulation experiments. Keywords: green multi-depot heterogeneous-fleet vehicle routing problem · mixed integer programming · Lagrange heuristic · surrogate subgradient
1 Introduction Vehicle routing problem (VRP) was first proposed by Dantzig and Ramser in 1959 [1]. As a classical combinatorial optimization problem, VRP has been studied by many scholars. With the rapid development of social economy, the modern logistics distribution system has developed rapidly. The traditional VRP model has been difficult to adapt to the modern logistics system. The importance of punctuality and efficiency of the modern logistics system has become increasingly prominent. Cross-regional multi-fleet joint distribution can better provide logistics services for customers [2, 3]. In addition, under the background of green development, it is important to consider green factors in actual distribution. Under the above background, studying the green multi depot multi vehicle routing problem (GMDHFVRP) has important practical significance. Meanwhile, GMDHFVRP also belongs to the NP-hard problem, and conducting research on it has significant theoretical value. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 537–548, 2023. https://doi.org/10.1007/978-981-99-4761-4_46
538
L. Xu et al.
GMDHFVRP is a kind of complex VRP that exists in real life. Some scholars have conducted relevant research on this kind of problem. Baldacci et al. [4] studied heterogeneous-fleet VRP, site-related VRP and multi-depot VRP, and proposed a Lagrange relaxation algorithm to solve the three types of problems. The experimental results show that the lower bound of the three types of problems solved by the algorithm is better than the lower bound of other documents. Contardo et al. [5] aimed at the multi-depot VRP with capacity and path length constraints, with the goal of minimizing the total cost, established the vehicle flow and set partition formula for the problem, and used the cutting plane method and the column and cutting generation method to solve the problem. The experimental results verified the effectiveness of the proposed algorithm. Ibrahim et al. [6] uses an improved column generation algorithm to solve the multi-depot VRP with inter-station routes, taking the path cost as the optimization objective. Bettinelli et al. [7] established an integer linear model with the optimization goal of path cost for the multi depot multi vehicle routing problem with time windows, and proposed a BCP algorithm to solve it. At present, research on the decomposition and solution of GMDHFVRP, further optimization of feasible solutions, and evaluation indicators for corresponding problem solutions is very limited, and there is still value for further research. Lagrange relaxation algorithm (LR) is an effective method for solving complex combinatorial optimization problems. The main idea of LR is to introduce the difficult constraints of the problem into the objective function to obtain the relaxation problem. By decomposing the relaxation problem into several sub-problems, the problem can be easily solved. However, using traditional LR can only obtain the infeasible solution of the problem. It is necessary to further design the corresponding algorithm to repair the solution to a feasible solution. By combining the heuristic method with LR, this paper designs a combined algorithm to solve GMDHFVRP. The main contributions of this paper are as follows: 1) The mixed Integer programming model of GMDHFVRP is established; 2) Decomposing and solving the problem, and designing repair algorithms and neighborhood search algorithms to improve the quality of feasible solutions. Aiming at GMDHFVRP in real life, this paper establishes a mixed integer programming model with the objective of minimizing transportation costs, and proposes a Lagrange heuristic algorithm incorporated with decomposition strategy (LHA_DS) to solve it. The rest is arranged as follows: Sect. 2 briefly introduces GMDHFVRP and its mathematical model, Sect. 3 introduces the algorithm design for solving GMDHFVRP, Sect. 4 carries out simulation experiments, and Sect. 5 summarizes the full text.
2 GMDHFVRP Description and Model 2.1 Relevant Descriptions and Assumptions of GMDHFVRP MDHFVRP can be described as: in the known customer scale, customer demand, customer and depot coordinates, and distribution fleet conditions, reasonably plan the delivery route of distribution vehicles to serve all customers, so as to minimize the total cost of distribution.
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy
539
GMDHFVRP Problem Assumptions: (1) The model and number of vehicles in each warehouse are the same, and all vehicles of all models in all warehouses participate in the distribution task. (2) Each vehicle can only be delivered once. (3) The vehicle starts from a warehouse and needs to return to the warehouse after completing the service. (4) Each customer can only be served once by one vehicle. (5) The vehicle load cannot exceed the maximum load of the vehicle type. (6) During the distribution process, vehicles should keep a constant speed and ignore the impact of road and other factors on vehicles. The parameters and related definitions of GMDHFVRP are shown in Table 1, and the relevant parameter values are shown in Table 2. Table 1. Symbols and definitions. Character Definition
Character Definition
Z
Total delivery cost
cm
Unit carbon emission cost 0.2/kg
N
Number of customers
i, j
Node i, j ∈ {0, 1, 2, ...N }
V
Customer set {1, 2, ...N }
p
Model number
D
Number of parking lots
f
Parking lot number
Do
Parking lot set {1, 2, ...D}
m
Vehicle number
Mc
Number of models
Qp
Vehicle load of type p vehicle
Ms
Model set {1, 2, ..., Mc }
qi
Demand of node i
Hp
Number of p type vehicle
ui
A continuous variable of MTZ
Hps
Set of p type vehicle {1, ..., Hp } dij
Distance from node i to node j
c1
Unit distance cost(1/km)
wfpmij
Vehicle load in arc (i, j)
cp
Fixed cost of type p vehicle
xfpmij
Sign of vehicle from node i to node j
2.2 GMDHFVRP Mixed Integer Programming Model Problem optimization objectives: MinZ =
Hp N N Mc D f =1 p=1 m=1 i=0 j=0
+cm
Hp N N Mc D
Hp N N Mc D f =1 p=1 m=1 i=0 j=0
cp xfpm0j
f =1 p=1 m=1 j=1
e
α ϕ ∗ dij wfpmij μζ 1000εη
e
αWp + βv2 ϕ ς SVp ( + )dij xfpmij μζ v 1000εη
f =1 p=1 m=1 i=0 j=0
+cm
c1 xfpmij dij +
Hp N Mc D
(1)
540
L. Xu et al.
α = a + g sin θ + gCr cos θ
(1.1)
β = 0.5Cd Aρ
(1.2)
Problem constraints: N
xfpm0j =
j=1
N
xfpmj0 = 1, ∀p ∈ Ms , ∀m ∈ Hps , ∀f ∈ Do
(2)
j=1 Hp N Mc D
xfpmij = 1, i = j, ∀j ∈ V \{0}
(3)
xfpmij = 1, i = j, ∀i ∈ V \{0}
(4)
f =1 p=1 m=1 i=0 Hp N Mc D f =1 p=1 m=1 j=0 Hp N N Mc D
xfpmij = 0, i = j
(5)
f =1 p=1 m=1 i=0 j=0 N
xfpmij =
i=0
N
xfpmji , ∀j ∈ V \{0}, ∀p ∈ Ms , ∀m ∈ Hps , ∀f ∈ Do
(6)
i=0
σi − σj + 1 − N (1 −
Hp Mc D
xfpmij ) ≤ 0, i = j, ∀i ∈ V \{0}, ∀j ∈ V \{0}
(7)
f =1 p=1 m=1
wfpmij ≤ xfpmij Qp , i = j, ∀f ∈ Do , ∀p ∈ Ms , ∀m ∈ Hps , ∀i ∈ V , ∀j ∈ V N N
xfpmij qi =
i=1 j=0 N i=0
wfpmij − qj
N
N
wfpm0j , ∀f ∈ Do , ∀p ∈ Ms , ∀m ∈ Hps
(8)
(9)
j=1
xfpmij =
i=0
N
wfpmji , ∀j ∈ V \{0}, ∀f ∈ Do , ∀p ∈ Ms , ∀m ∈ Hps , i = j
i=0
(10) wfpmij ≥ 0, ∀i ∈ V , ∀j ∈ V , ∀f ∈ Do , ∀p ∈ Ms , ∀m ∈ Hps
xfpmij =
⎧ ⎨
If the m − th vehicle of type p in parking lot f moves from node i to node j , ⎩ 0, otherwise
(11)
1,
(12)
i = j, ∀i ∈ V , ∀j ∈ V , ∀f ∈ Do , ∀p ∈ Ms , ∀m ∈ Hps In the above model, Eq. (1) is the optimization objective function, Eqs. (1.1) and (1.2) correspond to road-specific constant and vehicle-specific constant respectively.
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy
541
Equation (2) requires all vehicles to participate in distribution and return to the original parking lot after distribution. Equations (3), (4) and (5) require that each customer can only be served once by one car in a certain parking lot. Equation (6) is the flow balance constraint of vehicle service customer point; Eq. (7) removes constraints for MTZ subloop. Equation (8) requires that the load on the path is less than or equal to the vehicle load. Equation (9) requires that the load of all vehicles from the warehouse is equal to the total demand of goods to serve customers. Equation (10) means that the load capacity of the vehicle after serving a certain node is equal to the load capacity of the vehicle when starting from the node. Equations (11) and (12) represent decision variables.
3 LHA_DS 3.1 Relaxation Constraint Relax constraint (8) to obtain relaxation problem: L(λ) = min{
Hp N N Mc D
c1 xfpmij dij +
f =1 p=1 m=1 i=0 j=0
+cm
Hp N N Mc D f =1 p=1 m=1 i=0 j=0
+cm
Hp N N Mc D f =1 p=1 m=1 i=0 j=0
+
Hp N N Mc D
e
Hp N Mc D
cp xfpm0j
f =1 p=1 m=1 j=1
α ϕ ∗ dij wfpmij μζ 1000εη (13)
αWp + βv2 ϕ ς SVp ( + )dij xfpmij e μζ v 1000εη
λfpmij (wfpmij − xfpmij Qp )}
f =1 p=1 m=1 i=0 j=0
According to Eq. (13) relaxation problem and GMDHFVRP problem hypothesis, Hp Mc D N we can know that the fixed cost of vehicle ( cp xfpm0j ) is fixed. Therefore, f =1 p=1 m=1 j=1
the Lagrange dual (LD) problem is obtained by maximizing the relaxation problem after removing this term. The LD problem is further decomposed into vehicle sub-problems, and each sub-problem can be expressed as: Zfpm = cm
N N i=0 j=0
+
N N
α ϕ ∗ dij wfpmij + λfpmij (wfpmij − xfpmij Qp ) μζ 1000εη N
e
N
i=0 j=0
c1 xfpmij dij + cm
i=0 j=0
N N i=0 j=0
max
λ∈[0,+∞)
L∗(λ)
αWp + βv2 ϕ ς SVp ( + )dij xfpmij e μζ v 1000εη
= max{min{
Hp Mc D f =1 p=1 m=1
Zfpm }}
(14)
(15)
542
L. Xu et al.
Analytical formula (14), Zfpm can still be further decomposed. On the basis of relaxation constraint (8), the coupling relationship between xfpmij and wfpmij can be further eliminated by decomposition, making the problem easier to solve, It can be further 1 and Z 2 . decomposed into two sub-problems Zfpm fpm 1 Zfpm =
N N
(c1 dij − λfpmij Qp )xfpmij + cm
i=0 j=0
N N
e
i=0 j=0
αWp + βv 2 ϕ ςSVp ( + )dij xfpmij μζ v 1000εη
(16) s.t. (2)-(7), (12) 2 Zfpm
=
N N
(cm e
i=0 j=0
α ϕ ∗ dij + λfpmij )wfpmij μζ 1000εη
(17)
s.t. (9)-(11) 1 2 Zfpm = Zfpm + Zfpm
(18)
3.2 Customer Allocation Strategy The specific steps are as follows: Step 1: Calculate the European distance from all customer points to all parking lots; Step 2: Take all the parking lots as the cluster center, use K-means to cluster and assign customer points, and get several clusters of parking lots; Step 3: Output several customer sets in the parking lot. 3.3 Repair Strategy Based on Greedy Thought As GMDHFVRP has the same type and number of vehicles in each parking lot, the repair strategy is described below with parking lot f as an example, where the total number of vehicles in parking lot f is MH (MH = Mc ∗ Hp ). Step 1: Convert the solution of the Mc ∗ Hp subproblem into a sequencing solution, Assuming that the infeasible solution is xp0j1 k = 1, xpj1 j2 k = 1, xpj2 j3 k = 1, xpj3 0k = 1, the corresponding ranking solution is {0, j1 , j2 , j3 , 0}. Step 2: Judge whether the vehicle k1 (k1 = 1, 2, ..., MH ) meets the load constraint in turn. If the load constraint condition is met, put the vehicle k1 into the feasible vehicle set Xfeasible , otherwise put it into the infeasible vehicle set Xillegal , until all vehicles are judged. Step 3: If the infeasible vehicle set is empty, go to Step 8. Step 4: Sort the vehicle codes in Xillegal according to the vehicle load capacity from small to large, pick out the customer points exceeding the load in vehicle k2 and put them into the remaining customer set πillegal , and put the vehicle k2 meeting the load conditions into Xfeasible until all vehicles in Xillegal are processed.
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy
543
Step 5: Sort all vehicles (MH vehicles) in Xfeasible according to the vehicle load capacity from small to large, so that k3 = 1. Step 6: Find the customer cs1 that is closest to the last customer of vehicle k3 and meets the conditions in πillegal , insert cs1 into vehicle k3 and remove it from πillegal . Repeat step 6 until the vehicle k3 cannot be inserted into the customer point. Step 7: k3 = k3 + 1, if k3 ≤ MH , go to step 6. Step 8: The algorithm ends and outputs the feasible vehicle set Xfeasible . 3.4 Dynamic Neighborhood Search Strategy This section designs a neighborhood search strategy to optimize the feasible solution obtained in Sect. 3.3. The six types of neighborhood operations are as follows (Fig. 1):
Fig. 1. Neighborhood search strategy
3.4.1 Selection of Neighborhood Operation Order Set 6 neighborhood operations Insert_vtv, Exchange_vtv, 2 − opt_iv, Exchange_iv, Insert_iv and Swap_iv as N (i) (i = 1, 2, 3, 4, 5, 6). Let Fit(i) be the cumulative number of executions of N (i), and the initial value of Fit(i) is 1. The specific steps to determine the neighborhood operation sequence are as follows: Step 1: Initialize the neighborhood operation order, set the initial neighborhood operation order between vehicles as out_car = [N (1), N (2)], and the neighborhood operation order within vehicles as in_car = [N (3), N (4), N (5), N (6)]. Step 2: Perform neighborhood operation between vehicles according to out_car. If the current operation improves the target value, the corresponding Fit(i) = Fit(i) + 1, and do not perform other neighborhood operations in out_car. Step 2.1: Arrange the corresponding N (i) in descending order according to the size of Fit(i), and update out_car. Step 3: Perform the neighborhood operation in the vehicle according to in_car. If the current operation improves the target value, make the corresponding Fit(i) = Fit(i) + 1, and do not perform other neighborhood operations in in_car. Step 3.1: Arrange the corresponding N (i) in descending order according to the size of Fit(i), and update in_car. Step 4: Repeat steps 2 and 3, and update out_car and in_car until the neighborhood search is completed.
544
L. Xu et al.
3.4.2 Steps of Dynamic Neighborhood Search Algorithm As GMDHFVRP has the same type and number of vehicles in each parking lot, take the parking lot f as an example to describe the dynamic neighborhood search algorithm. The specific steps are as follows: Step 1: Set the feasible solution to σ , make the current optimal solution σ ∗ = σ , and the total number of vehicles is MH , loop = 1, N _co = 1, N _ci = 1, N _car = 1. Initialize out_car and in_car according to Sect. 3.4.1. Step 2: Select two vehicles at random in σ , and perform corresponding neighborhood operations on customers between vehicles in sequence according to out_car. If the constraint conditions are not met, let σ = σ ∗ , and continue to perform the other neighborhood operations in the order of out_car; On the contrary, if a better solution is obtained, update σ and σ ∗ to the better solution, and do not perform the remaining neighborhood operations in out_car. Step 3: Update out_car based on 3.4.1, and let N _co = N _co + 1. If N _co ≤ 10MH , perform step 2. Otherwise, let σ = σ ∗ , N _co = 1. Step 4: Perform the corresponding in-vehicle neighborhood operations on the N _car vehicle in σ in sequence according to the in_car sequence. If a better solution is obtained, update σ and σ ∗ to the better solution, and do not perform the other neighborhood operations in in_car; Otherwise, let σ = σ ∗ , and continue to perform the remaining neighborhood operations in in_car. Step 5: Update in_car based on 3.4.1, and let N _ci = N _ci + 1. If N _ci ≤ 20, perform step 4. Step 6: Let N _car = N _car +1, N _ci = 1, if N _car ≤ MH , perform step 4; Otherwise, let loop = loop + 1 and N _car = 1. Step 7: If loop ≤ 20, perform step 2; Otherwise, output σ ∗ . 3.5 Multiplier Update The surrogate subgradient is: g˜ fpmij (xn ) = wfpmij − xfpmij Qm
(19)
The step update formula is: Cn =
˜ n) (UB − L(λ) Hp D N M N f =1 p=1 m=1 i=0 j=0
(g˜ fpmij
ω
(20)
(xn ))2
Lagrange multiplier update method is: λn+1 = max{0, λnfpmij + C n g˜ fpmij (xn )} fpmij
(21)
In the above step size and multiplier update formula, C n is the step size of the n ˜ n is the surrogate dual target iteration, UB is the best upper bound of the n iteration, L(λ) value of the n iteration, and the parameter ω value is 0.5.
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy
545
3.6 Dynamic Neighborhood Search Strategy In each update iteration process, Mc ∗ Hp subproblems are solved each time, that is, all subproblems in a certain parking lot. The second iteration solves all subproblems in the second parking lot. When the algorithm runs to generation n, all subproblems are solved. At the same time, in each iteration, only the Mc ∗ Hp subproblem is solved and the corresponding solution is obtained, while the solution of the remaining subproblems is the same as that of the previous iteration. In the process of repairing and optimizing the solution, repair and neighborhood search operations are carried out for the current Mc ∗ Hp subproblem solution, keeping the previous solution unchanged. The specific steps of the algorithm are as follows: Step 1: Use the method described in Sects. 3.1 and 3.2 to decompose the whole problem, initialize the Lagrange multiplier λ0 = 0, the number of iterations n = 0, and set UB as the larger target value of the original problem. Step 2: Solve all subproblems to obtain dual cost L0 , repair all solutions of subproblems to obtain feasible solution UB∗ , update the current optimal upper bound UB, and update Lagrange multipliers according to Eqs. (19), (20), and (21). 1 , Z 2 ) according to the parking Step 3: n = n + 1, Solve the Mc ∗ Hp sub-problem (Zfpm fpm n n lot, obtain xfpmij and wfpmij , repair and optimize the current relaxation solution by using the repair strategy based on greedy thought in Sect. 3.3 and the dynamic neighborhood search strategy in Sect. 3.4, and get a better solution UB1 . Step 4: Update the current UB according to the values of UB1 and the remaining subproblems of the previous generation, and update the Lagrange multiplier according to Eqs. (19), (20) and (21). Step 5: Judge whether the termination conditions are met. If yes, output UB; Otherwise, go to step 3.
4 Experimental Results 4.1 Experimental Environment and Calculation Example The experimental algorithms in this paper are implemented by Python 3.8 programming. The operating system is windows10, the computer memory is 8G, the CPU is Intel i54258U, and the main frequency is 2.4 GHz. The test example is from the MDVRP data set in the website (http://www.bernabe.dorronsoro.es/vrp/), On the basis of this data set, the vehicle model data (vehicle dead weight, load and fixed cost) in Table 2 are further added to form a multi-depot heterogeneous-fleet model data set. 4.2 Computational Results In this section, LHA_ DS is compared with IACO [8], DTS [9] and Gurobi solver respectively. Set the running time of Gurobi solver to 7200, 10800, 14400 and 18000 s according to different scale calculation examples, and set LHA_ DS has a maximum of 30 iterations or a maximum running time of 3000 s. Set the running time of IACO and DTS to LHA_ DS running time, LHA_ DS, IACO and DTS operate independently for 15 times under the same experimental environment, and the Gurobi solver operates
546
L. Xu et al. Table 2. Parameter definition and value.
Parameter
Definition
Value
Parameter
Definition
Value
v
Vehicle speed (m/s)
14
S
Engine speed (r/s) 33
ς
Engine friction factor (kJ /r/l)
0.2
a
Vehicle acceleration
0
Vp
Displacement of p type vehicle (l)
[5,7,9]
θ
Road angel
0
η
Efficiency parameters of diesel engine
0.9
Cd
Air drag coefficient
0.7
e
carbon to fuel conversion ratio
3.09
Cr
Rolling resistance coefficient
0.01
μ
Calorific value of diesel (kJ /g)
44
A
Frontal surface area (m2 )
3.9
g
Gravitational constant 9.8 (m/s2 )
ρ
Air density
1.2
Wp
Self-weight of type p vehicle (t)
[6,8,10]
ϕ
Fuel to air mass ratio
1
Qp
Vehicle load of type p vehicle (kg)
[90,160,200]
ζ
Conversion factor
737
cp
Fixed cost of type p vehicle
[200,300,400]
ε
Vehicle drive train 0.4 efficiency
(kg/m3 )
once. The best value of each example is shown in bold. The solution results are shown in Table 3. Among them, Best, Avg and T respectively represent the best value, average value, and running time of the total cost of running the algorithm 15 times. From Table 3, according to the test results of 12 examples, the dominant rate of LHA_DS is 75%, and the overall solution effect is better than the other algorithms. Therefore, LHA_DS has good overall performance and is an effective algorithm for solving GMDHFVRP.
(1,0,1)
(0,1,1)
(1,1,3)
(2,1,1)
(1,0,1)
(0,1,1)
(1,1,0)
(0,1,1)
(1,1,0)
(1,1,1)
(2,1,1)
(1,1,0)
50 × 4
75 × 5
100 × 2
100 × 4
80 × 2
160 × 4
240 × 6
360 × 9
48 × 4
96 × 4
144 × 4
72 × 6
p01
p03
p05
p07
p12
p15
p18
p21
pr01
pr02
pr03
pr07
4385.2
6858.7
5237.2
3078.7
13012.2
7445.5
5783.2
2701.5
5413.8
4413.0
4230.7
2994.8
Best
LHA_DS
4387.9
6940.5
5268.4
3085.8
13027.5
7466.5
5784.2
2719.9
5426.2
4451.6
4231.9
2994.9
Avg
89.3
3207.8
744.0
58.9
3451.0
2505.4
1882.7
872.5
1373.3
3194.4
82.1
56.3
T (s)
5162.3
7690.5
5801.3
3443.1
17719.1
9374.4
6430.6
2731.2
5775.6
4565.9
4521.1
3157.3
Best
IACO
5254.0
7841.1
5861.8
3516.5
18433.3
9533.3
6581.1
2815.8
5839.3
4603.7
4578.4
3182.5
Avg
90.2
3208.5
744.3
59.1
3455.6
2507.1
1883.7
873.2
1374.3
3195.3
83.0
57.0
T (s)
5283.8
8594.9
5824.6
3361.1
20157.4
10929.6
7089.5
2915.9
5769.2
4581.4
4562.6
3152.5
Best
DTS
Note: “-” in the table indicates that the Gurobi solver cannot find a feasible solution to the problem within the specified time
type
Scale (N × D)
number
Test example
Table 3. Comparison of algorithms with different scales.
5683.5
8982.5
6170.1
3632.8
20707.6
12037.6
7686.5
3163.4
5965.3
4772.5
4724.6
3221.2
Avg
90.5
3215.6
745.9
59.4
3460.0
2510.3
1885.2
874.3
1376.8
3198.5
83.4
57.2
T (s)
4496.1
-
5539.0
2790.7
-
-
-
2432.5
-
-
4345.3
2848.3
Avg
Gurobi
7200
14400
10800
7200
18000
18000
14400
7200
14400
10800
7200
7200
T (s)
Lagrange Heuristic Algorithm Incorporated with Decomposition Strategy 547
548
L. Xu et al.
5 Conclusions and Future Research This article establishes a MIP model for GMDHFVRP and proposes an LHA_ DS for solving. The main work is as follows: (1) Establishing a MIP model with the optimization objective of minimizing total costs; (2) Based on the characteristics of the GMDHFVRP model, a repair strategy based on greedy thinking was designed to obtain the optimal feasible solution of GMDHFVRP. At the same time, a neighborhood search method was designed to determine the order of operations, further obtaining high-quality solutions to the problem. The effectiveness of LHA_DS in solving GMDHFVRP is verified by comparing simulation algorithms. In the future, based on the research of this paper, we will further consider the time window and variable speed factors, and design an effective solution algorithm. Acknowledgment. This research was supported by the National Natural Science Foundation of China (62173169 and 61963022) and the Basic Research Key Project of Yunnan Province (202201AS070030).
References 1. Dantzig, G.B., Ramser, J.H.: The truck dispatching problem. Manage. Sci. 1(6), 80–91 (1959) 2. Li, H.Q., Yuan, J.L., Lv, T., et al.: The two-echelon timeconstrained vehicle routing problem in linehaul-delivery systems considering carbon dioxide emissions. Transportation Research Part D (49), 231−245 (2016) 3. Zhao, Y.W., Zhang, J.L., Wang, W.L.: Vehicle Routing Optimization Method for Logistics Distribution[M]. Science Press, Beijing (2014) 4. Baldacci, R., Mingozzi, A.: A unified exact method for solving different classes of vehicle routing problems. Math. Program. 120(2), 347–380 (2009) 5. Contardoa, C., Martinelli, R.: A new exact algorithm for the multi-depot vehicle routing problem under capacity and route length constraints. Discret. Optim. 12(1), 129–146 (2014) 6. Muter, I., Cordeau, J.-F., Laporte, G.: A branch-and-price algorithm for the multidepot vehicle routing problem with interdepot routes. Transp. Sci. 48(3), 425–441 (2014) 7. Bettinelli, A., Ceselli, A., Righini, G.: A branch-and-cut-and-price algorithm for the multidepot heterogeneous vehicle routing problem with time windows. Transportation Research Part C Emerging Technologies 19(5), 723–740 (2011) 8. Li, Y., Soleimani, H., Zohal, M.: An improved ant colony optimization algorithm for the multidepot green vehicle routing problem with multiple objectives. J. Clean. Prod. 227, 1161–1172 (2019) 9. Meliani, Y., Hani, Y., Elhaq, S.L., EL Mhamedi, A.: A developed Tabu Search algorithm for heterogeneous fleet vehicle routing problem. IFAC PapersOnLine 52(13), 1051–1056 (2019)
A Novel Algorithm to Multi-view TSK Classification Based on the Dirichlet Distribution Lei Nie, Zhenyu Qian, Yaping Zhao, and Yizhang Jiang(B) School of Artificial Intelligence and Computer Science, Jiangnan University, 1800 Lihu Avenue, Wuxi 214122, Jiangsu, People’s Republic of China [email protected]
Abstract. With the help of multi-view classification technology, the classification performance can be effectively improved. However, the traditional multi-view TSK classification method has the problem of dimension explosion when superimposing the features of multiple views. This paper proposes an innovative multiview TSK classification method, which makes individual decisions in each view, and then selects the most credible decision results to obtain comprehensive classification results. Compared with other algorithms, our algorithm shows strong competitiveness in accuracy rate. This study also conducts experimental validation on different datasets, proving that this method exhibits good performance in multiple domains. Different views are dynamically integrated by the multi-view decision TSK algorithm by at the evidence level, providing a feasible idea for multi-view TSK classification. A new method of multi-view TSK classification has been developed in this study, which addresses the issue of high dimensionality present in traditional multi-view TSK classification methods. The experimental results provide strong evidence in support of the feasibility and potential for this proposed method of implementation and application. Keywords: Fuzzy inference system · multi-view classification · multi-view TSK classification · Dirichlet distribution
1 Introduction Fuzzy systems are widely used in various fields due to their ability to deal with uncertainty and interpretability. Traditionally, multi-view TSK classification methods stack features extracted from multiple views and feed them into a TSK classifier [16]. However, this approach leads to a dimensionality explosion problem, making it unsuitable for practical applications where computational efficiency is critical. Therefore, we propose a new multi-view TSK classification method that aims to overcome this limitation to meet the needs of realistic applications. The method makes independent decisions based on each view, which are then combined to make a final verdict. Our goal is to improve computational efficiency and provide competitive performance in dealing with uncertainty and interpretation.
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 549–558, 2023. https://doi.org/10.1007/978-981-99-4761-4_47
550
L. Nie et al.
A novel multi-view TSK classification method is proposed to overcome the limitation. Rather than merging multiple views, independent decisions are made based on each view and then combined for making the final verdict. The class probability distribution is modeled using the Dirichlet distribution. The weight given to each viewpoint depends on their contribution to the ultimate decision. Compared with traditional multi-view TSK classification methods, the proposed method has several advantages. First, it provides a new idea to solve the problem of excessive dimensionality of multi-view TSK classification methods. Second, it allows the identification of the most trustworthy viewpoints, which can lead to more accurate predictions. Third, it provides the flexibility to use a different TSK classifier for each view, which can lead to better performance. This allows us to better understand the impact of each view on the classification results. Our method shows excellent performance in terms of both computational efficiency and interpretability, and experiments on multiple datasets have demonstrated its effectiveness. To highlight the structural arrangement of this paper, the next part is briefly introduced: the second part will give a brief overview of the related fields., and in the third section we will analyze in detail the unique features and specific innovations of our proposed method. The fourth section will introduce the dataset and the method in the experiments and analyze the advantages of our proposed method through the experimental data. Finally, we will summarize the uniqueness of our method and the future directions for improving it.
2 Related Work We will first introduce the relevant formulas in TSK fuzzy systems and their parameter descriptions in this section, and then we will analyze the classical algorithm CCA in multi-view learning as well as recent high-performance algorithms. 2.1 Traditional TSK Fuzzy Systems Due to the interpretability and the ability of data-driven learning, TSK fuzzy systems are widely used in fields such as medicine and mechanics [7]. TSK fuzzy rule Rk : IF x1 is Ak1 ∧ x2 is Ak2 ∧ . . . ∧ xm is Akm , k THENf k (x) = p0k + p1k x1 + . . . + pm xm , k = 1, 2 . . . K,
(1)
In the above definition, x = [x1 , x2 . . . xm ]T and ∧ are the input values and logic symbols of the fuzzy system, Aki can be considered as a set consisting of a series of affiliation function values, k represents the number of rules used for each fuzzy system. Gaussian functions are one of the commonly used fuzzy membership functions and are widely used. μAk (xi ) = exp(− i
(xi − eik )2 2δik
)
(2)
A Novel Algorithm to Multi-view TSK Classification
551
where δik and eik are the antecedent parameters. For each fuzzy rule, the emission intensity for a given input vector x can be calculated by μ˜ k (x) =
d i=1
μAk (xi ) i
(3)
The final output vector is given by the following equation. yo =
K
μ˜ k (x)fk (x) = xTg P g
(4)
k=1
where fk denotes the k th rule of the output vector. xg and P g are defined as follows: xe = (1, xT )T
(5)
x˜ k = μ˜ k (x)xe
(6)
T T T T 1 2 xg = x˜ , x˜ , . . . , x˜ K
(7)
T pk = p0k , p1k , . . . , pdk
(8)
T P g = (p1 )T , (p2 )T , . . . , (pK )T
(9)
2.2 Multi-view Learning In recent years, the number of articles in the direction of multi-perspective research has gradually increased [11], which reflects the increasing attention to multi-perspective learning as a research direction in the field of pattern recognition [5]. It has achieved important applications in many fields, such as object recognition, face recognition, speech recognition, etc. In traditional single-view learning, the classifier can only use information from one view to make decisions [13], while in multi-view learning, the classifier can obtain information from multiple views and make more accurate decisions. A popular approach is to leverage CCA for multi-view learning [6]. The method finds the commonalities between different views by learning the linear correlation in each view, and uses these commonalities for the decision of the classifier. Besides CCA, DMCCA has also been proposed for multi-view learning, which uses a deep network to strengthen the correlation between different views. DGCCA is a generalized CCA method, which can not only deal with linear correlation, but also deal with nonlinear correlation. DTCCA is a trilinear CCA method, which makes decisions by combining information from different perspectives into tensors, and uses the three-way correlation to improve the accuracy of the classifier. In addition,
552
L. Nie et al.
DCCA_EigenGame is a deep CCA method based on constructing EigenGame, which improves the performance of classifiers. Multi-view learning has been widely promoted, driven by the techniques of deep learning and kernel learning [9]. This is due to its ability to improve the accuracy of classifiers. Although multi-view learning faces many challenges in practical applications, it is believed that with the continuous advancement of technology and research. There will be wider application and further development in the future.
3 Multi-View Decision TSK Fuzzy System This paper proposes an innovative multi-view TSK classification method, which makes individual decisions in each view, and then selects the most credible decision results to obtain comprehensive classification results, which avoids the dimension explosion in traditional multi-view TSK classification question. Different from traditional methods, we make separate decisions on the features of multiple perspectives, and then select the most credible decision result. Because of the high confidence misspecification problem that can result from using softmax output as the confidence level, our method uses a novel approach instead of softmax. This innovative point allows the whole model to obtain more accurate results. We have also introduced evidence-based uncertainty estimation techniques that allow for more accurate assessment of decision uncertainty and the flexibility to integrate multiple perspectives to make informed decisions when needed, making an effort to further improve accuracy and reliability. 3.1 Traditional Multi-view TSK Classification Task The main disadvantage of traditional multi-view TSK classification is that its input data has a high-dimensional problem, and the number of views is often large, which leads to the problem of dimension explosion [12]. Since there may be significant differences in the feature space of each view [8], simple superposition of all feature vectors will lead to redundancy of feature representation and loss of information (in Fig. 1). The division of feature space is the basis of rule base construction. Adding dimensions to the feature dimensions will increase the rule base, but it will further make the rule base sparse and affect the performance of the classifier [10]. In addition, the sparsity of the rule base also affects the performance of the classifier. 3.2 MULTI-VIEW Decision TSK FUZZY SYSTEM Our proposed method is different from the traditional multi-view TSK classification algorithm. Its innovation is that the most plausible decision results can be selected on a single view. In contrast, the traditional multi-view TSK classification algorithm superimposes features from multiple views, leading to the dimensional explosion problem of TSK. We use the characteristic of the Dirichlet distribution to assist our model by refining its output and finally feeding it into the decision level of the system. (in Fig. 2). While
A Novel Algorithm to Multi-view TSK Classification
553
Fig. 1. Traditional multi-view TSK classification methods stack features extracted from multiple views and feed them into a TSK classifier.
traditional neural network classifiers treat the output as points on Simplex, we use the Dirichlet distribution. As mentioned above, since the traditional model focuses on fusing features, the data is passed into the classifier. The classification results are mostly based on the final softmax function, which can lead to results that are not always accurate. Using the softmax output as a confidence level often leads to overconfidence and causes problems. Our model avoids this problem by introducing overall uncertainty. Specifically, the introduction of the Dirichlet distribution into our system allows our system to transform it into the form of parameters when making decisions for each view. Existing methods tend to use feature stacking to fuse features, which can lead to dimensionality assurance problems. It can be seen that the uncertainty is derived using a parameterization of the Dirichlet distribution, which faces many problems when incorporated into the system. Therefore, the problems caused by the uncertainty also limit the development of the method in the direction of multi-view classification. For clarity of expression, in the case of multi-class classification, we use a sub-logic method (SL) that relates the probability of the current classification to the parameters. The innovation of MVD-TSK-FS is that the uncertainty of each view can be computed independently and used as the output of each view classifier (in Fig. 2). Our proposed approach involves the use of subjective logic, which allows defining a theoretical framework to obtain information about the probability of different categories (quality of beliefs) and the overall uncertainty of a multi-categorical problem (quality of uncertainty) by collecting evidence from the data. Our system collects critical information from each view, which is considered as "evidence" (in Fig. 2). Note that the evidence here plays a key role in the system and prepares the data for the next step. The system uses SL to obtain uncertainty to specify a parameter for each decision. th For each view, we can assign k + 1 quality values to the v view: uv +
K
bvk = 1
(10)
k=1
≥ 0 and ≥ 0 represent the value of the classification decision and the where uncertainty parameter, respectively. bvk
uv
554
L. Nie et al.
Fig. 2. Each view is decided separately, and then this method selects the most plausible decision result, which leads to a comprehensive classification result, an innovative multi-view TSK classification method.
Our method draws on Dempster-Shafer theory and TMC [14, 15], which enables the system to collect critical evidence to be merged into a single confidence function to comprehensively consider all available evidence. Specifically, V sets of indepenK v v v dent probability assignments {M }1 need to be merged, where M = bvk k=1 , uv
becomes the combined result, and a joint quality M = {bk }K k=1 , u is obtained. We were able to obtain the above confidence quality for each viewpoint, applied to data with v different viewpoints. Immediately after, we can integrate the confidence levels of the different viewpoints, using Dempster’s combination rule. More specifically, we integrate the confidence and uncertainty qualities of multiple viewpoints using the following rules: M = M1 ⊕ M2 ⊕ · · · Mv .
(11)
By applying the combination rules described above, our system obtains critical evidence from each view and then uses the Dirichlet distribution to obtain refined parametric results. Thus, inferences are made about the final probability and overall uncertainty of each category. This approach enables MVD-TSK-FS to gain the ability to refine decisions from multiple views and avoid the dimensional explosion problem due to feature stacking. To facilitate understanding, we will elaborate the whole process. The key evidence parameters from ith view are first passed through the TSK classifier. The collected evidence parameter ei is passed into the Dirichlet distribution part by the formula α i = ei +1, thus obtaining the key parameter αi in the Dirichlet distribution. The above parameters do the work of refining the data for the SL, and finally the polynomial opinion D(pi |αi ) is obtained, where pi is the category-based probability distribution on the ith view. pij is the probability of the class j obtained after the multi-view decision for the ith sample. Our method takes into account a bit of the cross-entropy loss function and makes certain
A Novel Algorithm to Multi-view TSK Classification
555
modifications on top of that. Lace (α i ) ⎤ ⎡ K −yij log pij ⎦ = ⎣ j=1
=
K
1 αij −1 pij d pi B(α i ) K
j=1
(12)
yij ψ(Si ) − ψ αij ,
j=1
where ψ(·) is the Digamma function. In our model, we want to scale down the evidence for mislabeling to 0 because although the loss function above ensures that each sample has more evidence for the correct label than the other categories. α i )||D(pi |1) KL D(pi | ⎞ ⎛ ⎡ ⎞⎤ ⎛ K K K ˜ ik (13) k=1 α ⎠+ = log⎝ α˜ ij ⎠⎦, (α˜ ik − 1)⎣ψ(α˜ ik ) − ψ ⎝ K (K) k=1 (α˜ ik ) j=1 k=1 where (·) in the above equation is the gamma function. α i = yi + 1 − yi α i is the Dirichlet distribution that adjusts the parameters to be able to avoid penalizing the truth class of the evidence to 0. In summary, the specific loss of the samples is as follow: α i )||D(pi |1) (14) L(α i ) = Lace (α i ) + λt KL D(pi | where λt > 0 is the balance factor. Since MVD-TSK-FS is a multi-view decision-based algorithm, multiple decisions need to be considered. These multiple decisions provide the system with more references for classification, from which the most appropriate decision is selected. We take this feature into account when designing the overall loss function to ensure that all opinions form a reasonable opinion at the same time. N V v (15) L αi L(α i ) + Loverall = i=1
v=1
4 Experimental Details 4.1 Datasets To evaluate the methodology more objectively, we selected six real-world public data sets as experimental data, described as follows: 1. Dermatology: It is a medically oriented dermatology dataset that contains both histopathology and clinical medicine perspectives to extract special diagnoses.
556
L. Nie et al.
2. Forest Type: This dataset is a remote sensing image of the forest taken by satellite, and the band and spectral are used to extract the features of this image. 3. Epileptic EEG: The dataset is epileptic EEG data. DWT and WPD were used to extract the features of this data. 4. Caltech7: The dataset is extracted from 1474 images using the three directions Gabor, CENTRIST and wavelet texture. 5. Handwritten: It contains 2000 samples and 6 views, which is one of the most commonly used public datasets in multi-view learning. 4.2 Experimental Settings We combine the CCA-based algorithm with the traditional TSK algorithm. For clarity of presentation, we use a TSK classifier to classify a potential representation, which is obtained by the CCA-based approach. Other advanced TSK algorithms [1–4] stack features from multiple views based on a multi-view TSK classification learning approach. For fair comparison, all data have been normalized. The number of fuzzy rules ranging from 20 to 50 have been set. 4.3 Experimental Results After experimental testing, our proposed new multi-view TSK classification method obtains competitive classification results over the traditional multi-view TSK classification algorithm (in Table 1). Table 1. The correct classification rate of the experiment Model
Caltech7
Handwritten
Dermatology
Forest
EEG
TSK-FS
0.9153
0.9100
0.9075
0.9295
0.6050
HTSK
0.9390
0.9700
0.9600
0.9595
0.6750
MBGD-RDA
0.9356
0.9600
0.9700
0.9459
0.6650
TSK-MBGD-UR
0.9254
0.9600
0.9550
0.9459
0.5950
TSK-MBGD-BN
0.9254
0.9675
0.9324
0.9189
0.6000
HTSK-LN
0.9390
0.9500
0.9650
0.9324
0.6450
HTSK-LN-RELU
0.9322
0.9700
0.9600
0.9595
0.6550
DMCCA-TSK
0.9288
0.9100
0.9375
0.8514
0.4200
DGCCA-TSK
0.9390
0.8900
0.8650
0.8514
0.6150
DTCCA-TSK
0.9254
0.9400
0.8950
0.8243
0.5700
DCCAEG-TSK
0.9119
0.9300
0.9075
0.9459
0.6600
MVD-TSK-FS
0.9322
0.9100
0.9850
0.9595
0.5750
A Novel Algorithm to Multi-view TSK Classification
557
Our experiments show that our proposed MVD-TSK-FS based on single-view decision is competitive in terms of accuracy on different datasets than the traditional method of stacking multi-view features. The new method has stronger robustness and generalization ability, and can better avoid the problem of dimensional explosion. The traditional multi-view TSK classification method suffers from the dimensional explosion problem when using data features from multiple views, which leads to an increase in classifier computation and a decrease in efficiency. Our proposed MVD-TSK-FS makes individual decisions for each view and selects the final classification result based on the confidence information of each view, which effectively avoids the dimension explosion problem and improves the efficiency and practicality of the algorithm.
5 Conclusion The generality and wide applicability of the MVD-TSK-FS proposed in this paper are proved by experimental verification on datasets from different domains. The method makes separate decisions in each perspective, and then selects the most reliable decision results to obtain comprehensive classification results. Compared with other algorithms, our algorithm shows strong competitiveness in accuracy rate. In conclusion, this study successfully developed an innovative multi-view TSK classification method and verified its good performance on datasets from multiple domains. This provides a new idea for solving the problem of too many dimensions in the traditional multi-view TSK classification. The experimental results fully prove the feasibility and application potential of the method proposed in this study. In future research, we will further optimize and extend the multi-view TSK classification method proposed in this paper, so that the model can obtain a higher correct rate.
References 1. Cui, Y., Wu, D., Xu, Y.: Curse of dimensionality for tsk fuzzy neural networks: Explanation and solutions. In: 2021 International Joint Conference on Neural Networks (IJCNN), pp. 1–8. IEEE (2021, July) 2. Wu, D., Yuan, Y., Huang, J., Tan, Y.: Optimize TSK fuzzy systems for regression problems: Minibatch gradient descent with regularization, DropRule, and AdaBound (MBGD-RDA). IEEE Trans. Fuzzy Syst. 28(5), 1003–1015 (2019) 3. Cui, Y., Wu, D., Huang, J.: Optimize tsk fuzzy systems for classification problems: Minibatch gradient descent with uniform regularization and batch normalization. IEEE Trans. Fuzzy Syst. 28(12), 3065–3075 (2020) 4. Cui, Y., Xu, Y., Peng, R., Wu, D.: Layer Normalization for TSK Fuzzy System Optimization in Regression Problems. IEEE Trans. Fuzzy Syst. 31(1), 254–264 (2022) 5. Jiang, Y., Chung, F.L., Wang, S., Deng, Z., Wang, J., Qian, P.: Collaborative fuzzy clustering from multiple weighted views. IEEE transactions on cybernetics 45(4), 688–701 (2014) 6. Wong, H.S., Wang, L., Chan, R., Zeng, T.: Deep tensor CCA for multi-view learning. IEEE Transactions on Big Data 8(6), 1664–1677 (2021) 7. Jiang, Y., Zhang, Y., Lin, C., Wu, D., Lin, C.T.: EEG-based driver drowsiness estimation using an online multi-view and transfer TSK fuzzy system. IEEE Trans. Intell. Transp. Syst. 22(3), 1752–1764 (2020)
558
L. Nie et al.
8. Yin, J., Sun, S.: Incomplete multi-view clustering with cosine similarity. Pattern Recogn. 123, 108371 (2022) 9. Tian, Y., Fu, S., Tang, J.: Incomplete-view oriented kernel learning method with generalization error bound. Inf. Sci. 581, 951–977 (2021) 10. Liu, X., et al.: Late fusion incomplete multi-view clustering. IEEE transactions on pattern analysis and machine intelligence 41(10), 2410–2423 (2018) 11. Perry, R., et al.: mvlearn: Multiview machine learning in python. The Journal of Machine Learning Research 22(1), 4938–4944 (2021) 12. Jiang, Y., et al.: Recognition of epileptic EEG signals using a novel multi-view TSK fuzzy system. IEEE Trans. Fuzzy Syst. 25(1), 3–20 (2016) 13. Li, X.L., Chen, M.S., Wang, C.D., Lai, J.H.: Refining graph structure for incomplete multiview clustering. IEEE Transactions on Neural Networks and Learning Systems (2016) 14. Han, Z., Zhang, C., Fu, H., Zhou, J.T.: Trusted multi-view classification. arXiv preprint arXiv: 2102.02051 (2021) 15. Han, Z., Zhang, C., Fu, H., Zhou, J.T.: Trusted multi-view classification with dynamic evidential fusion. IEEE transactions on pattern analysis and machine intelligence (2022) 16. Zhang, W., Deng, Z., Lou, Q., Zhang, T., Choi, K.S., Wang, S.: TSK Fuzzy System Towards Few Labeled Incomplete Multi-View Data Classification. arXiv preprint arXiv:2110.05610 (2021)
Expert Knowledge-Driven Clothing Matching Recommendation System Qianwen Tao1 , Jun Wang2 , ChunYun Chen2 , Shuai Zhu1 , and Youqun Shi1(B) 1 School of Computer Science and Technology, Donghua University, Shanghai, China
{2212633,2212582}@mail.dhu.edu.cn, [email protected] 2 School of Fashion and Art Design, Donghua University, Shanghai, China {michael_wang,ccy}@dhu.edu.cn
Abstract. Internet shopping has become a major consumer channel, but the lack of professional knowledge in clothing matching and the inability to try on clothes often lead to blind consumption and high clothing return rates. Based on the experience of senior clothing experts, a digital clothing tagging system was established, the weight of clothing matching was quantified, and customer image parameters were designed to create a clothing recommendation system for dress matching based on a knowledge graph. According to the system testing results, it can be concluded that the system can meet the needs of at least 2000 users simultaneously, and the recommended response time should not exceed 1.5 s. Keywords: Expert System · Knowledge Graph · Matching Weights · Personalized Clothing Recommendation · Recommendation Algorithm
1 Introduction With the improvement of consumption level, people’s demand for clothing quality and fashionable dressing is also growing [1]. However, there are still three main shortcomings in the current clothing sales platform recommendations: The cold start problem caused by the lack of historical purchase information and product browsing records of new users on the platform [2]; Ignoring user’s body parameters leads to a high rate of returns and exchanges on the shopping platform [3]; The lack of timely iteration in the recommendation system leads to recommendations that do not align with current fashion trends [4]. The recommendation method described in this paper establishes a professional knowledge base of clothing matching based on expert knowledge in the clothing field by collecting literature, tracking fashion trends and obtaining users’ body parameters, and calculating the matching degree of clothing using scoring algorithms. Then the temporal parameters [5] are added to the triad to complete the construction of a dynamic knowledge graph [6], two different types of recommendations are provided: top and bottom outfit recommendations based on purchased items, and personalized recommendations based on user feature points.
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 559–572, 2023. https://doi.org/10.1007/978-981-99-4761-4_48
560
Q. Tao et al.
2 Expert Knowledge of Clothing Image design experts focus on the study of the wear aesthetics of the human body [7], and construct a complete knowledge base of clothing recommendations based on the theoretical methods provided by experts. The establishment of the knowledge base requires the following six steps: knowledge acquisition and pre-processing, establishment of the knowledge system, consultation with experts, modification of the knowledge system, formation of the knowledge base and digitization of clothing tags and user tags, as shown in Fig. 1.
Fig. 1. Flow chart of expert knowledge base construction.
2.1 Clothing Color Classification Coloro is a garment color classification theory [8], which provides a system of 3500 colors, classifies colors by means of primary and secondary classification, and displays them in the form of color blocks to help designers select and use garment colors more accurately. Coloro divides colors into 34 colored and 2 uncolored (black and white) categories [9], and the color categories obtained after classification Its color block aggregation is high, the graded level of color is clear, and the image of color blocks within the category is intuitive, which clearly shows the color layering structure and the properties of color blocks, as shown in Fig. 2. 2.2 Garment Label Coding According to the expert rules, the garment labels are divided into seven categories, namely, color, fabric, style, fit, upper collar type and lower trouser type, and category, and the sub-category labels under each category are coded. For example, the color label is TF101 for earthy yellow and TF102 for gold, as shown in Table 1.
Expert Knowledge-Driven Clothing Matching Recommendation System
561
Fig. 2. 36 color classification renderings. Table 1. Clothing label information. Number Clothing labels
Tag classification
TF1
Color
There are 36 classifications such as earthy yellow, gold, coffee, earthy red, brown, and light blue
TF2
Fabrics
Cotton-type fabrics, linen-type fabrics, silk-type fabrics, woolen fabrics, pure chemical fiber fabrics, leather, blended 7 kinds of classification
TF3
Style
Elegant and noble, classical and traditional, ladylike and gentle, cute and sweet, avant-garde novelty, romantic and charming, natural and casual, simple and intellectual, dramatic and exaggerated, handsome teenagers, sexy and charming, neutral and casual 12 kinds of classification
TF4
Versions
Tops and bottoms: 3 classifications of extra loose, loose and slim fit Suit: X, A, T, H, O 5 kinds of classification
TF5
Top collar style
Collar: stand-up collar, knotted collar, shirt collar, lace collar, etc Collarless: round collar, V-neck, U-neck, one-piece collar, boat-shaped collar, square collar, etc
TF6
Bottom trouser style Skirt types: A-line skirt, sheath skirt, pencil skirt, knit skirt, skort, etc Pant type: Skinny pants, dad pants, harem pants, wide-leg pants, straight pants, flared pants, etc
TF7
Category
Lower garments: Shirts, T-shirts/POLO shirts, Vests, Tops, Coats, Sweaters (non-cardigans), Base layer tops, Dresses Upper garments: Skirts, Pants (long), Shorts, Leggings
2.3 User Tag Encoding According to the expert rules, the user tags were divided into six categories, namely, user body type, face type, skin color, hair type, personality and height and body mass index, and the sub-category tags under each category were coded. For example, the hourglass
562
Q. Tao et al.
type in the body type label is UF101 and the inverted triangle type is UF102, as shown in Table 2. Table 2. User label information. Number User Characteristics
Characteristic classification
UF1
Body Type
Hourglass (X-shaped), inverted triangle (T-shaped), pear-shaped (A-shaped), apple-shaped (O-shaped), straight (H-shaped), petite (I-shaped)
UF2
Face shape
Diamond-shaped face, oval-shaped face, oblong-shaped face, square-shaped face, inverted triangle-shaped face, round-shaped face, heart-shaped face, rectangular-shaped face
UF3
Skin color
Light cool skin tone, neutral cool, dark cool skin tone, light warm skin tone, neutral warm, dark warm skin tone
UF4
Hairstyles
Straight hair, ponytail, pill hair, pear hair, bob hair and weave hair
UF5
Personality
Romantic, lively, sedate, strict, traditional, cute
UF6
Height and weight index Thin, normal, overweight, obese, overweight
2.4 Matching Rules The weighting factor [10] usually refers to the importance or influence of different labels in the entire combination, adding weight coefficients can achieve better visual effects and overall coordination. Color is the most basic element in clothing collocation, and the correct color collocation can highlight the overall effect of the clothing and reflect individual fashion taste. Style and category classify and position the overall design of clothing, which also has a great impact on the overall effect of clothing matching. Fabric and style consider the details of clothing, and their matching can also affect the overall effect of clothing, but their impact is relatively subtle compared to color and style. Therefore, using wi (i = 1…5) to represent the influence score of clothing on recommendations, which is provided by image design experts and can be maintained in the backend system. For example, the top color and bottom color play an important decision role in clothing matching with a weight index of 0.91, and the rest of the weight coefficients are shown in Table 3. In the personalized matching rules, w i (i = 1…6) indicates the influence score of each user label in clothing recommendation. According to the rules given by the image design experts, we can learn that the user’s personality has a certain influence on the style of clothing, which is because the style of clothing will express the individual’s personality and taste, and different people have different personality characteristics and aesthetic preferences, therefore, the user’s personality determines the style of clothing with a weighting factor of 0.78. The collar type has a great visual influence on the
Expert Knowledge-Driven Clothing Matching Recommendation System
563
Table 3. Clothing - Clothing weighting factor table. Code
Matching Rules
Weighting factor
w1
Top color - Bottom color
0.91
w2
Top fabric - Bottom fabric
0.71
w3
Top style - Lower style
0.83
w4
Top pattern - Lower pattern
0.65
w5
Upper clothing category - Lower clothing category
0.75
face shape and hair style, and different collar types will leave a Different collar styles will leave different impressions and feelings, and some collar styles can make up or highlight certain features. The user’s face shape and hair style determine the user’s top collar style with weight coefficients of 0.93 and 0.65, respectively, and the rest of the weight coefficients are shown in Table 4. Table 4. Clothing-User Feature Weighting Table. Code
Matching Rules
Weighting factor
w 1
User Body Type - Garment Fit
0.87
w 2
User’s face shape - Neckline
0.93
w 3
User Hair Style - Collar Type
0.65
w 4
User’s skin color - Clothing color
0.84
w 5
User personality - Clothing style
0.78
w 6
Height and weight index - Clothing style
0.77
The matching degree rating table is based on the matching rules of image design experts for different clothing tags and user tags. The values in the rating table can reflect the matching degree between tags, and the system calculates the overall rating between clothing items based on the scores of the rating table. Taking clothing collar type and user face type as an example, experts proposed six rules corresponding to face type and clothing collar type. aij (i = 1…n, j = 1…m) indicates its matching degree. For people with diamond-shaped faces, choosing clothing with V-shaped collars may make the face appear longer, which is not conducive to highlighting the beauty, with a rating of 0.2. Oval-shaped faces are suitable for wearing round and square collar clothing with a rating of 0.8, and are not suitable for wearing pointed collar clothing or complex collar type with a rating of 0.1. Some of the clothing personalization and matching scoring rules are shown in Table 5.
564
Q. Tao et al. Table 5. Clothing collar style-user face shape matching score.
Collar Square Type collar Face shape
Small round neck
Large round neck
U-collar
Small V-neck
Deep V-neck
High collar
Pointed collar
Rhombus
0.8
0.2
0.8
0.7
0.4
0.2
0.6
0.4
Heart type
0.1
0.6
0.5
0.3
0.9
0.9
0.6
0.4
Inverted Triangle
0.9
0.5
0.5
0.6
0.4
0.9
0.5
0.6
Oval type
0.8
0.8
0.8
0.4
0.6
0.3
0.6
0.1
Round face 0.8
0.7
0.7
0.8
0.7
0.6
0.2
0.3
Square type
0.9
0.9
0.3
0.7
0.6
0.3
0.9
0.1
3 Personalized Recommendation Algorithm 3.1 Top and Bottom Clothing with Recommended The matching recommendation of top and bottom clothes means that if the clothes currently browsed by the user belong to the top clothes category, the bottom clothes with high overall matching degree will be recommended for the user, and similarly if the clothes currently browsed belong to the top clothes category, the bottom clothes with high overall matching degree will be recommended for the user. When a user adds a garment to the shopping cart or decides to buy the garment, the known upper garment features T = { color TF1 , fabric TF2 , style TF3 , shape TF4 , collar type TF5 , category TF7 }, using the query statement and random function in the database, limiting The number of result set is 100, and obtain the clothing underwear features T = { color TF1 , fabric TF2 , style TF3 , pattern TF4 , trouser type TF6 , category TF7 }. The user-selected garments and the garments obtained from the database were calculated by the scoring algorithm, resulting in Eq. (1). S = w1 a1 + w2 a2 + w3 a3 + w4 a4 + w5 a5
(1)
The formula ai (i = 1…5) represents the matching score between clothing tags, wi (i = 1…5) represents the corresponding weights between clothing tags, and the specific algorithm flow is shown in Table 6. In the output of all scores, fuzzy logic is used to express the degree of matching between garments. 0 ~ 0.2 indicates that the matching style of the label is “very unsuitable”, 0.2 ~ 0.4 indicates that the matching style of the label is “unsuitable”, 0.4 ~ 0.6 indicates that the matching style of the label is “average”, 0.6 ~ 0.8 indicates that the matching style of the label is “suitable” and 0.8 ~ 1 indicates that the matching style of the label is “suitable”. 0.4 ~ 0.6 means the matching style of the label is “average”, 0.6 ~ 0.8 means the matching style of the label is “suitable” and 0.8 ~ 1 means the matching style of the label is “perfect”.
Expert Knowledge-Driven Clothing Matching Recommendation System
565
Table 6. Clothing label matching scoring algorithm. Steps
Clothing label matching scoring algorithm
➀Input
Top features T = {TF1 , TF2 , TF3 , TF4 , TF5 , TF7 } A total of 6 clothing features Underwear features T = { TF1 , TF2 , TF3 , TF4 , TF6 , TF7 } A total of 6 clothing features
➁Calculation
Step 1: Iterate through the weights in the weight table W using a For loop to obtain the two clothing labels wi = { f1i , f2i } corresponding to each weight wi Step 2: Iterate through the input features according to the clothing labels{ f1i , f2i }and use the IF statement to determine whether the feature exists in the data table, if not, repeat step 1, if so, Then Step 3: Use a For loop to traverse the collocation score table and find the score aij corresponding to { f1i , f2i } Step 4: Calculate the matching score S for that top and bottom outfit from Eq. (1) Step 5: Use IF to determine whether the collocation score satisfies the requirement, and if the score S > = 2.5, Then Step 6: Saving the down load information to the new data list Result Step 7: Repeat steps 2 to 6 until the lower mount features are entered Step 8: Sorting the data in Result from highest to lowest using a sorting algorithm Step 9: return Result
➂Output
Output Result
3.2 Personalized Clothing Recommendation Personalized clothing recommendation is to recommend tops, bottoms and suits that have a high overall match with the user when the user enters multiple characteristics of himself/herself. For example, to obtain user characteristics U = {body type UF1 , face shape UF2 , skin color UF3 , hair type UF4 , personality UF5 , height and weight index UF6 }, use the query statement and random function in the database to select 50 items of tops and 50 items of bottoms respectively, and combine the weight coefficients with matching scores, resulting in Eq. (2). S = w1 a1 + w2 a2 + w3 a3 + w4 a4 + w5 a5 + w6 a6 The personalized clothing recommendation algorithm process is in Table 7.
(2)
566
Q. Tao et al. Table 7. Personalized matching scoring algorithm.
Steps
Personalized matching scoring algorithm
➀Input
User characteristics U = {UF1 , UF2 , UF3 , UF4 , UF5 , UF6 } A total of 6 user characteristics Top features T = { TF1 , TF2 , TF3 , TF4 , TF5 , TF7 } A total of 6 clothing features
➁Calculation Step 1: Iterate through the weights in the weight table W using a For loop to obtain the two clothing labels wi = { u1i , f1i }corresponding to each weight wi Step 2: Iterate through the input features according to the clothing labels{ u1i , f1i }and use the IF statement to determine whether the feature exists in the data table, if not, repeat step 1, if so, Then Step 3: Use a For loop to traverse the collocation score table and find the score aij corresponding to{ u1i , f1i } Step 4: Calculate the matching score S for that top and bottom outfit from Eq. (1) Step 5: Use IF to determine whether the collocation score satisfies the requirement, and if the score S > = 3, Then Step 6: Saving the down load information to the new data list Result Step 7: Repeat steps 2 to 6 until the lower mount features are entered Step 8: Sorting the data in Result from highest to lowest using a sorting algorithm Step 9: return Result ➂Output
Output Result
4 Dynamic Knowledge Graph-Based Recommendation The traditional representation methods of knowledge graphs [11, 12] are unable to express the problem of constantly changing knowledge, and neglect the dynamic nature of knowledge. To address the issue of knowledge graphs being unable to represent time information, this chapter introduces the concept of dynamic knowledge graphs [13, 14]. 4.1 Storage of Knowledge Graph This paper stores data using RDF graphs [16]. Based on the clothing feature labels, color matching rules, and clothing matching rules at the schema level, the acquired knowledge is analyzed for clothing recommendation data. The knowledge is stored using a triple structure, and each piece of clothing is stored as a separate clothing knowledge graph, as shown in Fig. 3. In the clothing knowledge graph, the clothing recommendation rules are stored in the form of triples (user information, score, clothing information, time), (user information, suit_for, suit information, time), (top information, suit_for, bottom information, time),
Expert Knowledge-Driven Clothing Matching Recommendation System
567
Fig. 3. Clothing label storage diagram.
and stored in the clothing recommendation graph. With the passage of time, clothing launch, and user increase, dynamic updates are implemented, as shown in Fig. 4.
Fig. 4. Knowledge graph matching diagram.
4.2 Dynamic Clothing Recommendation System After the knowledge map is stored, the recommendation is made through two clothing recommendation modes. The first top and bottom outfit matching recommendation mode, when a user confirms the purchase of a certain top outfit, it searches for the bottom outfit matching with the outfit from the knowledge graph, sorts them according to the stored ratings, and recommends the Top 10 outfits for the user in priority, as shown in Fig. 5. The second personalized clothing recommendation mode, when encountering the situation of cold start data, can recommend suitable top and bottom clothes for users based on the personal information they input, as shown in Fig. 6.
568
Q. Tao et al.
Fig. 5. Clothing matching diagram.
Fig. 6. User Matching Diagram.
5 Application Examples and Test Results Apply the recommendation system to clothing wearing self media to test whether it can accurately recommend clothing. And conduct testing from two aspects: user concurrency and response time testing. 5.1 Clothing Recommendation Effectiveness After completing the registration process, the user needs to complete the collection of personal feature tags. The page will list feature images of basic information, body type, face shape, skin color, hairstyle, and style. Users will choose relevant information based on their own situation, as shown in Fig. 7. Based on the user’s personal information, personalized dressing recommendations suitable for that user are given after the assessment results. For example, the user’s gender
Expert Knowledge-Driven Clothing Matching Recommendation System
569
Fig. 7. User feature information collection page.
is female, weight is 56 ~ 60 kg, height is 166 ~ 170 cm, body type is X, face shape is round, skin color is white paper, hair style is ponytail, and prefers natural casual style clothes, as shown in Fig. 8-a. Recommend suitable clothes for the user based on his or her information. For example, if a user chooses to join to buy a certain clothing bottom, a suitable top is recommended for the user in the shopping cart, as shown in Fig. 8-b. When the user has a history of operations stored in the system, when the user opens the app, recommendations will also be made for the user on the system home page, as shown in Fig. 8-c.
a) Personalized recommendation results
b) Upper and lower clothing recommendation results
c) Homepage recommendations
Fig. 8. Rendering of recommended results.
5.2 Test Environment and Performance Testing The test was conducted on a Lenovo Xiaoxin Pro16 running Windows 11 Home Chinese version, with an R7 6800H processor, 16G LPDDR5 memory, 512G SSD hard disk,
570
Q. Tao et al.
and web browsers Google Chrome 100.0.4896.75 and Microsoft Edge 101.0.1210.32. Mobile devices such as Android phones and iPhones are used for system testing. Sample Size. The clothing wear recommendation system collects a large number of clothing samples, including various brands, different types, different occasions and different people, and the number of samples is shown in Table 8. Table 8. Response schedule of recommended modules. Clothing category
Brand source
Total sample
Top category
Rusty Code, Geun Master East, Laughing Han Court
1356
Bottom category
Redding, Uniqlo, rust code
1288
Suit category
Amore, Danfengye, Ochirly
953
Concurrency Testing. To simulate user’s actual operational behavior and monitor realtime performance, LoadRunner was selected as a stress load testing tool to predict system behavior, performance, and response time of requests. In the concurrency testing process, multiple users are required to visit concurrently, and different concurrency numbers are set to test the overall system performance. The concurrency test results are shown in Table 9. Table 9. Concurrent test results. Number of Concurrent
Expected results
Response time
Success rate
CPU usage
Memory Usage
200
1s
0.23 s
100%
11%
10%
400
1s
0.42 s
100%
22%
20%
600
2s
0.75 s
100%
36%
34%
800
2s
0.89 s
100%
41%
40%
1000
2s
1.05 s
100%
50%
46%
2000
3s
2.55 s
100%
72%
69%
The above-mentioned test index analysis, there is no abnormality during the operation when the number of concurrent people reaches 2000, and the success rate is 100%, which meets the demand of 2000 threads concurrent operation. Response Time Testing. This part of the test focuses on the response time of various modules in the system. The personalized recommendation module, shopping cart module, and homepage module participated in the system response time test. The response test results are shown in Table 10.
Expert Knowledge-Driven Clothing Matching Recommendation System
571
Table 10. Response schedule of recommended modules. Test scenarios
Traditional recommendation method
This article is recommended by
Personalized recommendation module
Response time 2.5 s
Response time 1.4 s
Shopping cart module
Response time 2.2 s
Response time 1.1 s
Home module
Response time 2.1 s
Response time 0.9 s
After analyzing the test results, it was found that the request response time for the recommended results is mostly less than 1.5 s. This indicates that the recommendation system designed in this article can meet users’ needs and achieve the expected goal.
6 Concluding The expert knowledge-based clothing wear recommendation system proposed in this paper constructs a clothing expert knowledge base, completes the tag matching scoring algorithm by clothing tag coding, and saves in dynamic knowledge graph for efficient recommendation. The results show that the system can quickly recommend clothes that meet the public aesthetics, improve the user’s retrieval efficiency and increase the user’s satisfaction, which can not only solve the problem of difficult clothes buying for users, but also be applied to major clothing platforms to create greater value for merchants. In addition, users can simulate their own effect after wearing the clothing in the system, that is, the problem of virtual fitting, which is also another direction in the field of clothing that can be continuously researched.
References 1. Iwendi, C., Ibeke, E., Eggoni, H., Velagala, S., Srivastava, G.: Pointer-based item-to-item collaborative filtering recommendation system using a machine learning model. Int. J. Inf. Technol. Decis. Mak. 21(01), 463–484 (2022) 2. Song, X., Han, X., Li, Y., Chen, J., Xu, X.S., Nie, L.: GP-BPR: Personalized compatibility modeling for clothing matching. In: Proceedings of the 27th ACM international conference on multimedia, pp. 320–328 (2019, October) 3. Deldjoo, Y., Schedl, M., Cremonesi, P., Pasi, G.: Recommender systems leveraging multimedia content. ACM Computing Surveys (CSUR) 53(5), 1–38 (2020) 4. Chen, W., et al.: POG: personalized outfit generation for fashion recommendation at Alibaba iFashion. In: Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining, pp. 2662–2670 (2019, July) 5. Ji, S., Pan, S., Cambria, E., Marttinen, P., Philip, S.Y.: A survey on knowledge graphs: Representation, acquisition, and applications. IEEE transactions on neural networks and learning systems 33(2), 494–514 (2021) 6. Hui, B., Zhang, L., Zhou, X., Wen, X., Nian, Y.: Personalized recommendation system based on knowledge embedding and historical behavior. Appl. Intell. 52(1), 954–966 (2021). https:// doi.org/10.1007/s10489-021-02363-w
572
Q. Tao et al.
7. Seymour, S.: Functional aesthetics. Ambra Verlag, In Functional Aesthetics (2019) 8. Hosseinian, S.M., Najafi Moghaddam Gilani, V., Mirbaha, B., Abdi Kordani, A.: Statistical analysis for study of the effect of dark clothing color of female pedestrians on the severity of accident using machine learning methods. Mathematical Problems in Engineering 2021, 1–21 (2021) 9. Kodžoman, D.: The psychology of clothing: Meaning of colors, body image and gender expression in fashion. Textile & leather review 2(2), 90–103 (2019) 10. Chen, X., et al.: Personalized fashion recommendation with visual explanations based on multimodal attention network: Towards visually explainable recommendation. In: Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 765–774 (2019, July) 11. Guo, Q., et al.: A survey on knowledge graph-based recommender systems. IEEE Trans. Knowl. Data Eng. 34(8), 3549–3568 (2020) 12. Chen, X., Jia, S., Xiang, Y.: A review: Knowledge reasoning over knowledge graph. Expert Syst. Appl. 141, 112948 (2020) 13. Wu, T., Khan, A., Yong, M., Qi, G., Wang, M.: Efficiently embedding dynamic knowledge graphs. Knowl.-Based Syst. 250, 109124 (2022) 14. Yan, Y., Liu, L., Ban, Y., Jing, B., Tong, H.: Dynamic knowledge graph alignment. In: Proceedings of the AAAI Conference on Artificial Intelligence, Vol. 35, No. 5, pp. 4564–4572 (2021, May) 15. Kazemi, S.M., et al.: Representation learning for dynamic graphs: A survey. The Journal of Machine Learning Research 21(1), 2648–2720 (2020) 16. Heath, T., Bizer, C.: Linked data: Evolving the web into a global data space. Synthesis lectures on the semantic web: theory and technology 1(1), 1–136 (2011)
Research on Construction Method of IoT Knowledge System Based on Knowledge Graph Qidi Wu, Shuai Zhu, Qianwen Tao, Yucheng Zhao, and Youqun Shi(B) School of Computer Science and Technology, Donghua University, Shanghai, China {2212502,2212557,2212582,2212633}@mail.dhu.edu.cn, [email protected]
Abstract. This research aims to deconstruct and reconstruct professional knowledge in the field of the Internet of Things (IoT), and extract and organize it into a hierarchical tree network structure based on knowledge points, in order to summarize and divide the entire IoT knowledge system and help learners or other individuals studying IoT-related knowledge to have a simple and easy-to-understand system to remove some learning obstacles. First, we design the architecture of an IoT knowledge system based on a knowledge graph, including knowledge dimension, knowledge association, knowledge expression, and knowledge storage. To enhance the knowledge system, we design crowdsourcing expansion methods, including role division, crowdsourcing process, and credibility and expertise mechanisms. Finally, we develop a system prototype and complete basic functions based on this foundation. Keywords: Internet of Things · Knowledge Graph · Crowdsourcing
1 Introduction Google first proposed the concept of knowledge graph on May 17, 2012 [1], and released a new generation of knowledge search engine that displays information related to tasks, locations, and events associated with familiar entities or concepts. The knowledge graph is based on semantic network structure, which expresses the entities and their relationships through connections between network nodes and can be displayed in a visual way. Narrowly defined, the knowledge graph is used by Internet companies to organize web data from a semantic perspective [2], thus providing intelligent search services for large-scale knowledge databases. Formally, it is a knowledge carrier identified by graph data structure, describing objects and their relationships in the objective world, where nodes represent objects and edges represent the relationships between them [3]. In implementation, the knowledge graph uses the resource description framework of the semantic web to represent the content of knowledge systems and instance data at two levels, forming a complete knowledge system. Broadly defined, knowledge descriptions, instance data, related standards, technologies, tools, and application systems constitute the general knowledge graph [4]. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 573–585, 2023. https://doi.org/10.1007/978-981-99-4761-4_49
574
Q. Wu et al.
Building a knowledge graph in the field of IoT enables quick discovery of valuable knowledge resources, acquiring knowledge links, forming a knowledge network, and constructing a high-reliability knowledge system that meets the learning and communication needs of users such as experts, scholars, and enterprises. However, there are several issues in building and expanding a knowledge graph in the IoT field. Firstly, the knowledge does not cover the entire industry chain, and builders often lack professional domain knowledge backgrounds. Therefore, external professional resources are needed to enrich the knowledge base through collective wisdom. Secondly, the nodes on existing knowledge graphs often only contain text information. The display form is not diverse enough, and the graph can be expanded by adding videos, pictures, PPTs, web pages, etc. A large amount of knowledge data is required for a knowledge system, and in practical applications, relying on simple expert addition of knowledge data often leads to low efficiency and inability to quickly expand the scale of knowledge data. The emergence of crowdsourcing technology brings about a solution to this problem. In 2006, Jeff Howe, a journalist for the American magazine Wired, first proposed the term “crowdsourcing” in the magazine and explained its meaning [5]. He believed that “crowdsourcing” is a business transformation, in which companies subcontract work to the public through the Internet, and any participant (amateur enthusiast) can use the network platform to provide creativity, solve problems, and obtain corresponding remuneration. The key prerequisite for its implementation is the construction of the network platform and the network linkage of potential participants. Crowdsourcing has now been widely used worldwide, including very famous cases like Wikipedia. As a distributed problem-solving and generation model, crowdsourcing has basic characteristics such as being based on the Internet, open production, independent participation, and independent collaboration. In fields where it is difficult for computers to handle alone, crowdsourcing technology has achieved good results. With the rapid development and widespread application of crowdsourcing, the drawbacks of traditional crowdsourcing systems have gradually emerged. Traditional centralized crowdsourcing systems are vulnerable to issues such as single-point failures and DDoS attacks. Moreover, they may give rise to undesirable behaviors such as “freeriding” and false reporting [6]. Meanwhile, during the task allocation process, there is a risk of sensitive information leakage for crowdsourcing participants. The emergence of blockchain technology provides a possible solution to these problems [7].
Research on Construction Method of IoT Knowledge System
575
2 The Architecture Design of the IoT Knowledge Graph 2.1 The Dimension of Knowledge According to the definition of knowledge graph, IoT knowledge points are referred to as knowledge entities, and the relationships between knowledge points are called object properties. From the four dimensions of knowledge entities, object properties, data properties, and domain resources, the characteristics of the IoT domain are analyzed. (1) Dimension of Knowledge Entities First is the dimension of knowledge entities in IoT: this domain has various types of domain entities, including different aspects such as perception layer, transmission layer, application layer, and common technology. Knowledge entities can be divided into at least 4 categories. However, the knowledge entities owned by other professional fields’ knowledge graphs are usually around 3 categories. For example, the clothing recommendation knowledge graph divides knowledge entities into clothing elements, personal elements, and rule elements. Therefore, when constructing an IoT knowledge graph, the classification of knowledge entities should be given importance, and the types and quantities of the main domain knowledge entities should be selected. (2) Dimension of Object Properties Second is the dimension of object properties in IoT: the relationship between knowledge entities in IoT is relatively complex. IoT involves many technical fields, such as communication, computation, storage, etc., and there may be various relationships between knowledge entities, such as “support,” “drive,” “depend,” etc. When constructing an IoT knowledge graph, a clear hierarchical structure needs to be delineated, and an accurate knowledge system needs to be organized. (3) Dimension of Data Properties IoT knowledge points are usually described in various forms such as web pages, rich text, images, tables, and videos. When constructing an IoT knowledge graph, it is necessary to define the data properties of the knowledge points. For example, properties such as sensor type, communication protocol, and energy consumption can be set for perception layer knowledge points; and properties such as application scenarios, data processing methods, and security performance can be set for application layer knowledge points. In the IoT knowledge graph, suitable data properties need to be set to accurately describe the knowledge points. (4) Dimension of Domain Resources The IoT domain has a wide range of resource types, and their forms and sources vary greatly. Resource forms include text, tables, images, and videos. These resources may come from professional books and online education platforms within the field, or they may come from websites or forums related to IoT. In real learning scenarios, learners usually need to search for related resources on their own to acquire specific knowledge points, lacking an efficient method of unified resource management. Therefore, in order to deal with these diverse sources and various forms of domain resources, an effective
576
Q. Wu et al.
resource management strategy needs to be adopted in combination with the knowledge graph. In summary, the knowledge entities in the IoT domain are diverse, the associations between entities are rich, the knowledge structure is highly structured, the sources of resources are extensive, and the overall data scale is large. The construction of an IoT knowledge graph should comprehensively consider the dimensions of knowledge entities, object properties, data properties, and domain resources to meet the needs of users in different fields and levels. 2.2 Knowledge Association Intelligent systems use dynamic membership values to express the knowledge graph’s association structure and a rule-based reasoning mechanism. In addition to basic properties such as encoding and naming, a number between 0 and 1 is used to represent the degree of association between nodes, which can be dynamically adjusted based on user access habits and frequency. This helps recommend associated learning content to users. Apart from using association degrees to express the relatedness of concepts, natural deduction rules and backward reasoning rules are emphasized to guide users in progressive learning and review-style learning. Rules include association rules, deduction rules, and backward reasoning rules. Association rules are the rules between knowledge nodes within the same ontology concept; deduction rules are the rules between knowledge nodes in ontologies that go from the former to the latter, from higher to lower levels; backward reasoning rules are the rules between knowledge nodes in ontologies that go from the latter to the former, from lower to higher levels. In association rules, knowledge nodes within the same ontology have edges between them that record their degree of relatedness. The higher their relatedness, the more similar the details of the knowledge nodes they describe. In deduction rules, the knowledge system can infer the subsequent knowledge node from the previous one or the lowerlevel knowledge node from the higher-level one, based on ontologies with sequential relationships and hierarchical relationships. In backward reasoning rules, the knowledge system can associate the previous knowledge node from the subsequent one or the higherlevel knowledge node from the lower-level one, based on ontologies with sequential relationships and hierarchical relationships (see 错误!未找到引用源。) (see Fig. 1).
Research on Construction Method of IoT Knowledge System
577
Fig. 1. Association rules of knowledge graph
2.3 Knowledge Representation In a knowledge graph system, multiple file formats are supported, including PPT, PDF, HTML, images, animations, videos, and audios. Standards and specifications specify the consistency of representation style for the knowledge system. File resources are identified using URIs, and the system has a dedicated index table to store the addresses of each knowledge point. The ID of each knowledge point in the table corresponds to its address (URL). Furthermore, following the development of technology, time stamps and background information about the inventors and institutions of corresponding knowledge points are added, enabling users to fully understand the technological background, knowledge contributors, and the accumulation process of achievements. File classification standards and specifications are crucial for knowledge graph systems. In a knowledge graph system, various file formats are supported, including images, documents, videos, animations, and audio, which contain contents such as PPT, rich text web pages, HTML, images, animations, and videos. Standards and specifications specify the consistency of representation style for the knowledge system, ensuring that different types of files can be standardizedly organized, effectively managed and retrieved. File classification standards involve aspects such as file naming rules, file storage paths, and file type categorization. File naming rules require short and clear file names that reflect the file’s topic and avoid duplication. File storage paths require hierarchical classification of files according to their types, such as images, documents, videos, etc., and further subdivision into small categories such as PPT, Word, HTML, etc. File type categorization should be based on their functionality and content, classified into categories such as images, videos, documents, animations, etc., and assigned standard file extensions, such as.jpg, .ppt, .mp4, and so on.
578
Q. Wu et al.
2.4 File Storage (1) File Indexing. The indexing of files in a knowledge graph is a critical technology for efficient retrieval and location of entities, properties, and relationships within the knowledge graph. In this article, file resources are uniquely identified using Uniform Resource Identifiers (URIs) for easy knowledge sharing. To quickly locate the content within knowledge graph files, a hash table is used as an indexing structure. A hash table is an array-based data structure that uses a hash function to map keys (entity URIs) to indexes in the array. The following steps are used to implement a hash table-based indexing method (see in Eq. (1)): H = (URI [i] × ci ) %M
(1)
Here, URI [ i] represents the ASCII value of the i-th character in the URI string. c is a constant used to reduce the probability of hash collisions. i is the position index (0, 1, 2, …, n − 1) of characters in the URI string, where n is the length of the URI string. M is the size of the hash table, which is usually selected as a prime number to reduce the likelihood of hash collisions. The above hash function maps the URI to an integer value and limits it to the range of the hash table size M through modulo operation. Furthermore, to resolve hash collisions, this article uses chaining for storage. (2) MySQL database table design. The knowledge point table contains the knowledge point ID and name, describing the basic information of the knowledge point, such as the knowledge point ID, name, description, hierarchy, etc. The structure of the table for storing knowledge points is shown in Table 1.
Table 1. Knowledge node storage table. Field name
Data types
Length
Field description
Id
bigint
20
Knowledge point table id
point_id
bigint
20
Knowledge Point id
point_name
varchar
32
Name of knowledge point
discription
varchar
255
Description of knowledge Points
Level
bigint
20
Knowledge level
create_time
datetime
--
Creation time
update_time
datetime
--
Update time
The knowledge relation table contains detailed information about the relationship as well as the detailed information of the knowledge points connected by the relationship. For example, it includes the description of the relationship content, the relationship name, knowledge point ID, knowledge point name, ID and name of the preceding knowledge point, creation and update time, etc. The storage structure of knowledge relations effectively expresses the associations between knowledge points, which can be applied for
Research on Construction Method of IoT Knowledge System
579
Table 2. Knowledge relation storage table. Field name
Data types
Length
Field description
Id
bigint
20
Relational table id
edge_id
bigint
20
Relationship id
discription
varchar
255
Relationship content Description
edge_name
varchar
32
Relationship name
point_id
bigint
20
Knowledge Point id
point_name
varchar
32
Name of knowledge point
before_id
bigint
20
Prior knowledge point id
before _name
varchar
32
Prepend the knowledge point name
create_time
datetime
--
Creation time
update_time
datetime
--
Update time
knowledge association, knowledge retrieval, knowledge recommendation, knowledge discovery, etc. The specific fields are shown in Table 2. The knowledge file address table contains the URL address where the knowledge file of the knowledge point is stored, as well as detailed information of the knowledge point, such as the ID of the address table, the file download address, knowledge point ID, knowledge name, address content description, file content address, uploading user, file content difficulty, file tags, upload time, file name, creation and update time. These pieces of information effectively store the address of the knowledge file, making it convenient for backend services to operate, as shown in Table 3. (3) OSS distributed cloud storage service. As knowledge service systems in the field of the Internet of Things require massive storage of structured, semi-structured, and unstructured data (such as images, audio, video, log files, etc.), OSS, as a distributed cloud storage service, has advantages such as elastic scalability, high availability and reliability, security, simple API and SDK, and data processing capabilities. It is very suitable for storing knowledge files of the knowledge graph. Therefore, this article adopts the OSS distributed cloud storage service.
580
Q. Wu et al. Table 3. Knowledge file address storage table.
Field name
Data types
Length
Field description
Id
bigint
20
Address table id
address_id
bigint
20
File download address
point_id
bigint
20
Knowledge point number
point_name
varchar
32
Name of knowledge
discription
varchar
255
Address Content Description
content_address
varchar
255
File content address
uploader_id
bigint
20
Uploading users
Degree
bigint
32
File Content Difficulty
Tag
varchar
32
File labels
upload_time
datetime
--
Upload time
file_name
varchar
32
Filename
create_time
datetime
--
Creation time
update_time
datetime
--
Update time
3 Content Expansion The system adopts a crowdsourcing model for content expansion. In the crowdsourcing expansion method, to motivate crowdsourcing workers to continue participating in the construction of the knowledge graph and protect their work results from infringement by others, blockchain technology is incorporated. Below, we will introduce the crowdsourcing expansion method in detail from four aspects: role division, crowdsourcing methods, basic process of crowdsourcing expansion, and professionalism and credibility mechanism. 3.1 Role Division The roles in the crowdsourcing expansion system are divided into four categories: task requesters, task workers, task reviewers, and knowledge learners. In the system, a credit score C and a professional degree score (X, score) (where X is the professional name) are set for all roles in the system to ensure the high quality of the knowledge system. Requester: To ensure the high quality of the knowledge system, only when a user meets a certain credit and professional degree in a certain field can they issue corresponding knowledge tasks in that field. When issuing tasks, the requester needs to pledge a certain amount of their credit to prevent malicious task publishing and system failures. Worker: Workers upload knowledge in the system, making them an extremely important role in the knowledge system. Workers can search for tasks that match their professional expertise in the system. If they meet the relevant conditions of the task, they can accept the task and complete its content.
Research on Construction Method of IoT Knowledge System
581
Reviewer: In order to ensure the professionalism and high quality of the knowledge, reviewers are added to score the quality of the uploaded knowledge. Before accepting a review task, a reviewer needs to meet certain credit and professional degree requirements. Learner: Learners can browse and learn the content that has already been constructed in the system. During the browsing and learning process, learners can selectively pay the corresponding fees to access the content. After learning the knowledge, the system will increase the user’s professional degree score in the corresponding field. In this system (see Fig. 2), all participants must register their identities on the blockchain, which involves being assigned public keys, private keys, and address information (with private keys securely held by users themselves).
Fig. 2. Crowdsourcing expands role and task division
3.2 Basic Crowdsourcing Expansion Process The basic crowdsourcing expansion process can be divided into the following six steps (see Fig. 3): (1) User Registration: Owing to the inherent characteristics of blockchain technology, all users must register a unique certificate on the chain to interact with it. Upon registration, each user receives a pair of public and private keys. This registration process helps the system identify distinct user identities, establish reputation profiles, and provide a foundation for subsequent task issuance, acceptance, and reward allocation. (2) Task Issuance: Tasks form the core of crowdsourcing. Each user on the chain, after meeting the system’s credit requirements, is granted the right to issue tasks within their own area of expertise. Tasks are divided into two categories based on the knowledge system storage structure: new knowledge node tasks and supplementary knowledge node learning file tasks. When issuing tasks, the task issuer must provide information such as task description, deadline and reward. (3) Task Acceptance: Users can search for tasks on the chain that match their professional degree and credit standing. After meeting the basic task requirements, they can accept the task. Users must submit the corresponding knowledge contribution results within the specified task deadline; otherwise, they will be considered in breach of contract and face penalties. (4) Task Review: Users can search for review tasks on the chain that meet their professional degree and credit requirements. After meeting the basic reviewer requirements, they can initiate a task review. Reviewers must score all knowledge contributions completed by workers and submit their review opinions.
582
Q. Wu et al.
(5) Task Submission: Once the review is complete, the requester can initiate a task submission application. The smart contract selects the highest-scoring knowledge contribution based on the reviewer’s scores for submission. After the task is submitted, the knowledge contribution officially becomes part of the system, allowing other users to browse and learn from it. (6) Reward Allocation: Upon completion of the task submission, the smart contract allocates different reward amounts to the requester, worker, and reviewer according to the task completion degree. The specific reward amount is determined by the requester when posting the task.
Fig. 3. System crowdsourcing extended flowchart
4 System Design and Development To gradually form a fully functional knowledge service system that integrates “content expression, association indexing, and multimedia teaching,” it is necessary to first develop a basic IoT knowledge graph system with the same structure as described in the previous sections. After that, further improvements can be made based on this foundation. The goal of this system is to help students or other individuals who are learning IoT-related knowledge to have a simple and easy-to-understand system that can eliminate some obstacles in the learning process. It also aims to help beginners better understand IoT-related knowledge through this diversified approach.
Research on Construction Method of IoT Knowledge System
583
4.1 System Architecture Design The entire knowledge system adopts the Browser/Server (B/S) structure (see Fig. 4). The front-end pages are mainly developed using HTML/CSS and JavaScript combined with the React.js framework and Ant Design component library. These front-end technologies not only effectively express various functional modules of the system but also visually display the knowledge graph. The back-end is developed using the Java language combined with the SpringBoot framework, which simplifies traditional Spring back-end development. The database adopts a two-tier storage design, with the lowest level of data storage using OSS and MySQL combined with MyBatis for storage. The SpringBoot framework, combined with OSS, is primarily used to store various types of knowledge files uploaded by users. After uploading and saving the files to OSS, users can share the file URLs with others for downloading, browsing, and learning. The MySQL database combined with the MyBatis framework is mainly used to store structured data other than user-uploaded knowledge. Redis is used as a caching database to increase system concurrency, improve system performance, and better serve users.
Fig. 4. Technical framework of knowledge graph system
4.2 Main Functions of the System The system follows the structure described earlier for development. Since the system involves many functions, this article selects important functions and displays their pages. (1) Knowledge System Display Function: Visualize the knowledge graph entities, relationships, and attributes to help users better understand and learn the structure and content of the IoT (see Fig. 5).
584
Q. Wu et al.
Fig. 5. The overall structure of the IoT knowledge graph system
(2) Search Function: A corresponding search interface is developed in the system. Users can enter the content they want to search in the search box at the top, and the system will search for all related nodes based on fuzzy matching of keywords and display them. Users can perform corresponding operations on the searched nodes, such as adding nodes, deleting nodes, and viewing details. The following figure shows the search results display for the keyword “application” (see Fig. 6).
Fig. 6. Search Engine interface
(3) Knowledge File Display Function: Display the knowledge files in the knowledge system. After users open the node details, they can learn and browse knowledge files. This page displays the type and learning difficulty of the knowledge, and supports full browsing and downloading of knowledge files (see Fig. 7).
Research on Construction Method of IoT Knowledge System
585
Fig. 7. Knowledge browsing interface
5 Conclusion This article presents the design of a comprehensive knowledge system with independent intellectual property rights, featuring “content expression, flexible expansion, visual navigation, and associative learning” for ordinary users to engage in professional knowledge learning and application. The goal of this knowledge system is to help students or others who are learning computer-related knowledge have a simple and easy-to-understand system that can remove some obstacles in their learning journey. It also aims to assist beginners in better understanding IoT-related knowledge through diversified methods. Currently, the system implemented in this article mainly focuses on knowledge expansion within the campus, and there are still many shortcomings and deficiencies. In the subsequent process, further development and improvement can be made to serve a larger number of users.
References 1. Singhal, A.: Official Google Blog: Introducing the Knowledge Graph: things, not strings. (2012) 2. Simmons, R.F.: Technologies for machine translation. Futur. Gener. Comput. Syst. 2(2), 83–94 (1986) 3. Qi, G., Gao, H., Wu, T.: Research progress of knowledge graph. Inf. Eng. 3(1), 4–25 (2017) 4. Xu, Z., Sheng, Y., He, L., et al.: A review of knowledge graph technology. J. Univ. Electron. Sci. Technol. 45(4), 18 (2016) 5. Howe, J.: The rise of crowdsourcing. . Wired Magazine 14(6), 1–4 (2006) 6. Zhang, X., Xue, G., Yu, R., et al.: Keep your promise: mechanism design against free-riding and false-reporting in crowdsourcing. IEEE Internet Things J. 2(6), 562–572 (2015) 7. Li, Y., Duan, H., Yin, Y., et al.: Survey on blockchain-based decentralized crowdsourcing technology. Comput. Sci. 048–011 (2021)
Reinforcement Learning
Robust Anti-forensics on Audio Forensics System Qingqing Wang and Dengpan Ye(B) Key Laboratory of Aerospace Information Security and Trusted Computing, Ministry of Education, School of Cyber Science and Engineering, Wuhan University, Wuhan, China [email protected]
Abstract. Audio forensics systems are effective evidence to prove whether audio is real or not. Recently, there have been studies on anti-forensics technology that can deceive audio forensics systems. However, the anti-forensics technology loses its effectiveness after being compressed. In this paper, we first study the influence of known and unknown compression on audio anti-forensics technology, and enhance the robustness of audio anti-forensics technology. In this paper, the compression algorithm and compression approximation algorithm are added to the iterative process of generating anti-forensics adversarial examples to generate anti-forensics adversarial examples that can resist compression, and solve the problem that audio anti-forensics adversarial examples cannot reduce the accuracy of audio forensics system under compression. Three compression algorithms, AAC-HE, AAC-LC and MP3, as well as a pre-trained compression approximation algorithm, were added to the iterative process of generating audio anti-forensics adversarial examples to generate audio anti-forensics adversarial examples, which were uploaded and downloaded through the actual network platform (Himalaya) and input into the audio forensics system. Experimental results show that the proposed method is effective in resisting compression. Keywords: audio anti-forensics · audio compression · robustness · adversarial example
1 Introduction Audio forensics [1] is an important branch of multimedia security, which can be used to evaluate the authenticity of digital audio. However, when malicious attacks are present, the accuracy of audio forensics methods decreases. Often, in real situations, a skilled forger will hide edits during or after a forgery operation, thus disabling audio forensics systems. This forger’s operation is known as anti-forensics, which is an important branch of audio forensics security analysis [14].
© The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 589–599, 2023. https://doi.org/10.1007/978-981-99-4761-4_50
590
Q. Wang and D. Ye
1.1 Anti-forensics Technology At present, anti-forensics technology has been studied in image, video and audio fields. Kirchner et al. [8, 9] proposed a digital image anti-forensics attack algorithm based on the periodic variation of residual signals. Fontani et al. [4] proposed an anti-forensics attack algorithm based on the statistical characteristics of pixel local relations. Sayanna et al. [16] applied anti-forensics technology to information hiding in the field of digital audio. Mascia et al. [15] discussed the game between audio forensics and anti-forensics. At present, there are few anti-forensics studies in the field of audio. Liu et al. [14] proposed an anti-forensics framework based on generative adversarial network to address the shortcomings of stereo forgery detectors. Tao et al. [18] proposed a simple method to distinguish history audio examples of MP3 compression. Li et al. [11] proposed an antiforensics framework based on Generative adversarial network (GAN) to forge audio source information by adding specific interference. Dong et al. [3] applied the idea of momentum method to BIM [10] and obtained tremendous benefits when generating adversarial examples. Adversarial examples can mislead DNNs. For human perception system, adversarial examples are almost imperceptible, but adversarial examples will lead to wrong output results of neural network [5]. There are several methods to generate adversarial examples for several patterns, and Carlini et al. [2] proposed a method to generate minimal adversarial perturbations to attack speech recognition systems. In the field of ASV, Li et al. [12] successfully attacked the i-vector speaker verification model using adversarial examples. For deception countermeasures, the anti-deception model of Liu et al. [13] also shows that it is vulnerable to adversarial attack. 1.2 Anti-forensics Technology for Compression Shin et al. [17] generate adversarial images which survive JPEG compression, by including a differentiable approximation to JPEG in the target model. Wang et al. [20] proposed a generic compression-resistant adversarial example generation framework named ComReAdv, which can achieve privacy-preserving photo sharing on social networks. Jiang et al. [6] proposed a black box adversarial attack method that comnbine natural evolution strategy(NES) [21] and basic iterative method(BIM) [10]. The success rate of the anti-forensics attack of [6] is 99%, that is, the accuracy of the audio forensics system is reduced to 1% by the anti-forensics attack, and the audio forensics model is almost invalid, indicating that the anti-forensics method has a good effect. But after compression, the effect of the anti-forensics method of [6] will decline sharply. The existing adversarial audio does not consider the existence of audio compression in networks, but compression does affect the quality of adversarial audio and the success rate of adversarial examples, thus affecting the robustness of audio anti-forensics technology. In order to solve this problem, we are faced with a challenge. After compression, the anti-forensics adversarial examples lose the ability to mislead the audio forensics system. We need to design reasonable algorithms to overcome this problem and enhance the robustness of anti-forensic adversarial examples.
Robust Anti-forensics on Audio Forensics System
591
Aiming at unknown compression in networks, inspired by [20] and [6], we propose an anti-forensics method that incorporates compression approximation method into gradient estimation for black-box adversarial attacks (ComNESBIM).
Fig. 1. Comparison between an attack algorithm [6] and the proposed ComNESBIM under the scenario of networks. The general routine of attacking DNNs based models behind the networks is first to (1) generate adversarial audios and (2) upload them to the networks, then (3) the adversarial audios will be compressed by the network platforms, and (4) download those audios, finally (5) classification.
As shown in Fig. 1, the adversarial audio generated by the proposed ComNESBIM still can mislead the audio forensics system after downloaded from networks. Since the effect of compression, [6] would fail to attack the classifier. By contrast, the proposed ComNESBIM would resist such negative effect from compression, misleading the classifier to predict a false label. Our main contributions are summarized as follows: 1. To our knowledge, we are the first to attempt to investigate the negative effects of unknown compression to audio anti-forensics adversarial examples on the network. On this basis, a new robust anti-forensics framework is proposed, and anti-compression audio anti-forensics adversarial examples are generated. 2. We designed a kind of audio compression approximation model(ComAppModel) based on the structure of encoding-decoding, to the unknown audio compression method for approximate, implements the differential form of unknown audio compression method. It solves the problem that the adversarial examples after compression cannot attack audio forensics system successfully. 3. We are in a typical audio network (Himalayan) on experimental evaluation, compared with the methods in [6], proves the proposed method successfully reduce the accuracy of audio forensics system from 92% to 39%.
592
Q. Wang and D. Ye
Fig. 2. ComNESBIM: Adversarial examples generation process based on natural evolution strategy with ComAppModel.
2 Proposed Method 2.1 Overview of ComNESBIM The key to achieving compression-resistant adversarial audios against unknown compression method is to reconstruct the unknown compression process, so that we regarded the ComAppModel and the audio forensics algorithm as a whole, and added the ComAppModel to the iterative process of generating audio adversarial examples to generate audio adversarial examples.We first attempt to explore the effect of compression on audio adversarial examples to enhance the robustness of audio anti-forensics technology. Figure 2 shows the flow of the attack. To put it simply, on the basis of [6], ComAppModel is added before audio forensics algorithm for gradient estimation and iteration. Specifically, to launch adversarial example attack on audio forensics model, the original NES algorithm should be modified because it is an optimization problem to minimize the objective function. BIM method is used to implement gradient descent, as shown Eq. (1) and (2): x˙ 0 = x
(1)
x˙ i = Clipx,ε {˙xi−1 − γ sign(∇x (f (x)) + δ)))}
(2)
Including γ super parameter vector is a representative, this vector using the adaptive adjustment vector algorithm to adjust, this method can be in gradient descent stall vector automatically reduced. Natural evolution strategy is used to estimate the gradient type ∇x (f (x)). The pipeline of the proposed ComNESBIM is showed in Fig. 3.
Robust Anti-forensics on Audio Forensics System
593
Fig. 3. The pipeline of the proposed ComNESBIM. The block 1 is the collection of training dataset for compression approximation by querying the network whose compression method is unknown; The block 2 is the training of the ComAppModel on the collected training set; The block 3 is the generation of adversarial audios based on the ComNESBIM.
The block 1 is the collection of training dataset for compression approximation by querying the network whose compression method is unknown. The compression method of the network is a black box, while the compressed audios can be obtained by querying. Therefore, given an original audio, it can be uploaded to the network and then downloaded to obtain the corresponding compressed audio. By this way, a large number of original-compressed audios pairs are collected to form the training set for ComAppModel. The block 2 is the training of the ComAppModel on the collected training set. In this part, utilizing the original-compressed audio pairs for supervised learning, we design an encoding-decoding based compression approximation model, called ComAppModel, which could effectively learn the transformation from original audios to compressed audios. The block 3 is the generation of adversarial audios based on the ComNESBIM. The goal of ComNESBIM is to generate compression-resistant adversarial audios. Instead of only considering the Deep neural networks(DNNs) based model in most existing attack algorithms, the ComNESBIM involves the ComAppModel into the optimization process, i.e., the adversarial audios are first passed through the ComAppModel and then fed to DNNs. Thus, the compression effect is considered during optimizing the adversarial audios. 2.2 Compression Approximation Model The ComAppModel is the key of ComNESBIM because it approximates the unknown compression and provides a differentiable model. Intuitively, the performance of ComAppModel would dominate the robustness of the generated adversarial audios. Therefore, the technical difficulty is to train ComAppModel well, which takes a long time and requires experiments to observe the training effect of ComAppModel. In this paper, we aim to transfer original audios to compressed audios. The ComAppModel consists of an encoder and a decoder. The encoding-decoding structure shapes two symmetric pyramids, which could effectively extract and reconstruct multi-level features.
594
Q. Wang and D. Ye
The encoder can realize audio feature extraction and decoder can realize audio reconstruction. Therefore, the encoding-decoding structure could effectively modify the input audios from coarse to fine through multi-level features and achieve target outputs. Following the encoding-decoding structure, we design the ComAppModel as illustrated in Fig. 4, where the encoder and decoder are implemented by convolution and deconvolution, respectively. The input is the original audio, and the output is supposed to be the compressed audio, which is named reconstructed audio because it is not the real compressed audio. The structure of encoder and decoder are similar to [20].
Fig. 4. The flow of ComAppModel.
3 Experimental 3.1 Dataset and Victim Model We used the ASVspoof 2019 dataset [19]. ASVspoof2019 is a dataset that uses VCTK dataset as a real sample and generates false speech based on it. It is divided into three parts: training (trn), development (dev) and evaluation (eval). Here, our audio forensics model is the CNN-GRU model [7], that is, the victim model is the CNN-GRU model. After training, the detection accuracy of the audio forensics model reaches 99.9%. 3.2 Audio Anti-compression Anti-forensics Technology In this section, the experiment will be carried out in two steps. First, the audio forensics system will be trained locally with MP3, AAC-LC, AAC-HE compression algorithms and ComAppModel to generate adversarial examples with anti-compression ability and be tested locally. Second, the adversarial examples with anti-compression ability will be uploaded to the network and then downloaded for testing.
Robust Anti-forensics on Audio Forensics System
595
Among them, the local test part mainly attacks MP3, AAC-LC and AAC-HE compression algorithms. Before evaluating the performance of adversarial examples after compression, it is necessary to re-evaluate the accuracy of audio forensics system. The accuracy of audio forensics system is 99%. Before generating the audio adversarial examples with anti-compression ability, the anti-compression ability of the audio adversarial examples generated in [6] is tested. After that, the method of integrating specific compression and ComAppModel into gradient estimation is used. Since it is difficult to directly differentiate the compression process, the gradient estimation method based on NES is used here. Specifically, the example is compressed before input into the audio forensics system in the gradient estimation stage, so that the estimated gradient includes the compression process and the false audio detection process.
4 Result and Analysis 4.1 Audio Anti-compression Adversarial Examples Local Test Before generating audio adversarial examples with anti-compression ability, the anticompression ability of audio adversarial examples generated in [6] is tested. The first point of grey, orange, and blue in Fig. 5 is the accuracy of the audio forensics system after the adversarial examples generated without compression algorithm are input by the corresponding compression algorithm. Grey shows the accuracy of audio forensics system after the adversarial examples which are generated without AAC-HE compression algorithm, generated with AAC-HE compression algorithm and noise threshold ε = 0.001, and generated with AAC-HE compression algorithm and noise threshold ε = 0.005 are respectively compressed by AAC-HE, and then input. The corresponding compression algorithms of orange and blue are AAC-LC and MP3 respectively. It can be seen from Fig. 5 blue and orange that for MP3 and ACC-LC, the accuracy of audio forensics system decreases by a small amount when the noise threshold is increased; while for ACC-HE, the accuracy of audio forensics system increases when noise threshold is increased in grey, which indicates that the higher the noise threshold is, does not mean the lower the accuracy of the audio forensics system will be. From Fig. 6, it can be seen more intuitively that when no compression algorithm is added to generate audio adversarial examples, the audio adversarial examples are compressed by MP3 and AAC-HE and then input into the audio forensics system for testing. The accuracy of the audio forensics system is still very high, indicating that MP3 and AAC-HE algorithms have a great impact on the performance of the adversarial examples, almost makes the adversarial examples completely lose their attack ability. In addition, AAC-LC does not have the same impact on adversarial examples performance as MP3 and AAC-HE. After the compression algorithm was incorporated into the gradient, the accuracy of the audio forensics system was greatly reduced by testing using the same compression algorithm as the training. But it can be intuitive to see from the Fig. 6, when the compression algorithm of training and testing do not match, the audio adversarial examples still cannot reduce the accuracy of audio forensics system. Namely the adversarial examples generated which directly are put specific compression algorithm in the process of adversarial examples generation do not have transferability, the adversarial examples
596
Q. Wang and D. Ye
Fig. 5. Accuracy of the audio forensics system among AAC-HE, AAC-LC and MP3. (the lower the better. 1 represents training without the audio compression, 2 represents training with the audio compression and 3 represents training with the audio compression and noise enhancement).
Fig. 6. Accuracy of the audio forensics system attack with AAC-HE, AAC-LC and MP3. (the lower the better. Yellow represents training without the audio compression, light blue represents training with MP3, light green represents with AAC-LC, dark green represents with AAC-HE and dark blue represents with ComAppModel).
do not have the robustness. However, the adversarial examples generated by including our pre-trained ComAppModel into the process of adversarial examples generation can still maintain robustness after the compression of MP3, AAC-LC and AAC-HE, which greatly reduces the accuracy of the audio forensics system. It is shown that this method has a good effect in eliminating the effect of compression on the noise suppression and improves the ability of audio adversarial examples to resist compression.
Robust Anti-forensics on Audio Forensics System
597
4.2 Audio Anti-compression Adversarial Examples Network Test Finally, the robustness of the generated anti-forensics adversarial examples in the real network environment is tested. This paper tests the robustness of anti-forensics adversarial examples by uploading the audio to the network and downloading and then put it into the audio forensics system. In this paper, an experiment is carried out on Himalayan platform. Before the experiment, the audio coding method used by Himalayan platform is investigated. Himalayan adopts AAC coding, with a sampling rate of 44100 Hz and a code rate of 192 kb/s. We set noise threshold ε = 0.005, attack power κ = 5 to generate adversarial examples. Table 1 shows the accuracy of audio forensics system after the adversarial examples generated with MP3, AAC-LC, AAC-HE and ComAppModel uploading to Himalayan and downloading from Himalayan. Table 1. Accuracy of the audio forensics system. Compression method
Accuracy of the audio forensics system
None
92%
AAC-HE
92%
AAC-LC
57%
MP3
56%
ComAppModel
39%
By Table 1, as we can see, when only use the method to generate adversarial examples in the [6], and then upload adversarial examples to the Himalayan download to test, the accuracy of audio forensics system is 92%. Although the adversarial examples generated with AAC-HE did not reduce the accuracy of the audio forensics system, AAC-LC and MP3 reduced to 57% and 56%, respectively. And even more surprising, the adversarial examples generated with ComNESBIM reduce the accuracy of the audio forensics system to 39%. Since ComAppModel is trained by raw audio and compressed audio, it has the ability to simulate compression algorithm, so the adversarial examples obtained by ComNESBIM have certain robustness against unknown compression. Compared with the adversarial examples obtained by the known compression algorithm, the adversarial examples obtained by ComNESBIM is more robust to the audio forensics system.
5 Conclusions In this paper, we first to attempt to investigate the negative effects of unknown compression on the network on audio anti-forensics adversarial examples. And we proposed a compression-resistant adversarial example generation framework named ComNESBIM, which can achieve robustness of audio adversarial examples against known and unknown audio compression. Specifically, we proposed ComAppModel, an encoding-decoding
598
Q. Wang and D. Ye
based compression model which can approximate unknown audio compression methods. With the trained ComAppModel as the differentiable approximation, we adapted existing gradient-based attack algorithms to efficiently generate compression-resistant adversarial examples. Extensive experiments validated that ComNESBIM can achieve high success rate in resisting MP3, AAC-LC, AAC-HE compression methods. Furthermore, the evaluation results on real social networks show that ComNESBIM can effectively approximate the unknown compression method and generate compression-resistant adversarial audios. Acknowledgements. This work was supported by National Natural Science Foundation of China NSFC (62072343), the National Key Re-search and Development Program of China (2019QY(Y)0206), and the Fundamental Research Funds for the Central Universities (2042023kf0228).
References 1. Bao, Y., Liang, R., Cong, Y.: Research progress on key technologies of audio forensics. J. Data Acquis. Process. 31(2), 252–259 (2016) 2. Carlini, N., Wagner, D.: Audio adversarial examples: targeted attacks on speech-to-text. In: 2018 IEEE security and privacy workshops (SPW), pp. 1–7. IEEE (2018) 3. Dong, Y., et al.: Boosting adversarial attacks with momentum. In: Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 9185–9193 (2018) 4. Fontani, M., Barni, M.: Hiding traces of median filtering in digital images. In: 2012 Proceedings of the 20th European Signal Processing Conference (EUSIPCO), pp.1239–1243. IEEE (2012) 5. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572 (2014) 6. Jiang, Y., Ye, D.: Black-box adversarial attacks against audio forensics models. Secur. Commun. Netw. 2022, 1–8 (2022) 7. Jung, J.w., Shim, H.j., Heo, H.S., Yu, H.J.: Replay attack detection with complementary highresolution information using end-to-end dnn for the asvspoof 2019 challenge. arXiv preprint arXiv:1904.10134 (2019) 8. Kirchner, M., Bohme, R.: Hiding traces of resampling in digital images. IEEE Trans. Inf. Forensics Secur. 3(4), 582–592 (2008) 9. Kirchner, M., Fridrich, J.: On detection of median filtering in digital images. In: Media forensics and security II, vol. 7541, pp. 371–382. SPIE (2010) 10. Kurakin, A., Goodfellow, I.J., Bengio, S.: Adversarial examples in the physical world. In: Artificial Intelligence Safety and Security, pp. 99–112. Chapman and Hall/CRC (2018) 11. Li, X., Yan, D., Dong, L., Wang, R.: Anti-forensics of audio source identification using generative adversarial network. IEEE Access 7, 184332–184339 (2019) 12. Li, X., Zhong, J., Wu, X., Yu, J., Liu, X., Meng, H.: Adversarial attacks on gmm i-vector based speaker verification systems. In: ICASSP 2020–2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp.6579–6583. IEEE (2020) 13. Liu, S., Wu, H., Lee, H.y., Meng, H.: Adversarial attacks on spoofing countermeasures of automatic speaker verification. In: 2019 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU), pp. 312–319. IEEE (2019) 14. Liu, T., Yan, D., Yan, N., Chen, G.: Anti-forensics of fake stereo audio using generative adversarial network. Multimed. Tools Appl. 81(12), 17155–17167 (2022)
Robust Anti-forensics on Audio Forensics System
599
15. Mascia, M., Canclini, A., Antonacci, F., Tagliasacchi, M., Sarti, A., Tubaro, S.: Forensic and anti-forensic analysis of indoor/outdoor classifiers based on acoustic clues. In: 2015 23rd European Signal Processing Conference (EUSIPCO), pp. 2072–2076. IEEE (2015) 16. Patole, B., Shinde, A., Bhatt, M., Shimpi, P.: Data hiding in audio-video using anti forensics technique for secret message and data in mp4 container (2018) 17. Shin, R., Song, D.: Jpeg-resistant adversarial images. In: NIPS 2017 Workshop on Machine Learning and Computer Security, vol. 1, p. 8 (2017) 18. Tao, B., Wang, R., Yan, D., Jin, C.: Anti-forensics of double compressed mp3 audio. Intl. J. Digital Crime Forensics 12(3), 45–57 (2020) 19. Todisco, M., et al.: Asvspoof 2019: future horizons in spoofed and fake audio detection. arXiv preprint arXiv:1904.05441 (2019) 20. Wang, Z., et al.: Towards compression-resistant privacy-preserving photo sharing on social networks. In: Proceedings of the Twenty-First International Symposium on Theory, Algorithmic Foundations, and Protocol Design for Mobile Networks and Mobile Computing, pp. 81–90 (2020) 21. Wierstra, D., Schaul, T., Glasmachers, T., Sun, Y., Peters, J., Schmidhuber, J.: Natural evolution strategies. J. Mach. Learn. Res. 15(1), 949–980 (2014)
Off-Policy Reinforcement Learning with Loss Function Weighted by Temporal Difference Error Bumgeun Park, Taeyoung Kim, Woohyeon Moon, Sarvar Hussain Nengroo, and Dongsoo Har(B) Korea Advanced Institute of Science and Technology, Daejeon 34101, Republic of Korea [email protected]
Abstract. Training agents via off-policy deep reinforcement learning algorithm requires a replay memory storing past experiences that are sampled uniformly or non-uniformly to create the batches for training. When calculating the loss function, off-policy algorithms commonly assume that all samples are of equal importance. We introduce a novel algorithm that assigns unequal importance, in the form of a weighting factor, to each experience, based on their distribution of temporal difference (TD) error, for the training objective. Results obtained with uniform sampling from the experiments in eight environments of the OpenAI Gym suite show that the proposed algorithm achieves in one environment 10% increase in convergence speed along with a similar success rate and in the other seven environments 3%–46% increases in success rate or 3%–14% increases in cumulative reward, along with similar convergence speed. The algorithm can be combined with existing prioritization method employing non-uniform sampling. The combined technique achieves 20% increase in convergence speed as compared to the prioritization method alone. Keywords: Experience · Off-Policy · Reinforcement learning · Replay memory · Weighting factor · TD error
1 Introduction Reinforcement learning (RL) enables an agent to learn a task by taking actions in an environment in order to maximize future reward [1]. The use of deep neural networks to model a policy or a value function allows RL algorithms to deal with complicated tasks, such as playing a range of Atari games and mastering the game of Go [2–5]. Recently, there have also been published works on the application of deep RL in game environments as well as practical environments. Examples of such applications include controlling a robotic arm [6–8], planning a path for mobile robots [9, 10], predicting traffic accidents [11], and controlling soccer robots in soccer game [12–14]. Training by deep RL algorithms is usually carried out in two different paradigms: onpolicy and off-policy. In both deep RL paradigms, the policy is learned from experiences obtained by taking actions in the environment. On-policy deep RL algorithms use current © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 600–613, 2023. https://doi.org/10.1007/978-981-99-4761-4_51
Off-Policy Reinforcement Learning with Loss Function Weighted
601
experiences to train the policy, without storing the experiences. Off-policy deep RL algorithms reuse current and past experiences by storing them in replay memory (RM). Using the RM to update the policy, called experience replay (ER), can increase the sample efficiency by using experiences (probabilistically) repeatedly for training. Offpolicy deep RL can be more stable than on-policy deep RL because the policy is updated based on the data obtained by the previous policy that was used to collect the experiences currently stored in the RM, whereas the off-policy deep RL can be computationally expensive, requiring a large number of interactions with the environment [15]. Challenges in off-policy deep RL usually involve how sampling is performed with the experiences stored in the RM to create the training batches. In the traditional vanilla implementation of off-policy algorithms, the batch is created based on the uniform sampling of the experiences in the RM. This assumes that all experiences are of equal importance. However, for a given task, particular types of experiences can be more important. To mitigate this issue and improve sampling efficiency, more sophisticated approaches have appeared lately in the literature. The combined experience replay is a sampling strategy that prioritizes experiences lately added to the RM [16]. A replay of recent experiences without explicit prioritization is proposed in [17]. Prioritized experience replay (PER) [18] non-uniformly samples the experiences based on the probability distribution that depends on the temporal-difference (TD) errors obtained during training. Our work considers the distribution of TD errors in a non-probabilistic manner in order to assign the weighting factors of loss function, unlike the PER for sampling experiences. When using prioritization methods such as PER, it is assumed that each experience has unequal importance, so non-uniform sampling is performed when creating the batches. Nonetheless, they are taken equally important when calculating the batch loss for training. We hypothesize in this work that training can be enhanced by assigning unequal importance to each experience, based on the distribution of TD errors, during training. From this viewpoint, a novel method, named distribution-based importance weighting (DBIW), that assigns the weighting factor representing unequal importance to each experience for calculating the loss function during training, is introduced. The DBIW is computationally efficient because it calculates the importance of experiences only in the batch, not in the RM. It can be combined with prioritization methods based on non-uniform sampling since it is applied after the generation of the training batches. It is shown in this paper that the DBIW can improve convergence speed and other performance metrics when used with various off-policy deep RL methods for experiments performed in various environments of the OpenAI Gym suite.
2 Background 2.1 TD-Based Off-Policy Deep Reinforcement Learning TD-based off-policy deep RL, a specific type of RL technique, has been widely used in a variety of autonomous tasks [19]. In TD-based off-policy deep RL, two policies, optimal policy and non-optimal policy, are used to explore environments for enough exploration while learning the optimal policy [1]. The RM can be used in various ways, for instance, which experiences to store, which experiences to replay, and which experiences to forget
602
B. Park et al.
[17]. In this paper, several off-policy deep RL algorithms are used as baselines, such as deep Q-networks (DQN), deep deterministic policy gradient (DDPG), and soft actorcritic (SAC), all of which use the RM [2, 20, 21]. 2.2 Prioritization Methods Prioritization is a technique used in off-policy deep RL for sampling in the RM [22]. Prioritization methods allow more meaningful experiences to be sampled with higher chance. The PER samples experiences based on the probability distribution that depends on the TD error [18]. An entropy-based prioritization that takes into account the specific properties, measured by Shannon’s entropy, of the sensed state space is proposed in [23]. The maximum entropy-based prioritization, which encourages the agent to pursue more diverse goals for multi-goal RL, is proposed in [24]. Batch prioritization which is based on an intended goal and information of the valid goal space is proposed in [7]. 2.3 Weighting Factor The use of weighting factors for a loss function or an error to improve the performance is proposed in [25–29]. In RL, the likelihood-free importance weighting method re-weights experiences by using weighting factors, based on their likelihood under the stationary distribution of the current policy [30].
Fig. 1. Overall structure of the proposed method. TD errors of the batch are element-wise multiplied by the weighting factor
Fig. 2. Loss function of two TD errors. Left one is the MSE loss LMSE and the right one is the weighted loss LW . The weighting factors multiplied by the TD errors distort the loss function and consequently change the direction of the gradient of the loss function.
Off-Policy Reinforcement Learning with Loss Function Weighted
603
Fig. 3. Weighting factor according to the magnitude of the TD error.
3 Proposed Method In this section, the details of the proposed DBIW method are described. The proposed method assigns unequal importance to each experience, based on the Distribution of the magnitude of the TD errors matched with the experiences in the Batch (DTB). With the fair size of the batch, the DTB can follow the Distribution of the magnitude of the TD errors matched with the experiences in the RM (DTR). For the training objective, assigning importance to each experience can be done by element-wise multiplying the weighting factor with the TD error, as illustrated in Fig. 1. Then, the loss function L is given by 2 LMSE = N1 N j=1 δj 2 (1) N 1 LW = N j=1 ωj δj where LMSE is the typical mean square error loss, LW is the weighted loss, N is the batch size, δj is the j − th TD error and ωj is the j − th weighting factor. The weighting factor distorts the shape of the loss function, which is a function of TD error, changing the direction of the gradient of the loss function, as shown in Fig. 2. Changing the direction of the gradient of the loss function can be mathematically described as following ⎡ ∂ L ⎤ ⎡ ∂ L ∂δ ⎤ ∂ L ∂δN 1 ∂δ1 · ∂θ1 + · · · + ∂δN · ∂θ1 ∂θ1 ⎢ ∂ L ⎥ ⎢ ∂ L ∂δ1 ∂ L ∂δN ⎥ · ∂θ2 ⎥ ⎢ ∂θ2 ⎥ ⎢ ∂δ1 · ∂θ2 + · · · + ∂δ N ⎢ ⎥ ⎥ = (2) ∇θ L = ⎢ .. ⎢ .. ⎥ ⎢ ⎥ ⎣ . ⎦ ⎣ ⎦ . ∂L ∂ L ∂δ1 ∂ L ∂δN ∂θK ∂δ1 · ∂θK + · · · + ∂δN · ∂θK ∂L ∂L · ∇θ δ1 + · · · + · ∇θ δN ∂δ1 ∂δN 2δj ∂L N , L = LMSE = 2δ j ωj ∂δj N , L = LW
∇θ L =
(3) (4)
where θ , ∇θ , K, N and θj represent the weights of the policy network, the gradient operator with respect to the θ , the dimension of the weights of the policy network, the batch size and the j −th element of the weights of the policy network, respectively. Then,
604
B. Park et al.
the gradient descent method with the weighting factors is given by θ ← θ − α∇θ L(ω ◦ δ; θ )
(5)
where α, ω ∈ RN , δ ∈ RN and ◦ are the learning rate, the vector of weighting factors, the vector of TD errors of the batch and element-wise multiplication operator, respectively. As shown in (2), (3), (4), the gradient of the loss function is given as (2), which can be rewritten as (3). Equation (4) is a partial derivative of the loss function with respect to the TD error δj and is multiplied by the vector ∇θ δj affecting the magnitude of the gradient. The weighting factor ωj is given by ωj = ω δj |δ1 , δ2 , . . . , δN = ω δj |δ (6) As noted in (6), the weighting factor ωj corresponds to the TD error δj and the weighting factor ωj is given by the distribution of TD errors δj , j = 1, . . . , N . Figure 3) shows the weighting factor according to the magnitude of TD error. The COR is the median magnitude of the TD error. Assignment of the weighting factor is based on the criteria given as
ωj δj |δ > ωi (δi |δ), |δi | >
δj
> COR (7) ωj δj |δ < ωi (δi |δ), |δi | < δj < COR
Fig. 4. Procedure for calculating the weighting factor
Equation (7) represents that for TD errors having their magnitude greater than the COR the smaller weighting factor is given to the experience having a larger magnitude of TD error. On the other hand, for TD errors having their magnitude lower than the COR the smaller weighting factor is given to the experience having a smaller magnitude of TD error. It implies that a higher weighting factor is assigned to the magnitude of TD error around the COR, e.g., roughly the center of distribution. Logical basis of using the median instead of the mean is described in Sect. 3.1. The pseudo-code representing the procedure for assigning the weighting factor to the TD error is presented in Algorithm 1. The proposed DBIW consists of three processes as shown in Fig. 4. Each of the three processes is described in each subsection of this section. In Sect. 3.1, the TD errors are normalized by the median of the DTB. In Sect. 3.2, the Gaussian function is used to give weighting factors according to (7). In Sect. 3.3, the softmax operator is used to prevent the excessive difference between weighting factors and to normalize weighting factors.
Off-Policy Reinforcement Learning with Loss Function Weighted
605
3.1 Normalization of TD Errors Magnitude Because the range of TD errors differs depending on the task and its related environment, the TD errors are normalized before assigning the weighting factors. The magnitude of each TD error is normalized by the median of the DTB, given as
δj − COR δn,j = (8) σ (|δ|) where σ (|δ|) and δn,j represent the standard deviation of the magnitude of TD errors of the batch and the j − th normalized TD error, respectively. The DTR is, in general, negatively skewed [31, 32]. With the fair size of the batch, the DTB follows the DTR, and the DTB can be negatively skewed, as shown in Fig. 5. For the negatively skewed DTB, the median of the DTB represents the typical value of the DTB better than the mean.
Fig. 5. Negatively skewed distribution of the magnitude of the TD error obtained during training in the FetchPush environment. Distributions of TD error after 30, 60 and 90 epochs show negative skewness.
606
B. Park et al.
3.2 Gaussian Function For assigning the weighting factors and reducing the impact on the outliers, the widely used zero-mean unit variance Gaussian function is used [33], which is given by 1 1 2 (9) pj = √ exp − δn,j 2 2π where pj is the variable representing the importance of the j − th TD error. By using the zero-mean unit variance Gaussian function, higher importance is assigned to normalized TD error close to zero. 3.3 Softmax Operator Excessive difference between the importance variables p results in overfitting to the experiences of high importance. To mitigate this problem, the softmax operator that is typically used for conversion into probability distribution is used to rescale and normalize the importances [34], since the sum of weighting factors is also 1. The softmax operator is given by exp pj (10) ωj = i=1 exp(pi ) where ωj is j − th weighting factor. The softmax operator provides the lower limit exp pj > 1. By using the softmax operator, the difference between values of importance variables is reduced and the effect of softmax operator is shown in the ablation study of Sect. 4.4.
4 Experiments Performance of the proposed method is evaluated in the OpenAI Gym environment [35]. The proposed method is combined with TD off-policy deep RL algorithms, DQN, DDPG and SAC. The DQN, DDPG, and SAC with a random sampling of experiences are considered as baselines, and the DQN, DDPG and SAC with the proposed DBIW are compared with the baselines. Experiments are executed in two different discrete control tasks: MountainCar and LunarLander, and six different continuous control tasks: Reacher, FetchPush, FetchPickAndPlace, FetchSlide, HandManipulateEgg, and HandManipulateBlock. All experiments in each environment are repeated with five random seeds. Therefore, five curves are obtained and the mean values of the five curves are selected and presented as the result. The upper and lower boundary lines of the light-colored area show the minimum and maximum values. 4.1 Discrete Control In this environment, DQN, which is a widely used RL algorithm in discrete control tasks, is used for the experiment. For each training, the returns are defined as the sum of the reward during an episode.
Off-Policy Reinforcement Learning with Loss Function Weighted
(a)
(b)
(c)
(d)
(e)
(f)
(g)
607
(h)
Fig. 6. Learning curves for the suite of OpenAI Gym control tasks. Plots in (a) and (b) represent the learning curves for the discrete control tasks, and plots in (c)–(h) represent the learning curves for the continuous control tasks. In all the plots, blue curve represents the proposed method and orange curve represents the baseline.
For the MountainCar environment, the training is performed for 200 episodes. In each episode, 20 evaluations are executed at the end of the episode and the mean of the return of 20 evaluations is obtained. Originally, in the MountainCar environment, a reward of −1 is given for every step until an episode, consisting of 200 steps ends. However, a reward-shaping technique is used to make the problem easier to learn [36]. Figure 6(a) shows that the proposed method increases the mean return at the episode index 200 from −193 to −185, representing a 3% increase compared with the baseline. To reduce granularity at each episode, the moving average of the returns across 20 episodes is taken. 4.2 Continuous Control Experiments are executed with the MuJoCo physics engine [37] and the robotic environment in the OpenAI Gym for continuous control tasks. The environments of the experiments considered in this study are Reacher, Fetch (Push, PickAndPlace, Slide tasks) and HandManipulate (Egg, Block tasks) [38]. In the Reacher environment, SAC, a state-of-the-art algorithm in this environment [39], is used for the experiment. In case
608
B. Park et al.
of the Fetch and HandManipulate environments, the DDPG with hindsight experience replay (HER) [40] is used for the experiments. Considering the Reacher environment, the training is performed for 700 episodes, and after every episode, ten evaluations are executed. The proposed method increases the mean return at the episode index 700 from −5 to −4, representing 14% increases in the return compared with the baseline, as shown in Fig. 6(c). To reduce granularity at each episode, the moving average of the returns across 70 episodes is taken. For the Fetch environment, training is performed for 100 epochs in the Push task, and in the PickAndPlace task and the Slide task, training is performed for 200 epochs. In each epoch, 50 episode cycles and 20 evaluations are executed sequentially. Figure 6(d), (e), (f) show that the proposed method increases the mean success rate at the end of the epoch index of training by approximately 23%, 3% and 4% compared with the baseline, respectively. The moving averages of the success rate across 10, 20 and 20 episodes are taken to reduce granularity at each epoch. For the HandManipulate environment, the training is performed for 1000 epochs, and in each epoch, 50 episode cycles and 20 evaluations are executed sequentially. Figure 6(g), (h) present the results for the Egg and Block tasks, respectively. Figure 6(g) shows that the proposed method increases the mean success rate at the end of the epoch index of training by approximately 46% compared with the baseline. h shows that the proposed method achieves a similar level of performance compared to the baseline. To reduce granularity at each epoch, the moving average of the success rate across 100 episodes is taken.
Fig. 7. Learning curves of FetchPush task for compatibility with PER.
4.3 Compatibility with PER In this subsection, compatibility of the proposed method with PER is presented. The PER prioritizes experiences to sample more surprising experiences before sampling. Two algorithms are applied in different steps, the PER for sampling and the proposed method for calculating the loss function. The experiment executed in the FetchPush task is repeated with five random seeds, and the mean values of the five curves are shown in Fig. 7. DDPG with HER is considered as the baseline. In Fig. 7, the baseline is
Off-Policy Reinforcement Learning with Loss Function Weighted
609
represented by the red curve, baseline with the proposed method is represented by the green curve, baseline with PER is represented by the orange curve, and baseline with PER and the proposed method is represented by the blue curve. The performance of the baseline combined with PER and DBIW (blue curve) achieves the best result, decreasing the number of epochs required to converge by 20% as compared to the baseline with PER (orange curve). 4.4 Ablation Studies Two distinct ablation studies are carried out. The first one examines the methodology of the proposed approach, while the second focuses on the impact of batch size. Ablation Study on the Effect of the Components of the Proposed Method. The first ablation study addresses the procedure of the proposed method. The proposed method undergoes through three processes: normalization of the TD errors, implementation of the Gaussian function, and execution of the softmax operator. To normalize TD error, two methods are considered: mean of the DTB and median of the DTB. An experiment is carried out to evaluate the effect of median normalization, and a weighting factor is assigned in accordance with (7) through the use of the Gaussian function. The softmax operator is then applied to rescale and normalize the weighting factor. However, it is also observed that without using the softmax operator, the weighting factor still satisfies (7). Therefore, further studies are performed to investigate the necessity of using the softmax operator. Table 1 shows the results of the ablation study. Given the use of the softmax operator, the use of the median value instead of the mean value for normalizing DTB increases the mean success rate at the epoch index of 25, 50 and 75, and converges before 100 epochs. The reason for the increase in the success rate is that the median of the DTB represents better typical value of the DTB than the mean. Without the softmax operator, the success rate cannot be over 15% for all the cases. The softmax operator rescales the important variables to reduce the difference between these variables, preventing overfitting to the specific samples. Table 1. Ablation study in FetchPush task Norm Mean Median
Softmax
Average success rate 25 Epochs
50 Epochs
75 Epochs
100 Epochs
O
0.18 ± 0.02
0.76 ± 0.03
0.88 ± 0.03
0.99 ± 0.01
×
0.09 ± 0.02
0.10 ± 0.03
0.04 ± 0.00
0.06 ± 0.06
O
0.30 ± 0.07
0.88 ± 0.04
0.98 ± 0.01
0.99 ± 0.00
×
0.10 ± 0.02
0.01 ± 0.03
0.04 ± 0.01
0.06 ± 0.05
Ablation Study on the Effect of the Batch Size. In this study, it is assumed that the DTB follows the DTR. If the number of samples is large enough, the distribution of the sample follows the distribution of the statistical population. For this reason, batch size
610
B. Park et al.
should be large enough for the DTB to follow the distribution of entire TD errors in the RM. Table 2 shows the result of the ablation study. For batch size 128, the proposed method marginally improves the performance compared with the baseline having the same batch size, decreasing the number of epochs by 7.8% required to converge. For batch size 256, the proposed method decreases the number of epochs by 43.8% required to converge with respect to the baseline keeping the same batch size. For batch size 512, the proposed method decreases the number of epochs by 18.0% compared with the baseline with the same batch size. The performance of the baseline also increases as the batch size increases and consequently, the reduction rate of batch size 512 is smaller than that of batch size 256. Table 2. Ablation study on the batch size in FetchPush task Batch Size
Convergence epochs
Reduction rate (%)
w/o DBIW
w/DBIW
128
128.6 ± 68.9
118.6 ± 59.8
7.8%
256
81.8 ± 65.7
46.0 ± 20.3
43.8%
512
49.0 ± 20.3
40.2 ± 21.7
18.0%
5 Conclusion Training agents via off-policy deep RL requires many experiences stored in a replay memory used for learning. When calculating the loss function using a sampled batch, it is assumed that all experiences are of equal importance. In this study, a novel method that prioritizes experiences by assigning different importance based on the DTB is proposed. Assigning importance can be achieved by multiplying the weighting factors by TD errors for the training objective. Assignment of the weighting factor is based on the idea that the performance of the off-policy deep RL algorithms can be improved by focusing on the TD errors around the typical value of the DTB. The proposed method is validated with two discrete control tasks and six continuous control tasks included in the OpenAI Gym suite. In the MountainCar environment, the proposed method increases the mean return by 3% over the baseline. For the LunarLander environment, the proposed method decreases the number of episodes required to converge by 10% compared with the baseline. In the Reacher environment, the proposed method increases the mean return by 14% compared with the baseline. For the Fetch environment (Push, PickAndPlace and Slide tasks), the proposed method increases the mean success rate by 23%, 3% and 4% with respect to the baseline, respectively. In the HandManipulate environment (Egg and Block tasks), the proposed method increases the mean success rate by 46% compared with the baseline for the Egg task, and achieves a similar level of performance compared with the baseline for the Block task. Combining the proposed method with PER decreases the number of epochs by 20% required to converge as compared to the
Off-Policy Reinforcement Learning with Loss Function Weighted
611
baseline with PER. It is worthwhile to mention that one of the most critical findings with the proposed method is that the distribution of TD errors presents a criterion for the assignment of a weighting factor to improve the performance. Acknowledgement. This work was supported by the Institute for Information communications Technology Promotion (IITP) grant funded by the Korean government (MSIT) (No. 2020-000440, Development of Artificial Intelligence Technology that continuously improves itself as the situation changes in the real world).
References 1. Sutton, R.S., Barto, A.G.: Reinforcement Learning: An Introduction. MIT Press, Cambridge (2018) 2. Mnih, V., et al.: Playing Atari with deep reinforcement learning. arXiv preprint arXiv:1312. 5602 (2013) 3. Silver, D., et al.: Mastering the game of go with deep neural networks and tree search. Nature 529, 484–489 (2016) 4. Li, Y.: Deep reinforcement learning: an overview. arXiv preprint arXiv:1701.07274 (2017) 5. Silver, D., et al.: Mastering the game of go without human knowledge. Nature 550, 354–359 (2017) 6. Seo, M., Vecchietti, L.F., Lee, S., Har, D.: Rewards prediction-based credit assignment for reinforcement learning with sparse binary rewards. IEEE Access 7, 118776–118791 (2019) 7. Vecchietti, L.F., Kim, T., Choi, K., Hong, J., Har, D.: Batch prioritization in multigoal reinforcement learning. IEEE Access 8, 137449–137461 (2020) 8. Vecchietti, L.F., Seo, M., Har, D.: Sampling rate decay in hindsight experience replay for robot control. IEEE Trans. Cybern. 52, 1515–1526 (2020) 9. Kim, I., Nengroo, S.H., Har, D.: Reinforcement learning for navigation of mobile robot with LiDAR. In: 2021 5th International Conference on Electronics, Communication and Aerospace Technology (ICECA), pp. 148–154. IEEE (2021) 10. Moon, W., Park, B., Nengroo, S.H., Kim, T., Har, D.: Path planning of cleaning robot with reinforcement learning. In: 2022 IEEE International Symposium on Robotic and Sensors Environments (ROSE), pp. 1–7. IEEE (2022) 11. Cho, I., Rajendran, P.K., Kim, T., Har, D.: Reinforcement learning for predicting traffic accidents. In: 2023 International Conference on Artificial Intelligence in Information and Communication (ICAIIC), pp. 684–688. IEEE (2023) 12. Hong, C., Jeong, I., Vecchietti, L.F., Har, D., Kim, J.-H.: AI world cup: robot-soccer-based competitions. IEEE Trans. Games 13, 330–341 (2021) 13. Kim, T., Vecchietti, L.F., Choi, K., Sariel, S., Har, D.: Two-stage training algorithm for AI robot soccer. PeerJ Comput. Sci. 7, e718 (2021) 14. Park, B., Lee, J., Kim, T., Har, D.: Kick-motion training with DQN in AI soccer environment. In: 2023 International Conference on Artificial Intelligence in Information and Communication (ICAIIC), pp. 689–692. IEEE (2023) 15. Yu, Y.: Towards sample efficient reinforcement learning. In: IJCAI, pp. 5739–5743 (2018) 16. Zhang, S., Sutton, R.S.: A deeper look at experience replay. arXiv preprint arXiv:1712.01275 (2017) 17. Novati, G., Koumoutsakos, P.: Remember and forget for experience replay. In: International Conference on Machine Learning, pp. 4851–4860. PMLR (2019)
612
B. Park et al.
18. Schaul, T., Quan, J., Antonoglou, I., Silver, D.: Prioritized experience replay. arXiv preprint arXiv:1511.05952 (2015) 19. Jeong, H.: Off-policy temporal difference learning for robotics and autonomous systems. University of Pennsylvania (2020) 20. Lillicrap, T.P., et al.: Continuous control with deep reinforcement learning. arXiv preprint arXiv:1509.02971 (2015) 21. Haarnoja, T., Zhou, A., Abbeel, P., Levine, S.: Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In: International Conference on Machine Learning, pp. 1861–1870. PMLR (2018) 22. Moore, A.W., Atkeson, C.G.: Prioritized sweeping: reinforcement learning with less data and less time. Mach. Learn. 13, 103–130 (1993) 23. Ramicic, M., Bonarini, A.: Entropy-based prioritized sampling in deep Q-learning. In: 2017 2nd International Conference on Image, Vision and Computing (ICIVC), pp. 1068–1072. IEEE (2017) 24. Zhao, R., Sun, X., Tresp, V.: Maximum entropy-regularized multi-goal reinforcement learning. In: International Conference on Machine Learning, pp. 7553–7562. PMLR (2019) 25. Menon, A., Narasimhan, H., Agarwal, S., Chawla, S.: On the statistical consistency of algorithms for binary classification under class imbalance. In: International Conference on Machine Learning, pp. 603–611. PMLR (2013) 26. Cui, Y., Jia, M., Lin, T.-Y., Song, Y., Belongie, S.: Class-balanced loss based on effective number of samples. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9268–9277 (2019) 27. Zhang, K., et al.: Re-weighted interval loss for handling data imbalance problem of end-to-end keyword spotting. In: INTERSPEECH, pp. 2567–2571 (2020) 28. Li, M., Zhang, X., Thrampoulidis, C., Chen, J., Oymak, S.: Autobalance: optimized loss functions for imbalanced data. In: Advances in Neural Information Processing Systems, vol. 34, pp. 3163–3177 (2021) 29. Guo, D., Li, Z., Zhao, H., Zhou, M., Zha, H.: Learning to re-weight examples with optimal transport for imbalanced classification. In: Advances in Neural Information Processing Systems, vol. 35, pp. 25517–25530 (2022) 30. Sinha, S., Song, J., Garg, A., Ermon, S.: Experience replay with likelihood-free importance weights. In: Learning for Dynamics and Control Conference, pp. 110–123. PMLR (2022) 31. Zhang, L., Tang, K., Yao, X.: Log-normality and skewness of estimated state/action values in reinforcement learning. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 32. Chan, S.C., Lampinen, A.K., Richemond, P.H., Hill, F.: Zipfian environments for reinforcement learning. In: Conference on Lifelong Learning Agents, pp. 406–429. PMLR (2022) 33. Rasmussen, C.E., Williams, C.K.: Gaussian Processes for Machine Learning. Springer, Heidelberg (2006) 34. LeCun, Y., Bengio, Y., Hinton, G.: Deep learning. Nature 521, 436–444 (2015) 35. Brockman, G., et al.: Openai gym. arXiv preprint arXiv:1606.01540 (2016) 36. Ng, A.Y., Harada, D., Russell, S.: Policy invariance under reward transformations: theory and application to reward shaping. In: Icml, pp. 278–287. Citeseer (1999) 37. Todorov, E., Erez, T., Tassa, Y.: MuJoCo: a physics engine for model-based control. In: 2012 IEEE/RSJ International Conference on Intelligent Robots and Systems, pp. 5026–5033. IEEE (2012) 38. Plappert, M., et al.: Multi-goal reinforcement learning: challenging robotics environments and request for research. arXiv preprint arXiv:1802.09464 (2018)
Off-Policy Reinforcement Learning with Loss Function Weighted
613
39. Chan, S.C., Fishman, S., Canny, J., Korattikara, A., Guadarrama, S.: Measuring the reliability of reinforcement learning algorithms. arXiv preprint arXiv:1912.05663 (2019) 40. Andrychowicz, M., et al.: Hindsight experience replay. In: Advances in Neural Information Processing Systems, vol. 30 (2017)
On Context Distribution Shift in Task Representation Learning for Online Meta RL Chenyang Zhao, Zihao Zhou, and Bin Liu(B) Research Center for Applied Mathematics and Machine Intelligence, Zhejiang Lab, Hangzhou 311121, China [email protected]
Abstract. Offline Meta Reinforcement Learning (OMRL) aims to learn transferable knowledge from offline datasets to enhance the learning process for new target tasks. Context-based Reinforcement Learning (RL) adopts a context encoder to expediently adapt the agent to new tasks by inferring the task representation, and then adjusting the policy based on this inferred representation. In this work, we focus on context-based OMRL, specifically on the challenge of learning task representation for OMRL. We conduct experiments that demonstrate that the context encoder trained on offline datasets might encounter distribution shift between the contexts used for training and testing. To overcome this problem, we present a hardsampling-based strategy to train a robust task context encoder. Our experimental findings on diverse continuous control tasks reveal that utilizing our approach yields more robust task representations and better testing performance in terms of accumulated returns compared to baseline methods. Our code is available at https://github.com/ZJLAB-AMMI/HS-OMRL. Keywords: Offline reinforcement learning · meta reinforcement learning · representation learning
1 Introduction Reinforcement learning (RL) has emerged as a powerful technique, having demonstrated remarkable success in several domains such as video games [2, 19], robotics [8], and board games [26]. However, RL still confronts the challenge of acquiring a satisfactory model or policy in a new setting, which requires a large number of online interactions. This challenge is significant, especially in scenarios where the cost or safety associated with interacting with the environment is high, like health care systems [7] and autonomous driving [32]. Meta RL has been demonstrated as a promising approach to tackle this issue by learning transferable knowledge about the learning process itself and extracting a metapolicy, which enables rapid adaptation with few samples in unseen target environments [6, 23]. Meta RL operates on an assumption that all environments share a similar structure and it learns about the shared structure by interacting with a distribution of training tasks. Recent research suggests that a meta RL agent can learn to infer about the task from © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 614–628, 2023. https://doi.org/10.1007/978-981-99-4761-4_52
On Context Distribution Shift in Task Representation Learning
615
few sample interactions and adjust its policy accordingly [5, 12, 23]. Context-based meta RL approaches involve learning a universal policy conditioned on a latent task representation [23]. During the meta-test stage, the agent adapts the acting policy based on the predicted task representation through a few online interactions. Figure 1 illustrates the general framework of context-based meta RL during both meta-train and meta-test stages.
Fig. 1. A general framework of context-based meta RL. During the meta-train stage, the agent learns to infer about the task and optimize its behavior in the meta-train environments via backpropagation, leveraging the same data buffer. During the meta-test stage, the agent predicts the task representation using few-shot context information and adapts the contextual policy based on the representation. The solid lines denote a forward pass, while the dashed lines signify a backward pass.
Although meta RL methods improve sample efficiency during the meta-test stage, a large batch of online experiences is usually required during the meta-train stage, which is collected by intensive interaction with the training environment. Recent research [4, 17, 18] has combined offline RL ideas [15, 22] with meta RL to address data collection problems during meta-training. Li et al. [16] and Zhou et al. [31] have trained context encoders through supervised contrastive learning, treating same-task samples as positive pairs and all other samples as negative ones. They learned data representations by clustering positive samples and pushing away negative ones in the embedding space. However, offline RL methods are susceptible to distributional discrepancies that might arise between the behavior policies used during the meta-train and meta-test phases [31]. In fully offline meta RL scenarios where both the meta-train and meta-test contexts are gathered using unknown behavior policies, the performance of trained RL agents in testing tasks heavily depends on the quality of the meta-test contexts. If the training policy is unbalanced concerning policy quality - for instance, if it only comprises expert data - the learned context encoder may not generalize well to a broader range of meta-test policies.
616
C. Zhao et al.
Fig. 2. Top row: Demonstrations of outcome testing trajectories with contexts of different qualities. Bottom row: t-SNE visualization of the embedded contexts of different qualities. Five random testing goals (colored dots) are randomly sampled from the task distribution and five contexts are randomly sampled for each goal for visualizing the trajectories. These comparisons show that the performance of the learned context encoder and contextual policy largely depends on the quality of test contexts. (Color figure online)
We utilize the PointRobotGoal environment to illustrate how the anticipated contextual embeddings and resultant trajectories differ based on the quality of the contextual data, as depicted in Fig. 2. As shown, performance drops dramatically when the context information is pre-collected with policies dissimilar from those used for meta-training. Motivated by these observations, we aim to tackle the distribution shift problem for OMRL and present an importance sampling method to learn a more robust task context encoder. We adopt the supervised contrastive learning framework [13] with task indexes as labels. Then, we calculate the “hardness” of positive and negative samples separately based on their distances to the anchor sample in the embedding space. Finally, we adjust the supervised contrastive loss function by weighing both positive and negative samples according to their “hardness” values. Our main contributions include: • We analyze the distribution shift problem in learning a task context encoder for OMRL and demonstrate the impact of context quality on the performance of a learned context encoder. • Based on our analysis, we propose a novel supervised contrastive objective that adopts hard positive and hard negative sampling to train a more resilient context encoder. We conduct experiments to demonstrate the superior performance of our proposed approach compared to various baseline methods. • Additionally, we perform an ablation study to investigate the contribution of the proposed hard sampling strategy for robust context encoder learning. We also analyze the quality of the learned context encoder concerning its uniformity and alignment characteristics.
On Context Distribution Shift in Task Representation Learning
617
2 Related Work 2.1 Meta Reinforcement Learning Meta-learning has proven to be a successful approach by enabling the learning of knowledge about the learning process itself, resulting in improved learning efficiency and exceptional performance across various applications [1, 10, 25]. In the context of meta RL, prior research focused on developing a meta-policy and an adaptation approach that work simultaneously to facilitate sample-efficient task adaptation during testing. Optimization-based approaches allow fine-tuning of the meta-policy using only a few shots of data in the meta-test tasks, leveraging gradient descent [6, 18]. Alternatively, context-based meta RL methods learn a meta-policy conditioned on some task context information, thereby eliminating the need to update networks during the meta-test phase. Duan et al. [5] and Wang et al. [28] utilized a recurrent neural network to encode the context information, while Rakelly et al. [23], Zintgraf et al. [33], and Humplik et al. [12] learned a separate network that encodes context information as task-specific latent variables. With fine-tuning operations in the meta-test stage, optimization-based methods perform more robustly on out-of-distribution test tasks. However, this comes at the cost of increased computing resources during meta-testing. On the other hand, contextbased methods achieve higher sample efficiency and better asymptotic performance when adapting to in-distribution target tasks, such as reaching different goal positions or running at different target speeds [23]. Recently, various works extended meta RL approaches to offline settings, assuming that agents can only access static datasets of pre-collected experiences in each training environment, rather than the environments themselves, to address offline RL scenarios where online data collection is not feasible [15, 22]. Mitchell et al. [18] utilized optimization-based methods and applied a Model-Agnostic Meta-Learning (MAML) style loss in policy learning. Regarding context-based methods, Li et al. [17] disentangled the learning of the encoder and policy. They first learned an informative task context encoder with transition data only and then learned the conditional policy using standard offline RL methods, such as Behavior Regularized Actor Critic (BRAC) [30]. Li et al. [16] and Yuan et al. [31] employed the objective of contrastive learning for learning task representation and achieved more robust performance in testing scenarios. However, most prior works assumed that the underlying distribution of context data remained unchanged between training and testing and did not explicitly consider the problem of distribution shift. This shift is common in OMRL, for instance, when agents are trained on experience data primarily collected with expert-level policies but need to explore new environments with random or sub-optimal policies. To tackle the issue of distribution shifts in OMRL, we propose a hard sampling approach to context encoder learning. 2.2 Contrastive Learning Contrastive learning is a prevalent technique for representation learning [3, 9, 20]. It learns data representations by encouraging similar samples to remain close in the
618
C. Zhao et al.
embedding space while pushing away dissimilar ones. In self-supervised learning settings where there is no label information available, the positive pair refers to various augmented views of a single sample, while negative samples refer to views of different samples [3]. In supervised learning settings, samples that belong to the same class label are considered positive samples and their label information is embedded into contrastive objectives [13]. Several recent works have focused on analyzing the behavior of contrastive learning. Robinson et al. [24] emphasized the importance of negative sample distribution and proposed an importance sampling technique to mine hard negative samples. Furthermore, Wang et al. [29] and Wang et al. [27] addressed two essential properties of learned representations that contribute to good performance in downstream tasks: alignment, which measures the proximity between similar samples in the embedding space, and uniformity, which measures how similar the learned representation distribution is to a uniform distribution. Theories provided by [11] reveal that it is crucial to strike a balance between these two critical properties to find good representations.
3 Preliminaries In meta RL, we consider a distribution of tasks pT (·). Each task is formalized as a Markov decision process (MDP), defined as T = S, A, P, R, γ, where S, A, P, R, γ denote the state space, action space, transition function, reward function and discount factor, respectively. The same as previous works [17, 23], we assume here that similar tasks share the same state space S, action space A and discount factor γ. The differences among tasks lie in the reward functions (e.g., reaching to different goals) and transition functions (e.g., walking on different terrains). During the meta-train stage, the agents have full access to a set of meta-train tasks {T }train . The objective of meta RL is to train an agent with data from the meta-train task set only, such that the trained agent can quickly adapt itself to unseen target tasks {T }test with limited data. 3.1 Context-Based Meta RL To enable quick adaptation to target tasks, context-based methods learn both a task context encoder parameterized by θenc : z ∼ qθenc (z|x) and a contextual policy parameterized by φ : a ∼ πφ (a|s, z). Specifically, the encoder predicts a latent task embedding by encoding some context information x, whereas the contextual policy predicts the optimal action a given the current state s and the encoded embedding z. In situations where there is a lack of prior knowledge regarding the target task, the context information x within meta RL may consist of a limited quantity of interactive experience {st , at , st+1 , rt }t=1,2,...,T that is relevant to the target task. Specially for OMRL, Li et al. [17] discovered that training the context encoder separately from the contextual value function and policies can lead to more effective results. Specifically, during the meta-training stage, the agent initially learns the context encoder, which is then frozen before proceeding to train the RL components such as the actor and critic networks. At the meta-test stage, given some task context information τ te , e.g., one trajectory data collected in testing environment, the agent first samples
On Context Distribution Shift in Task Representation Learning
619
thetask representation z te ∼ qθenc z|τ te , then deploys the corresponding acting policy πφ a|s, z te into the new environment. In this work, we also choose to decouple the learning of the context encoder and contextual policy, since this approach has been shown to improve robustness and lead to better overall performance. By treating the latent task embedding z as an unobserved component of the state, we can frame the context-based meta RL problem as a partially observed MDP (POMDP). Specifically, in this POMDP formulation, the state comprises both the environment state x and the task embedding z. Since the task remains fixed within a given trajectory, the task embedding z also remains constant over that period. If we assume that the task context encoder performs well, then the learning of the acting policy πφ (a|s, z) and associated value functions V (s, z), Q(s, a, z) can be addressed using any standard RL algorithms. 3.2 Offline RL The offline RL approach assumes that the agent can only access a static data buffer of trajectories that were collected using an unknown behavior policy πb . One issue with this method is that conventional techniques for learning an optimal Q function may require querying unseen actions, which can cause estimation errors and an unstable learning process. To tackle this challenge, Dorfman et al. [4] introduced implicit Q-learning (IQL), which employs expectile regressing to learn the optimal value function Vψ (s) and Q function Qθ (s; a). The corresponding loss functions are given as follows: LV (ψ) = E(s,a)∼D [Lτ2 (Q (s, a) − Vψ (s))], θ 2 LQ(θ ) = E(s,a,s )∼D [ r(s, a) + γ Q s , a − Qθ (s, a) ], θ
(1) (2)
where D is the offline dataset, θ, θ, and ψ parameterize the Q network, target Q network, and the value network, respectively. Lτ2 (x) = |τ − 1(x < 0)|x2 is the expectile regression function. To learn policies, advantage-weighted regression is utilized to extract a policy that is parameterized by φ from the estimated optimal Q function [21]. The objective is (3) Lπ (φ) = E(s,a)∼D [exp β Q (s, a) − Vψ (s) logπφ (a|s)]. θ Given its robustness and effectiveness in the context of offline RL, we adopt IQL as our method of choice for downstream policy learning in this work.
4 Learning Task Representation Offline RL poses the challenge of learning from data that was collected using unknown policies. As a result, distribution shifts between training and testing data due to performance differences in these policies can significantly degrade testing performance [15]. Similarly, imbalanced training data during task representation learning may lead to poor generalization when presented with contexts from different behavior policies. For instance, if offline data is primarily gathered using near-optimal policies, the learned context encoder could fail to identify low-quality context data, such as trajectories generated by worse policies.
620
C. Zhao et al.
To overcome this challenge, we propose a contrastive learning and importance sampling-based method for task representation learning from imbalanced offline training sets. By assuming that contexts of various qualities will be encountered uniformly during the meta-test stage, we weigh both positive and negative samples in the supervised contrastive objective according to their “hardness”, which is measured by the distance between samples in the embedding space. 4.1 Representation Learning Framework Consider the common meta RL scenario in which no prior information about tasks is available, and the agent must infer about the task based on trajectory data τ = {st , at , st+1 , rt }t=0,1,...,T . To process a batch of data, we first apply data augmentation to generate two distinct views of the batch. In the case of trajectory data, we use both a transition encoding module and a trajectory aggregation module to produce normalized embeddings. During the training phase, we also employ a projection module to generate lower-dimensional projections, upon which contrastive losses are computed. It is worth noting that the projection module will be discarded after the representation learning stage. To summarize, the key components of our representation learning framework are as follows: Data Augmentation. For every input trajectory τ , we create two distinct views of the data, with each view containing a subset of the information found in the original sample. Similar to how images are cropped in computer vision, we randomly select two different segments τ1 , τ2 as the two views of the trajectory τ . Encoding Network. The encoding network is responsible for mapping the input trajectories to latent representation vectors z = Enc(τ ) ∈ RDE , where DE represents the dimension of the embedding space. The encoding network comprises two components: First, a transition encoder that maps each transition to a transition embedding vector v = f s, a, s , r ∈ RDT ; second, an aggregator network that collects information from all the transitions and produces the latent representation z = g({vt }) ∈ RDE . We normalize the embeddings v and z so that they lie on the unit hypersphere in RDE and RDT , respectively. Projection Network. The projection network is responsible for mapping the context embedding z to the final output vector w, where w = Proj(z) ∈ RDP , and is used in computing the distance between a pair of data samples. We normalize the projected output w to lie on a unit hypersphere. Following prior work [3, 13], we discard the projection head when performing downstream tasks such as learning the contextual policy within the context of OMRL.
4.2 Contrastive Objective Consider a set of training tasks {T k } and the corresponding offline datasets {B k }. Each dataset B k comprises trajectories {τnk } collected offline in task T k using an
On Context Distribution Shift in Task Representation Learning
621
unknown behavior policy, where k ∈ I ≡ 1, . . . , K denotes the index of a task, and n ∈ I ≡ 1, . . . , N denotes the index of an arbitrary trajectory sample. The objective of representation learning is to encode the trajectory data τ to a latent representation vector zτ , such that trajectories from the same task are similar to each other in the embedding space. To achieve this, we first apply data augmentation to generate two different views. Let Di {Ai (τ )|τ ∈ {B k }}i=1,2 denote the i-th view of the trajectories drawn from all tasks, and let D [D1 , D2 ] denote the collection of two views. Supervised Contrastive Learning (SCL). Since the task labels are available in our OMRL setup, we leverage supervised contrastive learning to incorporate the information of task labels into the dataset [13]. Given an anchor sample τq = A1 (τ), we consider the other view A2 (τ ). Pairs of samples with the same task label in these two views are treated as positive samples, while those associated with different task labels are treated as negative samples. Following the approach taken in [20], we adopt the InfoNCE loss for this purpose, resulting in the following objective function: −1 ezτq ·zτp /β , (4) log
zτq ·za /β τq ∈D P τq τp ∈P (τq ) a∈D (τq ) e where τq denotes the anchor sample, P τq the set of corresponding positive samples, D(τ) ≡ D\{τq } the collection of views excluding the anchor sample τq . LSCL =
Hard Negative Sampling. To automatically select more informative negative samples, Robinson et al. [24] proposed hard negative sampling. This technique draws inspiration from importance sampling methods, where the optimizer places greater emphasis on challenging negative samples. These are samples that are located in close proximity in the embedding space but are associated with different tasks. In this work, we extend the hard negative sampling method to the supervised contrastive learning framework by defining the loss function as follows: −1 ezτq ·zτp /β , (5) log pos τq ∈D P τq τp ∈P (τq ) Z + Z neg Z pos = (6) exp zτq · za /β , a∈P (τq ) neg Z neg = (7) ωa exp zτq · za /β a∈N (τq ) where Z pos is computed over all positive samples P τq , and Z neg over all negative ones neg neg ωa measures the “hardness” N τq . Here, of a negative sample a, namely, ωa
exp zτq · za / a ∈N (τq ) exp zτq · za . LHG =
Hard Positive Sampling. To further address the issue of imbalanced data in representation learning for OMRL, we propose re-weighting the positive samples in Eq. (5) based on their “hardness”. Specifically, we assign higher weights to positive samples that are further away from the anchor sample. As a result, the loss function is formulated as follows: LHP+HG =
ωτp · ezτq ·zτp /β −1 log pos , τq ∈D P τq τp ∈P(τq ) Z + Z neg
pos
(8)
622
C. Zhao et al.
where Z
pos
=
pos a∈P (τq ) ωa exp(zτ
· za /β) is the re-weighted sum over positive sampos
ples, and remains unchanged. Here, ω for a positive
a measuresthe “hardness” pos sample τp , namely, ωa = exp −zτq · za / a ∈P (τq ) exp −zτq · za . Note that the positive weights are calculated based on the negation of dot products. Therefore, positive samples that are farther away from the anchor sample will receive more significant weights in the gradient computation. Z neg
5 Experiments In the experiments, our main objective is to assess whether our proposed method can develop a more robust context encoder, particularly when the training data is imbalanced. Our initial focus is on highlighting the issue of context distribution shift in task representation learning for OMRL. We then proceed to compare our technique with baseline methods, specifically in terms of their ability to handle contexts of varying qualities. Finally, we conduct two additional experiments to gain further insight into our approach. These involve reviewing the sampling strategy of our method and analyzing the uniformity and alignment characteristics of the trained encoders. 5.1 Experimental Setup We utilize a total of five simulated continuous control environments, comprised of three environments featuring varying goal conditions (reward functions) and two with dynamic variations (transition functions). These specific tasks have been previously employed in studies such as [17, 23], and [31]. Environments with Changing Reward Functions PointRobotGoal: the agent is responsible for controlling a robot to navigate towards various goal positions located on a unit circle. The reward is determined by the Euclidean distance between the robot and the designated goal position; AntDir: the agent is tasked with controlling a simulated ant robot to move in a variety of two-dimensional directions. The reward is determined by the speed projected on the desired direction of movement; CheetahVel: the agent is responsible for controlling a simulated half-cheetah robot to run at varying target speeds. The reward is calculated as the difference between the actual speed of the robot and the designated target speed. Environments with Changing Transition Functions. The WalkerParams and HopperParams environments are both locomotion tasks simulated using MuJoCo. The goal of each task is to train the agent to propel a walker or hopper robot forward as quickly as possible. In each scenario, certain physical parameters such as the robot’s mass and friction coefficient are randomized. For every environment, a total of 30 tasks are uniformly sampled to create the training set of tasks. Additionally, 10 tasks are sampled from the same distribution as the target
On Context Distribution Shift in Task Representation Learning
623
task set. To generate the offline dataset, we utilize soft actor-critic (SAC) [8] to train each individual task separately. The data captured during this process is stored as the offline dataset in the form of a replay buffer. For additional information regarding the testing environment distribution and offline dataset, refer to the appendix section of [34]. 5.2 The Distribution Shift Problem Our experiments revealed the previously mentioned distribution shift issue. In the PointRobotGoal environment, we visually observed the context embeddings and downstream performance outcomes when provided with contexts of varying qualities. This is illustrated in Fig. 2. To generate testing context buffers of varying qualities, we began by ranking trajectory samples from the offline context buffer for each task based on their accumulated returns. Next, the entire buffer was uniformly split into 10 smaller buffers. The context buffer of low quality was comprised of samples that had a performance level within the lowest 10th percentile. Context buffers of medium (and high) quality were made up of samples with a performance level ranging from the 10th to 20th percentile (and the top 10th percentile), respectively. The encoder and policy networks were meta-trained according to the SCL objective (Eq. 4). During the meta-test stage, one trajectory was sampled from the buffer as context information each time, which allowed us to produce one outcome trajectory with the inferred task representation and contextual policy. In Fig. 2, we have provided t-SNE visualizations of the resulting embeddings along with exemplary outcome trajectories. As demonstrated in the figure, the learned context encoder is capable of separating high-quality contexts (located in the right-hand column) in the embedding space and achieving the target goal. However, as the quality of the context information decreases (as observed in the left and middle columns), the performance drastically deteriorates. When incorporating context information that was obtained using dissimilar policies from those utilized during meta-training, the encoder trained according to SCL fails to cluster the context embeddings properly, which results in inferior testing performance outcomes. In comparison to SCL, our proposed method (Fig. 3) greatly improves the separability of task contexts in the embedding space, thereby producing better testing trajectories against low-quality contexts. 5.3 Task Adaptation Performance To evaluate the effectiveness of our hard sampling approach to OMRL, we compared it against the following baseline methods: Offline PEARL. A natural baseline for OMRL is to extend off-policy meta RL approaches to the offline setup. In this regard, we have considered the offline variant of PEARL as a baseline method [23]. In offline PEARL, the context encoder is jointly trained with the contextual value functions and contextual policy. The Mean Squared Error (MSE) loss in estimating the value function is utilized to update the context encoder. FOCAL. Li et al. [17] employed metric learning to train the context encoder. The aim is to bring positive pairs of samples together while pushing negative pairs apart. Positive and negative pairs are defined as transition samples from the same and different tasks,
624
C. Zhao et al.
Fig. 3. The t-SNE visualization and demonstrative trajectories corresponding to contexts of low quality in the PointRobotGoal environment. The comparison is made between training with hard sampling (a, b) and without hard sampling (c, d). Our results indicate that the implementation of a hard sampling strategy significantly improves the separability of task contexts in the embedding space, thereby producing better trajectories against low-quality contexts.
respectively. Another variant of FOCAL, called FOCAL++ [16], has also been considered. The primary differences include that FOCAL++ replaces the objective function with a contrastive objective with momentum [9] and employs attention blocks to encode the trajectories instead of transitions. CORRO. Yuan et al. [31] generated synthetic transition samples as negative samples and trained the encoder network using InfoNCE loss [20]. In our experiments, we utilize reward randomization to generate synthetic transitions. Note that for analyzing the impact of task representation learning, we have employed IQL [14] as the underlying offline RL method for all baselines. Although the original implementations of CORRO and PEARL employ SAC [8] and FOCAL uses BRAC [30], respectively, we have conducted two experiments where the agents are trained with the original offline RL algorithms and with IQL, respectively. We have found that IQL performs better in all cases, and therefore only report results trained with IQL. In the following experiments, we have evaluated the test performance of meta-trained policies in scenarios where distribution shifts exist between the meta-train and meta-test stages. Similar to the previous experiment, the low-quality context buffer comprises trajectory data with accumulated returns within the lowest 10th percentile. During testing,
On Context Distribution Shift in Task Representation Learning
625
Table 1. The comparisons of testing performances between our proposed method against baseline methods (top half) and different sampling strategies (bottom half). All learned context encoders are tested with low-quality contexts in target tasks. The bold numbers highlight the best performances over all compared methods. All results are averaged over 5 random seeds. Environment
PointRobotGoal
AntDir
WalkerParams
FOCAL
−80.4 ± 20.4
247.8 ± 54.7
154.8 ± 64.3
FOCAL++
−63.9 ± 10.9
289.8 ± 67.2
180.7 ± 73.2
CORRO
−52.7 ± 14.2
313.0 ± 74.3
259.0 ± 36.4
Offline PERAL
−129.4 ± 15.7
248.3 ± 74.8
193.8 ± 45.8
LHP+HG
−44.9 ± 3.5
352.1 ± 42.8
264.8 ± 41.4
LSCL
−60.6 ± 9.8
309.5 ± 63.4
160.4 ± 58.2
LHG
−49.8 ± 11.9
337.0 ± 48.9
247.2 ± 33.8
LHP
−55.4 ± 13.2
328.9 ± 43.0
183.9 ± 42.8
we randomly select one trajectory from the context buffer each time and test the agent in the corresponding task. The performance outcomes are measured by the accumulated return and are averaged over all tasks in the target task set.
Fig. 4. The uniformity and alignment properties of the embedded contexts in AntDir environment given different training strategies.
Table 1 summarizes the test results in three separate environments, where all policies are tested using low-quality contexts. As observed in the experiments, FOCAL and CORRO encounter challenges when trying to generalize to low-quality contexts during testing. We believe that this is largely due to the fact that the training buffers mainly consist of high-quality data, which causes the baseline methods to fail in capturing the information contained in low-quality context data. In comparison to these baselines, our hard sampling strategy helps to achieve better overall performance in these environments. In the other two environments, CheetahVel and HopperParams, our proposed methods, hard sampling performs competitively when compared with CORRO. The testing performances of our LHP+HG are −67.5 and 182.3, respectively, whereas those of
626
C. Zhao et al.
CORRO are −65.3 and 180.7. This can be attributed to the fact that in these environments, the contexts are less diverse due to the nature of the environment. For instance, consider the CheetahVel environment where different tasks are defined as running at different speeds. The robot agent is expected to move forward across all tasks. Consequently, the underlying distribution of the trajectory data is less diverse compared to the other environments (such as AntDir). Thus, the baseline methods suffer fewer distribution shift issues in these environments, and the advantage of using hard sampling strategies disappears. For more details on the experimental results, refer to the appendix section of [34]. 5.4 Further Analysis Sampling Strategies. The hard sampling strategy serves as a crucial component in our proposed method to address the data imbalance challenge in OMRL. To illustrate how hard positive and hard negative sampling impact the testing performance of the learned policies, we have conducted experiments. Using the same evaluation protocol, we have compared various hard sampling strategies, including SCL (with no hard sampling), HG (with only hard negative sampling), and HP (with only hard where the weights of all negative samples are
positive sampling equal, i.e., Z neg = a∈N (τq ) exp zτq · za /β ) and HP+HG. The testing performances are presented in Table 1. As illustrated, utilizing LHP+HG leads to the best performance outcomes. Additionally, using HP or HG individually results in better performance when compared to the SCL baseline when tested with low-quality contexts. This suggests that both hard positive and hard negative sampling have a positive impact on learning robust context encoders.
Uniformity and Alignment. Uniformity and alignment are considered to be two vital properties of good representations in contrastive learning [11, 27, 29]. Uniformity pertains to how uniformly the samples are spread in the representation space, while alignment refers to the closeness between semantically similar samples. Low uniformity may indicate a possible collapse of the model, whereas low alignment implies the loss of semantic information [27]. Wang et al. [29] have found that hard negative mining methods aid in learning embeddings with both low alignment and low uniformity. Ideally, we aim to achieve embeddings with both low alignment loss and low uniformity loss. Formally, given a batch of samples, these two properties can be quantified as follows: 2 (9) Luniformity = logEτ,τ ∼D exp −t zτ − zτ , 2 Lalignment = Eτ ∼D Eτ1 ,τ2 ∼A(τ ) zτ1 − zτ2 ,
(10)
where τ1 , τ2 are a pair of samples augmented from the same sample τ . We have demonstrated the uniformity and alignment properties of hard sampling compared to other variants in the AntDir task. As shown in Fig. 4, utilizing hard sampling enables the agent to attain a lower alignment loss while the uniformity loss remains comparable across all runs.
On Context Distribution Shift in Task Representation Learning
627
6 Conclusions In this paper, we have highlighted that the context distribution shift problem is likely to occur during the task representation learning phase of an offline meta-reinforcement learning (OMRL) process. To address this issue, we have proposed a novel technique that combines the hard sampling strategy with the idea of supervised contrastive learning in the context of OMRL. Our experimental results on several continuous control tasks have demonstrated that when there are context distribution shifts, utilizing our approach can lead to more robust context encoders and significantly improved test performance in terms of accumulated returns, compared to baseline methods. We have open-sourced our code at https://github.com/ZJLAB-AMMI/HS-OMRL to facilitate future research in this direction towards robust OMRL. Acknowledgement. This work was supported by Exploratory Research Project (No. 2022RC0AN02) of Zhejiang Lab.
References 1. Bengio, Y., Bengio, S., Cloutier, J.: Learning a synaptic learning rule. Citeseer (1990) 2. Berner, C., et al.: Dota 2 with large scale deep reinforcement learning. arXiv preprint arXiv: 1912.06680 (2019) 3. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607. PMLR (2020) 4. Dorfman, R., Shenfeld, I., Tamar, A.: Offline meta reinforcement learning–identifiability challenges and effective data collection strategies. Adv. Neural Inf. Process. Syst. 34, 4607– 4618 (2021) 5. Duan, Y., Schulman, J., Chen, X., Bartlett, P.L., Sutskever, I., Abbeel, P.: Rl2 : fast reinforcement learning via slow reinforcement learning. arXiv preprint arXiv:1611.02779 (2016) 6. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: International Conference on Machine Learning, pp. 1126–1135. PMLR (2017) 7. Gottesman, O., et al.: Guidelines for reinforcement learning in healthcare. Nat. Med. 25(1), 16–18 (2019) 8. Haarnoja, T., Zhou, A., Abbeel, P., Levine, S.: Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In: International Conference on Machine Learning, pp. 1861–1870. PMLR (2018) 9. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9729–9738 (2020) 10. Hospedales, T., Antoniou, A., Micaelli, P., Storkey, A.: Meta-learning in neural networks: a survey. IEEE Trans. Pattern Anal. Mach. Intell. 44(9), 5149–5169 (2021) 11. Huang, W., Yi, M., Zhao, X.: Towards the generalization of contrastive self-supervised learning. arXiv preprint arXiv:2111.00743 (2021) 12. Humplik, J., Galashov, A., Hasenclever, L., Ortega, P.A., Teh, Y.W., Heess, N.: Meta reinforcement learning as task inference. arXiv preprint arXiv:1905.06424 (2019) 13. Khosla, P., et al.: Supervised contrastive learning. In: Advances in Neural Information Processing Systems 33, pp. 18661–18673 (2020)
628
C. Zhao et al.
14. Kostrikov, I., Nair, A., Levine, S.: Offline reinforcement learning with implicit Q-learning. arXiv preprint arXiv:2110.06169 (2021) 15. Levine, S., Kumar, A., Tucker, G., Fu, J.: Offline reinforcement learning: tutorial, review, and perspectives on open problems. arXiv preprint arXiv:2005.01643 (2020) 16. Li, L., Huang, Y., Chen, M., Luo, S., Luo, D., Huang, J.: Provably improved context-based offline meta-RL with attention and contrastive learning. arXiv preprint arXiv:2102.10774 (2021) 17. Li, L., Yang, R., Luo, D.: Focal: efficient fully-offline meta-reinforcement learning via distance metric learning and behavior regularization. arXiv preprint arXiv:2010.01112 (2020) 18. Mitchell, E., Rafailov, R., Peng, X.B., Levine, S., Finn, C.: Offline meta-reinforcement learning with advantage weighting. In: International Conference on Machine Learning, pp. 7780–7791. PMLR (2021) 19. Mnih, V., et al.: Human-level control through deep reinforcement learning. Nature 518(7540), 529–533 (2015) 20. Oord, A.v.d., Li, Y., Vinyals, O.: Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748 (2018) 21. Peng, X.B., Kumar, A., Zhang, G., Levine, S.: Advantage-weighted regression: simple and scalable off-policy reinforcement learning. arXiv preprint arXiv:1910.00177 (2019) 22. Prudencio, R.F., Maximo, M.R., Colombini, E.L.: A survey on offline reinforcement learning: taxonomy, review, and open problems. arXiv preprint arXiv:2203.01387 (2022) 23. Rakelly, K., Zhou, A., Finn, C., Levine, S., Quillen, D.: Efficient off-policy metareinforcement learning via probabilistic context variables. In: International Conference on Machine Learning, pp. 5331–5340. PMLR (2019) 24. Robinson, J., Chuang, C.Y., Sra, S., Jegelka, S.: Contrastive learning with hard negative samples. arXiv preprint arXiv:2010.04592 (2020) 25. Schmidhuber, J.: Evolutionary principles in self-referential learning, or on learning how to learn: the meta-meta-... hook. Ph.D. thesis, Technische Universität München (1987) 26. Silver, D., et al.: Mastering chess and shogi by self-play with a general reinforcement learning algorithm. arXiv preprint arXiv:1712.01815 (2017) 27. Wang, F., Liu, H.: Understanding the behaviour of contrastive loss. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2495–2504 (2021) 28. Wang, J.X., et al.: Learning to reinforcement learn. arXiv preprint arXiv:1611.05763 (2016) 29. Wang, T., Isola, P.: Understanding contrastive representation learning through alignment and uniformity on the hypersphere. In: International Conference on Machine Learning, pp. 9929– 9939. PMLR (2020) 30. Wu, Y., Tucker, G., Nachum, O.: Behavior regularized offline reinforcement learning. arXiv preprint arXiv:1911.11361 (2019) 31. Yuan, H., Lu, Z.: Robust task representations for offline meta-reinforcement learning via contrastive learning. In: International Conference on Machine Learning, pp. 25747–25759. PMLR (2022) 32. Zhou, M., et al.: Smarts: scalable multi-agent reinforcement learning training school for autonomous driving. arXiv preprint arXiv:2010.09776 (2020) 33. Zintgraf, L., et al.: VariBAD: a very good method for Bayes-adaptive deep RL via metalearning. arXiv preprint arXiv:1910.08348 (2019) 34. Zhao, C., Zhou, Z., Liu, B.: On context distribution shift in task representation learning for offline meta RL. arXiv preprint arXiv: 2304.00354 (2023)
Dynamic Ensemble Selection with Reinforcement Learning Lihua Liu1,2 , Jibing Wu1,2 , Xuan Li1,2 , and Hongbin Huang1,2(B) 1 National University of Defense Technology, Changsha 410000, Hunan, China
[email protected] 2 Laboratory for Big Data and Decision, Changsha 410000, Hunan, China
Abstract. In this work, we propose a novel approach to ensemble learning, referred to as Dynamic Ensemble Selection using Reinforcement Learning. Traditional ensemble learning methods rely on static combinations of base models, which may not be optimal for diverse inputs and contexts. Our proposed method addresses this limitation by dynamically selecting the most appropriate ensemble member based on the current input and context, utilizing reinforcement learning algorithms. We formulate the ensemble member selection problem as a Markov Decision Process and employ Q-learning to learn a selection policy. The learned policy is then used to adaptively choose the best ensemble member for a given input, potentially improving the overall performance of the ensemble learning system. The proposed method demonstrates the potential for increased accuracy and robustness in various learning data sets. Keywords: Ensemble Selection · Reinforcement Learning · Ensemble Pruning
1 Introduction Ensemble learning, which involves a group of learners working together as a committee, has garnered a lot of attention in research due to its ability to enhance generalization performance [1]. It has become a popular topic in machine learning because of its simplicity and effectiveness. The concept of ensembling originated from Hansen and Salamon’s research, which demonstrated that the generalization ability of a neural network could be significantly improved by combining multiple neural networks [1]. Ensembles have been applied successfully to various domains, including face recognition [2], character recognition [3], and image analysis [4]. There are many ensemble learning algorithms, such as Bagging [5], Boosting [6], random forest [7], negative correlation learning [8, 9], and evolutionary computation-based algorithms [10, 11]. Ensemble learning algorithms that currently exist often produce ensembles that are excessively large, requiring significant amounts of memory. The process of generating a prediction for a new data point can be costly in these larger ensembles. While these additional expenses may not be significant when working with small datasets, they can become substantial when using ensemble methods on a larger scale. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 629–640, 2023. https://doi.org/10.1007/978-981-99-4761-4_53
630
L. Liu et al.
Furthermore, the notion that larger ensembles are always superior is not necessarily accurate. Both theoretical and empirical evidence suggest that small ensembles can outperform larger ones [12, 13]. For instance, boosting ensembles, such as Adaboosting [6] and Arcing [14], place greater emphasis on training samples that were incorrectly classified by previous learning machines, leading to a reduction in training error. However, boosting ensembles are susceptible to overfitting the noise present in the training data [15, 16]. In such situations, it is essential to eliminate some of the overfitting individuals to achieve optimal generalization performance. Over the past few decades, various ensemble pruning algorithms have been proposed, including Kappa pruning [17], concurrency pruning [18] and statistical learning based methods [19–22]. An alternative approach was proposed by Yao et al. [12], who utilized a global optimization method, the genetic algorithm (GA), to assign weights to ensemble members while constraining them to be positive. Zhou et al. [13] demonstrated that smaller ensembles can often outperform larger ones. A similar GA-based method is described in [23]. However, these GA-based approaches attempt to identify optimal weight combinations by minimizing training error, which can make them vulnerable to noise. To address this limitation, we propose a novel approach that employs reinforcement learning (RL) to dynamically select the most suitable ensemble member based on the feedback received from the environment. By leveraging the adaptability and decisionmaking capabilities of RL, this method aims to intelligently navigate the ensemble space and adjust its selection strategy to achieve better performance, particularly in non-stationary environments or when facing previously unseen data. In this paper, a revolutionary approach to ensemble learning is introduced, challenging the conventional static combination of base models. The proposed method, grounded in reinforcement learning algorithms, dynamically selects the best-suited ensemble member depending on the immediate input and context. This methodology is formulated as a Markov Decision Process, and Q-learning is used to learn a selection policy. This policy, in turn, facilitates the adaptive choice of the ideal ensemble member for each input, enhancing the ensemble learning system’s overall performance. The novelty of this proposed reinforcement learning-based dynamic ensemble selection method lies in its unique ability to adapt to changing environments and data distributions, as well as its capability to recognize the strengths and weaknesses of individual ensemble members. Unlike traditional ensemble methods that use static or predefined strategies, this approach continuously learns and refines its decision-making process through the interactions with the environment. By treating the ensemble member selection problem as a sequential decision-making task, this method effectively captures the underlying structure and dependencies within the data, resulting in more accurate and robust model predictions. Furthermore, the proposed approach is both flexible and generalizable, making it applicable to a wide range of machine learning tasks and ensemble configurations. The organization of this paper is as follows. Section 2 offers a concise overview of the relevant research background. Section 2.2 presents the proposed. Experimental outcomes and their analysis are explored in Sect. 3. The paper concludes with Sect. 4, which summarizes the findings and proposes potential avenues for future research.
Dynamic Ensemble Selection with Reinforcement Learning
631
2 Background The primary aim of ensemble pruning is to decrease the number of learners in an ensemble while maintaining or improving its overall performance. The approach used to prune a set of learners is a crucial factor that can determine the effectiveness of the entire system. There are two main types of ensemble pruning algorithms: selection-based and weight-based pruning algorithms. We will discuss each of these strategies in detail in the following sections. Ensemble pruning is aimed at reducing the size of an ensemble without affecting its performance. The pruning approach chosen for a group of learners is critical since it can determine the performance of the entire system. Two main categories of pruning algorithms are selection-based and weight-based algorithms. The selection-based algorithms do not assign weight coefficients to each learner but instead select or reject the learner. A straightforward selection-based method involves ranking individual learners according to their performance on a validation set and choosing the top performers. Nonetheless, this approach might not yield optimal results in certain situations. For example, an ensemble comprised of three identical classifiers, each with a 95% accuracy, could be less effective than an ensemble of three classifiers with a 67% accuracy and lower error correlation between pairs. Kappa pruning [17] seeks to enhance the pairwise differences among selected ensemble members. Following this, Prodromidis et al. designed a variety of pruning algorithms for their distributed data mining framework [24]. A broader approach to weight-based ensemble optimization focuses on enhancing the ensemble’s generalization performance by adjusting the weight assigned to each ensemble member. In the context of regression ensembles, the optimal combination weights can be determined analytically [13, 25–27]. This topic has also been explored in various other fields, such as financial forecasting [28] and operations research [29]. Despite its potential, this approach often falters in real-world applications. This is primarily due to the presence of multiple estimators with similar performance levels, which can render the correlation matrix C ill-conditioned and hinder least square estimation. Additional challenges with this formulation include (1) the calculation of optimal combination weights from the training set, which frequently leads to overfitting noise, and (2) the inability to reduce the ensemble size in most cases. A numerically stable algorithm for determining optimal combination weights is the least square formulation. In this thesis, we employ least square (LS) pruning in our experiments as a baseline algorithm. LS pruning can be applied to binary classification problems by transforming the classification issue into a regression problem with targets of −1 or +1. However, LS pruning often results in negative combination weights. Strategies that permit negative combination weights are considered unreliable [30, 31]. In order to avoid negative weights, Yao et al. [12] suggested utilizing a genetic algorithm (GA) to assign weights to ensemble members while ensuring the weights remain positive. Subsequently, Zhou et al. [13] demonstrated that small ensembles can outperform large ones. A similar genetic algorithm approach is presented in [23]. Nonetheless,
632
L. Liu et al.
GA-based algorithms tend to be sensitive to noise, as they attempt to achieve optimal combination weights by minimizing training error. Demiriz et al. [32] later employed mathematical programming to search for effective weighting schemes. According to empirical results, these optimization approaches are successful in enhancing performance and can sometimes substantially reduce ensemble size [32]. However, ensemble size reduction is not explicitly incorporated into these programs, and the resulting ensemble size can still be quite large in certain cases. 2.1 Reinforcement Learning Reinforcement Learning (RL) is a subfield of machine learning that focuses on learning optimal actions to take in a given environment to achieve a goal. It has been applied to various domains, such as robotics, game playing, natural language processing, and recommendation systems. In this section, we review some of the related work on reinforcement learning. Model-free RL algorithms, such as Q-learning [33] and SARSA [34], learn the action-value function directly from the interactions with the environment. These algorithms have been widely used in applications like robotic control [35] and game playing [36]. The advent of deep learning brought significant improvements to model-free RL, leading to the development of Deep Q-Networks (DQN) [37], which use deep neural networks as function approximators for the action-value function. Model-based RL algorithms build a model of the environment and use this model to plan and make decisions. One popular model-based RL approach is the Monte Carlo Tree Search (MCTS) [38], which has been successfully applied in game playing, such as Go [39]. Another approach is to learn a differentiable model of the environment dynamics and use gradient-based optimization techniques to improve the policy and Model-Agnostic Meta-Learning (MAML) [40]. Policy gradient methods optimize the policy directly by computing the gradient of the expected return with respect to the policy parameters. Some widely used policy gradient algorithms include REINFORCE [41], Trust Region Policy Optimization (TRPO) [42], and Proximal Policy Optimization (PPO) [43]. These algorithms have been successfully applied in tasks like robotic manipulation [44] and locomotion [45]. 2.2 Dynamic Ensemble Selection Using Reinforcement Learning The idea behind dynamic ensemble selection is to choose the best ensemble member for a specific input based on its past performance and the current context. Reinforcement learning can be used to learn this selection policy. Let’s consider an ensemble of N base models, denoted as M1 , M2 , . . . , MN . For any given input x t , we want to select the most appropriate model M i based on the feedback from the environment. We model this problem as a Markov Decision Process (MDP) with the following components: State Space S: The state st represents the current context of the environment, which can be derived from the input x t , past performance of the ensemble members, and any additional contextual information.
Dynamic Ensemble Selection with Reinforcement Learning
633
Action Space A: The action at corresponds to selecting one of the ensemble members Mi for the current input x t . Thus, the action space consists of N actions, one for each ensemble member. Reward Function R(st , at ): The reward represents the performance of the selected ensemble member M i on the input x t . It can be defined based on various criteria, such as accuracy or minimizing loss. Transition Function T (st , at , st+1 ): This function models the transition from state st to state st+1 after taking action at . In our case, it can be deterministic, as the next state depends solely on the current input and the selected ensemble member’s performance. Discount Factor γ : This scalar parameter determines the importance of future rewards compared to immediate rewards. A value close to 1 places a higher emphasis on future rewards, while a value close to 0 prioritizes immediate rewards. We can use a reinforcement learning algorithm, such as Q-learning or the Proximal Policy Optimization (PPO), to learn a policy π (st ) that maps states to actions (i.e., ensemble member selection) based on the MDP defined above. For instance, in the case of Q-learning, we can define a Q-function Q(st , at ) that represents the expected cumulative reward of selecting ensemble member at in state st . The Q-function can be iterative updated using the following update rule: Q(st , at ) ← Q(st , at ) + α R(st , at ) + γ maxQ(st+1 , at+1 ) − Q(st , at ) , (1) at+1
where α is the learning rate. The policy π (st ) can be derived from the Q-function as follows: π (st ) = arg maxQ(st , at ). at
(2)
Once the policy is learned, it can be used to dynamically select the most appropriate ensemble member for a given input based on the current context and the past performance of the ensemble members. This dynamic selection process allows the ensemble to adapt to the specific input and context at hand, potentially leading to improved overall performance compared to traditional static ensembles. To summarize, the dynamic ensemble selection using reinforcement learning involves the following steps: 1. Model the ensemble selection problem as an MDP with state space S, action space A, reward function R(st , at ), transition function T (st , at , st+1 ), and discount factor γ . 2. Choose a reinforcement learning algorithm (e.g., Q-learning or PPO) to learn the selection policy π (st ) based on the MDP. 3. Train the policy using the reinforcement learning algorithm by iteratively updating the Q-function (or equivalent) and deriving the policy based on the learned Q-function (or equivalent). 4. Apply the learned policy π (st ) to dynamically select the most appropriate ensemble member for a given input based on the current context and past performance of the ensemble members. The proposed dynamic ensemble selection method has the potential to improve the performance of ensemble learning systems by allowing them to adapt to different
634
L. Liu et al.
contexts and inputs. This flexibility can lead to more accurate and robust predictions in a variety of applications, such as classification, regression, and reinforcement learning tasks. Future research could focus on developing more advanced state representations that better capture the context and past performance of the ensemble members. Additionally, different RL algorithms could be explored to optimize the policy learning process, and more sophisticated reward functions could be designed to encourage the selection of diverse and complementary ensemble members. In this section, we present a formal algorithm for the dynamic ensemble selection method using reinforcement learning, specifically Q-learning, as described earlier. This algorithm outlines the process of learning a dynamic ensemble selection policy using Q-learning. The policy is trained using multiple episodes, and each episode consists of a dataset sampled from the environment. The Q-function is updated iteratively based on the current state, action, reward, and next state. The algorithm returns the learned policy, which can be used to dynamically select the most appropriate ensemble member for a given input.
3 Experiments This section reports the experiments conducted to demonstrate the efficacy of the proposed method. In order to draw a comparison, 3 baseline algorithms, i.e. Kappa-based (KBP), Random-based (RBP), and Least Square-based (LSBP) pruning are implemented along with the proposed method in Python using Sklearn Library. Reasons why KBP, RBP and LSBP were chosen include: 1. These methods are widely used and accepted in the field of ensemble learning. Therefore, comparisons using these methods can provide a fair and acceptable baseline against which to evaluate the performance of new methods. 2. These three approaches represent different approaches and ideas. KBP is based on the Kappa statistic, which is a measure for evaluating the agreement of classifiers. RBP is a random baseline that provides a measure of minimum expected performance. LSBP is based on the method of least squares, which tries to minimize the squared error between the prediction and the true value. This diversity allows evaluating the performance of new methods from multiple perspectives. 3. These methods can all be evaluated on the same task and dataset, which allows their results to be directly compared. These algorithms are compared using accuracy measure on 10 datasets. The datasets contain a varying number of samples, features and classes. Wine, digits, and breast cancer datasets are taken from Sklearn Library. The remaining datasets are retrieved from the UCI machine learning repository. Table 1 presents the dataset information i.e. number of features, classes, total sample size, and % of data used for training and testing sets. In all experiments, 70% and 30% data is used for training and testing sets respectively. An ensemble pool of 500 Decision Tree Classifiers is constructed and 50 members are chosen to participate in the final voting. For the Decision Tree Classifier, all parameters are set to default as in Sklearn library.
Dynamic Ensemble Selection with Reinforcement Learning
635
Algorithm 1 Dynamic Ensemble Selection using Q-learning Require: Ensemble members M1 , M 2 , , M N , learning rate α, discount factor γ, number of training episodes numepisodes, state encoding function fstate, reward function R. Initialize Q-function Q(s, a) for all s ∈ S and a ∈ A for episode = 1, 2, ..., numepisodes do Sample a dataset D = (x1, y1), (x2, y2), ..., (xT , yT ) from the environment for t = 1, 2, ..., T do Compute the state st = fstate(xt) Choose an action at = argmaxa Q(st , a) with probability 1 − ϵ, or a random action with probability ϵ ( ) Predict the output using the selected ensemble member M at Calculate the reward rt = (
R yt , yt
)
Compute the next state st+1 = fstate(xt+1) Update the Q-function using the update rule: Q st , at
Q st , at
return
R st , at
st
maxQ st 1 , at at
1
1
Q st , at
arg maxQ st , at at
Table 1. Datasets Information Dataset
Features
Classes
Samples
Training Size
Testing Size
wine
13
3
178
125
53
digits
64
10
1797
1258
539
breast cancer
30
2
569
398
171
seeds
7
3
210
147
63
maternal health
6
3
1014
710
304
phishing
10
3
1352
946
406
bank notes
4
2
1372
960
412
raisin
7
2
900
630
270
wifi localization
7
4
2000
1400
600
spambase
57
2
4601
3221
1380
Table 2 reports the performance of algorithms in terms of the average mean and standard deviation of accuracy measure based on 10 iterations against each dataset. Algorithms show varying performance across datasets. On wine dataset, the proposed
636
L. Liu et al. Table 2. Performance Comparison among algorithms on Accuracy Measure
Dataset
KBP
RBP
LSBP
Proposed Method
wine
92.59 ± 2.62(3)
92.04 ± 2.15(4)
92.78 ± 4.04(2)
95.00 ± 3.71(1)
digits
88.30 ± 2.31(2)
87.02 ± 1.32(3)
86.98 ± 0.68(4)
93.81 ± 1.80(1)
breast cancer
94.68 ± 1.12(2)
93.45 ± 1.87(4)
93.63 ± 1.44(3)
96.43 ± 1.75(1)
seeds
93.33 ± 2.78(1)
93.02 ± 3.1(3)
93.18 ± 4.49(2)
91.75 ± 5.12(4)
maternal health
83.42 ± 1.26(1)
81.84 ± 2.82(2)
81.74 ± 1.47(3)
79.47 ± 2.44(4)
phishing
87.54 ± 1.74(3)
86.75 ± 1.08(4)
88.77 ± 1.87(2)
89.78 ± 0.9(1)
bank notes
98.28 ± 0.76(3)
98.28 ± 0.61(3)
98.40 ± 0.68(2)
98.62 ± 0.68(1)
raisin
83.48 ± 2.52(3)
81.04 ± 2.39(4)
94.27 ± 1.65(1)
85.82 ± 1.21(2)
wifi localization
97.87 ± 0.61(1)
97.08 ± 0.43(2)
94.21 ± 1.33(3)
97.87 ± 0.62(1)
spambase
92.20 ± 0.54(2)
91.44 ± 1.06(4)
92.18 ± 0.55(3)
93.51 ± 0.51(1)
method achieves best accuracy performance i.e. 95% whereas other algorithms trend around 92%. On digits dataset, the proposed algorithm outperforms the other algorithms with 93.81% accuracy, whereas other algorithms report relatively lower accuracy ( 0 xg xt−1 0 t−1 , t > 0 g t−1 = , x c xg x0L , t = 0 xg x0L c00 , t=0
(9)
i where denotes the concatenation operation. ct−1 represents the remaining capacity of vehicle i at time t − 1. When t = 0, the vehicle is empty and at the depot. The context xc0 is used to calculate a single query qc in the edge-embeded multi-head attention (E-MHA). The key and the value are computed using the encoder’s outputs xjL , j ∈ {1, · · · , m}. The edge eijl are computed from the initial distance matrix eˆ ij .
qc = Qc ·x0c , kj = K c ·xLj , vj = V c ·xLj , eijl = Wedge · eˆ ij where Qc , K c , V c and Wedge are all learnable weight parameters.
(10)
Reinforcement Learning for Routing Problems with Hybrid Edge-Embedded Networks
647
Next, the attention coefficient uj of the decoder at time t is calculated using the query qc and the key k = {k1 , ..., km } as follows Eq. (11): T q k ( √c ki ), ∀t < t, i = πt d (11) ui,t = , uˆ i,j = softmax(ui,t · eijl ) −∞, otherwise 1 at time t The context embedding xc1 is used to calculate the attention coefficient ui,t using the Eqs. (12) similar to that in Bello et al.’s work [15] in the Single-Head Attention (SHA). The results are clipped to the range [−C, C] using tanh, where C = 10, and then passed through a softmax activation function to obtain the probability of each node.
xc1
= Wf ·
j∈Ni
uˆ i,j vi ,
1 ui,t
=
T
x 1 ki ), dk
C · tanh( √c
∀t < t, i = πˆ t
−∞,
otherwise
1 · eijl ) Pi,t (yt |x, y1 , . . . , yt−1 ) = pθ (πˆ t |s, πˆ t , ∀t < t) = softmax(ui,t
(12)
This section discusses the decoding strategy for predicting the next visiting node based on the probability distribution Pi,t . A sampling or greedy decoding strategy is used, and in the greedy mode, the node with the highest probability is selected at each step, resulting in a trajectory. Once a node is selected, it is masked to prevent its selection in subsequent paths. However, in the CVRP problem, the depot node can be visited again after one time step, but consecutive visits are not allowed. (13) As some nodes are already visited and masked, the graph structure undergoes significant changes. To reflect this, an Embedding Glimpse Layer is used to update the node embeddings every p steps. Multiple decoders are used, and the regularization term is maximized during training by computing the Kullback-Leibler (KL) divergence between pairs of output probability distributions. This encourages the decoders to learn diverse construction patterns and generate distinct solutions. 2.3 Reinforcement Learning The model architecture is depicted in Fig. 1. The encoder comprises of E-GAT layers (Fig. 2) and E-MHA layers (Fig. 3), as previously described. For each input instance, each decoder independently generates a trajectory πm and computes a separate reinforce loss using the same greedy roll-out baseline. The loss is calculated based on the tour length L(πdi ), as shown in Eq. (14): (14) where L(πdi ) denotes the length of the route. To compute the baseline, we use a similar method to that with the best performance in a previous study by Kool et al. (2018) [14]. Specifically, we utilize the model with the best set of parameters from previous epochs
648
X. Ke et al.
as the baseline model and decode greedily to obtain the baseline b(x). The computation for the baseline is given by Eqs. (15):
(T ) b(x) = minL πdi = ydi (1), . . . , ydi di
(15) (t) = arg max P di y x, y (1), . . . , y (t − 1) ydi di di θ t yt
where di is the index of decoders, θ denotes the current parameters, θ denotes the fixed parameters of the baseline model from the previous training epochs. The model is optimized through gradient descent, as shown in Eq. (16): (16) where kKL means KL loss, which should be computed for each state encountered by each decoder in the ideal scenarios. However, we impose KL loss only on the first step to avoid expensive computation.
3 Experiments 3.1 Experimental Dataset In this study, we followed the approaches of prior researches conducted by Kool et al. [14] and Nazari et al. [16] to generate instances of the TSP and CVRP problems. For TSP, we generated instances with 20, 50, and 100 nodes, where each instance is created with random node coordinates within a unit square of [0, 1]×[0, 1]. The intercity distances are computed using the two-dimensional Euclidean distance formula. For CVRP, we generate instances with 21, 51, and 101 nodes, respectively, where the first node represents the depot. The corresponding vehicle capacities are set to 30, 40, and 50, respectively. The demands are normalized to [1, 9], and the vehicle capacity is normalized as well. Moreover, the asymmetric distance matrix for TSP and CVRP problems is generated by the Eq. (17). ⎤ ⎡ 0 · · · dij ⎥ ⎢ D = ⎣ ... . . . ... ⎦, where dij = dji , dji · · · 0 ⎡
⎤ · · · dij + pij ⎥ ⎢ .. .. D = D + P = ⎣ ⎦, where dij = dji and pij = pji . . . 0 dji + pji · · · 0 .. .
(17)
where D is a symmetric distance matrix, but D is a asymmetric distance matrix due to the asymmetric perturbation matrix P. In detail, the distance matrix (D) is computed using the Euclidean distance between the coordinates (Nodes), and a perturbation matrix (P) is incorporated. P is an n × n matrix, where n denotes the number of cities. For each non-diagonal element di,j , i = j, a positive value pi,j is randomly generated within a certain range. Then, pi,j is added to
Reinforcement Learning for Routing Problems with Hybrid Edge-Embedded Networks
649
. There is the original distance matrix di,j resulting in the perturbed distance matrix di,j an example in Eq. (18).
⎡
⎤ ⎡ ⎤ 0.19 0.62 0 0.30 0.68 0.20 ⎢ 0.44 0.79 ⎥ ⎢ 0.30 0 0.62 0.16 ⎥ ⎥ ⎢ ⎥ Nodes = ⎢ ⎣ 0.78 0.27 ⎦, D = ⎣ 0.68 0.62 0 0.73 ⎦, 0.28 0.80 0.20 0.16 0.73 0 ⎤ ⎡ ⎤ ⎡ 0 0.36 0.73 0.22 0 0.06 0.05 0.02 ⎢ 0.31 0 0.63 0.17 ⎥ ⎢ 0.01 0 0.01 0.01 ⎥ ⎥ ⎢ ⎥ P=⎢ ⎣ 0.05 0.06 0 0.03 ⎦, D = D + P = ⎣ 0.73 0.68 0 0.76 ⎦ (18) 0.09 0.06 0.05 0 0.29 0.22 0.78 0
3.2 Experimental Results The study utilizes a training dataset of 12,800 instances and a testing dataset of 10,000 instances. Notably, our proposed model achieves better performance than the MDAM model [11], despite using a smaller training set. The training process uses a batch size of 128 due to memory limitations, and the parameters are initialized by previous model. All evaluations are conducted on a single RTX-2080Ti GPU, and we compare our results with those reported in the original papers. The code is implemented in Python and Pytorch. In TSP problem, we use the exact solver Concorde [3] to obtain the baseline solutions, while in the more challenging CVRP problem, we use the state-of-the-art heuristic solver LKH3 [7] to obtain the baseline solutions. Table 1 presents the test results of TSP instances at different scales. The baseline method is the Concorde solver [3]. The optimization solver Concorde can obtain the exact solutions of TSP through mixed integer programming within a reasonable running time. We can observe that our solutions significantly outperform the solutions by AM (Kool et al. 2018) [14] and MDAM (Xin et al. 2020) [11] for all test sets of TSP problems. According to Fig. 4, the route length of the validation set for the TSP20 problem converges to 3.84. The TSP50 and TSP100 problems are initialized with the parameters of the TSP20 problem, and the route lengths of the validation sets for TSP50 and TSP100 converge to 5.71 and 7.89, respectively. Finally, the test set is used to obtain the results of 3.84, 5.72, and 7.91. The test results of CVRP instances at different scales are shown in Table 2. The baseline method is the LKH solver [7]. In terms of solution quality, our solution outperforms other learning-based algorithms listed. Upon analysing Table 3, the performance of MDAM in handling asymmetric distance matrices diminishes, and the discrepancy between it and our proposed model becomes more pronounced. This indicates the robustness and efficiency of the proposed model in handling more challenging problem instances. Moreover, the experimental results suggest that further research could explore the effectiveness of the proposed model on instances with higher levels of perturbation, as well as on other combinatorial optimization problems with asymmetric distance matrices.
650
X. Ke et al.
Table 1. Ours vs Baseline (Concorde (David L.) [3], AM (Kool, van Hoof, and Welling 2018) [14], MDAM (Xin et al. 2020) [11]) in TSP. Method Concorde
n = 20
n = 50
n = 100
Obj
Gap
Time
Obj
Gap
Time
Obj
Gap
Time
3.84*
0.00%
1m
5.70*
0.00%
2m
7.76*
0.08%
3m
AM (greedy)
3.85
0.34%
0s
5.80
1.76%
2s
8.12
4.53%
6s
AM (sampling)
3.84
0.00%
5m
5.73
0.52%
24 m
7.94
2.26%
1h
MDAM (bs50)
3.84
0.00%
3m
5.70
0.03%
14 m
7.79
0.38%
44 m
Ours (greedy)
3.84
0.01%
6s
5.72
0.03%
20 s
7.91
1.93%
40 s
Ours (bs30)
3.84
0.00%
2m
5.70
0.02%
10 m
7.79
0.39%
25 m
Ours (bs50)
3.84
0.00%
3m
5.70
0.02%
19 m
7.78
0.26%
58 m
Fig. 4. Training & validation average costs for Table 1.
Table 2. Ours vs Baseline (LKH3 (Helsgaun, 2017) [7], RL (Nazari et al. 2018) [16], AM (Kool, van Hoof, and Welling 2018) [14], NeuRewriter (Chen and Tian 2019) [17], MDAM (Xin et al. 2020) [11]) in CVRP. Method
n = 20
n = 50
n = 100
Obj
Gap
Time
Obj
Gap
Time
Obj
Gap
Time
LKH
6.14*
0.00%
2h
10.38*
0.00%
7h
15.65*
0.00%
13 h
RL (beam 10)
6.40
4.39%
27 m
11.15
7.46%
39 m
16.96
8.39%
74 m
AM greedy
6.40
4.43%
1s
10.98
5.86%
3s
16.80
7.34%
8s
AM sampling
6.25
1.91%
6m
10.62
2.40%
28 m
16.23
3.72%
2h
NeuRewriter
6.16
0.48%
22 m
10.51
1.25%
35 m
16.10
2.88%
66 m
MDAM (bs50)
6.14
0.18%
5m
10.48
0.98%
15 m
15.99
2.23%
53 m
Ours (greedy)
6.22
1.31%
8s
10.71
3.18%
22 s
16.33
4.34%
1m
Ours (bs30)
6.14
0.20%
4m
10.49
1.06%
12 m
16.20
3.51%
45 m
Ours (bs50)
6.14
0.08%
7m
10.45
0.68%
20 m
15.91
1.66%
68 m
Reinforcement Learning for Routing Problems with Hybrid Edge-Embedded Networks
651
Table 3. Ours vs MDAM [11] with perturbation matrix pi,j in the range of [0, 0.2] n = 20
n = 50
n = 100
MDAM
5.68
10.62
17.56
Ours
5.65*
10.51*
17.35*
MDAM
8.51
16.22
27.23
Ours
8.46*
16.14*
27.01*
Method TSP CVRP
4 Conclusion In this paper, we propose a novel model that combines E-GAT and E-MHA to construct graph representations using edge information. The hybrid encoder and multiple decoders are trained through reinforcement learning, achieving superior performance with significantly fewer training samples. Our model’s ability to address the symmetric and asymmetric TSP and CVRP problems demonstrates its compatibility and generalization ability, and it has the potential to be extended to other combinatorial optimization (COP) problems. The proposed model provides a promising approach to addressing various combinatorial optimization problems in the real world, with improved performance and reduced training samples. Further development of this model could have farreaching implications in various fields, such as logistics, transportation, and scheduling, in real-world scenarios, overcoming challenges and driving innovation. Acknowledgements. The work is supported by the Natural Science Foundation of Fujian Province of China (No. 2022J01003).
References 1. Velickovic, P., Cucurull, G., Casanova, A., Romero, A., Liò, P., Bengio, Y.: Graph attention networks. arXiv, abs/1710.10903 (2017) 2. Toth, P., Vigo, D.: Vehicle Routing: Problems, Methods, and Applications, 2nd edn. (2014) 3. Applegate, D.L., Bixby, R.E., Chvátal, V., Cook, W.J.: The traveling salesman problem: a computational study (2007) 4. Li, Y., Chu, F., Feng, C., Chu, C., Zhou, M.: Integrated production inventory routing planning for intelligent food logistics systems. IEEE Trans. Intell. Transp. Syst. 20, 867–878 (2019) 5. Brouer, B.D., Álvarez, J.F., Plum, C., Pisinger, D., Sigurd, M.: A base integer programming model and benchmark suite for liner-shipping network design. Transp. Sci. 48, 281–312 (2014) 6. Perboli, G., Rosano, M.: Parcel delivery in urban areas: opportunities and threats for the mix of traditional and green business models. Transp. Res. Part C Emerg. Technol. (2019) 7. Helsgaun, K.: An extension of the Lin-Kernighan-Helsgaun TSP solver for constrained traveling salesman and vehicle routing problems: Technical report (2017) 8. Festa, P.: A brief introduction to exact, approximation, and heuristic algorithms for solving hard combinatorial optimization problems. In: 2014 16th International Conference on Transparent Optical Networks (ICTON), pp. 1–20 (2014)
652
X. Ke et al.
9. Khalil, E.B., Dai, H., Zhang, Y., Dilkina, B.N., Song, L.: Learning combinatorial optimization algorithms over graphs. In: NIPS (2017) 10. Nowak, A.W., Villar, S., Bandeira, A.S., Bruna, J.: A note on learning algorithms for quadratic assignment with graph neural networks. arXiv, abs/1706.07450 (2017) 11. Xin, L., Song, W., Cao, Z., Zhang, J.: Multi-decoder attention model with embedding glimpse for solving vehicle routing problems. In: AAAI Conference on Artificial Intelligence (2020) 12. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770–778 (2015) 13. Ioffe, S., Szegedy, C.: Batch normalization: accelerating deep network training by reducing internal covariate shift. arXiv, abs/1502.03167 (2015) 14. Kool, W., van Hoof, H., Welling, M.: Attention, learn to solve routing problems! In: International Conference on Learning Representations (2018) 15. Bello, I., Pham, H., Le, Q.V., Norouzi, M., Bengio, S.: Neural combinatorial optimization with reinforcement learning. arXiv, abs/1611.09940 (2016) 16. Nazari, M., Oroojlooy, A., Snyder, L.V., Takác, M.: Reinforcement learning for solving the vehicle routing problem. In: Neural Information Processing Systems (2018) 17. Chen, X., Tian, Y.: Learning to perform local rewriting for combinatorial optimization. In: Neural Information Processing Systems (2018)
Advancing Air Combat Tactics with Improved Neural Fictitious Self-play Reinforcement Learning Shaoqin He1,2 , Yang Gao1(B) , Baofeng Zhang1,2 , Hui Chang1 , and Xinchen Zhang1 1 Institute of Automation, Chinese Academy of Sciences, Beijing, China
{heshaoqin2021,yang.gao,zhangbaofeng2022,hui.chang, xinchen.zhang}@ia.ac.cn 2 School of Artificial Intelligence, University of Chinese Academy of Sciences, Beijing, China
Abstract. We study the problem of utilizing reinforcement learning for action control in 1v1 Beyond-Visual-Range (BVR) air combat. In contrast to most reinforcement learning problems, 1v1 BVR air combat belongs to the class of twoplayer zero-sum games with long decision-making periods and sparse rewards. The complexity of action and state space in this game makes it difficult to learn high-level air combat strategies from scratch. To address this problem, we propose a reinforcement learning self-play training framework to solve it from two aspects: the decision model and the training algorithm. Our decision-making model uses the Soft actor-critic (SAC) algorithm, a method based on maximum entropy, as the action control of the reinforcement learning part, and introduces an action mask to achieve efficient exploration. Our training algorithm improves Neural Fictitious Self-Play (NFSP) and proposes the best response history correction (BRHC) version of NFSP. These two components helped our algorithm to achieve efficient training in the high-fidelity simulation environment. The result of the 1v1 BVR air combat problem shows that the improved NFSP-BRHC algorithm outperforms both the NFSP and the Self-Play (SP) algorithms. Keywords: Air Combat · Reinforcement Learning · Neural Fictitious Self-Play
1 Introduction The development of intelligent algorithms for unmanned air combat has been a significant research topic for a long time. Since the 1960s, some scholars have been developing intelligent air combat systems, attempting to use artificial intelligence methods to replace pilots in air combat decision-making. Prior methods mainly include: expert systems constructed by knowledge from air combat domain experts [2], decision systems built by heuristic algorithms such as genetic algorithms and fuzzy logic [3], modelling and solving air combat decisions based on game theory [20], supervised learning based on deep learning, using expert-annotated air combat decision data for behavior cloning. In recent years, with the success of Alpha Go, a Go AI developed by the DeepMind team, in defeating the human world champion, more and more reinforcement learning-based © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 653–666, 2023. https://doi.org/10.1007/978-981-99-4761-4_55
654
S. He et al.
algorithm research has also appeared in the air combat domain [5, 11, 13–15, 17, 19]. The expert system method needs to rely on human experts to design a large number of air combat rules, and the performance of the rule-based air combat intelligent algorithm will not exceed the ability of the designer himself. Heuristic algorithms require experts to design utility functions. The method based on game theory faces the difficulty of modeling real air combat problems and solving Nash equilibrium. Behavioral cloning methods are limited by the size of the data and the ability of the pilots who annotate the data. The current work based on reinforcement learning methods has the problem that the simulation environment is too simple and the target strategy is fixed. In order to solve these problems, this paper uses reinforcement learning combined with our proposed improved version of the NFSP training method. The contributions of this paper are as follows: • In the decision-making model, we use the Soft actor-critic (SAC) algorithm, an entropy-based method, for the reinforcement learning aspect of action control, and introduce the mechanism of action mask to help the algorithm to explore efficiently. • In the training algorithm, we propose an improved version of the NFSP algorithm, called NFSP-BRHC, which is closer to the Nash equilibrium than the original NFSP algorithm. • Ablation and comparison experiments are conducted in a high-fidelity simulation environment, showing that the NFSP-BRHC training algorithm outperforms the NFSP and SP training algorithms.
2 Related Works With the continuous innovation of reinforcement learning algorithms, more and more air combat decision-making methods based on reinforcement learning have emerged. Liu [11] used the Deep Q-Learning (DQN) [12] algorithm for short-range maneuvers in air combat, and its training opponent was based on the Min-Max recursive approach with a depth of 5. Yang [19] used the DQN algorithm for the same scenario, and at the same time used the idea of course learning for basic training, and the opponent strategy for subsequent training was based on statistical rule algorithms. Guo [5] verified the effect of Deep Deterministic Policy Gradient (DDPG) [10] algorithm in the same scene. Qiu [15], on the other hand, added the action of firing missiles in a two-dimensional environment using the improved Twin Delayed Deep Deterministic Policy Gradient (TD3) [4] algorithm to learn maneuvers to evade the missiles. Piao [13] used the Proximal Policy Optimization (PPO) [16] algorithm in 1v1 BVR air combat and trained it through the self-play training method. In the Alpha Dogfight Trail held by DARPA, Adrian P. Pope et al. [14] developed an air combat agent based on the layered reinforcement learning architecture of the SAC algorithm [7], and won second place in the competition.
Advancing Air Combat Tactics with Improved Neural Fictitious
655
The first four works [5, 11, 15, 19] all use rule-based target strategy as opponents for training, which suffers from many problems such as overfitting and being limited by the target strategy. The last two works [13, 14] use a training framework based on self-play to achieve autonomous evolution of air combat AI, but this training method has the problem that the game may be caught in a loop and cannot converge to the Nash equilibrium. Different from the above methods, the training method used in our study is an improved version of NFSP. We improve the NFSP algorithm to make it converge faster and closer to the approximate Nash equilibrium during the training process. In terms of exploring the state space, we propose action masks to help the agent improve the exploration efficiency. These improvements allow agents to converge faster during training and perform better in high-fidelity simulation environment.
3 Proposed Method In this section, the BVR air combat problem will be modeled as a Markov game [1]. Then, the decision model and network architecture of the BVR air combat agent we use are described. Next, an action mask is introduced to help AI better explore the state space. Finally, several self-play algorithms are discussed, and a Neural Fictitious SelfPlay algorithm based on best response history correction (NFSP-BRHC) is proposed for the tactical evolution of air combat agents. 3.1 BVR Air Combat Modeling In this paper, the 1v1 BVR air combat problem is modeled as a fully competitive multiagent Markov game. We can use the tuple (s0 , S, A, P, r, γ ) to define this Markov game, where s0 is the initial state, S is the joint state space, A is the joint action space, P : S × A → S presents the state transition probability, r : S × A → R presents the reward function, and γ ∈ (0, 1] is the discount factor. For agent i, the state space is S i , the action space is Ai , the reward is ri , and its policy πi is the mapping of S i × Ai → [0, 1]. The goal of each agent T in tthe Markov game is to maximize the cumulative discounted γ ri (st , at ) , where T represents the time horizon, ρπi denotes reward: Ri = Eρπi t=0 the probability distribution of trajectories produced by following policy πi . State Space The state space of BVR air combat mainly contains two parts: the state Si that can be obtained directly and the state Sid that needs to be obtained indirectly by further calculation. Si includes opposing aircrafts’ three-dimensional coordinates x, y, z, three angles φp , φr , φy , indicating the attitude, velocities vx , vy , vz in the direction of the three axes, magnitude of the combined velocity |v|, radar lock signal lo and the number of remaining missiles m. Sid includes opposing aircrafts’ distance R, the projected distances rx , ry , rz in the direction of the three axes, the closing rate R , the aspect angle AA, the antenna train angle ATA, the elevation angle EA, the heading crossing angle HCA, and finally the energy-related velocity squared E and velocity squared difference Ed . The
656
S. He et al.
entire state space is defined as follows: S i = [x, y, z, φp , φr , φy , vx , vy , vz , |v|, lo, m, R, rx , ry , rz , R , AA, ATA, EA, HCA, E, Ed ].
(1)
There are 38 states in total. All states are inclusive of both sides except for the distance R, the projected distances rx , ry , rz , the closing rate R , the elevation angle EA, the heading crossing angle HCA and the velocity squared difference Ed . Action Space Fighter pilots can make many complex tactical maneuvers during their missions, and these complex tactical maneuvers are composed of basic atomic actions. So as long as the action space contains these basic atomic actions, it can cover higher-level complex maneuvers. Referring to the seven basic maneuvers [automated maneuvering decisions for air-to-air combat] designed by NASA scholars, the action space we designed contains seven basic actions: uniform straight flight, upward flight, downward flight, left-turn flight, right-turn flight, accelerate, decelerate, and a launch action to control the missile, as shown in Table 1. Table 1. Action Space. Category
Serial
Basic Atomic Action
Maneuvers
a1
uniform straight flight
a2
upward flight
a3
downward flight
a4
left-turn flight
a5
right-turn flight
a6
accelerate
a7
decelerate
a8
launch missile
Attack
Reward Definition In past work, some relied on fine-tuned per-step reward signals from human experts, while others were based on sparse reward signals from important events [13]. Per-step reward signals can solve the problem of cold start, but are difficult to adjust and performance is limited by the designer’s subjective understanding of the air combat advantage. Important event rewards reflect objective air combat events, but are too sparse and suffer from the cold start problem. Therefore, we combine these two types of rewards. The rewards at each step with small weights guide the agent to explore the tendency to explore more valuable states to solve the problem of cold start, and the objective important event rewards help the agent correctly recognize the causal relationship between winning and policy.
Advancing Air Combat Tactics with Improved Neural Fictitious
657
The rewards for each step include whether the radar detects the opponent, speed advantage in the direction of the enemy’s defense line and a fixed negative reward to encourage the agent to end the fight as soon as possible. Rewards for important events include sparse rewards for launching missiles and rewards for successfully evading missiles. The rewards for launching missiles are negative, which can make the agent more cautious about launching missiles. This setting is because each agent only carries two missiles and needs to maximize the benefits of launching. Rewards for important events also include rewards for crossing the enemy’s line, rewards for launching missiles and hitting the enemy to win, and penalties for going out of bounds. The coefficient relationship between the rewards of each step plus the rewards of the missile and the reward of the final result is such that the winning side will receive much larger rewards than the losing side. The reward settings are described in Table 2, where n denotes the max step number in the simulation. Table 2. Reward Definition. Category
Event
Weight
Per-step Rewards
Detect
0.5
Be detected
−0.5
Speed advantage
[−1, 1]
fixed negative
−0.1
Launch missile
−50
Import Event Rewards
Escape missile
50
Cross the line first
2n
Cross the line late
−2n
Hit
2n
Be hit
−2n
Out of bounds
−4n
3.2 Decision Model Soft Actor-Critic Soft actor-critic (SAC) [7] is a maximum entropy reinforcement learning (MERL) algorithm developed based on the idea of maximum entropy, which outputs a stochastic distributed policy function similar to the PPO algorithm. The difference is that the SAC algorithm is an off-policy actor-critic algorithm. Compared with PPO, the sample efficiency of SAC is higher. Compared with the DDPG algorithm, which belongs to the off-policy actor-critic algorithm, SAC uses a stochastic policy, while DDPG uses a deterministic policy. The stochastic policy can help the agent to explore better and converge more stably, while the deterministic policy is more likely to fall into a local optimum,
658
S. He et al.
and its performance and convergence are also unstable. And what makes SAC most different from other reinforcement learning algorithms is that it maximizes the entropy of the policy while optimizing it to maximize the cumulative benefit. This makes the stochastic policy behave more randomly while ensuring the cumulative return, which can more fully explore the state space and further enhance the stability of the algorithm’s performance. After introducing the maximum entropy, the training goal of the SAC actor is to maximize the cumulative reward and the expectation of the policy entropy, which is expressed as follows: ∗ = argmax E(st ,at )∼ρπ [r(st , at ) + αH (π (·|st ))], (2) πMaxEnt π
t
where ρπ denotes the distribution of state-action pairs that the agent encounters under the control of policy π, α is the temperature coefficient, which is used to adjust the degree of emphasis on entropy. Due to the different optimization goals, the value iteration formula of maximum entropy reinforcement learning is slightly changed compared with standard reinforcement learning. The Bellman equation for value iteration of the soft Q function and soft V function in the SAC algorithm is as follows: π π (3) Qsoft (st , at ) = Est+1 ∼p(st+1 |st ,at ) r(st , at ) + γ Vsoft (st+1 ) , π π Vsoft (st ) = Eat ∼π Qsoft (st , at ) − αlogπ (at |st ) .
(4)
The critic network corresponds to the soft Q function, and its training goal is to minimize the following loss function:
1 ˆ t , at ) , Q(st , at ) − Q(s JQ = E(at ,st )∼D (5) 2 ˆ is the target critic network. where D represents the replay buffer, Q The agent policy approximates the energy-based policy (EBP) [6] by minimizing the KL divergence between the policy represented by the state-conditioned stochastic neural network [6] and EBP. Finally, the loss function of the actor network is derived as follows:
1 (6) Jπ = Est ∼D,at ∼π logπ (at |st ) − Q(st , at ) . α Many performance-improving techniques are used in SAC, including drawing on the double Q network [18] and target Q network in DQN to alleviate the problem of overestimation of the Q value. One of the most important improvements is the automatic entropy adjustment mechanism. α, as a hyperparameter that controls MERL’s emphasis on entropy, has a significant impact on the performance of the algorithm. It has different suitable values in different reinforcement learning tasks or different training stages of the same task [7]. The authors of SAC propose an improvement to automatically adjust
Advancing Air Combat Tactics with Improved Neural Fictitious
659
alpha to keep entropy always greater than a threshold while maximizing the expected reward. The final loss function on alpha is as follows: Jα = Eat ∼πt −αlogπt (at |st ) − αH0 , (7) where H0 denotes the threshold of entropy. Action Mask Due to the huge combination of action and state space in 1v1 BVR air combat, in order to improve the efficiency of training, the method of action mask is proposed to prune the exploration of reinforcement learning. There is not much expert knowledge involved in the action mask, and it is mainly used to eliminate some unreasonable places in the training process: 1) About going out of bounds, whether it is falling into the sea or going out of bounds in the horizontal direction. By designing the action mask to realize turning or climbing in advance to avoid failure caused by going out of bounds. 2) Regarding launching missiles, avoid launching missiles when you are in a state where launching missiles is impossible to hit or when there are no missiles. Action masks help the algorithm cut down on exploration and use limited resources to explore more meaningful states. Network Structure The network structure of the reinforcement learning part of a single agent is shown in Fig. 1. The observed value of the environment is composed of two parts: direct observation Od and indirect observation Oid which needs further calculation. After normalization, Sd and Sid are obtained respectively, and these two parts constitute the state of the whole problem. Then input the state to an actor network and four critic networks (the other three critic networks with the same structure are omitted in the figure). The actor network and the critic network are composed of four layers of fully connected layers, the number of hidden units in each layer is 512, and the activation function is ReLU. In addition, there is a module called Action Mask that calculates the action mask of the current state based on the environment observation.
Fig. 1. The neural network structure of the RL part of the agent
660
S. He et al.
3.3 Neural Fictitious Self-play with Best Response History Correction The training methods used in previous work [13, 14, 17] on the application of reinforcement learning to air combat are often based on self-play. However, Using self-game in a non-transitive game will make the trained models of each generation fall into a cycle. In short, there are three policies A, B, and C in a non-transitive game, and A > C cannot be obtained when A > B and B > C are satisfied. Fictitious Self-Play [8] can alleviate this problem, and it has been proved that it can converge to approximate Nash equilibrium in imperfect-information poker games. In the Fictitious self-play, the current policy of the agent is obtained by the weighted average of the best response(BR) and the historical average policy. Among them, the BR is the best response of the agent to the opponent’s historical average policy, and the historical average policy is the historical average of BR for one’s own side. In the Neural Fictitious Self-Play [9], the best response and the historical average policy are represented by the neural network, as shown in Fig. 2. Each agent in the Neural Fictitious Self-Play consists of a best response policy network and a historical average policy network. The best response policy network is learned and trained by a reinforcement learning algorithm, such as SAC with an action mask used in this paper. The historical average policy network is trained by supervised learning, and the training samples come from the past average behavior of the best response policy network. The supervised learning part uses the logarithmic loss function: (8) J = E(s,a)∼MSL −log(a|s) . The final output policy of the agent trained by NFSP is a weighted mixture of the best response policy and the historical average policy, where the weight is η.
Fig. 2. The NFSP framework
Advancing Air Combat Tactics with Improved Neural Fictitious
661
Fictitious Play can converge to Nash equilibrium in theory, but Neural Fictitious Self-Play uses the neural network to approximate the best response policy, and there is a certain gap between the obtained and real best response. Because each time the best response policy network is used to approximate the best response, the network may not converge completely, and there will always be a deviation. For example, in the game of rock-paper-scissors, the opponent’s strategy is fixed rock, and our best response policy network outputs a policy of 25% rock, 5% scissors, and 70% paper after a round of training. This is not the best response. If we do not choose scissors because of the 30% probability, then the supervised learning buffer used to learn the historical average policy will be mixed with many non-best response samples, which will cause errors in the learning of the historical average policy. Therefore, we propose an NFSP with the best response history correction (NFSP-BRHC) to solve this problem. NFSP-BRHC no longer stores all the output history of the best response strategy network into the supervised learning buffer of the historical average policy, but stores the policy of the final winning episode output by the best response network into the SL buffer. Because in the two-person zero-sum game with long-term sparse rewards, the real reward signal is only the final victory or failure. In this game, the final winning decision sequence is more consistent with the definition of the best response than the failed decision sequence generated due to incomplete convergence, and it is also more valuable in training the historical average policy. In addition, considering that there are few samples in the supervised learning (SL) buffer in the early stage of training, so the best response strategy should account for a greater weight in the mixed strategy output. With the increase of training times, the best response policy samples in the SL buffer continue to expand, and the weight of the average policy should continue to increase. In this way, the second improvement of the NFSP-BRHC algorithm is obtained, and the fixed mixing coefficient is improved to decay to a threshold with the number of training generations. The improvements to the SL buffer samples and the attenuated mixing coefficients constitute the complete NFSP-BRHC training algorithm. Combining NFSP-BRHC and SAC algorithm with action mask, we get NFSP-BRHC with SAC, shown in Algorithm 1. 1v1 BVR air combat is a two-person zero-sum game with long-term sparse rewards, and other auxiliary rewards are set to help the algorithm converge faster and explore better. In the experimental part, we verified the effectiveness of the algorithm improvement on the 1v1 BVR air combat problem.
662
S. He et al.
4 Experiments 1v1 BVR air combat is a two-person zero-sum game with long-term sparse rewards, and other auxiliary rewards are set to help the algorithm converge faster and explore better. In the experimental part, we verified the effectiveness of the algorithm improvement on the 1v1 BVR air combat problem. 4.1 Simulation Environment and Problem Setting The 1v1 BVR air combat simulation environment in this article is a high-fidelity 6-DOF simulation environment developed based on C++ language, and the flight dynamics models of the corresponding aircraft and missiles are consistent with those in the DCS world. The aircraft used in our experiments is the F-16 fighter jet. The battlefield for BVR air combat is set as a rectangular area with a length of 100 km from north to south and a width of 50 km from east to west. The bases of both sides are 10 km in size from north to south, and 50 km in width from east to west, which is located at the southernmost and northernmost points of the battlefield respectively. The two opposing sides carry two
Advancing Air Combat Tactics with Improved Neural Fictitious
663
AIM-120 and set off from the center of their respective bases at a height of 3 km. If the aircraft is out of bounds, it will be directly judged to be a failure. The out-of-bounds actions include the position exceeding the battlefield boundary and crashing into the sea. 4.2 Experiment Setup In the experiments, the discount factor is set as 0.99 and the SAC algorithm in the reinforcement learning part uses the Adam optimizer with an initial learning rate of απ 5.0e−5 for the actor network and αQ 8.0e−5 for the critic network with a target entropy of −10. The neural network in the supervised learning part contains five hidden layers and uses the Adam optimizer with an initial learning rate of α 5.0e−5. The probability of the BR policy being selected in the NFSP-BRHR, η, is initially 1 and decays to a threshold of 0.1 as the experiment proceeds. In the experiments, each action lasts 0.5s, which is an appropriate value, considering the difference between BVR air combat and Dog Fight. If it is too long, it will be difficult to learn complex maneuvering behaviors, and if it is too short, it will be difficult to train. 4.3 Ablation and Contrast Experiments In a two-person zero-sum game, comparing the rewards of algorithms is pointless. We use the winning rate of simulation experiments between agents trained by different training algorithms as a comparison standard. Action Mask Ablation Experiment To demonstrate the effectiveness of the action mask, we conduct ablation experiments. Using the SP, NFSP, and NFSP-BRHC three algorithms to train a total of 5000 episodes, about two million (2M) steps. Finally, the trained red/blue model with the action mask module and the red/blue model without the action mask module of the same training algorithm is simulated 100 times for BVR air combat, and the average winning rate is calculated as shown in Fig. 3(a).
Fig. 3. a) The ablation experiment of Action Mask. b) Comparative experiment with baseline training algorithm (Color figure online)
It can be found that the effect of the policy model with the action mask module trained by three different training methods is better than that without the action mask
664
S. He et al.
module. The experimental results prove that the action mask can indeed effectively help the agent to prune the exploration space, and use more resources in the places that need to be explored more, thereby helping the agent to achieve a higher winning rate. Comparison with Baseline Algorithms After the same number of training times as in the ablation experiment, 100 BVR air combat simulations are performed between the models obtained by the three training algorithms of SP, NFSP, and NFSP-BRHC. The average winning rate is finally shown in Fig. 3(b). It can be found that the relationship among the effects of the three training algorithms is satisfied: NFSP-BRHC > NFSP > SP, which is consistent with our expectations. The model trained by NFSP-BRHC will be closer to the Nash equilibrium and converge faster, while the model trained by NFSP will deviate from the Nash equilibrium due to the error strategy generated when the best response network does not fully converge. However, the models trained by these two training methods are better than SP. The experimental results confirm the analysis we conducted in the method section, and prove the effectiveness of our proposed NFSP-BRHC training method. 4.4 Tactical Evolution During Training As shown in Fig. 4, during the training, the agents adapt to each other’s policies and continuously realize the evolution of tactics. These behaviors show that the tactics of the agent trained by the NFSP-BRHC algorithm are constantly evolving during the training process. The policies of the two sides are adapting to each other during the training, and they are constantly approaching the Nash equilibrium in the process of adaptation. This phenomenon of mutual adaptation of policies also confirms our analysis in the Methods section.
Fig. 4. a) Initially, the agent did not launch missiles or fly directly towards the finish line. b) As training progressed, the agent learned to accelerate at the start to maintain level flight and win the race. c) However, after learning to launch missiles, the strategy of accelerating towards the line became ineffective, and the agent learned to judge possible missile trajectories to escape them. d) The agent mastered missile launching, often performing a crank maneuver to evade opponent missiles launched at a closer range. e) The agent learned to launch two missiles at different positions. f) Additionally, the agent learned to climb first in the game and increase the threat of launching missiles, which also helped it gain a situational advantage.
Advancing Air Combat Tactics with Improved Neural Fictitious
665
5 Conclusion In this paper, we propose a new training method NFSP-BRHC. This method corrects the best response historical samples during training and uses decayed best response weights, making the averaging policy much closer to the historical average of best responses. And, we combine the NFSP-BRHC algorithm and SAC with action mask and apply it to 1v1 BVR air combat decision-making. The experimental results show that in 1v1 BVR air combat, the agent trained by our proposed training algorithm has more advantages than the agents trained based on SP or NFSP. In addition, the NFSP-BRHC training framework can be used not only in 1v1 BVR air combat but also in other sparsely rewarded two-person zero-sum games. In the future, we intend to apply the NFSP-BRHC training framework to more games that conform to this definition to further verify the effectiveness of the training framework and continue to improve it. Acknowledgement. This work was supported by National major science and technology plan project, National Defense Science and Technology Foundation Reinforcement Program Key Project, the Strategic Priority Research Program of the Chinese Academy of Sciences.
References 1. Bansal, T., Pachocki, J., Sidor, S., Sutskever, I., Mordatch, I.: Emergent complexity via multiagent competition. arXiv preprint arXiv:1710.03748 (2017) 2. Burgin, G.H.: Improvements to the adaptive maneuvering logic program. Technical report (1986) 3. Ernest, N., Carroll, D., Schumacher, C., et al.: Genetic fuzzy based artificial intelligence for unmanned combat aerial vehicle control in simulated air combat missions. J. Def. Manag. 6(1), 2167–2374 (2016) 4. Fujimoto, S., Hoof, H., Meger, D.: Addressing function approximation error in actor-critic methods. In: International Conference on Machine Learning, pp. 1587–1596. PMLR (2018) 5. Guo, J., et al.: Maneuver decision of UAV in air combat based on deterministic policy gradient. In: 2022 IEEE 17th International Conference on Control & Automation (ICCA), pp. 243–248. IEEE (2022) 6. Haarnoja, T., Tang, H., Abbeel, P., Levine, S.: Reinforcement learning with deep energy-based policies. In: International Conference on Machine Learning, pp. 1352–1361. PMLR (2017) 7. Haarnoja, T., Zhou, A., Abbeel, P., Levine, S.: Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In: International Conference on Machine Learning, pp. 1861–1870. PMLR (2018) 8. Heinrich, J., Lanctot, M., Silver, D.: Fictitious self-play in extensive-form games. In: International Conference on Machine Learning, pp. 805–813. PMLR (2015) 9. Heinrich, J., Silver, D.: Deep reinforcement learning from self-play in imperfect information games. arXiv preprint arXiv:1603.01121 (2016) 10. Lillicrap, T.P., et al.: Continuous control with deep reinforcement learning. arXiv preprint arXiv:1509.02971 (2015) 11. Liu, P., Ma, Y.: A deep reinforcement learning based intelligent decision method for UCAV air combat. In: Mohamed Ali, M., Wahid, H., Mohd Subha, N., Sahlan, S., Md. Yunus, M., Wahap, A. (eds.) AsiaSim 2017. CCIS, vol. 751, Part I, pp. 274–286. Springer, Singapore (2017). https://doi.org/10.1007/978-981-10-6463-0_24
666
S. He et al.
12. Mnih, V., et al.: Human-level control through deep reinforcement learning. Nature 518(7540), 529–533 (2015) 13. Piao, H., et al.: Beyond-visual-range air combat tactics auto-generation by reinforcement learning. In: 2020 International Joint Conference on Neural Networks (IJCNN), pp. 1–8. IEEE (2020) 14. Pope, A.P., et al.: Hierarchical reinforcement learning for air-to-air combat. In: 2021 International Conference on Unmanned Aircraft Systems (ICUAS), pp. 275–284. IEEE (2021) 15. Qiu, X., Yao, Z., Tan, F., Zhu, Z., Lu, J.G.: One-to-one air-combat maneuver strategy based on improved TD3 algorithm. In: 2020 Chinese Automation Congress (CAC), pp. 5719–5725. IEEE (2020) 16. Schulman, J., Wolski, F., Dhariwal, P., Radford, A., Klimov, O.: Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347 (2017) 17. Sun, Z., et al.: Multi-agent hierarchical policy gradient for air combat tactics emergence via self-play. Eng. Appl. Artif. Intell. 98, 104112 (2021) 18. Van Hasselt, H., Guez, A., Silver, D.: Deep reinforcement learning with double q-learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 30 (2016) 19. Yang, Q., Zhang, J., Shi, G., et al.: Maneuver decision of UAV in short-range air combat based on deep reinforcement learning. IEEE Access 8, 363–378 (2019) 20. Zheng, H., Deng, Y., Hu, Y.: Fuzzy evidential influence diagram and its evaluationalgorithm. Knowl. Based Syst. 131, 28–45 (2017)
Recent Advances in Deep Learning Methods and Techniques for Medical Image Analysis
Power Grid Knowledge Graph Completion with Complex Structure Learning Zhou Zheng1 , Jun Guo1 , Feilong Liao1 , Qiyao Huang2 , Yingyue Zhang2 , Zhichao Zhao1 , Chenxiang Lin1 , and Zhihong Zhang2(B) 1 State Grid Fujian Electric Power Research Institute, Fuzhou 350007, Fujian, China 2 School of Informatics, Xiamen University, Xiamen 361005, Fujian, China
[email protected]
Abstract. In recent years, the knowledge graph has become a commonly used storage way for large-scale knowledge in the power grid. It has proved to have an excellent performance which helps people get specialized knowledge easier. However, generating new knowledge automatically in the incomplete knowledge graph is still an urgent problem to be resolved, which name is knowledge graph completion. Previous works do not pay enough attention to the structural information in the power grid knowledge graph, resulting in poor performance. In this paper, we propose a novel framework called Complex Structure Entropy Network (CSEN) to conduct multi-hop reasoning over a power grid knowledge graph with novel two-stage cognitive theory and von Neumann graph entropy. The paper evaluates the model on the power grid defects dataset in the link prediction task and shows the effectiveness of the proposed method compared to a variety of baselines. Keywords: Power Grid · Knowledge Graph Completion · Von Neumann Graph Entropy
1 Introduction The power grid plays a vital role in modern society by providing a safe, reliable, and costeffective energy supply to industrial, commercial, and residential customers. However, malfunctions in electrical and mechanical components often compromise the reliability of the power grid, making it imperative to identify faulty equipment through expert knowledge. Technical power system knowledge must be stored efficiently, and useful information must be accessed quickly to ensure the optimal functioning of the power grid. To this end, knowledge graphs have emerged as a popular storage method for large-scale knowledge, with excellent performance in the power grid domain [5, 11]. Despite their effectiveness, knowledge graphs face a significant challenge in incomplete knowledge graphs, which requires generating new knowledge automatically. This challenge, known as knowledge graph completion, has become a pressing issue in recent years. To address this problem, there are two types of methods for knowledge graph reasoning: structure-based methods and text-based methods. However, both methods fall short of capturing structural information, leading to poor performance on unseen paths. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 669–679, 2023. https://doi.org/10.1007/978-981-99-4761-4_56
670
Z. Zheng et al.
To overcome this limitation, we propose a novel approach that leverages von Neumann graph entropy [7], extensively used to characterize salient features in static and dynamic network systems, to capture structural information in the power grid knowledge graph. We use a recent work that enables the efficient computation of von Neumann graph entropy using a quantum analogy. Furthermore, we build a two-stage reasoning model based on cognitive theory as the trunk of our model. This model takes into account the cognitive processes involved in problem-solving and decision-making, enhancing the overall performance of our approach [6]. By combining these innovative techniques, we believe our approach will significantly improve knowledge graph completion in the power grid domain, leading to a more reliable and efficient energy supply for society. Our contributions can be summarized as follows: 1. To our best knowledge, we are the first to utilize the von Neumann graph entropy in knowledge graph completion that captures the structure relationship in knowledge. 2. We propose a novel framework Complex Structure Entropy Network (CSEN) to conduct multi-hop reasoning over a power grid knowledge graph with novel two-stage cognitive theory and von Neumann graph entropy. 3. We evaluate our model on the power grid defects dataset in the link prediction task. The results show the effectiveness of our proposed method compared to a variety baseline.
2 Problem Formulation Definition 1. Knowledge Graph. A knowledge graph is typically viewed as a collection of pre-existing facts. To formally represent the knowledge graph, we use the notation G = (E, R, T), where E and R denote the entity and relation sets, respectively, and T is the set of triples. A triple is composed of a head entity es , a relation r, and a tail entity eo . For instance, the triple (es , r, eo ) ∈ T indicates that there exists a relation r ∈ R between the entities es ∈ E and eo ∈ E. To enhance connectivity within the graph, we also add inverse links (es , r −1 , eo ) to every (es , r, eo ) ∈ T. Definition 2. Knowledge Graph Completion. The task of knowledge graph completion involves utilizing known facts to deduce unknown or missing ones, and it is frequently employed to evaluate the performance of a module in knowledge graph reasoning or completion. More specifically, when presented with an incomplete triple (es , r, ?) or (?, r, eo ), the goal of link prediction is to identify one or more correct answers for the missing entity (Fig. 1).
Power Grid Knowledge Graph Completion with Complex Structure Learning
671
Fig. 1. The example of power grid knowledge graph completion. From a human perspective, since both the Oil-immersed Reactor and the Oil-immersed Transformer have the same component Iron Core, this defect should also exist in the Oil-immersed Reactor. In the incomplete knowledge graph, the task aims to build the unseen relationship like “Failure Phenomenon” between the “Oilimmersed Reactor” and the “Destruction of Insulation”. It needs the model successfully capture the potential structure information and think like humans more.
3 Related Work 3.1 Knowledge Graph Completion Method The task of knowledge graph completion can be broadly categorized into two types: structure-based methods and text-based methods [2]. Structure-based methods. The structure-based methods focus on learning the paths between the entities in the knowledge graph and evaluating the confidence of candidate entities using statistical functions or confidence metrics. For instance, Yang et al. [17] propose an end-to-end module that learns first-order logical rules and structural information. DRUM [14] considers unseen nodes and links the confidence scores to low-rank tensor approximation. NBFNet [20] represents node pairs using paths and solves path formulation by utilizing the Bellman-Ford algorithm. Although path-based methods are explainable, embedding-based methods generally have better performance. SMORE [13] learns an embedding-based method using contrastive learning that can perform both single-hop and multi-hop reasoning. FuzzQE [3] is a fuzzy logic-based logical query embedding framework capable of answering FOL queries over KGs. ReLMLKG [1] performs joint reasoning on a pre-trained language model and the associated knowledge graph to bridge the gap between the question and the knowledge graph. Text-based methods. The text-based methods employ knowledge embeddings with BERT, which has been a recent trend in knowledge graph completion. Researchers have
672
Z. Zheng et al.
also explored integrating knowledge graphs and BERT in the professional sphere [9]. For instance, Yao et al. [10] utilized BERT to embed knowledge graphs and successfully transferred specialized knowledge to the model. KG-BERT [18], another development in this area, demonstrated positive outcomes in the completion of knowledge graphs and link prediction. It accomplished this by employing the KG-BERT language model to take input from entity and relation descriptions of a triple and compute the scoring function of the triple. 3.2 Von Neumann Graph Entropy Since its introduction, the von Neumann graph entropy has been the subject of numerous studies, and it has found applications in various fields such as network analysis, information theory, and statistical mechanics.[7, 16, 19] In this section, we present an overview of some of the most relevant works related to von Neumann graph entropy. One of the earliest applications of von Neumann graph entropy was in network analysis. De Domenico et al. utilized the von Neumann graph entropy for structural reduction in multiplex networks [4]. Li et al. employed the von Neumann entropy for network-ensemble comparison [8], and they demonstrated that it is a valuable tool for characterizing the topology of complex networks and distinguishing between different network classes. Furthermore, Wang used an approximation of the von Neumann graph entropy based on node degree statistics to model network evolution over time [15]. Compared to other types of graph entropy, the von Neumann entropy has lower time complexity since a quadratic approximation of the von Neumann entropy yields a straightforward expression for the entropy related to the degree combinations of nodes forming edges [7]. They discovered that the entropy measure is an excellent predictor of network robustness and can assist in identifying critical nodes and edges in a network. However, to our knowledge, there is limited research on utilizing von Neumann graph entropy in knowledge graph completion.
4 Complex Structure Entropy Network In this section, we describe the proposed framework named Complex Structure Entropy Network. We first propose the approximate calculation of von Neumann graph entropy. Next, we show how the CSEN generates entity embedding with two-stage cognitive theory and von Neumann graph entropy. Finally, we report the method of the CSEN’s training process. 4.1 Von Neumann Graph Entropy Calculation It is now possible to define a graph using the von Neumann entropy by interpreting the scaled normalized Laplacian as a density operator. Given the definition of the density matrix used by Severini et al., the von Neumann entropy can be calculated from the normalized Laplacian spectrum in the manner shown below [12]: SVN = −Tr(ρlogρ) = −
|V | λi λi log , |V| |V| i=1
(1)
Power Grid Knowledge Graph Completion with Complex Structure Learning
673
where λ1 , , λ|V| are the eigenvalues combinatorial Laplacian matrix. This form of von Neumann entropy has been shown to be effective for network characterization. In fact, Han et al. [7] have shown how to approximate the calculation of von Neumann entropy in terms of simple degree statistics. Their approximation allows the cubic complexity of computing the von Neumann entropy from the Laplacian spectrum to be reduced to one of quadratic complexity using simple edge degree statistics, i.e. ⎧ in ⎫ du dvin ⎬ ⎨ + out out 1 1 du dv − SVN = 1 − , (2) ⎩ duout dvin ⎭ |V| 2|V|2 (u,v)∈E
where V is the node set and E is the edge set of the undirected graph G = (V, E), (u, v) ∈ V is an edge connecting nodes u and v and d out is the out degree of node u. In summary, we obtained the approximate calculation of von Neumann entropy and applied it to the entity embedding module with two-stage cognitive theory. 4.2 Two-Stage Power Grid Entity Embedding Module Overview. Motivated by the dual process theory in cognitive science, the Entity Embedding Module models the relationships between reasoning paths, which contain essential structural information of the graph. Specifically, There are two stages: Stage 1 mainly focuses on searching the clue paths of which the compositional semantic information relates to the given query. Then, the consequent candidate entities are provided for the embedding in Stage 2, which mainly focuses on meticulously modeling structure information with von Neumann Graph Entropy and gets the final results. Stage 1: Searching. For a query entity, we get a candidate entity set consisting of neighbor entities and initialize expanded scores. The expanded score S t {en } of candidate entity en in t steps is calculated as: Qt {en−1 } = MLP(X (en−1 )//vrq //vq ]),
(3)
Nt {en } = MLP(X (en )//vrn ]),
(4)
St {en } = σ (Nt · Qt ),
(5)
where en is a neighbor of en−1 , vrq is the vector of query relationship, vq is the vector of query entity, Qt {en−1 } and N t {en } means the embedding of en−1 and en in t steps. We select top − k edges with the largest probability and update cognitive graph G with them. Specifically, selected edges are added into edge set E, and tail entities that have not been visited are added into a node set V, which completes the expansion of cognitive graph G. Therefore, after the t steps, we can get the expanded score of entities in the expanded graph. And we normalize the score to get the new attention distribution:
(6) pt eq = Softmax(St {e1 }, St {e2 }, ..., St {en }). where pt (eq ) is the answer score distribution in all graphs and the S t {en } is the entity en ’s expand score. If an entity not be added to the expanded graph, the score S will be zero.
674
Z. Zheng et al.
In summary, the module of stage 1 expands the candidates related to the question using the embedding from stage 2 in each step and outputs the answer score when ending the extension. Stage 2: Embedding. To update the representation of entities, we choose a graph neural network, which modules the entity in the graph by aggregating its neighbor entities. The message pass to entity ei can be represented as: Update(ei , mei ) =
mei =
1 mei , |Eei |
(7)
Message_Layer(ek, rk, ei, SVN ),
(8)
(ek,rk,ei)∈E
Message_Layer(ek , rk , ei , SVN ) = GRU (X [ek ], vrk , vei , SVN ),
(9)
where M (ek , r k , ei , S V N ) is the message vector pass from ek to ei , and S V N is the von Neumann graph entropy of the edge (ek , r k , ei ). Here we update all representations on the same layer in sequence instead of calculating from previous layers. Considering the fact that most reasoning paths are relatively short, we choose GRU as a message function instead of many complex networks. Here GRU (·) update GRU with previous entity latent representation X[E k ] and the connection of relation embedding vrk and entity embeddingve . Let E ei = {(ek , r k , ei )|(ek , r k , ei ) ∈ E}, the update result is the average of all messages. Finally, we use the function to compute the loss for CSEN in training as follows: (pt , e) =
−logpt (e) −log(1 + − e pt (e))
pt (e) > 0 pt
(e) = 0
,
where is a hyperparameter close to zero. The function can avoid cases that division by zero. Table 1. Information about training, validation and test dataset Dataset
Triples
entities
relations
Training
221356
87224
13
Validation
37412
35757
13
Test
37413
35756
13
5 Experiment 5.1 Experimental Setup Power Grid Defects Dataset. It is a real-world dataset derived from recordings of grid defects that includes key details like location, description, cause, level, and so forth. These recordings are manually gathered from routine maintenance and then converted
Power Grid Knowledge Graph Completion with Complex Structure Learning
675
into triplets from structured data. We only include nodes whose sum of in-degree and outdegree is greater than three in order to guarantee that a reasoning path between the head entity and the tail entity is possible. In our dataset, there are 105535 entities and 296181 triples. The power grid knowledge graph has a 2.66e-5 graph density. With an 8:1:1 ratio, we divided the training, validation, and test sets. Table 1 contains comprehensive information about our dataset. Evaluation Metrics. For Evaluation, We mask the correct tail entities of triples from the testing and validation dataset and test the model to find the correct answer entities. And we use two kinds of metrics as follows: – Hits@k. Hits@k is a metric used to evaluate the accuracy of a knowledge graph’s link prediction task. It measures the proportion of true positives in the top k predicted links. In other words, it measures the percentage of times that the correct link appears within the top k-ranked links for a given entity. The higher the value of Hits@k, the better the performance of the model. – MRR. MRR stands for Mean Reciprocal Rank, which is a metric used to evaluate the quality of the ordering of a set of predicted links. It calculates the average reciprocal rank of the first correct prediction. In other words, it measures how high up in the ranking the first correct answer appears. The higher the MRR score, the better the performance of the model. In detail, we also use MRR to evaluate the performance of the current model checkpoint, and the k of hit ratio is set to 1 and 10. Baselines. We have selected five strong baselines that are widely used or highly regarded in the field to compare the performance of our proposed module: – TransE. TransE models entities and relations as vectors in a common space and learns to predict a missing entity in a triple by translating the embedding of the head entity by the embedding of the relation to approximate the embedding of the tail entity.
Table 2. Experiment result of link prediction Method
Hit@1
Hit@10
MRR
TransE
2.80
6.92
4.34
TransR
0.01
1.54
0.57
TransH
2.71
6.89
4.27
RED-GNN
33.60
47.81
38.94
EIGAT
52.01
73.27
59.97
CSEN w/o Stage2
65.01
80.76
71.13
CSEN
66.72
81.04
72.32
676
Z. Zheng et al.
– TransR. TransR further extends TransH by representing each relation as a matrix that maps the head entity and tail entity embeddings to a new space, allowing for more flexible modeling of complex relationships between entities. – TransH. TransH extends TransE by introducing a hyperplane for each relation in the embedding space to capture the heterogeneity of relations. – RED-GNN. RED-GNN introduces a relational directed graph to capture more complex structural information than traditional path-based methods. – EIGAT. EIGAT is a method that calculates graph attention while considering global entity importance. Implementation Details. For our proposed model, we chose a hidden representation size of 200 and an embedding size of 768. In addition, we set the maximum number of edges used when updating cognitive top − k to 64, and the maximum number of reasoning steps to l = 4. The model was trained for a total of 100,000 steps, and each entity was allowed to take up to 100 random walks during training. To optimize the model during training, we used Adam optimization with a learning rate of 1e-4. To evaluate the effectiveness of our proposed module, we compared its performance against the five baselines mentioned above. Through rigorous experimentation and evaluation, we were able to demonstrate that our proposed module outperformed these baselines and achieved state-of-the-art results in graph reasoning tasks. 5.2 Performance Comparison Table 2 presents the results of our approach on both Hits@k and MRR metrics, which demonstrate the state-of-the-art performance of our model on the power grid dataset. The results show that our models significantly outperform all baselines, including EIGAT, which is the second strongest baseline. Specifically, as shown in Fig. 2(a) and Fig. 2(b), CSEN improved Hits@1 14% compared to EIGAT. This significant improvement indicates that our approach successfully captures the structural information of the power grid knowledge graph and makes good use of it for the KGC tasks.
(a) Hits@1
(b) Hits@10
Fig. 2. The performance in Hits@1 and Hits@10.
Power Grid Knowledge Graph Completion with Complex Structure Learning
677
Furthermore, our model achieved a greater improvement in MRR in Fig. 3(a), indicating that the two-stage process theory and von Neumann graph entropy enhance precise reasoning, which is particularly valuable for the power grid company. It is worth noting that traditional embedding-based methods such as TransE, TransR, and TransH performed poorly on this dataset, suggesting that they are not suitable for grid defect diagnosis. This further highlights the superiority of our approach and its ability to meet the requirements of the power grid domain. In summary, our results demonstrate the superiority of our model in the power grid domain and provide valuable insights into the importance of capturing structural information and precise reasoning for KGC tasks. 5.3 Ablation Study In this experiment, we conducted an ablation study by removing stage 2 from our model, which is responsible for updating the entity embedding in the expanded graph. As shown in Fig. 3(b), the results of the ablation study clearly demonstrate that the model’s performance drops significantly when stage 2 is removed. This finding is significant because it suggests that the stage 2 process plays a crucial role in enhancing the model’s ability to perform precise reasoning. By updating the entity embeddings in the expanded graph, stage 2 allows the model to better capture the relationships and dependencies between different entities in the knowledge graph. Without these updated embeddings, the model’s performance suffers, highlighting the importance of this process. Overall, this ablation study provides valuable insights into the workings of our model and underscores the critical role played by stage 2 in enhancing the model’s ability to perform precise reasoning.
(a) MRR
(b) Ablation Study of Stage 2
Fig. 3. The performance in MRR and ablation study.
678
Z. Zheng et al.
6 Conculsion A power grid knowledge graph is important for storing technical power system knowledge and querying useful knowledge faster. But generating new knowledge automatically in an incomplete knowledge graph is still a challenge which is called knowledge graph completion. In this paper, we propose a novel framework called Complex Structure Entropy Network (CSEN) to conduct multi-hop reasoning over a power grid knowledge graph with novel two-stage cognitive theory and von Neumann graph entropy. The paper evaluates the model on the power grid defects dataset in the link prediction task and shows the effectiveness of the proposed method compared to a variety of baselines. Acknowledgment. This work is supported by the Research Funds from State Grid Fujian (SGFJDK00SZJS2200162).
References 1. Cao, X., Liu, Y.: Relmkg: reasoning with pre-trained language models and knowledge graphs for complex question answering. Applied Intelligence, pp. 1–15 (2022) 2. Chen, X., Jia, S., Xiang, Y.: A review: knowledge reasoning over knowledge graph. Expert Syst. Appl. 141, 112948 (2020) 3. Chen, X., Hu, Z., Sun, Y.: Fuzzy logic based logical query answering on knowledge graphs. In: Proceedings of the AAAI Conference on Artificial Intelligence, 36, pp. 3939–3948 (2022) 4. De Domenico, M., Nicosia, V., Arenas, A., Latora, V.: Structural reducibility of multilayer networks. Nat. Commun. 6(1), 1–9 (2015) 5. Ding, H., Qiu, Y., Yang, Y., Ma, J., Wang, J., Hua, L.: A review of the construction and application of knowledge graphs in smart grid. In: 2021 IEEE Sustainable Power and Energy Conference (iSPEC), pp. 3770–3775. IEEE (2021) 6. Gawronski, B., Creighton, L.A.: Dual Process Theories (2013) 7. Han, L., Escolano, F., Hancock, E.R., Wilson, R.C.: Graph characterizations from von neumann entropy. Pattern Recogn. Lett. 33(15), 1958–1967 (2012) 8. Li, Z., Mucha, P.J., Taylor, D.: Network-ensemble comparisons with stochastic rewiring and von neumann entropy. SIAM J. Appl. Math. 78(2), 897–920 (2018) 9. Liu, Q., Kusner, M.J.: A Survey on Contextual Embeddings. arXiv preprint arXiv:2003.07278 (2020) 10. Liu, W., Zhou, P.: K-bert: Enabling language representation with knowledge graph. In: Proceedings of the AAAI Conference on Artificial Intelligence (2020) 11. Meng, F., Yang, S., Wang, J., Xia, L., Liu, H.: Creating knowledge graph of electric power equipment faults based on bert–bilstm–crf model. J. Electrical Eng. Technol. 17(4), 2507– 2516 (2022) 12. Passerini, F., Severini, S.: Quantifying complexity in networks: the von neumann entropy. Int. J. Agent Technologies and Systems (IJATS) 1(4), 58–67 (2009) 13. Ren, H., et al.: Smore: Knowledge graph completion and multi-hop reasoning in massive knowledge graphs. In: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 1472–1482 (2022) 14. Sadeghian, A., Armandpour, M., Ding, P., Wang, D.Z.: Drum: end-to-end differentiable rule mining on knowledge graphs. Adv. Neural Information Processing Syst. 32 (2019) 15. Wang, J.: Statistical Mechanics for Network Structure and Evolution. Ph.D. thesis, University of York (2018)
Power Grid Knowledge Graph Completion with Complex Structure Learning
679
16. Xian, Y., Fu, Z., Muthukrishnan, S., De Melo, G., Zhang, Y.: Reinforcement knowledge graph reasoning for explainable recommendation. In: Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 285–294 (2019) 17. Yang, F., Yang, Z., Cohen, W.W.: Differentiable learning of logical rules for knowledge base reasoning. Adv. Neural Information Processing Syst. 30 (2017) 18. Yao, L., Mao, C.: Kg-bert: Bert for Knowledge Graph Completion. arXiv preprint arXiv: 1909.03193 (2019) 19. Ye, C., Wilson, R.C., Comin, C.H., Costa, L.d.F., Hancock, E.R.: Approximate von neumann entropy for directed graphs. Physical Review E 89(5), 052804 (2014) 20. Zhu, Z., Zhang, Z., Xhonneux, L.P., Tang, J.: Neural bellman-ford networks: a general graph neural network framework for link prediction. Adv. Neural. Inf. Process. Syst. 34, 29476– 29490 (2021)
A Graph-Transformer Network for Scene Text Detection Yongrong Wu1 , Jingyu Lin1 , Houjin Chen1 , Dinghao Chen1 , Lvqing Yang1(B) , and Jianbing Xiahou2(B) 1 School of Information, Xiamen University, Xiamen 361000, China
[email protected]
2 Quanzhou Normal University, Quanzhou 362000, Fujian, China
[email protected]
Abstract. Detecting text in natural images with varying orientations and shapes is challenging. Existing detectors often fail with text instances having extreme aspect ratios. This paper introduces GTNet, a Graph- Transformer network for scene text detection. GTNet uses a Graph-based Shared Feature Learning Module (GSFL) for feature extraction and a Transformer-based Regression Module (TRM) for bounding box prediction. Our architecture offers a flexible receptive field, combining global attention and local features for enhanced text representation. Extensive experiments show our method surpasses existing detectors in accuracy and effectiveness. Keywords: Scene Text Detection · Transformer · Graph convolutional network
1 Introduction Scene text detection, aimed at identifying and localizing text in natural images, has seen increased research attention due to its diverse applications, like sign recognition and autonomous driving. However, the task is challenging due to the irregular shapes of texts, varied scales, random rotations, and curved shapes in natural scenes, often leading to subpar performance. The accurate localization of text instances with varied shapes and sizes in natural scenes is a significant challenge in scene text detection. In recent years, segmentationbased methods have gained popularity as they employ pixel-level predictions and postprocessing algorithms to improve detection accuracy. However, binary class segmentation remains a bottleneck in addressing the problem, particularly for text instances with extreme aspect ratios. For example, long text instances and those with wide spacing. Such challenges have been demonstrated in Fig. 1. To better comprehend image semantics, context information is crucial for scene text detection. The range of the receptive field, which determines the overall perception capability of the model, plays a crucial role in this regard. In this work, we introduce a Graph-Transformer framework for scene text detection, offering a novel solution to extract local semantic information and resolve text box data. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 680–690, 2023. https://doi.org/10.1007/978-981-99-4761-4_57
A Graph-Transformer Network for Scene Text Detection
681
Fig. 1. Some detection results of text instances in natural scenes
Our approach, which merges graph-based methods and deep learning, enhances semantic information extraction quality and speeds up convergence, with a focus on regions of interest. The decoder generates text representations suitable for the task by regressing the coordinates of rotated bounding boxes as 8 control points of a Bezier-Curve [1] for arbitrary-shaped text. We have conducted extensive experiments on five prevalent datasets to showcase the effectiveness of our proposed scheme. With few post-processing techniques, our method can achieve excellent F1-scores, outperforming several state-of-the-art scene text detectors.
2 Related Work The advent of deep learning has greatly improved scene text detection, replacing old heuristic methods. Deep Convolutional Neural Networks (CNNs), inspired by object detection, are now used for text detection. They effectively handle text geometry, including bounding box regression, mask generation, contour prediction, and centerlines generation. These methods fall into two categories: segmentation- based and regression-based methods. 2.1 Segmentation-based Methods Segmentation-based text detection is a popular approach that utilizes pixel-level prediction and post-processing algorithms. To achieve real-time performance, DB [2] proposed a framework that combines segmentation and binarization for detecting arbitraryshaped text. Similarly, MOTD [3] uses semantic segmentation to detect text blocks and MSER [4] to find character proposals for these blocks. PixelLink considers text detection as a binary per-pixel prediction problem and uses a connected componentsbased method to generate final detection results. PSENet [5] uses different scale kernels to segment text instances, while SegLink [6] predicts all text line segments and then learns relations between them to obtain the bounding boxes. SegLink++ [7] proposes an instance-aware component grouping algorithm to improve the distinction between adjacent text instances. However, such methods require additional segmentation heads and post-processing steps, leading to high inference time.
682
Y. Wu et al.
2.2 Regression-based Methods Although traditional object detection methods can be used for text detection, regressionbased methods are often preferred due to their straightforwardness. These methods, such as Textboxes [8], TextBoxes++ [9], Deep TextSpotter [10], EAST [11], and RotationSensitive Regression [12], typically utilize popular object detectors like Faster R-CNN [13] and SSD [14] to predict text regions in an end-to-end manner. EAST, for example, uses SSD to directly regress the target region of interest with rotated rectangles or quadrilateral descriptions. However, these methods are specifically designed for detecting multiple-oriented text and may not perform well on curved text, which is common in real-world scenarios. To address this limitation, recent works such as LOMO [15], ABC-Net [1], and FCENet [16] have introduced novel representations and architectures. LOMO utilizes Mask-RCNN as its backbone to leverage both segmentation- and regression-based approaches, while ABC-Net introduces the Bezier curve representation to detect multi-oriented and curved text instances. FCENet uses Inverse Fourier Transforms to reconstruct text contours of arbitrarily-shaped text instances. Most existing regression-based techniques for scene text detection require regressing set coordinate points to form a text area. However, these methods struggle with detecting irregular text in natural scenes. In offline scenarios, where text detection precedes text recognition, segmentation-based methods perform better.
3 Method This section presents an overview of our scene text detection method, which comprises two main modules. We begin by introducing the overall architecture and subsequently discuss the benefits of our graph-based encoder and Transformer-based decoder. Furthermore, we provide a detailed analysis of the two modules to illustrate their effectiveness in the proposed method. 3.1 Overall Architecture Our proposed GTNet employs a graph-based encoder and an efficient Transformer decoder for scene text detection. The encoder extracts and combines multi-scale feature maps with positional encodings, which the Transformer decoder then refines. To focus on different characters, the decoder uses learnable vectors or text queries. Figure 2 provides a visual representation of our proposed GTNet architecture. 3.2 Graph-based Shared Feature Learning Module The input image is processed in a pipeline, as shown in Fig. 2. The input image is divided into non-overlapping tiles using a sliding window approach. These tiles are then used to create a graph G = V, E, with V representing the tile feature embeddings and E indicating 8-adjacent tile bordering relationships. This graph captures the image’s spatial and feature relations, crucial for further analysis.
A Graph-Transformer Network for Scene Text Detection
683
Fig. 2. The overall framework of the proposed GTNet. The architecture is composed of a graphbased encoder, a Transformer decoder, and a polygon prediction head.
Our proposed Graph-Transformer architecture starts by extracting the low- level representation of graph G using a Graph Convolution (GC) layer. As demonstrated in previous research [17], the GNN component within the architecture learns representations of graph nodes based on their neighborhood features. This aggregation of neighborhood information by GNN is particularly useful for capturing local and short-range correlations between graph nodes, making it an ideal shared layer for multiple tasks. The graph’s message propagation and aggregation are defined as such: H
(l+1)
1
1
∼− 2 ∼ ∼− 2
= σ (D
AD
H(l) W(l) )
(1)
Here, H(l) is the feature matrix of layer l. H(l+1) stands for the feature matrix of layer l ˜ is an adjacency matrix with + 1. σ represents an activation function, such as ReLU. A ˜ ˜ = A + IN . And ∼ D is the degree matrix of A. self-loops, A After feature aggregation, we flatten the aggregated features into a one-dimensional feature vector and input it into the feature decoder of the next stage. 3.3 Transformer-based Regression Module During the decoding stage, the Transformer-based Regression Module decodes the out. L represents the total number of patches, and d represents the put sequence feature vector length. Additionally, the module uses a set of learnable vectors as text , where M denotes the number of candidate text instances. M is usually queries set to a fixed value, which is required to be much larger than the maximum number of text instances that may appear in the dataset. In this model, M is set to 100. The decoder is composed of multihead attention modules, which allow the model to extract semantic features of text instances at different scales and significantly improve the performance of the decoder. By introducing these features and query vectors, the decoder can adaptively select the most relevant text instances and aggregate them into complete text instances. are fed into a FeedIn the latter part of the network, the decoded features Forward Network (FFN) for regression to obtain text box information, which includes the presence of corresponding text instances and the coordinates of the vertices of the polygon text box. During training, the module also uses the Hungarian matching algorithm to match the predicted text boxes with the ground truth text boxes and calculate the loss function. When dealing with irregular-shaped or curved text, the boundary box
684
Y. Wu et al.
regression calculation module based on the Transformer decoder uses the method proposed in ABCNet [1] to generate curved text boxes by calculating the control points of the Bezier curves. Assuming that n control points are set, 2 × n scalar values need to be regressed, representing the coordinates of the n Bezier curve control points. 3.4 Optimization In detection tasks, classification and detection losses are crucial. Classification loss evaluates the model’s accuracy in predicting each detection box’s category, and detection loss measures accuracy in predicting their position and size. GTNet method, designed for natural scene text detection, incorporates both losses in its function. The overall loss function formula is as follows:
L = λLclass + (1 − λ)Ldet
(2)
The category loss Lclass is used to measure the error between the predicted probability of the correct category of the text instance and the true annotation by the model. The detection loss Ldet is used to measure the difference between the predicted text box and the true text box. To adjust the weight of the two losses in the total loss, a scaling factor λ is used, where λ ∈ (0, 1). Before calculating the loss, the Hungarian matching algorithm is used to match the predicted text instances with the ground truth annotations, so that the correct predicted results can correspond when calculating the loss. Classification Loss Function. The model adopts the cross-entropy loss function to calculate the error between the predicted results after Hungarian matching and the ground truth results. The specific expression is shown below:
1 −[gˆ x · log(ˆpx ) + (1 − gˆ x ) · log(1 − pˆ x )] Lˆ class = N x
(3)
Here, N represents the total number of text queries, which is a fixed value and usually requires to be greater than the maximum number of text instances that may appear in a single image of the dataset. gˆ x denotes the text/non-text label, and pˆ x is the model’s confidence score for predicting a text instance. Detection Loss Function. In natural scenes, text instances often have extreme aspect ratios, making small and tiny text instances contribute less to the overall optimization in traditional detection loss functions, resulting in poor recall performance. To balance the detection performance of various scale text instances, we adopt a GIoU-based loss function. Compared to the conventional IoU loss function, GIoU more accurately measures the overlap between predicted and ground-truth boxes, considering bounding box area differences, and is better suited for handling text targets of varying sizes and aspect ratios. It also improves handling of text instance overlaps, thereby enhancing detection performance. Thus, the GIoU-based loss function improves the accuracy and robustness of text detection models. Its specific details are as follows: Lˆ det = Lgiou (bˆ i , bj ) = 1 − GIoU (bˆ i , bj )
(4)
A Graph-Transformer Network for Scene Text Detection
685
In this context, bˆ i and bj represent the ith predicted text instance and the jth groundtruth label, respectively. For any two oriented texts bˆ i and bj , the formula for calculating GIoU is as follows: GIoU (bˆ i , bj ) = IoU (bˆ i , bj ) −
Area(
C ) (bˆ i ,bj )
Area(C)
(5)
Here, C represents the minimum convex polygon that fully encloses both bounding boxes bˆ i and bj . Area() is a function for calculating the area. For irregular or curved text, the model refers to the use of Bezier curves in ABCNet [1] to regress the corresponding control points for generating curved text boxes.
4 Experimental Results and Analysis We introduce the datasets used for validation, detail implementation, and conduct extensive experiments on various text detection benchmarks. Our method is compared with SOTA approaches, with ablation studies and further comparisons demonstrating its effectiveness and generalizability. 4.1 Datasets The dataset used in this study is as follows. Synthtext [18]. The synthetic dataset used in this study consists of over 800,000 synthetic images, created by pasting words onto natural scenes following specific rules. This unique dataset is particularly effective for pre-training models due to its large size, providing valuable prior knowledge for real-world datasets. Total-Text [19]. Total-Text comprises 1,255 training images and 300 testing images containing the text of diverse shapes, including horizontal, multi-oriented, and curved forms. To facilitate annotation at the word-level, text regions are labeled with polygons. ICDAR2015 [20]. ICDAR2015 is a widely used dataset in the field of text detection. Initially introduced in the ICDAR 2015 Robust Reading Competition for detecting incidental scene text, the dataset comprises 1,500 images with a resolution of 720 × 1280, captured using Google glasses. Out of the total images, 1,000 are used for training and the remaining for testing purposes. The annotations in the dataset are word-level, and the text locations are indicated using quadrilateral boxes. MLT-2017 [21]. The MLT 2017 dataset is a comprehensive dataset of multilingual text that includes complete scene images from nine different languages and six scripts. The dataset consists of 7,200 training images, 1,800 validation images, and 9,000 testing images, all of which have been annotated with the four vertices of quadrilaterals to identify text regions. 4.2 Implementation Details We utilized PyTorch to implement our proposed framework and conducted all experiments on a workstation equipped with four NVIDIA RTX 3090 GPUs. During training, all methods were trained using a batch size of 8 with the Adam optimizer.
686
Y. Wu et al.
Our models were pre-trained on the synthetic dataset [18] for 200 epochs, and then fine-tuned on each real-world dataset for another 200 epochs. To ensure efficient learning, we adopted a “poly” learning rate policy, which gradually decayed the learning rate. Specifically, we set the initial learning rate for pre-training and fine-tuning to 1 × e−3 and 5 × e−4 respectively. 4.3 Ablation Study We have performed a series of ablation experiments on the MLT 2017 dataset to evaluate the efficacy of the crucial components of our proposed approach. The quantitative results obtained by four different variants of our model are presented in Table 1. These variants include (a) the baseline, (b) the baseline with GSFL, (c) the baseline with TRM, and (d) our proposed GTNet detector. (a) Baseline is a classical detector based on a convolutional neural network (CNN) encoder-decoder architecture (i.e., DB [2]). The ResNet-50 is used as the backbone, and it achieves an F1-score of 74.7% on the MLT-2017 dataset. Table 1. Ablation study on MLT 2017 validation set. The results are obtained by four variants of our proposed GTNet. The best result is highlighted in bold. Module
(a)
(b) √
(c) √
(d) √ √
74.7
75.2↑0.5
76.3↑1.6
78.2↑3.5
GSFL TRM F1-score(%)
Fig. 3. Qualitative results w / o GSFL and TRM modules. Pictures are sampled from Total-Text.
(b) Baseline with GSFL. The proposed graph-based encoder is used to replace the CNN backbone in the baseline detector to obtain the Baseline with the GSFL model, achieving an F1-score of 75.2% on MLT-2017. Compared with the baseline detector, this results in a 0.5% increase in F-score. The improvement can be attributed to the graph-based encoder’s ability to produce better representations that can effectively handle the problem of extreme aspect ratios. Specifically, the encoder aggregates spatial features from adjacent scales, which enhances the deteriorated spatial features of text instances.
A Graph-Transformer Network for Scene Text Detection
687
(c) Baseline with TRM. The use of Transformer decoder in the baseline detector has led to an increase in F1-score from 73.7% to 76.3% on MLT-2017 as compared to the baseline. The addition of Transformer blocks has enabled the algorithm to capture richer global information and directly regress the coordinate and rotated angle for each query, thus contributing to improved performance. (d) GTNet. The proposed approach, GTNet, utilizes both a dynamic encoder and Transformer decoder to enhance the baseline detector. GTNet achieves an F1-score of 78.2% on the MLT-2017 dataset, which is a substantial improvement of 3.5% over the baseline detector. The graph-based encoder and Transformer decoder work together in GTNet to aggregate the discriminative features of text instances and enhance the degraded features present in natural scene images. This joint learning approach leads to improved detection performance. Our design has been validated through experimental results, demonstrating its effectiveness. With the ability to obtain richer fused features and more discriminative representations, GTNet proves beneficial to text detection by enlarging the receptive field. To provide further inspection, we visualize our detection results in Fig. 3. 4.4 Comparisons with Previous Methods To assess the efficacy and adaptability of the proposed GTNet, we perform tests on multiple text detection datasets, including ICDAR2015 and Total-Text. The results, presented in tables, verify GTNet’s capacity to handle various complex text types. Table 2. Comparison results of the proposed GTNet method with other methods on the ICDAR 2015 dataset. Methods
Precision(%)
Recall(%)
F1-score(%)
PSENet [5]
86.9
84.5
85.7
LOMO [15]
91.3
83.5
87.2
PAN [22]
84.0
81.9
82.9
DB [2]
91.8
83.2
87.3
DRRG [23]
88.5
84.7
86.6
Raisi et al.[24]
89.8
78.3
83.7
DBNet++ [25]
90.9
83.9
87.3
GTNet(Ours)
89.7
86.1
87.9
Multi-oriented text detection. In this experiment, the proposed method GTNet is compared with the mainstream methods based on bounding box regression and image segmentation on the ICDAR 2015 dataset. The ICDAR 2015 dataset contains various text instances such as multi-oriented text, small text, blurry text, and narrow text, which poses a certain challenge. The experimental comparison results are shown in Table 2.
688
Y. Wu et al.
From Table 2, it can be seen that GTNet achieves the best detection performance. The F1-score of GTNet on the ICDAR 2015 dataset reaches 87.9%. Compared with the segmentation-based natural scene text detection methods PSENet [5] and DB [2], the accuracy of the GTNet method is improved by 2.2% and 0.6% in terms of F1-score, respectively. The PSENet method employs a Unet-like CNN for feature extraction and achieves predicted text regions through post-processing, including binarization. The DB method, meanwhile, mitigates instability and noise of traditional binarization, preserving more image details through differentiable binarization. It treats the pixel value threshold as a trainable parameter optimized according to an objective function, leading to improved performance. Nonetheless, the proposed GTNet method still surpasses the DB method in performance. Compared to conventional bounding box regression techniques for natural scene text detection, this chapter’s GTNet method outperforms notably. For instance, it exceeds the F1-score value of the DETR structure-based method by Raisi et al.[24]. by 4.2%. While their method struggles to detect blurry and small text, prevalent in the ICDAR 2015 dataset, our GTNet method excels, confirm ing its effectiveness in such scenarios. In summary, GTNet outperformed other top methods in natural scene text detection on the ICDAR 2015 dataset, verifying the benefits of optimized feature extraction and enhanced global semantic information. This substantiates the efficacy of GTNet. Table 3. Comparison results of the proposed GTNet method with other methods on the Total-Text dataset. Methods
Precision(%)
Recall(%)
F1-score(%)
PAN [22]
89.3
81.0
85.0
DB [2]
87.1
82.5
84.7
ContourNet [26]
86.9
83.9
85.4
Raisi et al.[24]
90.9
83.8
87.2
TextBPN [27]
90.7
85.2
87.9
DBNet + + [25]
88.9
83.2
86.0
GTNet(Ours)
90.7
85.7
88.1
Curved text detection. The Total-Text dataset is a natural scene text dataset that contains a large number of irregular and curved texts. In this experiment, the proposed method GTNet is compared with currently popular methods based on bounding box regression and image segmentation on this type of dataset. The experimental comparison results are shown in Table 3. DBNet++ builds upon the DB method and proposes an adaptive scale fusion module, which improves the robustness of the detector by adaptively fusing features of different scales. As shown in Table 3, compared with the image segmentation-based text detection method DBNet++, the F1-score of the GT- Net method exceeds that of DBNet++ by 2.1% on the Total-Text dataset. The GTNet outperforms other methods on the Total-Text dataset, which has abundant curved and irregular text instances. Traditional scale-fused
A Graph-Transformer Network for Scene Text Detection
689
features are inadequate for these complex shapes. However, GTNet effectively locates them using Bézier curves to regress text polygon vertices.
5 Conclusion In this paper, we present Graph-Transformer Architecture, a novel scene text detection approach. It uses a graph convolutional network to encode images and a Transformer decoder for NMS-free operations, resulting in lower computational time. This method, merging graph operations with a Transformer decoder, effectively extracts text instance features, providing a robust way for scene text detection.
References 1. Liu, Y., Chen, H., Shen, C., He, T., Jin, L., Wang, L.: Abcnet: real-time scene text spotting with adaptive beziercurve network. CVPR, pp. 9809–9818 (2020) 2. Liao, M., Wan, Z., Yao, C., Chen, K., Bai, X.: Real-time scene text detection with differentiable binarization. AAAI, pp. 11474–11481 (2019) 3. Zhang, Z., Zhang, C., Shen, W., Yao, C., Liu, W., Bai, X.: Multi-oriented text detection with fully convolutional networks. CVPR, pp. 4159–4167 (2016) 4. Matas, J., Chum, O., Urban , M.. Pajdla, T.: Robust wide-baseline stereo from maximally stable extremal regions. Image and Vision Computing (2004) 5. Wang, W., et al.: Shape robust text detection with progressive scale expansion network. CVPR, pp. 9336–9345 (2019) 6. Shi, B., Bai, X., Belongie, S.: Detecting oriented text in natural images by linking segments. CVPR, pp. 2550–2558 (2017) 7. Tang, J., Yang, Z., Wang, Y., Zheng, Q., Xu, Y., Bai, X.: SegLink++: detecting dense and arbitrary-shaped scene text by instance-aware component grouping. Pattern Recognition (2019) 8. Liao, M., Shi, B., Bai, X., Wang, X., Liu, W.: Textboxes: a fast text detector with a single deep neural network. AAAI (2017) 9. Liao, M., Shi, B., Bai, X.: Textboxes++: A single-shot oriented scene text detector. IEEE Trans. Image Process. 27(8), 3676–3690 (2018) 10. Liu, Y., Jin, L.: Deep matching prior network: toward tighter multi-oriented text detection. CVPR, pp. 1962–1969 (2017) 11. Zhou, X., et al.: East: an efficient and accurate scene text detector. CVPR, pp. 5551–5560 (2017) 12. Liao, M., Zhu, Z., Shi, B., Xia, G.-S., Bai, X.: Rotation-sensitive regression for oriented scene text detection. CVPR, pp. 5909– 5918 (2018) 13. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell. 39(6), 1137–1149 (2017) 14. Liu, W., et al.: SSD: single shot multibox detector. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 21–37. Springer, Cham (2016). https://doi.org/ 10.1007/978-3-319-46448-0_2 15. Zhang, C., et al.: Look more than once: an accurate detector for text of arbitrary shapes. CVPR, pp. 10552–10561 (2019) 16. Zhu, Y., Chen, J., Liang, L., Kuang, Z., Jin, L., Zhang, W.: Fourier contour embedding for arbitrary-shaped text detection. CVPR, pp. 3123–3131 (2021)
690
Y. Wu et al.
17. Wu, Z., Jain, P., Wright, M., Mirhoseini, A., Gonzalez, J.E., Stoica, I.: Representing longrange context for graph neural networks with global attention. Adv. Neural Inf. Processing Syst. 34, 13266–13279 (2021) 18. Gupta, A., Vedaldi, A., Zisserman, A.: Synthetic data for text localisation in natural images. CVPR, pp. 2315–2324 (2016) 19. Chng, C.-K., Chan, C.S.: Total-text: a comprehensive dataset for scene text detection and recognition. International Conference on Document Analysis and Recognition (2017) 20. Karatzas, D., Gomez-Bigorda, L., Nicolaou, A., Ghosh, S., Valveny, E.: ICDAR 2015 competition on robust reading. ICDAR, pp. 1156–1160 (2015) 21. Nayef, N., Fei, Y., Bizid, I., Choi, H., Ogier, J.M.: ICDAR2017 robust reading challenge on multi-lingual scene text detection and script identification - rrc-mlt. ICDAR, pp. 1454–1459 (2017) 22. Li, H., Xiong, P., An, J., Wang, L.: Pyramid Attention Network for Semantic Segmentation. arXiv:1805.10180 (2018) 23. Zhang, S.-X., et al.: Deep relational reasoning graph network for arbitrary shape text detection. CVPR, pp. 9699–9708 (2020) 24. Raisi, Z., Naiel, M.A., Younes, G., Wardell, S., Zelek, J.: Transformer-based text detection in the wild. CVPR, pp. 3162–3171 (2021) 25. Liao, M., Zou, Z., Wan, Z., Yao, C., Bai, X.: Real-time scene text detection with differentiable binarization and adaptive scale fusion. IEEE Transactions on Pattern Analysis and Machine Intelligence (2022) 26. Wang, Y., Xie, H., Zha, Z., Xing, M., Fu, Z., Zhang, Y.: Contournet: taking a further step toward accurate arbitrary-shaped scene text detection. CVPR, pp. 11750–11759 (2020) 27. Zhang, S.-X., Zhu, X., Yang, C., Wang, H., Yin, X.-C.: Adaptive boundary proposal network for arbitrary shape text detection. ICCV, pp. 1305–1314 (2021)
Hessian Non-negative Hypergraph Lingling Li1 , Zihang Li2 , Mingkai Wang3 , Taisong Jin3 , and Jie Liu4,5(B) 1 Henan Key Laboratory of General Aviation Technology, Zhengzhou University of
Aeronautics, Zhengzhou 450046, Henan, China 2 School of Computing and Data Science, Xiamen University Malaysia, 43900 Jalan Sunsuria,
Malaysia 3 School of Informatics, Xiamen University, Xiamen 361005, China 4 School of Information Science, North China University of Technology, Beijing 100144, China
[email protected] 5 China Language Intelligence Research Center, Capital Normal University, Beijing 100144,
China
Abstract. Hypergraph is a vital relationship modeling tool applied to many learning tasks. Compared with the common graph, the hypergraph can model the high-order relationship among more than two data samples. Thus, hypergraph provides more flexibility and the capability of modeling complex data correlations. However, real-world image data inevitably lie on or close to a thin layer of a manifold while containing the noise component. The practical image data influence the hypergraph learning procedure, subsequently resulting in a notable performance deterioration. To this end, we propose a novel Hessian non-negative hypergraph model. Specifically, the Hessian energy regularized non-negative constrained data reconstruction is used to generate manifold-respecting and noise resistant hyperedges. The extensive experiments on the benchmark image datasets demonstrate that the proposed hypergraph model can significantly enhance the learning performance of the hypergraph-based image clustering and classification. Keywords: Hypergraph · Hessian energy · Non-negative · Data representation
1 Introduction The development of data acquisition technologies, such as digital cameras, has resulted in a very fast accumulation of high-dimensional visual data. Effective data analysis algorithms are designed to extract important and useful information to enhance life quality and spawn new business models. Graph-based learning models [1] have received considerable attention among various algorithms. Hypergraph is used as the extension of graphs, which models the high-order relationship among the data samples. Thus, hypergraph-based learning methods have achieved promising performance on several visual applications, such as topic-sensitive influencer mining [2], image segmentation [3, 4], image clustering [5–7] and image classification [8–15]. For the hypergraph-based learning tasks, an informative hypergraph that can effectively modulate the underlying manifold of the data must be constructed. A © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 691–704, 2023. https://doi.org/10.1007/978-981-99-4761-4_58
692
L. Li et al.
neighborhood-based strategy [2, 4] is commonly used to generate a hyperedge set, where each hyperedge provides the capability to model complex correlations of data. In detail, this strategy takes each sample as a central node and then selects the k-nearest-neighbors to generate each hyperedge. The entire hypergraph is constructed based on a hyperedge set. However, the neighborhood-based strategy is sensitive to the neighborhood size and is only suitable for uniform distribution data. Moreover, noise usually contaminates the real-world data and the noisy data samples may be linked to generate a set of hyperedges. Thus, the learning performance of the neighborhood-based models is degraded dramatically. To tackle the noise issue, some recent studies have attempted to use data representation to generate noise-resistant hyperedges. For instance, 1 - Hypergraph [8] leverages sparse representation to reconstruct each sample and select the samples with sparse codes to generate each hyperedge. Furthermore, 1 -Hypergraph is extended by varying the regularization values of sparse representation and modulating different weighting schemes of hyperedges [9]. 2 -Hypergraph [5] leverages ridge regression instead of sparse coding for hyperedge generation. To be robust to significant data errors such as sample-specific outliers, Elastic-net Hypergraph [7] generates flexible hyperedges by solving a matrix elastic-net problem. As only the data samples with large representation coefficients are linked for hyperedge generation, the representation-based hypergraph models, to some extent, are not sensitive to noise. However, the existing representation-based hypergraph construction methods cannot still sufficiently reflect the data manifold [5–7]. Specifically, real-world image data lie on or close to a thin layer of a manifold embedded into high dimensional space. At the same time, the sampled image data may be contaminated by noise. For such image data, the existing hypergraph construction models including the neighborhood-based and representation-based methods cannot effectively capture the high-order correlations of the image data, resulting in degenerating learning performance. To this end, we aim to solve the practical problem of hypergraph learning with noisy image data sampled from the underlying manifold. The contributions of this article are summarized as follows: 1) We propose a novel hypergraph learning method, which introduces Hessian energy to the process of hyperedge generation. Hessian energy can make data representation vary linearly over the geodesics of the underlying image manifold. 2) The proposed method generates a set of hyperedges based on Hessian energy regularized data representation. Specifically, the non-negative constraint is imposed on the data representation. Then, the resultant large data representation is used to choose the image samples for hyperedge generation. 3) We employ extensive experiments on the noisy image datasets, demonstrating superior learning performance of the proposed method over the baseline methods.
2 Related Works Hypergraph has been used to model complex relationship among the data samples for the clustering task. For instance, Fang et al. [2] proposed a topic-sensitive influence mining framework and used hypergraph to determine the influence estimation. Docournau et al. [3] exploited directed hypergraphs to solve an image segmentation problem.
Hessian Non-negative Hypergraph
693
Furthermore, Kim et al. [4] suggested a hypergraph-based image segmentation method, employing a higher-order correlation clustering for a hypergraph-based segmentation procedure. Jin et al. [5] proposed 2 -Hypergraph, where the hyperedges are generated by solving a ridge regression problem. To be robust to non-Gaussian noises, Jin et al. [6] put forth a correntropy-induced-Hypergraph model, which leverages the correntropyinduced metric to measure the data errors and imposes a low-rank constraint on data representation. Liu et al. [7] proposed the Elastic-net Hypergraph, which generates a set of hyperedges by solving a matrix elastic-net problem. Except for the clustering task, the hypergraph is also used for the image classification and image ranking tasks. For instances, Zhang et al. [18] revisited semi-supervised learning on hypergraphs, confirming the interval approach’s efficacy on hypergraphs. Yu et al. [17] formulated a hypergraph-based semi-supervised learning method, where the weighting of hyperedges is adaptively coordinated. An et al. [19] proposed a multi-hypergraphbased person re-identification method, which leverages multiple hypergraphs to fuse complementary information embedded in different feature representations. Hypergraph is used for other learning tasks. For instance, Ji et al. [20] used a bilayer multimodal hypergraph learning method for the robust sentiment prediction of multimodal microblog tweets. Jin et al. [16] incorporated the neighborhood hypergraph to the low-rank matrix factorization framework, which makes the derived data representation capture the underlying manifold. Zhang et al. [9] used a hypergraph learning approach for feature selection, simultaneously learning hyperedge weights and doing feature selection. All these methods have demonstrated that hypergraph can significantly enhance the performance of machine learning approaches when capturing the high-order relationship among the data samples. Thus, this article aims to introduce the Hessian energy operator to hypergraph learning, which results in a manifold-respecting hypergraph.
3 Methodology Given a data matrix, X = [x1 , x2 , · · · , xm ] ∈ Rn×m , each column xi ∈ Rn is image data vector, which corresponds to a node of a hypergraph. To model the high-order image correlations, each hyperedge of a hypergraph links the different numbers of images. Thus, hypergraph can be represented as HG = (V , E, W), where V is a set of images known as nodes or vertices, and E is a set of non-empty subsets of V as hyperedges and a hyperedge e usually contains more than two nodes. A hyperedge weights matrix W is defined to model the different effectiveness of the hyperedges (Fig. 1). 3.1 Hessian Energy-based Manifold Preserving Considering that f (xi ) is a function that maps the image xi to the corresponding data representation, we define fk 2M to measure the smoothness of fk (xi ) over the geodesics on image manifold. Hessian energy [21] is an important tool to measure the smoothness of a function, defined as fk 2M = M ∇a ∇b fk 2Tx∗ M ⊗Tx∗ M dV (x) (1)
694
L. Li et al.
Fig. 1. The framework of the proposed method. For the noisy image data (the images contaminated by Gaussian noise), the images are reconstructed linearly and the Frobenius norm is used to measure the reconstruction errors. Both the Hessian energy and Frobenius-norm are used to regulate the data representation. Finally, the non-negative constraint is imposed on the data representation to derive the parts-based data representation. Furthermore, the images with large data representation are used to generate a manifold respecting and noise-resistant hyperedges.
where ∇a ∇b f is the second covariant derivative of f . In addition, dV (x) is the natural volume element. ∇a ∇b f 2 is the 2 -norm of the second covariant derivative, calculated as the Frobenius-norm of the Hessian of f in normal coordinates. Given normal coordinates xr at xi and the function values f xj on Nk (xi ), we would like to have an operator H to estimate the Hessian of f at xi , i.e., p (i) ∂ 2 fk (2) j=1 H rsj fk xj ∂xr ∂xs ≈ Xi
Furthermore, we can calculate the Hessian energy operator H via a secondorder polynomial curving fitting in normal coordinates, i.e., (3) q(xi ) = fk xj + nr=1 Br xr + nr ns=r Ars xr xs where q(xi ) is the second-order polynomial and the zeroth-order term is fixed at f (xi ). If the neighborhood size is zero, it becomes the second-order Taylor expansion of f around ∂fk ∂ 2 fk xi . i.e., Br = ∂xr and Ars = ∂xr ∂xs . xi
xi
To estimate the local tangent space, we measure the smoothness of the function on the data locality. In other words, we perform PCA on the k-nearest neighbors of each image, and the derived t leading eigenvectors correspond to an orthogonal basis of the local tangent space. Furthermore, we leverage linear least squares to fit the second-order polynomial, thus resulting in 2
k fk xj − fk (xi ) − (w)j arg min w∈RS
j=1
(4)
Hessian Non-negative Hypergraph
695
where ∈ Rk×P is the design matrix with P = m+ m(m+1) . The last m(m+1) components 2 2 of w corresponding to the coefficients Ars of polynomial. (i) Based on Eq. (4), we can derive the desired form Hrsj . The estimation of the Frobenius-norm of the Hessian energy of f at Xi is formulated as ∇a ∇b fk 2 ≈
p
M r,s=1
(i) 2
α=1 Hrsα
=
p
(i) α,β=1 vαk vβk Lαβ
(5)
(i) (i) (i) where Lαβ = nr,s=1 Hrsα Hrsβ . Then, the estimated Hessian energy Rp is the sum over all the data points, defined as Rk =
m n i=1
r,s=1
2
∂ 2 fk ∂xr ∂xj x
i
=
m i=1
α∈Np (Xi )
(i) β∈Np (Xi ) vαk vβk Lαβ
(6)
where m is the number of data points. 3.2 Hessian Regularized Non-negative Data Representation Based on the self-expression property of the image data, we reconstruct each image by the entire dataset. In other words, for each image xi , it can be represented as a new coordinate vector with respect to the entire image dataset. Furthermore, we combine the Hessian energy regularization term into the data reconstruction framework, resulting in the following optimization problem: minX − XF2F + λ k Rk + γ F2F F (7) = minX − XF2F + λTr FLHess FT + γ F2F F
s.t. F ≥ 0 where Tr(·) denotes trace of the matrix. LHess =
m
i=1 L
(i) .
3.3 Hyperedge Generation Given a node set and a hyperedge set, an incidence matrix H indicates whether a node belongs to a hyperedge. In other words, H(v, e) = 1 if v ∈ e; otherwise, H(v, e) = 0. The hyperedges have flexibility to model the high-order image correlations, which links each image to different numbers of images. We need compute the incidence matrix to indicate the relationship between each image and a hyperedge. Thus, we compute the incidence matrix as 1 if fij ∈ top(q) (8) H(vj ,ei ) = 0 otherwise where ei is the hyperedge associated with data point xi , and fij is the j-th representation coefficient of xi and top(q) is the top-q-largest data representation of xi . Finally, we present the weight scheme for a given hyperedge. In our method, we sum up the similarities of the images as the weight of a hyperedge, where the similarity
696
L. Li et al.
between two images is calculated as the dot product of the two representation vectors of the two images. Thus, we give the weight of the hyperedge as [6] w(ei ) = vi ∈ei ,j =i S fi , fj (9) where S(fi , fj ) =< fi , fj > is used to measure the similarity between xi and xj . 3.4 The Proposed Algorithm Algorithm 1 lists the main procedure of how to construct a Hessian-induced non-negative hypergraph.
4 Optimization 4.1 The Update Rule Notice that A2F = Tr(AAT ), we have O =Tr((X − XF)(X − XF)T ) + λTr(FLHess FT ) + γ Tr(FFT ) = Tr XXT − 2XFXT + XFFT XT + λFLHess FT + γ FFT
(10)
Let φkj be the Lagrange multiplier for constraint fkj ≥ 0. We define matrix = [φkj ], then the Lagrange function Q is Q = Tr XXT − 2XFXT + XFFT XT + λFLHess FT + γ Tr FFT + Tr FT (11) The partial derivatives of Q with respect to F is ∂Q ∂F
= −2XT X + 2XT XF + 2λFLHess + γ F +
(12)
Using the KKT condition φkj fkj = 0, we get the following formulation for fkj : (XT XF)kj fkj + (λFLHess )kj fkj = (XT X )kj fkj + γ fkj
(13)
Hessian Non-negative Hypergraph
697
We adopt a similar trick as in [23] to solve the optimization problem. Introduce + − − LHess LHess = LHess |(L
) |+(L
)
(14)
|(L
) |−(L
)
+ − where (LHess )ij = Hess ij 2 Hess ij and (LHess )ij = Hess ij 2 Hess ij . Substituting Eq. (14) into Eq. (13), we derive the following formulation: − + −(XT X )kj Fkj − (λFLHess )kj Fkj + (XT XF)kj Fkj + (λX LHess )kj Fkj + γ Fkj = 0 (15)
Equation (15) results in the following update rules:
XT X+λFL−
Fkj ← Fkj XT XF+λFL+ Hess+γkjF Hess
(16)
kj
4.2 Convergence Analysis The update rule of Eq. (16) is literately employed to update F. In the following, we give the proof to prove the objective function is non-increasing under this update rule.
Definition:G(f , f ) is an auxiliary function for g(f ) if the conditions G f , f ≥ g(f ), G(f , f ) = g(f )
(17)
are satisfied. Lemma: If G is an auxiliary function of g, then g is non-increasing under the update f (t+1) = argminG f , f (t) (18) f
Proof:
g f (t+1) ≤ G f (t+1) , f (t) ≤ G f (t) , f (t) = g f (t)
(19)
For each element fkj in F, gkj denotes the part of the objective function which is relevant only to fkj . gkj = ∂O = −2XT X + 2XT XF + 2λFLHess + 2γ F kj (20) ∂F kj
(t)
gkj = 2(XT X )kk + 2λFkj (LHess )jj + 2γ F Lemma: Function
(t) (t) (t) (t) f − fkj G f , fkj = gkj fkj + gkj fkj +
+ (XT XF+λFLHess +γ F)kj (t) fkj
(t)
(f − fkj )2
(21)
(22)
698
L. Li et al.
is an auxiliary function for gkj , the part of the objective function which is relevant only to fkj . (t)
Proof: First, we obtain G(f , f ) = gkj (f ), so we need to prove only that G(f , fkj ) ≥ gkj (f ). The Tayler series expansion of gkj (f ) is as follows: (t) (t) (t) f − fkj gkj (f ) =gkj fkj + gkj fkj (t)
+ [(XT X )kk + λ(LHess )jj + +γ F](f − fkj )2
(23)
We have (XT XF)kj =
K
l=1 (X
T X)
(t) kl flj
(t)
≥ fkj (XT X )kk
(24)
and + (FLHess )kj =
N
(t)
(t)
+ + fkl (LHess )lj ≥ fkj (LHess )jj
l=1 (t)
(t)
+ − ≥ fkj (LHess − LHess )jj = fkj (LHess )jj
Then, we arrive at:
+ (XT XF+λVLHess )kj (t)
fkj
gkj (f ).
(25)
≥ (XT X)kk + λ(LHess )jj . Thus, G(f , fkj(t) ) ≥
Lemma: The objective function in Eq. (10) is non-increasing under the updating rules in Eq. (16). (t)
Proof: Replacing G(v, vkj ) in Eq. (18) by Eq. (22) results in the update rule as: (t) gkj fkj (t+1) (t) (t) fkj =fkj − fkj T + 2X XF + 2λFLHess + 2γ F kj T X X + λFL− Hess kj (t) = fkj T (26) X XF + λFL+ Hess + γ F kj Because Eq. (22) is an auxiliary function, gkj is non-increasing under this update rule.
5 Experimental Results and Discussing 5.1 Datasets and Baselines We validate the proposed method on two real-world image datasets Coil 100 [22] and USPS [5]. To demonstrate the superiority of our method, we compare it with multiple baselines including the neighborhood-based common graph (KNN-Gr), AdaHypergraph (Ada-HG) [17], KNN-Hypergraph (KNN-HG) [22], 2 -Hypergraph (1 HG) [8], 2 -Hypergraph (2 -HG) [5] and Elastic-Net Hypergraph (EN-HG) [7]. We conduct cross-validation tests to tune the parameter values of the baseline methods.
Hessian Non-negative Hypergraph
699
5.2 Image Clustering on Corrupted Data Hypergraph-based clustering belong to the graph-based machine learning model. When the constructed hypergraph is used to model the relationship among the images, the corresponding hypergraph Laplacian is embedded into the spectral clustering framework to employ the image clustering task. To simulate the noisy image data, we add the different corruptions of Gaussian noise to the image x of each dataset. In other words, x = x + ηn, where η is the corruption ratio and n is the noise following the Gaussian distribution. To measure the image clustering performance of different methods, we adopt classification accuracy (AC) and the mutual information-based metric (NMI) [16] to evaluate the performance. The results of different methods are listed in Table 1 and Table 2. From two tables, we observe the following: Table 1. The image clustering results on Coil 100 dataset Method
AC(%)
Corrupted ratio
10%
20%
30%
40%
10%
NMI(%) 20%
30%
40%
KNN-Gr
70.9
65.4
61.2
56.3
79.5
74.3
69.7
63.2
KNN-HG [22]
71.9
67.5
63.2
57.6
80.2
76.3
72.5
67.2
EN-HG [7]
70.1
68.5
66.3
64.2
80.2
77.5
75.6
73.2
Ada-HG [17]
73.2
70.6
64.3
60.1
86.3
80.3
75.2
69.5
1 -HG [8]
72.5
70.1
68.9
66.5
83.2
81.5
79.5
77.6
2 -HG [5]
74.3
72.1
70.3
68.5
85.2
83.5
81.1
77.5
Hs-HG (ours)
75.6
73.1
71.2
69.8
86.5
84.1
82.1
79.6
Table 2. The image clustering results on USPS 100 dataset Method
AC(%)
NMI(%)
Corrupted ratio
10%
20%
30%
40%
10%
20%
30%
40%
KNN-Gr
68.5
64.3
60.3
56.2
72.3
66.5
62.9
56.8
KNN-HG [22]
70.2
66.3
61.5
57.6
73.6
69.2
64.2
60.3
EN-HG [7]
74.3
73.1
71.2
69.1
76.5
73.5
70.2
68.5
Ada-HG [17]
78.9
74.1
70.2
64.3
80.3
76.2
71.2
64.2
1-HG [8]
76.5
74.8
75.2
72.8
78.9
75.4
73.1
72.3
2-HG [5]
80.2
78.5
76.2
74.3
81.5
80.0
78.5
75.3
Hs-HG (ours)
81.3
79.6
77.2
76.3
82.9
81.2
79.6
77.5
The proposed hypergraph-based image clustering method can outperform the second-best results of the baseline method by at least 1% and 2% on the Coil 100
700
L. Li et al.
and USPS image datasets, respectively. In particular, the performance gain of the proposed method (Hs-HG) is more evident as the image noise corruption level increases. The reason is that the Hessian energy is introduced to the data reconstruction scheme, which can be robust to image noise and discover the discriminative structure of the underlying image manifold. Moreover, the non-negative constraint makes the data reconstruction scheme learn the parts-based data representation. Thus, Hessian energy, data reconstruction scheme and parts-based data representation make the proposed method achieve superior performance over the baseline methods. Table 3. The image classification results on Coil 100 dataset Method
ACC(%)
Corrupted ratio
10%
20%
30%
40%
KNN-Gr
89.5
84.3
79.2
75.6
KNN-HG [22]
90.6
86.5
81.3
77.9
EN-HG [7]
88.6
86.2
84.1
82.2
Ada-HG [17]
91.2
88.3
83.1
79.2
1-HG [8]
91.2
89.2
86.3
83.5
2-HG [5]
92.1
90.2
88.4
86.5
Hs-HG(ours)
94.1
93.1
91.2
89.1
Table 4. The image classification results on USPS dataset Method
ACC(%)
Corrupted ratio
10%
20%
30%
40%
KNN-Gr
94.5
91.2
85.3
80.5
KNN-HG [22]
96.1
92.3
88.7
83.2
EN-HG [7]
93.1
90.4
89.8
87.6
Ada-HG [17]
95.4
93.4
89.2
83.1
1-HG [8]
95.6
93.2
91.1
89.8
2-HG [5]
96.2
94.5
90.2
91.0
Hs-HG(ours)
97.2
95.8
94.2
92.2
5.3 Image Classification on Corrupted Data We use the common transductive classification framework [5, 8] to classify. We adopt the classification accuracy for classification (ACC) [6] to evaluate the classification performance and leverage the same noisy image datasets, which is the same as the ones
Hessian Non-negative Hypergraph
701
used in the clustering experiments. For Coil 100 dataset, we randomly select 30% of the labeled images from each object to form the training set and then the rest of the object images as the test set. For USPS dataset, we adopt the familiar experimental setting, which includes 7,291 train images and 2,007 test images. Table 3 and Table 4 list the experimental results on two corrupted datasets. From two tables, we observe that the proposed method (Hs-HG) consistently outperforms the baseline methods and achieves the best image classification performance. The main reason is that the proposed method effectively captures the image manifold via Hessian energy regularization. Compared with the baseline methods, the hessian energy can make the data representation with the same class move more smoothly over the image manifold. Moreover, the non-negative constraint makes the data reconstruction scheme learn the parts-based data representation, resulting in more discriminative hyperedges. Table 5. The ablation studies on Coil 100 dataset Method
AC(%)
Corrupted ratio
10%
20%
40%
10%
NMI(%) 20%
40%
10%
ACC(%) 20%
40%
He-HG
73.8
72.5
67.5
83.3
81.4
77.6
92.1
91.2
87.2
Hc-HG
70.4
68.5
64.3
76.5
73.2
71.5
88.3
86.2
80.1
Hs-HG
75.6
73.1
69.8
86.5
84.1
79.6
94.1
93.1
89.1
Table 6. The ablation studies on USPS dataset Method
AC(%)
Corrupted ratio
10%
20%
40%
10%
NMI(%) 20%
40%
10%
ACC(%) 20%
40%
He-HG
80.3
78.1
75.1
81.0
77.6
75.4
96.1
95.1
91.2
Hc-HG
78.6
77.5
72.0
78.9
77.1
71.8
92.3
90.4
87.3
Hs-HG
81.3
79.6
76.3
82.9
81.2
77.5
97.2
95.8
92.2
5.4 Ablation Studies of Two Strategies To validate further the effectiveness of the proposed two schemes, namely, the Hessian energy regularization and the non-negative constraint for representation-based hypergraph learning, we perform detailed ablation studies on two noisy image datasets to compare the proposed method with the following two variants. (1) We remove the Hessian energy regularization from the proposed framework and preserve the constraint, resulting in one variant of the proposed method, termed as Hc-HG. (2) We remove the non-negative constraint from the proposed framework and preserve the Hessian energy
702
L. Li et al.
term, resulting in the other variant of the proposed method, termed as He-HG. The results are listed in Table 5 and Table 6. The two tables show that the model with Hessian energy regularization (HeHG) and the one with the non-negative constraint (Hc-HG) bring some performance improvements on two image noisy datasets.
Fig. 2. The performance under different parameters. (a) k, (b) λ and (c) γ .
With both strategies, the proposed method (Hs-HG) displays the large performance improvements over the model with the single strategy (Hc-HG or He-HG). The experimental results demonstrate that both the Hessian energy regularization and the non-negative constraint play the important roles in enhancing the performance of hypergraph-based learning tasks. 5.5 Parameter Setting The parameter adjustment is crucial for the proposed method to achieve encouraging results on different image datasets. We design the parameter setting for the proposed method, which has three essential regularization parameters: k, λ, and γ . k is the neighborhood size parameter, λ is the Hessian energy regularization parameter and γ is the Frobenius norm regularization parameter. For the q parameter, we preserve the top-5 most significant data representations and set the other data representations as zeros. We conduct the parameter setting experiments by varying different values of one parameter when fixing two parameters on the Coil 100 dataset (20% image Gaussian Noise). From Fig. 2, we observe the following: With the increase of parameter k, the performance also increases accordingly. The learning performance of the proposed method begins to degenerate a little until the parameter value reaches 9. When the parameter value of λ is too small or too large, the learning performance degenerates slightly. For the other parameter values, the learning performance is relatively stable. The performance changes under different parameter values of γ are similar as the other parameters. In other words, when the parameter value of γ is very small, resulting in the degenerating performance; when the value of parameter γ is very large, the learning procedure is over-fitted to the regularization process.
Hessian Non-negative Hypergraph
703
6 Conclusion In this article, we have proposed a novel Hessian-induced hypergraph model. Compared with the existing methods, the proposed model has two key differences: (1) Hessian energy-based data representation to generate the manifold respecting hyperedges. (2) The non-negative constraint is imposed on the data representation to generate noiseresistant hyperedges. The experimental results on the noisy image datasets demonstrate that the proposed hypergraph construction method outperforms the baselines with a large performance gain. In this work, we choose middle-size image datasets instead of large datasets to employ the proposed method. We will make other attempts to handle the scalability issue in the future. Acknowledgement. This work is supported by National Key Research and Development Program of China (No. 2020AAA0109700), National Natural Science Foundation of China (Nos. 62072386, 62076167), the Henan Center for Outstanding Overseas Scientists (GZS2022011) and the open foundation of Henan Key Laboratory of General Aviation Technology (no. ZHKF-230212).
References 1. Belkin, M., Niyogi, P.: Laplacian eigenmaps and spectral techniques for embedding and clustering. In Advances in Neural Inf. Processing Syst. 14, 585–591 (2001) 2. Fang, Q., Sang, J., Xu, C., et al.: Topic-sensitive influencer mining in interest-based social media networks via hypergraph learning. IEEE Trans. Multimedia 16(3), 796–812 (2014) 3. Ducournau, A., Bretto, A.: Random walks in directed hypergraphs and application to semisupervised image segmentation. Comput. Vis. Image Underst. 120, 91–102 (2014) 4. Kim, S., Yoo, C.D., Nowozin, S., et al.: Image segmentation using higher order correlation clustering. IEEE Trans. Pattern Analysis and Machine Intelligence 36(9), 1761–1774 (2014) 5. Jin, T., Yu, Z., Gao, Y., Gao, S., Sun, X., Li, C.: Robust L2 -hypergraph and its applications. Inf. Sci. 501, 708–723 (2019) 6. Jin, T., Ji, R., Gao, Y., Sun, X., Zhao, X., Tao, D.: Correntropy-induced robust low-rank hypergraph. IEEE Trans. Image Process. 28(6), 2755–2769 (2019) 7. Liu, Q., Sun, Y., Wang, C., Liu, T., Tao, D.: Elastic-Net hypergraph learning for image clustering and semi-supervised classification. IEEE Trans. Image Process. 26(1), 452–463 (2017) 8. Wang, M., Liu, X., Wu, X.: Visual classification by 1-hypergraph modeling. IEEE Trans. Knowl. Data Eng. 27(9), 2564–2574 (2015) 9. Zhang, Z., Bai, L., Liang, Y., Hancockc, E.: Joint hypergraph learning and sparse regression for feature selection. Pattern Recogn. 63, 291–309 (2017) 10. Wang, M., Jin, T., Zhang, M., et al.: CSMPQ: Class Separability Based Mixed-Precision Quantization. arXiv preprint arXiv:2212.10220 (2022) 11. Zheng, X., et al.: Ddpnas: Efficient neural architecture search via dynamic distribution pruning. International Journal of Computer Vision, pp. 1–16 (2023) 12. Zheng, X., et al.: Migo-nas: Towards fast and generalizable neural architecture search. IEEE Trans. Pattern Analysis and Machine Intelligence 43(9), 2936–2952 (2021) 13. Zhang, S., et al.: You Only Compress Once: Towards Effective and Elastic BERT Compression via Exploit-Explore Stochastic Nature Gradient. arXiv preprint arXiv:2106.02435 (2021)
704
L. Li et al.
14. Zhang, S., et al.: Targeted hyperparameter optimization with lexicographic preferences over multiple objectives. The Eleventh International Conference on Learning Representations (2023) 15. Zheng, X., et al.: Towards optimal fine grained retrieval via decorrelated centralized loss with normalize-scale layer. Proceedings of the AAAI Conference on Artificial Intelligence 33(1) (2019) 16. Jin, T., Yu, J., Jane, Y., et al.: Low-rank matrix factorization with multiple hypergraph regularizer. Pattern Recogn. 48(3), 1011–1022 (2015) 17. Yu, J., Tao, D., Wang, M.: Adaptive hypergraph learning and its application in image classification. IEEE Trans. Image Process. 21(7), 3262–3272 (2012) 18. Zhang, C., Hu, S., Tang, Z., Chan, T.H.: Re-revisiting learning on hypergraphs: confidence interval, subgradient method and extension to multiclass. IEEE Trans. Knowl. Data Eng. 32(3), 506–518 (2020) 19. An, L., Chen, X., Yang, S., Li, X.: Person re-identification by multihypergraph fusion. IEEE Trans. Neural Networks Learning Syst. 28(11), 2763–2774 (2016) 20. Ji, R., Chen, F., Cao, L., Gao, Y.: Cross-modality microblog sentiment prediction via bi-layer multimodal hypergraph learning. IEEE Trans. Multimedia 21(4), 1062–1075 (2018) 21. Kim, K., Steinke, F., Hein , M.: Semi-supervised regression using hessian energy with an application to semi-supervised dimensionality reduction. In: Advances in Neural Information Processing Systems, pp. 979–987 (2009) 22. Huang, Y., Liu, Q., Lv, F., Gong, Y., Metaxas, D.: Unsupervised image categorization by hypergraph partition. IEEE Trans. Pattern Analysis Machine Intell. 33(6), 1266–1273 (2011) 23. Min, X., Chen, Y., Ge, S.: Nonnegative matrix factorization with Hessian regularizer. Pattern Anal. Appl. 21(2), 501–513 (2017)
Explainable Knowledge Reasoning on Power Grid Knowledge Graph Yingyue Zhang1 , Qiyao Huang1 , Zhou Zheng2 , Feilong Liao2 , Longqiang Yi3 , Jinhu Li4 , Jiangsheng Huang4 , and Zhihong Zhang1(B) 1 School of Informatics, Xiamen University, Xiamen 361005, Fujian, China
[email protected]
2 State Grid Fujian Electric Power Research Institute, Fuzhou 350007, Fujian, China 3 Kehua Data Co., Ltd., Xiamen 361006, Fujian, China 4 State Grid Info-Telecom Great Power Science and Technology Co., Ltd., Fuzhou 350003,
Fujian, China
Abstract. The smooth operation of the power grid is closely related to the national economy and people’s livelihood. The knowledge graph, as a widely-used technology, has made considerable contributions to power grid dispatching and query answering. However, explainable reasoning on grid defects datasets is still of great challenge, most models cannot balance effectiveness and explainablity. Therefore, their assistance in grid defects diagnosis is minimal. To address this issue, we propose the rule-enhanced cognitive graph for power grid knowledge reasoning. Our model consists of two modules: expansion and reasoning. For the expansion module, we take into consideration that path-based methods often ignore graph structure and global information and combine the local cognitive graph and global degree distribution. For the reasoning module, we provide reasoning evidence from two aspects: logical rule learning for strong evidence and cognitive reasoning for possible paths. Experiment results on our grid defects dataset make known that our model achieves better performance with explainablity. Keywords: Power Grid Knowledge Graph · Knowledge Reasoning · Cognitive Graph
1 Introduction The power system plays a critical role in modern society by providing a safe, reliable, and affordable supply of electricity to meet the diverse needs of industrial, commercial, and residential consumers. However, grid defects, such as electrical faults and equipment faults, often threaten its normal operation. Traditional grid defects diagnosis is a processing of state recognition, which identifies abnormalities by analyzing historical and current operating conditions and analyzes the causes of anomalies or faults. A knowledge base, or knowledge graph [1], is now a standard storage format for large-scale datasets and has been used to serve the power system [2, 3]. In addition to the benefits of storage and query, knowledge graph reasoning has also made considerable progress. Knowledge graph reasoning seeks to infer new triples from existing © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 705–714, 2023. https://doi.org/10.1007/978-981-99-4761-4_59
706
Y. Zhang et al.
facts to accomplish various downstream tasks such as link prediction and relation prediction. Link prediction is helpful in grid defects diagnosis, for it could provide related information even directly faulty equipment to maintenance personnel in a short time. A well-executed graph reasoning system could significantly increase job efficiency. The common approaches of knowledge graph reasoning could be roughly divided into two categories: embedding-based methods and path-based methods [4]. Embeddingbased methods generate vector representation for each node and relation and get answers by distance metrics. But they ignore hidden relations for multi-hop reasoning and always lack explainability. Path-based methods get the relations between entities by path information, so it could output the whole reasoning path. But it often performs poorly on the unseen path. Getting the reasoning path is just as essential as getting the right answer when diagnosing grid defects. So figuring out how to argue with explainability is a significant challenge. In recent years, the dual process theory [5] is utilized in graph reasoning. Specifically, it learns from the cognitive reasoning system of humans: system 1 for retrieving relevant information and system 2 for logical reasoning. Ding et al. first introduce dual process theory into multi-hop reasoning. They proposed CogQA [6], a framework for question answering based on Bert and Graph Neural Network (GNN), highlighting its tremendous potential in explainable knowledge graph reasoning. Furthermore, there are two major issues in our grid defects dataset. First, the graph may be very sparse due to the large amount of equipment and relatively low failure rate. Traditional embedding-based methods, particularly complex methods, have a difficult time obtaining adequate representation. Second, the reasoning path for the grid defect could involve more than one hop, increasing the demand for multi-hop reasoning. To tackle these challenges, we select the cognitive graph to enhance inference in our model and considering the characteristics of the domain-specific knowledge graph, utilize logical rules to improve the explainablity of the reasoning result. Specifically, we first generate rules by constrained random walks, then expand the cognitive graph for the current node and finally reason over the cognitive graph. To expose our model to global structural information, we add degree distribution into entity representation. The corresponding logical rule will be seen as strong evidence for inference and the reasoning path as a supplement. Our contributions are summarized as follows: • We utilize the rule-enhanced cognitive graph to conduct multi-hop reasoning over the power grid knowledge graph, which provides explainable results. • We capture a lot of rules about grid defects by constrained random walks. These rules are meaningful in grid defects diagnosis. • We evaluate our model on the grid defects dataset through the task of link prediction and our model exceeds all baselines.
2 Related Work 2.1 Power Knowledge Graph With the development of big power data, it is more and more necessary to explore a natural language processing method to deal with various power data. Knowledge graph technology, with more flexible knowledge storage and retrieval, is widely used to
Explainable Knowledge Reasoning on Power Grid Knowledge Graph
707
describe the relations between different entities. These years, managing power data with knowledge graphs has been trendy. Fan et al. [7] detailed describes the process of power knowledge graph construction using dispatching data. Huang et al. [8] proposes a semiautomatic knowledge graph construction technology, and successfully improves the efficiency of the power grid’s daily operations. Tang et al. [9] further designs an intelligent question-answering system to directly capture answers to understand natural language questions. Power knowledge graph technology is also utilized to solve specific challenges of the power network. Wang et al. [10] apply a power knowledge graph and neural network for the automatic identification of the topology of power networks. Liang et al. [11] propose PF2RM, a module that combines power fault retrieval and serpolymorphic retrieval recommendation, to face the rapid development of grid defects arising from energy saving and emission reduction. Wu et al. [12] develop an intelligent search engine for power equipment management. In conclusion, the power knowledge graph feasibly helps the power grid’s operation and is expected to be applied to more specific scenarios. 2.2 Knowledge Graph Reasoning Traditional knowledge reasoning is mainly based on logical rules. Later statistical learning methods represented by random walks [13] are used to enrich inference. Path ranking [14] algorithm is first developed to conduct path reasoning. Then DeepPath [15] introduces reinforcement learning to complete knowledge graph reasoning. Yang et al. [16] propose neural logic programming, an end-to-end module that combines first-order logical rules and structure learning. DRUM [17] mines first-order logical rules from the knowledge graph and connects confidence scores and low-rank tensor approximation to deal with unseen entities. NBFNet [18] represents a pair of nodes as the generalized sum of all paths and utilizes the Bellman-Ford algorithm to solve path formulation. SMORE [19] is a general framework for both single-hop and multi-hop reasoning, which learns an embedding-based method by contrastive learning. RED-GNN [20] introduces a relational directed graph, which consists of relational paths, and recursively encodes it to achieve significant performance. Path-based methods are still mainstream in knowledge graph reasoning. However, embedding-based methods often perform better in downstream tasks, as a result of which, various approaches are used to improve their performance, ‘neural for symbolic’ is now a hot research point.
3 Problem Formulation 3.1 Knowledge Graph A knowledge graph is often thought of as a collection of existing facts. We represent the knowledge graph G as G = (ε, R, T ), where ε and R represent the entity and relation set respectively and is the set of triples. A triple consists of head entity es , relation r and tail entity eo . For example, the triple (es , r, eo ) ∈ T means the face that there is a relation r ∈ R between entity es ∈ ε and entity eo ∈ ε. To enhance the connectivity, we add inverse link (es , r −1 , eo ) to every (es , r, eo ) ∈ T .
708
Y. Zhang et al.
3.2 Link Prediction Link prediction is the task to learn from known facts to infer unknown or missing ones. It is widely used to evaluate the performance of a module in knowledge graph reasoning or completion. Specifically, given an incomplete triple (es , r, ?) or (?, r, eo ), link prediction aims to find out one or more correct answers to the missing entity. In our experiment, we remove the tail entities of the validation and test dataset to evaluate the performance of our model. 3.3 Logical Rule A logical rule of length l is defined as: Pl+1 (E1 , El+1 ) ← ∧li=1 Pi (Ei , Ei+1 ) where E ∈ ε represents the variables of entities, P ∈ R donates the predicates. The rule head contains a predicate Pl+1 , which is known in link prediction. For a knowledge graph, we can regard the relation r as predicate Pl+1 and the head entity es and tail entity eo are E 1 and E l+1 respectively. A logical rule is a path to get the triple (es , r, eo ). Obviously, a rule corresponds to a walk on the graph. So we capture rules by constrained random walks.
4 Approach 4.1 Cognitive Reasoning The cognitive graph G is a subgraph of knowledge graph G and G = (E, T, X), where E ∈ ε and T ∈ T . Here X ∈ R|E|×d is the latent representation of nodes in the cognitive graph and d is the dimension of the latent representation. A cognitive graph G is constructed during the expansion process. Given an unknown triple (es , r, ?), its corresponding cognitive graph G is initialized with the start entity es . In the expansion process, part of neighbour entities of entity es are involved in G, and their latent representation in X is updated. An attention and scored. The representation of a candidate node is the connection of hidden representation, relation embedding and entity embedding. Only top-k nodes are involved in the cognitive graph. In the section stage, GRU is used to update entity embedding and the attention flow is calculated by aggregating possibility distribution (Fig. 1). Flow is also calculated to identify nodes in G. After n steps of expansion and calculation, a cognitive graph with entities in n-hops of entity es is constructed. Compared with traditional path-based methods, reasoning over cognitive graphs is more reasonable and impressive. First, a cognitive graph could module the interactions between paths, which contains essential structural information of the graph. Second, the cognitive graph allows the module to reason over the ‘unseen’ path, which means that there might be no relation between the adjacent two hops on reasoning paths. It could help the reasoning on incomplete graphs.
Explainable Knowledge Reasoning on Power Grid Knowledge Graph
709
Fig. 1. The process of our model in a step. The full cognitive reasoning process can be divided into two stages. In the first stage, the neighbours of current nodes are collected.
Expansion Module. The expansion process mainly collects usable evidence from the knowledge graph G to serve subsequent reasoning. Especially, it recursively searches relevant entities from the neighbourhood and updates their representation. For a start entity es , we get candidate entity set A0 = {(r, e)|(es , r, e) ∈ G} consists of neighbor entities of entity es and initialize attention a with a(es ) = 1 and a(e) = 0 for other entities. Considering that node es may have no outgoing edges, we add self-lop to each node and reverse edges to each edge. The representation of entity ek ∈ A0 here is the connection of its entity embedding vk , relation embedding vr and latent representation X(ek ) in cognitive graph. To avoid neighbourhood explosion, we rank edges of the adjacency list according to PageRank and set a parameter λ to constrain the size of the cognitive graph. We set C0 = {es} as current entity set. The representation of C0 is the connection of entity embedding vs , query representation vq and latent representation X(es ). The score of entity ek is calculated as: sk = σ ([X (ek )vr vk ]W1 ) · σ (W2 [X (es )vq vs ])
(1)
where vq , vs , vr , vk are all trainable vectors related to query q, entity es , relation r and entity ek respectively. Query q here could be from support pairs or simply using relation. σ represents activation function and means connection. Similarly, for each step t in expansion process, we can get current entity set Ct−1 and candidate entity set At−1 . Taking the fact into account that cognitive graphs only learn from the neighbourhood and ignore the global structural information, we put global degree distribution d into our representation. For an entity ec ∈ Ct−1 , let the candidate matrix M (ec ) ∈ R|At−1|×3d , the score distribution and probability distribution of candidate entities is: st−1 (ec ) = σ (M (ec )W1 ) · σ (W2 [X (ec )vqc vc d])
(2)
pt−1 (ec ) = at−1 (ec )Softmax(st−1 (ec ))
(3)
710
Y. Zhang et al.
We select top k edges with the largest probability and update cognitive graph G with them. Specifically, selected edges are added into edge set E, and tail entities which have not been visited are added into a node set V, which completes the expansion of cognitive graph G. The score function Eq. 2 connects local representation and global degree distribution. So After the expansion process, we can get attention a from both local and global information. Obviously, more than one edge and nodes are involved in G during the expansion process. So the reasoning system could search over a directed acyclic graph rather than a set of paths, which makes it possible to reason over the unseen path or incomplete graph. Reasoning Module. The aim of the reasoning process is to conduct reasoning on the cognitive graph. More specifically, the reasoning system’s function is to update both representations of entities and attention a. To update the representation of entities, we choose a graph neural network, which modules the entity in the graph by aggregating its neighbour entities. The message pass to entity e can be represented as M (ek , rk , e) (4) me = (ek ,rk ,e)∈E
where M (ek , rk , e) is the message vector pass from ek to e. Here we update all representations on the same layer in sequence instead of calculating from previous layers. We utilize a Gated Recurrent Unit(GRU) as the message function: M (ek , rk , e) = GRU (X [ek ]vrk ve )
(5)
GRU is a type of Recurrent Neural Network (RNN) with a gating mechanism, which could adaptively capture dependencies of different time scales. Considering the fact that most reasoning paths are relatively short, we choose GRU as a message function instead of a more complex network such as LSTM. [21] Here GRU (•) update GRU with previous entity latent representation X[E k ] and the connection of relation embedding and entity embedding vrk ve . Let E e = {(ek , r k , e)| (ek , r k , e) ∈ E}, the update result is the average of all messages: U (e, me ) =
1 me . |Ee |
(6)
To finish updating attention a, we simply aggregate the probability of all selected ingoing edges of entity e to get the new attention distribution: pt (ek , rk , e) (7) at (e) = Normalized (ek ,rk ,e)∈E
where Normalized(•) means normalization function. After that, we finish the construction of cognitive graphs and relational reasoning using GNN. 4.2 Logical Rule Learning A logical rule consists of a rule head and a rule tail. Treating a known triple (es , r, eo ) as the rule head, we hope to find a set of paths from entity es to entity eo as the
Explainable Knowledge Reasoning on Power Grid Knowledge Graph
711
rule tail. Generally, the paths could be simply obtained from random walks. However, considering the reverse edges and self-loop in data processing, the traditional random walk could be inefficient and meaningless. A simple solution is to learn logical rules on the unprocessed graph. But it might result in the loss of important information because of the sparse knowledge graph. To overcome this challenge, we utilize constrained random walks to get rule tails. We express a walk w in the knowledge graph as a sequence of relations, such as r 1 , r 2 …r l , where l is the max length of walks. To decrease the influence of reverse edges and self-lops, two constraints are applied in walks. First, self-lop relation r lop should not appear in walks. Second, the walks should be similar to backtrackless walks, which means for each r i (r > 1) in a walk, r i−1 could not be r −1 . Here r −1 is the reverse relation of r i . 4.3 Optimization After the reasoning process, the candidates of the tail entity could be selected according to attention distribution a. So a direct optimization way is to maximise the attention score of correct tail entity eo for a triple (es , r, eo ), for which cross-entropy loss is a common function. However, a problem here is that only entities within sampled t-hops have attention scores. It is possible that the attention score of entity eo is set as 0. To settle this problem, we minimize the sum of a when a(ze ) = 0. The loss final loss function is set as: ⎧ a (e ) > 0 ⎨ −logaT (eo ) T o (aT , eo ) = (8)
⎩ −log 1 + − e aT (eo ) aT (eo ) = 0 where is a hyperparameter close to zero.
5 Experiment 5.1 Dataset We conduct link prediction on our grid defects dataset. It is a real-world dataset obtained from grid defects recordings, containing the main information such as location, description, reason, level and so on. These recordings are collected from routine maintenance manually and finally transformed from structured data into triplets. To ensure there could be a reasoning path between the head entity and the tail entity, we only take nodes with the sum of in-degree and out-degree more than 3. There are 105535 entities and 296181 triples in our dataset. The graph density of obtained power knowledge graph is 2.66e-5. It is obvious that the graph is more than sparse. We divided training, validation and test sets with a ratio of 8:1:1. Detailed information about our dataset is shown in Table 1.
712
Y. Zhang et al. Table 1. Information about training, validation and test dataset
Dataset
Triples
Training
entities
relations
221356
87224
13
Validation
37412
35757
13
Test
37413
35756
13
5.2 Experimental Setup We evaluate our model through the link prediction task, which removes the correct tail entities of triples in the testing and validation dataset and predicts the losing entities. Candidate constraint are not added in either our model or baselines. For our model, we set the embedding size as 768 and the hidden representation size as 200. The max number of edges used when updating cognitive top k is 64 and the maximum reasoning step l = 4. The batch size is 28 and the maximum training step is 100000. And the max number of random walks is 100 for each entity.We use adam optimization for training with the learning rate 1e-4. Three metrics, Hit@1, Hit@10 and mean reciprocal rank(MRR) are selected to evaluate the performance of modules. MRR metric is used to determine whether the current model checkpoint is the best. 5.3 Baselines We choose five excellent or commonly used baselines to compare the performance with our model: EIGAT [22] is an embedding-based method calculating graph attention with global entity importance. RED-GNN [20] is a path-based method that introduces relational directed graph to capture more complex structural information than paths. TransE[23], TransR [24] and TransH [25]. TransE, TransR and TransH are all embedding-based methods widely used in graph reasoning. Table 2. Experiment result of link prediction Method
Hit@1
Hit@10
MRR
TransE
2.80%
6.92%
0.0434
TransR
0.01%
1.54%
0.0057
TransH
2.71%
6.89%
0.0427
RED-GNN
33.60%
47.81%
0.3894
EIGAT
52.01%
73.27%
0.5997
Ours
69.06%
82.69%
0.7430
Explainable Knowledge Reasoning on Power Grid Knowledge Graph
713
5.4 Result Analysis Table 2 reports the result of link prediction on our power dataset. Generally, embeddingbased methods perform better than path-based methods. Our method achieves the best performance on all metrics. Especially, our method outperforms EIGAT, the secondbest method, with 17.05 on Hit@1, 9.42 on Hit@10 and 14.33 on MRR. It is worth noting that our model achieves greater improvement on Hit@1, which indicates that the dual process theory enhances precise reasoning. It shows the availability of cognitive reasoning and structural information. TransE, TransR and TransH all perform poorly on the power dataset, which means the traditional embedding-based method could not meet the requirements of grid defects diagnosis. 5.5 Ablation Study We conduct an ablation study to validate the effectiveness of elements in our method. The result is reported in Table 2. The structure represents our method without structural information. It is obvious that our method without structural information can affect its performance. The model without structural information performs more poorly on Hit@10 rather than other metrics, which proves that global structural information is beneficial to collect evidence.
6 Conculsion In this paper, we propose the rule-enhanced cognitive graph for Knowledge reasoning, a module based on dual process theory. We raise the explainablity by logical rule learning and improve performance by combining local subgraph structure and global degree distribution. The experiment result on link prediction demonstrates the effectiveness of our model. As a mainly path-based method, inductive learning and orphan nodes seriously influence the performance, and nodes’ attributes are not utilized fully. We hope to introduce other approaches such as the pre-trained language module to improve our model in future research further. Acknowledgment. This work is supported by Major Program of Xiamen (3502Z20231006); National Natural Science Foundation of China (62176227, U2066213); Fundamental Research Funds for the Central Universities (20720210047).
References 1. Pujara, J., Miao, H., Getoor, L., Cohen, W.: Knowledge graph identification. In: Alani, H., et al. (eds.) ISWC 2013. LNCS, vol. 8218, pp. 542–557. Springer, Heidelberg (2013). https:// doi.org/10.1007/978-3-642-41335-3_34 2. Meng, F., Yang, S., Wang, J., Xia, L., Liu, H.: Creating knowledge graph of electric power equipment faults based on bert–bilstm–crf model. J. Electrical Eng. Technol. 17(4), 2507– 2516 (2022)
714
Y. Zhang et al.
3. Ding, H., Qiu, Y., Yang, Y., Ma, J., Wang, J., Hua, L.: A review of the construction and application of knowledge graphs in smart grid. In: 2021 IEEE Sustainable Power and Energy Conference (iSPEC). pp. 3770–3775. IEEE (2021) 4. Chen, X., Jia, S., Xiang, Y.: A review: knowledge reasoning over knowledge graph. Expert Syst. Appl. 141, 112948 (2020) 5. Gawronski, B., Creighton, L.A.: Dual Process Theories (2013) 6. Ding, M., Zhou, C., Chen, Q., Yang, H., Tang, J.: Cognitive Graph for Multi-Hop Reading Comprehension at Scale. arXiv preprint arXiv:1905.05460 (2019) 7. Fan, S., et al.: How to construct a power knowledge graph with dispatching data? Sci. Program. 2020, 1–10 (2020) 8. Huang, H., Hong, Z., Zhou, H., Wu, J., Jin, N.: Knowledge graph construction and application of power grid equipment. Math. Probl. Eng. 2020, 1–10 (2020) 9. Tang, Y., Han, H., Yu, X., Zhao, J., Liu, G., Wei, L.: An intelligent question answering system based on power knowledge graph. In: 2021 IEEE Power & Energy Society General Meeting (PESGM). pp. 01–05. IEEE (2021) 10. Wang, C., An, J., Mu, G.: Power system network topology identification based on knowledge graph and graph neural network. Frontiers in Energy Res. 8, 613331 (2021) 11. Liang, K., Zhou, B., Zhang, Y., Li, Y., Zhang, B., Zhang, X.: Pf2rm: a power fault retrieval and recommendation model based on knowledge graph. Energies 15(5), 1810 (2022) 12. Wu, X., Tang, Y., Zhou, C., Zhu, G., Song, J., Liu, G.: An intelligent search engine based on knowledge graph for power equipment management. In: 2022 5th International Conference on Energy, Electrical and Power Engineering (CEEPE), pp. 370–374. IEEE (2022) 13. Lovász, L.: Random walks on graphs. Combinatorics, Paul Erdos is Eighty 2(1–46), 4 (1993) 14. Page, L., Brin, S., Motwani, R., Winograd, T.: The pagerank citation ranking: Bring order to the web. Tech. rep., technical report, Stanford University (1998) 15. Xiong, W., Hoang, T., Wang, W.Y.: Deeppath: A Reinforcement Learning Method for Knowledge Graph Reasoning. arXiv preprint arXiv:1707.06690 (2017) 16. Yang, F., Yang, Z., Cohen, W.W.: Differentiable learning of logical rules for knowledge base reasoning. Advances in Neural Inf. Processing Syst. 30 (2017) 17. Sadeghian, A., Armandpour, M., Ding, P., Wang, D.Z.: Drum: End-to-end differentiable rule mining on knowledge graphs. Advances in Neural Information Processing Syst. 32 (2019) 18. Zhu, Z., Zhang, Z., Xhonneux, L.P., Tang, J.: Neural bellman-ford networks: a general graph neural network framework for link prediction. Adv. Neural. Inf. Process. Syst. 34, 29476– 29490 (2021) 19. Ren, H., et al.: Smore: Knowledge graph completion and multi-hop reasoning in massive knowledge graphs. In: Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, pp. 1472–1482 (2022) 20. Zhang, Y., Yao, Q.: Knowledge graph reasoning with relational digraph. In: Proceedings of the ACM Web Conference 2022, pp. 912–924 (2022) 21. Chung, J., Gulcehre, C., Cho, K., Bengio, Y.: Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling. arXiv preprint arXiv:1412.3555 (2014) 22. Zhao, Y., et al.: Eigat: Incorporating global information in local attention for knowledge representation learning. Knowl.-Based Syst. 237, 107909 (2022) 23. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., Yakhnenko, O.: Translating embeddings for modeling multi-relational data. Advances in Neural Inf. Process. Syst. 26 (2013) 24. Lin, Y., Liu, Z., Sun, M., Liu, Y., Zhu, X.: Learning entity and relation embeddings for knowledge graph completion. In: Proceedings of the AAAI Conference on Artificial Intelligence, 29 (2015) 25. Wang, Z., Zhang, J., Feng, J., Chen, Z.: Knowledge graph embedding by translating on hyperplanes. In: Proceedings of the AAAI Conference on Artificial Intelligence, 28 (2014)
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph Embedding Models for Main Electrical Equipment Yanyu Chen1 , Jianye Huang2 , Jian Qian2 , Longqiang Yi3 , Jinhu Li4 , Jiangsheng Huang4 , and Zhihong Zhang1(B) 1 School of Informatics, Xiamen University, Xiamen 361005, China
[email protected]
2 State Grid Fujian Electric Power Research Institute, Fuzhou 350000, China 3 Kehua Data Co., Ltd, Xiamen 361005, China 4 State Grid Info-Telecom Great Power Science and Technology Co., Ltd., Xiamen 361005,
China
Abstract. The safety of electric power grids can be threatened by defects in main electrical equipment, creating significant risks and pressures for dispatching operations. To analyze defects in main electrical equipment, we adopt a knowledge graph link prediction approach. We found that using pre-training models, such as BERT, to extract node features and embed initial embeddings significantly improves the effectiveness of knowledge graph embedding models (KGEMs). However, this approach may not always work and could lead to performance degradation. To address this, we propose a transfer learning method that utilizes a small amount of domain-specific electric power corpus to fine-tune the pre-training model. The PCA algorithm is used to reduce the dimensionality of extracted features, thereby lowering the computational cost of KGEMs. Experimental results show that our model effectively improves link prediction performance in analyzing defects in main electrical equipment. Keywords: Main electrical equipment defects · Knowledge graph embedding models · Pre-training models
1 Introduction The electric power grid is a crucial infrastructure in modern society, and main electrical equipment defects represent a significant factor affecting its safe and reliable operation. These defects are typically categorized as either electrical faults or equipment faults. Electrical faults refer to abnormal power system operations caused by short circuits, overloads, or grounding issues, while equipment faults indicate abnormal behavior of system devices. The occurrence of main electrical equipment defects during electric power grid operation can cause severe consequences such as power outages and equipment damage, which can significantly impact the system’s safe and stable operation. Therefore, it is essential to promptly and accurately analyze and diagnose main electrical equipment defects to maintain the system’s reliability. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 715–725, 2023. https://doi.org/10.1007/978-981-99-4761-4_60
716
Y. Chen et al.
Follow predecessors’ work [1], using knowledge graph to describe defect events, defect phenomena, defect attributes and their relationships in a structured way, so as to better express complex defect relationships. For the analysis of main electrical equipment defects, we adopt the knowledge graph link prediction technique [2] to facilitate a more structured and comprehensive approach. This approach facilitates the identification of the root cause of the defect, allowing for targeted repairs and maintenance. Moreover, link prediction can significantly enhance the efficiency and accuracy of main electrical equipment defects analysis, enabling timely and effective maintenance, ultimately leading to reduced downtime and costs. A simple example is shown as Fig. 1.
Fig. 1. In this example, we are provided with head entities and relations, and the objective is to predict the tail entities. The complete set of entities in the knowledge graph forms the pool of all possible answers for the tail entity, and the best entity is selected based on its probability. The figure shows only a subset of the possible tail entity candidates.
In analyzing main electrical equipment defects, we have encountered serious issues when using embedding-based KGEMs [3] for link prediction. Due to the presence of a large number of professional terms, abbreviations, and incomprehensible words and sentences in the main electrical equipment defects data set, their embedding expressions may appear similar but have quite distinct meanings. As a result, their differences in the embedding space are not significant, making it challenging for the model to distinguish them accurately, resulting in poor performance. To address this challenge, we were inspired by the success of pre-trained language models, such as BERT [4] and RoBERTa [5], in sentence embeddings[6], and attempted to use BERT to extract features from knowledge graph nodes. However, this approach did not always work. Our analysis has revealed that using BERT pre-trained on public datasets not only fails to generate a more discriminative initial embedding of KGEMs, but also leads to misleading embeddings in some models, ultimately leading to a decline in link prediction performance.
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph
717
To overcome this difficulty, we propose fine-tuning BERT using domain-specific corpus. Although it is challenging to obtain sufficient corpus in the field of main electrical equipment defects in practical applications, we discovered that we can achieve excellent results with very little corpus. Another difficulty lies in the large dimensionality of the BERT hidden layer output, which can result in increased computational complexity and potential memory overflow issues with CUDA. Therefore, our aim was to apply the proposed method in low-resource environments, where only a small amount of specialized domain corpus and computing resources are available. To achieve this, we conducted experiments using only a few MB of corpus and low-dimensional embedding. To achieve dimensionality reduction for the initial embedding, we employed the simple PCA algorithm [7] and found it to be effective. For those aiming to improve accuracy further, we suggest exploring the use of more sophisticated dimensionality reduction methods instead of the simple PCA algorithm. In summary, we have made two contributions: – Propose a method for constructing a knowledge graph for analyzing main electrical equipment defects. – Enhancing the ability of BERT to comprehend domain-specific knowledge by finetuning it with such knowledge, which leads to an improvement in the link prediction performance of KGEMs through the use of the feature information extracted from BERT as the initial embedding.
2 Related Work The problem of learning sentence embeddings has been extensively researched in NLP. Currently, there is a new trend in utilizing the power of BERT in sentence embeddings [8]. In the professional field, researchers have also explored the combination of knowledge graphs and BERT. Liu et al. [9] have utilized BERT to embed knowledge graphs and effectively transfer expert knowledge to the model. KG-BERT [10] achieved promising results in knowledge graph completion and link prediction by taking the entity and relation descriptions of a triple as input and computing the scoring function of the triple using the KG-BERT language model. Another approach in combining pre-training models and knowledge graphs is to pretrain the knowledge graphs themselves [11], instead of relying on language models for learning initial embeddings. This approach has shown better performance on knowledge graph tasks such as link prediction and knowledge graph completion through multi-task pre-training and fine-tuning. For domain-specific knowledge graphs, Gururangan et al. [12] proposed a secondary pre-training approach to narrow the gap between public knowledge graphs and domainspecific ones. By fine-tuning BERT with domain-specific knowledge and employing incremental training, the performance of BERT can be significantly improved in specific domains. Our work draws inspiration from the aforementioned approach to incremental training and fine-tuning, but the corpus used in our study is much smaller and more
718
Y. Chen et al.
practical for real-world applications. Sushil et al. [13] investigated the gap between general BERT and domain-specific knowledge, and found a significant difference. They concluded that unsupervised text retrieval can be used to bridge the gap and facilitate inference by leveraging existing information.
3 Proposed Approach We begin by constructing a knowledge graph, which involves the design of an ontology [14] structure and the subsequent extraction of unstructured text based on this structure. The extraction process requires manual input from a professional team, but in this paper, we only focus on the ontology design, which will be explained in detail in Sect. 3.1. In Sect. 3.2, we provide a detailed description of our link prediction network architecture. We fine-tuned the pre-trained language model using domain-specific corpora. We use the pre-trained model as a feature extractor, and input these features into KGEMs to improve link prediction results. For KGEMs, we directly adopt some of the excellent models implemented by previous researchers. 3.1 Ontology Design This section presents the methodology for constructing the ontology. The ontology of a knowledge graph is a critical component that describes and organizes the entities, relationships, and attributes within the knowledge graph. It also includes information on the hierarchical structure, constraints, and semantic relationships between these components. We provide a detailed description of the ontology construction method in this section, specifically tailored to our electric power defect dataset. Our ontology design comprises 12 entity types and 12 relationship types, with the “defect phenomenon” at the core, which has the highest number of related edges. The complete list of entities is as follows: defect phenomenon, power station/line, power/line type, voltage level, defect attribute, defect location, defect nature, defect description, equipment type, defect equipment, equipment part type, and equipment part. The relationship types we have designed are intuitive, with a relationship denoted as h-t between the head entity h and the tail entity t. The detailed entity-relationship diagram is shown as Fig. 2. 3.2 Our Network Structure This section describes the network structure of our model. The complete architecture of the model is illustrated in Fig. 3. We express our knowledge graph as triples, with each triple in the format of (h, r, t), where h denotes the head entity, r denotes the relationship, and t denotes the tail entity. We extract textual representations for each node and relationship, and input them into a pre-trained language model that has been fine-tuned on domain-specific text. The output of the last hidden layer of the pre-trained model is used as the representation of the node.
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph
719
We fine-tuned a pre-trained model using domain-specific long-form texts. We merged all the corpus documents into a single long text document, where each line represented a sentence and was used as a training sample input for fine-tuning BERT with all layers except the last one being frozen. The standard output from BERT’s hidden layer is a 768-dimensional vector, which is not ideal for initial embeddings in link prediction tasks due to its high computational and memory demands. To address this issue, we employ principal component analysis (PCA) to decrease the output vector’s dimensionality, preserving essential information while reducing computational complexity and memory usage. The feature vector obtained after dimensionality reduction is utilized as the initial embedding for the link prediction task. We will adopt some existing implementations of KGEMs from prior research for our link prediction model, which includes its loss function, optimizer, and evaluation indicators that are consistent with the original implementation. Our prediction task involves selecting the most likely candidate entity from a pool of all entities, given the head entity and relationship. To achieve this, we calculate the probability of each candidate entity being selected.
Fig. 2. Design of our Ontology. As an illustration of edge definition, the relationship between the head entity “defect phenomenon” and the tail entity “defect location” is denoted as “defect phenomenon -- defect location”. For better clarity, the edge labels in the figure have been hidden.
720
Y. Chen et al.
Fig. 3. Design of our Network. There are various types of long text corpus that can be utilized to fine-tune BERT. However, for simplicity, we only show three types in the figure: maintenance question bank, equipment defect classification standards, and fault analysis cases. The BERT encoder is used to encode the nodes and edges of the knowledge graph, and the resulting feature representation is employed as the initial output of KGEMs. To keep the figure concise, we only depict the head node entering the BERT encoder.
4 Experiments 4.1 Dataset Description Our dataset is derived from real-world data obtained from State Grid Fujian Electric Power Research Institute. The data mainly comprises defect information on main electrical equipment, which is collected and recorded manually, sorted into a format, and finally transformed into triplets based on the ontology design. Our dataset comprises 58,820 triples, with 53,889 entities and 12 relations. We also created the inverse edges by adding “tail-to-head” relationship edges for each “head-to-tail” relationship edge. We divided the dataset into training, validation, and test sets in an 8:1:1 ratio, with bidirectional edges used in the training set, and unidirectional edges used in the validation and test sets. The fine-tuning corpus used in this paper was only 4.96M, consisting of materials related to electrical power defects, such as substation operation and maintenance question banks, substation primary equipment defect classification standards, and distribution network fault analysis cases. This size is significantly smaller than the training corpus of BERT and other publicly available datasets. Our corpus comprises 31,079 sentences, with
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph
721
each sentence treated as a separate line. We divided the corpus into training, validation, and test sets in an 8:1:1 ratio. 4.2 Experimental Setup We have leveraged a number of components from the PyKEEN [15] library in our research. Specifically, we utilized the built-in functionality in the library to automatically split our dataset. The fine-tuning process was carried out using the pre-training tools provided by HuggingFace [16], and given that our dataset is in Chinese, we opted to use the BERT-Chinese-Base as our pre-training model. For the fine-tuning pre-training model, we set the learning rate to 2e-5 and the train_batch_size to 128. We analyzed the length distribution of our corpus, as presented in Fig. 4, and established a trade-off max_seq_length of 256. Any sentences exceeding this length will be truncated, while sentences that are too short will be padded. The number of training epochs is set at 40.
Fig. 4. Sentence length distribution of our fine-tuning corpus.
In our KGEMs training setup, we set the maximum training rounds to 300 and evaluation rounds to 10, with early stopping enabled, having a tolerance of 5 and a relative delta of 0.01. The input dimension for KGEMs was adjusted to 300, achieved by applying PCA to reduce BERT’s hidden layer output from 768 to 300 dimensions. We used TransH and DistMult as our KGEMs, implemented via PyKEEN, with all other settings remaining consistent with the original implementation. TransH is one of the Trans series models [17] proposed by Wen et al. in 2014 [18]. The model defines a projection matrix for each relation, which projects the relation vector into a new vector space. In this new space, entities of the same relationship are mapped onto
722
Y. Chen et al.
the same hyperplane after projection. The model then calculates the similarity between entities in this new vector space and uses this similarity value to predict missing triples. DistMult was proposed by Yang et al. in 2014 [19]. The main idea of the model is based on the traditional TransE [20] model for cross-entity relationship prediction, and it uses a special tensor decomposition method to represent the mutual relationship between entities and relationships. Specifically, DistMult represents each relation as a diagonal matrix and each entity as a vector, and uses this vector to perform element-wise multiplication with the diagonal matrix to calculate the similarity between entities. 4.3 Results and Analysis We carried out three experiments. The first used KGEMs directly for link prediction tasks. The second utilized the pre-trained model’s original settings to extract features from knowledge graph nodes and edges for initial KGEMs embeddings in link prediction tasks. The third experiment, our proposed method, involved fine-tuning a pre-trained model on a domain-specific corpus and repeating the second experiment. In our evaluation, we adopt two performance metrics: hit@k and AAMR (Adjusted Arithmetic Mean Rank) [21], where k is set to 1, 5, and 10. The Adjusted Arithmetic Mean Rank is an improvement over the standard arithmetic mean ranking, designed to penalize outliers and improve the stability of the evaluation metric. To mitigate the effect of chance, we run each experiment 10 times and report the maximum and average hit@k values, the minimum and average AAMR values. The experimental results are shown in Table 1 and Table 2. In the table, the optimal value (maximum hit@k or minimum AAMR) is displayed, while the average value is shown in brackets. The optimal value is highlighted in bold. The experimental results show that our proposed method overcomes the performance loss from using BERT directly and improves KGEMs’ effectiveness. BERT, trained on general corpora, struggles with professional field vocabulary, which includes slang, abbreviations, and jargon. Extracting inaccurate semantic information may lead to misleading node feature initialization, making the embeddings less effective than random ones. Fine-tuning BERT on a domain-specific corpus allows it to better comprehend sentence semantics and extract more valuable features, making it an effective feature extractor. Table 1. The Results of TransH. Model
Hit@1
Hit@5
Hit@10
AAMR
TransH
0.616(0.571)
0.893(0.878)
0.908(0.898)
0.0056(0.0140)
TransH-BERT
0.604(0.473)
0.886(0.784)
0.914(0.812)
0.0061(0.0524)
TransH-Ours
0.679(0.614)
0.920(0.889)
0.951(0.923)
0.0026(0.0091)
We conducted an analysis of the PCA dimensionality reduction algorithm used, performing repeated experiments with initial embeddings of 50, 100, 150, 200, 250, and
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph
723
Table 2. The Results of DistMult. Model
Hit@1
Hit@5
Hit@10
AAMR
DistMult
0.613(0.582)
0.899(0.888)
0.912(0.906)
0.0351(0.0411)
DistMult-BERT
0.578(0.521)
0.897(0.889)
0.921(0.916)
0.0238(0.0282)
DistMult-Ours
0.637(0.602)
0.901(0.892)
0.918(0.912)
0.0246(0.0297)
300 dimensions. Our results show that the 300-dimensional embedding provided the best performance, as we did not attempt to use embeddings of higher dimensions due to the occurrence of CUDA out of memory errors on our machine. The experimental results indicate that using more dimensions leads to better performance. The performance of the model declined during the upgrade from 100-dimensional to 200-dimensional, possibly due to the instability of the PCA algorithm, which caused slightly information loss that was rectified in the 300-dimensional embedding. We present the performance of our proposed algorithm on these dimensions, using the hit@1 indicator. For brevity, we only report the maximum value, and its line graph is displayed in Fig. 5.
Fig. 5. As the dimension increases, hit@1 shows an erratic increase and reaches its peak at 300.
5 Conclusion In this paper, we investigated the effectiveness of fine-tuning a pre-trained language model on a smaller corpus to address the performance degradation associated with utilizing the original BERT as a feature extractor. We propose an approach that constructs a knowledge graph via ontology design and utilizes a fine-tuned pre-trained language model as a feature extractor for KGEMs to enhance link prediction results. We conducted our experiments on a real-world dataset related to main electrical equipment, and our
724
Y. Chen et al.
model achieved significant improvements in link prediction tasks using KGEMs TransH and DistMult. Acknowledgment. This work is supported by Major Program of Xiamen (3502Z20231006); National Natural Science Foundation of China (62176227,U2066213); Fundamental Research Funds for the Central Universities (20720210047).
References 1. Fan, S., Liu, X., Chen, Y., et al.: How to construct a power knowledge graph with dispatching data? Sci. Program. 2020, 1–10 (2020) 2. Lü, L., Zhou, T.: Link prediction in complex networks: a survey. Physica A 390(6), 1150–1170 (2011) 3. Wang, Q., Mao, Z., Wang, B., et al.: Knowledge graph embedding: a survey of approaches and applications. IEEE Trans. Knowl. Data Eng. 29(12), 2724–2743 (2017) 4. Devlin, J., Chang, M.W., Lee, K., et al.: Bert: Pre-Training of Deep Bidirectional Transformers for Language Understanding. arXiv preprint arXiv:1810.04805 (2018) 5. Liu, Y., Ott, M., Goyal, N., et al.: Roberta: A Robustly Optimized Bert Pretraining Approach. arXiv preprint arXiv:1907.11692 (2019) 6. Reimers, N., Gurevych, I.: Sentence-bert: Sentence Embeddings Using Siamese BertNetworks. arXiv preprint arXiv:1908.10084 (2019) 7. Ma´ckiewicz, A., Ratajczak, W.: Principal Components Analysis (PCA). Computers & Geosciences 19(3), 303–342 (1993) 8. Liu, Q., Kusner, M.J., Blunsom, P.: A Survey on Contextual Embeddings. arXiv preprint arXiv:2003.07278 (2020) 9. Liu, W., Zhou, P., Zhao, Z., et al.: K-bert: Enabling language representation with knowledge graph. Proceedings of the AAAI Conference on Artificial Intelligence 34(03), 2901–2908 (2020) 10. Yao, L., Mao, C., Luo, Y.: KG-BERT: BERT for Knowledge Graph Completion. arXiv preprint arXiv:1909.03193 (2019) 11. Li, D., Yi, M., He, Y.: Lp-bert: Multi-Task Pre-Training Knowledge Graph Bert for Link Prediction. arXiv preprint arXiv:2201.04843 (2022) 12. Gururangan, S., Marasovi´c, A., Swayamdipta, S., et al.: Don’t Stop Pretraining: Adapt Language Models to Domains and Tasks. arXiv preprint arXiv:2004.10964 (2020) 13. Sushil, M., Suster, S., Daelemans, W.: Are we there yet? exploring clinical domain knowledge of BERT models. In: Proceedings of the 20th Workshop on Biomedical Language Processing, pp. 41–53 (2021) 14. Gruber, T.R.: Toward principles for the design of ontologies used for knowledge sharing? Int. J. Hum. Comput. Stud. 43(5–6), 907–928 (1995) 15. Ali, M., Berrendorf, M., Hoyt, C.T., et al.: PyKEEN 1.0: a python library for training and evaluating knowledge graph embeddings. The Journal of Machine Learning Res. 22(1), 3723– 3728 (2021) 16. Wolf, T., Debut, L., Sanh, V., et al.: Transformers: State-of-the-art natural language processing. In: Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations, pp. 38–45 (2020) 17. Ji, S., Pan, S., Cambria, E., et al.: A survey on knowledge graphs: representation, acquisition, and applications. IEEE trans. Neural Networks Learning Syst. 33(2), 494–514 (2021)
A Novel Approach to Analyzing Defects: Enhancing Knowledge Graph
725
18. Wang, Z., Zhang, J., Feng, J., et al.: Knowledge graph embedding by translating on hyperplanes. In: Proceedings of the AAAI Conference on Artificial Intelligence 28(1) (2014) 19. Yang, B., Yih, W., He, X., et al.: Embedding Entities and Relations for Learning and Inference in Knowledge Bases. arXiv preprint arXiv:1412.6575 (2014) 20. Bordes, A., Usunier, N., Garcia-Duran, A., et al.: Translating embeddings for modeling multirelational data. Adv. Neural Information Processing Syst. 26 (2013) 21. Suchanek, F.M., Kasneci, G., Weikum, G.: Yago: a large ontology from wikipedia and wordnet. J. Web Semantics 6(3), 203–217 (2008)
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction Haitao Zhang1 , Yuxing Dai2 , Qing Yin1 , Xin He1 , Jian Ju1 , Haotian Zheng1 , Fengling Shen1 , Wenjuan Guo1 , Jinhu Li3 , Zhihong Zhang2(B) , and Yong Duan1 1 State Grid Shaanxi Marketing Service Center (Metrology Center), Xi’an, China 2 School of Informatics, Xiamen University, Xiamen, China
[email protected] 3 State Grid Info-Telecom Great Power Science and Technology Co., Ltd., Xiamen, China
Abstract. Accurately predicting electricity demand is crucial for optimizing power resource allocation, improving the safety and economic performance of power grid operations, and providing significant economic and social benefits. To address this challenge, we propose a hybrid model that combines Convolutional Neural Networks (CNNs) and Long Short-Term Memory (LSTM) networks for predicting electricity demand across different industries in urban areas. The proposed model leverages the LSTM component to capture the temporal patterns of the time series data and the CNN component to extract spatial features of electricity demand across different areas. We evaluate our model on a diverse dataset of electricity demand from multiple city areas and industries. The experimental results demonstrate that our proposed model outperforms state-of-the-art methods, resulting in significant improvements in the accuracy of electricity demand prediction. Overall, our proposed hybrid model provides a valuable framework for accurately predicting electricity demand and has practical implications for power grid operations and management in urban areas. Keywords: Convolutional Neural Network · Time Series Network · Electricity Demand Prediction
1 Introduction Accurate electricity demand predictions can optimize the rational allocation of power resources and improve the economy and safety of power grid operation, which has significant economic and social significance. Therefore, how to accurately predict the electricity demand of different regions, different types of electricity demand, and even different industries in different time scales is of great significance to the healthy and stable development of power grid enterprises and even the steady advancement of the construction of the power trading market system. Electricity demand prediction is a technology widely used in energy management and power system planning. It aims to help electricity suppliers effectively manage and plan their production and distribution by analyzing historical electricity demand © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 726–737, 2023. https://doi.org/10.1007/978-981-99-4761-4_61
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
727
data, weather data, economic indicators, and other relevant factors to predict future electricity demand. In recent years, with the development of data science and artificial intelligence technology, the methods and tools for electricity demand prediction have been constantly updated and improved. Traditional time series models have been replaced by machine and deep learning models, and the combination of multiple factors has been widely used in electricity demand prediction. The emergence of these new methods and technologies has led to significant improvements in the accuracy and efficiency of electricity demand prediction, as well as better electricity demand experiences and services for both electricity suppliers and customers. The more commonly used electricity prediction models include regression analysis prediction models, gray system prediction models, and other models. Box et al. [1] proposed a time series-based ARIMA model. Vahakyla et al. [2] used a differential autoregressive moving average model in load power prediction with significant results. Christiaanse et al. [3] used an exponential change model for electricity demand prediction and obtained more accurate prediction results. Taylor et al. [4] used a model with a dual seasonal index to improve the accuracy of load electricity prediction. Papalexopoulos et al. [5] used multiple regression models to model the prediction of load electricity and achieved better results. Amjady et al. [6] pioneered designing a new ARIMA model to carry out short-term electricity power prediction, which focused on the relationship between temperature and electricity data to describe day-to-day electricity variability. Feinberg et al. [7] took into account period factors, meteorological variations, and multidimensional classification of electricity users to predict the power system by setting various types of parameters short-term power values. Electricity demand prediction plays a crucial role in ensuring power systems’ reliable and efficient operation. Accurate prediction of electricity demand can facilitate the planning and operation of the power grid, as well as enable utilities to make informed decisions on investment and pricing strategies. In this paper, we focus on the task of predicting electricity demand for different industries in a city using a combination of Convolutional Neural Network (CNN) and Long Short-Term Memory (LSTM) models. We summarize the contribution of this paper to electricity demand prediction in the following aspects. 1. Firstly, the CNN-LSTM model can leverage the strengths of both models, with CNN being effective in capturing spatial dependencies in input data, and LSTM being effective in capturing long-term temporal dependencies. This combination can potentially improve the accuracy of prediction results. 2. Secondly, the proposed model uses historical electricity demand data, weather data, and calendar information as input features. This allows the model to capture complex relationships between multiple input variables and the target variable, which is crucial for accurate demand prediction. 3. Last but not least, we conduct experiments on our proposed method on a real-world electricity demand dataset. Qualitative and quantitative analyses show the excellent performance of our proposed method and verify the validity of the model.
728
H. Zhang et al.
2 Proposed Method In this section, we first present the overall framework of the model for the task of predicting electricity demand in different sectors of the city. Secondly, we summarize the input information of the model and analyze the relationship between different variables. Finally, we describe the CNN-LSTM model with detailed splitting for the final time series prediction task. 2.1 Overall Framework Electricity demand prediction and time series prediction are closely related because electricity demand data typically exhibit time series patterns. Time series prediction methods are often used to model and predict electricity demand data over time. These methods analyze historical data to identify patterns and trends and use this information to predict future electricity demand. Thus, time series prediction can be used to predict future electricity demand, which is a critical component of electricity demand prediction. In this paper, we adopt the combination of CNN and LSTM networks to fully realize their advantages. The CNN-LSTM algorithm framework is shown in Fig. 1, which consists of two main parts, the CNN part is mainly responsible for feature extraction and the LSTM network is primarily responsible for electricity demand prediction.
Fig. 1. CNN-LSTM framework.
The CNN-LSTM model can be seen as a two-stage process: feature extraction with the CNN and prediction with the LSTM. The CNN is used to extract features from the input time series data and the LSTM is used to predict future demand values based on the extracted features. To do this, the output from the CNN is fed into the LSTM network as a sequence of features. The LSTM takes these features as inputs and learns to predict future demand values based on past demand values and the extracted features.
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
729
One of the key advantages of using the CNN-LSTM model is that the CNN can learn and extract features that capture the temporal dependencies and patterns in the input data, and the LSTM can use these extracted features to make accurate predictions for future demand values. The dataset used in this paper contains features including daily and monthly electricity demand data for each power supply region, maximum load, and the type of electricity in the province. Likewise, meteorological information is indispensable, and the dataset also contains things like daily maximum temperatures and daily minimum temperatures. The integration of different features can be very helpful in the task of electricity demand prediction. In this paper, we express the data of these dimensions as a high-dimensional time series representation. (1) Yn = yn1 , yn2 , · · · , ynp , n = 1, 2, · · · , N Similarly, there are time series data with different characteristic covariates. Xn = xn1 , xn2 , · · · , xnp , n = 1, 2, · · · , N
(2)
For such a batch of data, how to use suitable statistical and artificial intelligence methods to target sequence Yn to build a prediction model and analyze it, and to accomplish the prediction of each dimension in the target accuracy is the focus of this paper. 2.2 Feature Extraction Module Convolutional Neural Networks (CNNs) have been widely used in image recognition tasks. However, CNNs can also be used in time series data processing, such as electricity demand data. In the case of electricity demand data, the data can be treated as a 1dimensional signal, where the input features are the time steps and the output is the corresponding electricity demand value. CNNs can learn and extract meaningful patterns and features from time series data. The convolution operation in CNNs is used to detect local patterns in the data, which can be useful for extracting features from electricity demand data. The convolutional layer consists of multiple filters, each of which can be seen as a feature detector. The filters slide across the input data and produce feature maps that capture the patterns and features of the data. In the case of electricity demand data, CNN can be used to extract features such as daily, weekly, or monthly demand patterns, trends, and seasonality. For example, a CNN can be trained to detect and extract the daily demand pattern from the data. The filter in this case would have a length equal to 24, representing the 24 h in a day. The filter would slide over the input data, and the resulting feature map would represent the daily demand pattern. Moreover, using multiple convolutional filters with different lengths can help capture different patterns and features of the data. For instance, a filter of 7 can capture weekly demand patterns, while a filter of 30 can capture monthly demand patterns. We can design a CNN model to extract features from these data using the following formula: hi = f (W · xi + b)(i = 1, 2, ..., p)
(3)
730
H. Zhang et al.
In Eq. 3, xi represents the input at the i-th time step of the electricity demand data Yn or covariate data Xn , W is the weight matrix, b is the bias term, and f (·) is the activation function. hi represents the feature vector extracted by the i-th convolutional kernel. 2.3 Time Series Prediction Module LSTM networks are a type of RNN that are designed to handle sequential data such as time series data. LSTM networks are capable of learning long-term dependencies and can store information over a longer period of time compared to traditional RNNs. In the context of electricity demand prediction, the features extracted by the CNN can be used as inputs to an LSTM network for predicting future demand values. The LSTM network can be trained on the historical demand data to learn the patterns and relationships in the data and use these to make predictions. Different from standard neural units, the LSTM network layer has specially designed memory neural units. This gives it a memory of the input sequence. Each memory unit contains three types of state management gates: forget gate (z f ), input gate (z i ), and output gate (z o ). It also includes memory cells (ct ) and the hidden states (ht ). When operating on the input sequence, each gate in the memory unit uses the activation function σ (·) to control whether they are triggered, making the change of state and the addition of information through the unit subject to certain conditions. 1. Forget gate (z f ): Let xt be the input at the current moment and ht−1 be the output of the hidden layer at the previous moment. When the input flows through the LSTM network layer, it should first go through the forgetting gate to conditionally decide to discard some irrelevant information from the cell state. 2. Input gate (z i ): Let ct is the current state candidate, ct−1 is the cell state at the previous moment, and ct is the current cell state. After the cell state has been forgotten by the forgetting gate, the input gate conditionally decides which values from the current input to add to the current cell state. 3. Output gate (z o ): Let ht be the output at the current moment. After a state update, the output gate conditionally decides which information to output.
The framework of LSTM is shown in Fig. 2. Each LSTM cell is like a miniature state machine. The weights of each gate are also learned during the training process. After the CNN model extracts the features from the electricity demand and covariate data, we can use them as input to an LSTM model for further processing. Specifically, we can employ the extracted feature vectors hi = ht as the input to an LSTM cell. By using the extracted features from the CNN model as input to the LSTM model, we can capture the temporal dependencies and patterns in the electricity demand and covariate data over a longer time horizon, which can further improve the accuracy of predictions. The CNN-LSTM model offers several advantages for accurately predicting the electricity demand of various industries in cities. Firstly, it can effectively handle complex and multi-modal time-series data, including the covariates such as electricity demand data and weather factors. Secondly, the CNN component of the model can capture the local spatial patterns of the input data, while the LSTM component can capture the longterm dependencies and temporal dynamics in the time-series data. Thirdly, the model
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
731
Fig. 2. CNN-LSTM framework.
can learn from both past and current information, making it suitable for real-time prediction. Overall, the CNN-LSTM model provides an efficient and accurate approach for electricity demand prediction, which can help optimize energy usage and planning in cities.
3 Experiment 3.1 Experimental Dataset This paper utilizes a dataset covering the period from January 1, 2020, to January 31, 2023. The dataset includes electricity demand data from January 1, 2020, to June 17, 2022, which are used for training, and the subsequent 225 data points are used for medium-term prediction. The training and testing sets are divided into a 4:1 ratio. The dataset also contains meteorological features and other relevant factors, including holidays and weekends, which impact electricity demand patterns. The data has undergone pre-processing to eliminate any outliers or missing values. The dataset provides a valuable resource for developing and testing machine learning models that can capture the complex and dynamic nature of electricity demand patterns in urban areas. The following is an introduction to the relevant features of the dataset. Dataset Characteristics: a) Time-related characteristics: the position of year, month, day, week, date, and week in the middle of the year; b) Characteristics related to temperature: average temperature, minimum temperature, and maximum temperature; c) Weather-related characteristics: wind speed, surface pressure, and overall humidity; d) Other characteristics: maximum load: this variable is given in the electricity demand data set, and it is also the most important covariant; Whether it is a holiday or not: this variable is mainly a covariate to assist in judging the life of urban and rural residents and the tertiary industry such as accommodation and catering industry, which includes all holidays, including weekends and all legal holidays; Summer vacation or not: used to fit the peak value of electricity demand.
732
H. Zhang et al.
3.2 Experimental Baselines We give a detailed description of the experimental comparison methods used in the following paper. 1. Prophet: It is a time series prediction model developed by Facebook’s Data Science team. It is designed to handle a variety of time series data, including seasonality and trends. In the context of electricity demand prediction, Prophet can capture the complex patterns of electricity demand at different time scales, such as daily, weekly, and yearly trends. It is also able to handle missing data and outliers, which are common in time series data. Prophet is an easy-to-use model that requires minimal hyperparameter tuning and can be applied to a wide range of time series prediction tasks. 2. GBDT (Gradient Boosting Decision Tree): It is a machine learning model that has been widely used for regression and classification tasks. In the context of electricity demand prediction, GBDT can be used to learn the non-linear relationship between the input features and the target variable, which is electricity demand. It is a powerful model that can handle high-dimensional and non-linear data, and it can be trained efficiently using large datasets. 3. LSTM: Similarly, the classical time series network LSTM is used in this paper for time series prediction. It is worth noting that the network is trained and predicted only for electricity demand data and is not as complex as the model proposed in this paper, which considers other covariates. 3.3 Experimental Settings In this paper, we use the min-max normalization method, that is, each column of data is normalized to the interval [0, 1], and then input to the network for training, and the final prediction results are to be inverse of the normalization operation: x∗ = x − xmin /xmax − xmin
(4)
where x denotes the electricity demand data, xmax and xmin denote the maximum and minimum values of the data, respectively. In our experiments, we use the MSE function as the loss function for the regression prediction task. As in Eq. 5, where x denotes the predicted value of electricity demand of the model at the current moment, y denotes the ground truth of electricity demand at the current moment, and n denotes the number of training samples. In the gradient optimization process of the model, we employ the Adam optimizer to train the model. Also, we set the learning rate to 10−3 . To better learn the nonlinearity of the model, we adopt ReLU as the activation function. n (x − y)2 (5) MSEloss = i=1 n
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
733
The CNN-LSTM model used in this study uses a time series feature map as the input to the network input. The data such as electricity history data, temperature information, and holiday information are mutually independent time series. In order to couple these features that affect electricity, this paper refers to the word vectors in natural language processing. In this paper, we refer to the word vector representation method in natural language processing and represent the electricity value at a certain time by its associated features in a matrix. 3.4 Experimental Settings In our experiments, we employ two evaluation metrics to judge the ability of the model on the electricity demand prediction task. The evaluation metrics are monthly error and daily error, and the calculation formulas are shown in Eqs. 6 and 7, respectively. n n i=1 ypred (i) − i=1 yact (i) (6) montherror = n y (i) i=1 act n 1 ypred (i) − yact (i) (7) dailyerror = n yact (i) i=1
where ypred (i) denotes the predicted value of electricity demand on the day i and yact (i) denotes the ground truth of electricity demand on the day i. In our experiments, we have launched electricity demand prediction tasks for different cities. The industries are divided as follows into a total of 13 industries. We have also abbreviated the expression for easy reading. 1) Urban and rural residential electricity demand (Urre); 2) Agriculture, forestry, animal husbandry, and fishery (Afaf); 3) Accommodation and catering industry (Aci); 4) Construction industry (Ci); 5) Real estate industry (Rei); 6) Industrial sector (Is); 7) Information transmission, software, and information technology services industry (Itsitsi); 8) Total electricity demand of the society (Teds); 9) Financial industry (Fi); 10) The wholesale and retail industry (Twri); 11) Rental and business services industry (Rbsi); 12) Public service and management service (Psms); 13) Transportation, warehousing, and postal industry (Twpi). The results of our experiments on predicting the next three months are shown in Table 1. The monthly error is abbreviated as ME, the daily error is abbreviated as DE, and “Model-Error” indicates the monthly error or daily error using the method. For the sake of data confidentiality, we call the predicted city A-City. Here we give the abbreviations of the models: CNN-LSTM (CL), GDBT (GD), Prophet (Pr), and LSTM (L). We mark the monthly error and the daily error with the smallest absolute value in Table 1 in red for the best results.
734
H. Zhang et al.
Table 1 predicts the electricity demand predictions for 13 industries for the month of August. As can be seen from Table 1, a comparison with the comparison methods shows that the CNN-LSTM model proposed in this paper achieves the lowest errors in terms of monthly and daily errors in seven industries, and at least the lowest results in terms of monthly or daily errors in 11 industries. The effectiveness of the CNN-LSTM model is fully verified. In contrast, the model using only the LSTM achieves the lowest error in only 2 industries, which also illustrates the indispensability of the CNN module of the CNN-LSTM model. Moreover, the Prophet model exhibits the poorest performance with the highest errors across all 13 industries. The GDBT model achieves slightly better results than the Prophet model but still lags behind the CNN-LSTM model in terms of accuracy. The evaluation results presented in Table 1 demonstrate the effectiveness of the proposed CNN-LSTM model in accurately predicting electricity demand in diverse industries across the month of August. Overall, our proposed CNN-LSTM model presents a promising approach for achieving accurate electricity demand prediction in urban areas, with potential applications in power grid operation and management. Table 1. Electricity demand prediction error results of different models in 13 industries in A-City
Industry Urre Afaf Aci Ci Rei
CLME -0.270 -0.012 -0.020 0.044 -0.027
Is Itsitsi Teds Fi Twri Rbsi Psms Twpi
0.330 0.029 -0.040 0.012 0.024 -0.003 -0.077 0.007
CLDE -0.177 0.090 0.007 0.066 0.0007 0.330 0.030 0.0006 0.029 0.045 0.029 -0.047 0.017
Evaluation Metrics (Model-Error) GDGDPr-ME Pr-DE ME DE -0.380 0.380 4.600 4.600 -0.200 0.300 1.900 1.900 -0.140 0.200 0.210 0.360 -0.150 0.170 0.550 0.670 -0.180 0.200 0.400 0.500 0.120 -0.060 -0.120 -0.100 -0.130 -0.200 -0.090 0.004
0.120 0.080 0.160 0.200 0.160 0.220 0.170 0.060
0.530 0.520 0.860 0.200 0.400 0.510 0.330 0.004
0.620 0.550 0.940 0.300 0.640 0.620 0.400 0.060
L-ME
L-DE
-0.240 -0.028 -0.077 -0.055 0.0288 0.083 0.015 -0.028 -0.018 -0.025 -0.060 -0.050 0.007
0.340 0.280 0.084 0.203 0.138 0.088 0.040 0.100 0.128 0.094 0.137 0.153 0.057
Here, we also visualize the curve of the electricity demand prediction results of CNNLSTM and the comparison model in August. Among them, the blue curve represents the model proposed in this paper, and the purple curve represents the ground truth of electricity demand. The results of the curve comparison are shown in Figs. 3, 4, and 5. The curve fitting degree of CNN-LSTM is very close to the ground truth in the electricity demand prediction of the industries construction industry, transportation, warehousing, and postal industry, information transmission, software, and information technology
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
735
services industry, accommodation and catering industry, financial industry and rental and business services industry. As can be seen from the figures, CNN-LSTM outperforms the other three methods in most industries, especially in the “construction industry”, “transportation, warehousing, and postal industry”, “information transmission, software, and information technology services industry”, and “financial industry”. The LSTM and Prophet methods also perform well in some industries, but their performance is inferior to that of CNN-LSTM. The GDBT method shows poor performance in all industries compared to the other three methods. Overall, the results demonstrate the effectiveness of the proposed CNN-LSTM model in electricity demand prediction and its superiority over the traditional LSTM, Prophet, and GDBT methods. In conclusion, we propose a CNN-LSTM model for electricity demand prediction, which combines the strengths of CNN and LSTM models in capturing temporal and spatial features of the data. The proposed model outperforms the traditional LSTM, Prophet, and GDBT methods in terms of accuracy and reliability. This research provides insights for future studies on electricity demand prediction and has practical implications for energy management in various industries.
Fig. 3. Comparative results of electricity demand predictions for different models in the construction industry and transportation, warehousing, and postal industry. (Color figure online)
736
H. Zhang et al.
Fig. 4. Comparative results of electricity demand predictions for different models in information transmission, software, and information technology services industry and accommodation and catering industry. (Color figure online)
Fig. 5. Comparative results of electricity demand predictions for different models in financial industry and rental and business services industry. (Color figure online)
Hybrid CNN-LSTM Model for Multi-industry Electricity Demand Prediction
737
4 Conclusion In this paper, we propose a CNN-LSTM model to predict the electricity demand of different industries in a city. Our model utilizes convolutional neural networks to extract meaningful features from the raw electricity data and then employs long short-term memory networks to capture temporal dependencies and generate accurate predictions. We evaluate our model on a real-world dataset of electricity demand in a major city and achieve promising results with low prediction errors across different industries. Our experiments also demonstrate the effectiveness of our proposed model in comparison to other baseline methods. As for future research directions, we suggest investigating the following areas: 1. Exploring alternative network architectures, such as attention mechanisms and graph neural networks, to capture more complex relationships among industries and regions. 2. Investigating the transferability of the proposed model to other cities or regions, and adapting it to different datasets and scenarios. Overall, our proposed CNN-LSTM model has shown great potential in predicting the electricity demand of different industries in a city, and we believe that further research can contribute to improving its performance and expanding its applicability. Acknowledgements. This work is supported by the Research Funds from State Grid Shannxi (SGSNYX00SCJS2310024); Major Program of Xiamen (3502Z20231006); National Natural Science Foundation of China (62176227, U2066213); Fundamental Research Funds for the Central (20720210047).
References 1. Box, G.E., Jenkins, G.M., Reinsel, G.C., Ljung, G.M.: Time series analysis: forecasting and control. Wiley, Hoboken (2015) 2. Vähäkyla, P., Hakonen, E., Léman, P.: Short-term forecasting of grid load using Box-Jenkins techniques. Int. J. Electr. Power Energy Syst. 2(1), 29–34 (1980) 3. Christiaanse, W.: Short-term load forecasting using general exponential smoothing. IEEE Trans. Power Appar. Syst. 2, 900–911 (1971) 4. Taylor, J.W.: Short-term electricity demand forecasting using double seasonal exponential smoothing. J. Oper. Res. Soc. 54(8), 799–805 (2003) 5. Papalexopoulos, A.D., Hesterberg, T.C.: A regression-based approach to short-term system load forecasting. IEEE Trans. Power Syst. 5(4), 1535–1547 (1990) 6. Amjady, N.: Short-term hourly load forecasting using time-series modeling with peak load estimation capability. IEEE Trans. Power Syst. 16(3), 498–505 (2001) 7. Feinberg, E.A., Genethliou, D.: Load forecasting. In: Chow, J.H., Wu, F.F., Momoh, J. (eds.) Applied Mathematics for Restructured Electric Power Systems, pp. 269–285. Springer, Boston (2005). https://doi.org/10.1007/0-387-23471-3_12
Improve Knowledge Graph Completion for Diagnosing Defects in Main Electrical Equipment Jianye Huang1 , Jian Qian1 , Yanyu Chen2 , Rui Lin3 , Yuyou Weng1 , Guoqing Lin1 , and Zhihong Zhang2(B) 1 State Grid Fujian Electric Power Research Institute, Fuzhou 350000, China 2 School of Informatics, Xiamen University, Xiamen 361005, China
[email protected] 3 Fuzhou Power Supply Company of State Grid Fujian Electric Power Co., Ltd.,
Fuzhou 350000, China
Abstract. Defects in main electrical equipment can potentially cause power grid collapses, leading to catastrophic risks and immense pressure on dispatch operations. We adopt a Knowledge Graph Completion (KGC) approach to diagnose defects in main electrical equipment. The conventional approach for knowledge graph completion involves randomly initializing the feature representation of nodes and edges. In the context of main electrical equipment, we propose an innovative method based on ChatGPT to generate a corpus for fine-tuning the pre-trained model. We then employ the pre-trained model as a feature extractor for nodes and edges, enhancing the relationships between nodes and improving the initial embedding quality of both nodes and edges. Our pre-trained model fine-tuning process can be efficiently executed on a CPU with less than 10% of the parameters required for fine-tuning the entire pre-trained model. Our experimental results show that our model effectively improves the performance of KGC in diagnosing defects in main electrical equipment. Keywords: Knowledge graph completion · Defect diagnosis · Pre-training model
1 Introduction The safe and reliable operation of the power grid is of paramount importance, as defects in main electrical equipment can lead to severe consequences, such as power outages, equipment damage, and significant societal disruption. Ensuring the prompt and accurate diagnosis of such defects is essential to maintain system reliability and mitigate the negative impacts on people’s daily lives and the economy. The emergence of knowledge graphs as an innovative knowledge engineering technology has provided a structured approach for describing defect events, properties, and their relationships. This allows for a more comprehensive representation of complex defect relationships. In the context of main electrical equipment defect diagnosis, our © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 738–748, 2023. https://doi.org/10.1007/978-981-99-4761-4_62
Improve Knowledge Graph Completion for Diagnosing Defects
739
objective is to predict the defective components based on the given defect phenomena and their relationships. Additionally, the model can take an entity from a specific knowledge graph and a given relationship as inputs and output the tail entity under this relationship. To this end, we propose an approach that involves extracting triplets from tabular defect data, constructing a knowledge graph, and employing a Knowledge Graph Completion [1] algorithm for defect diagnosis. Given input defect phenomena or other entities, along with the relationship that requires querying, our KGC model outputs the most probable alternative answers. The prediction of the tail entity involves scoring the probability of all entity nodes in the knowledge graph and sorting them accordingly. The correct answer is determined by selecting the entity with the highest probability, and the model reports the hit@k alternative answers. For example, given the defect phenomenon “the wind turbine does not rotate, and the circuit breaker for fan #4 in the cooling box trips and opens after being closed”, and the relation “defect phenomenon – defective part,” the objective is to predict the defective part, which in this case is the “cooling box.” Our system’s general design is illustrated in Fig. 1.
Fig. 1. In this example, the input for the KGC model is the head entity (defect phenomenon) and the queried relation (phenomenon to part), with the model required to predict the tail entity. The KGC model is trained using defect data, which is derived from structured tables and converted into triplets. These triplets construct the knowledge graph, thereby supporting the training of the KGC model.
Inspired by previous studies, we recognized that secondary fine-tuning pre-training in specialized fields can yield better results [2]. We fine-tuned the language model using a carefully designed task, opting for a sentence classification task. Specifically, we constructed our CoLA (Corpus of Linguistic Acceptability) dataset, which comprises 80% original texts labeled as grammatically correct sentences, while the remaining 20% of
740
J. Huang et al.
sentences were generated by ChatGPT and are labeled as incorrect. Our work explores the generation of binary classification datasets from a basic corpus and fine-tuning the language model to better understand professional knowledge. To save computational resources during fine-tuning and enable running on a CPU environment, we chose to use Adapter BERT [3], a model that adds Adapter layers to BERT [4] without directly fine-tuning its parameters. This approach significantly reduces computational costs. In this paper, we have designed a fully connected layer to be added after the output of BERT, which is then connected to the KGC model to reduce the dimensionality of the input. By doing so, we can further reduce the computational costs of the KGC algorithm. Our goal is to reduce the computational demands of the application, making it easier to deploy even on hardware with average capabilities. In summary, our contributions can be encapsulated in two main points: – We propose a novel approach that utilizes ChatGPT to generate a labeled CoLA corpus from Long-text corpus of power systems, facilitating the fine-tuning of pretraining models in the field of electrical main equipment. This enables the model to better understand the specialized knowledge required for defect diagnosis in main electrical equipment. – By using the language model fine-tuned on the electric power defects corpus, we extract features of nodes and edges in the knowledge graph to initialize the embedding of the KGC model. This enhances the model’s performance for electrical main equipment defect diagnosis, allowing for more accurate and efficient predictions. Our proposed approach demonstrates the potential for leveraging knowledge graphs and advanced natural language processing techniques to improve the diagnosis of defects in main electrical equipment. By enhancing the model’s ability to accurately predict defective components, we contribute to the ongoing efforts to maintain the reliability and safety of power grid systems, ultimately benefiting society as a whole.
2 Related Work Various techniques have been developed to detect and diagnose defects in power equipment, such as text mining techniques [5]. However, these techniques face challenges in practical applications. Another approach involves machine learning methods using the technical condition index as an informative parameter for quick and accurate assessment of technological disturbances in power equipment operation [6]. Fault Detection and Diagnosis (FDD) techniques ensure the safety and reliability of power equipment. Both traditional and relatively new signal processing-based FDD approaches have been developed, with artificial intelligence-based FDD methods gaining attention [7]. Despite advancements in FDD techniques, challenges and future research trends still need to be addressed to develop new approaches for diagnosing power equipment defects. Knowledge graphs have shown promise in various domains, particularly in defect diagnosis applications [8]. The construction of a large-scale knowledge graph of electric power equipment faults can facilitate the development of automatic fault diagnosis and intelligent question-answering in the electric power industry. A novel BERT-BiLSTMCRF model [4, 9, 10] has been proposed for this purpose, which demonstrates improved
Improve Knowledge Graph Completion for Diagnosing Defects
741
accuracy in Chinese entity recognition and extraction from technical literature, thus enabling the construction of a comprehensive and accurate Chinese knowledge graph of electric power equipment faults. KGs integrate data, aiding fault disposal and enhancing power system emergency handling [11]. A review on KGs in fault diagnosis offers a guide for further research. A multimodal knowledge graph addresses limitations of traditional KG methods [12], with successful applications showing potential in KGC tasks for diagnosing power equipment defects. Pre-trained models (PTMs) have revolutionized natural language processing (NLP) by significantly improving performance in various tasks [13]. Models such as BERT demonstrate the advantage of learning general language representations that can be finetuned for specific domains and tasks. Adapting PTMs to downstream tasks has proven effective in boosting performance and has opened up potential research directions. ChatGPT, an NLP model built using the transformer neural network, has shown success in natural language generation tasks across diverse applications, including customer service, healthcare, education, and research [14]. By fine-tuning with human feedback, language models can be better aligned with user intent in a wide range of tasks [15]. InstructGPT, a fine-tuned version of GPT-3, has demonstrated improvements in truthfulness and reduced toxic output generation while maintaining performance on public NLP datasets. These advancements in PTMs and their integration with knowledge graphs indicate potential for further innovation in the field of NLP and defect diagnosis. In the field of BERT fine-tuning, a significant advancement is the introduction of Adapter BERT, a parameter-efficient transfer learning approach for NLP. Houlsby et al. proposed using adapter modules that enable a compact and scalable model, requiring only a few additional trainable parameters per task. This method keeps the original network’s parameters fixed, allowing for a high degree of parameter sharing and reducing the computational overhead for fine-tuning.
3 Proposed Approach In Sect. 3.1, we provide a detailed description of our process for generating a fine-tuning corpus. The corpus consists of real data, with some sentences labeled as correct and the remainder generated from the original sentences by ChatGPT labeled as incorrect. In Sect. 3.2, we describe our fine-tuning procedure in detail, which involves fine-tuning a pre-trained language model on our CoLA dataset described in Sect. 3.1. Section 3.3 will cover our approach of using the pre-trained model as a feature extractor to extract node and edge features from the knowledge graph. These features are then fed into a knowledge graph completion model after being dimensionally reduced via a fully connected layer to obtain the final result. We adopt some excellent models previously implemented by other researchers for the knowledge graph completion task. The architecture of our proposed method is illustrated in Fig. 2.
742
J. Huang et al.
Fig. 2. Our method architecture. We fine-tune BERT using a sentence classification task, distinguishing grammatically correct and incorrect sentences. Our CoLA dataset comprises 80% correct sentences from the original corpus and 20% incorrect sentences generated by ChatGPT. The finetuned BERT serves as a feature extractor for nodes and edges, and the last layer’s pooled output produces a feature vector. Dimensionality reduction is achieved via a fully connected layer before sending the initial embeddings to the KGC model for completion predictions based on candidate entity probabilities.
Improve Knowledge Graph Completion for Diagnosing Defects
743
3.1 Corpus Generation Based on ChatGPT CoLA is a language understanding task that stands for “Corpus of Linguistic Acceptability”. It involves binary labeling of sentences to indicate whether they conform to the rules of grammar. We transformed our domain-specific long text data into the format of the CoLA dataset, and then fine-tuned BERT on it. We believe that this can enhance BERT’s ability to understand domain-specific language. Specifically, we labeled 80% of the sentences as correct, which were directly extracted from the original corpus and labeled with the correct tags. For the remaining 20% of the sentences, we sent them to the ChatGPT API interface to generate syntactically incorrect sentences. The key technology here lies in the design of prompts, which significantly affects the quality of the returned data by ChatGPT. In our prompt design, we have four sections: the first section is task definition, the second section is a specific description of the task with examples, the third section contains additional requirements, and the fourth section is the corpus (i.e. the original content to be generated). As an example of our design, here is a task prompt: Your task is to transform a Chinese sentence into a new sentence with extremely incorrect grammar but related to the original sentence. For example, 1. Part of speech error: such as using adjectives as verbs or using nouns as adjectives. 2. Word order error: such as incorrect word order of phrases or sentences. 3. Sentence structure error: such as subject-verb disagreement, missing subject or object, or incomplete grammar. 4. Language expression error: such as improper word choice, inaccurate word usage, or unsmooth sentences. Please note that you can be creative and flexible in generating different types of grammatical errors. The original sentence is: “Highly sensitive and reliable sensors are the foundation for the digital development of electric power and equipment status perception, which can monitor equipment status by capturing signals such as sound, light, and electricity generated during equipment operation.” 3.2 Fine-Tuning Adapter BERT is a variant of BERT that can be trained using adapters, which allows the model to be trained on new tasks while adjusting only a small number of parameters for each task. In our paper, we chose to use the implementation of Adapter BERT as our base model for fine-tuning due to our focus on resource efficiency. As mentioned in the original paper, fine-tuning with a minimal setup only requires 3.6% of the parameters needed for fine-tuning the entire model, while still maintaining high performance. We selected our CoLA dataset for our fine-tuning task, and once the accuracy threshold was met, we froze all the layer parameters, treating the model as a feature extractor. We then pooled the output of the final hidden layer to obtain the final vector representation. 3.3 Knowledge Graph Completion In Knowledge Graph Completion, a knowledge graph comprises a set of triples (h, r, t), where h and t are entities, and r is the relation between them. The objective of KGC is to predict the missing tail entity t given a head entity h and a relation r. This can be
744
J. Huang et al.
achieved by learning a function f(h, r) = t, using the available triples in the knowledge graph. Once trained, the function f can be used to predict missing tail entities in the graph. The KGC models used in this paper are based on excellent implementations from previous studies. Our modifications are limited to its input layer, specifically, we use the feature vector obtained by the method proposed in this paper as the initial embedding for knowledge graph nodes and edges. To reduce the dimensionality of the input layer, we introduced a fully connected layer between the output layer of the pre-trained language model and the KGC model. This was necessary because the output dimension of BERT is quite high, and a lowerdimensional vector is required to save computation resources. The fully connected layer can be trained, and its parameters will be updated during the KGC task iterations. The input dimension of the fully connected layer is larger than its output dimension, which achieves dimensionality reduction. In this paper’s KGC model, the answer selection process calculates probabilities for candidate entities, ranking them and selecting the top k entities as final answers.
4 Experiments 4.1 Dataset Description The defects dataset used in our paper was derived from actual data obtained from State Grid Fujian Electric Power Research Institute, consisting mainly of information on defects in main electrical equipment. The data was manually collected, organized into a structured format, and then transformed into a set of triplets. In total, our dataset contains 58,820 triplets, comprising 53,889 entities and 12 distinct relationships. We created inverse edges by adding a new relation r −1 for each relation r in the dataset, such that for each triple (h, r, t) in the dataset, we also add the triple (t, r −1 , h). In other words, if there is a relationship from entity h to entity t with relation r, then we also add a relationship from entity t to entity h with relation r−1 . To evaluate the performance of our model, we divided the dataset into training, validation, and test sets in an 8:1:1 ratio. Bidirectional edges were only included in the training set. We have the following entity types: defect phenomenon, power station/line, power/line type, voltage level, defect attribute, defect location, defect nature, defect description, equipment type, defect equipment, defective part type, and defective part. The relationship types are expressed as “entity–entity” and are formally defined as a set of triples (e_i, r_k, e_j), where e_i and e_j are entities of any type, and r_k is the relation between them. Our fine-tuning dataset consists of 31,079 sentences, which were split into training, validation, and test sets with a ratio of 8:1:1. For fine-tuning, we utilized a 4.96 MB corpus specifically focusing on the electrical power domain. This corpus encompassed a diverse range of materials, including substation operation and maintenance question banks, primary equipment defect classification standards for substations, and distribution network fault analysis cases, among others.
Improve Knowledge Graph Completion for Diagnosing Defects
745
4.2 Experimental Configuration In this paper, we utilized various components from the PyKEEN library [16], such as the data splitting module. Given that our dataset is in Chinese, we opted for the BERT-Chinese-Base pre-trained model for our experiments. We tailored the pre-trained model for fine-tuning by setting a learning rate of 3e−4 and a train_batch_size of 32. We established a max_seq_length of 128 to handle sentence length variations. Sentences exceeding this limit will be truncated, while shorter ones will be padded. We scheduled the training to consist of 5 epochs. The fine-tuning of the pre-trained language model on a CPU took approximately 30 min, using an Intel 6338 Xeon CPU with 32 cores. For the KGC models, we set the training parameters to include a maximum of 300 training epochs and an evaluation after every 10 epochs. If the model does not improve after 5 consecutive evaluations, the training will be stopped early. We employed a fully connected layer to link BERT’s output to the KGC model’s output, with the layer’s input being 768 dimensions and its output being 300 dimensions. We selected DistMult [17], KG2E [18], and NodePiece [19] as our KGC models and implemented them using the PyKEEN library. The training of the KGC models was conducted on GPUs, utilizing four Tesla A100 graphics cards, with each epoch taking approximately 1.5 s to complete. 4.3 Results and Analysis Our fine-tuning task achieved 97.7% accuracy on our CoLA dataset, compared to 56.9% on GLUE [20] CoLA dataset in the original Adapter BERT paper. The high accuracy achieved indicates successful fine-tuning, and we analyzed the reasons for this performance on our dataset. Firstly, the wrong sentences generated by ChatGPT are relatively simple and easy to identify, which contributes to the high accuracy. Additionally, since we used a Chinese dataset, it is less prone to subtle errors such as spelling mistakes, further contributing to the high accuracy. We compared two methods for the KGC task: direct usage of the KGC model and fine-tuning the pre-training model using a corpus based on Long-text corpus of electric power systems to generate a CoLA dataset with ChatGPT. We then utilized the fine-tuned pre-training model as a feature extractor to extract initial embeddings for the KGC tasks. The second method is the approach proposed in this paper. We evaluated the performance of our models using hit@k metrics, where k is set to 1 and 5, and reported the highest values achieved by each model. Additionally, we calculated the AMR (Arithmetic Mean Rank) metric [21], which represents the average ranking of true triplets in the knowledge graph across all test queries, to assess the stability of the results. To mitigate the impact of chance, we repeated each experiment 10 times and presented the best result (maximum hit@k or minimum AMR) in bold. The results of our experiments are summarized in Table 1. We observed no improvement for hit@5, which we attribute to the limitations of the KGC model. Fine-tuning BERT allows us to obtain a more suitable initial embedding, which improves the accuracy of the model during queries. This improved accuracy results in faster discovery of the correct answer, leading to a better hit@1 score. However,
746
J. Huang et al. Table 1. The Results of KGC.
Model
Hit@1
Hit@5
AMR
DistMult
0.605
0.900
732
DistMult-Ours
0.613
0.896
708
KG2E
0.489
0.885
281
KG2E-Ours
0.540
0.829
254
NodePiece
0.230
0.498
3277
NodePiece-Ours
0.447
0.467
1288
achieving higher values for hit@5 may be constrained by the original performance of the KGC model. We conducted an experiment where BERT was used without fine-tuning. In this experiment, features were directly extracted from the nodes and edges of a knowledge graph using BERT and used as the initial embedding for a KGC model. The aim was to compare and analyze the effect of fine-tuning with and without domain-specific data. The results, presented in Fig. 3, showed a significant improvement in the model’s performance after fine-tuning.
Fig. 3. This bar chart compares the Hit@1 mean and maximum performance of directly using the original BERT as a feature extractor and our fine-tuned approach.
We analyzed the possible reason for this improvement and found that BERT, when trained on a public dataset, had limited understanding of technical terms. However, after fine-tuning, the vector representation encoded by BERT became more accurate, allowing nodes with different meanings to be more easily distinguished in the vector space. This, in turn, improved the effectiveness of the KGC model. Our proposed method leverages the power of fine-tuning BERT on professional field corpora to improve the effectiveness of feature extraction. The experimental results
Improve Knowledge Graph Completion for Diagnosing Defects
747
demonstrate that our method outperforms traditional methods, highlighting the importance of incorporating domain-specific knowledge to enhance model performance. By fine-tuning BERT on a corpus specific to the professional field of interest, the model gains a deeper understanding of the semantic information present in sentences, resulting in more accurate and useful features. Overall, our approach represents a promising direction for future research in the field of knowledge graph completion.
5 Conclusion In this paper, we present a method for improving Knowledge Graph Completion in diagnosing defects in main electrical equipment. By leveraging ChatGPT, we generate a CoLA dataset from the original long-text data and employ this CoLA dataset to fine-tune BERT, using it for node and edge feature extraction within the knowledge graph and initializing the KGC model’s initial embeddings. This approach enhances the KGC performance in the domain of main electrical equipment and facilitates defect diagnosis. We conducted experiments on a defect dataset, selecting DistMult, KG2E, and NodePiece as KGC models. Our method, which deviates from traditional KGC approaches or using the original BERT directly as a feature extractor, demonstrated significant improvements in the KGC task. Acknowledgements. This work is supported by the Research Funds from State Grid Fujian (SGFJDK00SZJS2200162).
References 1. Bordes, A., Usunier, N., Garcia-Duran, A., et al.: Translating embeddings for modeling multirelational data. In: Advances in Neural Information Processing Systems 26 (2013) 2. Gururangan, S., Marasovi´c, A., Swayamdipta, S., et al.: Don’t stop pretraining: adapt language models to domains and tasks. arXiv preprint arXiv:2004.10964 (2020) 3. Houlsby, N., Giurgiu, A., Jastrzebski, S., et al.: Parameter-efficient transfer learning for NLP. In: International Conference on Machine Learning. PMLR, pp. 2790–2799 (2019) 4. Devlin, J., Chang, M.W., Lee, K., et al.: Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018) 5. Yang, S., Wang, J., Meng, F., et al.: Text mining techniques for knowledge of defects in power equipment. In: 2021 10th IEEE International Conference on Communication Systems and Network Technologies (CSNT). IEEE, pp. 205–210 (2021) 6. Shcherbatov, I., Lisin, E., Rogalev, A., et al.: Power equipment defects prediction based on the joint solution of classification and regression problems using machine learning methods. Electronics 10(24), 3145 (2021) 7. Abid, A., Khan, M.T., Iqbal, J.: A review on fault detection and diagnosis techniques: basics and beyond. Artif. Intell. Rev. 54, 3639–3664 (2021) 8. Meng, F., Yang, S., Wang, J., et al.: Creating knowledge graph of electric power equipment faults based on BERT–BiLSTM–CRF model. J. Electr. Eng. Technol. 17, 2507–2516 (2022) 9. Graves, A.: Long short-term memory. In: Supervised Sequence Labelling with Recurrent Neural Networks. Studies in Computational Intelligence, vol. 385, pp. 37–45. Springer, Heidelberg (2012). https://doi.org/10.1007/978-3-642-24797-2_4
748
J. Huang et al.
10. Lafferty, J., McCallum, A., Pereira, F.C.N.: Conditional random fields: probabilistic models for segmenting and labeling sequence data (2001) 11. Chen, Q., Li, Q., Wu, J., et al.: Application of knowledge graph in power system fault diagnosis and disposal: a critical review and perspectives. Front. Energy Res. 10, 1307 (2022) 12. Zhang, T., Ding, J., Guo, Z.: Multimodal knowledge graph for power equipment defect data. In: Proceedings of the 7th International Conference on Cyber Security and Information Engineering, pp. 666–668 (2022) 13. Qiu, X., Sun, T., Xu, Y., et al.: Pre-trained models for natural language processing: a survey. Sci. China Technol. Sci. 63(10), 1872–1897 (2020) 14. Hai, H.N.: ChatGPT: The Evolution of Natural Language Processing. Authorea Preprints (2023) 15. Ouyang, L., Wu, J., Jiang, X, et al.: Training language models to follow instructions with human feedback. In: Advances in Neural Information Processing Systems 35, pp. 27730– 27744 (2022) 16. Ali, M., Berrendorf, M., Hoyt, C.T., et al.: PyKEEN 1.0: a python library for training and evaluating knowledge graph embeddings. J. Mach. Learn. Res. 22(1), 3723–3728 (2021) 17. Yang, B., Yih, W., He, X., et al.: Embedding entities and relations for learning and inference in knowledge bases. arXiv preprint arXiv:1412.6575 (2014) 18. He, S., Liu, K., Ji, G., et al.: Learning to represent knowledge graphs with Gaussian embedding. In: Proceedings of the 24th ACM International on Conference on Information and Knowledge Management, pp. 623–632 (2015) 19. Galkin, M., Denis, E., Wu, J., et al.: Nodepiece: compositional and parameter-efficient representations of large knowledge graphs. arXiv preprint arXiv:2106.12144 (2021) 20. Wang, A., Singh, A., Michael, J., et al.: GLUE: a multi-task benchmark and analysis platform for natural language understanding. arXiv preprint arXiv:1804.07461, (2018) 21. Opitz, J., Parcalabescu, L., Frank, A.: AMR similarity metrics from principles. Trans. Assoc. Comput. Linguist. 8, 522–538 (2020)
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis Shitao Cai1 , Zhou Zheng2 , Chenxiang Lin2 , Longqiang Yi3 , Jinhu Li4 , Jiangsheng Huang4 , and Zhihong Zhang1(B) 1 School of Informatics, Xiamen University, Xiamen 361005, China
[email protected]
2 State Grid Fujian Electric Power Research Institute, Fuzhou, China 3 Kehua Data Co., Ltd., Xiamen, China 4 State Grid Info-Telecom Great Power Science and Technology Co., Ltd., Xiamen, China
Abstract. The effective maintenance of power grid equipment is critical for ensuring the safe and stable operation of the power grid. In recent years, knowledge graphs have emerged as a powerful tool for representing complex relationships and knowledge in a structured and accessible format. In this paper, we proposed a knowledge graph-based approach for analyzing and diagnosing defects in power grid transformers. We first designed an ontology for defect data in the field of main trans- formers in power grids. The ontology included equipment information, defect descriptions, and industry-standard classification criteria. We then performed named entity recognition(NER) on textual data in the field of main transformers using the BertBilstm-CRF [1–3] model to extract entities. The extracted entity information was represented using the ontology, and the ontology was embedded into a knowledge graph using models such as TransE [4]. We conducted knowledge graph completion experiments to achieve diagnosis and analysis of the defect level. Our experimental results demonstrated that this method efficiently and automatically constructs a knowledge graph of main transformers in power grids. The welldesigned ontology and effective knowledge graph completion experiments also support the analysis of defect levels in main transformers in power grids. Additionally, this method can promote the understanding and management of complex systems in the field of power grid equipment. Keywords: Knowledge Graph · Named entity recognition · KG completion · Main transformer
1 Introduction The importance of maintaining power grid equipment cannot be overstated, as the safe and stable operation of power grids relies heavily on the reliability of transformers. Power grid equipment, including transformers, are critical components in the power generation, transmission, and distribution processes, and any failure or defect can cause severe consequences such as power outages, economic losses, and even safety hazards. However, © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 749–759, 2023. https://doi.org/10.1007/978-981-99-4761-4_63
750
S. Cai et al.
due to the complexity of power grid equipment and the large volume of data generated during its operation, equipment maintenance and defect analysis can be challenging. Traditional approaches to equipment maintenance and defect analysis have relied on manual inspection and analysis of data, which can be time-consuming and error-prone. In recent years, knowledge graphs have emerged as a promising tool for representing complex relationships and knowledge in a structured and accessible format. Knowledge graphs are graphs that consist of nodes representing entities and edges representing relationships between the entities. They provide a powerful means to integrate data from different sources and enable automatic inference and reasoning of complex relationships within the data. In the field of equipment maintenance and defect analysis, several studies have demonstrated the effectiveness of knowledge graph-based approaches in various application domains, such as medical, Network security, news, and social media. For example, Ernst et al. [5] proposed an automated method for the construction of biomedical knowledge graphs. The method integrated data from posts on various health portals, but was unable to integrate with standard health data. Jia et al. [6] built a knowledge graph for the network security domain, and proposed a path-based approach in the graph to expand and innovate the knowledge. Similar work has been done in the field of journalism. This work [7] constructs a knowledge graph in the news domain for the task of fake news detection. The approach implements the application of knowledge graphs for fake news detection through link prediction. In social networks, knowledge graphs can also describe the relationship between users and users, and protect the privacy of users by weighted matching [8].
Fig. 1. In previous applications, defect grading relied on manual judgment. Knowledge graphs can effectively improve the automation process of defect analysis.
However, in the field of power grids, defect grading still relies on manual judgment. During the inspection process, workers need to rely on their experience and technical manuals to determine the severity of equipment defects. As shown in Fig. 1, a knowledge graph-based approach can extract defect information into a knowledge graph and predict the grade of defects based on the information contained in it. In this paper, we extend the application of knowledge graphs to the specific context of power grid transformers and propose a knowledge graph-based approach for analyzing and diagnosing defects
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis
751
in power grid transformers. We designed an ontology for defect data in the field of main transformers in power grids, which enables the representation of equipment information, defect descriptions, and industry-standard classification criteria. We used named entity recognition to extract entities from textual data and represented the extracted entity information using the ontology. We then embedded the ontology into a knowledge graph using models such as TransE and conducted knowledge graph completion experiments to diagnose and analyze the defect level. Our proposed approach provides an efficient and automatic way to construct a knowledge graph of main transformers in power grids, which can support the understanding and management of complex systems in the field of power grid equipment. Our work builds upon existing research on knowledge graph-based approaches for equipment maintenance and defect analysis, but extends the application of knowledge graphs to the specific context of power grid transformers.
2 Related Work 2.1 Knowledge Graph Construction The construction of knowledge graphs involves the extraction of structured knowledge from unstructured text data. One of the most widely used techniques for this purpose is named entity recognition (NER) combined with relation extraction. For example, Bosselut et al. [9]. Proposed an automated method for the construction of a common sense knowledge graph about the prevalence of common knowledge, which was effectively extracted. Han et al. [10]. Proposed a joint learning framework to solve the problem of data fusion between texts of knowledge graphs. This work also effectively implemented relationship and entity extraction. Recent work has also used sequential model architectures, such as LSTM-CNN [11]. This approach learns features of characters and words.MGNER [12] proposes detection methods for entities of different granularity and accomplishes the entity classification task. The use of pre-trained language models also plays an important role in named entity recognition, such as K-Bert [13]. All of these methods can achieve good performance in entity recognition. 2.2 Knowledge Graph Construction Knowledge Graph Completion (KGC) methods aim to fill the missing information in a knowledge graph (KG). Among various KGC methods, embedding-based approaches have shown great promise in recent years. TransE, proposed by Bordes et al. [4], is one of the earliest and most influential embedding-based models for KGC. TransE represents entities and relations as low-dimensional vectors and measures the plausibility of a triple (h, r, t) by comparing the vector sum of entity h and relation r with entity t. Since TransE, a number of variants have been proposed, such as TransH [14] and TransR [15], to handle more complex relation semantics. More recently, graph neural networks (GNNs) have also been applied to KGC, such as ComplEx [16] and ConvE [17], to capture higher-order dependencies in the KG. In addition to this, the embedding-based approach replaces the head or tail entities in each entity pair and calculates the scores of all candidate entities. For example, SimpLE [18], HolE [19] and R-GCN [20], can also be used for knowledge graph complementation.
752
S. Cai et al.
2.3 Knowledge Graph Application in the Electric Power Industry In the electric power industry, knowledge graphs have been applied for various tasks, such as fault diagnosis, energy management, and smart grid operation. For example, Tang et al. [21] proposed a knowledge graph-based power grid equipment management system that incorporates multiple sources of data to build a knowledge graph. This work demonstrates the efficiency improvement of knowledge graph for power equipment management. Huang et al. [22] proposed a semi-automated technique for constructing knowledge graphs for the daily operation of grid devices. Overall, the related work in knowledge graph construction, completion, and application in electric power industry provides valuable insights and techniques for our research on applying knowledge graphs to the field of electric power transformer defect classification.
3 Proposed Method In this section, we describe our proposed knowledge graph-based approach for analyzing and diagnosing defects in power grid transformers. The approach consists of three main steps: Ontology design, Knowledge Graph construction, and Knowledge Graph completion. We will introduce the details of ontology design in Sect. 3.1, which also provides the basis for the construction of the knowledge graph. Then Sect. 3.2 describes how we can extract entities by the named entity recognition method. Finally, in Sect. 3.3, we introduced the application of entity completion methods in the analysis of power grid equipment defects. 3.1 Ontology Design The ontology design for defect data in the field of main transformers in power grids was an essential component of our proposed method. The ontology we designed contains 14 entity types and 14 relationships, as shown in Fig. 2. The entity types include Station, Voltage level, Station-type, Defect-equipment, Equipment-type, Equipmentcategory, Component, Component-category, Position, Defect-phenomenon, Defectattribute, Defect-description, Classification-basis and Defect-level. Among them, Station is the primary key, which determine the name of the transformer where the defect occurred. Position indicates the specific location where the defect occurs. The defective equipment is usually the main transformer, which has different types. Defective equipment consists of components, which also have different types. The defect level is the most important entity type and indicates the severity of the defect. Our ontology covers information related to the type of equipment, location of the defect, voltage level, and other relevant details. In addition, it also contains expert knowledge on standard defect classification levels. While defect phenomena and attributes are subjective descriptions provided by users in daily application scenarios, defect descriptions and classification criteria are based on standard expert knowledge. By combining subjective descriptions with expert knowledge, we aim to automatically classify the defect levels.
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis
753
Fig. 2. The Ontology of Main Transformer Knowledge graph.
Overall, our ontology design provided a structured and standardized representation of defect data in the field of main transformers in power grids, which was essential for the subsequent steps of our proposed method. 3.2 Knowledge Graph for Main Transformer Construction In order to extract entities from textual data in the field of main transformers and represent them in the ontology, we utilized named entity recognition (NER) techniques. Specifically, we used the Bert-Bilstm-CRF [1–3] model to perform NER on the text data. This model combines a pre-trained BERT [1] model with a bidirectional LSTM [2] and a conditional random field (CRF) [3] layer to capture both local and global context information and predict entity labels for each token in the text. The Bert-Bilstm-CRF model is a state-of-the-art method for named entity recognition that combines three different components: a pretrained BERT model, a bidirectional long short-term memory (BiLSTM) network, and a conditional random field layer. The BERT model, which stands for Bidirectional Encoder Representations from Transformers, is first trained on large amounts of unannotated text to learn general language features. These learned features are then fine-tuned on a smaller annotated dataset specific to the task of NER. The Bilstm network, which is bidirectional to allow for context from both left and right, takes the output from the BERT model and further processes it to capture sequence information. Finally, the CRF layer is applied to the output of the Bilstm to predict the most likely sequence of named entities in the input text. By combining these three components, the Bert-Bilstm-CRF model is able to achieve state-of-the-art performance on NER tasks, including the extraction of entities in the field of main transformers for our proposed method.
754
S. Cai et al.
Fig. 3. Named entity recognition is done by Bert-BiLSTM-CRF model. By inputting the labeled BIO format data, the different classes of entities can be identified.
We have labeled 37019 data and divided the training set, validation set and test set by ten-fold cross-validation. In the process, we labeled all entity types in the text data. Then we transformed the data into BIO form as the input to the model. As shown in Fig. 3, the BIO data are trained by the Bert-BiLSTM-CRF model to achieve the task of entity recognition. In the practical application, text sequence information is used as input to achieve automatic entity extraction. After extracting entities from the text using NER, we represented the extracted entity information in the ontology. Specifically, we mapped the extracted entities to classes in the ontology, and used the properties in the ontology to specify the relationships between entities. 3.3 Graph Completion for Defect Classification To perform knowledge graph completion and diagnose transformer defects, we employed TransE [4], TransR [15], and SimplE [18]models. TransE is a popular embedding model that maps entities and relations to a common vector space, where each relation is represented as a translation from head to tail entities. TransR and SimplE are advanced models that can handle one-to-many, many-to-one, and many-to-many relations by mapping relations to a relation-specific space before the translation process. TransE represents entities and relations in a knowledge graph as vectors in a lowdimensional space. The model tries to find a vector representation for each entity and relation that preserves the semantic meaning of the knowledge graph. Specifically, for each triple (h, r, t), where h and t are head and tail entities, and r is a relation, the model tries to learn an embedding vector for each entity such that h + r ≈ t. TransR extends the TransE model by introducing relation-specific projection matrices. Instead of representing each relation as a single vector, TransR projects the head and tail entities into a relation-specific subspace, where the relation vector resides. Specifically, TransR first projects the head and tail entities into a relation-specific subspace using a relationspecific projection matrix, and then computes the score of the triple using a modified scoring function that considers the relation-specific subspace. SimplE is a simplified version of TransE that aims to improve the efficiency and expressiveness of TransE.
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis
755
SimplE represents each relation as a combination of two diagonal matrices, and uses a bilinear scoring function to compute the scores of triples. Specifically, SimplE learns two diagonal matrices that represent each relation, and computes the score of each triple by taking the dot product of the embeddings of the head and tail entities using the learned diagonal matrices. To apply the embedding models, we first trained them using the existing edges in the knowledge graph. We then utilized the learned embeddings to predict the likelihood of missing edges. Specifically, we computed the scores of all possible tail entities for each missing edge, and selected the top-k entities as the predicted tail entities. We compared the predicted tail entities with the ground truth to evaluate the models’ performance. An example of knowledge graph completion in a defect classification task is shown in Fig. 4. The missing edges make it impossible to conduct the determination of defect levels. Our proposed method can effectively alleviate this problem by allowing automatic and efficient diagnosis of defects. This approach provides valuable insights for maintenance and troubleshooting, aiding in the safe and stable operation of the power grid.
Fig. 4. The edges of the tap changer and breather are missing in the knowledge graph. The size of the edges in the map can be increased by adding the missing information through knowledge map complementation. Implementing defect level classification for a specific defect description. The knowledge map stores the query path for the on-load tap-changer silicone discoloration, but is missing the query path from the on-load tap-changer breather to the silicone discoloration. Using the tap changer as the header entity, predict the possible entities in the candidate entity to complement the correct edge.
4 Experiments and Results In this section, we present the datasets and evaluation metrics used in our experiments, as well as the results of the proposed methods. 4.1 Datasets and Evaluation Metrics We collected a dataset of transformer defect reports from a power grid company. The dataset used in this study consists of 37,019 textual data with a total of 1,530,329 characters. This dataset was used for the named entity recognition experiments, and a total of 48,082 entities were extracted from the textual data. Following the ontology graph, entities and relations were constructed to form a knowledge graph, resulting in a total of 320,725 triples. The triple data was used for the knowledge graph completion experiments.
756
S. Cai et al.
To evaluate the performance of Named Entity Recognition, we used precision, recall, and F1-score as evaluation metrics. Precision measures the proportion of true positive entities among all predicted entities, recall measures the proportion of true positive entities among all actual entities, and F1-score is the harmonic mean of precision and recall. Meanwhile, we use MRR (Mean Reciprocal Rank) as an evaluation metric for the knowledge graph complementation task. 4.2 Named Entity Recognition We conducted the named entity recognition experiment using the Bert-Bilstm-CRF model. We randomly divided the dataset into training, validation, and testing sets with a ratio of 8:1:1. We used the training set to train the model, and the validation set to tune the hyperparameters. We used the testing set to evaluate the performance of the model. The results of the named entity recognition experiment are shown in Table 1. Our experimental results show that the Bert-Bilstm-CRF model achieved an precision of 0.92, a recall of 0.91, an F-1 score of 0.91, and a accuracy of 0.93, indicating its high accuracy and effectiveness in identifying named entities in the text data. Table 1. Named Entity recognition results using the BERT-BiLSTM-CRF model. Model
Precision
Recall
F1-score
Accuracy
Bert-BiLSTM-CRF
0.92
0.91
0.91
0.93
4.3 Knowledge Graph Completion Experiment We performed the knowledge graph completion experiment using TransE, TransR, and SimplE models. We randomly selected 80% of the triples as the training set and used the remaining 20% as the testing set. We evaluated the performance of the models using MRR (Mean Reciprocal Rank). Table 2 shows the results of the knowledge graph completion experiment, where -c indicates that when entity complementation is performed, the predicted tail entity ranking is the candidate entity, not all entities. The TransR model achieved the best performance with an MRR of 0.857. The TransE and SimplE models also achieved good results, with MRR of 0.845 and 0.847, respectively. By changing the candidate set, there is also a significant performance improvement for entity completion.
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis
757
Table 2. Knowledge Graph completion results using TransE, SimplE, and TransR. Model
MRR
TransE
0.765
SimplE
0.733
TransR
0.763
TransE-c
0.845
SimplE-c
0.847
TransR-c
0.857
4.4 Analysis From the results, we can see that the BERT-BiLSTM-CRF model achieved high precision and recall on entity extraction, which indicates that it can accurately identify entities from the text. On the other hand, TransR outperformed TransE and simpIE on entity completion, suggesting that it can better capture the complex relationships between entities in the knowledge graph. However, all three methods achieved high MRR scores, indicating that they can effectively complete missing entities in the graph. In conclusion, the proposed Bert-Bilstm-CRF model and knowledge graph completion approach using TransR model have shown promising results in transformer defect analysis. The evaluation results demonstrate their effectiveness and efficiency in detecting and diagnosing transformer defects, providing valuable insights for maintenance and troubleshooting in power grid companies.
5 Conclusion In this paper, we proposed a knowledge-driven approach to identify potential defects in power grid transformers by leveraging natural language processing and knowledge graph completion techniques. We developed a pipeline that integrates a Bert-BilstmCRF model for named entity recognition and TransE, TransR, and SimplE models for knowledge graph completion. We constructed a dataset containing 48,082 entities, and extracted 320,725 triplets to build a knowledge graph. We evaluated the performance of our approach using precision, recall, F1 score, accuracy, and MRR metrics. Overall, the proposed approach provides a novel and effective solution for transformer defect diagnosis by integrating natural language processing and knowledge graph techniques. The approach can be extended to other domains that require knowledge-driven analysis and decision-making. Further research could explore the use of other advanced neural network models for natural language processing and knowledge graph completion to improve the accuracy and efficiency of the approach. Acknowledgements. This work is supported by Major Program of Xiamen (3502Z20231006); National Nature Science Foundation of China (62176227, U2066213); Fundamental Research Funds for the Central Universities (20720210047).
758
S. Cai et al.
References 1. Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: Bert: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018) 2. Schuster, M., Paliwal, K.K.: Bidirectional recurrent neural networks. IEEE Trans. Signal Process. 45(11), 2673–2681 (1997) 3. Lafferty, J., McCallum, A., Pereira, F.C.: Conditional random fields: probabilistic models for segmenting and labeling sequence data (2001) 4. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., Yakhnenko, O.: Translating embeddings for modeling multi-relational data. In: Advances in neural information processing systems 26 (2013) 5. Ernst, P., Meng, C., Siu, A., Weikum, G.: Knowlife: a knowledge graph for health and life sciences. In: 2014 IEEE 30th International Conference on Data Engineering, pp. 1254–1257. IEEE (2014) 6. Jia, Y., Qi, Y., Shang, H., Jiang, R., Li, A.: A practical approach to constructing a knowledge graph for cybersecurity. Engineering 4(1), 53–60 (2018) 7. Ciampaglia, G.L., Shiralkar, P., Rocha, L.M., Bollen, J., Menczer, F., Flammini, A.: Computational fact checking from knowledge networks. PLoS ONE 10(6), e0128193 (2015) 8. Qian, J., Li, X.Y., Zhang, C., Chen, L., Jung, T., Han, J.: Social network de-anonymization and privacy inference with knowledge graph model. IEEE Trans. Depend. Secur. Comput. 16(4), 679–692 (2017) 9. Bosselut, A., Rashkin, H., Sap, M., Malaviya, C., Celikyilmaz, A., Choi, Y.: Comet: commonsense transformers for automatic knowledge graph construction. arXiv preprint arXiv: 1906.05317 (2019) 10. Han, X., Liu, Z., Sun, M.: Neural knowledge acquisition via mutual attention between knowledge graph and text. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 11. Chiu, J.P., Nichols, E.: Named entity recognition with bidirectional LSTM-CNNs. Trans. Assoc. Comput. Linguist. 4, 357–370 (2016) 12. Xia, C., et al.: Multi-grained named entity recognition. arXiv preprint arXiv:1906.08449 (2019) 13. Liu, W., et al.: K-BERT: enabling language representation with knowledge graph. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 2901–2908 (2020) 14. Wang, Z., Zhang, J., Feng, J., Chen, Z.: Knowledge graph embedding by translating on hyperplanes. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 28 (2014) 15. Lin, Y., Liu, Z., Sun, M., Liu, Y., Zhu, X.: Learning entity and relation embeddings for knowledge graph completion. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 29 (2015) 16. Trouillon, T., Welbl, J., Riedel, S., Gaussier, É., Bouchard, G.: Complex embeddings for simple link prediction. In: International Conference on Machine Learning, pp. 2071–2080. PMLR (2016) 17. Dettmers, T., Minervini, P., Stenetorp, P., Riedel, S.: Convolutional 2D knowledge graph embeddings. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 18. Kazemi, S.M., Poole, D.: Simple embedding for link prediction in knowledge graphs. In: Advances in Neural Information Processing Systems 31 (2018) 19. Nickel, M., Rosasco, L., Poggio, T.: Holographic embeddings of knowledge graphs. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 30 (2016)
Knowledge Graph-Based Approach for Main Transformer Defect Grade Analysis
759
20. Schlichtkrull, M., Kipf, T.N., Bloem, P., van den Berg, R., Titov, I., Welling, M.: Modeling relational data with graph convolutional networks. In: Gangemi, A., et al. (eds.) ESWC 2018. LNCS, vol. 10843, pp. 593–607. Springer, Cham (2018). https://doi.org/10.1007/978-3-31993417-4_38 21. Tang, Y., Liu, T., Liu, G., Li, J., Dai, R., Yuan, C.: Enhancement of power equipment management using knowledge graph. In: 2019 IEEE Innovative Smart Grid Technologies - Asia (ISGT Asia), pp. 905–910 (2019). https://doi.org/10.1109/ISGT-Asia.2019.8881348 22. Huang, H., Hong, Z., Zhou, H., Wu, J., Jin, N.: Knowledge graph construction and application of power grid equipment. Math. Probl. Eng. 2020, 1–10 (2020)
CSAANet: An Attention-Based Mechanism for Aligned Few-Shot Semantic Segmentation Network Guangpeng Wei and Pengjiang Qian(B) School of Artificial Intelligence and Computer Science, Jiangnan University, Wuxi 214122, Jiangsu, China [email protected]
Abstract. Semantic segmentation, a fundamental job in computer vision, involves identifying and classifying items in an image. However, it is too costly to collect a sizable volume of annotated data for prediction tasks. Few-shot semantic segmentation approaches aim to learn from a short amount of labeled data and generalize to new classes in order to get over this constraint. Learning to distinguish objects from a small sample of labeled samples is the key challenge in this project. Thus, we propose a Channel and Spatial Attention Alignment Network (CSAANet) for better performance in few-shot semantic segmentation. Our approach uses the channel and spatial attention to obtain weighted classifiers for novel classes. The classes in the image may be precisely segregated using the weight classifiers. Additionally, we construct a semantically aligned auxiliary learning module to fully utilize the supporting image information and enhance the learned weights. Experimental findings on few-shot semantic segmentation datasets, PASCAL-5i and COCO-20i, demonstrate that our proposed method outperforms other methods. Keywords: few-shot semantic segmentation · attention mechanism · aligned auxiliary learning
1 Introduction Semantic segmentation has various applications in medical image recognition, autonomous driving, and geological exploration [1]. Semantic segmentation performance has improved significantly as a result of the advent of deep learning. However, training convolutional neural networks for semantic segmentation requires a considerable amount of labeled data. Few-shot semantic segmentation has been proposed as a remedy for this problem since it aims to finish the semantic segmentation task with a minimal quantity of labeled input. Few-shot semantic segmentation needs train a model [2] using K images in the support set that can predict novel classes of samples in the training image pairs. However, few-shot semantic segmentation is a difficult task [3] that has a number of obstacles. Firstly, traditional methods face difficulties in extracting class-related information from © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 760–772, 2023. https://doi.org/10.1007/978-981-99-4761-4_64
CSAANet: An Attention-Based Mechanism
761
only a few samples, making it challenging to handle novel classes. Secondly, the complexity of image content may result in the suppression of novel class information by other base classes, owing to the lack of data. Because creating a new classifier also changes the output probability space, it is currently unclear how to keep the fundamental classifier and dynamically add other classifiers. In this paper, we suggest a few-shot semantic segmentation network that is attentionbased. The CSAANet model employs channel and spatial attention modules to selectively emphasize critical features of previously unseen classes and filter out irrelevant features. To back-predict the segmentation mask map of the original support pictures, we also develop an auxiliary alignment learning module that creates a new support set out of query photos and their prediction outcomes. On the PASCAL-5i and COCO-20i few-shot semantic segmentation datasets, we have shown via rigorous testing that our suggested method outperforms most of the already available methodologies.
2 Related Work 2.1 Semantic Segmentation Semantic segmentation, which involves giving categorical labels to specific pixels inside an image, is an important step in computer vision. Deep learning techniques have been heavily used for this problem in recent years [4], leading to notable performance gains. The Fully Convolutional Network (FCN), which has now been a pillar of study in the area, was first described by Long et al. [5]. The U-Net architecture was first described by Ronneberger et al. [6] and it has since gained popularity as a framework for biomedical image segmentation. 2.2 Attention Mechanism Attention mechanisms have become increasingly popular for their ability to extract essential information from complex input sequences. Oktay et al. [7] proposed an attention mechanism that can adapt to relevant image regions for medical image segmentation. Fu et al. [8], which makes use of spatial and channel attention processes, was created to improve the performance of predict result. Gidaris et al. [9] proposed a dynamic attention mechanism for few-shot semantic segmentation that adapts to the current task and prevents forgetting of prior knowledge. 2.3 Few-Shot Semantic Segmentation It is difficult to divide objects in an image with few-shot semantic segmentation since there are few annotated examples to use as a guide. Recently, there has been an increase in interest in the development of few-shot semantic segmentation algorithms. Wang et al. [10] was to match the query image’s features with the support set prototypes. Rakelly et al. [11] presented a few-shot segmentation method that leverages guided networks, which can learn to propagate segmentation masks across different images. Liu et al. [12] proposed a method that leverages cross-referencing between object instances to improve segmentation accuracy. Mennatullah et al. [13] proposed a method that learns adaptive prototypes for each object class in the support.
762
G. Wei and P. Qian
3 Network 3.1 Problem Setting Our study aims to build a network that can accurately segment a target item in a query picture by using just a small number of annotated support images for the target class. In other words, our task is an N-way K-shot problem [14]. We use the support set samples, which include K examples from each of N new classes, to estimate the class of a query picture in a query set. The 1-way 1-shot and 1-way 5-shot tasks are examined in this work.
Fig. 1. Overview of Channel and Spatial Attention Alignment Network (CSAANet)
Initially, we create two separate classes, namely the base class Dbase and the novel class Dnovel , with no overlap, i.e. Dbase ∩ Dnovel = ∅. Next, we choose a set of image pairs (image and its mask)Ctrain = {(Si , Qi )}i=1 from Dbase as the training set, and another set of image pairs Ctest = {(Sj , Qj )}j=1 from Dnovel as the test set. After being trained on the training set Ctrain , the performance of the semantic segmentation model M is assessed on the test set Ctest . Previous studies have shown that the episodic training mechanism is an effective approach for few-shot tasks. As a result, we separate the training and test sets into a number of episodes, each of which includes a collection of support sets S (annotated) and a set of query setsQ. Each episode can be treated as a segmentation task. To be specific, each support set Si (i = 1, 2, ..., c) consists of K image pairs, represented as S = {(Ikc , Mkc )}, where Ikc and Mkc denote the k-th image and its mask of the c-th category, respectively. Here, c = 1, 2, ..., N and k = 1, 2, ..., K. For training and testing, the support set S consists of C classes drawn from Dbase and Dnovel , respectively. In order to learn the feature weights of class C from the support set, we first train the model on Ctrain . Then, we utilize the learned weights to perform semantic segmentation on the query set. Following extensive training, the segmentation model M acquires strong generalization capabilities by processing different semantic classes present in each episode. In the next step, we assess the performance of the trained model M on Ctest .
CSAANet: An Attention-Based Mechanism
763
We propose a CSAANet in this paper, which uses the weights of new classes to construct new classifiers. Figure 1 shows the network architecture of our suggested model. 3.2 Channel and Spatial Attention Module The limited data for novel classes in few-shot semantic segmentation can be easily drowned out by background information or other classes, resulting in new classifiers with subpar performance. To solve this issue, we suggest a channel and spatial attention module (CSAM), which will extract class-specific data and provide a more accurate weight estimator G(·). There are two basic components to the CSAM module. The former part weights different feature channels based on their importance, while the latter part weights different spatial locations based on their relevance to the task. First, we extract the features Fcs from the support images using backbone. Then, we estimate the weights wc of the novel classes based on the extracted features Fcs and support image mask Mcs by wc = G(Fcs , Mcs )
(1)
Figure 2 illustrates the three-step operation of CSAM. In the first step, the support feature map Fcs and the corresponding mask Mcs are input to the attention module. This module makes use of a SE block with two 1 × 1 convolutional layers, a ReLU function, and a sigmoid function sandwiched in the middle. The SE block takes as input the maximum and average values of the masked feature map and outputs a set of channel weights, which are then applied to feature map. The following is a formula for this procedure, gc = σ (Conv(Pool(Fcs ⊗ Mcs )))
(2)
In the above formula, ⊗ represents a mask operation that applies a binary mask to the feature map through element-by-element multiplication. This ensures that the attention mechanism operates only on the relevant part of the feature map. σ denotes sigmoid activation function. The attention vector gc is used to suppress the irrelevant features. This function can be defined as,
Fcs = Fcs gc
(3)
where denotes element-by-element multiplication. The SC block, which applies a 7 × 7 convolutional layer, and a sigmoid function, which creates a set of spatial weights, in the spatial attention module. These weights are then applied to the channel-weighted feature maps. The SC block takes as input the maximum and average values of the masked feature maps. The following equation describes this process,
gs = σ (Conv(δ(Fcs )))
(4)
where δ(·) denotes the operation that obtains either the maximum or mean value of the feature maps, while Conv represents a convolution operation. Additionally, an attention
764
G. Wei and P. Qian
vector gs is employed to suppress the relevant features, which is achieved using the following equation,
Fcs = Fcs gs
(5)
The output of CSAM is a feature map that has been improved to be more applicable to the current task. This is achieved by computing the product of the channel-dimensional and spatially weighted feature maps with the elements of the original feature map. The resulting improved feature map is more educational since it captures key features while excluding unimportant ones. Finally, the weight wc of the novel classes are estimated as follows,
wc = Pool(Fcs ⊗ Mcs )
(6)
Fig. 2. Illustration of the Channel and Spatial Attention Module (CSAM)
3.3 Dynamic Classifier To enable the end-to-end combined training of the estimated new classifier and the base classifier, we have employed a novel training method [15]. Specifically, our approach, CSAANet, initially focuses on the new class information using the channel and spatial attention modules after obtaining a support set S. Next, we extract the features Fks from each image, and estimate the weights wc of the new class c based on the extracted features to construct a new classifier. As a result, the dynamic classifier incorporates the weights of the new classifier c. 3.4 Auxiliary Alignment Learning Module In previous works [10], it has been extensively demonstrated that adding an auxiliary alignment module to the model can enhance the final model’s capacity for generalization by better utilizing the information of support images. We have developed an Auxiliary
CSAANet: An Attention-Based Mechanism
765
Alignment Learning (AAL) module, which effectively incorporates information from the support set images to guide the network model’s training. The schematic diagram of the AAL module is displayed in Fig. 3. After obtaining the segmentation results, we use MAP [16] to obtain the prototypes of the segmentation results.
Fig. 3. The Auxiliary Alignment Learning Module (AAL)
The formulas for calculating novel classes and background prototypes are as follows, c c 1 x,y Fk (x, y)II[Mk (x, y) = c] Pc = (7) c K x,y II[Mk (x, y) = c] k c c 1 x,y Fk (x, y)II[Mk (x, y) = Ci ] (8) Pbg = c CK x,y II[Mk (x, y) = Ci ] c,k
After obtaining the novel class and background prototypes of the segmentation results, We use a method to assess how comparable the query prototypes and the support features are. Then, we use a function to convert the distance values into a probability distribution over the prototypes [17]. The segmentation Ms for the support image is then created by summing the probability distribution by the prototypes. Let the prototype P = {Pc |c ∈ Ci } ∪ {Pbg }. For each Pj ∈ P, we have j Ms (x, y) =
exp(−d (Fs (x, y), pj )) pj ∈P exp(−d (Fs (x, y), pj ))
(9)
Then, we got the predicted segmentation mask by the following equation j
Ms (x, y) = arg max Ms (x, y) j
(10)
The segmentation results of the support image were obtained, and subsequently compared against ground truth annotations. The Lpar was calculated using the following equation, Lpar = −
1 j II[Ms (x, y) = j] log Ms (x, y) CK c,k,x,y pj ∈P
(11)
766
G. Wei and P. Qian
The loss of our network model during training is ultimately given by L = Lseg + Lpar
(12)
The complete training and testing process of CSAANet for few-shot semantic segmentations summarized in Algorithm 1 and Algorithm 2. Algorithm 1: Training CSAANet Input: Training set Dtrain Initialize model parameters for CSAANet for each episode ( Si , Qi ) ∈ Dtrain do Obtain the images ( I kc , M kc ) from Si , Qi and extract the features Fs and Fq Estimate the classifier wc' by Eqn(6) Record the c-th classifier wc and replace the c-th classifier wc ← wc' Obtain the segmentation result Q pre Extract prototypes P from Q pre using Eqn(7) and Eqn(8) Obtain the segmentation result M s by Eqn(10) Compute the total loss L by Eqn(12) Update the gradient and optimize via SGD ' Restore the c-th classifier wc ← wc End
Algorithm 2: Testing CSAANet Input: Testing set Dtest For each episode ( Si , Qi ) ∈ Dtest do Obtain the support set images ( I sc, k , M sc, k ) from Si and extract the features Fs Estimate the classifier wc by Eqn(6) and add classifier wc to classifier Obtain the query set images ( I qc, k , M qc, k ) from Qi and extract the features Fq Obtain segmentation results based on dynamic classifier End
4 Experiments 4.1 Setup Datasets. We evaluated the effectiveness of our proposed model using the datasets PASCAL-5i and COCO-20i. The PASCAL VOC 2012 dataset was enhanced with SBD annotations to provide PASCAL-5i, a few-shot semantic segmentation dataset. The base and novel sets of the PASCAL VOC dataset, which consists of 20 categories and 10,582 and 1,449 pictures, respectively. On the other hand, by adding 80 categories to the COCO2014 dataset, COCO-20i, the largest and most difficult few-shot semantic segmentation dataset, was developed. The basic and novel sets each have 82,783 and 40,504 photos, respectively. Evaluation Metrics. To evaluate the effectiveness of our model, we used the Mean Intersection over Union (MIoU) metric, which is frequently employed in semantic segmentation tasks. In our experiments, we calculated the mIoU scores for each of the
CSAANet: An Attention-Based Mechanism
767
four test folds and then computed their average values as a metric for evaluating the performance of the model. Implementation Details. We implemented our proposed method using PyTorch. For feature extraction, We employed the expanded ResNet-50 network along with the ASPP module as the feature extractor. After obtaining the output result of CSAANet, we utilized the bilinear interpolation method to process the result. We resized the input image and its ground-truth mask to 320 × 320 and augmented the data with random horizontal flips. We trained the model over 40,000 iterations on PASCAL-5i using the SGD optimizer. On COCO-20i, we trained the model for 150,000 iterations.
4.2 Comparison Using the PASCAL-5i and COCO-20i datasets as the benchmarks, we evaluated our proposed technique against existing methods by MIoU metrics. On the ResNet-50 backbones, we ran 1-way, 1-shot and 1-way 5-shot experiments for the studies. We provide the network model’s quantitative and qualitative conclusions. Table 1. Performance of CSAANet under PASCAL-5i 1-way 1-shot and 5-shot. The symbol “*” indicates that the network model was re-implemented by us. “-” means that the results were not reported. Model
1-shot
5-shot
fold0
fold1
fold2
fold3
mean
fold0
fold1
fold2
fold3
mean
Co-FCN
36.77
50.62
44.93
32.41
41.18
37.58
50.02
44.13
33.94
41.42
SG-One
40.26
58.43
48.46
38.45
46.40
41.93
58.63
48.65
39.47
47.17
PANet*
34.30
50.11
44.92
40.18
42.38
43.74
56.77
50.93
46.57
49.50
PANet
42.30
58.00
51.10
41.20
48.10
51.80
64.60
59.80
46.50
55.70
CANet
52.53
65.91
51.35
51.97
55.44
55.50
67.80
51.90
53.20
57.10
CRNet
-
-
-
-
55.70
-
-
-
-
58.80
DENet*
55.32
69.26
61.85
50.22
59.16
54.43
71.21
62.32
50.47
59.61
DENet
55.74
69.69
63.62
51.26
60.64
54.72
70.99
64.51
51.63
60.46
CSAANet
57.43
69.33
64.26
53.37
61.10
56.84
72.75
66.67
50.41
61.67
Results on PASCAL-5i. Table 1 presents a comparison of mIoU results obtained by our proposed method and existing network models [10, 12, 15, 16, 19, 20] on the PASCAL-5i dataset. In 1-way, 1-shot and 5-shot tasks, our suggested method gets the greatest mIoU of 61.10% and 61.67%, respectively, proving its efficacy. Additionally, we observed that the inclusion of the AAL module does not yield results as good as some network models, however, the network’s convergence speed has greatly increased, as we will further discuss in the subsequent ablation experiments. Our results demonstrate that the
768
G. Wei and P. Qian
proposed CSAANet exhibits superior generalization ability to new classes, highlighting the effectiveness of our added modules. Results on COCO-20i. Table 2 presents a comparison of the mIoU obtained by our proposed CSAANet with existing network models [15, 20–22] on the COCO-20i dataset. In 1-way, 1-shot and 5-shot tasks, our suggested method achieves the mIoU of 42.57% and 43.47%, respectively. The experimental findings show that our suggested strategy outperforms the majority of network models, despite the fact that it did not produce optimal results on all folds. This demonstrates how well our suggested module performs when used to enhance few-shot semantic segmentation networks.
Table 2. Performance of CSAANet under COCO-20i 1-way 1-shot and 5-shot. The symbol “*” indicates that the network model was re-implemented by us. “-” means that the results were not reported. Model
1-shot
5-shot
fold0
fold1
fold2
fold3
mean
fold0
fold1
fold2
fold3
mean
PGNet*
39.54
39.68
33.90
33.49
36.65
42.33
38.81
32.69
36.84
37.67
ASGNet
-
-
-
-
34.60
-
-
-
-
42.50
MMNet
34.92
41.05
37.27
37.04
37.57
34.92
41.05
37.27
37.04
37.57
CANet*
42.19
42.35
37.68
40.88
40.78
44.66
43.08
37.52
42.69
41.99
DENet
42.90
45.78
42.16
40.22
42.77
45.40
44.86
41.57
40.26
43.02
CSAANet
43.16
44.68
42.82
39.63
42.57
46.06
45.22
40.83
41.78
43.47
Fig. 4. The semantic segmentation results in 1-way 1-shot
CSAANet: An Attention-Based Mechanism
769
Qualitative Results. Figure 4 and Fig. 5 show the qualitative outcomes of the 1-way segmentation. As illustrated in the figure, CSAANet demonstrates remarkable segmentation outcomes on novel classes despite limited image support. For instance, as depicted in row 1 of Fig. 4, the proposed method successfully segments the dog. Furthermore, the proposed method achieves accurate segmentation of the target object even when the query image contains multiple objects.
Table 3. The impact of CSAM and AAL modules on network performance in the 1-way 1-shot case. “-” means that the module is not used in the network model. Model
1-shot fold0
fold1
fold2
fold3
mean
DENet
55.32
69.26
61.85
50.22
59.16
CSAANet-AAL
58.09
70.23
65.09
54.35
61.94
CSAANet-CSAM
54.58
68.66
60.74
49.41
58.35
CSAANet
57.43
69.33
64.26
53.37
61.10
However, our experiments revealed that our proposed method still faces some challenges, as shown in Fig. 5. Particularly, when new class items are tiny or tightly spaced from the background, the segmentation results may be suboptimal due to the limited supporting information that guides the segmentation process. Furthermore, our model encounters difficulties in accurately handling the boundaries of the novel class. Lastly, for more complex and densely populated segmentation tasks, our model may struggle to achieve satisfactory segmentation results. 4.3 Ablation Studies We conducted ablation experiments on the PASCAL-5i dataset in this part using DENet as a baseline to test the validity of our proposed model. Table 3 presents the experimental results, where we examined the impact of the CSAM and AAL modules in the 1-way 1-shot case, with ResNet50 as the backbone. For each experiment, the average outcomes of the four folds on the PASCAL-5i dataset are displayed. The findings shown in Table 3 show that the model’s mIoU increased by 2.78% as a result of the installation of the CSAM module. Additionally, the addition of the AAL module resulted in a 0.91% decrease in mIoU. However, in our experiments, we observed that the AAL module accelerates the convergence of the model, which illustrates the effect of the model with and without the AAL module on achieving the best convergence, as shown in Fig. 6. Finally, when combining the CSAM and AAL modules, we achieved a 1.94% improvement in mIoU, while also reducing the training time required to reach convergence. The outcomes of these tests show how the proposed model’s performance may be improved by adding CSAM and auxiliary alignment learning modules.
770
G. Wei and P. Qian
Fig. 5. Failure cases of CSAANet on the PASCAL-5i dataset.
Fig. 6. The impact of the presence or absence of the AAL module on the training of DENet and CSAANet.
5 Conclusion We offer a few-shot semantic segmentation network in this paper, CSAANet. Our proposed network, includes a novel attention module CSAM that incorporates channel and spatial attention mechanisms to obtain weight information of new classes, leveraging support images and its masks. Additionally, we introduce a dynamic classifier that efficiently guides query image segmentation using the newly acquired class weight information. Furthermore, we propose an auxiliary alignment learning module to fully utilize support image information and enhance the model’s generalization capability. Our experiments on the PASCAL-5i and COCO-20i datasets demonstrate that our method outperforms others in terms of efficacy.
References 1. Zha, H., Liu, R., Yang, X., Zhou, D., Zhang, Q., Wei, X.: ASFNet: adaptive multiscale segmentation fusion network for real-time semantic segmentation. Comput. Animat. Virtual Worlds 32(3–4), e2022 (2021) 2. Rao, X., Lu, T., Wang, Z., Zhang, Y.: Few-shot semantic segmentation via frequency guided neural network. IEEE Signal Process. Lett. 29, 1092–1096 (2022) 3. Chang, Z., Lu, Y., Wang, X., Ran, X.: MGNet: mutual-guidance network for few-shot semantic segmentation. Eng. Appl. Artif. Intell. 116, 105431 (2022)
CSAANet: An Attention-Based Mechanism
771
4. Gong, C., Shi, K., Niu, Z.: Hierarchical text-label integrated attention network for document classification. In: Proceedings of the 2019 3rd High Performance Computing and Cluster Technologies Conference, pp. 254–260 (2019) 5. Long, J., Shelhamer, E., Darrell, T.: Fully convolutional networks for semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3431–3440 (2015) 6. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W., Frangi, A. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-245744_28 7. Oktay, O., Schlemper, J., Folgoc, L.L., et al.: Attention U-Net: Learning Where to Look for the Pancreas. arXiv preprint arXiv:1804.03999 (2018) 8. Fu, J., Liu, J., Tian, H., et al.: Dual attention network for scene segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3146–3154 (2019) 9. Gidaris, S, Komodakis, N.: Dynamic few-shot visual learning without forgetting. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4367–4375 (2018) 10. Wang, K., Lie J., Zou, Y., Zhou, D., Feng, J.: PANet: few-shot image semantic segmentation with prototype alignment. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 9197–9206 (2019) 11. Rakelly, K., Shelhamer, E., Darrell, T., Efros, A.A., Levine, S.: Few-shot segmentation propagation with guided networks. arXiv preprint arXiv:1806.07373 (2018) 12. Liu, W., Zhang, C., Lin, G., and Liu, F.: CRNet: cross-reference networks for few-shot segmentation. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4165–4173 (2020) 13. Siam, M., Oreshkin, B.N., Jagersand, M.: AMP: adaptive masked proxies for few-shot segmentation. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 5249–5258 (2019) 14. Wang, Z., Liu, L., Li, F.: TAAN: task-aware attention network for few-shot classification. In: 2020 25th International Conference on Pattern Recognition (ICPR) (2021) 15. Liu, L., Cao, J., Liu, M., Guo, Y., Chen, Q., Tan, M.: Dynamic extension nets for fewshot semantic segmentation. In: Proceedings of the 28th ACM International Conference on Multimedia (2020) 16. Zhang, X., Wei, Y., Yang, Y., Huang, T.S.: SG-One: similarity guidance network for one-shot semantic segmentation. IEEE Trans. Cybern. 50, 3855–3865 (2020) 17. Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: Advances in Neutal Information Processing Systems, pp. 4077–4087 (2017) 18. Shaban, A., Bansal, S., Liu, Z., Essa, I., Boots, B.: One-shot learning for semantic segmentation. arXiv preprint arXiv:1709.03410 (2017) 19. Rakelly, K., Shelhamer, E., Darrell, T., Efros, A., Levine, S.: Conditional networks for few-shot semantic segmentation. In: International Conference on Learning Representations, Workshop Track Proceedings (2018) 20. Zhang, C., Lin. G., Liu, F., Yao, R, Shen, C.: CANet: class-agnostic segmentation networks with iterative refinement and attentive few-shot learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5217–5226 (2019)
772
G. Wei and P. Qian
21. Zhang, C., Lin, G., Liu, F., Guo, J, Wu, Q., Yao, R.: Pyramid graph networks with connection attentions for region-based one-shot semantic segmentation. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 9587–9595 (2019) 22. Wu, Z., Shi, X, Lin, G., Cai, J.: Learning meta-class memory for few-shot semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 571–526 (2021)
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies Liangjun Chen and Pengjiang Qian(B) School of Artificial Intelligence and Computer Science, Jiangnan University, Wuxi 214122, Jiangsu, China [email protected]
Abstract. Few-shot learning has made significant progress recently thanks to pretraining methods and meta-learning approaches. These methods, however, require an extensive labeled dataset that is difficult to obtain. We propose an unsupervised few-shot learning method based on positive expansions and negative proxies to fully utilize abundant unlabeled data. Our approach learns meaningful representations through self-supervised pre-training on unlabeled data using a simple but effective positive and negative sampling strategy. Specifically, we sort the negative queue in descending order based on similarity to the query embedding and then select the top N negatives as positive extensions. Behind these N negatives, we choose M negatives as proxies. We further incorporate this sampling strategy into a novel contrastive loss function. We learn the representation by minimizing the distance between query and positive extensions while maximizing the distance to negative proxies. Our approach greatly narrows the performance gap between supervised and unsupervised learning in two widely used few-shot benchmarks. Under a linear evaluation protocol, our method also achieves performance comparable to current SOTA self-supervised learning methods. Keywords: Few-shot Learning · Self-supervised Learning · Contrastive Learning
1 Introduction In recent years, great achievements have been made with supervised deep learning in a number of computer vision tasks. These results require abundant manual annotated data. Unfortunately, the cost of data collecting and manual annotation is prohibitive in practical applications. This has allowed few-shot learning (FSL) [1–3, 14] to gain extensive attention. Meta-learning is a popular solution for few-shot learning, which learns general prior knowledge from manually constructed diverse tasks and then applies that knowledge to the downstream FSL task. However, the task construction phase usually requires a large number of manually annotated labels. Several unsupervised metalearning approaches attempt to construct synthetic tasks without labels using data augmentation [5] or generating pseudo-labels [4] but their performance still falls far short of supervised meta-learning methods. In addition, pre-training-based methods [6, 7] © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 773–784, 2023. https://doi.org/10.1007/978-981-99-4761-4_65
774
L. Chen and P. Qian
attempt to train a transferable feature extractor from data in similar domains and apply it to downstream few-shot tasks, which achieve excellent results with supervision but still cannot overcome the challenge of relying on scarce labeled samples. We follow this paradigm but focus on learning reliable prior knowledge from only unlabeled images during pre-training. Self-supervised learning has demonstrated significant potential for learning representations from unlabeled data, with contrastive learning showing particularly promising performance. This approach aims to bring similar instances closer in the embedding space while forcing dissimilar instances away from each other. The sampling strategies for positive and negative instances are crucial for the effectiveness of contrastive learning. Several studies [12, 13, 20] have emphasized the development of various image augmentation techniques to generate positives from distinct views of the single image. Although these techniques have improved the performance of contrastive learning, the effect of the number of positive instances on representation learning requires further investigation. Many existing works on negative sampling strategy define negatives as all views of the remaining images in the batch without considering their semantic information. The negatives may be from the same class as the anchor, resulting in incorrect distancing instead of attracting. Moreover, using a large number of negatives to learn high-quality representations from unlabeled data would cause significant computational overhead. To overcome this issue, several works [11, 14, 16] proposed to abandon negative samples and rely solely on positive pairs for self-supervised learning. However, this inevitably leads to representation collapse. Recently, some works have investigated how the negatives affect contrastive learning and found that most negatives do not contribute to contrastive learning. Only a small fraction of negatives, known as hard negatives, are necessary and sufficient for learning good representations, as they carry the most valuable information. Based on the above motivation, in this paper, we follow the paradigm of learning representations during pre-training and propose an effective method that leverages a selfsupervised network similar to BYOL [11]. Our network employs a simple and efficient positive and negative sampling strategy. Specifically, we store the embedding features of negatives by maintaining a momentum queue and then sort the queue in descending order according to the similarity to the anchor image. Then, select the top N negatives as positive extensions containing the most similar semantic information. Behind these N negatives, we choose M negatives as proxies. We further incorporate this sampling strategy into a novel contrastive loss function. Our sampling strategy for few-shot learning differs from existing approaches in that positives include diverse views of other images in addition to views of the same image. Also, a carefully chosen subset of the remaining images in the batch, rather than the views of all of them, are included in the negatives. This encourages images to move closer to multiple views with similar semantic information in the embedding space while allowing them to learn rich variability from negatives that provide the most learning value, thus reducing computational overhead. We also apply this sampling strategy to a novel contrastive loss function, which combines BYOL and MoCo.
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies
775
In brief, this paper’s key contributions are as follows: 1. We propose an effective method to FSL that leverages contrastive learning and learns rich representations when pre-training. Our simple and efficient positive and negative sampling strategy improves the quality of representations. 2. We introduce a novel contrastive loss function combining BYOL and MoCo. And we integrate this new loss function with our effective sampling strategy to further enhance the effectiveness of our unsupervised few-shot learning method. 3. Our work significantly narrows the performance gap between unsupervised and supervised SSL on two standard benchmarks. 4. Our method outperforms current SOTA SSL learning methods under a linear evaluation protocol.
2 Related Work Few-shot learning has become a prominent area of research, with meta-learning being a popular approach for solving few-shot problems. Meta-learning typically trains a model through a series of episode tasks, from which prior knowledge is extracted to quickly adapt to new tasks without having to train from scratch. For instance, MAML [1] optimizes the initial parameters of the network to quickly adapt to new tasks through gradient descent. To address challenges in learning high-dimensional parameter spaces, [15] proposed a new latent embedding optimization algorithm called LEO, which combines encoders and relation networks to reduce the dimensionality of the embedding space and improve generalization performance. To eliminate dependence on data labels, some researchers have proposed unsupervised meta-learning methods. For instance, CACTUs [4] obtains pseudo-labels through unsupervised clustering of training data and runs the standard meta-learning algorithm. In contrast, UMTRA [5] utilizes original and augmented images as support and query sets, respectively, to construct few-shot tasks. While most of these works focus on constructing meta-learning tasks without labels, their performance is nevertheless constrained by meta-learning approaches that fall short of supervised FSL methods now in use. Additionally, recent research has demonstrated that a straightforward yet challenging baseline in FSL is to pre-train the embedding network using supervised learning and to fine-tune the target task. Instead of designing intricate meta-learning strategies, many works [6, 7] have concentrated on developing effective embeddings. Despite significant progress in these methods, their reliance on annotations has constrained their practical application. In this paper, we follow this paradigm, but with a difference, we use unlabeled data during the pre-training stage to learn representations based on contrastive learning. Contrastive learning has emerged as a promising self-supervised learning approach in computer vision, with the goal of bringing similar instances closer together in the embedding space while pushing dissimilar instances away. The pioneering works in contrastive learning, such as SimCLR [10] and MoCo [8, 9], store negatives using the large batch size and memory bank, which are the augmented views of the same image. Recent studies have proposed abandoning the use of negatives altogether. SwAV [16] and Barlow Twins [14] have proposed different ways of indirectly comparing positives by clustering and maximizing the similarity between distorted views. Although these
776
L. Chen and P. Qian
methods have shown remarkable results, they may lead to representation collapse, where the model learns to map all samples to a single point. Other seminal works have recently been presented to analyze the role of negatives in contrastive learning and proposed various techniques to improve the quality of negatives. [17] demonstrated empirically that not all negatives have a positive effect on contrastive learning. [18] employed a sophisticated strategy to eliminate false negatives. [22] served that the use of very close samples is detrimental to contrastive learning. [19] proposed a “de-biasing” approach that corrects for the fact that not all negatives can be true negatives. As mentioned above, not all negatives are helpful for contrastive learning, but discarding negatives may lead to representation collapse. In contrast, our work reconsiders the sampling strategy of positives and negatives to avoid collapse.
3 Method In this section, we first introduce some notations, formalize the few-shot problem, and illustrate InfoNCE [23], a loss function widely used in contrastive learning. Next, we introduce our approach to contrastive learning using positive expansions and negative proxies (PENP), a novel self-supervised method by which meaningful embeddings can be learned in the few-shot pre-training phase. 3.1 Problem Statement FSL with Unsupervised Pre-training. In unsupervised few-shot image classification, the objective is to pre-train an embedding network using an unlabeled base dataset Dbase that can adapt efficiently to the few-shot task T during the fine-tuning phase. The few-shot task T is typically formulated as N-way K-shot, where N classes are selected from Dnoval . From each class, K and Q images are sampled to form a support N ,K N ,Q and a query set Qt = {(xqn , yqn )}n,q=1 , respectively. Notice that set St = {(xkn , ykn )}n,k=1 Dbase and Dnoval are completely disjoint, i.e., Dbase ∩ Dnoval = ∅. In the fine-tuning phase, the classifier is trained using St , which predicts on Qt . The evaluation metric is the average prediction accuracy of the query set across all sampled few-shot classification tasks. We describe our approach to pre-train an embedding network next. Contrastive learning aims to learn representations that obtain meaningful information by narrowing the distance between positives and expanding the distance between negatives. InfoNCE is typically utilized as the loss function in contrastive learning: Lu,v+ ,{v− } = − log
exp(d (u, v+ )/τ ) exp(d (u, v− )/τ ) + exp(d (u, v+ )/τ )
(1)
where τ is the softmax temperature, u is the embedding representation of the input image X , v+ is the positive of u, v− are the negatives of u, and d represents the distance metric. Most self-supervised methods consider positives as another augmentation of the same image, meaning that for an anchor u, there is only one positive v+ . Negatives, on the other hand, are views from different images, either from the current batch or from a memory bank. To extract good knowledge from as many negatives as possible, prior
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies
777
works like SimCLR or MoCo, either significantly increase the batch size or maintain large memory bank. However, studies have indicated that increasing the memory/batch size yields diminishing returns in terms of performance. Furthermore, most negatives are non-essential, and only a few hard negatives are crucial for contrastive learning. In the following section, we propose our approach to address this issue (Fig. 1).
Fig. 1. Overview of the proposed method. All notations here are detailed in Sect. 3.2
3.2 Positive Expansions and Negative Proxies Similar to BYOL, our proposed approach employs an asymmetric architecture consisting of an online network and a target network that are initialized in the same way. The online network is composed of three modules: an encoder f , a projector g, and a predictor p, while the target network is similar to the online network but does not have a projector. The online network parameters, denoted as θ , are updated by back-propagating the losses, the target network parameters, denoted as ξ , are updated to the running average of the online network. In addition, we introduce a momentum queue Q, which is used to store negatives that are sorted and updated during the training process. Notice that after the pre-training stage, only the online encoder will be applied with downstream tasks. Given an unlabeled input image X from the dataset Dbase , two views x1 and x2 are generated by randomly applying two image augmentations. The representation vectors g(f (x1 )) g(f (x2 )) u = g(f (x1 ))2 and v = g(f (x2 ))2 are obtained by the encoder and the projector from the online network and the target network, respectively. We further calculate z = p(u) by predictor p from the online network. Then, before adding v to the momentum queue Q, we compute the nearest neighbors of the representation u from Q: NN (u, Q) = arg min(d (u, q))
(2)
q∈Q
where d (u, q) represents the distance between two embeddings, in our experiments, N +M we let d (u, q) = u − q2 . Next, we select {vi }N i=1 and {vj }j=N +1 from NN (u, Q), N respectively. {vi }i=1 are the closest samples to the anchor. In fact, it has been found that
778
L. Chen and P. Qian
these negatives are not only unnecessary but also harmful because they may be identified as false negatives even though they are from the same class as the anchor. Inspired by [24], we propose to treat these samples, which are highly correlated in semantic information, as extensions of the positive, together with the true positive v+ as moving targets of N +M the anchor u. {vj }j=N +1 are the hard negatives we selected, since they carry information which is worth learning, we use them as proxies for negatives to participating in the comparison instead of all negatives. With the positive expansions and negative proxies, we combine them with a new contrastive loss that we propose. Our loss function is defined as follows: 1 (d (z, v+ ) + N exp( (N +1)τ i=1 d (z, NN (u, Q)i ))) L = − log N + M j=N +1 exp(d (u, NN (u, Q)j )/τ )
(3)
The target network parameters ξ are obtained by a weighted average of the current target network and the previous target network: ξ ← m · ξ + (1 − m) · θ
(4)
where m denotes the momentum coefficient and we update the weights ξ after each iteration. Algorithm 1 summarizes the process of our method. Notably, as suggested by previous works such as BYOL and SimSiam, meaningful representations can be acquired by exploiting the asymmetry in positives. Therefore, we first use the projector p to compute z = p(u) in the online network before aligning z with the positives of u. To ensure a more uniform distribution of negatives, we employ the same settings as SimCLR and MoCo when comparing negatives. This way, our method combines the positive attraction approach of BYOL with the negative repulsion mechanism of MoCo in the loss function, which effectively mitigates the issue of representation collapse. In the special case where N = 0, no positive expansion is utilized. On the other hand, when M = 0, our method is identical to the standard BYOL. When M equals the length of the momentum queue Q, it signifies the utilization of all negatives for comparison, making our approach equivalent to a variant of BYOL with MoCo.
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies
779
Algorithm 1 Positive Expansions and Negative Proxies Require: D, B, N, M, m 1: for each training epoch do 2: Load a minibatch X from D with B samples 3: Generate two random augmented views: x1 = 4: Calculate target embeddings (
(
(
(
=‖
(
(
(
(
1 )) 1 ))‖2
1 (X),
,
x2 =
2 (X)
= p(u) and online embeddings
=
2 )) 2 )) 2
5: Compute the nearest neighbors: NN(u, Q) from the momentum queue Q 6: Dividing positive extensions and negative proxies from NN(u, Q) according to N and M 7: Calculate loss L with Eq. (3) 8: enqueue and dequeue 9: Update the weights of the online network
by SGD
10: Update the weights of the target networks with Eqs. (4)
4 Experiments 4.1 Datasets We test the effectiveness of our proposed method on two standard FSL datasets, miniImageNet [26], and tieredImageNet [27]. In addition, our method is compared with SOTA self-supervised learning methods under a linear evaluation protocol, showing that the proposed method is applicable to a broader range of downstream tasks than just few-shot learning. miniImageNet is well-known for benchmarking FSL algorithms. It is a subset of ImageNet [29], containing 100 classes with 600 images per class for a total of 60,000 images. We followed the protocol in [28] using 64 classes for training, 16 classes for validation, and 20 classes for testing. tieredImageNet is introduced by Ren et al., comprising 608 classes, each with 1000 images. The classes are organized into 34 higher-level categories, with 20 categories (consisting of 351 classes) designated for training, 6 categories (97 classes) for validation, and 8 categories (160 classes) for testing. Due to the substantial semantic variation between the different categories, learning from limited amounts of data is especially challenging in this dataset. 4.2 Implementation Details During the pre-training stage, in order to compare more fairly with the baseline, we use ResNet18 as the encoder, which is widely adopted by FSL. Both the projector and the predictor are multilayer perceptrons (MLPs) with similar architectures. Specifically, each MLP consists of a linear layer followed by a batch normalization layer, a ReLU
780
L. Chen and P. Qian
activation function and a second linear layer. We set the length of the momentum queue to 128 k and the momentum to 0.99. The number of positive expansions N is set to 5, and the number of negative proxies M equals to one-third of the length of the momentum queue. For the effect of the values of N and M on performance, see 4.4. We use an SGD optimizer with a weight decay of 0.0004, learning rate of 0.3 and 0.1 for tieredImageNet and miniImageNet, respectively. The temperature coefficient τ is 2.0. In the few-shot evaluation phase, we fix the parameters of the pre-trained encoder and train the logistic regression classifier using a FSL task sampling 5-way, 1-shot and 5-shot in the test classes. Each task was evaluated for classification accuracy using 15 query samples. 4.3 Comparison with the State-of-the-Art In this subsection, we compare our proposed method with supervised and unsupervised few-shot learning (FSL) approaches on miniImageNet and tieredImageNet benchmarks. On the miniImageNet, all the methods involved in the comparison use resnet18 as the backbone, while on the tieredImageNet, the comparison methods use resnet12 as the backbone. The results are presented in Table 1 and Table 2. Overall, our method outperforms the unsupervised approaches and is comparable to the SOTA supervised methods. Table 1. Comparison with prior works on miniImageNet Setting
Method
1-shot
5-shot
Supervised
MAML [1]
49.61 ± 0.92
65.72 ± 0.77
MatchingNet [26]
52.91 ± 0.88
68.88 ± 0.69
ProtoNet [2]
54.16 ± 0.82
73.68 ± 0.65
Baseline++ [6]
51.87 ± 0.77
75.68 ± 0.63
TADAM [3]
58.50 ± 0.30
76.70 ± 0.30
CACTUs [4]
39.90 ± 0.74
53.97 ± 070
UMTRA [5]
43.09 ± 0.35
53.42 ± 0.31
SimCLR [10]
55.76 ± 0.88
75.59 ± 0.69
MoCo [9]
54.19 ± 0.93
73.04 ± 0.61
Ours
58.60 ± 0.43
78.20 ± 0.56
Unsupervised
Our work achieves substantial performance gains of 18.7% and 25% over unsupervised meta-learning in 1-shot and 5-shot, respectively on the miniImageNet. This indicates that learning high-quality embedding representations is a more practical and efficient approach than designing complex unsupervised meta-learning tasks. Moreover, our method demonstrates comparable or even superior performance to supervised FSL methods without using labels. The significant improvements to the unsupervised approach prove the efficacy of our approach. Additionally, on the tieredImageNet benchmark, our
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies
781
method, while not the best performer compared to supervised SOTA approaches, significantly reduces the performance gap between them compared to unsupervised methods. We believe that the performance gap with supervised SOTA may be due to the fact that more classes information mitigates to some extent the overfitting problem of supervised methods on tieredImageNet. Table 2. Comparison with prior works on tieredImageNet Setting
Method
1-shot
5-shot
Supervised
MetaOptNet [30]
65.99 ± 0.72
81.56 ± 0.53
LEO [15]
66.33 ± 0.05
81.44 ± 0.09
BML [31]
68.99 ± 0.50
85.49 ± 0.34
Unsupervised
CC + Rot [32]
62.93 ± 0.45
79.87 ± 0.33
CACTUs [4]
56.47 ± nan
71.09 ± nan
SimCLR [10]
63.38 ± 0.42
79.17 ± 0.34
SimSiam [25]
64.05 ± 0.40
81.40 ± 0.30
Ours
65.41 ± 0.22
81.45 ± 0.52
Table 3. Accuracy under linear evaluation protocol CIFAR-10
CIFAR-100
STL-10
MoCo
79.33
52.34
81.11
BYOL
81.47
49.68
80.09
Ours
82.73
50.02
83.45
To verify whether the embedding network can learn good representations as we expected, we adopt the linear evaluation protocol commonly used in self-supervised learning. First, we pre-train the feature extractor using unlabeled training data. Then, we train a linear classifier followed the frozen backbone network using the labeled testing data. Unlike FSL, the classes in the test phase are visible in the training phase under the linear evaluation protocol. Experiments are built on CIFAR-10, CIFAR-100 and STL-10 to evaluate the effectiveness of our approach. We examine the performance of our method using the linear evaluation protocol and compare it with BYOL and MoCo. Table 3 presents the experimental results, which demonstrate that our method outperforms BYOL and MoCo with varying degrees of improvement. This indicates our method’s potential applicability to a diverse variety of tasks beyond FSL.
782
L. Chen and P. Qian
4.4 Ablation Study We investigate the impacts of positive expansions and negative proxies on our proposed method. To expedite the training process and obtain results quickly, we train our embedding network for 100 epochs and the classifier for 40 epochs in our ablation experiments. The remaining settings are consistent with those described in Sect. 4.2. Table 4. Accuracy with different number of positive expansions N =0
N =2
N =5
N = 10
1-shot
43.70
44.89
45.45
42.15
5-shot
59.33
61.12
60.83
59.17
Effect of Positive Expansions. We examine the effect of different numbers of positive expansions (N ) on performance without using negative proxies (N = 0). For our experiments, we select N values of 0, 2, 5, and 10. We take N = 0 as the baseline, which equates to the combination of BYOL and MoCo. The results (see Table 4) indicate that utilizing positive expansions has a positive impact on performance compared to the baseline. Furthermore, our method is resilient to changes in N and performs optimally when N = 5. Consequently, in this paper, N defaults to 5 when discussing positive expansions unless stated otherwise (Table 5).
Table 5. Accuracy with different number of negative proxies
N =0 N =5
M = L/4
M = L/3
M = L/2
M =L
1-shot
43.82
46.61
46.95
43.70
5-shot
61.59
61.88
62.31
59.33
1-shot
43.88
45.21
47.98
45.45
5-shot
63.63
64.92
64.72
60.83
Effect of Negative Proxies. Similar to the above, we examine the impact of negative proxies with (N = 5) and without (N = 0) positive expansions. We set the variable to the number of negative proxies (M ) and let L denote the momentum queue length. For our experiments, we choose M values of L/4, L/3, L/2, and L, respectively. When M = L, we do not use negative proxies. Overall, different degrees of negative proxies lead to varying degrees of performance improvement, with the greatest performance gain achieved when selecting L/3 and L/2 negatives as proxies. Furthermore, the use of positive expansions enhances the gain effect. These results suggest that hard negatives play a critical role in learning representation and that the hardest negative can be considered positive due to its semantic similarity to the anchor rather than being forced to be distinguished from it.
Unsupervised Few-Shot Learning via Positive Expansions and Negative Proxies
783
5 Conclusion Our work proposes a self-supervised pre-training network for few-shot learning, which learns good embeddings from unlabeled samples by using a sampling strategy of positive expansions and negative proxies during the pre-training stage and generalizes well to downstream few-shot tasks. In addition, we improve the loss function by combining BYOL with MoCo. Experimental results demonstrate that the method we proposed significantly closes the performance gap between supervised and unsupervised FSL in two popular benchmarks. And it achieves an advanced performance under the selfsupervised linear evaluation. Although our proposed positive and negative sampling strategy significantly improves the performance of contrastive learning, we believe that the impact of backbone structure on performance is well worth being explored, especially since ViT, which is very popular nowadays, will be applied in future work. In addition, although our method proves to have good generalization performance in downstream tasks, further research is needed to address the few-shot cross-domain problem.
References 1. Finn, C., Abbeel, P., Levine, S.: Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (2017) 2. Snell, J., Swersky, K., Zemel, R.: Prototypical Networks for Few-shot Learning (2017) 3. Oreshkin, B.N., López, P., Lacoste, A.: TADAM: task dependent adaptive metric for improved few-shot learning (2018) 4. Hsu, K., Levine, S., Finn, C.: Unsupervised Learning via Meta-Learning (2018) 5. Khodadadeh, S., Bölöni, L., Shah, M.: Unsupervised Meta-Learning For Few-Shot Image Classification (2018) 6. Chen, W.-Y., Liu, Y.-C., Kira, Z., Wang, Y.-C., Huang, J.-B.: A Closer Look at Few-shot Classification (2019) 7. Tian, Y., Wang, Y., Krishnan, D., Tenenbaum, J.B., Isola, P.: Rethinking few-shot image classification: a good embedding is all you need?. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.M. (eds.) ECCV 2020. LNCS, vol. 12359, pp. 266–282. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58568-6_16 8. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: 2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2020) 9. Chen, X., Fan, H., Girshick, R., He, K.: Improved Baselines with Momentum Contrastive Learning (2020) 10. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: ICML (2020) 11. Grill, J.-B., et al.: Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning (2020) 12. Zheng, M., et al.: ReSSL: Relational Self-Supervised Learning with Weak Augmentation (2021) 13. Zhang, H., Cisse, M., Dauphin, Y., Lopez-Paz, D.: mixup: Beyond Empirical Risk Minimization 14. Zbontar, J., Jing, L., Misra, I., LeCun, Y., Deny, S.: Barlow Twins: Self-Supervised Learning via Redundancy Reduction (2021) 15. Rusu, A., et al.: Meta-Learning with Latent Embedding Optimization (2018)
784
L. Chen and P. Qian
16. Caron, M., Misra, I., Mairal, J., Goyal, P., Bojanowski, P., Joulin, A.: Unsupervised Learning of Visual Features by Contrasting Cluster Assignments (2020) 17. Cai, T., Frankle, J., Schwab, D.J., Morcos, A.S.: Are all negatives created equal in contrastive instance discrimination (2021) 18. Huynh, T., Kornblith, S., Walter, M.R., Maire, M., Khademi, M.: Boosting contrastive selfsupervised learning with false negative cancellation. In: 2022 IEEE/CVF Winter Conference on Applications of Computer Vision (WACV) (2022) 19. Chuang, C.-Y., Robinson, J., Lin, Y.-C., Torralba, A., Jegelka, S.: Debiased Contrastive Learning (2020) 20. Wu, M., Zhuang, C., Mosse, M., Yamins, D, Goodman, N.: On Mutual Information in Contrastive Learning for Visual Representations (2020) 21. Peng, X., Wang, K., Zhu, Z., You, Y.: Crafting Better Contrastive Views for Siamese Representation Learning 22. Harwood, B., Vijay Kumar, B.G., Carneiro, G., Reid, I., Drummond, T.: Smart Mining for Deep Metric Learning (2017) 23. Gutmann, M.U., Hyvärinen, A.: Noise-contrastive estimation: a new estimation principle for unnormalized statistical models (2010) 24. Khosla, P., et al.: Supervised Contrastive Learning (2020) 25. Chen, X., He, K.: Exploring simple Siamese representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 15750–15758 (2021) 26. Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., Wierstra, D.: Matching Networks for One Shot Learning (2016) 27. Triantafillou, E., et al.: Meta-Learning for Semi-Supervised Few-Shot Classification (2018) 28. Ravi, S., Larochelle, H.: Optimization as a Model for Few-Shot Learning (2017) 29. Russakovsky, O., et al.: Imagenet large scale visual recognition challenge. IJCV 115, 211–252 (2015) 30. Lee, K., Maji, S., Ravichandran, A., Soatto, S.: Meta-learning with differentiable convex optimization. In: 2019 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2020) 31. Zhou, Z., Qiu, X., Xie, J., Wu, J., Zhang, C.: Binocular mutual learning for improving fewshot classification. In: 2021 IEEE/CVF International Conference on Computer Vision (ICCV) (2022) 32. Gidaris, S., Bursuc, A., Komodakis, N., Perez, P.P., Cord, M.: boosting few-shot visual learning with self-supervision. In: 2019 IEEE/CVF International Conference on Computer Vision (ICCV) (2020)
GAN for Blind Image Deblurring Based on Latent Image Extraction and Blur Kernel Estimation Xiaowei Huang and Pengjiang Qian(B) School of Artificial Intelligence and Computer Science, Jiangnan University, Wuxi 214122, Jiangsu, China [email protected]
Abstract. We propose a GAN for image deblurring based on latent image extraction and blur kernel estimation, with which the single image deblurring assignment is successfully completed. We introduce the FFT to the image of the latent image extraction because the FFT can convert image transformation from spatial to frequency domain; this is a good solution for the convolutional neural network for partial frequency domain knowledge learning and gives a sharper picture; We also apply the cross-scale reproducibility of natural images to the extraction of blur kernel. By adding regularization constraints to the kernel to enhance estimation precision, the estimated kernel of the image is generated via fusing of local kernels after numerous iterations. Meanwhile, a multi-scale discriminator structure combining RSGAN and PatchGAN is used. RSGAN is applied as a global discriminator with more accurate classification criteria, while PatchGAN is applied as a local discriminator to determine the accuracy of local blur kernel. The experiment proves that our work is effective. Keywords: Image blind deblurring · Generative adversarial network · Latent image extraction · Blur kernel estimation
1 Introduction Image deblurring is a fundamental and active problem in the realm of low-level computer vision, spurred by the image processing and computer vision groups. For a given blurred image that may be caused by multiple factors, by reducing or erasing the blurring effects in the blurred image, the single image deblurring operation attempts to restore a clear version of the blurred image. Image blurring can be produced by a number of reasons, including a focus loss, camera shake, or quick target motion, among others. Mathematically, the process may be represented as follows: Ib = Is ⊗ Mker nel + Mnoise
(1)
where Ib , Is , Mker nel and Mnoise are blur image, sharp image, blur kernel, and additional noise, respectively, and ⊗ is used to represent the convolution operator. © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 785–796, 2023. https://doi.org/10.1007/978-981-99-4761-4_66
786
X. Huang and P. Qian
There are two types of image deblurring methods: non-blind deblurring and blind deblurring. The former implies that the blur kernel is known, whereas the later assumes that it is unknown. Image deblurring was traditionally seen as an inverse filtering problem prior to the introduction of non-deep learning, and non-blind deblurring used conventional image deconvolution methods such as Lucy Richardson or Wiener deconvolution to generate clear pictures. Blind deblurring restores both sharp image and blur kernel. Because Eq. (1) indicates that this task is complex, several restrictions are introduced to regularize the answer. Although non-deep learning-based algorithms do well in some circumstances, they do not function well in complicated or even natural environments. The discipline of computer vision has already been completely transformed as deep learning has advanced. Under the influence of deep learning technology, several domains, including image classification and object identification, have made substantial progress, and picture deblurring is no exception.In terms of blind deblurring, significant progress has been made both from the basic layer blocks (convolutional layers, repeated layers, residual layers, tight layers, attention layers, etc.) and network structures (DAE, GAN, cascade network, multi-scale network, etc.). Deep learning-based blind deblurring is divided into two distinct forms: one is to obtain sharp images directly by neural network learning, and the other is to obtain sharp images by neural network learning to estimate the blur kernel and then further deconvolution. In this paper, the main ideas of the two kinds of methods are combined, and an image blind deblurring method is proposed to obtain the image content and blur kernel at the same time, and obtain the sharp image through the convolution of the two. In this paper, the main contribution has the following three points: 1. The residual structure with rapid Fourier transform is used to translate the picture from the spatial domain to the frequency domain, enabling the neural network to learn information from the image’s high and low frequency areas at the same time. 2. Use the cross-scale reproducibility of natural image to extract blur kernel. 3. Deconvolve the acquired sharper image with the pre-estimated blur kernel to obtain a sharper image.
2 Related Works In this part, we will briefly introduce the article main module on the image to the blur related work. 2.1 Blind Deblurring The challenge of blind image deblurring has achieved a significant advance in recent years, thanks to the rapid growth of deep learning technologies. There are many people who estimate blur kernels and latent images to obtain sharp images through various network structures and methods. MSCNN [7] employs a multi-scale convolutional neural network to restore blurred pictures generated by complicated factors. (Seungjun Nah et al. 2017). SRN [8] proposed a scaling recurrent network (SRN) for image deblurring, which also opened a series of studies based on this idea (Xin Tao et al. 2018).SelfDeblur [9] sets two generative networks to obtain blur kernel and latent image respectively, an
GAN for Blind Image Deblurring Based on Latent Image Extraction
787
unconstrained neural network is trained for optimization, and a clear image is obtained by blind deconvolution method (Dongwei Ren et al. 2019). ASNet [10] presents a new design that divides the deblurring network into two parts: an analytical network that estimates ambiguity and a composite network that deblurs pictures using the kernel. (Adam Kaufman et al. 2020).DeFMO [11] emitters an image of a blurred object into a latent space to disentangle the background and render the appearance, output the appearance and position of the object in a series of subframes, and achieve temporal super-resolution (Denys Rozumnyi et al. 2021). These approaches can be classified into two types. One type is training to be able to get a sharp image of the network directly, then through a variety of structures to improve network generalization; Another is by extracting the latent image blurred image and blur through deconvolution and other means to obtain sharp images. 2.2 GAN-Based Deblurring GAN [16] is composed of two models: discriminant D and the generator G, the two modules to form a minimax game, finally through the game to get the best training results. The task of generator is to learn the features of the original image, with the aim of capturing the true data distribution and generating samples that are closer to the true sample, which is used to deceive the discriminator; The discriminator’s job is to determine if the incoming generator generated the sample as true or false. The hope is that the generator will produce more realistic images, making it more difficult for the discriminator to distinguish between fake and true images. In recent years, the application of GAN in image deblurring has made amazing progress. In the image deblurring task, the task of the generator is trained to learn the features of the original image, with the purpose of capturing the true data distribution and generating a sharper image. The discriminator’s job is to distinguish between the sharp image generated by the generator and the true sharp image. DeblurGAN-v2 [1], a twoscale discriminator end-to-end based on a relative conditional GAN, which for the first time introduces feature pyramids into deblurring and can allow flexible replacement of the backbone network (Orestes Kupyn et al., 2019); UCSDBN [12] proposed an end-toend deblurring architecture with unsupervised training using unpaired data. The model consists of GAN, which uses adversative loss to obtain priors in the sharp image domain and map blur images to their sharp equivalent images (Madam Nimisha et al. 2018). DBRBGAN [13] proposed a new training idea, which is divided into two modules, BGAN and DBGAN. BGAN is used to blur the sharp image from the dataset including unpaired sharp and blur images, then used to guide DBGAN to deblur such images (Kaihao Zhang et al. 2020). MSG – GAN [14] by allowing gradient in multiple scales from discrimination to the generator of multi-scale gradient network, solved the GANs due to the distribution of the true and false support when there is not enough overlap from the discriminator is passed to the generator gradient becomes no information due to the low degree of generalization of the problem (Animesh Karnewar et al., 2020). Due to the characteristics of GAN and CNN, the combination of the two makes them achieve more satisfactory effects in the course of image deblurring. Our work also proposes a GAN-based on CNN and GAN to extract the latent image and blur kernel of blur image simultaneously. The generator part of the network extracts the latent image and
788
X. Huang and P. Qian
blur kernel at the same time and obtains sharp image through deconvolution method. The discriminator part of the network is the combination of RSGAN [2] and PatchGAN [3]. To improve the network’s accuracy, a regularization term that is conducive to generating blur kernels with higher precision is introduced to the loss function.
3 Methodology We introduce a generative adversarial network in this section. This network simultaneously acquires the latent image of the blurred image and the blur kernel. The generator is divided into two parts: one is CNN with fast Fourier transform residual structure to obtain latent image, the other is CNN based on cross-scale reproducibility of natural image to estimate blur kernel, and the last is deconvolutional network to generate sharp image. Based on the idea of DeblurGAN-v2, the discriminator adopts the dual scale discriminator combining global and local. Figure 1 depicts the network structure:
Fig. 1. Network structure.
3.1 LIE (Latent Image Extraction) The network backbone used by the LIE module is MobileNet-v3-small [5]. The main reasons for selecting this network are as follows: Firstly, compared with other networks, the MobileNet-v3-small network structure is the final global network structure found by the platform-aware neural architecture method, and its structure is more reasonable. Compared with the MobileNet-v3-large network generated by the same method, its network structure is more concise. Secondly, in the MobileNet-v3 network, the SE module is introduced, which brings channel attention to the network. This module pools each channel and improve the performance of the network. Third, the MobileNet-v3 network itself has an inverted residual structure, which adopts a strategy of first expansion, then feature extraction, and finally compression. In this paper, fast Fourier transform (FFT) [6] is introduced into this module to enhance the acquisition of latent images. Figure 2 depicts the structure of the improved FFT-Resblock. The theoretical [4] reasons for the introduction of fast Fourier transform are as follows: first, ResBlock cannot effectively learn global information due to the small
GAN for Blind Image Deblurring Based on Latent Image Extraction
789
receptive field, and in the image deblurring task, effective learning of global information is more important for learning the difference between sharp images and blurred images. Secondly, experiments have proved that ResBlock has a better learning ability in the high-frequency region of natural images, but it is not satisfactory in the low-frequency region. Meanwhile, the analysis of the frequency domain of the image shows that sharp image contains less data of low frequency and more data of high frequency than blur image. Thirdly, Fourier transform can transform images from spatial domain to frequency domain, and then realize the network’s learning of high and low frequency information of images through 1 × 1 convolution.
Fig. 2. Structure diagram of FFT-RES module. The module is divided into FFT branch, convolution branch and Identity branch.
3.2 FKE (Blur Kernel Extraction) The theoretical basis of this module is the cross-scale reproducibility of natural images or the recurrence of small scale crops on a single image. To put it simply, for a picture in a natural scene, the field of view is different between the original size and the crop after downsampling for different sizes of a certain piece of the picture, but The general pixel distribution of the image is consistent amongst them. Using this consistency, we can get the approximate local blur kernel of the crop, and generate several local blur kernels for the original image after several iterations. The final blur kernel can be formed by the fusion of several local blur kernel. As illustrated in Fig. 3, the FKE module is divided into three parts: namely image cropping, downsampling and blur kernel update. The downsampling part is linear neural network. The main idea of this module is to generate training crops after image clipping and downsampling of the blur image given input. Meanwhile, real crops as large as the training crops are cut from the original image under the calculation of a probability matrix. Since the training crops are processed by downsampling, their field of view is smaller than that of the real crops. In this way, both of them are in line with the application scenario of cross-scale reproducibility of natural images. Next, the blur kernel update part will judge the patch allocation of the two. The blur kernel update part is a 7-layer
790
X. Huang and P. Qian
Fig. 3. Diagram of the schematic of the FKE module. The input image is processed by three modules: data cropping, downsampling and kernel extraction and update to generate the local estimated kernel. The cropped image is directly sent to the kernel extraction and update module to update the estimated kernel. All local blur nuclei were extracted and synthesized.
linear network, in which the bicubic linear interpolation is carried out. This process will go through several iterations to make the blur kernel update part think that the training crop after downsampling is from the original image. After this effect is achieved, it can be considered that the bicubic linear interpolation matrix is the estimated local blur kernel of the crop. Since the blur kernel generated in this way is a part of the entire blur kernel, we can consider the partial blur kernel as a single blur kernel. Therefore, we will impose the following constraints on the local blur kernel, as shown in Eq. 2: R = αLnorm + βLbound −to−zero + γ Lcentralize
(2)
where α = 0.6, β = 0.5, γ = 1, and: Lnorm = 1 − mi,j i,j mi,j · ni,j Lbound −tp−zero =
(3) (4)
i,j
Lcentralize
i,j mi,j · (i, j) = (xcen , ycen ) − i,j (i, j)
(5)
2
Since the result of the bicubic linear interpolation matrix given by the blur kernel update module also represents the similarity between the training crop and the real crop, the constraint given by Eq. (3) aims to make the sum of this matrix tend to 1. Since the image of the training crop is considered to be a local part of the sharp image, the local crop is regarded as a single blur kernel. Therefore, the boundary value of the constrained blur kernel tends to 0 in Eq. (4), where n is the constant mask of the weight, and the weight increases exponentially with the distance from the center. Equation (5) is a very important regularization constraint for the estimation of the blur kernel. If the centroid
GAN for Blind Image Deblurring Based on Latent Image Extraction
791
is not at the center of the kernel, it will cause ambiguity. Based on this, we believe that it is necessary to restrict the overlap between the center of the local blur kernel and the centroid. After extracting local fuzzy kernel for many times, the image is fused to obtain the final fuzzy kernel. 3.3 Discriminator Considering the diversity of the causes of natural image blurring and the uncertainty of blur kernel, we will use the global and local discriminator jointly composed of RAGAN and PatchGAN. The global discriminator uses RAGAN. The reason for choosing RAGAN is that RAGAN’s discriminator not only considers the working principle that the discriminator should discriminate the false samples generated by the generator as false, but also discriminate the true samples as true. Also considered as the optimization of the generator to make false sample is more and more real results in the decrease of real samples is judged to be true probability. Based on this prior information, RAGAN does not judge the probability that the samples fed into the discriminator are true or false, but judges the probability that one sample is more true than another. The local discriminator uses PatchGAN. The reasons why we choose PatchGAN are as follows: first, the receptive field of PatchGAN corresponds to a small area of the original blur image, which corresponds to the slicing pre-estimated version of the local blur kernel when we estimate the blur kernel, so PatchGAN is more in line with our needs. Secondly, PatchGAN itself is aimed at the image field of high resolution and high-definition details, so PatchGAN is the most suitable choice. 3.4 Loss Functions The loss function is divided into three components. The first component is the loss of deblurring. Charbonnier Loss [17] is a variant of L1 Loss. By introducing a minimum constant, the loss function is differentiable in the neighborhood of 0. The Charbonnier distance between the generator generated image and the sharp image is compared in this section, as shown in Eq. 6: Ldeblur =
N 1 2 x + α2 N
(6)
1
where α is the minimum constant and is set to α = 1e−3 in the experiment. The second part is the adversarial loss, which is calculated by the discriminator D, as shown in Eq. 7: = LRPGAN D LRPGAN = G
1 1 Ex∼Pdata (x) [D(x) − 1]2 + Ex∼Ppatch (x) [D(x) − 1]2 2 2
1 1 Ex∼Pdata (x) [D(x) − 1]2 + Ez∼Ppatch (z) [D(G(z)) − 1]2 2 2
(7)
The third part is the regularization term, which consists of five terms, namely FFT regularization term, bicubic regularization term and blur kernel constraint regularization
792
X. Huang and P. Qian
term. The regularization term of blur kernel constraint is shown in Eq. 2, and the rest is shown in Eq. 8: R = αLnorm + βLbound −to−zero + γ Lcentralize + δLFFT + ηLbicubic
(8)
|x − y|, x is the generated image, and y is the original image; where, LFFT = N1 N 2 N M 1 Lbicubic = 1 1 xi,j − yi,j , x is the estimated blur kernel, y is the given blur kernel; in the experiment, δ = 0.7, η = 5 So the joint loss is shown in Eq. 9: L = λ1 Ldeblur + λ2 Ladv + R
(9)
where λ1 = 0.7 and λ2 = 0.4.
4 Experiments 4.1 Implementation Details The operating environment of our experiment is win11 system, the CPU model is 11th Gen Intel(R) Core(TM) i7-11800H@ 2.30 GHz 2.30 GHz, The GPU is a single-core NVIDIA GeForce RTX 3070 Laptop GPU, and all the models are implemented using PyTorch. In this paper, GOPRO data set is used to divide the training set and test set into 22 groups, each group contains roughly 100 training pictures and 100 test pictures. Commonly used PSNR and SSIM are adopted as evaluation indicators, and the final indicators are presented by the average value obtained by multiple experiments. The three hyperparameters of FKE module are the scale factor of each image, the clipping size and the number of iteration clipping. Experiments show that the network works best when scale_factor is set to 0.5,crop_size is set to 64, and max_iteration is set to 1000. The three hyperparameters of generating adversarial network lr represent the learning rate of generating adversarial network, b1 represents the decay rate of first-order momentum, and b2 represents the decay rate of second-order gradient momentum. Experiments show that the network works best when lr is set to 0.0008,b1 is set to 0.5, and b2 is set to 0.999. 4.2 Selection of the Backbone Network for FDCE In the selection of backbone, we have evaluated three backbone networks, namely MobileNet-v2, MobileNet-v3-large and MobileNet-v3-small [5], and applied FFT in the third, which has greatly improved the network effect. Considering effect, memory and running time, we finally chose MobileNet-v3-small. In our experiments, we found that the fourth set of tests always had better results, while the seventh set had the opposite effect. Through experiments, we find that the two sets of data have great differences in the frequency domain information of images. The low frequency area of group 4 is large and the change is small, while the high frequency area is small but the change is large. Group 7 was reversed. Because convolutional neural network is more sensitive to high-frequency changes, the training results of the fourth
GAN for Blind Image Deblurring Based on Latent Image Extraction
793
Table 1. Selection of backbone network and experimental results. MobileNet -v2 [21]
MobileNet -v3-large
MobileNet -v3-small
FFT-MobileNet -v3-small
PSNR-avg
20.5871
20.6699
20.6687
21.8784
PSNR-max
23.8796
24.1202
24.0875
25.2601
PSNR-min
17.0860
17.0217
17.0441
19.0245
SSIM-avg
0.7434
0.7544
0.7524
0.8023
SSIM-max
0.9073
0.9165
0.9157
0.9235
SSIM-min
0.6076
0.6112
0.6101
0.7214
group are better. Therefore, we introduce a residual module with FFT into the backbone network to convert the image from the spatial domain to the frequency domain, so that the network can better learn the information of the full frequency domain of the image. The experimental results show that PSNR and SSIM value are improved comprehensively in training set and test set. Table 1 illustrates this point. 4.3 Contrast Experiment In the comparative experiment, we compared with several image deblurring models we used for reference and several newly published deblurring models, and Table 2 displays the results.It can be seen from Table 2 that our work has obtained a great advantage in comparison with the models used for reference, KernelGAN and DeblurGAN-v2. Compared with some advanced methods in recent four years, it is better than most models, and has a comparable effect with the latest MIMO-UNet and HINet. Table 2. Contrast experiment results. method
PSNR
SSIM
DeblurGANv2 [1]
24.0716
0.9073
KernelGAN [4]
24.0442
0.8994
OID [19]
25.6241
0.9224
ASNet [10]
26.0124
0.9312
ours
26.5210
0.9375
MIMO-UNet [18]
26.6421
0.9402
HINet [20]
27.021
0.9421
794
X. Huang and P. Qian Table 3. The results of the ablation experiment.
Module
PSNR/SSIM(avg)
PSNR/SSIM(max)
PSNR/SSIM(min)
MobileNet-v3-small
20.6691/0.7544
24.1201/0.9187
17.0217/0.611
MobileNet-v3-small + FFT
21.8784/0.8023
25.2601/0.9235
19.0245/0.7214
FKE + Deconv
21.6214/0.7925
23.0127/0.8574
17.0427/0.6412
FDCE + FKE
25.0174/0.9204
26.7745/0.9401
19.4251/0.7642
Generator + RAGAN
25.4125/0.9278
26.7412/0.9387
19.4452/0.7701
Generator + PatchGAN
25.5714/0.9288
26.8412/0.9398
20.0214/0.7902
Gen + RAGAN + PatchGAN
26.5210/0.9375
27.2451/0.9541
22.1472/0.8754
4.4 Ablation Experiment Our ablation experiment was divided into seven steps. The first step is the backbone of network application MobileNet-v3. The second step introduces residual structure with FFT. T The third step is to add FKE module for fuzzy kernel extraction. The fourth step combines the latent image extraction network and the blur kernel estimation network. The fifth step replace the discriminator with RAGAN’s. The sixth step is to incorporate the PatchGAN into the discriminator. The seventh step is to combine the generator with the multi-scale discriminator composed of RAGAN and PatchGAN. Our evaluation criteria are determined from two indicators and three values, namely the mean, maximum, and minimum PSNR/SSIM values. Table 3 summarizes the specific experimental results. 4.5 Display of Experimental Results We tested the photos from the test set after network training. Figure 4 depicts the effects of a portion of the test images: As shown in Fig. 4, the left picture is fuzzy image, the middle picture is clear image, and the right picture is our experimental results. It can be seen that our experimental results are highly similar to the clear image.
GAN for Blind Image Deblurring Based on Latent Image Extraction
795
Fig. 4. Part of the experimental results.
5 Conclusion We implement an improved image deblurring model in this research. It uses GAN to extract the latent image and blur kernel. The model first consider the images for the whole frequency domain information of the study, by introducing residual structure makes no matter with FFT for high frequency area of the color changed or color change gently the low-frequency area can have a good ability of information extraction, and then based on natural images across scales implements the reproducibility for the extraction of blur core, deconvolution is used to blur image; High quality results are obtained under the confrontation of global and local dual-scale discriminators.
References 1. Kupyn, O., Martyniuk, T., Wu, J., Wang, Z.: Deblurgan-v2: Deblurring (Orders-of-Magnitude) Faster and Better. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8878–8887 (2019) 2. Yadav, S., Chen, C., Ross, A.: Synthesizing Iris Images Using RaSGAN with application in presentation attack detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops (2019) 3. Pan, Y., Pi, D., Chen, J., Meng, H.: FDPPGAN: remote sensing image fusion based on deep perceptual patchGAN. Neural Comput. Appl. 33, 9589–9605 (2021) 4. Bell-Kligler, S., Shocher, A., Irani, M.: Blind Super-Resolution Kernel Estimation Using an Internal-GAN. Advances in Neural Information Processing Syst. 32 (2019) 5. Kavyashree, P.S., El-Sharkawy, M.: Compressed Mobilenet v3: A Light Weight Variant for Resource-Constrained Platforms. In: 2021 IEEE 11th Annual Computing and Communication Workshop and Conference (CCWC), pp. 0104–0107. IEEE (2021) 6. Mao, X., Liu, Y., Shen, W., Li, Q., Wang, Y.: Deep Residual Fourier Transformation for Single Image Deblurring. arXiv preprint arXiv:2111.11745 (2021) 7. Nah, S., Hyun Kim, T., Mu Lee, K.: Deep Multi-Scale Convolutional Neural Network for Dynamic Scene Deblurring. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3883–3891 (2017)
796
X. Huang and P. Qian
8. Tao, X., Gao, H., Shen, X., Wang, J., Jia, J.: Scale-Recurrent Network for Deep image Deblurring. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 8174–8182 (2018) 9. Ren, D., Zhang, K., Wang, Q., Hu, Q., Zuo, W.: Neural Blind Deconvolution Using Deep Priors. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3341–3350 (2020) 10. Kaufman, A., Fattal, R.: Deblurring Using analysis-synthesis networks Pair. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5811–5820 (2020) 11. Rozumnyi, D., Oswald, M.R., Ferrari, V., Matas, J., Pollefeys, M.: Defmo: Deblurring and Shape Recovery of Fast Moving Objects. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3456–3465 (2021) 12. Nimisha, T.M., Sunil, K., Rajagopalan, A.N.: Unsupervised class-specific deblurring. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 353-369 (2018) 13. Zhang, K., et al.: Deblurring by realistic blurring. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2737–2746 (2020) 14. Karnewar, A., Wang, O.: :Msg-gan: multi-scale gradients for generative adversarial networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7799–7808 (2020) 15. Xu, L., Ren, J.S., Liu, C., Jia, J.: Deep convolutional neural network for image deconvolution. Advances in Neural Information Processing Syst. 27 (2014) 16. Creswell, A., White, T., Dumoulin, V., Arulkumaran, K., Sengupta, B., Bharath, A.A.: Generative adversarial networks: an overview. IEEE Signal Process. Mag. 35(1), 53–65 (2018) 17. Gajera, B., Kapil, S.R., Ziaei, D., Mangalagiri, J., Siegel, E., Chapman, D.: CT-scan denoising using a charbonnier loss generative adversarial network. IEEE Access 9, 84093–84109 (2021) 18. Shocher, A., Cohen, N., Irani, M.: “Zero-Shot” super-resolution using deep internal learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3118–3126 (2018) 19. Cho, S.J., Ji, S.W., Hong, J.P., Jung, S.W., Ko, S.J.: Rethinking coarse-to-fine approach in single image deblurring. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4641–4650 (2021) 20. Chen, L., Lu, X., Zhang, J., Chu, X., Chen, C.: Hinet: half instance normalization network for image restoration. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 182–192 (2021) 21. Sandler, M., Howard, A., Zhu, M., Zhmoginov, A., Chen, L.C.: Mobilenetv2: inverted residuals and linear bottlenecks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4510–4520 (2018)
GLUformer: An Efficient Transformer Network for Image Denoising Chenghao Xue and Pengjiang Qian(B) School of Artificial Intelligence and Computer Science, Jiangnan University, Wuxi 214122, Jiangsu, China [email protected]
Abstract. In this article, we introduce GLUformer, an efficient image denoising network based on Transformer. The architecture of our network is a layered encoder-decoder. In order to further reduce the amount of computation and lower the computational cost, we propose a new Global-Local Window Transformer block (GLWin). Different from the traditional global self-attention, it performs the calculation of non-overlapping window self-attention, this can effectively reduce the computational complexity of the feature map. At the same time, it combines local window division and global window division, which can effectively capture the context relationship of the image color blocks. In addition, we add dynamic position bias (DPB) to the (GLWin) block, which address the issue of limited input image size. With the support of these two designs, GLUformer shows an excellent image noise reduction effect. Our network also achieves good results on the SIDD dataset with a relatively small amount of computation. Keywords: GLWin · Dynamic position bias · Transformer
1 Introduction The brilliant technology has made the ownership rate of smartphones continue to increase, more and more people are using mobile phones or cameras to record everything around them. However, compared with professional cameras, the pixels of photos taken by mobile phones are usually affected by various factors, and the demand for eliminating unwanted degradation (such as noise and blur) in photos is increasing. Image restoration does this kind of work. Most recent work has been achieved through Convolutional Neural Networks (CNNs) [1], which consist of convolution. Compared with traditional restoration methods, the translation invariance and feature extraction characteristics of convolution make them more effective. However, convolution also has defects because its limited receptive field prevents it from capturing long-range dependencies between pixels. To solve this problem, the transformer was proposed [2]. They are good at handling long-term dependencies in sequences. Also, they can process all elements in the sequence in parallel, making them faster and more efficient for processing longer sequences. At the same time, due to the particularity of Transformers, this enables them to better capture © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 797–807, 2023. https://doi.org/10.1007/978-981-99-4761-4_67
798
C. Xue and P. Qian
the complex relationships between elements in image sequences. However, due to the quadratic growth of the computational complexity of self-attention in transformers with increasing spatial resolution, recent work has divided input images into non-overlapping patches of size 48 × 48 and independently calculated SA on each patch [3]. To this end, we propose an efficient image denoising network and name it GLUformer. Its overall structure is a U-shaped network. It is composed of GLWin blocks, which is also one of our innovations. Its function is to capture local image attention while also paying attention to global image attention. It associates the upper and lower semantic information of pixels well, which is very important for image restoration. Compared with conventional multi-head self-attention, because it performs non-overlapping window self-attention, its computational cost is smaller and the convergence speed is faster. We noticed that recent work has adopted similar designs in image classification tasks [4]. Secondly, we know that in natural language processing, the relative position between tags is more important than their absolute position. For example, in language modeling, the relationship between words and their previous and next words is crucial for predicting the next word in the sequence. Similarly, the relative position bias between color blocks in the picture is also very important. Previous work has mentioned that by combining relative positional deviations, Swin Transformer can capture more complex and subtle relationships between tags in the input sequence [4]. However, the size of this bias is limited by the size of the picture. So we propose dynamic position bias, which makes the input image size no longer limited. Based on the above two designs, our GLUformer has achieved excellent results in image restoration. In the classic dataset SIDD of image denoising, compared with the most advanced image restoration network Uformer [5] before, although the best result of our model is 0.11 dB worse, our network calculation is reduced by 0.42 G compared to it, the training time of each round is 998.17 s shorter than it, and the overall time is reduced by 7181.17 s, which greatly reduces the training time. Time consuming and computationally expensive. Overall, we summarize the contributions of this article as follows: We propose GLUformer, a U-shaped network for image denoising. GLUformer is composed of basic GLWin blocks, which are effective and efficient. We adopt the method of dynamic positional bias to improve the quality of image restoration. Our network has been demonstrated to be highly efficient in numerous experiments.
2 Related Work Image restoration aims to restore damaged or noisy images to clean and complete appearance. In recent years, the popularity of CNN networks has proved that the superiority of neural networks in solving image restoration is unmatched by traditional restoration methods [1]. In convolutional neural networks, U-shaped networks based on encoderdecoder are very effective for image restoration [6]. Its multi-level structural design can well capture multi-scale information of images, and its skip-link structure can also better learn residual signals, which is very helpful for image restoration. The Transformer model was first proposed to be developed for sequence processing in natural language tasks [7]. Unlike the structure of CNN, the mechanism design of Transformer is naturally suitable for capturing long-distance dependencies of pictures.
GLUformer: An Efficient Transformer Network for Image Denoising
799
Vision-transformer pioneered the application of transformer to image processing, and the results obtained in a series of image processing tasks have made great progress compared with CNN. Although the experimental results obtained by the visiontransformer are excellent, the calculation of the transformer is expensive due to the Multi-Head self-attention [8]. To reduce the computational cost associated with self-attention, a modification is needed to avoid its quadratic time complexity, Swin Transformer is proposed [4]. It works by dividing images into smaller “windows” and processing them in a hierarchical fashion, allowing models to efficiently process large images without running out of memory. A key innovation of the Swin Transformer is the use of a “shift window” to process the input image. Instead of using a fixed window size, it shifts the window position by a certain amount at each layer, allowing it to capture information from a larger portion of the image while still maintaining a manageable window size. This innovation puts it ahead of the rest of the network. With the continuous development of position coding, relative position coding has gradually been applied to transformers, and many networks have adopted Relative Position Bias (RPB) to replace the original APE [4]. The advantage of doing this is that the relative position coding has high learnability and strong robustness. Although shuffling embedding does not change the output of transformer, if the position of embedding is also important information, it will affect the output results. In response to this problem, many embedding position representations have been proposed in previous work, but currently the most effective one is relative position coding, which has stronger learnability. However, the input of relative position coding is relatively fixed. Recently, some work has also mentioned that dynamic position coding can be used to solve this problem [8]. This point of view is somewhat similar to ours. Although Transformer has many explorations in the field of vision, it has not had much involvement in low-level vision. Recently, Uformer was proposed and applied to image deblurring and image deraining work, achieving good results [5]. However, its calculation time is still too long. Based on this point, we have redesigned a new U-shaped network based on transformer and constructed a new network by combining global window attention and local window attention methods. It achieved good results.
3 Method Within this chapter, We first describe the overall hierarchical structure of GLUformer, and then provide a detailed explanation of the basic module, GLWin, which constitutes the entire network. 3.1 Overall Pipeline As shown in Fig. 1(a), our network GLUformer is a U-shaped hierarchical structure based on encoder-decoder with skip links between encoder and decoder [6]. Specifically, a noisy real image I ∈ R3×H ×W is applied to an input layer, which consists of extracting features X0 ∈ RC×H ×W in the image through a 4 × 4 convolution layer and LeakyReLU. Pass the extracted features X0 through K encoder stages (K defaults to 4). Each encoder
800
C. Xue and P. Qian
Fig. 1. (a) The overall structure of GLUformer
stage is composed of a GLWin block and a downsampling layer. The GLWin block captures long-distance dependencies through the self-attention mechanism. At the same time, through the combination of self-attention of local non-overlapping windows and self-attention of global non-overlapping windows, it can better link context while reducing computational costs. Unlike previous work, the difference is that we replaced Relative position bias (RPB) [4] with Dynamic position bias (DPB), so that the network is no longer limited by the size of the input image. In the downsampling layer, we first reshape the flat features into a 2D spatial feature map, and then downsample the feature map through a 4 × 4 convolution with a stride of 2. We added a bottleneck structure consisting of GLWin blocks after the final encoder. Because our network is hierarchical, it is possible for Transformer to capture longer dependencies. Next is the stage of feature reconstruction, and the decoder also includes 4 stages. Each decoder stage consists of an encoder-like GLWin block and an upsampling layer. We use a 2 × 2 transposed convolution with a stride of 2 for upsampling, which halves the number of channels and doubles the feature maps. The features obtained by the corresponding encoder and the features after the upsampling layer are stitched together through skip links and input into the GLWin block.After 4 decoding stages, the flattened features have been reshaped into 2D feature maps. Finally, the refined features are obtained through a 3 × 3 convolutional layer to obtain a residual image R ∈ R 3×H ×W . Finally, the original image containing noise is added by Iˆ = I + R to obtain the final restored image. We used Charbonnier loss [9] when training GLUformer: 2 ˆ (1) (I , I ) = I − Iˆ + 2 where Iˆ is the real image, = 10−3 is a constant
GLUformer: An Efficient Transformer Network for Image Denoising
801
3.2 GLWin Transformer Block We often encounter some problems when we use transformers in practical applications. For example, in VIT [10], for each pixel in the feature map, it needs to be calculated together with all pixels when calculating self-attention. In this way, its calculation amount is very large. To address this issue we propose the GLWin block, shown in Fig. 2(a). It is composed of a local-windows Multi-Head Self-Attention (LW-MSA) module, a global-windows Multi-Head Self-Attention (GW-MSA) module and a Feedforward Network with Local Enhancement (Leff) module [11]. As shown in Fig. 2, LW-MSA modules and GW-MSA modules appear alternately in different blocks. At the same time, in both LW-MSA and GW-MSA, we used dynamic position bias (DPB) to obtain the embedded position representation. We use residual links in each block, which allows better flow of image information. Below we will explain the LW-MSA, GW-MSA, Leff and DPB modules respectively.
Fig. 2. (a) The structure of the GLWin block (b) LW-MSA (c) GW-MSA
Local-windows Multi-Head Self-Attention (LW-MSA). As mentioned above, when performing Multi-Head Self-Attention calculation in VIT[10], it is calculated on the global feature map. The height, width and depth of the image are H, W and C respectively, then its calculation amount is 4HWC 2 + 2(HW )2 C. To address this issue, we divide the feature map into multiple non-overlapping regions and calculate their local window self-attention separately [4]. As shown in Fig. 2 (b). First, we divide the input image X into windows of size M × M (M = 2), where X has a height of H, width of W and depth of C. Then we calculate self-attention separately for each window. This undoubtedly greatly reduces the calculation amount. The calculation amount of LW-MSA. When the output vector is X, calculate H × W according to the height M of each window, and a total of M M windows will be obtained. According to the previous work, the calculation amount of each window is 4(MC)2 + 2(M)4 C, so the total calculation amount is: (4(MC)2 + 2(M)4 C) ×
W H × = 4HWC 2 + 2M 2 HWC M M
(2)
802
C. Xue and P. Qian
Compared with VIT [10], it can be seen that the calculation amount of LW-MSA has decreased significantly. Global-windows Multi-Head Self-Attention (GW-MSA). Although LW-MSA can reduce the amount of calculation very well. But it lacks the information transmission between different windows. To solve this problem, we propose Global-windows MultiHead Self-Attention (GW-MSA), which can better capture the long-term dependencies in the picture. We divide the window by a method similar to the Atrous Convolution [12]. As demonstrated in Fig. 2(c), it can be seen that the part with the same color belongs to the same window. We can also get 4 divided windows by sampling, and because the sampled pixel blocks are not adjacent, which can better represent a nearly global information. We combine local window attention and global window attention into a new module, which combines the advantages of local attention and global attention while greatly reducing the amount of computation.
Fig. 3. Feedforward Network with Local Enhancement
Feedforward Network with Local Enhancement (LeFF). Traditional FFNs lack the ability to capture the spatial structure of images. To address these issues, we incorporate convolutions into FFN networks [13]. As demonstrated in Fig. 3, first we put the input into a linear projection layer, increasing its dimensionality. The goal is to introduce non-linearities into the network, enabling it to learn more complex patterns in the data. Here we use GELU as the default activation function. Then we apply a 3 × 3 convolutional layer to capture local information that may be missed by traditional feed-forward neural networks. The features are then flattened and fed into a second linear projection layer to restore their dimensions. Finally we utilize jump link to connect the input to the final output. In this way, the spatial structure of the input data can be preserved to the greatest extent, and it can also make it easier for information to flow in the network to help alleviate the problem of gradient disappearance. Dynamic Position Bias (DPB). With the continuous development of position coding, relative position coding is gradually applied to transformers, and many networks use Relative position bias (RPB) [4] to replace the original Adaptive Positional Encoding (APE) [14]. The advantage of doing this is that the relative position coding is highly
GLUformer: An Efficient Transformer Network for Image Denoising
803
learnable and robust. And it can be directly inserted into the corresponding attention, the following is the formula of self-attention: √ (3) Attention = Soft max(QK T / d + B)V In the above formula, Q, K, and V represent the query in self-attention, key, value, √ 2 2 d is constant normalizers, and B ∈ RG ×G is the RPB matrix. Although RPB is very easy to use, there is a fatal problem with relative position encoding. Previous work mentioned Bi,j = Bxij,yij , [4] where B is a fixed-size matrix, and (xij , yij ) is the distance between the i and j embeddings. Obviously, if the size of the image changes, the size of (xij , yij ) will exceed the range of B, so that the image size is restricted. To address this issue, we propose a module that can dynamically generate relative position bias.
Fig. 4. Dynamic Position Bias
This module takes a bias tensor as input. Each element in the bias tensor corresponds to a positional bias. These bias tensors are usually computed using a fixed set of sine functions, but our module can compute them dynamically based on the input. Specifically, as shown in Fig. 4. The input bias first goes through a linear projection layer, projecting the dimensions into a lower-dimensional space. The projected bias is then passed to three successive layers, each layer consisting of layer normalization [15], ReLU activation function [16], and linear projection. These three layers are used to calculate the final positional deviation.
4 Experiments We demonstrate the validity of GLUformer on the task of image denoising on the SIDD dataset [17]. At the same time, ablation experiments are performed on each of our proposed new modules in order to verify whether they are really effective. The overall experimental environment of this article is carried out on pytorch 1.8.0, using NVIDA GeForce RTX 3060 Laptop GPU.
804
C. Xue and P. Qian
4.1 Experimental Setup Basic Settings. Our initial learning rate is 2e−4. There are 4 encoders and 4 decoders included in our GLUformer. They are all composed of GLWin blocks. We use the common Peak Signal-to-Noise Ratio (PSNR) [17] metric to measure the quality of our restored images. At the same time, in the ablation experiment, we also counted the overall training time and flops of the model to measure the pros and cons of our model from the perspective of calculation. 4.2 Real Image Noise Reduction Experiment Image Noise Reduction Experiment. Table 1 shows our results after denoising real images on the SIDD dataset [17]. We selected 5 of the best image denoising networks to compare with our GLUformer, namely BM3D [19], DnCNN [18], RIDNet [21], CBDNet [20] and Uformer [5]. Our GLUformer achieves 37.22 db on PSNR, although it is still lower than the state-of-the-art Uformer (0.11 db). Table 1. Experimental results of image denoising on the SIDD dataset Method
PSNR
DnCNN [18]
23.66
BM3D [19]
25.65
CBDNet [20]
30.07
RIDNet [21]
38.71
Uformer[5]
37.33
GLUformer(ours)
37.22
In Fig. 5 below, we pick several networks and test their experimental results on the SIDD [17] dataset. Our GLUformer network successfully removes noise while preserving image details. 4.3 Ablation Study We compared the ablation experiments of each module in GLUformer, and the results are well shown in Table 2 and Table 3. GLWin vs LeWin Because our overall network architecture is similar to Uformer, we compare the building block LeWin in Uformer [5] with the GLWin of our network. We can see that as shown in Table 2, although our network lags behind by 0.11 db compared to Uformer, we have fewer parameters and calculations, and the overall running time is shorter. This also achieves our purpose of building this module.
GLUformer: An Efficient Transformer Network for Image Denoising
805
Fig. 5. Evaluate the effectiveness of several networks for image denoising on the SIDD dataset and compare their performance
Table 2. Comparison of experimental results of GLUformer and Uformer Method
PSNR
GMACs
Total operation time
Uformer[5]
37.33
12.00 G
22604.82 s
GLUformer
37.22
11.58 G
15423.65 s
DBP vs RBP We compared the effects of adding the dynamic position bias (DBP) module and the relative position bias (RBP) [4] module in GLWin to our experimental results. We can see that in Table 3 below, after adding the DBP module, the experimental results are improved by 0.12 db. Although a little calculation is sacrificed, it is completely acceptable to us. This is a good demonstration of the effectiveness of DBP for image restoration.
806
C. Xue and P. Qian Table 3. In the experiment, RPB and DPB were added to GLWin respectively.
Method
PSNR
GMACs
Total operation time
GLUformer-RPB
37.10
11.43 G
11419.78 s
GLUformer-DPB
37.22
11.58 G
15423.65 s
5 Conclusion We propose an efficient network for image denoising and name it GLUformer, through a combination of local attention and global attention. Our GLUformer network is built on the main component GLWin block, which can process global context information while processing local context information, and can also capture long-distance dependencies well. At the same time, it solves the problem of large amount of calculation in traditional self-attention. Experiments show that GLUformer achieves advanced performance in image denoising. At the same time, its training time is shorter and the amount of calculation is smaller. But we haven’t performed more image inpainting experiments on GLUformer. We look forward to researching more applications of GLUformer.
References 1. Cheng, S., Wang, Y., Huang, H.: NBNet: Noise Basis Learning for Image Denoising with Subspace Projection. arXiv (2021) 2. Zhang, Y., Li, K., Li, K.: Image Super-Resolution Using Very Deep Residual Channel Attention Networks. arXiv (2018) 3. Dosovitskiy, A., Beyer, L., Kolesnikov, A.: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. arXiv (2021) 4. Liu, Z., Lin, Y., Cao, Y.: Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. arXiv (2021) 5. Wang, Z., Cun, X., Bao, J.: Uformer: A General U-Shaped Transformer for Image Restoration. arXiv (2021) 6. Ronneberger, O., Fischer, P., Brox, T.: U-Net: Convolutional Networks for Biomedical Image Segmentation. arXiv (2015) 7. Vaswani, A., Shazeer, N., Parmar, N.: Attention Is All You Need. arXiv (2017) 8. Wang, W., Yao, L., Chen, L.: CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention. arXiv (2021) 9. Charbonnier, P., Blanc-Feraud, L., Aubert, G.: Two deterministic half-quadratic regularization algorithms for computed imaging. In: Proceedings of 1st International Conference on Image Processing. 2, 168–172 (1994) 10. Ranftl, R., Bochkovskiy, A., Koltun, V.: Vision Transformers for Dense Prediction. arXiv (2021) 11. Li, Y., Zhang, K., Cao, J.: LocalViT: Bringing Locality to Vision Transformers. arXiv (2021) 12. Chen, L.-C., Papandreou, G., Kokkinos, I.: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs. arXiv (2017) 13. Wu, H., Xiao, B.: CvT:Introducing Convolutions to Vision Transformers. 10 Apr 2023. http:// export.arxiv.org/abs/2103.15808
GLUformer: An Efficient Transformer Network for Image Denoising
807
14. Sukhbaatar, S., Grave, E., Bojanowski, P.: Adaptive Attention Span in Transformers. arXiv (2019) 15. Ba, J.L., Kiros, J.R., Hinton, G.E.: Layer Normalization. arXiv (2016) 16. Agarap, A.F.: Deep Learning using Rectified Linear Units (ReLU). arXiv (2019) 17. Image Quality Assessment: From Error Visibility to Structural Similarity. 11 Apr 2023 18. Zhang, K., Zuo, W., Chen, Y.: Beyond a gaussian denoiser: residual learning of deep CNN for image denoising. IEEE Trans. Image Process. 26(7), 3142–3155 (2017) 19. Dabov, K., Foi, A., Katkovnik, V.: Image denoising by sparse 3-D transform-domain collaborative filtering. IEEE Trans. Image Process. 16(8), 2080–2095 (2007) 20. Guo, S., Yan, Z., Zhang, K.: Toward Convolutional Blind Denoising of Real Photographs. arXiv (2019) 21. Zhuo, S., Jin, Z., Zou, W.: RIDNet: recursive information distillation network for color image denoising. 2019 IEEE/CVF International Conference on Computer Vision Workshop (ICCVW), pp. 3896–3903 (2019)
Author Index
A Alshabandar, Raghad An, Jingbo 40
B Bin, Zhang
482
301
C Cai, Dupeng 3, 27 Cai, Shitao 749 Cai, Xingquan 239 Cao, Ruifen 363 Cao, Shi 88 Cao, Yi 51 Cao, Yunxiang 227 Chang, Hui 653 Chen, ChunYun 559 Chen, Cong 503 Chen, Dinghao 680 Chen, Houjin 680 Chen, Li 227, 251 Chen, Liangjun 773 Chen, Peng 277, 288 Chen, Yanyu 715, 738 Chen, Yuehui 51 Cheng, Pengyan 239
D Dai, Yuxing 726 Ding, Rui 385, 641 Ding, Siyi 75 Ding, Weiping 469 Dong, Haoyu 63 Dong, Yongsheng 147 Dong, Zehao 51 Du, Zexing 193 Duan, Yong 726
F Feng, Guangsheng Feng, Zhida 227
457
G Gao, Yang 653 Gao, Yufei 263 Gong, Congjin 63 Guo, Jun 669 Guo, Wenjuan 726 Guo, Wenqi 493 Guo, Xiangkai 137 Guo, Zhenyu 493 H Har, Dongsoo 600 He, Di 193 He, Fuyun 125 He, Guanrong 469 He, Shaoqin 653 He, Wanxian 40 He, Xin 726 Hong, Tao 204 Hu, Cong 125 Hu, Rong 537 Hu, Yan 239 Hu, Yingbiao 351 Hu, Zhuhua 3, 27 Huang, Da 421, 433 Huang, Hongbin 629 Huang, Jiangsheng 705, 715, 749 Huang, Jianye 715, 738 Huang, Qiyao 669, 705 Huang, Xiaowei 785 Hussain, Abir 482 J Jaddoa, Ali 482 Ji, Young Seo 397 Jiang, Wanchun 327 Jiang, Yizhang 549
© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 D.-S. Huang et al. (Eds.): ICIC 2023, LNAI 14090, pp. 809–812, 2023. https://doi.org/10.1007/978-981-99-4761-4
810
Jiao, Jinchao 112 Jie, Gao 181 Jin, Huaiping 537 Jin, Taisong 691 Ju, Jian 726 K Ke, Xinyu 385, 641 Kim, Taeyoung 600 L Lei, Zhuo 171 Li, Chao 15 Li, Gang 527 Li, Hui 503, 514 Li, Huinian 351 Li, Jianrong 409 Li, Jie 409 Li, Jinhu 705, 715, 726, 749 Li, Jun 251 Li, Lingling 691 Li, Ruoqing 27 Li, Shengquan 171 Li, Wei 457 Li, Xiaobing 373 Li, Xuan 629 Li, Xuewei 469 Li, Zheng 409 Li, Zihang 691 Li, Zijue 363 Liao, Feilong 669, 705 Lim, Soon bum 397 Lin, Chenxiang 669, 749 Lin, Guoqing 738 Lin, Jingyu 680 Lin, Rui 738 Liu, Bin 157, 614 Liu, Hanqiang 137 Liu, Jie 691 Liu, Lihua 629 Liu, Zhaoguo 421, 433 Liu, Zhaoyang 216 Liu, Zhihan 457 Liu, Zhiqiang 469 Liu, Zhitao 204 Long, Yingjie 351 Lu, Sipei 137 Luo, Jiawei 445 Lv, Xiaoxuan 514
Author Index
M Ma, Hui 409 Ma, Jinwen 147, 204 Mao, Yunqing 171 Meng, Caixia 263 Meng, Xiangxu 457 Moon, Woohyeon 600 N Nengroo, Sarvar Hussain Nie, Lei 549 Niu, Dongmei 216 Niu, Zihan 288
600
P Park, Bumgeun 600 Pei, Yuanhua 147 Peng, Siyu 445 Q Qi, Hao 3, 27 Qi, Lin 373 Qian, Bin 537 Qian, Jian 715, 738 Qian, Pengjiang 760, 773, 785, 797 Qian, Youwei 125 Qian, Zhenyu 549 R Ren, Zhicheng
327
S Shan, Danfeng 327 Shan, Yijing 327 Shan, Yongxin 327 Shen, Cong 445 Shen, Fengling 726 Shen, Longchao 147 Shi, Gongcheng 409 Shi, Lei 263 Shi, Youqun 559, 573 Shou, Lidan 171 Song, Kefeng 40 Sun, Lijun 527 Sun, Xiaofeng 15 T Tang, Xiaohu 125 Tao, Qianwen 559, 573
Author Index
Teng, Fei 112 Tian, Ye 363 Tie, Yun 373 Tu, Shikui 493
W Wang, Bing 277, 288 Wang, Bo 340, 445 Wang, Huiqiang 457 Wang, Jianxin 327 Wang, Jun 559 Wang, Luping 157 Wang, Mingkai 691 Wang, Qing 193 Wang, Qingqing 589 Wang, Shuo 99, 373 Wang, Xiaoqi 373 Wang, Xue 193 Wang, Yubo 227 Wang, Yutao 99 Wei, Guangpeng 760 Wei, Jiangpo 340 Wei, Lin 263 Wei, Pijing 363 Wei, Songjie 314 Wei, Tianpeng 493 Wei, Yan 125 Wen, Zhuoer 421 Weng, Yuyou 738 Wu, Dan 433 Wu, Di 40 Wu, Jibing 629 Wu, Qidi 573 Wu, Yongrong 680
X Xia, Yi 277, 288 Xiahou, Jianbing 680 Xiang, Ying 181 Xiang, Yunfeng 3, 27 Xiaoying, Shuai 301 Xiong, Wei 469 Xu, Chaonong 15 Xu, Lang 277 Xu, Lei 493 Xu, Linhao 537 Xu, Qiqi 99
811
Xu, Weixia 421 Xue, Chenghao 797 Xue, Hongqiu 263 Xuzhou, Fu 181 Y Yang, Bo 216 Yang, Dong 314 Yang, Gang 63, 99 Yang, Haotian 147 Yang, Lvqing 680 Yang, Shuangyuan 385, 641 Yao, Jiali 239 Ye, Dengpan 589 Ye, Jingyan 251 Yi, Longqiang 705, 715, 749 Yifan, Hu 181 Yin, Changqing 112 Yin, Qing 726 Yu, Jian 469 Yu, Mei 469 Yu, Naikang 537 Yu, Qiang 171 Yuxia, Yin 301 Z Zeng, Biyang 493 Zhang, Baofeng 653 Zhang, Caiming 216 Zhang, Chuanlei 409 Zhang, Haitao 726 Zhang, Jun 277, 288 Zhang, Leyi 351 Zhang, Qihang 421, 433 Zhang, Xinchen 653 Zhang, Yao 125 Zhang, Yingyue 669, 705 Zhang, Yunhua 99 Zhang, Zhihong 669, 705, 715, 726, 738, 749 Zhao, Chen 527 Zhao, Chenyang 614 Zhao, Xiuyang 216 Zhao, Yaochi 3, 27 Zhao, Yaou 51 Zhao, Yaping 549 Zhao, Yucheng 573 Zhao, Zheng 457
812
Zhao, Zhichao 669 Zheng, Chunhou 363 Zheng, Haotian 726 Zheng, Lintao 147 Zheng, Zhou 669, 705, 749 Zhiqiang, Liu 181
Author Index
Zhou, Can 314 Zhou, Shun 239 Zhou, Zihao 614 Zhu, Qing 75, 88 Zhu, Shuai 559, 573 Zhu, Wanting 75, 88