166 63 66MB
English Pages 744 [741] Year 2023
Mohammad Tanveer · Sonali Agarwal · Seiichi Ozawa · Asif Ekbal · Adam Jatowt (Eds.)
Communications in Computer and Information Science
1791
Neural Information Processing 29th International Conference, ICONIP 2022 Virtual Event, November 22–26, 2022 Proceedings, Part IV
Communications in Computer and Information Science Editorial Board Members Joaquim Filipe , Polytechnic Institute of Setúbal, Setúbal, Portugal Ashish Ghosh , Indian Statistical Institute, Kolkata, India Raquel Oliveira Prates , Federal University of Minas Gerais (UFMG), Belo Horizonte, Brazil Lizhu Zhou, Tsinghua University, Beijing, China
1791
Rationale The CCIS series is devoted to the publication of proceedings of computer science conferences. Its aim is to efficiently disseminate original research results in informatics in printed and electronic form. While the focus is on publication of peer-reviewed full papers presenting mature work, inclusion of reviewed short papers reporting on work in progress is welcome, too. Besides globally relevant meetings with internationally representative program committees guaranteeing a strict peer-reviewing and paper selection process, conferences run by societies or of high regional or national relevance are also considered for publication. Topics The topical scope of CCIS spans the entire spectrum of informatics ranging from foundational topics in the theory of computing to information and communications science and technology and a broad variety of interdisciplinary application fields. Information for Volume Editors and Authors Publication in CCIS is free of charge. No royalties are paid, however, we offer registered conference participants temporary free access to the online version of the conference proceedings on SpringerLink (http://link.springer.com) by means of an http referrer from the conference website and/or a number of complimentary printed copies, as specified in the official acceptance email of the event. CCIS proceedings can be published in time for distribution at conferences or as postproceedings, and delivered in the form of printed books and/or electronically as USBs and/or e-content licenses for accessing proceedings at SpringerLink. Furthermore, CCIS proceedings are included in the CCIS electronic book series hosted in the SpringerLink digital library at http://link.springer.com/bookseries/7899. Conferences publishing in CCIS are allowed to use Online Conference Service (OCS) for managing the whole proceedings lifecycle (from submission and reviewing to preparing for publication) free of charge. Publication process The language of publication is exclusively English. Authors publishing in CCIS have to sign the Springer CCIS copyright transfer form, however, they are free to use their material published in CCIS for substantially changed, more elaborate subsequent publications elsewhere. For the preparation of the camera-ready papers/files, authors have to strictly adhere to the Springer CCIS Authors’ Instructions and are strongly encouraged to use the CCIS LaTeX style files or templates. Abstracting/Indexing CCIS is abstracted/indexed in DBLP, Google Scholar, EI-Compendex, Mathematical Reviews, SCImago, Scopus. CCIS volumes are also submitted for the inclusion in ISI Proceedings. How to start To start the evaluation of your proposal for inclusion in the CCIS series, please send an e-mail to [email protected].
Mohammad Tanveer · Sonali Agarwal · Seiichi Ozawa · Asif Ekbal · Adam Jatowt Editors
Neural Information Processing 29th International Conference, ICONIP 2022 Virtual Event, November 22–26, 2022 Proceedings, Part IV
Editors Mohammad Tanveer Indian Institute of Technology Indore Indore, India Seiichi Ozawa Kobe University Kobe, Japan
Sonali Agarwal Indian Institute of Information Technology Allahabad Prayagraj, India Asif Ekbal Indian Institute of Technology Patna Patna, India
Adam Jatowt University of Innsbruck Innsbruck, Austria
ISSN 1865-0929 ISSN 1865-0937 (electronic) Communications in Computer and Information Science ISBN 978-981-99-1638-2 ISBN 978-981-99-1639-9 (eBook) https://doi.org/10.1007/978-981-99-1639-9 © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 This work is subject to copyright. All rights are reserved by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors, and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd. The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore
Preface
Welcome to the proceedings of the 29th International Conference on Neural Information Processing (ICONIP 2022) of the Asia-Pacific Neural Network Society (APNNS), held virtually from Indore, India, during November 22–26, 2022. The mission of the Asia-Pacific Neural Network Society is to promote active interactions among researchers, scientists, and industry professionals who are working in neural networks and related fields in the Asia-Pacific region. APNNS has Governing Board Members from 13 countries/regions – Australia, China, Hong Kong, India, Japan, Malaysia, New Zealand, Singapore, South Korea, Qatar, Taiwan, Thailand, and Turkey. The society’s flagship annual conference is the International Conference of Neural Information Processing (ICONIP). The ICONIP conference aims to provide a leading international forum for researchers, scientists, and industry professionals who are working in neuroscience, neural networks, deep learning, and related fields to share their new ideas, progress, and achievements. Due to the current situation regarding the pandemic and international travel, ICONIP 2022, which was planned to be held in New Delhi, India, was organized as a fully virtual conference. The proceedings of ICONIP 2022 consists of a multi-volume set in LNCS and CCIS, which includes 146 and 213 papers, respectively, selected from 1003 submissions reflecting the increasingly high quality of research in neural networks and related areas. The conference focused on four main areas, i.e., “Theory and Algorithms,” “Cognitive Neurosciences,” “Human Centered Computing,” and “Applications.” The conference also had special sessions in 12 niche areas, namely 1 2.
International Workshop on Artificial Intelligence and Cyber Security (AICS) Computationally Intelligent Techniques in Processing and Analysis of Neuronal Information (PANI) 3. Learning with Fewer Labels in Medical Computing (FMC) 4. Computational Intelligence for Biomedical Image Analysis (BIA) 5 Optimized AI Models with Interpretability, Security, and Uncertainty Estimation in Healthcare (OAI) 6. Advances in Deep Learning for Biometrics and Forensics (ADBF) 7. Machine Learning for Decision-Making in Healthcare: Challenges and Opportunities (MDH) 8. Reliable, Robust and Secure Machine Learning Algorithms (RRS) 9. Evolutionary Machine Learning Technologies in Healthcare (EMLH) 10 High Performance Computing Based Scalable Machine Learning Techniques for Big Data and Their Applications (HPCML) 11. Intelligent Transportation Analytics (ITA) 12. Deep Learning and Security Techniques for Secure Video Processing (DLST)
vi
Preface
Our great appreciation goes to the Program Committee members and the reviewers who devoted their time and effort to our rigorous peer-review process. Their insightful reviews and timely feedback ensured the high quality of the papers accepted for publication. The submitted papers in the main conference and special sessions were reviewed following the same process, and we ensured that every paper has at least two highquality single-blind reviews. The PC Chairs discussed the reviews of every paper very meticulously before making a final decision. Finally, thank you to all the authors of papers, presenters, and participants, which made the conference a grand success. Your support and engagement made it all worthwhile. December 2022
Mohammad Tanveer Sonali Agarwal Seiichi Ozawa Asif Ekbal Adam Jatowt
Organization
Program Committee General Chairs M. Tanveer Sonali Agarwal Seiichi Ozawa
Indian Institute of Technology Indore, India IIIT Allahabad, India Kobe University, Japan
Honorary Chairs Jonathan Chan P. N. Suganthan
King Mongkut’s University of Technology Thonburi, Thailand Nanyang Technological University, Singapore
Program Chairs Asif Ekbal Adam Jatowt
Indian Institute of Technology Patna, India University of Innsbruck, Austria
Technical Chairs Shandar Ahmad Derong Liu
JNU, India University of Chicago, USA
Special Session Chairs Kai Qin Kaizhu Huang Amit Kumar Singh
Swinburne University of Technology, Australia Duke Kunshan University, China NIT Patna, India
Tutorial Chairs Swagatam Das Partha Pratim Roy
ISI Kolkata, India IIT Roorkee, India
viii
Organization
Finance Chairs Shekhar Verma Hayaru Shouno R. B. Pachori
Indian Institute of Information Technology Allahabad, India University of Electro-Communications, Japan IIT Indore, India
Publicity Chairs Jerry Chun-Wei Lin Chandan Gautam
Western Norway University of Applied Sciences, Norway A*STAR, Singapore
Publication Chairs Deepak Ranjan Nayak Tripti Goel
MNIT Jaipur, India NIT Silchar, India
Sponsorship Chairs Asoke K. Talukder Vrijendra Singh
NIT Surathkal, India IIIT Allahabad, India
Website Chairs M. Arshad Navjot Singh
IIT Indore, India IIIT Allahabad, India
Local Arrangement Chairs Pallavi Somvanshi Yogendra Meena M. Javed Vinay Kumar Gupta Iqbal Hasan
JNU, India University of Delhi, India IIIT Allahabad, India IIT Indore, India National Informatics Centre, Ministry of Electronics and Information Technology, India
Regional Liaison Committee Sansanee Auephanwiriyakul Nia Kurnianingsih
Chiang Mai University, Thailand Politeknik Negeri Semarang, Indonesia
Organization
Md Rafiqul Islam Bharat Richhariya Sanjay Kumar Sonbhadra Mufti Mahmud Francesco Piccialli
ix
University of Technology Sydney, Australia IISc Bangalore, India Shiksha ‘O’ Anusandhan, India Nottingham Trent University, UK University of Naples Federico II, Italy
Program Committee Balamurali A. R. Ibrahim A. Hameed
IITB-Monash Research Academy, India Norwegian University of Science and Technology (NTNU), Norway Fazly Salleh Abas Multimedia University, Malaysia Prabath Abeysekara RMIT University, Australia Adamu Abubakar Ibrahim International Islamic University, Malaysia Muhammad Abulaish South Asian University, India Saptakatha Adak Philips, India Abhijit Adhikary King’s College, London, UK Hasin Afzal Ahmed Gauhati University, India Rohit Agarwal UiT The Arctic University of Norway, Norway A. K. Agarwal Sharda University, India Fenty Eka Muzayyana Agustin UIN Syarif Hidayatullah Jakarta, Indonesia Gulfam Ahamad BGSB University, India Farhad Ahamed Kent Institute, Australia Zishan Ahmad Indian Institute of Technology Patna, India Mohammad Faizal Ahmad Fauzi Multimedia University, Malaysia Mudasir Ahmadganaie Indian Institute of Technology Indore, India Hasin Afzal Ahmed Gauhati University, India Sangtae Ahn Kyungpook National University, South Korea Md. Shad Akhtar Indraprastha Institute of Information Technology, Delhi, India Abdulrazak Yahya Saleh Alhababi University of Malaysia, Sarawak, Malaysia Ahmed Alharbi RMIT University, Australia Irfan Ali Aligarh Muslim University, India Ali Anaissi CSIRO, Australia Ashish Anand Indian Institute of Technology, Guwahati, India C. Anantaram Indraprastha Institute of Information Technology and Tata Consultancy Services Ltd., India Nur Afny C. Andryani Universiti Teknologi Petronas, Malaysia Marco Anisetti Università degli Studi di Milano, Italy Mohd Zeeshan Ansari Jamia Millia Islamia, India J. Anuradha VIT, India Ramakrishna Appicharla Indian Institute of Technology Patna, India
x
Organization
V. N. Manjunath Aradhya Sunil Aryal Muhammad Awais Mubasher Baig Sudhansu Bala Das Rakesh Balabantaray Sang-Woo Ban Tao Ban Dibyanayan Bandyopadhyay Somnath Banerjee Debajyoty Banik Mohamad Hardyman Barawi Mahmoud Barhamgi Kingshuk Basak Elhadj Benkhelifa Sudip Bhattacharya Monowar H Bhuyan Xu Bin Shafaatunnur Binti Hasan David Bong Larbi Boubchir Himanshu Buckchash George Cabral Michael Carl Dalia Chakrabarty Deepayan Chakraborty Tanmoy Chakraborty Rapeeporn Chamchong Ram Chandra Barik Chandrahas Ming-Ching Chang Shivam Chaudhary Dushyant Singh Chauhan Manisha Chawla Shreya Chawla Chun-Hao Chen Gang Chen
JSS Science and Technology University, India Deakin University, Australia COMSATS University Islamabad, Wah Campus, Pakistan National University of Computer and Emerging Sciences (NUCES) Lahore, Pakistan NIT Rourkela, India International Institute of Information Technology Bhubaneswar, India Dongguk University, South Korea National Institute of Information and Communications Technology, Japan Indian Institute of Technology, Patna, India University of Tartu, Estonia Kalinga Institute of Industrial Technology, India Universiti Malaysia, Sarawak, Malaysia Claude Bernard Lyon 1 University, France Indian Institute of Technology Patna, India Staffordshire University, UK Bhilai Institute of Technology Durg, India Umeå University, Sweden Northwestern Polytechnical University, China UTM, Malaysia Universiti Malaysia Sarawak, Malaysia University of Paris, France UiT The Arctic University of Norway, Norway Federal Rural University of Pernambuco, Brazil Kent State University, USA Brunel University London, UK IIT Kharagpur, India IIT Delhi, India Mahasarakham University, Thailand C. V. Raman Global University, India Indian Institute of Science, Bangalore, India University at Albany - SUNY, USA Indian Institute of Technology Gandhinagar, India Indian Institute of Technology Patna, India Amazon Inc., India Australian National University, Australia National Kaohsiung University of Science and Technology, Taiwan Victoria University of Wellington, New Zealand
Organization
He Chen Hongxu Chen J. Chen Jianhui Chen Junxin Chen Junyi Chen Junying Chen Lisi Chen Mulin Chen Xiaocong Chen Xiaofeng Chen Zhuangbin Chen Long Cheng Qingrong Cheng Ruting Cheng Girija Chetty Manoj Chinnakotla Andrew Chiou Sung-Bae Cho Kupsze Choi Phatthanaphong Chomphuwiset Fengyu Cong Jose Alfredo Ferreira Costa Ruxandra Liana Costea Raphaël Couturier Zhenyu Cui Zhihong Cui Juan D. Velasquez Rukshima Dabare Cherifi Dalila Minh-Son Dao Tedjo Darmanto Debasmit Das Dipankar Das Niladri Sekhar Dash Satya Ranjan Dash Shubhajit Datta Alok Debnath Amir Dehsarvi Hangyu Deng
Hebei University of Technology, China University of Queensland, Australia Dalian University of Technology, China Beijing University of Technology, China Dalian University of Technology, China City University of Hong Kong, China South China University of Technology, China Hong Kong Baptist University, China Northwestern Polytechnical University, China University of New South Wales, Australia Chongqing Jiaotong University, China The Chinese University of Hong Kong, China Institute of Automation, China Fudan University, China George Washington University, USA University of Canberra, Australia Microsoft R&D Pvt. Ltd., India CQ University, Australia Yonsei University, South Korea The Hong Kong Polytechnic University, China Mahasarakham University, Thailand Dalian University of Technology, China UFRN, Brazil Polytechnic University of Bucharest, Romania University of Franche-Comte, France Peking University, China Shandong University, China University of Chile, Chile Murdoch University, Australia University of Boumerdes, Algeria National Institute of Information and Communications Technology, Japan STMIK AMIK Bandung, Indonesia IIT Roorkee, India Jadavpur University, India Indian Statistical Institute, Kolkata, India KIIT University, India Indian Institute of Technology, Kharagpur, India Trinity College Dublin, Ireland Ludwig Maximilian University of Munich, Germany Waseda University, Japan
xi
xii
Organization
Mingcong Deng Zhaohong Deng V. Susheela Devi M. M. Dhabu Dhimas Arief Dharmawan Khaldoon Dhou Gihan Dias Nat Dilokthanakul Tai Dinh Gaurav Dixit Youcef Djenouri Hai Dong Shichao Dong Mohit Dua Yijun Duan Shiv Ram Dubey Piotr Duda
Sri Harsha Dumpala Hridoy Sankar Dutta Indranil Dutta Pratik Dutta Rudresh Dwivedi Heba El-Fiqi Felix Engel Akshay Fajge Yuchun Fang Mohd Fazil Zhengyang Feng Zunlei Feng Mauajama Firdaus Devi Fitrianah Philippe Fournierviger Wai-Keung Fung Baban Gain Claudio Gallicchio Yongsheng Gao
Tokyo University of Agriculture and Technology, Japan Jiangnan University, China Indian Institute of Science, Bangalore, India VNIT Nagpur, India Universitas Indonesia, Indonesia Texas A&M University Central Texas, USA University of Moratuwa, Sri Lanka Vidyasirimedhi Institute of Science and Technology, Thailand Kyoto College of Graduate Studies for Informatics, Japan Indian Institute of Technology Roorkee, India SINTEF Digital, Norway RMIT University, Australia Ping An Insurance Group, China NIT Kurukshetra, India Kyoto University, Japan Indian Institute of Information Technology, Allahabad, India Institute of Computational Intelligence/Czestochowa University of Technology, Poland Dalhousie University and Vector Institute, Canada University of Cambridge, UK Jadavpur University, India Indian Institute of Technology Patna, India Netaji Subhas University of Technology, India UNSW Canberra, Australia Leibniz Information Centre for Science and Technology (TIB), Germany Indian Institute of Technology Patna, India Shanghai University, China JMI, India Shanghai Jiao Tong University, China Zhejiang University, China University of Alberta, Canada Bina Nusantara University, Indonesia Shenzhen University, China Cardiff Metropolitan University, UK Indian Institute of Technology, Patna, India University of Pisa, Italy Griffith University, Australia
Organization
Yunjun Gao Vicente García Díaz Arpit Garg Chandan Gautam Yaswanth Gavini Tom Gedeon Iuliana Georgescu Deepanway Ghosal Arjun Ghosh Sanjukta Ghosh Soumitra Ghosh Pranav Goel Tripti Goel Kah Ong Michael Goh Kam Meng Goh Iqbal Gondal Puneet Goyal Vishal Goyal Xiaotong Gu Radha Krishna Guntur Li Guo Ping Guo Yu Guo Akshansh Gupta Deepak Gupta Deepak Gupta Kamal Gupta Kapil Gupta Komal Gupta Christophe Guyeux Katsuyuki Hagiwara Soyeon Han Palak Handa Rahmadya Handayanto Ahteshamul Haq Muhammad Haris Harith Al-Sahaf Md Rakibul Hasan Mohammed Hasanuzzaman
xiii
Zhejiang University, China University of Oviedo, Spain University of Adelaide, Australia I2R, A*STAR, Singapore University of Hyderabad, India Australian National University, Australia University of Bucharest, Romania Indian Institute of Technology Patna, India National Institute of Technology Durgapur, India IIT (BHU) Varanasi, India Indian Institute of Technology Patna, India Bloomberg L.P., India National Institute of Technology Silchar, India Multimedia University, Malaysia Tunku Abdul Rahman University of Management and Technology, Malaysia RMIT University, Australia Indian Institute of Technology Ropar, India Punjabi University Patiala, India University of Tasmania, Australia VNRVJIET, India University of Macau, China Beijing Normal University, China Xi’an Jiaotong University, China CSIR-Central Electronics Engineering Research Institute, India National Library of Medicine, National Institutes of Health (NIH), USA NIT Arunachal Pradesh, India NIT Patna, India PDPM IIITDM, Jabalpur, India IIT Patna, India University of Franche-Comte, France Mie University, Japan University of Sydney, Australia IGDTUW, India Universitas Islam 45 Bekasi, Indonesia Aligarh Muslim University, India Universitas Nusa Mandiri, Indonesia Victoria University of Wellington, New Zealand BRAC University, Bangladesh ADAPT Centre, Ireland
xiv
Organization
Takako Hashimoto Bipan Hazarika Huiguang He Wei He Xinwei He Enna Hirata Akira Hirose Katsuhiro Honda Huy Hongnguyen Wai Lam Hoo Shih Hsiung Lee Jiankun Hu Yanyan Hu Chaoran Huang He Huang Ko-Wei Huang Shudong Huang Chih-Chieh Hung Mohamed Ibn Khedher David Iclanzan Cosimo Ieracitano Kazushi Ikeda Hiroaki Inoue Teijiro Isokawa Kokila Jagadeesh Mukesh Jain Fuad Jamour Mohd. Javed Balasubramaniam Jayaram Jin-Tsong Jeng Sungmoon Jeong Yizhang Jiang Ferdinjoe Johnjoseph Alireza Jolfaei
Chiba University of Commerce, Japan Gauhati University, India Institute of Automation, Chinese Academy of Sciences, China University of Science and Technology Beijing, China University of Illinois Urbana-Champaign, USA Kobe University, Japan University of Tokyo, Japan Osaka Metropolitan University, Japan National Institute of Informatics, Japan University of Malaya, Malaysia National Cheng Kung University, Taiwan UNSW@ADFA, Australia University of Science and Technology Beijing, China UNSW Sydney, Australia Soochow University, Taiwan National Kaohsiung University of Science and Technology, Taiwan Sichuan University, China National Chung Hsing University, Taiwan IRT-SystemX, France Sapientia Hungarian University of Transylvania, Romania University “Mediterranea” of Reggio Calabria, Italy Nara Institute of Science and Technology, Japan Kobe University, Japan University of Hyogo, Japan Indian Institute of Information Technology, Allahabad, India Jawaharlal Nehru University, India AWS, USA Indian Institute of Information Technology, Allahabad, India Indian Institute of Technology Hyderabad, India National Formosa University, Taiwan Kyungpook National University Hospital, South Korea Jiangnan University, China Thai-Nichi Institute of Technology, Thailand Federation University, Australia
Organization
Ratnesh Joshi Roshan Joymartis Chen Junjie Ashwini K. Asoke K. Talukder Ashad Kabir Narendra Kadoo Seifedine Kadry M. Shamim Kaiser Ashraf Kamal Sabyasachi Kamila Tomoyuki Kaneko Rajkumar Kannan Hamid Karimi Nikola Kasabov Dermot Kerr Abhishek Kesarwani Shwet Ketu Asif Khan Tariq Khan Thaweesak Khongtuk Abbas Khosravi Thanh Tung Khuat Junae Kim Sangwook Kim Mutsumi Kimura Uday Kiran Hisashi Koga Yasuharu Koike Ven Jyn Kok Praveen Kolli Sunil Kumar Kopparapu Fajri Koto Aneesh Krishna Parameswari Krishnamurthy Malhar Kulkarni Abhinav Kumar Abhishek Kumar Amit Kumar
Indian Institute of Technology Patna, India Global Academy of Technology, India IMAU, The Netherlands Global Academy of Technology, India National Institute of Technology Karnataka Surathkal, India Charles Sturt University, Australia CSIR-National Chemical Laboratory, India Noroff University College, Norway Jahangirnagar University, Bangladesh ACL Digital, India Indian Institute of Technology Patna, India University of Tokyo, Japan Bishop Heber College, India Utah State University, USA AUT, New Zealand University of Ulster, UK NIT Rourkela, India Shambhunath Institute of Engineering and Technology, India Integral University, India UNSW, Australia Rajamangala University of Technology Suvarnabhumi (RMUTSB), India Deakin University, Australia University of Technology Sydney, Australia DST Group, Australia Kobe University, Japan Ryukoku University, Japan University of Aizu, Japan University of Electro-Communications, Japan Tokyo Institute of Technology, Japan Universiti Kebangsaan Malaysia, Malaysia Pinterest Inc, USA Tata Consultancy Services Ltd., India MBZUAI, UAE Curtin University, Australia University of Hyderabad, India IIT Bombay, India NIT, Patna, India Indian Institute of Technology Patna, India Tarento Technologies Pvt Limited, India
xv
xvi
Organization
Nagendra Kumar Pranaw Kumar Puneet Kumar Raja Kumar Sachin Kumar Sandeep Kumar Sanjaya Kumar Panda Chouhan Kumar Rath Sovan Kumar Sahoo Anil Kumar Singh Vikash Kumar Singh Sanjay Kumar Sonbhadra Gitanjali Kumari Rina Kumari Amit Kumarsingh Sanjay Kumarsonbhadra Vishesh Kumar Tanwar Bibekananda Kundu Yoshimitsu Kuroki Susumu Kuroyanagi Retno Kusumaningrum Dwina Kuswardani Stephen Kwok Hamid Laga Edmund Lai Weng Kin Lai
Kittichai Lavangnananda Anwesha Law Thao Le Xinyi Le Dong-Gyu Lee Eui Chul Lee Minho Lee Shih Hsiung Lee Gurpreet Lehal Jiahuan Lei
IIT Indore, India Centre for Development of Advanced Computing (CDAC) Mumbai, India Jawaharlal Nehru University, India Taylor’s University, Malaysia University of Delhi, India IIT Patna, India National Institute of Technology, Warangal, India National Institute of Technology, Durgapur, India Indian Institute of Technology Patna, India IIT (BHU) Varanasi, India VIT-AP University, India ITER, SoA, Odisha, India Indian Institute of Technology Patna, India KIIT, India National Institute of Technology Patna, India SSITM, India Missouri University of Science and Technology, USA CDAC Kolkata, India Kurume National College of Technology, Japan Nagoya Institute of Technology, Japan Universitas Diponegoro, Indonesia Institut Teknologi PLN, Indonesia Murdoch University, Australia Murdoch University, Australia Auckland University of Technology, New Zealand Tunku Abdul Rahman University of Management & Technology (TAR UMT), Malaysia King Mongkut’s University of Technology Thonburi (KMUTT), Thailand Indian Statistical Institute, India Deakin University, Australia Shanghai Jiao Tong University, China Kyungpook National University, South Korea Sangmyung University, South Korea Kyungpook National University, South Korea National Kaohsiung University of Science and Technology, Taiwan Punjabi University, India Meituan-Dianping Group, China
Organization
Pui Huang Leong Chi Sing Leung Man-Fai Leung Bing-Zhao Li Gang Li Jiawei Li Mengmeng Li Xiangtao Li Yang Li Yantao Li Yaxin Li Yiming Li Yuankai Li Yun Li Zhipeng Li Hualou Liang Xiao Liang Hao Liao Alan Wee-Chung Liew Chern Hong Lim Kok Lim Yau Chin-Teng Lin Jerry Chun-Wei Lin Jiecong Lin Dugang Liu Feng Liu Hongtao Liu Ju Liu Linjing Liu Weifeng Liu Wenqiang Liu Xin Liu Yang Liu Zhi-Yong Liu Zongying Liu
xvii
Tunku Abdul Rahman University of Management and Technology, Malaysia City University of Hong Kong, China Anglia Ruskin University, UK Beijing Institute of Technology, China Deakin University, Australia Tsinghua University, China Zhengzhou University, China Jilin University, China East China Normal University, China Chongqing University, China Michigan State University, USA Tsinghua University, China University of Science and Technology of China, China Nanjing University of Posts and Telecommunications, China Tsinghua University, China Drexel University, USA Nankai University, China Shenzhen University, China Griffith University, Australia Monash University Malaysia, Malaysia Universiti Tunku Abdul Rahman (UTAR), Malaysia UTS, Australia Western Norway University of Applied Sciences, Norway City University of Hong Kong, China Shenzhen University, China Stevens Institute of Technology, USA Du Xiaoman Financial, China Shandong University, China City University of Hong Kong, China China University of Petroleum (East China), China Hong Kong Polytechnic University, China National Institute of Advanced Industrial Science and Technology (AIST), Japan Harbin Institute of Technology, China Institute of Automation, Chinese Academy of Sciences, China Dalian Maritime University, China
xviii
Organization
Jaime Lloret Sye Loong Keoh Hongtao Lu Wenlian Lu Xuequan Lu Xiao Luo Guozheng Ma Qianli Ma Wanli Ma Muhammad Anwar Ma’sum Michele Magno Sainik Kumar Mahata Shalni Mahato Adnan Mahmood Mohammed Mahmoud Mufti Mahmud Krishanu Maity Mamta Aprinaldi Mantau Mohsen Marjani Sanparith Marukatat José María Luna Archana Mathur Patrick McAllister Piotr Milczarski Kshitij Mishra Pruthwik Mishra Santosh Mishra Sajib Mistry Sayantan Mitra Vinay Kumar Mittal Daisuke Miyamoto Kazuteru Miyazaki
U. Mmodibbo Aditya Mogadala Reem Mohamed Muhammad Syafiq Mohd Pozi
Universitat Politècnica de València, Spain University of Glasgow, Singapore, Singapore Shanghai Jiao Tong University, China Fudan University, China Deakin University, Australia UCLA, USA Shenzhen International Graduate School, Tsinghua University, China South China University of Technology, China University of Canberra, Australia Universitas Indonesia, Indonesia University of Bologna, Italy JU, India Indian Institute of Information Technology (IIIT) Ranchi, India Macquarie University, Australia October University for Modern Sciences & Arts MSA University, Egypt University of Padova, Italy Indian Institute of Technology Patna, India IIT Patna, India Kyushu Institute of Technology, Japan Taylor’s University, Malaysia NECTEC, Thailand Universidad de Córdoba, Spain Nitte Meenakshi Institute of Technology, India Ulster University, UK Lodz University of Technology, Poland IIT Patna, India IIIT-Hyderabad, India Indian Institute of Technology Patna, India Curtin University, Australia Accenture Labs, India Neti International Research Center, India University of Tokyo, Japan National Institution for Academic Degrees and Quality Enhancement of Higher Education, Japan Modibbo Adama University Yola, Nigeria Saarland University, Germany Mansoura University, Egypt Universiti Utara Malaysia, Malaysia
Organization
Anirban Mondal Anupam Mondal Supriyo Mondal J. Manuel Moreno Francisco J. Moreno-Barea Sakchai Muangsrinoon Siti Anizah Muhamed Samrat Mukherjee Siddhartha Mukherjee Dharmalingam Muthusamy Abhijith Athreya Mysore Gopinath Harikrishnan N. B. Usman Naseem Deepak Nayak Hamada Nayel Usman Nazir Vasudevan Nedumpozhimana Atul Negi Aneta Neumann Hea Choon Ngo Dang Nguyen Duy Khuong Nguyen Hoang D. Nguyen Hong Huy Nguyen Tam Nguyen Thanh-Son Nguyen Vu-Linh Nguyen Nick Nikzad Boda Ning Haruhiko Nishimura Kishorjit Nongmeikapam Aleksandra Nowak Stavros Ntalampiras Anupiya Nugaliyadde
xix
University of Tokyo, Japan Jadavpur University, India ZBW - Leibniz Information Centre for Economics, Germany Universitat Politècnica de Catalunya, Spain Universidad de Málaga, Spain Walailak University, Thailand Politeknik Sultan Salahuddin Abdul Aziz Shah, Malaysia Indian Institute of Technology, Patna, India Samsung R&D Institute India, Bangalore, India Bharathiar University, India Pennsylvania State University, USA BITS Pilani K K Birla Goa Campus, India University of Sydney, Australia Malaviya National Institute of Technology, Jaipur, India Benha University, Egypt Lahore University of Management Sciences, Pakistan TU Dublin, Ireland University of Hyderabad, India University of Adelaide, Australia Universiti Teknikal Malaysia Melaka, Malaysia University of Canberra, Australia FPT Software Ltd., FPT Group, Vietnam University College Cork, Ireland National Institute of Informatics, Japan Leibniz University Hannover, Germany Agency for Science, Technology and Research (A*STAR), Singapore Eindhoven University of Technology, Netherlands Griffith University, Australia Swinburne University of Technology, Australia University of Hyogo, Japan Indian Institute of Information Technology (IIIT) Manipur, India Jagiellonian University, Poland University of Milan, Italy Sri Lanka Institute of Information Technology, Sri Lanka
xx
Organization
Anto Satriyo Nugroho Aparajita Ojha Akeem Olowolayemo Toshiaki Omori Shih Yin Ooi Sidali Ouadfeul Samir Ouchani Srinivas P. Y. K. L. Neelamadhab Padhy Worapat Paireekreng Partha Pakray Santanu Pal Bin Pan Rrubaa Panchendrarajan Pankaj Pandey Lie Meng Pang Sweta Panigrahi T. Pant Shantipriya Parida Hyeyoung Park Md Aslam Parwez Leandro Pasa Kitsuchart Pasupa Debanjan Pathak Vyom Pathak Sangameshwar Patil Bidyut Kr. Patra Dipanjyoti Paul Sayanta Paul Sachin Pawar Pornntiwa Pawara Yong Peng Yusuf Perwej Olutomilayo Olayemi Petinrin Arpan Phukan
Agency for Assessment & Application of Technology, Indonesia PDPM IIITDM Jabalpur, India International Islamic University Malaysia, Malaysia Kobe University, Japan Multimedia University, Malaysia Algerian Petroleum Institute, Algeria CESI Lineact, France IIIT Sri City, India GIET University, India Dhurakij Pundit University, Thailand National Institute of Technology Silchar, India Wipro Limited, India Nankai University, China Sri Lanka Institute of Information Technology, Sri Lanka Indian Institute of Technology, Gandhinagar, India Southern University of Science and Technology, China National Institute of Technology Warangal, India IIIT Allahabad, India Idiap Research Institute, Switzerland Kyungpook National University, South Korea Jamia Millia Islamia, India Federal University of Technology - Parana (UTFPR), Brazil King Mongkut’s Institute of Technology Ladkrabang, Thailand Kalinga Institute of Industrial Technology (KIIT), India University of Florida, USA TCS Research, India IIT (BHU) Varanasi, India Indian Institute of Technology Patna, India Ola, India Tata Consultancy Services Ltd., India Mahasarakham University, Thailand Hangzhou Dianzi University, China Ambalika Institute of Management and Technology (AIMT), India City University of Hong Kong, China Indian Institute of Technology Patna, India
Organization
Chiara Picardi Francesco Piccialli Josephine Plested Krishna Reddy Polepalli Dan Popescu Heru Praptono Mukesh Prasad Yamuna Prasad Krishna Prasadmiyapuram Partha Pratim Sarangi Emanuele Principi Dimeter Prodonov Ratchakoon Pruengkarn
Michal Ptaszynski Narinder Singh Punn Abhinanda Ranjit Punnakkal Zico Pratama Putra Zhenyue Qin Nawab Muhammad Faseeh Qureshi Md Rafiqul Saifur Rahaman Shri Rai Vartika Rai Kiran Raja Sutharshan Rajasegarar Arief Ramadhan Mallipeddi Rammohan Md. Mashud Rana Surangika Ranathunga Soumya Ranjan Mishra Hemant Rathore Imran Razzak Yazhou Ren Motahar Reza Dwiza Riana Bharat Richhariya
xxi
University of York, UK University of Naples Federico II, Italy University of New South Wales, Australia IIIT Hyderabad, India University Politehnica of Bucharest, Romania Bank Indonesia/UI, Indonesia University of Technology Sydney, Australia Thompson Rivers University, Canada IIT Gandhinagar, India KIIT Deemed to be University, India Università Politecnica delle Marche, Italy Imec, Belgium College of Innovative Technology and Engineering, Dhurakij Pundit University, Thailand Kitami Institute of Technology, Japan Mayo Clinic, Arizona, USA UiT The Arctic University of Norway, Norway Queen Mary University of London, UK Tencent, China SU, South Korea UTS, Australia City University of Hong Kong, China Murdoch University, Australia IIIT Hyderabad, India Norwegian University of Science and Technology, Norway Deakin University, Australia Bina Nusantara University, Indonesia Kyungpook National University, South Korea Commonwealth Scientific and Industrial Research Organisation (CSIRO), Australia University of Moratuwa, Sri Lanka KIIT University, India Birla Institute of Technology & Science, Pilani, India UNSW, Australia University of Science and Technology of China, China GITAM University Hyderabad, India STMIK Nusa Mandiri, Indonesia BITS Pilani, India
xxii
Organization
Pattabhi R. K. Rao Heejun Roh Vijay Rowtula Aniruddha Roy Sudipta Roy Narendra S. Chaudhari Fariza Sabrina Debanjan Sadhya Sumit Sah Atanu Saha Sajib Saha Snehanshu Saha Tulika Saha Navanath Saharia Pracheta Sahoo Sovan Kumar Sahoo Tanik Saikh Naveen Saini Fumiaki Saitoh Rohit Salgotra Michel Salomon Yu Sang
Suyash Sangwan Soubhagya Sankar Barpanda Jose A. Santos Kamal Sarkar Sandip Sarkar Naoyuki Sato Eri Sato-Shimokawara Sunil Saumya Gerald Schaefer Rafal Scherer Arvind Selwal Noor Akhmad Setiawan Mohammad Shahid Jie Shao
AU-KBC Research Centre, India Korea University, South Korea IIIT Hyderabad, India IIT Kharagpur, India Jio Institute, India Indian Institute of Technology Indore, India Central Queensland University, Australia ABV-IIITM Gwalior, India IIT Dharwad, India Jadavpur University, India Commonwealth Scientific and Industrial Research Organisation, Australia BITS Pilani K K Birla Goa Campus, India IIT Patna, India Indian Institute of Information Technology Manipur, India University of Texas at Dallas, USA Indian Institute of Technology Patna, India L3S Research Center, Germany Indian Institute of Information Technology Lucknow, India Chiba Institute of Technology, Japan Swansea University, UK Univ. Bourgogne Franche-Comté, France Research Institute of Institute of Computing Technology, Exploration and Development, Liaohe Oilfield, PetroChina, China Indian Institute of Technology Patna, India VIT-AP University, India Ulster University, UK Jadavpur University, India Jadavpur University, India Future University Hakodate, Japan Tokyo Metropolitan University, Japan Indian Institute of Information Technology Dharwad, India Loughborough University, UK Czestochowa University of Technology, Poland Central University of Jammu, India Universitas Gadjah Mada, Indonesia Aligarh Muslim University, India University of Science and Technology of China, China
Organization
Nabin Sharma Raksha Sharma Sourabh Sharma Suraj Sharma Ravi Shekhar Michael Sheng Yin Sheng Yongpan Sheng Liu Shenglan Tomohiro Shibata Iksoo Shin Mohd Fairuz Shiratuddin Hayaru Shouno Sanyam Shukla Udom Silparcha Apoorva Singh Divya Singh Gitanjali Singh Gopendra Singh K. P. Singh Navjot Singh Om Singh Pardeep Singh Rajiv Singh Sandhya Singh Smriti Singh Narinder Singhpunn Saaveethya Sivakumar Ferdous Sohel Chattrakul Sombattheera Lei Song Linqi Song Yuhua Song Gautam Srivastava Rajeev Srivastava Jérémie Sublime P. N. Suganthan
xxiii
University of Technology Sydney, Australia IIT Bombay, India Avantika University, India International Institute of Information Technology Bhubaneswar, India Queen Mary University of London, UK Macquarie University, Australia Huazhong University of Science and Technology, China Southwest University, China Dalian University of Technology, China Kyushu Institute of Technology, Japan University of Science & Technology, China Murdoch University, Australia University of Electro-Communications, Japan MANIT, Bhopal, India KMUTT, Thailand Indian Institute of Technology Patna, India Central University of Bihar, India Indian Institute of Technology Patna, India Indian Institute of Technology Patna, India IIIT Allahabad, India IIIT Allahabad, India NIT Patna, India Jawaharlal Nehru University, India Banasthali Vidyapith, India Indian Institute of Technology Bombay, India IIT Bombay, India Mayo Clinic, Arizona, USA Curtin University, Malaysia Murdoch University, Australia Mahasarakham University, Thailand Unitec Institute of Technology, New Zealand City University of Hong Kong, China University of Science and Technology Beijing, China Brandon University, Canada Banaras Hindu University (IT-BHU), Varanasi, India ISEP - Institut Supérieur d’Électronique de Paris, France Nanyang Technological University, Singapore
xxiv
Organization
Derwin Suhartono Indra Adji Sulistijono John Sum Fuchun Sun Ning Sun Anindya Sundar Das Bapi Raju Surampudi Olarik Surinta Maria Susan Anggreainy M. Syafrullah Murtaza Taj Norikazu Takahashi Abdelmalik Taleb-Ahmed Hakaru Tamukoh Choo Jun Tan Chuanqi Tan Shing Chiang Tan Xiao Jian Tan Xin Tan Ying Tan Gouhei Tanaka Yang Tang Zhiri Tang Tanveer Tarray Chee Siong Teh Ya-Wen Teng Gaurish Thakkar Medari Tham Selvarajah Thuseethan Shu Tian Massimo Tistarelli Abhisek Tiwari Uma Shanker Tiwary
Bina Nusantara University, Indonesia Politeknik Elektronika Negeri Surabaya (PENS), Indonesia National Chung Hsing University, Taiwan Tsinghua University, China Nankai University, China Indian Institute of Technology Patna, India International Institute of Information Technology Hyderabad, India Mahasarakham University, Thailand Bina Nusantara University, Indonesia Universitas Budi Luhur, Indonesia Lahore University of Management Sciences, Pakistan Okayama University, Japan Polytechnic University of Hauts-de-France, France Kyushu Institute of Technology, Japan Wawasan Open University, Malaysia BIT, China Multimedia University, Malaysia Tunku Abdul Rahman University of Management and Technology (TAR UMT), Malaysia East China Normal University, China Peking University, China University of Tokyo, Japan East China University of Science and Technology, China City University of Hong Kong, China Islamic University of Science and Technology, India Universiti Malaysia Sarawak (UNIMAS), Malaysia Academia Sinica, Taiwan University of Zagreb, Croatia St. Anthony’s College, India Sabaragamuwa University of Sri Lanka, Sri Lanka University of Science and Technology Beijing, China University of Sassari, Italy IIT Patna, India Indian Institute of Information Technology, Allahabad, India
Organization
Alex To Stefania Tomasiello Anh Duong Trinh Enkhtur Tsogbaatar Enmei Tu Eiji Uchino Prajna Upadhyay Sahand Vahidnia Ashwini Vaidya Deeksha Varshney Sowmini Devi Veeramachaneni Samudra Vijaya Surbhi Vijh Nhi N. Y. Vo Xuan-Son Vu Anil Kumar Vuppala Nobuhiko Wagatsuma Feng Wan Bingshu Wang Dianhui Wang Ding Wang Guanjin Wang Jiasen Wang Lei Wang Libo Wang Meng Wang Qiu-Feng Wang Sheng Wang Weiqun Wang Wentao Wang Yongyu Wang Zhijin Wang Bunthit Watanapa Yanling Wei Guanghui Wen Ari Wibisono Adi Wibowo Ka-Chun Wong
xxv
University of Sydney, Australia University of Tartu, Estonia Technological University Dublin, Ireland Mongolian University of Science and Technology, Mongolia Shanghai Jiao Tong University, China Yamaguchi University, Japan IIT Delhi, India University of New South Wales, Australia IIT Delhi, India Indian Institute of Technology, Patna, India Mahindra University, India Koneru Lakshmaiah Education Foundation, India JSS Academy of Technical Education, Noida, India University of Technology Sydney, Australia Umeå University, Sweden IIIT Hyderabad, India Toho University, Japan University of Macau, China Northwestern Polytechnical University Taicang Campus, China La Trobe University, Australia Beijing University of Technology, China Murdoch University, Australia City University of Hong Kong, China Beihang University, China Xiamen University of Technology, China Southeast University, China Xi’an Jiaotong-Liverpool University, China Henan University, China Institute of Automation, Chinese Academy of Sciences, China Michigan State University, USA Michigan Technological University, USA Jimei University, China KMUTT-SIT, Thailand TU Berlin, Germany RMIT University, Australia Universitas Indonesia, Indonesia Diponegoro University, Indonesia City University of Hong Kong, China
xxvi
Organization
Kevin Wong Raymond Wong Kuntpong Woraratpanya Marcin Wo´zniak Chengwei Wu Jing Wu Weibin Wu Hongbing Xia Tao Xiang Qiang Xiao Guandong Xu Qing Xu Yifan Xu Junyu Xuan Hui Xue Saumitra Yadav Shekhar Yadav Sweta Yadav Tarun Yadav Shankai Yan Feidiao Yang Gang Yang Haiqin Yang Jianyi Yang Jinfu Yang Minghao Yang Shaofu Yang Wachira Yangyuen Xinye Yi Hang Yu Wen Yu Wenxin Yu Zhaoyuan Yu Ye Yuan Xiaodong Yue
Murdoch University, Australia Universiti Malaya, Malaysia King Mongkut’s Institute of Technology Ladkrabang (KMITL), Thailand Silesian University of Technology, Poland Harbin Institute of Technology, China Shanghai Jiao Tong University, China Sun Yat-sen University, China Beijing Normal University, China Chongqing University, China Huazhong University of Science and Technology, China University of Technology Sydney, Australia Tianjin University, China Huazhong University of Science and Technology, China University of Technology Sydney, Australia Southeast University, China IIIT-Hyderabad, India Madan Mohan Malaviya University of Technology, India University of Illinois at Chicago, USA Defence Research and Development Organisation, India Hainan University, China Microsoft, China Renmin University of China, China International Digital Economy Academy, China Shandong University, China BJUT, China Institute of Automation, Chinese Academy of Sciences, China Southeast University, China Rajamangala University of Technology Srivijaya, Thailand Guilin University of Electronic Technology, China Shanghai University, China Cinvestav, Mexico Southwest University of Science and Technology, China Nanjing Normal University, China Xi’an Jiaotong University, China Shanghai University, China
Organization
Aizan Zafar Jichuan Zeng Jie Zhang Shixiong Zhang Tianlin Zhang Mingbo Zhao Shenglin Zhao Guoqiang Zhong Jinghui Zhong Bo Zhou Yucheng Zhou Dengya Zhu Xuanying Zhu Hua Zuo
Indian Institute of Technology Patna, India Bytedance, China Newcastle University, UK Xidian University, China University of Manchester, UK Donghua University, China Zhejiang University, China Ocean University of China, China South China University of Technology, China Southwest University, China University of Technology Sydney, Australia Curtin University, Australia ANU, Australia University of Technology Sydney, Australia
Additional Reviewers Acharya, Rajul Afrin, Mahbuba Alsuhaibani, Abdullah Amarnath Appicharla, Ramakrishna Arora, Ridhi Azar, Joseph Bai, Weiwei Bao, Xiwen Barawi, Mohamad Hardyman Bhat, Mohammad Idrees Bhat Cai, Taotao Cao, Feiqi Chakraborty, Bodhi Chang, Yu-Cheng Chen Chen, Jianpeng Chen, Yong Chhipa, Priyank Cho, Joshua Chongyang, Chen Cuenat, Stéphane Dang, Lili Das Chakladar, Debashis Das, Kishalay Dey, Monalisa
xxvii
Doborjeh, Maryam Dong, Zhuben Dutta, Subhabrata Dybala, Pawel El Achkar, Charbel Feng, Zhengyang Galkowski, Tomasz Garg, Arpit Ghobakhlou, Akbar Ghosh, Soumitra Guo, Hui Gupta, Ankur Gupta, Deepak Gupta, Megha Han, Yanyang Han, Yiyan Hang, Bin Harshit He, Silu Hua, Ning Huang, Meng Huang, Rongting Huang, Xiuyu Hussain, Zawar Imran, Javed Islam, Md Rafiqul
xxviii
Organization
Jain, Samir Jia, Mei Jiang, Jincen Jiang, Xiao Jiangyu, Wang Jiaxin, Lou Jiaxu, Hou Jinzhou, Bao Ju, Wei Kasyap, Harsh Katai, Zoltan Keserwani, Prateek Khan, Asif Khan, Muhammad Fawad Akbar Khari, Manju Kheiri, Kiana Kirk, Nathan Kiyani, Arslan Kolya, Anup Kumar Krdzavac, Nenad Kumar, Lov Kumar, Mukesh Kumar, Puneet Kumar, Rahul Kumar, Sunil Lan, Meng Lavangnananda, Kittichai Li, Qian Li, Xiaoou Li, Xin Li, Xinjia Liang, Mengnan Liang, Shuai Liquan, Li Liu, Boyang Liu, Chang Liu, Feng Liu, Linjing Liu, Xinglan Liu, Xinling Liu, Zhe Lotey, Taveena Ma, Bing Ma, Zeyu Madanian, Samaneh
Mahata, Sainik Kumar Mahmud, Md. Redowan Man, Jingtao Meena, Kunj Bihari Mishra, Pragnyaban Mistry, Sajib Modibbo, Umar Muhammad Na, Na Nag Choudhury, Somenath Nampalle, Kishore Nandi, Palash Neupane, Dhiraj Nigam, Nitika Nigam, Swati Ning, Jianbo Oumer, Jehad Pandey, Abhineet Kumar Pandey, Sandeep Paramita, Adi Suryaputra Paul, Apurba Petinrin, Olutomilayo Olayemi Phan Trong, Dat Pradana, Muhamad Hilmil Muchtar Aditya Pundhir, Anshul Rahman, Sheikh Shah Mohammad Motiur Rai, Sawan Rajesh, Bulla Rajput, Amitesh Singh Rao, Raghunandan K. R. Rathore, Santosh Singh Ray, Payel Roy, Satyaki Saini, Nikhil Saki, Mahdi Salimath, Nagesh Sang, Haiwei Shao, Jian Sharma, Anshul Sharma, Shivam Shi, Jichen Shi, Jun Shi, Kaize Shi, Li Singh, Nagendra Pratap Singh, Pritpal
Organization
Singh, Rituraj Singh, Shrey Singh, Tribhuvan Song, Meilun Song, Yuhua Soni, Bharat Stommel, Martin Su, Yanchi Sun, Xiaoxuan Suryodiningrat, Satrio Pradono Swarnkar, Mayank Tammewar, Aniruddha Tan, Xiaosu Tanoni, Giulia Tanwar, Vishesh Tao, Yuwen To, Alex Tran, Khuong Varshney, Ayush Vo, Anh-Khoa Vuppala, Anil Wang, Hui Wang, Kai Wang, Rui Wang, Xia Wang, Yansong
Wang, Yuan Wang, Yunhe Watanapa, Saowaluk Wenqian, Fan Xia, Hongbing Xie, Weidun Xiong, Wenxin Xu, Zhehao Xu, Zhikun Yan, Bosheng Yang, Haoran Yang, Jie Yang, Xin Yansui, Song Yu, Cunzhe Yu, Zhuohan Zandavi, Seid Miad Zeng, Longbin Zhang, Jane Zhang, Ruolan Zhang, Ziqi Zhao, Chen Zhou, Xinxin Zhou, Zihang Zhu, Liao Zhu, Linghui
xxix
Contents – Part IV
Theory and Algorithms I Knowledge Transfer from Situation Evaluation to Multi-agent Reinforcement Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Min Chen, Zhiqiang Pu, Yi Pan, and Jianqiang Yi
3
Sequential Three-Way Rules Class-Overlap Under-Sampling Based on Fuzzy Hierarchical Subspace for Imbalanced Data . . . . . . . . . . . . . . . . . . . . . . . Qi Dai, Jian- wei Liu, and Jia- peng Yang
15
Two-Stage Multilayer Perceptron Hawkes Process . . . . . . . . . . . . . . . . . . . . . . . . . Xiang Xing, Jian- wei Liu, and Zi- hao Cheng
28
The Context Hierarchical Contrastive Learning for Time Series in Frequency Domain . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Ming-hui Wang and Jian-wei Liu
40
Hawkes Process via Graph Contrastive Discriminant Representation Learning and Transformer Capturing Long-Term Dependencies . . . . . . . . . . . . . . Ze Cao, Jian-wei Liu, and Zi-hao Cheng
53
A Temporal Consistency Enhancement Algorithm Based on Pixel Flicker Correction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Junfeng Meng, Qiwei Shen, Yangliu He, and Jianxin Liao
65
Data Representation and Clustering with Double Low-Rank Constraints . . . . . . . Haoming He, Deyu Zeng, Chris Ding, and Zongze Wu RoMA: A Method for Neural Network Robustness Measurement and Assessment . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Natan Levy and Guy Katz
79
92
Independent Relationship Detection for Real-Time Scene Graph Generation . . . 106 Tianlei Jin, Wen Wang, Shiqiang Zhu, Xiangming Xi, Qiwei Meng, Zonghao Mu, and Wei Song A Multi-label Feature Selection Method Based on Feature Graph with Ridge Regression and Eigenvector Centrality . . . . . . . . . . . . . . . . . . . . . . . . . . 119 Zhiwei Ye, Haichao Zhang, Mingwei Wang, and Qiyi He
xxxii
Contents – Part IV
O3 GPT: A Guidance-Oriented Periodic Testing Framework with Online Learning, Online Testing, and Online Feedback . . . . . . . . . . . . . . . . . . . . . . . . . . . . 130 Yimeng Ren, Yuhu Shang, Kun Liang, Xiankun Zhang, and Yiying Zhang AFFSRN: Attention-Based Feature Fusion Super-Resolution Network . . . . . . . . 142 Yeguang Qin, Fengxiao Tang, Ming Zhao, and Yusen Zhu Temporal-Sequential Learning with Columnar-Structured Spiking Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 153 Xiaoling Luo, Hanwen Liu, Yi Chen, Malu Zhang, and Hong Qu Graph Attention Transformer Network for Robust Visual Tracking . . . . . . . . . . . 165 Libo Wang, Si Chen, Zhen Wang, Da-Han Wang, and Shunzhi Zhu GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding . . . . 177 Qimeng Guo, Huajuan Duan, Chuanhao Dong, Peiyu Liu, and Liancheng Xu Towards a Unified Benchmark for Reinforcement Learning in Sparse Reward Environments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 189 Yongxin Kang, Enmin Zhao, Yifan Zang, Kai Li, and Junliang Xing Effect of Logistic Activation Function and Multiplicative Input Noise on DNN-kWTA Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 202 Wenhao Lu, Chi-Sing Leung, and John Sum A High-Speed SSVEP-Based Speller Using Continuous Spelling Method . . . . . . 215 Bang Xiong, Jiayang Huang, Bo Wan, Changhua Jiang, Kejia Su, and Fei Wang AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 227 Mengzhu Wang, Shanshan Wang, Tianwei Yan, and Zhigang Luo Aggregating Intra-class and Inter-class Information for Multi-label Text Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 239 Xianze Wu, Dongyu Ru, Weinan Zhang, Yong Yu, and Ziming Feng Fast Estimation of Multidimensional Regression Functions by the Parzen Kernel-Based Method . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 251 Tomasz Gałkowski and Adam Krzy˙zak ReGAE: Graph Autoencoder Based on Recursive Neural Networks . . . . . . . . . . . 263 Adam Małkowski, Jakub Grzechoci´nski, and Paweł Wawrzy´nski
Contents – Part IV
xxxiii
Efficient Uncertainty Quantification for Under-Constraint Prediction Following Learning Using MCMC . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 275 Gargi Roy and Dalia Chakrabarty SMART: A Robustness Evaluation Framework for Neural Networks . . . . . . . . . . 288 Yuanchun Xiong and Baowen Zhang Time-aware Quaternion Convolutional Network for Temporal Knowledge Graph Reasoning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 300 Chong Mo, Ye Wang, Yan Jia, and Cui Luo SumBART - An Improved BART Model for Abstractive Text Summarization . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 313 A. Vivek and V. Susheela Devi Saliency-Guided Learned Image Compression for Object Detection . . . . . . . . . . . 324 Haoxuan Xiong and Yuanyuan Xu Multi-label Learning with Data Self-augmentation . . . . . . . . . . . . . . . . . . . . . . . . . 336 Yuhang Ge, Xuegang Hu, Peipei Li, Haobo Wang, Junbo Zhao, and Junlong Li MnRec: A News Recommendation Fusion Model Combining Multi-granularity Information . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 348 Laiping Cui, Zhenyu Yang, Guojing Liu, Yu Wang, and Kaiyang Ma Infinite Label Selection Method for Mutil-label Classification . . . . . . . . . . . . . . . . 361 Yuchen Pan, Jun Li, and Jianhua Xu Simultaneous Perturbation Method for Multi-task Weight Optimization in One-Shot Meta-learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 373 Andrei Boiarov, Kostiantyn Khabarlak, and Igor Yastrebov Searching for Textual Adversarial Examples with Learned Strategy . . . . . . . . . . . 385 Xiangzhe Guo, Ruidan Su, Shikui Tu, and Lei Xu Multivariate Time Series Retrieval with Binary Coding from Transformer . . . . . 397 Zehan Tan, Mingyu Zhao, Yun Wang, and Weidong Yang Learning TSP Combinatorial Search and Optimization with Heuristic Search . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 409 Hua Yang and Ming Gu A Joint Learning Model for Open Set Recognition with Post-processing . . . . . . . 420 Qinglin Li, Guanyu Xing, and Yanli Liu
xxxiv
Contents – Part IV
Cross-Layer Fusion for Feature Distillation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 433 Honglin Zhu, Ning Jiang, Jialiang Tang, Xinlei Huang, Haifeng Qing, Wenqing Wu, and Peng Zhang MCHPT: A Weakly Supervise Based Merchant Pre-trained Model . . . . . . . . . . . . 446 Zehua Zeng, Xiaohan She, Xuetao Qiu, Hongfeng Chai, and Yanming Yang Progressive Latent Replay for Efficient Generative Rehearsal . . . . . . . . . . . . . . . . 457 Stanisław Pawlak, Filip Szatkowski, Michał Bortkiewicz, Jan Dubi´nski, and Tomasz Trzci´nski Generalization Bounds for Set-to-Set Matching with Negative Sampling . . . . . . . 468 Masanari Kimura ADA: An Attention-Based Data Augmentation Approach to Handle Imbalanced Textual Datasets . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 477 Amit Kumar Sah and Muhammad Abulaish Countering the Anti-detection Adversarial Attacks . . . . . . . . . . . . . . . . . . . . . . . . . 489 Anjie Peng, Chenggang Li, Ping Zhu, Xiaofang Huang, Hui Zeng, and Wenxin Yu Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks . . . . 501 Hao Tang, Donghong Liu, Xinhai Xu, and Feng Zhang Improving Knowledge Graph Embedding Using Dynamic Aggregation of Neighbor Information . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 513 Guangbin Wang, Yuxin Ding, Yiqi Su, Zihan Zhou, Yubin Ma, and Wen Qian Generative Generalized Zero-Shot Learning Based on Auxiliary-Features . . . . . 526 Weimin Sun and Gang Yang Learning Stable Representations with Progressive Autoencoder (PAE) . . . . . . . . 538 Zhouzheng Li, Dongyan Miao, Junfeng Gao, and Kun Feng Effect of Image Down-sampling on Detection of Adversarial Examples . . . . . . . 550 Anjie Peng, Chenggang Li, Ping Zhu, Zhiyuan Wu, Kun Wang, Hui Zeng, and Wenxin Yu Boosting the Robustness of Neural Networks with M-PGD . . . . . . . . . . . . . . . . . . 562 Chenghai He, Li Zhou, Kai Zhang, Hailing Li, Shoufeng Cao, Gang Xiong, and Xiaohang Zhang
Contents – Part IV
xxxv
StatMix: Data Augmentation Method that Relies on Image Statistics in Federated Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 574 Dominik Lewy, Jacek Ma´ndziuk, Maria Ganzha, and Marcin Paprzycki Classification by Components Including Chow’s Reject Option . . . . . . . . . . . . . . 586 Mehrdad Mohannazadeh Bakhtiari and Thomas Villmann Community Discovery Algorithm Based on Improved Deep Sparse Autoencoder . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 597 Dianying Chen, Xuesong Jiang, Jun Chen, and Xiumei Wei Fairly Constricted Multi-objective Particle Swarm Optimization . . . . . . . . . . . . . . 610 Anwesh Bhattacharya, Snehanshu Saha, and Nithin Nagaraj Argument Classification with BERT Plus Contextual, Structural and Syntactic Features as Text . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 622 Umer Mushtaq and Jérémie Cabessa Variance Reduction for Deep Q-Learning Using Stochastic Recursive Gradient . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 634 Haonan Jia, Xiao Zhang, Jun Xu, Wei Zeng, Hao Jiang, and Xiaohui Yan Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 647 Xinlei Huang, Jialiang Tang, Haifeng Qing, Honglin Zhu, Ning Jiang, Wenqing Wu, and Peng Zhang Unsupervised Domain Adaptation Supplemented with Generated Images . . . . . . 659 S. Suryavardan, Viswanath Pulabaigari, and Rakesh Kumar Sanodiya MAR2MIX: A Novel Model for Dynamic Problem in Multi-agent Reinforcement Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 671 Gaoyun Fang, Yang Liu, Jing Liu, and Liang Song Adversarial Training with Knowledge Distillation Considering Intermediate Representations in CNNs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 683 Hikaru Higuchi, Satoshi Suzuki, and Hayaru Shouno Deep Contrastive Multi-view Subspace Clustering . . . . . . . . . . . . . . . . . . . . . . . . . 692 Lei Cheng, Yongyong Chen, and Zhongyun Hua Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 705
Theory and Algorithms I
Knowledge Transfer from Situation Evaluation to Multi-agent Reinforcement Learning Min Chen1,2 , Zhiqiang Pu2(B) , Yi Pan2 , and Jianqiang Yi2 1
School of Artificial Intelligence, University of Chinese Academy of Sciences, Beijing, China 2 Institute of Automation, Chinese Academy of Sciences, Beijing, China [email protected]
Abstract. Recently, multi-agent reinforcement learning (MARL) has achieved amazing performance on complex tasks. However, it still suffers from challenges of sparse rewards and contradiction between consistent cognition and policy diversity. In this paper, we propose novel methods for transferring knowledge from situation evaluation task to MARL task. Specifically, we utilize offline data from a single-agent scenario to train two situation evaluation models for: (1) constructing guiding dense rewards (GDR) in multi-agent scenarios to help agents explore real sparse rewards faster and jump out of locally optimal policies without changing the global optimal policy; (2) transferring a situation comprehension network (SCN) to multi-agent scenarios that balances the contradiction between consistent cognition and policy diversity among agents. Our methods can be easily combined with existing MARL methods. Empirical results show that our methods achieve state-of-the-art performance on Google Research Football which brings together above challenges.
Keywords: Multi-agent reinforcement learning Football
1
· Transfer learning ·
Introduction
Deep reinforcement learning (DRL) has achieved super-human performance on complex games [2,16,28]. However, in the multi-agent setting, task complexity grows exponentially with the number of agents. In addition, some tasks have only sparse rewards and require agents to emerge diversity while cooperating. The above challenges make it difficult for agents to learn a robust and satisfactory policy. For the challenge of sparse rewards, dense rewards are designed with three major ways to guide learning process of agents. Human knowledge rewards [13,25] give agents a numerical reward when they complete sub-goals that are helpful for the original task in terms of designers’ experience. Intrinsic rewards [3,30,35] c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 3–14, 2023. https://doi.org/10.1007/978-981-99-1639-9_1
4
M. Chen et al.
encourage agents to perform certain characteristics, such as curiosity and diversity. However, these two ways of reward shaping will change the optimal policy of agents. [17] proposes potential-based reward shaping (PBRS) that guarantees the optimal policy unchanging in single-agent setting. [31] additionally considers the effect of agent actions on PBRS and [7] considers dynamic potential situation. [9] gives a method to translate an arbitrary reward function into PBRS. [6] extends PBRS to multi-agent setting and proves the Nash Equilibria of underlying stochastic game is not modified. However, the above three reward shaping methods require fine-tuning hyper-parameters, otherwise they possibly bring worse results. A commonly adopted solution is to consider the hyper-parameter adjustment of reward shaping as the second optimization objective in addition to the reinforcement learning task. [11] uses population optimization to learn the optimal human knowledge reward size by assigning different parameters to agents and improving the optimal reward iteratively. [34] proposes a scalable meta-gradient framework for learning useful intrinsic reward functions. [10] formulates the utilization of PBRS as a bi-level optimization problem to adjust the weight of various rewards. However, the above methods are difficult to deploy on complex multiagent tasks due to the exponentially growing complexity. As for the contradiction between consistent cognition and policy diversity, policy decentralization with shared parameters (PDSP) has been widely used [15,26] to speed up training process and endow agents with consistent cognition of their task. However, it also brings the difficulty of emerging personalities of agents. Some approaches stimulate player roles [30], and diversity [4] by introducing intrinsic rewards. As mentioned above, these approaches changes the original optimization objective, which results in agents less being eager to complete their original tasks. Transfer learning is to study how to utilize the knowledge learned from source domain to improve the performance on target domain. Since deep learning has strong representation learning capability and transferability, many achievements are obtained in supervised learning [5,20,23,27]. As for RL tasks, they are also expected to benefit from related source tasks to improve the performance of agents. Based on the above discussion, for the challenges of sparse rewards and contradiction between consistent cognition and policy diversity in MARL, we transfer the knowledge learned from offline dataset of a single-agent scenario to multiagent scenarios. Specifically, we (1) construct guiding dense rewards (GDR) in multi-agent scenarios to help agents explore real sparse rewards faster and jump out of locally optimal policies without changing the global optimal policy; (2) transfer a situation comprehension network (SCN) to balances the contradiction between consistent cognition and policy diversity among agents. Empirical results show that our methods achieve state-of-the-art performance on Google Research Football (GRF) [12] which brings together the above challenges.
2
Related Works
The advanced deep MARL approaches include value-based [21,24,29] algorithms and policy-gradient-based [14,33] algorithms. Theoretically, our methods can be
Knowledge Transfer from Situation Evaluation to MARL
5
combined with any of above approaches, we choose to validate our methods on MAPPO because of its better performance in solving sparse reward task than others. Though the constructing of GDR is close to the idea of inverse reinforcement learning (IRL) [1,8,19] for both constructing dense rewards from offline data, GDR has fundamental difference with IRL. The goal of IRL is to imitate the offline policy. It always assumes that the offline policy is already optimal. We still maintain the original task goal and do not have hard requirements on the performance of offline policies. The transfer of the SCN is similar to the pre-training and fine-tuning paradigm [32]. They both reuse the network from upstream tasks in downstream tasks.
3
Background
Multi-agent reinforcement learning can be formulated as a Dec-POMDP [18], which is defined as a tuple N, S, A, P, R, O, Ω, n, γ. At each step, environment is characterized as sS. The set of n agents N receives the local observation oΩ according to the observation function O, then agents choose the joint action aA according to their joint policy π. The environment transfers to a new state s according to the state transition function P and gives agents a reward rR. The task of agents is to learn the optimal joint policy to maximum the accumulated reward E ( t=0 γ t rt ) with discount factor γ.
4
Methods
In this section, we will introduce the details of GDR and SCN. We illustrate and verify our methods with a football game. It needs to be emphasized that, our methods are also applicable for other tasks with sparse reward character and requiring cooperation and diversity among agents. 4.1
Situation Evaluation Task
The situation evaluation task is defined as: Given: A vector representing the game state st . Do: Estimate value φ(st ): the value of st Once the form of st , φ(st ) and learning algorithm are determined, a corresponding model will be obtained. To collect data for this task, we train an agent with RL algorithm PPO [22] for single-agent 11 vs.11 scenario and test it with easy baseline, hard baseline and self to collect 3000 episodes respectively for each team combination. For each episode, we record all raw observation, which is composed of: – player information: the two-dimensional position, two-dimensional speed, fatigue value, and role of all players in the game. – ball information: the position, speed, rotation speed, and ball possession which indicates the team and player controlling the ball.
6
M. Chen et al.
– game information: the score, remaining time, and game mode which includes normal, kickoff, goal-kick, free-kick, corner, throw-in, and penalty. where the score information is used to label samples and other information can be combined as st . 4.2
Construction of Guiding Dense Reward
The first situation evaluation model is used to construct guiding dense reward. Firstly, we label samples as: φ (st ) = γ t−tg Sign (tg )
(1)
where tg represent the most recent score time after t, Sign (tg ) is the signature function. If it is a home team score at tg , Sign (tg ) will be 1, otherwise it will be -1. That means, the absolute value of a game state is positively related the number of steps required to scoring. Then we choose two-dimensional position of ball to represent game state st and divide the pitch into grids to calculate the mean value of each grid as its expected state value: 1 φ (G) = φ (st ) (2) n st G
where G is the grid, and n is the number of samples in G. In addition, it can be seen in Fig. 1 that we make the situation evaluation symmetrical in horizontal direction, because the agent of offline data excels at scoring in the bottom half of the pitch, which is agent’s biased knowledge that needs to be removed.
(a) Unsymmetrical state value.
(b) Symmetrical state value.
Fig. 1. State value visualization (the boundary of GRF is equal to the limitation of axis). The color of each point represents the state value of dribbling the ball there.
At each step, agents will obtain a PBRS reward F (st , st+1 ) for transferring game state from st to st+1 : (3) F (st , st+1 ) = γφ G − φ (G) where st represent the game state at time t, φ (st ) the potential of st , γ the discount factor, G and G the grids that st and st+1 in.
Knowledge Transfer from Situation Evaluation to MARL
4.3
7
Transfer of Situation Comprehension Network
The second situation evaluation model is used to transfer situation comprehension network. The input of model st concludes the position, velocity and tired factor of 22 players; the position, velocity and rotation of ball; ball possession, scores of two teams and game mode. The model contains five fully connected layers and the output is the probability of scoring before the changing of ball possession. We divided the observations of agents into two parts. One part is the information related to game situation, which is the same as the input of the situation evaluation network above mentioned, and the other part is private information. The transfer method of deep network structure among similar tasks has been widely used, which benefits from the strong representation learning ability of deep network. Much researches claim that the output of low layers of deep network is general semantics, which of medium layers is hidden variable, and the output of high layers is specific semantics. That means if the situation evaluation network is properly extracted, a situation comprehension network that processes the raw situation observation into situation comprehension information can be transferred to RL task. Hence, we make agents share the situation comprehension network (the first several layers of situation evaluation network). Once agents obtain their raw observation, the situation comprehension information given by SCN and private information will be spliced as the new observation. Then agents train their actor separately to balance the contradiction between consistent cognition and policy diversity among them.
5
Experiments
In this section, we will validate our methods on multi-agent scenarios of GRF to illustrate the effectiveness of GDR and SCN. 5.1
Performance on Google Research Football
We validate our methods on following seven multi-agent scenarios: (a) academy 3 vs 1 with keeper, (b) academy pass and shoot with keeper, (c) academy run pass and shoot with keeper, (d) academy counterattack hard, (e) academy corner, (f) 5 vs 5, (g) 11 vs 11 easy stochastic (we control four agents and the other seven agents are controlled by built-in rules of GRF). In (a)–(e), agents need to learn offensive skills that score a goal quickly in a short time span. In (f) and (g), agents need to learn offensive and defensive skills that score more goals than opponents in a whole match. To verify the effectiveness of our methods, we experiment and compare the following four methods, in which SCN is only used in the scenarios with 22 players on the pitch (d, e, g):
8
M. Chen et al.
Fig. 2. Comparison of our approach against baseline algorithms on Google Research Football.
Score: Agents obtain a sparse reward +1 when scoring a goal and −1 when opponents scoring. Checkpoint: Checkpoint is built-in human knowledge reward of GRF, it divides pitch into 10 zones based on the Euclidean distance from the goal. Agents will obtain a 0.1 reward when they possess the ball for the first time in each zone to encourage them to run towards opponents’ goal with the ball. GDR: Agents obtain a dense reward that constructed according to Sect. 4.2 at each step. GDR+SCN: Agents share SCN to process situation information and obtain GDR at each step. As can be seen in Fig. 2, the winning rate of our methods significantly higher than which of baselines (Score and Checkpoint) in all scenarios. For Score, agents tend to fall into the first local optimal policy on account of sparse reward in simple tasks (a, b, c, e), because agents find that sparse rewards can only be obtained by adopting the local optimal policy. Without the guidance of dense reward, agents will gradually strengthen the local optimal policy and eventually converge. In complex tasks (d, f), agents hardly continuously sample good actions to explore sparse rewards. In these cases, most of samples are invalid. As for Checkpoint, it plays a guiding role in some scenarios (d, f, g). However, it changes the optimization objective, which leads to ordinary convergent policies. Moreover, due to the bias and incompleteness of human knowledge, in some scenarios (a, b, e), Checkpoint even hinders the learning process of agents. The role of GDR can be summarized as: (1) helping agents jump out of random local optimal policies in simple tasks (a, b, c, e); (2) gradually guiding agents to explore real sparse rewards, then continuously stimulate agents to explore better policies in difficult tasks (d, f, g). As showed in Fig. 3, in the academy pass and shoot with keeper, which is a simple scenario, the learning process of Score-agents and GDR-agents is divided into two stage. In the
Knowledge Transfer from Situation Evaluation to MARL
9
Fig. 3. The test results of the same seed trained with Score and GDR in academy pass and shot with keeper. (a) the curve of winning rate over time. (b), (c) the curve of average shot position over time.
first stage, the Score-agents obtain sparse rewards through randomly sampling actions. At the end of the first stage, the Score-agents find a local optimal policy of 30% test winning rate by strengthening the state-action pairs that obtain real sparse rewards. The effect of GDR on agents is approximately equivalent to specifying different paths for agents to obtain sparse rewards. Therefore, unlike the Score-agents, GDR-agents learn to approximate GDR with their critics, instead of entirely relying on randomly sampling actions when exploring sparse rewards. Hence, it appears that both agents learn the same policy at the end of the first stage. In reality, GDR-agents collect much more valid samples than Score-agents. In the second stage, GDR-agents quickly jump out of the local optimal policy, while Score-agents fall into the local optimal policy. Academy counterattack hard is a typical difficult scenario, which requires more steps and cooperation among players than easy tasks. For Score-agents, it is difficult to explore real sparse rewards by randomly sampling actions. Therefore, it can be seen in Fig. 4, the winning rate of Score-agents is less than 10% during the entire learning process. As for GDR-agents, their learning process is naturally divided into two stages under the guidance of GDR. In the first stage, GDR-agents mainly focus on dribbling skills to pursue higher dense rewards. They consciously learn to approach opponents’ goal. Although shot is the key action of this task, GDR-agents
10
M. Chen et al.
Fig. 4. The test results of the same seed trained with Score and GDR in academy counterattack hard. (a) the curve of winning rate over time. (b), (d) the curve of average dribbling position over time. (c) the curve of the frequency of performing shot action over time. The shot action is blocked when agents are much far from the goal. Score-agents do not perform shot at the beginning of training, because they are always far from the goal.
gradually weaken the frequency of performing shot action. They find that the positions where they dribble the ball are not good enough to obtain real sparse rewards, but make the ball out of control and lead to lower dense rewards. At the end of the first stage, GDR-agents have found good shot positions, where they possibly obtain real sparse rewards when performing shot action. Hence, the frequency of performing shot action increases sharply. In the second stage, GDR-agents mainly focus on shot skills to pursue higher sparse rewards. The winning rate increases rapidly. 5.2
Parametric Sensitivity of GDR
We study the parameter sensitivity of GDR in two scenarios to illustrate that GDR does not require precisely adjusting its parameter. It can be seen in Fig. 5 that regardless of the size of GDR, the winning rate will not be significantly
Knowledge Transfer from Situation Evaluation to MARL
11
affected. The relative size of potentials in different states is more important than absolute size. The role of GDR is to guide agents from lower potential states to higher ones. The parameter only affects the absolute value, while the relative relationship is learned from offline data.
Fig. 5. We train agents respectively with 1, 0.5 and 0.25 times GDR in (a) academy 3 vs 1 with keeper and (b) academy pass and shoot with keeper, and contrast with Score-agents.
5.3
Transfer Layers Study
As can be seen in Fig. 6, when one or two layers are transferred, the performance in both academy counterattack hard and academy corner scenarios is better than no transferring, because the low layers of deep networks is to learn general features. On the one hand, these features have deeper semantic information than raw observation. On the other hand, they are not limited to specific task. Therefore, the transfer of low layers is effective and universal. However, for deep layers, the third and fourth layers in our task, learn abstract features related to specific task. When transferring three or four layers, agents perform exceptionally well in academy corner scenario, but poorly in academy counterattack scenario. The former is a fast-paced task. Hence the state space is smaller than the last. The deep-layer features of the situation evaluation network happen to suitable for solving this specific MARL task. As for counterattack, it is difficult for these features to remain unbiased in such a large state space.
12
M. Chen et al.
Fig. 6. We respectively transfer 1, 2, 3, 4 layers from situation evaluation network as situation comprehension network of GDR-agents in (a) academy counterattack hard and (b) academy corner, and contrast with GDR-agents.
6
Conclusion
In this paper, we study viable solutions to the challenges of (1) sparse rewards and (2) contradiction between consistent cognition and policy diversity in MARL. On the one hand, the output of situation evaluation models can be utilized to construct guiding dense reward. On the other hand, situation evaluation network can be transferred as situation comprehension network to MARL task. Our method can be easily combined with existing MARL methods. Acknowledgment. This work was supported by the National Key Research and Development Program of China under Grant 2020AAA0103404 and the National Natural Science Foundation of China under Grant 62073323.
References 1. Arora, S., Doshi, P.: A survey of inverse reinforcement learning: challenges, methods and progress. Artif. Intell. 297(C), 103500 (2021) 2. Berner, C., et al.: Dota 2 with large scale deep reinforcement learning. arXiv preprint arXiv:1912.06680 (2019) 3. Burda, Y., Edwards, H., Pathak, D., Storkey, A., Darrell, T., Efros, A.A.: Largescale study of curiosity-driven learning. In: International Conference on Learning Representations (2018) 4. Chenghao, L., Wang, T., Wu, C., Zhao, Q., Yang, J., Zhang, C.: Celebrating diversity in shared multi-agent reinforcement learning. In: Advances in Neural Information Processing Systems, vol. 34 (2021) 5. Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: BERT: pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018)
Knowledge Transfer from Situation Evaluation to MARL
13
6. Devlin, S., Kudenko, D.: Theoretical considerations of potential-based reward shaping for multi-agent systems. In: The 10th International Conference on Autonomous Agents and Multiagent Systems. pp. 225–232. ACM (2011) 7. Devlin, S.M., Kudenko, D.: Dynamic potential-based reward shaping. In: Proceedings of the 11th International Conference on Autonomous Agents and Multiagent Systems, pp. 433–440. IFAAMAS (2012) 8. Finn, C., Levine, S., Abbeel, P.: Guided cost learning: deep inverse optimal control via policy optimization. In: International Conference on Machine Learning, pp. 49–58. PMLR (2016) 9. Harutyunyan, A., Devlin, S., Vrancx, P., Now´e, A.: Expressing arbitrary reward functions as potential-based advice. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 29 (2015) 10. Hu, Y., et al.: Learning to utilize shaping rewards: a new approach of reward shaping. Adv. Neural Inf. Process. Syst. 33, 15931–15941 (2020) 11. Jaderberg, M., et al.: Human-level performance in 3d multiplayer games with population-based reinforcement learning. Science 364(6443), 859–865 (2019) 12. Kurach, K., et al.: Google research football: a novel reinforcement learning environment. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 4501–4510 (2020) 13. Lample, G., Chaplot, D.S.: Playing fps games with deep reinforcement learning. In: Thirty-First AAAI Conference on Artificial Intelligence (2017) 14. Lowe, R., Wu, Y.I., Tamar, A., Harb, J., Pieter Abbeel, O., Mordatch, I.: Multiagent actor-critic for mixed cooperative-competitive environments. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 15. Ma, X., Yang, Y., Li, C., Lu, Y., Zhao, Q., Jun, Y.: Modeling the interaction between agents in cooperative multi-agent reinforcement learning. arXiv preprint arXiv:2102.06042 (2021) 16. Mnih, V., et al.: Playing Atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602 (2013) 17. Ng, A.Y., Harada, D., Russell, S.: Policy invariance under reward transformations: theory and application to reward shaping. In: ICML, vol. 99, pp. 278–287 (1999) 18. Oliehoek, F.A., Amato, C.: A concise introduction to decentralized POMDPs. BRIEFSINSY, Springer, Cham (2016). https://doi.org/10.1007/978-3-319-289298 19. Peng, X.B., Kanazawa, A., Toyer, S., Abbeel, P., Levine, S.: Variational discriminator bottleneck: improving imitation learning, inverse RL, and GANs by constraining information flow (2020) 20. Radford, A., Narasimhan, K., Salimans, T., Sutskever, I.: Improving language understanding by generative pre-training (2018) 21. Rashid, T., Samvelyan, M., Schroeder, C., Farquhar, G., Foerster, J., Whiteson, S.: QMIX: monotonic value function factorisation for deep multi-agent reinforcement learning. In: International Conference on Machine Learning, pp. 4295–4304. PMLR (2018) 22. Schulman, J., Wolski, F., Dhariwal, P., Radford, A., Klimov, O.: Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347 (2017) 23. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 24. Son, K., Kim, D., Kang, W.J., Hostallero, D.E., Yi, Y.: Qtran: learning to factorize with transformation for cooperative multi-agent reinforcement learning. In: International Conference on Machine Learning, pp. 5887–5896. PMLR (2019)
14
M. Chen et al.
25. Song, S., Weng, J., Su, H., Yan, D., Zou, H., Zhu, J.: Playing fps games with environment-aware hierarchical reinforcement learning. In: IJCAI, pp. 3475–3482 (2019) 26. Sunehag, P., et al.: Value-decomposition networks for cooperative multi-agent learning. arXiv preprint arXiv:1706.05296 (2017) 27. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 28. Vinyals, O., Babuschkin, I., Czarnecki, W.M., Mathieu, M., Dudzik, A., Chung, J., Choi, D.H., Powell, R., Ewalds, T., Georgiev, P., et al.: Grandmaster level in starcraft ii using multi-agent reinforcement learning. Nature 575(7782), 350–354 (2019) 29. Wang, J., Ren, Z., Liu, T., Yu, Y., Zhang, C.: Qplex: Duplex dueling multi-agent q-learning. arXiv preprint arXiv:2008.01062 (2020) 30. Wang, T., Dong, H., Lesser, V., Zhang, C.: Roma: multi-agent reinforcement learning with emergent roles. In: International Conference on Machine Learning, pp. 9876–9886. PMLR (2020) 31. Wiewiora, E., Cottrell, G.W., Elkan, C.: Principled methods for advising reinforcement learning agents. In: Proceedings of the 20th International Conference on Machine Learning (ICML-2003), pp. 792–799 (2003) 32. Yosinski, J., Clune, J., Bengio, Y., Lipson, H.: How transferable are features in deep neural networks? In: Advances in Neural Information Processing Systems, vol. 27 (2014) 33. Yu, C., Velu, A., Vinitsky, E., Wang, Y., Bayen, A., Wu, Y.: The surprising effectiveness of PPO in cooperative, multi-agent games. arXiv preprint arXiv:2103.01955 (2021) 34. Zheng, Z., et al.: What can learned intrinsic rewards capture? In: International Conference on Machine Learning, pp. 11436–11446. PMLR (2020) 35. Zheng, Z., Oh, J., Singh, S.: On learning intrinsic rewards for policy gradient methods. In: Advances in Neural Information Processing Systems, vol. 31 (2018)
Sequential Three-Way Rules Class-Overlap Under-Sampling Based on Fuzzy Hierarchical Subspace for Imbalanced Data Qi Dai1 , Jian- wei Liu1(B)
, and Jia- peng Yang2
1 Department of Automation, China University of Petroleum, Beijing, China
[email protected] 2 College of Science, North China University of Science and Technology, Tangshan, China
Abstract. The imbalanced data classification is one of the most critical challenges in the field of data mining. The state-of-the-art class-overlap under-sampling algorithm considers that the majority nearest neighbors of minority class instances are more prone to class-overlap. When the number of minority instances is small, the instances removed by such methods are not thorough. Therefore, a Sequential Three-way Rules class-overlap undersampling method based on fuzzy hierarchical subspace is proposed, which is inspired by sequential three-way decision. First, the fuzzy hierarchical subspace (FHS) concept is proposed to construct the fuzzy hierarchical subspace structure of the dataset. Then, a sequential three-way rules is constructed to find the equivalent majority instances of the minority instances from the fuzzy hierarchical subspace. We assume that the equivalent majority instances are overlapping instances of the minority class. Finally, in order to preserve the information of the majority instances in the equivalence class, we keep the majority instances with the largest Mahalanobis distance from the center of the equivalence class. Experimental results on 18 real datasets show that S3RCU outperforms or partially outperforms state-of-the-art class-overlap under-sampling methods on two evaluation metrics, F-measure and KAPPA. Keywords: imbalanced data · class-overlap · fuzzy hierarchical subspace · sequential three-way rules · undersampling
1 Introduction Class imbalanced problem is the focus of research in the field of machine learning and data mining. Imbalanced data has serious class imbalanced distribution. Therefore, the traditional classification model cannot effectively represent the structural characteristics of imbalanced data, and it is difficult to determine the true classification borderline of the data set, which leads to the classification model biased to the majority class, which seriously affects the classification performance of traditional classifiers [1, 2]. Class This work was supported by the Science Foundation of China. University of Petroleum, Beijing (No. 2462020YXZZ023). © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 15–27, 2023. https://doi.org/10.1007/978-981-99-1639-9_2
16
Q. Dai et al.
imbalance problems are common in various fields, such as network intrusion detection [3, 4], text classification [5, 6], biomedical diagnosis [7, 8], etc. The problem of class imbalance often occurs in various fields, in which the majority of classes (also called negative classes) have far more instances than the minority classes (also called positive classes) [9, 10]. At this time, the minority class is what we are interested in. The reason is that in the real world, minority events will bring huge economic losses and even casualties. For example, in the field of network intrusion detection, intrusion incidents only account for a small fraction of network incidents. For network intrusion detection systems, it is very important to correctly identify real intrusion detection events. If we directly apply traditional classifiers on imbalanced data sets, we often fail to obtain ideal results. Assuming that there are 1000 instances in the data set, which includes 980 instances of the majority class, and only 20 instances of the minority class, if the model we design divides all instances into the majority class, the accuracy can still reach 98%. From the perspective of accuracy, the classification ability of the model is indeed excellent, but it is not difficult to find that the model does not hit the minority instances at this time, which loses the meaning of building a classification model. In recent years, a lot of research has been done in the field of imbalanced data mining and many solutions have been proposed. Related research shows that when the dataset does not have the problem of class-overlap, the class distribution has little effect on the classification performance. The latest class overlapping under-sampling method considers that the closer the majority instances is to the minority instances, the higher the probability of overlapping. However, for highly imbalanced data, such an assumption is obviously not valid, because in a highly imbalanced data set, there are few minority instances, and it is not thorough to eliminate instances in this way. Granular Computing (GrC) is a new method for simulating human computational thinking and solving problems [11]. Information granules computing and information granulation are key issues in the field of granular computing research. For the class imbalance problem, some preprocessing methods based on granular computing are proposed [12, 13]. A sequential three-way rules class-overlap under-sampling based on fuzzy hierarchical subspace is proposed inspired by granular computing and sequential three-way decision. This method uses the idea of granular computing to find potential overlapping instances in the boundary region of the dataset and improve the classification performance of the base classifier. The main contributions of this paper are summarized as follows: (1) The fuzzy Euclidean distance is given, and the fuzzy hierarchical subspace structure is constructed. (2) According to the fuzzy hierarchical subspace, construct sequential three-way rules to form the original granular structure. (3) In the class-overlap undersampling method, the Mahalanobis distance is used to preserve the majority instances in the equivalence class. The remainder of the paper is organized as follows. In Sect. 2, we review some related work on the resampling technique of imbalanced data. Sect. 3 introduces the specific process and pseudo code of the model in detail. Sect. 4 analyzes the experimental results
Sequential Three-Way Rules Class-Overlap Under-Sampling
17
between this method and other popular sampling techniques. Some conclusions and future works are drawn in Sect. 5.
2 Related Work The solutions to the problem of imbalanced data distribution can usually be divided into four categories: data-level methods [14, 15], algorithm-level methods [16, 17], cost-sensitive learning [18, 19] and ensemble learning [20, 21]. The method studied in this paper belongs to the data-level method, so this section will focus on the data-level methods. The data level method to deal with imbalanced data is also called data preprocessing method. They all use data preprocessing technique to change the class distribution of the training set before building the classification model, thereby improving the classification performance of the classification model. This type of method processes the data set before building the model and does not change the classification algorithm. Therefore, these methods are usually simple, easy to implement, have high generalization ability, and can be used in combination with all classifiers. Roughly speaking, we can divide such methods into: under-sampling [22], over-sampling [23] and hybrid sampling [24]. Since this paper is a class-overlap under-sampling method, we mainly introduce the under-sampling method. The under-sampling method balances the class distribution by eliminating some redundant instances in the majority class. It is easy to eliminate valuable majority instances in the under-sampling process, resulting in loss of instances information. The simplest of these algorithms is the random under-sampling method (RUS), which balances the data set by randomly eliminating instances in the majority class, but the algorithm is not stable due to the random deletion of instances. Tsai et al. [25] proposed a cluster-based instance selection (CBIS), which combines clustering algorithm with instance selection to achieve under-sampling of imbalanced data sets. Xie et al. [26] proposed a new method of density peak progressive under-sampling, which introduced two indicators to evaluate the importance of each instance, and gradually eliminated minority instances according to the importance. Regarding the under-sampling method considering the class overlap problem, the state-of-the-art related research reports are less. Aiming at the potential overlap problem in fraudulent transaction detection, Li et al. [27] adopted the divide and conquer idea and proposed a hybrid sampling strategy to solve the problem of class-overlap and class-imbalance. Yuan et al. [28] proposed a new random forest algorithm (OIS-RF) considering class overlap and imbalance sensitivity issues. The above method can effectively improve the classification performance of the classical classification model. However, they generally believe that the majority class nearest neighbor instances of the minority class are more likely to overlap. If we remove all nearest neighbor instances, the over-elimination problem is prone to occur. Therefore, in the equivalence class in the boundary domain, we calculate the Mahalanobis distance of all instances to the center of the equivalence class, and retain the majority instance farthest to the center of the equivalence class, protecting the majority instance information in the equivalence class.
18
Q. Dai et al.
3 Sequential Three-Way Rules Class-Overlap Under-Sampling Based on Fuzzy Hierarchical Subspace (S3RCU) In this section, we introduce in detail the Sequential three-way rules class-overlap undersampling based on the fuzzy hierarchical subspace method (S3RCU) proposed in this paper. The S3RCU method can be roughly divided into three phases: construct fuzzy hierarchical subspace, determine the sequential three-way rules and under-sampling. The S3RCU flowchart is shown in Fig. 1.
Fig. 1. The flowchart of the proposed S3RCU.
The specific steps of the algorithm are as follows: Phase 1: (Construct Fuzzy Hierarchical Subspace) Step 1. Since the original features of the instances have different dimensions and different value ranges, the Cauchy distribution function is selected as the membership function of the algorithm according to experience, and the instance features are fuzzified. 1 xm ≤ am (1) μ(xm ) = 1 x > am 1+(x −a )0.5 m m
m
where μ(xm ) is the membership degree of the m-th feature, and am is the minimum feature value. Step 2. Use fuzzy Euclidean distance to calculate the fuzzy distance between features. The fuzzy Euclidean distance is defined as follows:
Sequential Three-Way Rules Class-Overlap Under-Sampling
19
Definition 1 (Fuzzy Euclidean Distance). Assuming that there are n instances in a dataset S, μ(xim ) represents the membership degree of the m-th feature corresponding to the i-th instance, then the fuzzy Euclidean distance between features i and j is defined as: n [μi (x) − μj (x)]2 (2) d (i, j) = x=1
Step 3. Use the fuzzy Euclidean distance result to construct a fuzzy similarity matrix between features. ⎡ ⎤ d (1, 1) d (1, 2) · · · d (1, m) ⎢ d (2, 1) d (2, 2) · · · d (2, m) ⎥ ⎢ ⎥ R = [d (i, j)]m×m = ⎢ (3) ⎥ .. .. .. .. ⎣ ⎦ . . . . d (m, 1) d (m, 2) · · · d (m, m) This matrix is a diagonal matrix, and the diagonal line represents the distance to itself, so the value on the diagonal line is 0. Step 4. According to the fuzzy similarity matrix, the distance value is selected from small to large to construct a hierarchical structure, and the subspace is obtained according to the hierarchical subspace structure Definition 2 (Hierarchical Structure). Let U be a finite field and R˜ be a fuzzy equivalence relation on U . The set πR˜ (U ) = {U \R˜ λ |λ ∈ D } is called the hierarchical quo˜ we call it a hierarchical structure, and each layer is called a tient space structure of R, hierarchical subspace. Phase 2: (Determine the Sequential Three-Way Rules) Step 1. (Discretization) Discretization is an essential step in the granulation process. For simplicity, we use the equidistant division method to discretize each features. In the process of discretization, it is not clear how many equal parts are divided into equal distances. Therefore, we introduce the concept of granulation factor here. The value of the granulation factor represents the number of regions where each feature is equidistant in the discrete process. When the granulation factor is small, the discrete area is less, and the granulation space formed is rougher, and vice versa. The granulation factor τ is a hyperparameter introduced by us and needs to be manually set by the user. Step 2. (Binary Relationship Granulation) This step is the kernel part of granulation. The definition used in the granulation process are as follows: Definition 3 (Binary Relations). Let U = { x1 , x2 , . . . , xn } represents the universe of discourse containing n instances, and x represents the instance in the universe, R is the equivalence relation on U . If R satisfies reflexive, symmetrical, and transitive [29], then R is referred to as an equivalence relation.
20
Q. Dai et al.
Definition 4 (Information Granules). Let R be the equivalence relation on the universe of discourse U , if there is a G ∈ R, than the G is called the information granules in the equivalence relation R. Definition 5 (Binary Relational Information Granules Refinement Operation). In the decision table S, there is B, C ∈ At, Let GB ∈ RB and GC ∈ RC be the information granules formed by the binary relations RB and RC respectively, then GB ∩ GC =U RB ∩U \RC . Step 3 (Determine the Three-Way Rules). Three-way decision (3WD) is a granular computing method developed in recent years to deal with uncertain decisions. Sequential three-way decision is a variant of three-way decision, which is mainly designed for delayed decision in the decision process [30]. After the initial partition is obtained, only instances in the boundary domain are processed each time a feature is added. In our algorithm, features are added layer by layer according to the difference of features between hierarchical subspaces, and the boundary domain is processed using the three-way rule. Computational complexity is higher if we perform computations on subsets of the granulated data separately. Therefore, we propose an easy-to-understand three-way rule to divide the subsets formed by data granulation. Definition 6 (Three-way Rules). Let G be the information granules formed after data granulation, the class label C={c1 , c2 } in the decision table, where c1 is the minority class and c2 is the majority class. Assuming there are two different instances {x1 , x2 } ∈ G,we can get the following rules: If x1 , x2 ∈ c1 , the information granule G is called the isolated information granule, and the domain of this type of information granule is called the isolated domain Ri . If x1 ∈ c1 , x2 ∈ c2 , the information granule G is called the borderline information granule, and the domain of this type of information granule is called the borderline domain Rb . If x1 , x2 ∈ c2 , the information granule G is called the redundant information granule, and the domain of this type of information granule is called the redundant domain Rr . In particular, when there is one instance in the information granule G, if the instance is in the majority class, the instance is classified as a redundant information granule; if the instance is in the minority class, the instance is classified as an isolated information granule. Phase 3 (Class-overlap Under-Sampling) In the current re-sampling methods, most researchers still use cluster algorithms as benchmark methods, and clustering algorithms can effectively achieve data partitioning. From the perspective of granular computing, binary relations and clustering are common methods of granulation. Therefore, these two methods have something in common. In our algorithm, the equivalence classes in the boundary region are independent, so we keep the majority instances in the equivalence class that are the farthest from the equivalence class center Mahalanobis distance. In addition, to further reduce the scale of the majority instance in the redundant domain, we also use the majority instance that is kept farthest from the center of the equivalence class as the representative of the equivalence class
Sequential Three-Way Rules Class-Overlap Under-Sampling
21
and put it into the training set. The formula for calculating the Mahalanobis distance is as follows. Definition 7 (Mahalanobis Distance). Assuming that each instances in the data has m features, the mean vector μ(μ1 , μ2 , . . . , μm ), and the covariance matrix is , the Mahalanobis distance of instance xim (ai1 , ai2 , . . . , aim ) is defined as: −1 dMD = (xim − μ)T (4) (xim − μ) The pseudo-code of the Sequential three-way rules class-overlap under-sampling based on fuzzy hierarchical subspace (S3RCU) as shown in Algorithm 1.
4 Experimental Results and Discussion 4.1 Datasets This article will evaluate SVM on 18 publicly available data sets. The data sets used in the experiment are mainly collected from the KEEL database and HDDT database. The four data sets of covtyp, estate, ism and pendigits are from the HDDT database, and the rest of the data sets are from the KEEL database. The basic information of the data set is shown in Table 1.
22
Q. Dai et al. Table 1. Basic information of the data set
Datasets
Features
Minority
Majority
Class
Imbalanced ratio
covtyp
10
2747
35753
2
13.02
ecoli1
7
77
259
2
3.36
ecoli2
7
52
284
2
5.46
ecoli3
7
35
301
2
8.6
estate
12
636
4686
2
7.37
glass0
9
70
144
2
2.06
glass1
9
76
138
2
1.82
haberman
3
81
225
2
2.78
ism
6
260
10919
2
42
page_blocks0
10
559
4913
2
8.79
pendigits
16
1142
9850
2
8.63
8
268
500
2
1.87
segment0
19
329
1979
2
6.02
wisconsin
9
239
444
2
1.86
yeast3
8
163
1321
2
8.1
yeast4
8
51
1433
2
28.1
yeast5
8
44
1440
2
32.73
yeast6
8
35
1449
2
41.4
pima
4.2 Experimental Settings Evaluation Metrics We compare their performance on all models using two evaluation metrics, F-measure and Kappa. For the training and testing of the classifier, all data sets pass the 10-fold cross-validation training and testing data sets. In the experiment, the granulation factor in the S3RCU undersampling algorithm is set to τ = 6. Base classifier We are calling the SVM directly from the sklearn package, we have not changed its default parameters. Baseline methods We use five class-overlap undersampling methods as baseline methods for comparative experiments. The baseline algorithms are as follows: Random under-sampling (RUS), Tomek-Links (TL), Condensed Nearest Neighbour (CNN), NB-TL and NB-Comm [31]. 4.3 Experimental Results Due to space reasons, we will not repeat the experiments on the granulation factor. In the subsequent experiments, we use the greedy search method to search for the optimal result of the granulation factor. Table 2. presents the F-measure results using SVM. While the Kappa coefficients are given in Table 3.. Table 4. gives the Friedman ranking and the APVs values of Holm’s post-hoc test to F-measure and Kappa.
Sequential Three-Way Rules Class-Overlap Under-Sampling
23
The experimental results show that the Sequential three-way rules class-overlap under-sampling based on fuzzy hierarchical subspace method proposed in this paper can obtain the highest value under the F-measure evaluation metric on most datasets. Among the 18 datasets, S3RCU achieves the best value on 11 datasets. Since our S3RCU method needs to discretize the data set before mining equivalence class instances in the calculation process, in some data sets, this method may cause the problem of data distortion. On some datasets, when the imbalance ratio is low, our algorithm may lead to a decrease in the recognition accuracy of the majority instances. Table 2. F-measure comparative analysis of S3RCU Datasets
RUS
TL
CNN
NB-TL
NB-Comm
S3RCU
covtyp
0.9466
0.9738
0.9793
0.9738
0.9478
0.9809
ecoli1
0.7587
0.7746
0.7877
0.7689
0.7168
0.8206
ecoli2
0.8169
0.8568
0.8633
0.7897
0.4749
0.8934
ecoli3
0.6249
0.6428
0.6704
0.6059
0.4981
0.6871
estate
0.7843
0.9368
0.9373
0.9369
0.8213
0.9366
glass0
0.7011
0.7201
0.7388
0.7329
0.6821
0.7365
glass1
0.5882
0.5279
0.5585
0.5424
0.6149
0.6649
haberman
0.3859
0.2228
0.3035
0.4413
0.3508
0.3142
ism
0.9313
0.9908
0.9907
0.9823
0.9896
0.9917
page-blocks0
0.6049
0.6883
0.7017
0.7079
0.5408
0.6945
pendigits
0.9983
0.9991
0.9989
0.9988
0.9984
0.9947
pima
0.6591
0.6178
0.6577
0.6595
0.6436
0.6602
segment0
0.9705
0.9841
0.9846
0.9846
0.9817
0.9849
wisconsin
0.9431
0.9504
0.9216
0.9428
0.9279
0.9449
yeast3
0.6044
0.7673
0.7821
0.7822
0.7730
0.7917
yeast4
0.2936
0
0.1393
0.3216
0
0.2371
yeast5
0.4413
0.3706
0.6050
0.5972
0.5544
0.6952
yeast6
0.3219
0.1411
0.2867
0.2471
0.2379
0.3749
When we choose Kappa as the evaluation metric, S3RCU achieves the best performance on 12 datasets. Since SVM needs to calculate the support vector of the data set, when the size of the data set is large, its computational complexity is high, and the obtained support vector may have imbalance. Therefore, the performance of SVM may be suboptimal when the dataset size is large. When we observe the results in Tables 4., we can notice that our proposed method has a high Friedman ranking under both evaluation metrics. Such results show that S3RCU can achieve better performance on most datasets. Similarly, Holm’s post-hoc test verifies this result.
24
Q. Dai et al.
4.4 Global Analysis of Results We can make a global analysis of results combining the results offered by Tables 2., 3., and 4.: Firstly, according to the experimental results, it is not difficult to find that the performance of SVM is not ideal for datasets with large instance sizes. In addition, for datasets with low imbalance ratio, S3RCU may delete more majority instances, resulting in a decrease in the performance of global majority instances. For some datasets, if the difference between the feature values in the dataset is small, S3RCU tends to ignore potential overlapping instances or delete too many majority instances. Second, the classical class-overlapping undersampling algorithm searches the nearest neighbors of minority instances through a global metric. However, they may ignore instances with overlapping local features. S3RCU constructs a fuzzy subspace on the data set and uses the sequential three-way decision to build three-way rules, which can gradually discover potential overlapping instances during the division process. In addition, we retain part of the majority instance information in the boundary region, which can reduce the risk of loss information caused by undersampling. This is also ignored by other algorithms, because all the instances on their default boundaries are overlapping, which will cause the loss of majority class information. Finally, according to the results in Table 4., we observe that for the F-measure, S3RCU performs on par with CNN and NB-TL without statistical significance. In addition, we observe that for Kappa, S3RCU outperforms the other five algorithms with statistical significance. According to the Friedman ranking results, it can be seen that S3RCU has considerable competitive advantages. Table 3. Kappa comparative analysis of S3RCU Sampling
RUS
TL
CNN
NB-TL
NB-Comm
S3RCU
covtyp
0.5472
0.7415
0.7276
0.6736
0.7424
0.6729
ecoli1
0.6988
0.7101
0.7208
0.7265
0.7249
0.7668
ecoli2
0.7818
0.8163
0.8391
0.8269
0.8069
0.8579
ecoli3
0.5653
0.6321
0.6341
0.6317
0.6349
0.6571
estate
0.0939
0.0229
0.0328
0.0618
0.0255
0.0062
glass0
0.5087
0.5940
0.5732
0.4329
0.6047
0.6010
glass1
0.3448
0.3503
0.1514
0.3361
0.3653
0.4357
haberman
0.1128
0.1055
0.1123
0.1064
0.1051
0.2658
ism
0.2191
0.3891
0.5189
0.4742
0.3629
0.5309
page-blocks0
0.5435
0.6607
0.6714
0.4637
0.6628
0.6772
pendigits
0.9840
0.9912
0.7925
0.9903
0.9888
0.9573
pima
0.4391
0.4219
0.4461
0.3995
0.4414
0.4645 (continued)
Sequential Three-Way Rules Class-Overlap Under-Sampling
25
Table 3. (continued) Sampling
RUS
TL
CNN
segment0 wisconsin
NB-TL
NB-Comm
S3RCU
0.9655
0.9809
0.9823
0.9188
0.9226
0.8707
0.9821
0.9820
0.9824
0.9153
0.9126
0.9242
yeast3
0.5056
0.7441
0.7534
0.7286
0.7449
0.7645
yeast4
0.2512
0
0.1329
0.3982
0
0.2235
yeast5
0.4355
0.4384
0.5548
0.5646
0.4532
0.6012
yeast6
0.2709
0.0787
0.2031
0.4704
0.1282
0.4469
Table 4. The F-measure and Kappa of Friedman ranking and Holm’s post-test Methods
Friedman ranking
APVs
Methods
F-measure
Friedman ranking
APVs
Kappa
S3RCU
1.9444
–
S3RCU
1.9444
–
CNN
2.8056
0.167325
CNN
3.3889
0.025761
NB-TL
3.1111
0.122738
NB-TL
3.5556
0.025761
TL
4.0000
0.002940
NB-Comm
3.5833
0.025761
RUS
4.3333
0.000511
TL
4.0278
0.003342
NB-Comm
4.8056
0.000022
RUS
4.5000
0.000208
5 Conclusion and Future Work The class-imbalance problem has attracted extensive attention of data mining researchers. However, some studies have shown that the imbalance of class distribution is not the main factor affecting the performance of the classifier, and they believe that the class-overlap between instances is the main reason for the degradation of classification performance. Therefore, for the first time, we apply the idea of granular computing and sequential three-way decision to class-overlap under-sampling for imbalanced data, and proposed a new class-overlap under-sampling method (S3RCU). The experimental results show that we are successful in introducing the granular computing model into the class imbalance problem, it shows obvious competitive advantages. In the future, we can introduce more granular computing related models into data mining. Our proposed S3RCU shows excellent performance on binary classification problems, and we will also consider using it in the future to build a class-overlap undersampling algorithm for multi-class imbalanced data.
26
Q. Dai et al.
References 1. He, H., Garcia, E.A.: Learning from imbalanced data. IEEE Trans. Knowl. Data Eng. 21(9), 1263–1284 (2009) 2. Japkowicz, N., Stephen, S.: The class imbalance problem: a systematic study. Intell. Data Anal. 6(5), 429–449 (2002) 3. Al, S., Dener, M.: STL-HDL: a new hybrid network intrusion detection system for imbalanced dataset on big data environment. Comput. Secur. 110, 102435 (2021) 4. Pozi, M.S.M., Sulaiman, M.N., Mustapha, N., Perumal, T.: Improving anomalous rare attack detection rate for intrusion detection system using support vector machine and genetic programming. Neural Process. Lett. 44(2), 279–290 (2016) 5. Naderalvojoud, B., Sezer, E.A.: Term evaluation metrics in imbalanced text categorization. Nat. Lang. Eng. 26(1), 31–47 (2020) 6. Wu, Q., Ye, Y., Zhang, H., Ng, M.K., Ho, S.S.: ForesTexter: an efficient random forest algorithm for imbalanced text categorization. Knowl. Based Syst. 67, 105–116 (2014) 7. Krawczyk, B., Galar, M., Jele´n, Ł, Herrera, F.: Evolutionary undersampling boosting for imbalanced classification of breast cancer malignancy. Appl. Soft Comput. 38, 714–726 (2016) 8. Rupapara, V., Rustam, F., Shahzad, H.F., Mehmood, A., Ashraf, I., Choi, G.S.: Impact of SMOTE on imbalanced text features for toxic comments classification using RVVC model. IEEE Access 9, 78621–78634 (2021) 9. Krawczyk, B.: Learning from imbalanced data: open challenges and future directions. Progr. Artif. Intell. 5(4), 221–232 (2016). https://doi.org/10.1007/s13748-016-0094-0 10. Wang, S., Yao, X.: Multiclass imbalance problems: analysis and potential solutions. IEEE Trans. Syst. Man Cybern. Part B (Cybern.). 42(4), 1119–1130 (2012) 11. Yao, J.T., Vasilakos, A.V., Pedrycz, W.: Granular computing: perspectives and challenges. IEEE Trans. Cybern. 43(6), 1977–1989 (2013) 12. Yan, Y.T., Wu, Z.B., Du, X.Q., Chen, J., Zhao, S., Zhang, Y.P.: A three-way decision ensemble method for imbalanced data oversampling. Int. J. Approx. Reason. 107, 1–16 (2019) 13. Dai, Q., Liu, J.W., Liu, Y.: Multi-granularity relabeled under-sampling algorithm for imbalanced data. Appl. Soft Comput. 124, 109083 (2022) 14. Liu, Y., Liu, Y., Yu, B.X.B., Zhong, S., Hu, Z.: Noise-robust oversampling for imbalanced data classification. Pattern Recogn. 133, 109008 (2023). https://doi.org/10.1016/j.patcog.2022. 109008 15. Liang, T., Xu, J., Zou, B., Wang, Z., Zeng, J.: LDAMSS: fast and efficient undersampling method for imbalanced learning. Appl. Intell. 52(6), 6794–6811 (2021). https://doi.org/10. 1007/s10489-021-02780-x 16. Wang, G., Wong, K.W.: An accuracy-maximization learning framework for supervised and semi-supervised imbalanced data. Knowl. Based Syst. 255, 109678 (2022) 17. Liu, J.: Fuzzy support vector machine for imbalanced data with borderline noise. Fuzzy Sets Syst. 413, 64–73 (2021) 18. Peng, P., Zhang, W., Zhang, Y., Wang, H., Zhang, H.: Non-revisiting genetic cost-sensitive sparse autoencoder for imbalanced fault diagnosis. Appl. Soft Comput. 114, 108138 (2022) 19. Aram, K.Y., Lam, S.S., Khasawneh, M.T.: Linear cost-sensitive max-margin embedded feature selection for SVM. Expert Syst. Appl. 197, 116683 (2022) 20. Ren, J., Wang, Y., Mao, M., Cheung, Y.M.: Equalization ensemble for large scale highly imbalanced data classification. Knowl. Based Syst. 242, 108295 (2022) 21. Gupta, N., Jindal, V., Bedi, P.: CSE-IDS: Using cost-sensitive deep learning and ensemble algorithms to handle class imbalance in network-based intrusion detection systems. Comput. Secur. 112, 102499 (2022)
Sequential Three-Way Rules Class-Overlap Under-Sampling
27
22. Yan, Y., Zhu, Y., Liu, R., Zhang, Y., Zhang, Y., Zhang, L.: Spatial distribution-based imbalanced undersampling. IEEE Trans. Knowl. Data Eng. (2022). https://doi.org/10.1109/TKDE. 2022.3161537 23. Azhar, N.A., Pozi, M.S.M., Din, A.M., Jatowt, A.: An investigation of SMOTE based methods for imbalanced datasets with data complexity analysis. IEEE Trans. Knowl. Data Eng. (2022). https://doi.org/10.1109/TKDE.2022.3179381 24. Koziarski, M., Bellinger, C., Wo´zniak, M.: RB-CCR: radial-based combined cleaning and resampling algorithm for imbalanced data classification. Mach. Learn. 110(11–12), 3059– 3093 (2021). https://doi.org/10.1007/s10994-021-06012-8 25. Tsai, C.F., Lin, W.C., Hu, Y.H., Yao, G.T.: Under-sampling class imbalanced datasets by combining clustering analysis and instance selection. Inf. Sci. 477, 47–54 (2019) 26. Xie, X., Liu, H., Zeng, S., Lin, L., Li, W.: A novel progressively undersampling method based on the density peaks sequence for imbalanced data. Knowl. Based Syst. 213, 106689 (2021) 27. Li, Z., Huang, M., Liu, G., Jiang, C.: A hybrid method with dynamic weighted entropy for handling the problem of class imbalance with overlap in credit card fraud detection. Expert Syst. Appl. 175, 114750 (2021) 28. Yuan, B.W., Zhang, Z.L., Luo, X.G., Yu, Y., Zou, X.H., Zou, X.D.: OIS-RF: a novel overlap and imbalance sensitive random forest. Eng. Appl. Artif. Intell. 104, 104335 (2021) 29. Wang, C., He, Q., Shao, M., Xu, Y., Hu, Q.: A unified information measure for general binary relations. Knowl. Based Syst. 135, 18–28 (2017) 30. Yao, Y.: Three-way decisions with probabilistic rough sets. Inf. Sci. 180(3), 341–353 (2010) 31. Vuttipittayamongkol, P., Elyan, E., Petrovski, A.: On the class overlap problem in imbalanced data classification. Knowl. Based Syst. 212, 106631 (2021)
Two-Stage Multilayer Perceptron Hawkes Process Xiang Xing, Jian- wei Liu(B)
, and Zi- hao Cheng
Department of Automation, College of Information Science and Engineering, China University of Petroleum, Beijing, China [email protected]
Abstract. Many social activities can be described as asynchronous discrete event sequences, such as traffic accidents, medical care, financial transactions, social networks and violent crimes, how to predict the event occurrence probabilities, times of occurrence and types of events is a challenging and upmost important problem. It has broad application prospects in urban management, traffic optimization and other fields. Hawkes processes are used to simulate complex sequences of events. Recently, in order to expand the capacity of Hawkes process, neural Hawkes process (NHP) and transformer Hawkes process (THP) were proposed. We argue that the complexity of the model is high due to the introduction of recurrent neural networks or attention mechanisms. While the attention mechanism can achieve good performance, it is not essential. Therefore, in this paper, we propose a Twostage Multilayer Perceptron Hawkes Process (TMPHP). The model consists of two types of multilayer perceptrons: one that applies MLPs (learning features of each event sequence to capture long-term dependencies between different events) independently for each event sequence, and one that applies MLPs to different event sequences MLP (capturing long-term and short-term dependencies between different events). Our model is simpler than other state-of-the-art models, but it has better predictive performance. Especially for MIMI data sets, our model outperforms RMTPP (4.2%), NHP (2.2%) and THP (2.2%) in terms of prediction accuracies. Keywords: Hawkes process · Multilayer perceptron · Asynchronous discrete event sequences
1 Introduction In real life, asynchronous event sequences are used in different fields, such as financial transactions [1], earthquake sequences [2], electronic medical records [3], social networks [4] and so on. For example, the prediction of stock buying and selling at different times can be regarded as an asynchronous sequence of events, analyzing the relationship between events, so as to predict the occurrence of future events. Point process model This work was supported by the Science Foundation of China University of Petroleum, Beijing (No. 2462020YXZZ023). © The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 28–39, 2023. https://doi.org/10.1007/978-981-99-1639-9_3
Two-Stage Multilayer Perceptron Hawkes Process
29
[5] is often used to model sequential data. Hawkes process [6] can consider the influence of historical events on current events, so Hawkes process can be used to predict event sequence data. However, in order to solve the problem of over-simplification of Hawkes process, the combination of deep neural network and Hawkes process can meet the requirements of more actual event sequence prediction. Du et al. propose a Recurrent Marked Time Point Process model (RMTPP) [7, 8]. The model uses a recursive neural network (RNN) [9, 10], which shows that RNN learns the representation potential to model event sequences. However, due to the existence of gradient explosion and vanish of RNN, the model training is very difficult. Mei et al. propose the Neural Hawkes Process (NHP) [11]. In 2017, transformer structures [12] show good performance in both natural language processing and computer vision. Zhang et al. propose the Self Attentional Hawkes Process (SAHP) [13], which uses the self-attentional mechanism to predict the occurrence probability of the next event. Zuo et al. propose the Transformer Hawkes Process (THP) [14], which greatly improved the computational efficiency and performance of the model. Neither RNN nor attention mechanism is necessary, although they show good performance in predicting sequence of events. We find that the Hawkes process based on the attention mechanism needs to calculate the attention mechanism for each position of the sequence, which makes the model have great computational redundancy and waste computing resource. Thus, we introduce the MLP-Mixer [15] model to generate a Twostage Multilayer Perceptron Hawkes Process (TMPHP), which utilizes two multi-layer perceptron to separately learn asynchronous event sequences without the use of attention mechanism. Compared to existing models, our model is much improved. The main contributions of this paper are as follows: 1. Our model does not use the attention mechanism and only uses multi-layer perceptron. The model uses two multi-layer perceptron to capture the long- and short-term dependencies between event sequences. 2. We first use MLP1 to learn the dependencies between different events. And then we apply MLP1 output as input of MLP2, and learn the dependencies between different sequences. Compared with other models, our model is relatively simple, and it is the first time to introduce multi-layer perceptron into point process domain. 3. We verify on relevant data sets that our proposed TMPHP model is superior to existing models. The rest of the article is organized as follows: Sect. 2 provides related work. In Sect. 3, we introduce the Hawkes process of two-level multilayer perceptron. In Sect. 3.2, numerical experiments on relevant data sets are introduced. In Sect. 3.3 we draw conclusions and future work.
2 Related Work In this section, we briefly review the Hawkes process [6], the Transformers Hawkes process [14], and MLP-Mixer [15].
30
X. Xing et al.
Hawkes Process [6] Hawkes Process is a double random point process which allows the occurrence of historical events to stimulate the occurrence of current events. The conditional intensity function of Hawkes process is as follows: (t − tk ) (1) λ(t) = μ + k:tk 0 and k is the first k singular values that usually range from c to Xd , where c denotes the number of clusters and Xd denotes the dimension of the sample. Figure 4 shows the clustering performance of UV-LRR for different values of parameters λ and k. In the Binary Alphadigit dataset, the optimal interval of parameter λ is 15–25, and the optimal interval of parameter k is 150– 200. In the COIL-100 dataset, the optimal interval of parameter λ is 3–5, and the optimal interval of parameter k is 800–1024, and in the USPS dataset, the optimal interval of parameter λ is 15–20, and the optimal interval of parameter k is 125–150. This indicates that UV-LRR performs well under a relatively wide range of values of parameter λ and k.
Fig. 4. Parameter sensitivity of UV-LRR for λ and k on the three datasets.
4
Conclusion
In this paper, a dual low-rank constrained representation method for noise in multi-subspace high-dimensional data is proposed. The increase of the low-rank constraint can effectively separate the low-rank part of the data from the noisy part, thus reducing the effect of noise in the low-rank representation of multisubspace data and achieving the goal of low-rank compression. In addition, to further constrain the noise present in the data, the proposed method uses two different noise sparse terms to control the sparsity of sparse global noise and local structured noise, which are more effective in separating the sparse noise
Data Representation and Clustering with Double Low-Rank Constraints
91
terms of the data. Experimental results show that our method achieves good results in subspace clustering, and it can effectively reduce the effect of noise on high-dimensional data. It can still maintain a reasonable recognition rate with excellent robustness in different data heavily contaminated by noise. In terms of denoising, the proposed method also achieves good results. In conclusion, further low-ranking and increasing noise sparsity control are the main contribution points of our method, and this idea can be a good one for future subspace clustering and denoising of high-dimensional data.
References 1. Wu, Y., Zhang, Z., Huang, T.S., et al.: Multibody grouping via orthogonal subspace decomposition. In: Proceedings of the 2001 IEEE Computer Society Conference on Computer Vision and Pattern Recognition. CVPR 2001, vol. 2, p. 2. IEEE (2001) 2. Ma, Y., Yang, A.Y., Derksen, H., et al.: Estimation of subspace arrangements with applications in modeling and segmenting mixed data. SIAM Rev. 50(3), 413–458 (2008) 3. Elhamifar, E., Vidal, R.: Sparse subspace clustering: algorithm, theory, and applications. IEEE Trans. Pattern Anal. Mach. Intell. 35(11), 2765–2781 (2013) 4. Liu, G., Lin, Z., Yan, S., et al.: Robust recovery of subspace structures by low-rank representation. IEEE Trans. Pattern Anal. Mach. Intell. 35(1), 171–184 (2012) 5. Li, C., Liu, C., Gao, G., et al.: Robust low-rank decomposition of multi-channel feature matrices for fabric defect detection. Multimed. Tools Appl. 78(6), 7321– 7339 (2019) 6. Ding, Y., Chong, Y., Pan, S.: Sparse and low-rank representation with key connectivity for hyperspectral image classification. IEEE J. Sel. Top. Appl. Earth Observ. Remote Sens. 13, 5609–5622 (2020) 7. Abdi, H., Williams, L.J.: Principal component analysis. Wiley Interdiscip. Rev. Comput. Statist. 2(4), 433–459 (2010) 8. Ding, C., Zhou, D., He, X., et al.: R 1-PCA: rotational invariant l 1-norm principal component analysis for robust subspace factorization. In: Proceedings of the 23rd International Conference on Machine Learning, pp. 281–288 (2006) 9. Nene, S.A., Nayar, S.K., Murase, H.: Columbia object image library (coil-100) (1996) 10. Hastie, T., Tibshirani, R., Friedman, J.H., et al.: The elements of statistical learning: data mining, inference, and prediction. Springer, New York (2009) 11. Lu, C.-Y., Min, H., Zhao, Z.-Q., Zhu, L., Huang, D.-S., Yan, S.: Robust and efficient subspace segmentation via least squares regression. In: Fitzgibbon, A., Lazebnik, S., Perona, P., Sato, Y., Schmid, C. (eds.) ECCV 2012. LNCS, vol. 7578, pp. 347– 360. Springer, Heidelberg (2012). https://doi.org/10.1007/978-3-642-33786-4 26 12. Vidal, R., Favaro, P.: Low rank subspace clustering (LRSC). Pattern Recogn. Lett. 43, 47–61 (2014) 13. Brbi´c, M., Kopriva, I.: 0 -motivated low-rank sparse subspace clustering. IEEE Trans. Cybern. 50(4), 1711–1725 (2018) 14. Sch¨ utze, H., Manning, C.D., Raghavan, P.: Introduction to Information Retrieval. Cambridge University Press, Cambridge (2008) 15. Sokolova, M., Japkowicz, N., Szpakowicz, S.: Beyond accuracy, F-score and ROC: a family of discriminant measures for performance evaluation. In: Sattar, A., Kang, B. (eds.) AI 2006. LNCS (LNAI), vol. 4304, pp. 1015–1021. Springer, Heidelberg (2006). https://doi.org/10.1007/11941439 114
RoMA: A Method for Neural Network Robustness Measurement and Assessment Natan Levy(B) and Guy Katz The Hebrew University of Jerusalem, Jerusalem, Israel {natan.levy1,g.katz}@mail.huji.ac.il
Abstract. Neural network models have become the leading solution for a large variety of tasks, such as classification, natural language processing, and others. However, their reliability is heavily plagued by adversarial inputs: inputs generated by adding tiny perturbations to correctlyclassified inputs, and for which the neural network produces erroneous results. In this paper, we present a new method called Robustness Measurement and Assessment (RoMA), which measures the robustness of a neural network model against such adversarial inputs. Specifically, RoMA determines the probability that a random input perturbation might cause misclassification. The method allows us to provide formal guarantees regarding the expected frequency of errors that a trained model will encounter after deployment. The type of robustness assessment afforded by RoMA is inspired by state-of-the-art certification practices, and could constitute an important step toward integrating neural networks in safety-critical systems. Keywords: Neural networks Certification
1
· Adversarial examples · Robustness ·
Introduction
In the passing decade, deep neural networks (DNNs) have emerged as one of the most exciting developments in computer science, allowing computers to outperform humans in various classification tasks. However, a major issue with DNNs is the existence of adversarial inputs [11]: inputs that are very close (according to some metrics) to correctly-classified inputs, but which are misclassified themselves. It has been observed that many state-of-the-art DNNs are highly vulnerable to adversarial inputs [6]. As the impact of the AI revolution is becoming evident, regulatory agencies are starting to address the challenge of integrating DNNs into various automotive and aerospace systems—by forming workgroups to create the needed guidelines. Notable examples in the European Union include SAE G-34 and EUROCAE WG-114 [21,26]; and the European Union Safety Agency (EASA), which is responsible for civil aviation safety, and which has published a road map for certifying AI-based systems [9]. These efforts, however, must overcome a significant gap: on one hand, the superior performance of DNNs makes it highly c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 92–105, 2023. https://doi.org/10.1007/978-981-99-1639-9_8
RoMA: Robustness Measurement and Assessment
93
desirable to incorporate them into various systems, but on the other hand, the DNN’s intrinsic susceptibility to adversarial inputs could render them unsafe. This dilemma is particularly felt in safety-critical systems, such as automotive, aerospace and medical devices, where regulators and public opinion set a high bar for reliability. In this work, we seek to begin bridging this gap, by devising a framework that could allow engineers to bound and mitigate the risk introduced by a trained DNN, effectively containing the phenomenon of adversarial inputs. Our approach is inspired by common practices of regulatory agencies, which often need to certify various systems with components that might fail due to an unexpected hazard. A widely used example is the certification of jet engines, which are known to occasionally fail. In order to mitigate this risk, manufacturers compute the engines’ mean time between failures (MTBF ), and then use this value in performing a safety analysis that can eventually justify the safety of the jet engine system as a whole [17]. For example, federal agencies guide that the probability of an extremely improbable failure conditions event per operational hour should not exceed 10−9 [17]. To perform a similar process for DNN-based systems, we first need a technique for accurately bounding the likelihood of a failure to occur—e.g., for measuring the probability of encountering an adversarial input. In this paper, we address the aforesaid crucial gap, by introducing a straightforward and scalable method for measuring the probability that a DNN classifier misclassifies inputs. The method, which we term Robustness Measurement and Assessment (RoMA), is inspired by modern certification concepts, and operates under the assumption that a DNN’s misclassification is due to some internal malfunction, caused by random input perturbations (as opposed to misclassifications triggered by an external cause, such as a malicious adversary). A random input perturbation can occur naturally as part of the system’s operation, e.g., due to scratches on a camera lens or communication disruptions. Under this assumption, RoMA can be used to measure the model’s robustness to randomlyproduced adversarial inputs. RoMA is a method for estimating rare events in a large population—in our case, adversarial inputs within a space of inputs that are generally classified correctly. When these rare events (adversarial inputs) are distributed normally within the input space, RoMA performs the following steps: it (i) samples a few hundred random input points; (ii) measures the “level of adversariality” of each such point; and (iii) uses the normal distribution function to evaluate the probability of encountering an adversarial input within the input space. Unfortunately, adversarial inputs are often not distributed normally. To overcome this difficulty, when RoMA detects this case it first applies a statistical power transformation, called Box-Cox [5], after which the distribution often becomes normal and can be analyzed. The Box-Cox transformation is a widespread method that does not pose any restrictions on the DNN in question (e.g., Lipschitz continuity, certain kinds of activation functions, or specific network topology). Further, the method does not require access to the network’s design or weights, and is thus applicable to large, black-box DNNs.
94
N. Levy and G. Katz
We implemented our method as a proof-of-concept tool, and evaluated it on a VGG16 network trained on the CIFAR10 data set. Using RoMA, we were able to show that, as expected, a higher number of epochs (a higher level of training) leads to a higher robustness score. Additionally, we used RoMA to measure how the model’s robustness score changes as the magnitude of allowed input perturbation is increased. Finally, using RoMA we found that the categorial robustness score of a DNN, which is the robustness score of inputs labeled as a particular category, varies significantly among the different categories. To summarize, our main contributions are: (i) introducing RoMA, which is a new and scalable method for measuring the robustness of a DNN model, and which can be applied to black-box DNNs; (ii) using RoMA to measure the effect of additional training on the robustness of a DNN model; (iii) using RoMA to measure how a model’s robustness changes as the magnitude of input perturbation increases; and (iv) formally computing categorial robustness scores, and demonstrating that they can differ significantly between labels. Related Work. The topic of statistically evaluating a model’s adversarial robustness has been studied extensively. State-of-the-art approaches [7,14] assume that the confidence scores assigned to perturbed images are normally distributed, and apply random sampling to measure robustness. However, as we later demonstrate, this assumption often does not hold. Other approaches [19,25,27] use a sampling method called importance sampling, where a few bad samples with large weights can drastically throw off the estimator. Further, these approaches typically assume that the network’s output is Lipschitz-continuous. Although RoMA is similar in spirit to these approaches, it requires no Lipschitz-continuity, does not assume a-priori that the adversarial input confidence scores are distributed normally, and provides rigorous robustness guarantees. Other noticeable methods for measuring robustness include formalverification based approaches [15,16], which are exact but which afford very limited scalability; and approaches for computing an estimate bound on the probability that a classifier’s margin function exceeds a given value [1,8,28], which focus on worst-case behavior, and may consequently be inadequate for regulatory certification. In contrast, RoMA is a scalable method, which focuses on the more realistic, average case.
2
Background
Neural Networks. A neural network N is a function N : Rn → Rm , which maps a real-valued input vector x ∈ Rn to a real-value output vector y ∈ Rm . For classification networks, which is our subject matter, x is classified as label l if y’s l’th entry has the highest score; i.e., if argmax(N (x)) = l. Local Adversarial Robustness. The local adversarial robustness of a DNN is a measure of how resilient that network is against adversarial perturbations to specific inputs. More formally [3]:
RoMA: Robustness Measurement and Assessment
95
Definition 1. A DNN N is -locally-robust at input point x0 iff ∀x.||x − x0 ||∞ ≤ ⇒ argmax(N (x)) = argmax(N (x0 )) Intuitively, Definition 1 states that for input vector x, which is at a distance at most from a fixed input x0 , the network function assigns to x the same label that it assigns to x0 (for simplicity, we use here the L∞ norm, but other metrics could also be used). When a network is not -local-robust at point x0 , there exists a point x that is at a distance of at most from x0 , which is misclassified; this x is called an adversarial input. In this context, local refers to the fact that x0 is fixed. Distinct Adversarial Robustness. Recall that the label assigned by a classification network is selected according to its greatest output value. The final layer in such networks is usually a softmax layer, and its outputs are commonly interpreted as confidence scores assigned to each of the possible labels.1 We use c(x) to denote the highest confidence score, i.e. c(x) = max(N (x)). We are interested in an adversarial input x only if it is distinctly misclassified [17]; i.e., if x’s assigned label receives a significantly higher score than that of the label assigned to x0 . For example, if argmax(N (x0 )) = argmax(N (x)), but c(x) = 20%, then x is not distinctly an adversarial input: while it is misclassified, the network assigns it an extremely low confidence score. Indeed, in a safety-critical setting, the system is expected to issue a warning to the operator when it has such low confidence in its classification [20]. In contrast, a case where c(x) = 80% is much more distinct: here, the network gives an incorrect answer with high confidence, and no warning to the operator is expected. We refer to inputs that are misclassified with confidence greater than some threshold δ as distinctly adversarial inputs, and refine Definition 1 to only consider them, as follows: Definition 2. A DNN N is (, δ)-distinctly-locally-robust at input point x0 , iff ∀x. ||x − x0 ||∞ ≤ ⇒ argmax(N (x)) = argmax(N (x0 )) ∨ (c(x) < δ) Intuitively, if the definition does not hold then there exists a (distinctly) adversarial input x that is at most away from x0 , and which is assigned a label different than that of x0 with a confidence score that is at least δ.
3
The Proposed Method
3.1
Probabilistic Robustness
Definitions 1 and 2 are geared for an external, malicious adversary: they are concerned with the existence of an adversarial input. Here, we take a different path, 1
The term confidence is sometimes used to represent the reliability of the DNN as a whole; this is not our intention here.
96
N. Levy and G. Katz
and follow common certification methodologies that deal with internal malfunctions of the system [10]. Specifically, we focus on “non-malicious adversaries”— i.e., we assume that perturbations occur naturally, and are not necessarily malicious. This is represented by assuming those perturbations are randomly drawn from some distribution. We argue that the non-malicious adversary setting is more realistic for widely-deployed systems in, e.g., aerospace, which are expected to operate at a large scale and over a prolonged period of time, and are more likely to encounter randomly-perturbed inputs than those crafted by a malicious adversary. Targeting randomly generated adversarial inputs requires extending Definitions 1 and 2 into a probabilistic definition, as follows: Definition 3. The (δ, )-probabilistic-local-robustness score of a DNN N at input point x0 , abbreviated plrδ, (N, x0 ), is defined as: plrδ, (N, x0 ) Px:x−x 0 ∞ ≤ [(argmax(N (x)) = argmax(N (x0 )) ∨ c(x) < δ)] Intuitively, the definition measures the probability that an input x, drawn at random from the -ball around x0 , will either have the same label as x0 or, if it does not, will receive a confidence score lower than δ for its (incorrect) label. A key point is that probabilistic robustness, as defined in Definition 3, is a scalar value: the closer this value is to 1, the less likely it is a random perturbation to x0 would produce a distinctly adversarial input. This is in contrast to Definitions 1 and 2, which are Boolean in nature. We also note that the probability value in Definition 3 can be computed with respect to values of x drawn according to any input distribution of interest. For simplicity, unless otherwise stated, we assume that x is drawn uniformly at random. In practice, we propose to compute plrδ, (N, x) by first computing the probability that a randomly drawn x is a distinctly adversarial input, and then taking that probability’s complement. Unfortunately, directly bounding the probability of randomly encountering an adversarial input, e.g., with the Monte Carlo or Bernoulli methods [13], is not feasible due to the typical extreme sparsity of adversarial inputs, and the large number of samples required to achieve reasonable accuracy [27]. Thus, we require a different statistical approach to obtain this measure, using only a reasonable number of samples. We next propose such an approach. 3.2
Sampling Method and the Normal Distribution
Our approach is to measure the probability of randomly encountering an adversarial input, by examining a finite set of perturbed samples around x0 . Each perturbation is selected through simple random sampling [24] (although other sampling methods can be used), while ensuring that the overall perturbation size does not exceed the given . Next, each perturbed input x is passed through the DNN to obtain a vector of confidence scores for the possible output labels. From this vector, we extract the highest incorrect confidence (hic) score: hic(x) =
max
i=argmax(N (x 0 ))
{N (x)[i]}
RoMA: Robustness Measurement and Assessment
97
which is the highest confidence score assigned to an incorrect label, i.e., a label different from the one assigned to x0 . Observe that input x is distinctly adversarial if and only if its hic score exceeds the δ threshold. The main remaining question is how to extrapolate from the collected hic values a conclusion regarding the hic values in the general population. The normal distribution is a useful notion in this context: if the hic values are distributed normally (as determined by a statistical test), it is straightforward to obtain such a conclusion, even if adversarial inputs are scarce. To illustrate this process, we trained a VGG16 DNN model (information about the trained model and the dataset appears in Sect. 4), and examined an arbitrary point x0 , from its test set. We randomly generated 10,000 perturbed images around x0 with = 0.04, and ran them through the DNN. For each output vector obtained this way we collected the hic value, and then plotted these values as the blue histogram in Fig. 1. The green curve represents the normal distribution. As the figure shows, the data is normally distributed; this claim is supported by running a “goodness-of-fit” test (explained later).
Fig. 1. A histogram depicting the highest incorrect confidence (hic) scores assigned to each of 10,000 perturbed inputs. These scores are normally distributed.
Our goal is to compute the probability of a fresh, randomly-perturbed input to be distinctly misclassified, i.e. to be assigned a hic score that exceeds a given δ, say 60%. For data distributed normally, as in this case, we begin by calculating the statistical standard score (Z-Score), which is the number of standard deviations by which the value of a raw score exceeds the mean value. Using the Z-score, we can compute the probability of the event using the Gaussian function. In our case, we get hic(x) ∼ N (μ = 0.499, Σ = 0.0592 ), where μ is the = 0.6−0.499 = 1.741, average score and Σ is the variance. The Z-score is δ−μ σ 0.059 where σ is the standard deviation. Recall that our goal is to compute the plr score, which is the probability of the hic value not exceeding δ; and so we obtain that:
98
N. Levy and G. Katz
plr0.6,0.04 (N, x0 ) = NormalDistribution(Z-score) = NormalDistribution(1.741) t=1.741 −t2 1 =√ e 2 dt = 0.9591 2π −∞ We thus arrive at a probabilistic local robustness score of 95.91%. Because our data is obtained empirically, before we can apply the aforementioned approach we need a reliable way to determine whether the data is distributed normally. A goodness-of-fit test is a procedure for determining whether a set of n samples can be considered as drawn from a specified distribution. A common goodness-of-fit test for the normal distribution is the Anderson-Darling test [2], which focuses on samples in the tails of the distribution [4]. In our evaluation, a distribution was considered normal only if the Anderson-Darling test returned a score value greater than α = 0.15, which is considered a high level of significance—guaranteeing that no major deviation from normality was found. 3.3
The Box-Cox Transformation
Unfortunately, most often the hic values are not normally distributed. For example, in our experiments we observed that only 1,282 out of the 10,000 images in the CIFAR10’s test set (fewer than 13%) demonstrated normally-distributed hic values. Figure 2(a) illustrates the abnormal distribution of hic values of perturbed inputs around one of the input points. In such cases, we cannot use the normal distribution function to estimate the probability of adversarial inputs in the population. The strategy that we propose for handling abnormal distributions of data, like the one depicted in Fig. 2(a), is to apply statistical transformations. Such transformations preserve key properties of the data, while producing a normally distributed measurement scale [12]—effectively converting the given distribution into a normal one. There are two main transformations used to normalize probability distributions: Box-Cox [5] and Yeo-Johnson [29]. Here, we focus on the Box-Cox power transformation, which is preferred for distributions of positive data values (as in our case). Box-Cox is a continuous, piecewise-linear power transform function, parameterized by a real-valued λ, defined as follows: Definition 4. The Box-Coxλ power transformation of input x is: λ x −1 if λ = 0 λ BoxCoxλ (x) = ln(x) if λ = 0 The selection of the λ value is crucial for the successful normalization of the data. There are multiple automated methods for λ selection, which go beyond our scope here [22]. For our implementation of the technique, we used the common SciPy Python package [23], which implements one of these automated methods.
RoMA: Robustness Measurement and Assessment
99
Fig. 2. On the top: a histogram depicting the highest incorrect confidence (hic) scores of each of 10,000 perturbed inputs around one of the test points. These scores are not normally distributed. Beneath: the same scores after applying the Box-Cox power transformation, now normally distributed.
Figure 2(b) depicts the distribution of the data from Fig. 2(a), after applying the Box-Cox transformation, with an automatically calculated λ = 0.534 value. As the figure shows, the data is now normally distributed: hic(x) ∼ N (μ = −0.79, Σ = 0.0922 ). The normal distribution was confirmed with the AndersonDarling test. Following the Box-Cox transformation, we can now calculate the Z-Score, which gives 3.71, and the corresponding plr score, which turns out to be 99.98%. 3.4
The RoMA Certification Algorithm
Based on the previous sections, our method for computing plr scores is given as Algorithm 1. The inputs to the algorithm are: (i) δ, the confidence threshold for a distinctly adversarial input; (ii) , the maximum amplitude of perturbation that can be added to x0 ; (iii) x0 , the input point whose plr score is being computed; (iv) n, the number of perturbed samples to generate around x0 ; (v) N , the neural network; and (vi) D, the distribution from which perturbations are drawn. The algorithm starts by generating n perturbed inputs around the provided x0 , each drawn according to the provided distribution D and with a perturbation that does not exceed (lines 1–2); and then storing the hic score of each of these inputs in the hic array (line 3). Next, lines 5–10 confirm that the samples’ hic values
100
N. Levy and G. Katz
distribute normally, applying the Box-Cox transformation if needed. Finally, on lines 11–13, the algorithm calculates the probability of randomly perturbing the input into a distinctly adversarial input using the properties of the normal distribution, and returns the computed plrδ, (N, x0 ) score on line 14. Algorithm 1. Compute Probabilistic Local Robustness(δ, , x0 , n, N, D) 1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14:
for i := 1 to n do xi = CreatePerturbedPoint(x0 , , D) hic[i] ← Predict(N, xi ) end for if Anderson-Darling(hic = NORMAL) then hic ← Box-Cox(hic) if Anderson-Darling(hic = NORMAL) then Return “Fail” end if end if avg ← Average(hic) std ← StdDev(hic) z-score ← Z-Score(avg,std,BoxCox(δ)) Return NormalDistribution(z-score)
Soundness and Completeness. Algorithm 1 depends on the distribution of hic(x) being normal. If this is initially not so, the algorithm attempts to normalize it using the Box-Cox transformation. The Anderson-Darling goodness-of-fit test ensures that the algorithm does not treat an abnormal distribution as a normal one, and thus guarantees the soundness of the computed plr scores. The algorithm’s completeness depends on its ability to always obtain a normal distribution. As our evaluation demonstrates, the Box-Cox transformation can indeed lead to a normal distribution very often. However, the transformation might fail in producing a normal distribution; this failure will be identified by the Anderson-Darling test, and our algorithm will stop with a failure notice in such cases. In that sense, Algorithm 1 is incomplete. In practice, failure notices by the algorithm can sometimes be circumvented—by increasing the sample size, or by evaluating the robustness of other input points. In our evaluation, we observed that the success of Box-Cox often depends on the value of . Small or large values more often led to failures, whereas mid-range values more often led to success. We speculate that small values of , which allow only tiny perturbation to the input, cause the model to assign similar hic values to all points in the -ball, resulting in a small variety of hic values for all sampled points; and consequently, the distribution of hic values is nearly uniform, and so cannot be normalized. We further speculate that for large values of , where the corresponding -ball contains a significant chunk of the input space, the sampling produces a close-to-uniform distribution of all possible labels, and consequently a close-to-uniform distribution of hic values,
RoMA: Robustness Measurement and Assessment
101
which again cannot be normalized. We thus argue that the mid-range values of are the more relevant ones. Adding better support for cases where Box-Cox fails, for example by using additional statistical transformations and providing informative output to the user, remains a work in progress.
4
Evaluation
For evaluation purposes, we implemented Algorithm 1 as a proof-of-concept tool written in Python 3.7.10, which uses the TensorFlow 2.5 and Keras 2.4 frameworks. For our DNN, we used a VGG16 network trained for 200 epochs over the CIFAR10 data set. All experiments mentioned below were run using the Google Colab Pro environment, with an NVIDIA-SMI 470.74 GPU and a singlecore Intel(R) Xeon(R) CPU @ 2.20GHz. The code for the tool, the experiments, and the model’s training is available online [18]. Experiment 1: Measuring Robustness Sensitivity to Perturbation Size. By our notion of robustness given in Definition 3, it is likely that the plrδ, (N, x0 ) score decreases as increases. For our first experiment, we set out to measure the rate of this decrease. We repeatedly invoked Algorithm 1 (with δ = 60%, n = 1, 000) to compute plr scores for increasing values of . Instead of selecting a single x0 , which may not be indicative, we ran the algorithm on all 10,000 images in the CIFAR test set, and computed the average plr score for each value of ; the results are depicted in Fig. 3, and indicate a strong correlation between and the robustness score. This result is supported by earlier findings [27].
100.00%
99.95%
plr
99.90% 99.85%
99.97% 99.95%
99.93% 99.93%
99.90% 99.90%
99.88% 99.86%
99.85%
99.80%
99.85%
99.85%
99.83%
99.84%
99.80%
99.75% 99.70%
value of ε
Fig. 3. Average plr score of all 10,000 images from the CIFAR10 dataset, computed on our VGG16 model as a function of .
Running the experiment took less than 400 min, and the algorithm completed successfully (i.e., did not fail) on 82% of the queries. We note here that Algorithm 1 naturally lends itself to parallelization, as each perturbed input can be evaluated independently of the others; we leave adding these capabilities to our proof-of-concept implementation for future work.
102
N. Levy and G. Katz
Experiment 2: Measuring Robustness Sensitivity to Training Epochs. In this experiment, we wanted to measure the sensitivity of the model’s robustness to the number of epochs in the training process. We ran Algorithm 1 (with δ = 60%, = 0.04, n = 1, 000) on a VGG16 model trained with a different number of epochs—computing the average plr scores on all 10,000 images from CIFAR10 test set. The computed plr values are plotted as a function of the number of epochs in Fig. 4. The results indicate that additional training leads to improved probabilistic local robustness. These results are also in line with previous work [27]. 100.00%
99.91% 99.93% 99.79%
99.80% 99.60%
99.95% 99.95% 99.95% 99.97% 99.98%
plr
99.49%
99.40%
99.20% 99.00%
99.16%
10
20
30
40
50
60
70
80
90
100
number of epochs
Fig. 4. Average plr score of all 10,000 images from CIFAR10 test set, computed on VGG16 model as a function of training epochs.
Experiment 3: Categorial Robustness. For our final experiment, we focused on categorial robustness, and specifically on comparing the robustness scores across categories. We ran Algorithm 1 (δ = 60%, = 0.04, and n = 1, 000) on our VGG16 model, for all 10,000 CIFAR10 test set images. The results, divided by category, appear in Table 5. For each category we list the average plr score, the standard deviation of the data (which indicates the scattering for each category), and the probability of an adversarial input (the “Adv” column, calculated as 1 − plr). Performing this experiment took 37 min. Algorithm 1 completed successfully on 90.48% of the queries. The results expose an interesting insight, namely the high variability in robustness between the different categories. For example, the probability of encountering an adversarial input for inputs classified as Cats is four times greater than the probability of encountering an adversarial input for inputs classified as Trucks. We observe that the standard deviation for these two categories is very small, which indicates that they are “far apart”—the difference between Cats and Trucks, as determined by the network, is generally greater than the difference between two Cats or between two Trucks. To corroborate this, we applied a T-test and a binomial test; and these tests produced a similarity score of less than 0.1%, indicating that the two categories are indeed distinctly different. The important conclusion that we can draw is that the per-category robustness of models can be far from uniform.
RoMA: Robustness Measurement and Assessment
103
Fig. 5. An analysis of average, per-category robustness, computed over all 10,000 images from the CIFAR10 dataset.
It is common in certification methodology to assign each sub-system a different robustness objective score, depending on the sub-system’s criticality [10]. Yet, to the best of our knowledge, this is the first time such differences in neural networks’ categorial robustness have been measured and reported. We believe categorial robustness could affect DNN certification efforts, by allowing engineers to require separate robustness thresholds for different categories. For example, for a traffic sign recognition DNN, a user might require a high robustness score for the “stop sign” category, and be willing to settle for a lower robustness score for the “parking sign” category.
5
Summary and Discussion
In this paper, we introduced RoMA—a novel statistical and scalable method for measuring the probabilistic local robustness of a black-box, high-scale DNN model. We demonstrated RoMA’s applicability in several aspects. The key advantages of RoMA over existing methods are: (i) it uses a straightforward and intuitive statistical method for measuring DNN robustness; (ii) scalability; and (iii) it works on black-box DNN models, without assumptions such as Lipschitz continuity constraints. Our approach’s limitations stem from the dependence on the normal distribution of the perturbed inputs, and its failure to produce a result when the Box-Cox transformation does not normalize the data. The plr scores computed by RoMA indicate the risk of using a DNN model, and can allow regulatory agencies to conduct risk mitigation procedures: a common practice for integrating sub-systems into safety-critical systems. The ability to perform risk and robustness assessment is an important step towards using DNN models in the world of safety-critical applications, such as medical devices, UAVs, automotive, and others. We believe that our work also showcases the potential key role of categorial robustness in this endeavor.
104
N. Levy and G. Katz
Moving forward, we intend to: (i) evaluate our tool on additional norms, beyond L∞ ; (ii) better characterize the cases where the Box-Cox transformation fails, and search for other statistical tools can succeed in those cases; and (iii) improve the scalability of our tool by adding parallelization capabilities. Acknowledgments. We thank Dr. Pavel Grabov From Tel-Aviv University for his valuable comments and support.
References 1. Anderson, B., Sojoudi, S.: Data-Driven Assessment of Deep Neural Networks with Random Input Uncertainty. Technical report (2020). arxiv:abs/2010.01171 2. Anderson, T.: Anderson-Darling tests of goodness-of-fit. Int. Encycl. Statist. Sci. 1, 52–54 (2011) 3. Bastani, O., Ioannou, Y., Lampropoulos, L., Vytiniotis, D., Nori, A., Criminisi, A.: Measuring neural net robustness with constraints. In: Proceedings of 30th Conference on Neural Information Processing Systems (NIPS) (2016) 4. Berlinger, M., Kolling, S., Schneider, J.: A generalized Anderson-Darling test for the goodness-of-fit evaluation of the fracture strain distribution of acrylic glass. Glass Struct. Eng. 6(2), 195–208 (2021) 5. Box, G., Cox, D.: An analysis of transformations revisited, rebutted. J. Am. Stat. Assoc. 77(377), 209–210 (1982) 6. Carlini, N., Wagner, D.: Towards Evaluating the Robustness of Neural Networks. In: Proceedings of 2017 IEEE Symposium on Security and Privacy (S&P), pp. 39–57 (2017) 7. Cohen, J., Rosenfeld, E., Kolter, Z.: Certified Adversarial Robustness via Randomized Smoothing. In: Proceedings of 36th International Conference on Machine Learning (ICML) (2019) 8. Dvijotham, K., Garnelo, M., Fawzi, A., Kohli, P.: Verification of Deep Probabilistic Models. Technical report (2018). arXiv:abs/1812.02795 9. European Union Aviation Safety Agency: Artificial Intelligence Roadmap: A Human-Centric Approach To AI In Aviation (2020). https://www.easa.europa. eu/newsroom-and-events/news/easa-artificial-intelligence-roadmap-10-published 10. Federal Aviation Administration: RTCA Inc, Document RTCA/DO-178B (1993). https://nla.gov.au/nla.cat-vn4510326 11. Goodfellow, I., Shlens, J., Szegedy, C.: Explaining and Harnessing Adversarial Examples. Technical report (2014). arXiv:abs/1412.6572 12. Griffith, D., Amrhein, C., Huriot, J.M.: Econometric Advances in Spatial Modelling and Methodology: Essays in Honour of Jean Paelinck. ASTA, Springer Science & Business Media, New York (2013). https://doi.org/10.1007/978-1-4757-2899-6 13. Hammersley, J.: Monte Carlo Methods. MSAP, Springer Science & Business Media, Dordrecht (2013). https://doi.org/10.1007/978-94-009-5819-7 14. Huang, C., Hu, Z., Huang, X., Pei, K.: Statistical certification of acceptable robustness for neural networks. In: Proceedings International Conference on Artificial Neural Networks (ICANN), pp. 79–90 (2021) 15. Katz, G., Barrett, C., Dill, D., Julian, K., Kochenderfer, M.: Reluplex: an efficient SMT solver for verifying deep neural networks. In: Proceedings of 29th International Conference on Computer Aided Verification (CAV), pp. 97–117 (2017)
RoMA: Robustness Measurement and Assessment
105
16. Katz, G., Barrett, C., Dill, D., Julian, K., Kochenderfer, M.: Reluplex: a calculus for reasoning about deep neural networks. In: Formal Methods in System Design (FMSD) (2021) 17. Landi, A., Nicholson, M.: ARP4754A/ED-79A-guidelines for development of civil aircraft and systems-enhancements, novelties and key topics. SAE Int. J. Aerosp. 4, 871–879 (2011) 18. Levy, N., Katz, G.: RoMA: Code and Experiments (2022). https://drive.google. com/drive/folders/1hW474gRoNi313G1 bRzaR2cHG5DLCnJl 19. Mangal, R., Nori, A., Orso, A.: Robustness of neural networks: a probabilistic and practical approach. In: Proceedings of 41st IEEE/ACM International Conference on Software Engineering: New Ideas and Emerging Results (ICSE-NIER), pp. 93– 96 (2019) 20. Michelmore, R., Kwiatkowska, M., Gal, Y.: Evaluating Uncertainty Quantification in End-to-End Autonomous Driving Control. Technical report (2018). arXiv:abs/1811.06817 21. Pereira, A., Thomas, C.: Challenges of machine learning applied to safety-critical cyber-physical systems. Mach. Learn. Knowl. Extract. 2(4), 579–602 (2020) 22. Rossi, R.: Mathematical Statistics: an Introduction to Likelihood Based Inference. John Wiley & Sons, New York (2018) 23. Scipy: Scipy Python package (2021). https://scipy.org 24. Taherdoost, H.: Sampling methods in research methodology; how to choose a sampling technique for research. Int. J. Acad. Res. Manage. (IJARM) (2016) 25. Tit, K., Furon, T., Rousset, M.: Efficient statistical assessment of neural network corruption robustness. In: Proceedings of 35th Conference on Neural Information Processing Systems (NeurIPS) (2021) 26. Vidot, G., Gabreau, C., Ober, I., Ober, I.: Certification of Embedded Systems Based on Machine Learning: A Survey. Technical report (2021). arXiv:abs/2106.07221 27. Webb, S., Rainforth, T., Teh, Y., Kumar, M.: A Statistical Approach to Assessing Neural Network Robustness. Technical report (2018). arXiv:abs/1811.07209 28. Weng, L., et al.: PROVEN: verifying robustness of neural networks with a probabilistic approach. In: Proceedings of 36th International Conference on Machine Learning (ICML) (2019) 29. Yeo, I.K., Johnson, R.: A new family of power transformations to improve normality or symmetry. Biometrika 87(4), 954–959 (2000)
Independent Relationship Detection for Real-Time Scene Graph Generation Tianlei Jin(B) , Wen Wang, Shiqiang Zhu, Xiangming Xi, Qiwei Meng, Zonghao Mu, and Wei Song(B) Zhejiang Laboratory, Intelligent Robot Research Center, Hangzhou, China {jtl,wangwen,zhusq,xxm21,mengqw,muzonghao,weisong}@zhejianglab.com
Abstract. The current scene graph generation (SGG) task still follows the method of first detecting objects-pairs and then predicting relationships between objects-pairs. This paper introduces a parallel SGG thought that decouples relationship detection and object detection. In detail, we propose an independent visual relationship detection method, ‘Relationship You Only Look Once’ (RYOLO), which calculates relationships directly from the input image. For SGG, we present Similar Relationship Suppression and Object Matching Rules to match relationships and detected objects. In this way, the relationship detection and object detection can be calculated in parallel, and detected relationships can easily cooperate with detected objects to generate diversified scene graphs. Finally, our thought has verified the feasibility on the public Visual Genome dataset, and our method may be the first to attain realtime SGG. Keywords: Relationship Detection Relationship You Only Look Once
1
· Scene Graph Generation ·
Introduction
Recently, computer vision has achieved great success in visual perceptual tasks, such as object detection. However, generating cognitive relationships from perceptual objects is still challenging. Scene graph generation (SGG) is an essential method for building relationship graphs between individual objects in the scene. In fact, the scene graph is often used as an introductory module to help highlevel visual understanding tasks, such as image captioning [1], visual question answering [2], and visual grounding [3]. In the SGG task, one relationship between two objects can be represented as a triple: , such as . Traditional SGG works [4–8] always relies on a series structure, as shown in Fig. 1(A). In the first stage, an image is fed into an object detection model to get object proposals, and in the second stage, a relationship prediction model is used to predict relationships based on these object proposals. In this case, some SGG works [5,8] rely on the two-stage object detection [9,10] to obtain intermediate c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 106–118, 2023. https://doi.org/10.1007/978-981-99-1639-9_9
Independent Relationship Detection for Real-Time Scene Graph Generation
107
Fig. 1. (A) Traditional series SGG thought. (B) Our parallel SGG thought. (C) Independent object detection results calculate the center of the bounding box. (D) Our relationship detection results on the grids. (E) Object matching for scene graph generation. (F) The direction anchor for our relationship detection method.
features of object proposals through RoIAlign. Other SGG works [7,11] take the bounding box and category from the object detection as input to the predicted relationships. In general, the series structure makes the relationship prediction need wait for object detection processing, which restricts the real-time scene graph generation. In addition, each object-pair (subject and object) must perform one operation to predict the relationship, which is a quadratic number time complexity [4]. A large number of objects detected will seriously slow down the inference speed of SGG. However, for agents such as robots, the surrounding scene is changing in all the time, so a real-time SGG method is of great significance for rapid response of agents. In this paper, we propose a parallel SGG thought with decoupling object detection and relationship detection for the real time SGG, shown in Fig. 1(B). In our thought, the object detection and the relationship detection are performed in parallel, and the scene graph is generated by combining the two results. For an image, we use a normal object detection method such as YOLOv5 [12] to get the position and category of objects as shown in Fig. 1(C). Meanwhile, we use another independent visual relationship detection method to predict the relationships existing in the image. Each predicted relationship contains a start position, an end position and a relation type, as shown in Fig. 1(D). Afterward, we use the start and end positions to match the nearest object based on the object center positions from object detection results. Once the start and end positions are both matched by different objects, the triple relationship in the scene graph is established, as shown in Fig. 1(E). The object matched by the start position is the subject-object in the triple relationship, while the object matched by the end position is the object-object in the triple relationship, and the relation type is
108
T. Jin et al.
the predicate. In this way, matching the nearest object is a batch minimum operation that replaces previous objects-pairs predictions and improves operational efficiency. In addition, because the visual relationship detection is independent, it can easily cooperate with any object detection model to generate scene graphs. The key technology of our SGG thought is the visual relationship detection method. Inspired by the YOLO [13], we propose a novel independent visual relationship detection method RYOLO that can predict relationships from a whole image without intermediate steps such as object proposals [5,14] or knowledge embedding [15,16]. In detail, same as common visual tasks, an image is input into a backbone network and get a feature map. The feature map is composed of W × H grid cells. As shown in Fig. 1(F), for each grid cell, we preset some direction anchors in the polar coordinate. Each direction anchor contains an initial length ρanchor and an initial radian direction θanchor . One object can be assigned to a unique grid cell based on its center position. Then, the relationship of this object can be expressed as a direction vector from the grid cell, pointing to the center position of another object. Therefore, the neural network will predict direction offsets of the direction anchor Δρ, Δθ and make direction anchors close to direction vectors after regulating by direction offsets. In addition, in order to get the relation type, the neural network will output a confidence score and relation type scores for each direction anchor while predicting direction offsets. Overall, our contribution can be summarized as: • We propose a new parallel scene graph generation thought that decouples object detection and relationship detection. We verify the feasibility of this thought on the public dataset and achieve real-time scene graph generation. • We propose an independent visual relation detection method RYOLO with preset direction anchors in the polar coordinate. RYOLO can cooperate with any object detection model to generate scene graph.
2
Related Work
From the perspective of the inputting information, we categorize recent SGG researches into three: First, using external knowledge. VRD [17] and UVTransE [15] introduce an external language model to embed word features of objects and relations. GB-NET [6] introduces an external commonsense knowledge graph into SGG. Second, using statistical context information from the dataset, VCTree [18] constructs tree structure with statistical information, while KERN [19] constructs knowledge graph. Third, only using the visual image, IMP [20], Graph R-CNN [4] and FCSGG [21] input the whole image and output the scene graph without additional operation. Our method also falls into this category. From the perspective of the relationship prediction method, MOTIFS [5], VCTree [18], CogTree [22] construct grouping ordered objects structures and use RNN or LSTM to predict relationships. Graph R-CNN [4], GPS-Net [14] and KERN [19] construct graph structures and use graph neural networks to predict relationships. TDE [23] and PUM [24] put forward new modules, and
Independent Relationship Detection for Real-Time Scene Graph Generation
109
Fig. 2. The example of our relationship detection method RYOLO and scene graph generation method. We use the yolov5 backbone for feature extraction and generate multi-scale feature maps through the neck structure. Different direction anchors are preset on feature maps. Each grid tensor contains some clusters and each cluster contains a confidence, a direction offset and some relation type scores. The detected relationships generates a scene graph through similarity relationship suppression (SRS) and object matching.
improve the performance based on previous methods. However, these methods are all based on the series SGG structure: using Faster-RCNN to get object proposals or features, then objects-pairs feature finetunes iteratively for relationship prediction. This structure causes the relationship prediction heavily dependent on the object detection, and the inference speed is slow. Pixels2Graphs [25] and FCSGG [21] predict objects and relationships in one model which tends to the parallel structure. This paper, we propose a parallel SGG thought and introduce RYOLO to detect potential relationships with direction directly from images.
3 3.1
Method Independent Relationship Detection
We redesign the output of an excellent object detection work yolov5 [12] and make it can be used for visual relationship detection. We call it ‘Relationship You Only Look Once’, RYOLO. As shown in Fig. 2, the whole image is input into the yolov5’s backbone to extract features. The neck structure can generate multi-scale feature maps, and we finally get F different scale feature maps. We preset a set of direction anchors for each feature map, including ρanchor and θanchor , and direction anchors apply to all grid cells. On each grid cell of the feature map, the grid tensor outputs several clusters, which contains direction offsets Δρ, Δθ, confidence c of the relationship, and scores s of each relation type. We will introduce how to get the start position, the end position, and the relation type based on the direction anchor in detail. Relationship Calculation. On the whole, we can get F feature maps, and we set F = 3 means 3 different scale feature maps. For each feature map f with
110
T. Jin et al.
f ∈ F , it has i × j grid cells, while i, j ∈ Wf , Hf and W , H are the scale of the feature map. The scale between the input image and the feature map f can be expressed as Sf . Each grid cell on the feature map can be the starting position: start positionf ij = i × Sf , j × Sf .
(1)
On each feature map f , we preset K direction anchors in the polar coordinate system, and we set K = 6 means 6 different directions. For each direction anchor danchor , we can represent as danchor = [ρanchor , θfanchor ], f ∈ F and k ∈ K fk fk fk k anchor anchor anchor with ρf k ∈ [0, 1] and θf k ∈ [0, 2π]. ρf k is the normalized length in the polar coordinate system, while θfanchor is the radian direction. In our k anchor work, we divide the radian direction θ into K evenly for diverse directions. We hope to predict the short-distance relationship on the large-scale feature map while the long-distance relationship on the small-scale feature map. The larger the feature map, the smaller ρanchor . On the feature map f , each grid tensor also contains K predicted clusters. Each cluster contains the direction , θfanchor ] on the offset Δdf kij = [Δρf kij , Δθf kij ] for the direction anchor [ρanchor fk k grid cell i, j, which makes the direction anchor can point to the end position. Therefore, we can calculate the end position by the following formula: Δρf kij cos(θfanchor + Δθf kij ) d xf kij = ρanchor fk k Δρf kij sin(θfanchor + Δθf kij ), d yf kij = ρanchor fk k end positionf kij = start positionf ij + df kij = i × Sf + d xf kij , j × Sf + d yf kij ,
(2)
(3)
In fact, although we named Δρ direction offset, it is actually a scale factor to adjust the length ρanchor . Δθ add on θanchor to adjust the radian direction. In this way, each cluster can generate a predicted direction for connection between a start position and an end position. The confidence c in each cluster is used to judge whether the relationship is established. Each cluster also contains scores s independently of the R relation type, which is used to predict the relation type between the start position and end position. Loss and Training. The loss consists of direction loss, confidence loss, relation loss. We first pick out the start positionf ij with the relationship, which can gen. For direction loss Ldir , the label direction offset erate a direction vector dvector f ij label label label Δdf kij = [Δρf kij , Δθf kij ] from the direction anchor danchor to the direction f kij vector dvector can be simply expressed as: f ij vector vector Δρlabel /ρanchor , Δθflabel − θfanchor , f kij = ρf ij f kij kij = θf ij kij
(4)
We calculate the loss between the predicted direction Δdpred f kij and the label direclabel tion Δdf kij through L2 loss, which hopes the predicted direction vector is equal to the label direction vector as much as possible. For confidence loss Lconf , when the start positionf ij cannot generate directo 0. But when there is a direction vector on the tion vectors, we set clabel f ij
Independent Relationship Detection for Real-Time Scene Graph Generation
111
start positionf ij , we compare the distance between the direction vector and direction anchors. We retain direction anchors that close to the direction vector, label and set the clabel f kij to 1. Other direction anchors set cf kij to 0. For relation loss Lrel , relation types between subjects and objects are not unique in labels. For an object-pair, and may both exist in the label. We set each relation type score slabel f kijr to 1 if the relation type exists in the label for the corresponding cluster. We also use BCELoss to calculate relation loss Lrel and confidence loss Lconf . In summary, the toal loss can be expressed as: Loss = Lconf + Lrel + Ldir 1 Nconf
F H K W f =1 k=1 i=1 j=1 F
K
label BCE(cpred f kij , cf kij )+
W
H
W
H
R
1 label label c BCE(spred f kijr , sf kijr )+ Nrel f kij i=1 j=1 r=1
(5)
f =1 k=1 F
K
1 label label c ||Δdpred f kij , Δdf kij ||2 , Ndir f kij i=1 j=1 f =1 k=1
In the Eq. 5, Nconf , Nrel and Ndir are used to calculate the average of Lconf , Lrel and Ldir . Only when clabel f kij = 1, the corresponding Lrel and Ldir are counted into loss. During our following training, the size of input images is variable multi-scale similiar with yolov5, and the initial learning rate is 0.04. After 100k interations, the learning rate decay to 0.001. The model is trained end-to-end using SGD optimizer with the batch size of 32. 3.2
Scene Graph Generation
Similar Relationship Suppression. We can get many relationships through our relationship detection RYOLO, and each relationship contains a start position, an end position, and a relation type. We design multiple direction anchors in multi-scale feature maps and multiple directions, but different direction anchors may predict the same relationships. We propose a similar relationship suppression (SRS) method during SGG. In detail, we filter out all relac tionships by threshold tc , that cpred f kij > t . Then, we compare the relationship sorted by cpred with rest relationships. During the comparison, SRS will suppress similar relationships. The suppression condition can be expressed as: ⎧ ⎨start positionf kij − start positionf k i j < td end positionf kij − end positionf k i j < td (6) ⎩ relation typef kij = relation typef k i j ,
112
T. Jin et al.
Table 1. R@K and ng-R@K evaluation results on VG-150 dataset. G1, G2, G3 stands for group 1, 2, 3. The SGG works in group 1 draw on external knowledge and in group 2 draw on statistical context information. The works in group 3 only use visual images for SGG. Ours work falls into group 3. R@K Method
mAP @50 FPS
PredCls R@20/50/100
SGCls R@20/50/100
SGDet R@20/50/100
G1
GB-Net-β UVTransE
– 23.8
1.92 –
–/66.6/68.2 –/65.3/67.3
–/37.3/38.0 –/35.9/36.6
–/26.3/29.9 –/30.1/33.6
G2
KERN GPS-Net MOTIFS-TDE VCTree
– – – –
1.27 – 1.15 0.75
–/65.8/67.6 67.6/69.7/69.7 33.6/46.2/51.4 60.1/66.4/68.1
–/36.7/37.4 41.8/42.3/42.3 21.7/27.7/29.9 35.2/38.1/38.8
–/27.1/29.8 22.3/28.9/33.2 12.4/16.9/20.3 22.0/27.9/31.3
G3
IMP Graph RCNN FCSGG
20.0 23.0 25.0
– 5.26 12.50
58.5/65.2/67.1 –/54.2/59.1 24.2/31.0/34.6
31.7/34.6/35.4 –/29.6/31.6 13.6/17.1/18.8
14.6/20.7/24.5 –/11.4/13.7 11.5/15.5/18.4
Ours RYOLOs +yolov5s +yolov5l RYOLOl +yolov5s +yolov5l
20.2 26.2 20.2 26.2
35.71 22.9/30.9/32.6 25.82 22.42 23.4/32.1/33.9 17.59
8.8/13.6/15.8 10.8/16.4/19.0 8.9/13.8/16.1 11.0/16.9/19.7
7.6/11.9/14.7 8.7/13.6/16.9 7.6/12.0/15.0 8.7/13.7/17.2
R@K Method
mAP @50 FPS
PredCls SGCls SGDet ng-R@20/50/100 ng-R@20/50/100 ng-R@20/50/100
G1
GB-Net-β
–
1.92
–/83.5/90.3
–/46.9/50.3
–/29.3/35.0
G2
KERN LSBR
– –
1.27 –
–/81.9/88.9 77.9/82.5/90.2
–/45.9/49.0 43.6/46.2/50.2
–/30.9/35.8 26.9/31.4/36.5
G3
Pixels2Graphs FCSGG
– 25.0
0.28 12.50
–/68.0/75.2 28.1/40.3/50.0
–/26.5/30.0 14.2/19.6/24.0
–/9.7/11.3 12.7/18.3/23.0
Ours RYOLOs +yolov5s +yolov5l RYOLOl +yolov5s +yolov5l
20.2 26.2 20.2 26.2
35.71 29.1/42.1/50.8 25.82 22.42 29.9/43.3/52.1 17.59
9.3/13.6/17.2 11.7/17.4/21.7 9.5/13.9/17.4 12.0/17.8/22.0
10.0/14.7/18.4 12.0/17.7/21.9 10.2/15.0/18.7 12.2/18.0/22.3
In suppression condition Eq. 6, f k i j = f kij, and td is the preset distance threshold. If relationships satisfy all three suppression conditions at the same time, low-confidence relationships will be suppressed, and only the highestconfidence relationship will be retained. Object Matching. To generate a scene graph, we need to employ the results of object detection. In other word, we need to match relationships and objects on the image. The start position or end position will query all center coordinates of detected objects and find the nearest object for matching as the subject-object or object-object, shown in Fig. 1(C). During Object Matching, we use the simple distance threshold to eliminate failed matching between the start position or end position and the center point of objects. The start position and the end position cannot be matched by the same object.
Independent Relationship Detection for Real-Time Scene Graph Generation
113
Table 2. mR@K, zsR@K and zsR@K, ng-zsR@K evaluation results on VG-150 dataset. mR@K zsR@K
PredCls
Method
mR@ 50/100
zsR@ 50/100
SGCls mR@ 50/100
zsR@ 50/100
mR@ 50/100
zsR@ 50/100
GB-NET-β KERN VCTree CogTree MOTIFS-TDE VCTree-TDE FCSSG
22.1/24.0 17.7/19.4 17.9/19.4 28.4/31.0 25.5/29.1 25.4/28.4 5.2/6.1
– – – – 14.4/18.2 14.3/17.6 8.6/10.9
12.7/13.4 9.4/10.0 12.7/13.4 15.7/16.7 13.1/14.9 12.2/14.0 2.9/3.4
– – – – 3.4/4.5 3.2/4.0 1.7/2.1
7.1/8.5 6.4/7.3 7.1/8.5 11.1/12.7 8.2/9.8 9.3/11.1 2.6/3.1
– – – – 2.3/2.9 2.6/3.2 1.0/1.4
RYOLOs+yolov5s RYOLOs+yolov5l RYOLOl+yolov5s RYOLOl+yolov5l
5.3/5.9
6.9/7.4
5.7/6.3
7.3/7.9
2.3/2.9 2.8/3.5 2.4/3.1 2.9/3.7
1.2/1.5 1.7/2.1 1.2/1.4 1.6/2.0
2.0/2.6 2.3/2.9 2.1/2.7 2.4/3.1
0.8/0.9 1.1/1.4 0.7/0.9 1.2/1.5
ng-mR@K ng-zsR@K PredCls
SGDet
SGCls
SGDet
Method
ng-mR@ 50/100
ng-zsR@ 50/100
ng-mR@ 50/100
ng-zsR@ ng-mR@ 50/100 50/100
ng-zsR@ 50/100
FCSSG
9.5/14.7
12.8/19.6
6.3/9.4
2.9/4.4
4.7/6.9
1.8/2.7
RYOLOs+yolov5s RYOLOs+yolov5l RYOLOl+yolov5s RYOLOl+yolov5l
9.7/15.4
12.2/19.2
1.6/2.7 2.6/4.8 1.5/2.5 2.4/4.1
4.5/6.8 5.4/8.2 4.6/7.1 5.5/8.4
1.6/2.8 2.7/4.3 1.6/2.6 2.5/4.1
4 4.1
3.9/6.0 5.2/7.8 10.2/16.1 12.5/19.6 4.1/6.2 5.3/8.0
Experiment Dataset, Model and Metrics
Dataset. We train and evaluate our models on the public VG-150 [23]. VG150 contains the most frequent 150 object categories and 50 predicate categories from the Visual Genome dataset [26]. Model. We decouple the relationship detection and object detection for SGG. For the object detection, we train independent model yolov5s and yolov5l with different backbone [12]. Similarly, for the relationship detection, we introduce our RYOLO method but use two backbone networks named RYOLOs and RYOLOl. Both Yolov5 and RYOLO are trained in the VG-150 dataset. We will show the impact of different performances object detection models and relationship detection models on SGG. Metrics. We analyze our method on three standard SGG evaluation tasks: Predicate Classification (PredCls), Scene Graph Classification (SGCls), and Scene Graph Detection (SGDet). The PredCls task only needs to perform our
114
T. Jin et al.
relationship detection and object information can be obtained from label. The SGCls and SGDet tasks need to employ the results of object detection. The conventional metric of SGG is Recall@K (R@K) [17]. Since predicates are not exclusive, Pixels2Graphs [25] proposes No Graph Constraint Recall@K (ng-R@K). Mean Recall@K (mR@K) [18] and No Graph Constraint Mean Recall@K (ngmR@K) optimize the influence of high-frequency predicates. For verify generalization of SGG, Zero Shot Recall@K (zR@K) [17] and No Graph Constraint Zero Shot Recall@K (ng-zR@K) [23] count triplet relationships that not occurred in the training. In addition, the object detection metric mAP@50 [28] is also displayed as a reference. 4.2
Results and Discussion
Results Analysis. The results of our work are shown in Table 1 and Table 2. We decouple the dependence of relationship detection and object detection in traditional series SGG, so our relationship detection method RYOLO can deal with the PredCls task independently. The traditional SGG methods extract the features of specific objects based on the ground truth bounding box and category, and the predicted relationship is more accurate. RYOLO predicts the relationship from the whole image whether the object bounding box is known or not. The ground truth bounding box and category are only used in the object matching process. We introduce an additional independent object detection method yolov5 to help RYOLO complete the SGCls and SGDet tasks. From the results, our method cannot bring accuracy improvements in SGG tasks. For a fair comparison with previous works, we divide previous works into three groups: using external knowledge, using statistical context information, and only using visual images. Our method is competitive in SGDet tasks compared with methods in G3. Compared to previous methods in R@K and mR@K metrics, our method performs similarly on SGCls and SGDet tasks. The main reason is that previous methods highly depend on object detection results, and biased detected objects drop performance from SGCls to SGDet. But our method is independent and detects relationships from the whole image. Biased detected object only affect object matching slightly. In the No Graph Constraint condition, each objects-pair can predict multiple possible relation types. Similar to other methods, RYOLO can recall more relationships in this condition. As for our results of zsR@K and ng-zsR@K, these two metrics are used to judge whether the SGG method can predict the unseen triple relationship in training. They are not common in previous SGG evaluations, and there are limited references. Since our SGG method is independent, it is not restricted by the objects-pairs. From the ng-zsR@K results, our method can predict unseen triple relationships more than the latest fully convolutional scene graph generation method FCSGG. In addition, it seems that SGDet performs better than SGCls when no graph constraint. The reason is that object matching uses detected object positions before Non-Maximum Suppression in SGDet, rather than using ground truth object bounding box in SGCls.
Independent Relationship Detection for Real-Time Scene Graph Generation
115
Fig. 3. Visualization results of scene graph combined with OWOD [27] and our RYOLOs. In these examples, the color of the bounding box is the same as the color of the text label and each objects-pair only shows the highest-confidence relation.
Advantages. The advantages of our SGG method with parallel thought lie in its inference speed and flexibility. In terms of inference speed, our method has an obvious advantage. To compare the speed with the previous works under the same GPU condition [21], we also perform our method on an NVIDIA GeForce GTX 1080 Ti GPU with batch size 1. Using the combination of RYOLOs and yolov5s, our method realizes real-time scene graph generation (FPS > 30). Compare the fast SGG method FCSGG (HRNetW48-1s) [21], our method improves inference speed by nearly three times with less loss of precision, in Table 1. This is because we adopt a light full convolutional backbone network from yolov5 with less computation and our method supports the parallel structure and can simultaneously detect objects and relationships through multiple processes. In addition, traditional methods with objects-pairs relationship detection have a quadratic number time complexity, but RYOLO detects all relationships at the same time. Object matching is a batch operation to find the minimum position in the matrix operation, and it can maintain high-speed calculation. Furthermore, in the case of a single GPU, the parallel structure can not well reflect the advantages. Object detection yolov5 and relationship detection RYOLO, run in parallel only about 5% faster than running in series, as we show the inference time in Table 1. But nearly 30% faster with multiple GPUs in our experiment. In terms of flexibility, thanks to the decoupling of object detection and relationship detection, RYOLO can easily cooperate with any object detection model to generate scene graphs. A better object detection model can reduce false detections and missed detections, and improve the accuracy of SGG. In Table 1, we can easily replace the different yolov5 models without retraining the relationship detection model. We believe that this independent method has a more comprehensive and wider practical application value. In addition, we try to replace yolov5 with open-world object detection OWOD [27] for the open-world scene graph generation. As shown in Fig. 3, OWOD can detect unknown objects and mark them as unknown. Similar to human cognition, humans may not recognize a new object but can analyze the relationship between this object and other objects. Combining OWOD and RYOLO in Fig. 3, we can generate some novel relationships, such as and . Based on these relationships, the computer can infer unknown attributes through knowledge, such as the unknown object on the diningtable may be tableware.
116
T. Jin et al.
Limitations. Our method sacrifices accuracy for speed and flexibility. As shown in Fig. 3, there are still some error relationships such as the relationship . Due to our parallel thought, relationships are detected directly from the whole image, the details of the objects themselves will be ignored. Although our relationship detection and object detection are independent, we rely on the consistency of the object center positions detected in the SGG. The current object matching only considers location information from object detection and relationship detection without contextual content. It cannot avoid some false matches. For example, a man and a shirt may fall on the same grid cell, and the nearest object matching may produce a wrong triple rather than . In the long-distance direction prediction in the polar coordinate, a slight shift in the radian causes a huge error in the end position.
5
Conclusion
In this paper, we rethink the methods of SGG and introduce a parallel SGG thought with an independent visual relationship detection method RYOLO. In RYOLO, we design direction anchors to directly predict relationships from the image without relying on object detection results. As for SGG, object detection and relationship detection results are correlated through object matching rules to generate triples and the scene graph. This way, we decouple object detection and relationship detection and realize real-time SGG. We expect our method can become a new baseline for the real-time scene graph generation. In the future, we will consider incorporating knowledge and statistical context information to improve the performance of real-time SGG. Acknowledgement. The research was supported by the National Natural Science Foundation of China (Grant No. U21A20488) and the ‘10000 Talents Plan’ of Zhejiang Province (Grant, No. 2019R51010). The research was supported by Lab-initiated Research Project of Zhejiang Lab (No. G2021NB0AL03).
References 1. Gu, J., Joty, S., Cai, J., Zhao, H., Yang, X., Wang, G.: Unpaired image captioning via scene graph alignments. In: CVPR (2019) 2. Hudson, D.A., Manning, C.D.: Learning by abstraction: the neural state machine. In: NIPS (2019) 3. Wan, H., Luo, Y., Peng, B., Zheng, W.: Representation learning for scene graph completion via jointly structural and visual embedding. In: IJCAI (2018) 4. Yang, J., Lu, J., Lee, S., Batra, D., Parikh, D.: Graph R-CNN for scene graph generation. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11205, pp. 690–706. Springer, Cham (2018). https://doi.org/10.1007/ 978-3-030-01246-5 41
Independent Relationship Detection for Real-Time Scene Graph Generation
117
5. Zellers, R., Yatskar, M., Thomson, S., Choi, Y.: Neural motifs: scene graph parsing with global context. In: CVPR (2018) 6. Zareian, A., Karaman, S., Chang, S.-F.: Bridging knowledge graphs to generate scene graphs. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12368, pp. 606–623. Springer, Cham (2020). https://doi.org/10. 1007/978-3-030-58592-1 36 7. Zhang, J., Kalantidis, Y., Rohrbach, M., Paluri, M., Elgammal, A., Elhoseiny, M.: Large-scale visual relationship understanding. In: AAAI (2019) 8. Lin, X., Li, Y., Liu, C., Ji, Y., Yang, J.: Scene graph generation based on noderelation context module. In: Cheng, L., Leung, A.C.S., Ozawa, S. (eds.) ICONIP 2018. LNCS, vol. 11302, pp. 134–145. Springer, Cham (2018). https://doi.org/10. 1007/978-3-030-04179-3 12 9. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. In: NIPS (2015) 10. He, K., Gkioxari, G., Doll´ ar, P., Girshick, R.: Mask R-CNN. In: ICCV (2017) 11. Gkanatsios, N., Pitsikalis, V., Koutras, P., Maragos, P.: Attention-translationrelation network for scalable scene graph generation. In: ICCV Workshops (2019) 12. Glenn-Jocher, et al.: yolov5 (2021). https://github.com/ultralytics/yolov5 13. Redmon, J., Divvala, S., Girshick, R., Farhadi, A.: You only look once: unified, real-time object detection. In: CVPR (2016) 14. Lin, X., Ding, C., Zeng, J., Tao, D.: GPS-Net: graph property sensing network for scene graph generation. In: CVPR (2020) 15. Hung, Z., Mallya, A., Lazebnik, S.: Contextual translation embedding for visual relationship detection and scene graph generation. IEEE Trans. Pattern Anal. Mach. Intell. 43(11), 3820–3832 (2020) 16. Gu, J., Zhao, H., Lin, Z., Li, S., Cai, J., Ling, M.: Scene graph generation with external knowledge and image reconstruction. In: CVPR (2019) 17. Lu, C., Krishna, R., Bernstein, M., Fei-Fei, L.: Visual relationship detection with language priors. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 852–869. Springer, Cham (2016). https://doi.org/10.1007/ 978-3-319-46448-0 51 18. Tang, K., Zhang, H., Wu, B., Luo, W., Liu, W.: Learning to compose dynamic tree structures for visual contexts. In: CVPR (2019) 19. Chen, T., Yu, W., Chen, R., Lin, L.: Knowledge-embedded routing network for scene graph generation. In: CVPR (2019) 20. Xu, D., Zhu, Y., Choy, C.B., Fei-Fei, L.: Scene graph generation by iterative message passing. In: CVPR (2017) 21. Liu, H., Yan, N., Mortazavi, M., Bhanu, B.: Fully convolutional scene graph generation. In: CVPR (2021) 22. Yu, J., Chai, Y., Wang, Y., Hu, Y., Wu, Q.: CogTree: cognition tree loss for unbiased scene graph generation. In: IJCAI (2021) 23. Tang, K., Niu, Y., Huang, J., Shi, J., Zhang, H.: Unbiased scene graph generation from biased training. In: CVPR (2020) 24. Yang, G., Zhang, J., Zhang, Y., Wu, B., Yang, Y.: Probabilistic modeling of semantic ambiguity for scene graph generation. In: CVPR (2021) 25. Newell, A., Deng, J.: Pixels to graphs by associative embedding. In: NIPS (2017) 26. Krishna, R., et al.: Visual genome: connecting language and vision using crowdsourced dense image annotations. Int. J. Comput. Vis. 123(1), 32–73 (2017). https://doi.org/10.1007/s11263-016-0981-7
118
T. Jin et al.
27. Joseph, K.J., Khan, S., Khan, F., Balasubramanian, V.: Towards open world object detection. In: CVPR (2021) 28. Everingham, M., Eslami, S.M.A., Van Gool, L., Williams, C.K.I., Winn, J., Zisserman, A.: The Pascal visual object classes challenge: a retrospective. Int. J. Comput. Vis. 111(1), 98–136 (2014). https://doi.org/10.1007/s11263-014-0733-5
A Multi-label Feature Selection Method Based on Feature Graph with Ridge Regression and Eigenvector Centrality Zhiwei Ye, Haichao Zhang(B) , Mingwei Wang, and Qiyi He School of Computer Science, Hubei University of Technology, Wuhan, China [email protected], {20HaichaoZhang,wmwscola,qiyi.he}@hbut.edu.cn
Abstract. In multi-label learning, instances with multiple semantic labels suffer from the impact of high feature dimensionality. The goal of multi-label feature selection is to process multi-dimensional and multilabel data and keep relevant information in the original data. However, varieties of existing feature ranking-based multi-label feature selection methods do not take into account the relationship between features. Methods based on feature graph can present and utilize the association between features but still have imperfections. Among them, the exploration of the correlation between features and labels is not sufficient, and there is no efficient use of the association between features to evaluate features. In this paper, a multi-label feature selection method based on feature graph with ridge regression and eigenvector centrality is proposed. Ridge regression is used to learn a valid representation of feature label correlation. The learned correlation representation is mapped to a graph to efficiently display and use feature relationships. Eigenvector centrality is used to evaluate nodes in the graph to obtain scores for features. The effectiveness of the proposed method is testified according to three evaluation metrics (Ranking loss, Average precision, and Micro-F1) on four datasets by comparison with seven state-of-the-art multi-label feature selection methods. Keywords: multi-label learning · feature selection ridge regression · eigenvector centrality
1
· feature graph ·
Introduction
Multi-label learning has a variety of real-world applications, such as classification of protein sequence functions [1], video classification [2], and image recognition [3]. However, with the development of big data and information collection This work is supported by the National Natural Science Foundation of China under grant NO. 61502155, National Natural Science Foundation of China under grant NO. 61772180 and Fujian Provincial Key Laboratory of Data Intensive Computing and Key Laboratory of Intelligent Computing and Information Processing, Fujian: BD201801. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 119–129, 2023. https://doi.org/10.1007/978-981-99-1639-9_10
120
Z. Ye et al.
technology, the increasingly large feature space of multi-label data with few relevant features is occupied by large-scale features with low differentiation, which imposes heavy storage requirements on multi-label learning and may lead to degraded learning performance. Feature selection is a data dimensionality reduction technique aiming to eliminate redundant and irrelevant features in machine learning and data mining and its goal is to obtain a representative feature subset from the original data [4]. Therefore, it is necessary to select discriminative features for multi-label data. Multi-label feature selection is an extension of traditional single-label feature selection which is could be divided into the wrapper, embedded, and filter [5]. Compared to the others, the filter is independent of learning algorithm training and uses information about the internal structure of the data to evaluate features, making it more flexible and efficient to run [6–8]. The filter method, depending on its manipulation of the data, is further divided into problem transformation and algorithm adaptation. Problem transformation converts multilabel data to single-label data and then uses the traditional single-label feature selection method such as ELA-CHI [9], PPT-CHI [10] and PPT-MI [11]. However, after transforming the data, these methods may introduce noise and lose instance information. Although methods based on algorithm adaptation using information theory consider redundancy between features, other relationships between features are ignored [12,13]. Graph is a common topology for showing connections and relationships between objects, which have been used in algorithm adaptation-based methods [7,8,14,15]. For the feature graph-based methods, the nodes in the graph are features and the whole graph shows the connections between features. An important part of multi-label feature selection is to mine feature label correlation for subsequent operations. Considering that when dealing with multi-label data, the feature matrix and label matrix are required to be globally processed to fully exploit the potential feature-label correlation in them. However, existing feature graph-based methods slice these two matrices and calculate the correlations using Pearson coefficients or mutual information, and the global information is neglected. To tackle the issues mentioned before, a multi-label feature selection method based on feature graph with ridge regression and eigenvector centrality(RRECFS) is proposed. Ridge regression, a linear estimator, efficiently mines the correlation coefficients of the explanatory and response variables [16]. Ridge regression is applied to learn the correlation coefficients of the feature and label matrices without slicing the matrix, which preserves the global correlation between features and labels. The eigenvector centrality [17,18] is employed to evaluate the importance of the features in the graph to acquire an optimal feature ranking. Eigenvector centrality takes into account that the importance of a node is influenced by the importance of its neighboring nodes, which means that all nodes are not equivalent, and a high-centricity node in a node’s neighborhood contributes more to that node’s score than a low-centricity node. Therefore, nodes that have a significant influence on the graph can be selected in this way. The main contributions of this paper are described as follows:
Feature Graph-Based Feature Selection Method
121
1. Ridge regression solves the problem of not efficiently mining feature label correlation in feature graph methods. 2. Eigenvector centrality is used to further exploit and utilize the potential relationships between features for evaluating features. The remainder of this paper is organized as follows. The proposed method is described in Sect. 2. Section 3 presents experimental results and discussion. Finally, the paper is concluded in Sect. 4.
2
The Proposed Method
In this section, feature label correlation is derived from the learning process of ridge regression. After that, the acquired correlation representation is mapped to a feature graph with a weighted affinity matrix. An efficient measure of evaluating graph nodes, i.e., eigenvector centrality, obtains importance scores for all features in the graph to accomplish the feature ranking process. 2.1
Exploring Feature Label Correlation via Ridge Regression
For the multi-label data, the structure of the feature matrix and label matrix is shown as: ⎡ ⎤ ⎤ ⎡ Y11 Y12 · · · Y1q X11 X12 · · · X1m ⎢ Y21 Y22 · · · Y2q ⎥ ⎢ X21 X22 · · · X2m ⎥ ⎢ ⎥ ⎥ ⎢ Y =⎢ . (1) X=⎢ . ⎥ . . .. . . .. ⎥ . . . . . . ⎣ . ⎣ . . . ⎦ . . ⎦ . . Xn1 Xn2 · · · Xnm
Yn1 Yn2 . . . Ynq
Least squares estimation plus a regular term is the basic form of ridge regression, and the original linear regression can produce severe overfitting in the presence of anomalous data, where there may be cointegration between features making the estimated parameters non-valid. Ridge regression which mitigates
Fig. 1. An example of learned feature label correlation
122
Z. Ye et al.
the consequences of overfitting and multicollinearity is to capture feature-label correlation and is defined as: ˆ = arg min ||XW − Y || + λ||W ||22 = XX T + λI −1 X T Y W (2) W
In (2), I ∈ RN ∗N is a unit matrix, and λ is a regularization parameter. As can be seen from the simple example in Fig. 1, W ∈ Rm∗q is the regression coefficient matrix whose element Wij represents the correlation between i-th feature in X and j-th label in Y . The higher value of Wij , the stronger dependence between feature i to label j. Then, the framework is constructed based on the feature label correlation matrix (FLCM): ⎡ ⎤ W11 W12 · · · W1q ⎢ W21 W22 · · · W2q ⎥ ⎢ ⎥ (3) F LCM = ⎢ . .. . . . ⎥ ⎣ .. . .. ⎦ . Wm1 Wm2 · · · Wmq
F1 F2
F6
F3
F5 F4
Fig. 2. Feature graph built by features using FDM
For the goal to compare features, a distance metric is needed. The Euclidean distance between feature i and feature j is given as:
2 (4) F Dij = (F LCM (i, :) − F LCM (j, :)) The Euclidean distance between each row is used to get the feature distance matrix (FDM): ⎡ ⎤ F D11 F D12 · · · F D1m ⎢ F D21 F D22 · · · F D2m ⎥ ⎢ ⎥ F DM = ⎢ . (5) ⎥ .. .. .. ⎣ .. ⎦ . . . F Dm1 F Dm2 · · · F Dmm
Feature Graph-Based Feature Selection Method
123
Algorithm 1. The pseudo-codes for RRECFS Input: Training multi-label datasetD = [X, Y ] Output:Ranked features F rank 1: F rank = ∅ −1 T X Y 2: Calculate W using XX T + λI 3: Construct feature label correlation matrix F LCM with W 4: Calculate the weighted adjacency matrix F DM using Euclidean distance according to F LCM 5: Build weighted feature graph based on F DM 6: Assigning eigenvector centrality to each node according to αc = F DM c 7: Sort c in descending order 8: F rank = c
FDM is used to build the graph, as shown in Fig. 2, where features are used as nodes, and elements of FDM are the edges’ weight between nodes. The graph is denoted as G(F, E), where F represents the set of feature nodes and E is the set of edges between feature nodes. 2.2
Feature Ranking with Eigenvector Centrality
With the weighted graph being developed, a measure is needed to evaluate nodes’ importance. Eigenvector centrality is an important metric for assessing the importance of a node in social network analysis, based on the principle that the high-scoring neighbors of an important node contribute more compared to other nodes. Suppose G(V, E) is the corresponding graph of a social network, consisting of node V and edge E, and A is its corresponding adjacency matrix. The calculation process is to add up all the neighboring nodes’ centrality of this node as the centrality of this node. The calculation formula is as: αci =
n
Aij cj , i = 1, · · · , n
(6)
j=1
The Eq. (6) can be transformed into a matrix equation in linear algebraic form as: αc = Ac (7) where ci and cj represent the centrality of the node i and its neighbor j, respectively, α is the eigenvalue of A, n represents the number of nodes, and c is the eigenvector corresponding to α which is also the centrality of the eigenvector being sought, so the calculation is transformed into a process of finding the eigenvector corresponding to the maximum eigenvalue of A. For the graph, the higher the eigenvector centrality of a feature node, the more important the feature is and the higher its discriminative ability. In this measure, the degree of importance of a feature in eigenvector centrality depends on the importance of its neighboring features. Since the graph is fully connected with undirected connections between each node, the neighboring feature nodes
124
Z. Ye et al. 20NG
0.4 0.35
Rankingloss
20NG
0.7 ELA-CHI PPT-MI MCLS MIFS
FIMF MLACO MGFS RRECFS
0.6
Average-precision
0.45
0.3 0.25 0.2
0.5 0.4 0.3
0.15 0.1 10
20
30
40
50
60
70
Number of selected features
80
90
0.2 10
100
20
30
40
50
60
70
80
90
100
Number of selected features
(a)
(b) 20NG
0.6 0.5
Macro-F1
0.4 0.3 0.2 0.1 0 10
20
30
40
50
60
70
80
90
100
Number of selected features
(c)
Fig. 3. Results of 20NG
of a feature node are all the remaining feature nodes in the graph. Eigenvector centrality for the graph is defined as: αc = F DM c
(8)
α is the eigenvalue of the F DM . The eigenvector corresponding to the largest eigenvalue αmax is used as the centrality. The obtained eigenvector centrality is sorted in descending order to process feature ranking. The pseudo-code to the RRECFS is given in Algorithm 1. Rows 1 to 5 use ridge regression to mine the correlation between features and labels as the basic architecture of the graph. In the last two steps of Algorithm 1, eigenvector centrality calculates the centrality of the features in the graph and takes the obtained centrality as the score of the features. The results of feature ranking are returned based on the features sorted in descending order of scores.
3
Experiments
In this section, the effectiveness of the proposed method are tested on four multilabel datasets with seven state-of-the-art methods: ELA-CHI [9], PPT-MI [11], MCLS [19], MIFS [20], FIMF [12], MLACO [7], MGFS [8]. Then Experimental data sets and settings are described briefly. After that, results and discussion about it are given in detail.
Feature Graph-Based Feature Selection Method
125
Table 1. overview of the data set Data set Labels Instances Features Average label
3.1
20NG
20
19300
1006
1.029
Medical Enron
45
978
1449
1.245
53
1702
1001
3.378
Bibtex
159
7395
1836
2.402
Experimental Data Sets
Four real-world data sets used in the experiment are from Mulan Library.1 . Table 1 provides some statistics on the datasets including the size of the label set, number of instances, number of features, and average labels. All datasets are high-dimensional datasets with dimensionality over 1000. medical
Rankingloss
0.12
ELA-CHI PPT-MI MCLS MIFS
0.1
0.8
FIMF MLACO MGFS RRECFS
0.08 0.06
0.7 0.6 0.5
0.04 0.02 10
medical
0.9
Average-precision
0.14
20
30
40
50
60
70
80
90
0.4 10
100
20
30
40
50
60
70
80
90
100
Number of selected features
Number of selected features
(a)
(b) medical
0.5 0.45
Macro-F1
0.4 0.35 0.3 0.25 0.2 10
20
30
40
50
60
70
80
90
100
Number of selected features
(c)
Fig. 4. Results of medical
3.2
Experimental Settings and Compared Methods
ML-KNN [21] (K = 10) used as a classifier is applied to evaluate the performance of selected feature subsets by these methods. For each dataset, 30 independent 1
http://mulan.sourceforge.net/datasets-mlc.html.
Z. Ye et al. enron
0.12
Rankingloss
0.115 0.11
FIMF MLACO MGFS RRECFS
0.64
0.105 0.1 0.095 0.09 10
enron
0.66 ELA-CHI PPT-MI MCLS MIFS
Average-precision
126
0.62 0.6 0.58 0.56 0.54 0.52
20
30
40
50
60
70
80
90
0.5 10
100
20
30
Number of selected features
40
50
60
70
80
90
100
Number of selected features
(a)
(b) enron
0.16 0.15
Macro-F1
0.14 0.13 0.12 0.11 0.1 0.09 10
20
30
40
50
60
70
80
90
100
Number of selected features
(c)
Fig. 5. Results of enron
tests are performed. For each test, 60% of instances are randomly chosen as the training set, and 40% is the test set. Three multi-label evaluation metrics which are Ranking loss, Average precision, and Macro-F1 are used to show the effects of feature selection for these methods. The regularization parameter λ of RRECFS is set to 10. The parameters of the comparison methods are set according to the default of the corresponding research. 3.3
Results and Discussion
Figures 3, 4 and 5, show the results of RRECFS and compared methods including three evaluation metrics on four data sets. The horizontal axis of the figs indicates the number of features contained in the input feature subset, which increases linearly from top 10 to top 100. The vertical axis of the figs indicates the evaluation metrics of the classifier output after inputting a feature subset. With the number of features increasing, the effect of RRECFS commonly gets a huge boost and then tends to be stable or even fall. It can be observed that RRECFS can output a valid feature subset for classifiers to obtain good performance even when the feature subset contains a few features. RRECFS, in most cases, achieves better results in each of the evaluation indicators compared to other methods. Specifically, RRECFS dramatically surpass the other comparison methods on medical (Fig. 4), bibtex (Fig. 6). As shown in Fig. 3(a), though
Feature Graph-Based Feature Selection Method bibtex
0.3
Rankingloss
bibtex
0.5 ELA-CHI PPT-MI MCLS MIFS
FIMF MLACO MGFS RRECFS
0.45
Average-precision
0.35
127
0.25 0.2 0.15
0.4 0.35 0.3 0.25 0.2 0.15
0.1 10
20
30
40
50
60
70
80
90
0.1 10
100
Number of selected features
20
30
40
50
60
70
80
90
100
Number of selected features
(a)
(b) bibtex
0.16 0.14
Macro-F1
0.12 0.1 0.08 0.06 0.04 0.02 0 10
20
30
40
50
60
70
80
90
100
Number of selected features
(c)
Fig. 6. Results of bibtex
RRECFS is inferior to other methods, the difference in effect is not very large, and it has a similar or better effect on (b) and (c). In addition, the top 30 best features ranked by all method features were selected as feature subsets for the four datasets, and the results on all evaluation metrics are presented in Table 2. The average rank of each method on different data sets is calculated. The best results for all evaluation metrics are bolded in the table. In most cases, the selected features of RRECFS rank first in terms of classification effectiveness. Except on dataset 20NG, the classification performance of RRECFS is inferior to MGFS, but still higher than more than half of the compared methods. Thus, it is evident that the proposed method RRECFS outperforms the other multi-label feature selection methods. In comparison with the feature graph-based methods MLACO and MGFS on three metrics for all datasets, the proposed method obtains 13.2% and 10.5% average enhancement in performance, respectively. Among all experimental results, RRECFS obtained the highest improvement of 31.9% and 48.9% compared to MLACO and MGFS, respectively, on the metric MaF1 of the dataset bibtex. The experimental results validate that RRECFS has better performance in multi-label feature selection compared with these two feature graph-based methods.
4
Conclusion
In this paper, a novel feature graph-based multi-label feature selection method (RRECFS) is proposed, whose main contribution is to retain the global infor-
128
Z. Ye et al.
Table 2. Experimental results for all methods on four datasets with top 30 features Data set
RRECFS ELA-CHI PPT-MI MCLS MIFS
FIMF MLACO MGFS
20NG Ranking loss
0.2192
0.2012
0.2055
0.3813 0.3645 0.2974 0.2196
0.1980
Average precision 0.5514
0.5425
0.5470
0.2739 0.2959 0.2974 0.5272
0.5526
Macro-F1
0.4765
0.3975
0.3255
0.0054 0.1003 0.2974 0.3972
0.4252 0.0440
Medical Ranking loss
0.0380
0.0462
0.0466
0.0670 0.1294 0.0505 0.0479
Average precision 0.8601
0.7812
0.8035
0.6770 0.4299 0.7784 0.7966
0.8120
Macro-F1
0.4355
0.3765
0.3970
0.3151 0.2066 0.3735 0.3899
0.3949
Enron Ranking loss
0.0967
0.1066
0.1033
0.1125 0.1118 0.0974 0.1044
0.1065
Average precision 0.6473
0.5575
0.6260
0.5789 0.5442 0.6350 0.6052
0.6245
Macro-F1
0.1389
0.1087
0.1260
0.1129 0.1005 0.1335 0.1159
0.1264 0.2056
Bibtex Ranking loss
0.1980
0.2134
0.2107
0.3002 0.3115 0.2813 0.2203
Average precision 0.3552
0.3161
0.3374
0.1702 0.1595 0.2154 0.3217
0.3076
Macro-F1
0.0930
0.0634
0.0340
0.0033 0.0066 0.0055 0.0705
0.0624
Average rank
1.33
4.42
3.33
7.25
2.83
7.50
5.08
4.25
mation of feature label correlation and to efficiently evaluate features using the potential relationships between features. The performance of RRECFS is compared together with the classical multi-label feature selection methods, the recent advanced methods, and two feature graph-based methods. From the three evaluation metrics on four real-world datasets, the proposed method is very competitive concerning the available methods and gains a significant improvement compared to the feature graph-based methods. Experimental results demonstrate the importance of uncovering and retaining the global correlation in the multilabel feature selection based on the feature graph methods when constructing graphs using feature label correlation. Using the potential relationships between features in evaluating feature nodes is beneficial to improving feature selection properties. The proposed method does not introduce label correlation, and in the future, we will focus on the integration of label correlation into the feature graph-based methods.
References 1. Chauhan, V., Tiwari, A., Joshi, N., Khandelwal, S.: Multi-label classifier for protein sequence using heuristic-based deep convolution neural network. Appl. Intell. 52(3), 2820–2837 (2022). https://doi.org/10.1007/s10489-021-02529-6 2. Tirupattur, P., Duarte, K., Rawat, Y.S., Shah, M.: Modeling multi-label action dependencies for temporal action localization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1460–1470 (2021) 3. Gao, B.B., Zhou, H.Y.: Learning to discover multi-class attentional regions for multi-label image recognition. IEEE Trans. Image Process. 30, 5920–5932 (2021)
Feature Graph-Based Feature Selection Method
129
4. Cai, J., Luo, J., Wang, S., Yang, S.: Feature selection in machine learning: a new perspective. Neurocomputing 300, 70–79 (2018) 5. Kashef, S., Nezamabadipour, H., Nipour, B.: Multilabel feature selection: a comprehensive review and guide experiments. WIREs Data Min. Knowl. Discov. 8(2), e1240 (2018) 6. Hu, L., Gao, L., Li, Y., Zhang, P., Gao, W.: Feature-specific mutual information variation for multi-label feature selection. Inf. Sci. 593, 449–471 (2022) 7. Paniri, M., Dowlatshahi, M.B., Nezamabadi-Pour, H.: MLACO: a multi-label feature selection algorithm based on ant colony optimization. Knowl.-Based Syst. 192, 105285 (2020) 8. Hashemi, A., Dowlatshahi, M.B., Nezamabadi-Pour, H.: MGFS: a multi-label graph-based feature selection algorithm via PageRank centrality. Expert Syst. Appl. 142, 113024 (2020) 9. Chen, W., Yan, J., Zhang, B., Chen, Z., Yang, Q.: Document transformation for multi-label feature selection text categorization. In: 7th IEEE International Conference on Data Mining (ICDM 2007), New York, pp. 451–456. IEEE Press (2007) 10. Read, J.: A pruned problem transformation method for multi-label classification. In: Proceedings of 2008 New Zealand Computer Science Research Student Conference (NZCSRS 2008), Christchurch, pp. 143–150 (2008) 11. Doquire, G., Verleysen, M.: Feature selection for multi-label classification problems. In: Cabestany, J., Rojas, I., Joya, G. (eds.) IWANN 2011. LNCS, vol. 6691, pp. 9–16. Springer, Heidelberg (2011). https://doi.org/10.1007/978-3-642-21501-8 2 12. Lee, J., Kim, D.W.: Fast multi-label feature selection based on informationtheoretic feature ranking. Pattern Recogn. 48(9), 2761–2771 (2015) 13. Zhang, P., Liu, G., Gao, W.: Distinguishing two types of labels for multi-label feature selection. Pattern Recogn. 95, 72–82 (2019) 14. Hatami, M., Mahmood, S.R., Moradi, P.: A graph-based multi-label feature selection using ant colony optimization. In: 2020 10th International Symposium on Telecommunications (IST), Tehran, pp. 175–180. IEEE Press (2020) 15. Paniri, M., Dowlatshahi, M.B., Nezamabadi-pour, H.: Ant-TD: Ant colony optimization plus temporal difference reinforcement learning for multi-label feature selection. Swarm Evol. Comput. 64, 100892 (2021) 16. McDonald, G.C.: Tracing ridge regression coefficients. WIREs Comput. Stat. 2(6), 695–703 (2010) 17. Bonacich, P.: Some unique properties of eigenvector centrality. Soc. Netw. 29(4), 555–564 (2007) 18. Roffo, G., Melzi, S.: Ranking to learn: feature ranking and selection via eigenvector centrality. In: Appice, A., Ceci, M., Loglisci, C., Masciari, E., Ra´s, Z.W. (eds.) NFMCP 2016. LNCS (LNAI), vol. 10312, pp. 19–35. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-61461-8 2 19. Huang, R., Jiang, W., Sun, G.: Manifold-based constraint Laplacian score for multilabel feature selection. Pattern Recogn. Lett. 112, 346–352 (2018) 20. Jian, L., Li, J., Shu, K., Liu, H.: Multi-label informed feature selection. In: 25th International Joint Conference on Artificial Intelligence (IJCAI), New York, pp. 1627–1633 (2016) 21. Zhang, M.L., Zhou, Z.H.: ML-KNN: a lazy learning approach to multi-label learning. Pattern Recogn. 40(7), 2038–2048 (2007)
O3 GPT: A Guidance-Oriented Periodic Testing Framework with Online Learning, Online Testing, and Online Feedback Yimeng Ren, Yuhu Shang(B) , Kun Liang, Xiankun Zhang, and Yiying Zhang College of Artificial Intelligence, Tianjin University of Science and Technology, Tianjin 300457, China [email protected]
Abstract. Periodic testing (PT) is part and parcel of instructional process, which targets at measuring student proficiency level on specific stage. In general, most previous PTs follow an inflexible offline-policy method, which can hardly adjust testing procedure using the online feedback instantly. In this paper, we develop a dynamic and executed online periodic testing framework called O3 GPT, which selects the most suitable questions step by step, depending on student’s previous timestep’s real-time feedback. To begin with, we employ a stacked GRU to update student’s state representation instantly, which could well capture the long-term dynamic nature from their past learning trajectories, leading to the testing agent perform effective periodic testing. Subsequently, in Stage2, O3 GPT incorporates a flexible testing-specific reward function into the soft actor-critic algorithm (SAC) to guarantee the rationality of all selected questions. Finally, to set up the online feedback, we test O3 GPT on an on-line simulated environment which can model qualitative development of knowledge proficiency. The results of our experiment conducted on two well-established student response datasets indicate that O3 GPT outperforms state-of-the-art baselines in PT task.
Keywords: Periodic testing algorithm
1
· Online learning · Soft actor-critic
Introduction
It is a common notion that periodic summary usually influences people’s implicit states to guide their going-on direction. For instance, periodic shopping history data are essential to find out customers’ preference on particular times and This work is supported by the Natural Science Foundation of Tianjin City (Grant No. 19JCYBJC15300), Tianjin Science and Technology planning project China (Grant No. 21ZYQCSY00050), Tianjin Higher Education Institute Undergraduate Teaching Quality and Teaching Reform Research Project (Grant No. B201005706), National Natural Science Foundation of China (Grant No. 61807024), and the Science and Technology Program of Tianjin (Grant No. 22YDTPJC00940). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 130–141, 2023. https://doi.org/10.1007/978-981-99-1639-9_11
O3 GPT
131
redirect the consumption behaviors. Similarly, in the domain of education, a good periodic test has different functions, e.g., understanding students’ learning process, identifying areas for improvement, and implementing intervention.
Fig. 1. Examples of students’ learning data.
After conducting an extensive literature review, we think that massive effort on traditional testing method somehow is deviating from basic learnings from the area. There is a significant difference between standard testing [1] and periodic testing [2] tasks. First, standard testing is a general assessment type in online learning that makes all students take the same number of testing items in a specified order at the end of each lesson. In contrast, periodic testing involves the administration of good tests over a specified period, which could be weekly, fortnightly or monthly. Second, in terms of the modeling objective, standard testing aims to consolidate the learning effect of students by simultaneously optimizing many objectives, such as novelty, diversity, accuracy, recall and so on. However, the objective of periodic testing is to measure student ability/proficiency in a specific stage rather than consolidate knowledge learned. Finally, from the perspective of utility in real-world intelligent learning environments, a guidanceoriented periodic testing model can provide valuable clues in different learning stage to improve the learning gains and enthusiasm of learners. Clearly, periodic testing for each learning stage during an online course is a complex task. Based on the domain analysis, this study discusses three characteristics inherent to periodic testing domains that previous efforts failed to recognize: (1) A series of questions that help impair the uncertainty of the student’s knowledge proficiency can provide valuable clues for periodic testing. (2) During the testing process, we need accurately estimate the student’s state of mastery of all skills. (3) Periodic testing is a completely new double-objective optimization task. A toy example of the periodic testing task is shown in Fig. 1. As a distinguished direction, Deep Reinforcement Learning (DRL) requires the agent to find an optimal strategy which optimizes multiple objectives and achieves a trade-off among the conflicting objectives. Obviously, the DRL techniques which explore dynamic multi-objective could contribute significantly to
132
Y. Ren et al.
the periodic testing. This paper develop a novel framework for building the periodic testing framework with advanced reinforcement learning technique, called O3 GPT, which provides three advantages over the existing models. 1) We propose a deep reinforcement learning with state-enhanced representation for periodic testing, to the best of our knowledge, this work is the first attempt to consider state representation and reinforcement learning techniques with a focus on periodic testing. 2) Due to the specificity of this task, we design a flexible reward function to quantify Informativeness and Diversity factors so that O3 GPT could adjust and optimize its testing policy by interactively receiving students’ realtime feedback. 3) We conducted experiments on two large-scale datasets, and the results show that O3 GPT performs better than the other models. Moreover, our empirical research can provide a new perspective on accurately measuring students’ proficiency at all stages and help them attain higher achievements.
2
Related Works
Periodic Testing. In terms of theoretical research, Zenisky et al. [3] developed a descriptive analysis method by taking into account the number of stages, the ability level of the student group, and the extent of target information overlap for modules within stages to complete the periodic testing. Item Response Theory (IRT), as a typical unidimensional model in psychometrics, is widely used to perform periodic testing task. Xu et al. [4] proposed a multidimensional adaptive test based model, multidimensional item response theory, for the multistage testing to implement multidimensional assessments. Due to the advantages of nature-inspired metaheuristic, Li et al. [2] attempt to achieve the objectives of MST in the application of cognitive diagnosis via heuristic algorithms. As demonstrated in the literature, periodic testing problem has attracted considerable attention in the cognitive psychology research field. Nonetheless, rare studies performing a deep investigation of periodic testing based on deep learning. Thus, our proposed O3 GPT framework has a unique significance in addressing periodic testing problems. Online Learning Models. To the best of our knowledge, current online learning techniques can be divided into two branches: The first branch is online supervised learning models, which can derive users’ preferences based on the feedback information of user’s interactions with the corresponding items [5]. Despite the effectiveness, the above models only receive the environment feedback without any additional exploration mechanism; this would result in selecting the suboptimal question for student. The other branch is factorization bandit learning models, which is an important branch of online learning [6]. To some extent, it is challenging to apply factorization bandit learning models to online learning environment since its performance is highly dependent on exploration-exploitation tradeoff. Such techniques can be divided into two branches: context-free factorization bandits [7] and contextual bandit [8–10]. The former can update the latent factor vectors instantly by receiving environments’ feedback, the latter by associating side information with each arm.
O3 GPT
3 3.1
133
Methods and Implementation Problem Definition and Notations
To setup this experiment, the learning records that have been answered are randomly divided into two segments, denoted as Rτ = (Ωτ , Jτ ), (τ = 1, 2). For the τ -th segment data, we divided the questions response Rτ of each student into 70%/30% partitions, using the 70% subset Ωτ for training and the remaining 30% subset Jτ as the candidate questions set for periodic testing. Definition 1. Historical learning trajectories The historical learning trajectories of student i can be defined as Ωτ (i) = Ωτ1(i) , Ωτ2(i) , · · · , ΩτN(i) , Ωτt (i) = {qτt , ptτ }, in which qτt ∈ Q represents the question practiced by the student at step t, and ptτ denotes the corresponding performance score. Also, we denote the association between questions and skills as a binary relation G ⊆ Q × K, where (qτt , kit ) ∈ G if qτt is related to kit . Definition 2. Online testing records At each timestep t, the objective of O3 GPT model is to select the learning items, i.e., question qt , based on the online testing’s feedback. Specifically, a t sequential list, Lτ (i) = {L1τ (i) , L2τ (i) , · · · , LX τ (i) }, Lτ (i) ∈ Jτ , is built for collecting the online testing records instantly, where t represents the latest timestamp. 3.2
The Overall Architecture of the Proposed Model
Figure 2 illustrates the process of O3 GPT at each testing step, consisting of three main stages: Online Learning, Online Testing, and Online Feedback. In stage1, O3 GPT first collects the latest testing’s feedback from the t-th timestep and leverages this feedback to update the current state. Then, in stage2, O3 GPT receives a student who arrives at this timestep and apply a soft actor-critic algorithm to automatically select the best suited testing item for the student. Considering that the characteristic of periodic testing, we also carefully designed two testing-specific objectives as reward functions. In stage3, O3 GPT build an interactive testing simulator to mimic the online environment and received realtime rewards calculated according to testing feedback. Stage 1: Online Learning. There are two sub-layers in this stage. Knowledge Embedding Layer. At each testing step t, O3 GPT leverages the testing’s feedback from the t-th timestep to obtain the vectorized embedding representations of each questions in real time. For skill kz , kz∗ is obtained by multiplying its one-hot representation vector kz with the trainable matrix K, ˜t i.e., kz∗ = kzT K. Based on this, at each testing step t, we get the input vector x by concatenating the question with the response score pz . x ˜t = kz∗ ⊕ pz
(1)
134
Y. Ren et al.
Fig. 2. Graphic representation of the O3 GPT model within one time step.
One step further, the joint vector x ˜t is fed to state transition layer. State Transition Layer. Furthermore, with embedded vector sequence x ˜ = {˜ x1 , x ˜2 , · · · , x ˜t } described above as input, we employ a stacked GRU network to capture the long-term dynamic nature from students’ past learning trajectories to build the state st . Specifically, the 1st component of the stacked GRU model applies the GRU network to generate the hidden representation as below: (1)
ht
(1)
= GRU (1) (˜ xt , ht−1 ; θ(1) )
(2)
The 2nd component of the stacked GRU network, denoted as: (2)
ht
(1)
(2)
= GRU (2) (ht , ht−1 ; θ(2) )
(3)
When multiple layers of neurons are stacked, it may also bring some potential problems, e.g., overfitting and harder to train. Inspired by ResNet [1], the residual connection techniques are introduced to alleviate above limitations. Then, we obtain state representation via an activation function as follows. (2)
˜ t ) + bs ) ct = σ(Ws (ht ⊕ x
(4)
Finally, we take the outputs of final state ct as the student’s state st , i.e., st = ct .
O3 GPT
135
At each testing step, the agent observes the current state st and selects the most suitable question according to the testing-specific rewards, after which it enters a new state st+1 . Then, the state transition process is denoted as follows, ˜t ) st+1 = T(st , Ws , x
(5)
Stage 2: Online Testing. After obtaining the student’s current learning state st , O3 GPT introduces a soft actor-critic algorithm to determine which candidate question should be chosen for a certain student. Considering that the design of the reward function will directly affect the agent’s action optimization strategy, we also carefully designed a reward function with two testing-specific objectives. • Informativeness. The key idea of this factor is inspired by the observation that the selected questions could be considered high-Informativeness if the student’s knowledge proficiency has a significant change. Otherwise, if proficiency hardly changes, knowing the response to the questions brings little information. For this purpose, we set up a general diagnosis model M with parameters ϕ estimating and tracing the learner’s knowledge proficiency degree. In r1 , M(zt ; ϕ) represents the knowledge mastery degree of student in current state st , M(zt , at ; ϕ) represents the knowledge mastery degree of student by perform an action, i.e., adding a new learning record to M. Therefore, the increment of knowledge mastery degree ΔM(at ) can be written as: r1 = ΔM(at ) = |M(zt , at ; ϕ) − M(zt ; ϕ)|
(6)
where zt = {(qk1 , p1k ), (qk2 , p2k ), · · · , (qkt , ptk )} contains all historical interaction records of the student from time step 1 to t. We cannot train the agents to make sequential testing due to the lack of real-time feedback from students with the offline data. To solve this problem, we will discuss it in detail in Stage3. • Diversity. During periodic testing process, we need accurately estimate the student’s state of mastery of all knowledge skills. For this purpose, Diversity function aims to select one question from Jτ that maximizes the skill coverage during the whole testing process. Along this line, we designed a Diversity reward function with an incremental property, that is, as questions related to the skill increases, the stimulation to the model will gradually increase from 0 to 1. Comparatively, a punishment should be given if the selected skill that has already answered during the testing process. Obviously, the setting of punishment coefficient β can be flexibly adapted. knt(kt ,q(k)i ) , k \{k1 k2 · · · kt−1 } = ∅ r2 = knt(kt ,q(k)i )+1 t β, else. 1[(q(k)i , kt ) ∈ G] (7) knt(kt , q(k)i ) = q(k)i ∈Jτ
where knt(kt , e(k)i ) represents the number of skills contained in t-th question.
136
Y. Ren et al.
Considering the length of periodic testing are typically short, adopting the traditional model-free DRL algorithms may result in a decline in the Diversity of the selected questions. To encourage the agent executes extensive exploration, we adopt the actor-critic algorithm based on maximum entropy theory for online testing [12]. The Actor network, also called the policy network, generates the corresponding action distribution π(a|s; ∅) based on the input student’s states si . The critic network V(·; θ), also called value network, which employs a deep neural network to estimate the expected return vi from each state. vi = V(si ; γ, θ)
(8)
where γ represents the temperature coefficient that determines the relative importance of the entropy term versus the reward. The actor network is updated with a policy gradient algorithm, defined as: ∂θ = logπ(a|si ; ∅)(Ri − vi )
(9)
The critic network trained by optimizing the distance between the estimated value and actual return, which is shown in Eq. (10). lcritic = ||vi − Ri ||22
(10)
The temperature parameter γ can be optimized with stochastic gradients by: ¯ − logπ(a|si ; γ) ∂γ = H
(11)
¯ represent a constant vector. where H Stage 3: Online Feedback. Based on the above Informativeness factor in online testing stage, we implement a system simulator in it with the knowledge tracing to simulate the online testing environment, i.e., to simulate the reward r1 according to the feedback of the corresponding testing. To be specific, Deep Knowledge Tracing (DKT) [13] is applied to capture the implicit knowledge mastery degree M(zt , ϕ). The input to the network is a representation of the student’s historical learning sequence, while the output is a predicted vector between 0 and 1 that represents the probability of being correct for each skill [16].
4
Experiment Settings
4.1
Datasets Description
The experiments are carried out on two well-established datasets: ASSISTments2009 and Statics2011. The ASSISTments20091 dataset is provided by the tutoring platform ASSISTments. During the preprocessing, we filtered the students that practiced less than 500 questions. After preprocessing, over 127,350 records, 178 students, 16,891 questions and 110 skills were remained. The Statics20112 dataset is stemmed from an engineering statics course. We filtered the 1 2
https://sites.google.com/site/assistmentsdata/home/assistment-2009-2010-data. https://pslcdatashop.web.cmu.edu/DatasetInfo?datasetId=507.
O3 GPT
137
students that practiced less than 400 questions. After preprocessing, there are 162 students, 80 skills, 300 questions and 79,459 records. 4.2
Baselines
To demonstrate the effectiveness of O3 GPT, we select six best-known models (i.e., RAND, Q-learning [14], ε-greedy+ [15], LinUCB [9], hLinUCB [10], and PTS [8]) as baselines. The above-mentioned models are selected because of their outperform in either reinforcement learning (i.e., Q-learning) or bandit online learning (i.e., ε-greedy+, LinUCB, hLinUCB, and PTS). 4.3
Evaluation Metrics
To report the results, four different metrics are used in the performance evaluation. The metrics are listed as follows: Cumulative Reward: As a general evaluation metric, the cumulative reward is leveraged to evaluate the accumulation of the reward for each testing. In order to verify the efficiency of the testing procedure, Student Performance Prediction task is generally used in PT. Here, Rationality and Coverage, are adopted to measure testing-specific rewards, i.e., Informativeness and Diversity. Rationality Metric: We predict student’ performance on every question qi whose ground truth has been recorded. For the τ -th stage, we adopt the common AUC (Area Under ROC) metric to measure the performance of different models. Rationality@N = AU C({M(qi |ϕ(sX ))|qi ∈ Jτ })
(12)
where ϕ(sX ) represents the parameter of model M in the last testing step. Coverage Metric: In the evaluation plan, we design the Coverage metric to calculate the proportion of knowledge skills covered in the testing procedure. X
1 1[k ∈ Lxτ(i) ] Coverage(Lτ (i) ) = |K| x=1
(13)
where |K| denotes the number of knowledge skills contained in each dataset. Periodic Rationality Metric: Actually, there is no universal standard to verify the effectiveness of the periodic testing. As an alternative, we define a new standard, called Rationality(Base), using the training data of τ -th stage to directly predict the AUC result of the student performance in (τ +1)-th stage. Rationality(Base) = AU C({M(qi |ϕ(s0 ))|qi ∈ Rτ +1 })
(14)
where ϕ(s0 ) represents the parameter of model M trained only with the data of τ -th stage. Intuitively, at the τ -th testing stage, if students have improved in predicting the AUC of (τ +1)-th stage after performing periodic testing, compared with
138
Y. Ren et al.
those who have not performed periodic testing, it indicates that our O3 GPT is effective. For model comparisons, based on the training data of τ -th stage and the question set generated by periodic testing, we define Rationality(Period) indicator to predict the AUC result of different models in (τ + 1)-th stage. Rationality(P eriod) = AU C({M(qi |ϕ(sX ))|qi ∈ Rτ +1 })
5
(15)
Experimental Results and Analysis
Evaluation on Cumulative Reward. The learning curves of 1,000 episodes are shown in Fig. 3, in which the abscissa represent the number of training episodes, and the ordinate represent the cumulative reward. Among all datasets, O3 GPT always outperform than its baselines, and converges faster than other methods. On both datasets, HLinUCB consistently outperform PTS. The reason is that, in matrix factorization module, PTS requires resorting to approximate sampling methods, thereby potentially introducing additional errors to the matrix factorization. In addition, the ε-greedy+ method has poor testing quality because its constant exploration probability ε results in a linear growth in the regret. Rationality@N Performance Comparison. Table 1 reports the comparison of Rationality indicator for selected models, and we present the results at step 10 and 20 of testing. One observation is that all the model performances attend to rise as the sequence length increases. It reveals that they can benefit from a longer sequence. Another sight is that, on both datasets, the Rationality of O3 GPT is significantly higher than that of other online algorithms. The results perform better as the online testing component (i.e., SAC) becomes more complex, which is reasonable because complex component give a more comprehensive knowledge diagnosis. This observation also confirms the advantage of our framework: O3 GPT can improve the periodic testing by making it flexible to replace the Q-learning without redesigning the reward function.
Fig. 3. The results of cumulative rewards after 1,000 episodes.
O3 GPT
139
Table 1. Comparisons of the Rationality@N results on student performance prediction. Model
Assistments0910 Statics2011 @10 @20 @10 @20
RAND
0.6338
0.6289
0.6364
0.6427
Q-learning 0.6701
0.6739
0.6870
0.6945
ε-greedy+ 0.6422
0.6519
0.6637
0.6617
LinUCB
0.6751
0.6786
0.6960
0.7004
hLinUCB
0.7013
0.7083
0.7381
0.7431
PTS
0.6810
0.6850
0.7039
0.7093
3
O GPT
0.7306 0.7428 0.7788 0.7847
Fig. 4. Comparison of skill Coverage during testing.
Skill Coverage Performance Comparison. The experimental results in Fig. 4 demonstrate that the value of skill coverage of all methods increasing rapidly with the number of testing steps increasing. Interestingly, the coverage of O3 GPT can achieve a fairly rapid growth at the beginning by selecting the high-Diversity questions. We should note that this feature is critical factor for periodic testing because periodic testing is typically short. Though the Statics2011 dataset contains a limited number of questions, the types are relatively rich and the coverage of the testing skill is slightly better than Assistments0910 dataset. Effectiveness Experimental of the Periodic Outcomes. Additionally, we investigate whether periodic outcomes indeed provide valuable clues for the next stage of learning diagnosis. The results are summarized in Table 2, in which ‘↑’ represents positive gains, and ‘↓’ represents negative losses. We declare that the Base model refers to using the training data of τ -th stage to predict the (τ +1)-th stage students’ performance. In contrast, other models use periodic testing’s learning sequences Lτ (i) and the τ -th stage practice records to predict student performance at the next testing stage. As expected, O3 GPT showed the significant performance improvements, indicating that our method played a good auxiliary role in diagnosing learners in the next testing stage. This is the strongest evidence of effectiveness caused by the periodic testing procedure.
140
Y. Ren et al. Table 2. Effectiveness evaluation of the periodic outcomes on two datasets. Training Sequence Model
Assistments0910 Statics2011
Ωτ
Base
0.6027
0.6122
Ωτ + Lτ (i)
RAND Q-learning ε-greedy+ LinUCB hLinUCB PTS O3 GPT
0.5910↓ 0.6199↑ 0.6105↑ 0.6210↑ 0.6368↑ 0.6299↑ 0.6512 ↑
0.5972↓ 0.6271↑ 0.6186↑ 0.6314 ↑ 0.6357 ↑ 0.6314 ↑ 0.6485 ↑
Fig. 5. Evolution of skill proficiency of a certain student during her 50 testing steps.
Case Study on the Application of Periodic Testing. Figure 5 provides a case study of visualizing the predicted proficiency levels of a certain student on explicit knowledge skills during the testing process. In a group of 50 randomly selected skills, the blue and red lines indicate the ability of a learner before and after the periodic testing, respectively. In addition, the green symbol indicates the learner’s actual performance score, i.e., at t-th testing step, if the skill was answered correctly, the green dot on y = 1, otherwise y = 0. Through comparing the results before and after removing the periodic outcomes of the students, it is obvious that the prediction with the additional periodic testing outcomes is closer to the learner’s real responses in the exercising process.
6
Conclusions
This paper developed a new online periodic testing model called O3 GPT which provides a general approach to automatically tailor both high-Informativeness and high-Diversity testing. Unlike prior studies, we employ a stacked GRU network to capture the dynamic patterns of student learning during the interaction with testing systems and further obtain the state representation of student. Furthermore, motivated by the reinforcement learning techniques, another key novelty of O3 GPT is it adopt SAC algorithm with more stochastic exploration to tailor an adaptive periodic testing procedure for each student. Overall, adaptive periodic testing is a very promising area where we expect to see the deployment
O3 GPT
141
of DRL-driven schemes in real-world educational systems in the coming years. Another interesting open research direction is to better understand the ethical implications of periodic testing, and explore more beneficial multiple objectives that can accurately diagnosis students’ ability level, reducing testing length.
References 1. Huang, Y.M., Lin, Y.T., Cheng, S.C.: An adaptive testing system for supporting versatile educational assessment. Comput. Educ. 52(1), 53–67 (2009) 2. Li, G., Cai, Y., Gao, X., Wang, D., Tu, D.: Automated test assembly for multistage testing with cognitive diagnosis. Front. Psychol. 12, 509844 (2021). https://doi. org/10.3389/fpsyg.2021.509844 3. Zenisky, A., Hambleton, R.K., Luecht, R.M.: Multistage testing: issues, designs, and research. In: van der Linden, W., Glas, C. (eds.) Elements of Adaptive Testing. SSBS, pp. 355–372. Springer, New York (2009). https://doi.org/10.1007/978-0387-85461-8 18 4. Xu, L., Wang, S., Cai, Y., Tu, D.: The automated test assembly and routing rule for multistage adaptive testing with multidimensional item response theory. J. Educ. Meas. 58(4), 538–563 (2021) 5. Li, Y., Li, Z., Feng, W., et al.: Accelerated online learning for collaborative filtering and recommender systems. In: ICDM, pp. 879–885 (2014). https://doi.org/10. 1109/ICDMW.2014.95 6. Hoi, S.C., Sahoo, D., Lu, J., Zhao, P.: Online learning: a comprehensive survey. Neurocomputing 459, 249–289 (2021) 7. Wang, Q., Zeng, C., Zhou, W., et al.: Online interactive collaborative filtering using multi-armed bandit with dependent arms. TKDE 31(8), 1569–1580 (2019) 8. Wang, H., Wu, Q., Wang, H.: Learning hidden features for contextual bandits. In: Proceedings of the 25th ACM International on Conference on Information and Knowledge Management, pp. 1633–1642. (2016). https://doi.org/10.1145/2983323. 2983847 9. Li, L., Chu, W., Langford, J., Schapire, R.E.: A contextual-bandit approach to personalized news article recommendation. In: WWW, pp. 661–670. (2010). https:// doi.org/10.1145/1772690.1772758 10. Kawale, J., Bui, H.H., Kveton, B., Tran-Thanh, L., Chawla, S.: Efficient Thompson sampling for online matrix-factorization recommendation. In: NIPS, vol. 28, pp. 1297–1305 (2015) 11. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, pp. 770–778 (2016). https://doi.org/10.1109/CVPR.2016.90 12. Haarnoja, T., Zhou, A., Abbeel, P., Levine, S.: Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In: ICML, pp. 1861–1870 (2018). https://doi.org/10.48550/arXiv.1801.01290 13. Piech, C., Bassen, J., Huang, J., et al.: Deep knowledge tracing, pp. 505–514 (2015). https://doi.org/10.48550/arXiv.1506.05908 14. Watkins, C.J., Dayan, P.: Q-learning. Mach. Learn. 8(3), 279–292 (1992). https:// doi.org/10.1007/BF00992698 15. McInerney, J., Lacker, B., Hansen, S., et al.: Explore, exploit, and explain: personalizing explainable recommendations with bandits. In: RECSYS, pp. 31–39 (2018). https://doi.org/10.1145/3240323.3240354 16. Liu, Q., Shen, S., Huang, Z., Chen, E., Huang, Z.: A survey of knowledge tracing. https://doi.org/10.48550/arXiv.2105.15106
AFFSRN: Attention-Based Feature Fusion Super-Resolution Network Yeguang Qin1 , Fengxiao Tang1(B) , Ming Zhao1(B) , and Yusen Zhu2 1
School of Computer Science, Central South University, Changsha, China {tangfengxiao,meanzhao}@csu.edu.cn 2 School of Mathematics, Hunan University, Changsha, China
Abstract. Recent, the single image super-resolution (SISR) methods are primarily based on building more profound and more complex convolutional neural networks (CNN), which leads to colossal computation overhead. At the same time, some people introduce Transformer to lowlevel visual tasks, which achieves high performance but also with a high computational cost. To address this problem, we propose an attentionbased feature fusion super-resolution network (AFFSRN) to alleviate the network complexity and achieve higher performance. The detail capture capability of CNN makes its global modeling capability weak, we propose the Swin Transformer block (STB) instead of convolution operation for global feature modeling. Based on STB, we further propose the self-attention feature distillation block (SFDB) for efficient feature extraction. Furthermore, to increase the depth of the network with a small computational cost and thus improve the network’s performance, we propose the novel deep feature fusion group (DFFG) for feature fusion. Experimental results show that this method achieves a better peak signal-to-noise ratio (PSNR) and computation overhead than the existing super-resolution algorithms.
Keywords: Efficient super-resolution networks · Transformer
1
· Convolutional neural
Introduction
Single image super-resolution (SISR) is a fundamental task of computer vision. It is designed to reconstruct high-resolution (HR) images from the corresponding low-resolution (LR) images. The SISR is a seriously ill-posed problem. To solve such a problem, a large number of deep neural networks have been proposed [11– 13,20,26,27]. In spite of their excellent performance, they are not suitable for realworld scenarios due to their sizeable computational overhead and high complexity. To meet the needs in real scenarios, rebuilding lightweight networks has become the recognized future direction of SISR [1,6,9,10,14]. Lightweight CNNs achieved good performance and can be lightweight for actual scenarios, but CNN is weak in both global information capturing and long-term dependency modeling. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 142–152, 2023. https://doi.org/10.1007/978-981-99-1639-9_12
AFFSRN
143
The natural language processing model converter [22] has been widely used in computer vision to capture global information better. These Transformer models [3,24] have achieved good results on SISR, but at the same time, there are problems with a large number of parameters and a large number of calculations. To solve these problems, we propose an attention-based feature fusion superresolution network (AFFSRN) for long-range dependence modeling while reducing computational complexity. Specifically, we propose STB with a shift design for the basic build to capture global information. The shifted window mechanism in STB enables long-range dependency modeling and saves computation. To perform efficient feature extraction, we propose SFDB. It is worth noting that SFDB is a hybrid architecture that combines CNN and Transformer. The STB is the main feature extraction block in SFDB while combined with feature distillation connection (FDC) to perform feature distillation. In addition, we propose a deep feature fusion structure consisting of several SFDBs, which uses a progressive approach for feature fusion. The main contributions of this article are as follows: (1) To achieve efficient feature extraction while controlling the number of parameters and computation, we propose SFDB. The SFDB consists of STB for long-range dependency modeling to capture global features and FDC for local features. (2) To increase the depth of the network while controlling the amount of computation and thus improving the reconstruction capability of the network, we propose DFFG for efficient feature fusion. (3) We evaluate our proposed model on several benchmark data sets. Experimental results show that the proposal achieves a better balance between performance and model complexity compared with other SISR and Transformer models. The structure of this paper is as follows. In Sect. 2, we present a review of some of the work related to our approach. In Sect. 3, we explain the proposed methodology. The results and experiments are detailed in Sect. 4. Finally, we conclude this paper in Sect. 5.
2
Related Work
Many CNN-based models have been proposed for SISR in recent years. Dong et al. [5] proposed an end-to-end three-layer convolutional neural network (SRCNN), which is the pioneering work of using deep learning for SISR. SAN [4] proposed second-order attention networks to achieve stronger feature representation and feature relationship learning. Based on IDN [10], a faster and lighter Information Multiple Distillation Network (IMDN) [9] is also proposed. However, CNN can only extract local features and can not learn global information, which is not conducive to the recovery of texture details. The pioneering work of Vision Transformer (ViT) [7] directly uses the Transformer architecture for
144
Y. Qin et al.
image classification. IPT [3] used a novel Transformer-based network as the pretrained model for low-level image restoration tasks. ACT [24] proposed a twoway structure consisting of Transformer and convolutional branches. However, these Transformer models are very computationally intensive. In this paper, we propose SFDB that combines CNN and Transformer to achieve efficient SR.
3 3.1
Method Overview
Fig. 1. Proposed AFFSRN architecture
As shown in Fig. 1, our proposed structure can be divided into three stages: shallow feature extraction stage, deep feature extraction stage, and reconstruction stage. In the shallow feature extraction stage, we first use the convolutional operation to extract the shallow features from the input LR image, which stabilizes the training process and achieves better extraction accuracy. Then, we use the DFFG module for deep feature extraction and fusion. The details of DFFG are given in Sect. 3.2. In the final reconstruction stage, we use two convolutional (Conv) layers with different filter sizes to extract multi-scale features and multiple sub-pixel layers for up-sampling LR image features. 3.2
Deep Feature Fusion Group
Our proposed DFFG has n SFDBs. The content of SFDB is discussed in Sect. 3.3. To perform a compelling fusion of SFDB, we propose an improved binarized feature fusion (BFF) [19], as shown in Fig. 2. BFF only connects adjacent blocks and, however, there is almost no connection for non-adjacent blocks, which does not facilitate the extraction of more features. We refer to the improved BFF as I-BFF. I-BFF connects the generated feature block with the next SFDB, uses channel shuffling to mix information between groups and thus improve the feature extraction capability of the network, and finally uses adaptive residual learning to output features. To increase the performance of the network, we
AFFSRN
145
apply the residual link to more SFDBs, which can improve the SR performance without introducing additional parameters. Our proposed DFFG can increase the depth of the network while controlling the amount of computation and effectively prevent the disappearance of important information during network flow.
Fig. 2. Deep Feature Fusion Group
3.3
Self-attention Feature Distillation Block
Inspired byRFDN [14], we proposed a brand-new SFDB module, using Swin Transformer block (STB) to replace the 3 × 3 convolution of feature extraction. It can make better use of local features while establishing long-term dependencies. At each step, the feature is divided into two parts by channel separation operation: One part is retained, and the other part is integrated into the next stage of the distillation operation. As shown in Fig. 3(a), the STB followed by a channel splitting layer can be decoupled into 1 × 1 convolution and STB. At the same time, 1 × 1 convolution is used for channel dimensionality reduction, which can reduce parameters while maintaining high efficiency. Given input feature Fin , this process can be described as: Fdistilled 1 , Fcoarse 1 = Split1 (L1 (Fin ))
(1)
Fdistilled 2 , Fcoarse 2 = Split2 (L2 (Fcoarse 1 ))
(2)
Fdistilled 3 , Fcoarse 3 = Split3 (L3 (Fcoarse 2 ))
(3)
146
Y. Qin et al.
Fdistilled 4 = Conv3×3 (Fcoarse 3 )
(4)
where Ln represents the n-th STB, Splitj represents the j-th channel segmentation operation, Fdistilled j represents the j-th distilled feature, and Fcoarse j represents the j-th rough feature further processed by subsequent layers. Then the extracted features are connected as the output of the progressive refinement module (PRM) (gray background in Fig. 3(a)). Fout = Hlinear (Concat(Fdistilled 1 , Fdistilled 2 , Fdistilled 3 , Fdistilled 4 ))
(5)
Where Concat is the concatenation operation along the channel dimension, Hlinear denotes the 1 × 1 convolution, and Fout is the aggregated features.
Fig. 3. The architecture of (a) the self-attention feature distillation block (SFDB) and (b) the Swin Transformer block (STB).
Swin Transformer Block (STB). As shown in Fig. 3(b), the STB is composed of Swin Transformer layer (STL) [15] and convolutional layers. The STL is built by using shifted windows to replace the standard multi-head self-attention (MSA) module in the Transformer block, with other layers kept the same. The shifted window mechanism in STL enables long-range dependency modeling. A Swin Transformer block consists of a shifted window-based MSA module, followed by a 2-layer MLP with GELU non-linearity between them is used for
AFFSRN
147
further feature transformations. The LayerNorm (LN) layer is applied before each MSA module and MLP, and the residual connection is employed for both modules. We define f0 as the feature the input of STB. The feature procedure can be expressed as: (6) fout = fconv (fST L (f0 )) Where fST L (·) is the STL in the STB. Then add a convolutional layer at the end of the STB to enhance the features. fconv (·) is the convolutional layer in the STB, fout represents the output of STB.
4 4.1
Experiments Tranining Detail
In the experiment, we use 800 high-resolution images from the DIV2K [21] dataset to train the proposed model. The Set5 [2], Set14 [25], BSD100 [17], Urban100 [8], and Manga109 [18] datasets are utilized for testing. The criterions adopted in the experiment are PSNR and the structural similarity index measure (SSIM). To compare the different SISR methods, PSNR and SSIM measurements are evaluated based on the Y channel of the YCbCr color space of the generated image. The training images are randomly cropped in 64 patches of 64, 192, and 256 for the ×2, ×3, and ×4 models, respectively. We use the L1 loss as training loss and Adam optimizer with β1=0.9, β2= 0.99 to optimize the network parameters with the initial learning rate of 5e−4. During the training process, we use the cosine annealing strategy to adjust the learning rate adaptively. The AFFSRN uses a channel number of 48 to achieve a better reconstruction quality. 4.2
Comparisons with State-of-the-Arts Methods
We compare the proposed AFFSRN with existing state-of-the-art lightweight SR methods, including Bicubic, VDSR [11], LapSRN [13], IDN [10], CARN [1], IMDN [9], SMSR [23], RFDN [14], and ESRT [16]. PSNR and SSIM are important indicators to evaluate image reconstruction quality. The higher the PSNR and SSIM values, the better the SISR reconstruction effect. All the quantitative results for ×2, ×3, and ×4 are shown in Table 2. The results show that the performance of our proposed model on multiple benchmark data sets is much better than other models, using fewer parameters. This is mainly because our proposed AFFSRN can extract and fuse features more efficiently. In addition, AFFSRN reconstructs more explicit images with more high-frequency textures and edge details closer to the image in Fig. 4.
148
Y. Qin et al.
Fig. 4. Visual comparison of SR performance with state-of-the-art SR methods. Table 1. Effects of DFFG, STB, and FDC modules. Experiments are performed on Set5(×4) and Set14(×4).
4.3
DFFG STB FDC #Params Set5(×4) PSNR/SSIM
Set14(×4) PSNR/SSIM
×
878K
32.28/0.8957
28.70/0.7836
×
755K
32.04/0.8931
28.55/0.7802
×
892K
32.27/0.8953
28.67/0.7830
892K
32.36/0.8967 28.78/0.7854
Ablation Studies
To evaluate the validity of DFFG, STB, and FDC, we performed a series of ablation experiments. We modify AFFSRN into three different variants: AFFSRN without DFFG, AFFSRN without STB, and AFFSRN without FDC. As shown in Table 1, We can see that the PSNR value of AFFSRN is only 32.04 dB on Set5(4×) without STB. This shows that our proposed STB can extract more features. Row 2 of Table 1 shows that the DFFG can improve the PSNR by 0.08 dB with only 14 K of additional parameters. Similarly, the FDC can improve by 0.9 dB without adding additional parameters.
AFFSRN
149
Table 2. The average PSNR and SSIM of different SISR methods on Set5, Set14, BSD100, Urban100, and Manga109. Red and blue indicate the best results and the second best results, respectively. Method
Scale #params Set5 Set14 BSD100 Urban100 Manga109 PSRN/SSIM PSNR/SSIM PSNR/SSIM PSNR/SSIM PSNR/SSIM
Bicubic ×2 VDSR LapSRN IDN CARN IMDN SMSR RFDN ESRT AFFSRN
666K 502K 553K 1592K 715K 985K 626K 677K 833K
30.66/0.9299 37.53/0.9587 37.52/0.9591 37.83/0.9600 37.76/0.9590 38.00/0.9605 38.00/0.9601 38.08/0.9606 38.03/0.9600 38.12/0.9609
30.24/0.8688 33.03/0.9124 32.99/0.9124 33.30/0.9148 33.52/0.9166 33.63/0.9177 33.64/0.9179 33.67/0.9190 33.75/0.9184 33.79/0.9199
29.56/0.8431 31.90/0.8960 31.80/0.8952 32.08/0.8985 32.09/0.8978 32.19/0.8996 32.17/0.8990 32.18/0.8996 32.25/0.9001 32.30/0.9015
26.88/0.8403 30.76/0.9140 30.41/0.9103 31.27/0.9196 31.92/0.9256 32.17/0.9283 32.19/0.9284 32.24/0.9290 32.58/0.9318 32.64/0.9335
30.80/0.9339 37.22/0.9750 37.27/0.9740 38.01/0.9749 38.36/0.9765 38.88/0.9774 38.76/0.9771 38.95/0.9773 39.12/0.9774 39.22/0.9784
Bicubic x3 VDSR LapSRN IDN CARN IMDN SMSR RFDN ESRT AFFSRN
666K 502K 553K 1592K 715K 993K 643K 770K 857K
30.39/0.8682 33.66/0.9213 33.81/0.9220 34.11/0.9253 34.29/0.9255 34.36/0.9270 34.40/0.9270 34.47/0.9280 34.42/0.9268 34.58/0.9286
27.55/0.7742 29.77/0.8314 29.79/0.8325 29.99/0.8354 30.29/0.8407 30.32/0.8417 30.33/0.8412 30.35/0.8421 30.43/0.8433 30.48/0.8446
27.21/0.7385 28.82/0.7976 28.82/0.7980 28.95/0.8013 29.06/0.8034 29.09/0.8046 29.10/0.8050 29.11/0.8053 29.15/0.8063 29.21/0.8086
24.46/0.7349 27.14/0.8279 27.07/0.8275 27.42/0.8359 28.06/0.8493 28.17/0.8519 28.25/0.8536 28.32/0.8547 28.46/0.8574 28.56/0.8613
26.95/0.8556 32.01/0.9340 32.21/0.9350 32.71/0.9381 33.50/0.9440 33.61/0.9445 33.68/0.9445 33.78/0.9458 33.95/0.9455 34.12/0.9478
Bicubic ×4 VDSR LapSRN IDN CARN IMDN SMSR RFDN ESRT AFFSRN
666K 502K 553K 1592K 715K 1006K 643K 751K 892K
28.42/0.8104 31.35/0.8838 31.54/0.8852 31.82/0.8903 32.13/0.8937 32.21/0.8948 32.12/0.8932 32.28/0.8957 32.19/0.8947 32.36/0.8967
26.00/0.7027 28.01/0.7674 28.09/0.7700 28.25/0.7730 28.60/0.7806 28.58/0.7811 28.55/0.7808 28.61/0.7818 28.69/0.7833 28.78/0.7854
25.96/0.6675 27.29/0.7251 27.32/0.7275 27.41/0.7297 27.58/0.7349 27.56/0.7353 27.55/0.7351 27.58/0.7363 27.69/0.7379 27.69/0.7404
23.14/0.6577 25.18/0.7524 25.21/0.7562 25.41/0.7632 26.07/0.7837 26.04/0.7838 26.11/0.7868 26.20/0.7883 26.39/0.7962 26.47/0.7983
24.89/0.7866 28.83/0.8870 29.09/0.8900 29.41/0.8942 30.47/0.9084 30.45/0.9075 30.54/0.9085 30.61/0.9096 30.75/0.9100 31.07/0.9149
4.4
Model Complexity Analysis
We also show the comparison of PSNR vs Multi-Adds on the Set5(×4) in Fig. 5. As we can see, our AFFSRN achieves higher PSNR than VDSR, LapSRN, IDN, CARN, IMDN, SMSR, RFDN, and ESRT while using fewer calculations. This is made possible by the proposed STB and DFFG, which can efficiently extract useful features and fuse them. Obviously, we can see that our proposed method achieves the best balance between reconstruction performance and calculations.
150
Y. Qin et al.
Fig. 5. PSNR vs. Mult-Adds.
5
Conclusion
In this paper, we propose a lightweight network for single-image super-resolution called AFFSRN. The structural design of AFFSRN is inspired by RFDN and Swin Transformer. We adopt a similar architecture as RFDN, but we propose STB to replace the shallow residual blocks (SRB) in RFDN to capture global information through a self-attentive mechanism. In addition, we also propose DFFG for feature fusion, which can control a certain amount of computation while improving the reconstruction performance. Experiments demonstrate that our proposed method, with fewer parameters can achieve a better trade-off between performance and computational complexity.
References 1. Ahn, N., Kang, B., Sohn, K.-A.: Fast, accurate, and lightweight super-resolution with cascading residual network. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11214, pp. 256–272. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01249-6 16 2. Bevilacqua, M., Roumy, A., Guillemot, C., Alberi-Morel, M.L.: Low-complexity single-image super-resolution based on nonnegative neighbor embedding (2012) 3. Chen, H., et al.: Pre-trained image processing transformer. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12299– 12310 (2021) 4. Dai, T., Cai, J., Zhang, Y., Xia, S.T., Zhang, L.: Second-order attention network for single image super-resolution. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11065–11074 (2019) 5. Dong, C., Loy, C.C., He, K., Tang, X.: Image super-resolution using deep convolutional networks. IEEE Trans. Pattern Anal. Mach. Intell. 38(2), 295–307 (2015)
AFFSRN
151
6. Dong, C., Loy, C.C., Tang, X.: Accelerating the super-resolution convolutional neural network. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9906, pp. 391–407. Springer, Cham (2016). https://doi.org/10.1007/ 978-3-319-46475-6 25 7. Dosovitskiy, A., et al.: An image is worth 16 × 16 words: transformers for image recognition at scale. arXiv preprint arXiv:2010.11929 (2020) 8. Huang, J.B., Singh, A., Ahuja, N.: Single image super-resolution from transformed self-exemplars. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5197–5206 (2015) 9. Hui, Z., Gao, X., Yang, Y., Wang, X.: Lightweight image super-resolution with information multi-distillation network. In: Proceedings of the 27th ACM International Conference on Multimedia, pp. 2024–2032 (2019) 10. Hui, Z., Wang, X., Gao, X.: Fast and accurate single image super-resolution via information distillation network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 723–731 (2018) 11. Kim, J., Lee, J.K., Lee, K.M.: Accurate image super-resolution using very deep convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1646–1654 (2016) 12. Kim, J., Lee, J.K., Lee, K.M.: Deeply-recursive convolutional network for image super-resolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1637–1645 (2016) 13. Lai, W.S., Huang, J.B., Ahuja, N., Yang, M.H.: Deep Laplacian pyramid networks for fast and accurate super-resolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 624–632 (2017) 14. Liu, J., Tang, J., Wu, G.: Residual feature distillation network for lightweight image super-resolution. In: Bartoli, A., Fusiello, A. (eds.) ECCV 2020. LNCS, vol. 12537, pp. 41–55. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-67070-2 2 15. Liu, Z., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10012–10022 (2021) 16. Lu, Z., Liu, H., Li, J., Zhang, L.: Efficient transformer for single image superresolution. arXiv preprint arXiv:2108.11084 (2021) 17. Martin, D., Fowlkes, C., Tal, D., Malik, J.: A database of human segmented natural images and its application to evaluating segmentation algorithms and measuring ecological statistics. In: Proceedings Eighth IEEE International Conference on Computer Vision, ICCV 2001, vol. 2, pp. 416–423. IEEE (2001) 18. Matsui, Y., et al.: Sketch-based manga retrieval using manga109 dataset. Multimed. Tools Appl. 76(20), 21811–21838 (2017). https://doi.org/10.1007/s11042016-4020-z 19. Ren, H., El-Khamy, M., Lee, J.: Image super resolution based on fusing multiple convolution neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, pp. 54–61 (2017) 20. Tai, Y., Yang, J., Liu, X.: Image super-resolution via deep recursive residual network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3147–3155 (2017) 21. Timofte, R., Agustsson, E., Van Gool, L., Yang, M.H., Zhang, L.: NTIRE 2017 challenge on single image super-resolution: methods and results. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops, pp. 114–125 (2017) 22. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017)
152
Y. Qin et al.
23. Wang, L., et al.: Exploring sparsity in image super-resolution for efficient inference. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4917–4926 (2021) 24. Yoo, J., Kim, T., Lee, S., Kim, S.H., Lee, H., Kim, T.H.: Rich CNN-transformer feature aggregation networks for super-resolution. arXiv preprint arXiv:2203.07682 (2022) 25. Zeyde, R., Elad, M., Protter, M.: On single image scale-up using sparserepresentations. In: Boissonnat, J.-D., et al. (eds.) Curves and Surfaces 2010. LNCS, vol. 6920, pp. 711–730. Springer, Heidelberg (2012). https://doi.org/10. 1007/978-3-642-27413-8 47 26. Zhao, M., Liu, X., Liu, H., Wong, K.: Super-resolution of cardiac magnetic resonance images using Laplacian Pyramid based on Generative Adversarial Networks. Comput. Med. Imaging Graph. 80, 101698 (2020) 27. Zhao, M., Liu, X., Yao, X., He, K.: Better visual image super-resolution with Laplacian pyramid of generative adversarial networks. Comput. Mater. Continua 64(3), 1601–1614 (2020)
Temporal-Sequential Learning with Columnar-Structured Spiking Neural Networks Xiaoling Luo, Hanwen Liu, Yi Chen, Malu Zhang, and Hong Qu(B) University of Electronic Science and Technology of China, Chengdu 611731, China [email protected], [email protected] Abstract. Human can memorize complex temporal sequences, such as music, indicating that the brain has a mechanism for storing time intervals between elements. However, most of the existing sequential memory models can only handle sequences that lack temporal information between elements, such as sentences. In this paper, we propose a columnar-structured model that can memorize sequences with variable time intervals. Each column is composed of several spiking neurons that have the dendritic structure and the synaptic delays. Dendrites allow a neuron to represent the same element belonging to different contexts, while transmission delays between two spiking neurons preserve the time intervals between sequence elements. Moreover, the proposed model can remember sequence information even after a single presentation of a new input sample, i.e., it is capable of one-shot learning. Experimental results demonstrate that the proposed model can memorize complex temporal sequences like musical pieces, and recall the entire sequence with high accuracy given an extremely short sub-sequence. Its significance lies not only in its superiority to comparable methods, but also in providing a reference for the development of neuromorphic memory systems. Keywords: Spiking neural networks delays · Musical learning
1
· Sequential memory · Synaptic
Introduction
Human’s daily life and knowledge acquisition are almost always related to images, sounds, words, etc., and most of these information is ordered sequences. The human brain, on the other hand, is proficient in learning and storing sequences, and playing them back when necessary. Over the decades, various models have been proposed to simulate the sequential memory process in the brain, and adapt to different realistic tasks. The recurrent neural network (RNN) [4,12], born for sequence learning, is a recursive neural network that connects nodes (neurons) to form a closed loop. RNN and its derivatives long short-term memory (LSTM) [6] and gate recurrent unit (GRU) [2] have been widely used in text generation, machine translation, image captioning and other scenarios [10,16,19]. But their capabilities are limited by c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 153–164, 2023. https://doi.org/10.1007/978-981-99-1639-9_13
154
X. Luo et al.
the long-term storage of information. The learning of parameters is also timeconsuming, and is completely different from the brain, which usually requires only a few or even one-shot learning. Then brain-inspired models was proposed. Active neuro-associative knowledge graph (ANAKG) [7] is an episodic memory model that only needs one injection of input to complete the storage of a sequence. However, the repeated appearance of high-frequency elements makes the sequence retrieval disorder. Then the synaptic delay associative knowledge graphs (SDAKG) [15] eliminates this disorder by adding synaptic gates, but at the expense of a huge demand for gates. Lumped minicolumn associative knowledge graph (LUMAKG) [1] increases the memory capacity by reusing neurons and connections. This idea has actually emerged in the hierarchical temporal memory (HTM) model [5], which is suitable for sequence memory, prediction and anomaly detection [3]. There are also algorithms based on neural ensembles [8] and Bayesian theory [17]. However, many of them [1,7,8,15] can only memorize short sequences with lengths ranging from a few to tens. And more importantly, almost all of these models only focus on the order of sequence elements, and the time interval between elements is meaningless. In contrast, there are some sequences with meaningful time intervals in reality, such as music, in which both the notes and the duration of the notes are important. However, little research has been done on this kind of sequential memory. In this article, we propose a novel columnar-structured network incorporating pyramidal neurons that can naturally memorize such sequences. Compared with existing memory models, our model has the following advantages: 1. The proposed model introduces synaptic transmission delays between neurons to perform the function of representing the time intervals between elements, allowing it memorize sequences with meaningful time intervals by one subnetwork. 2. The model is capable of very long-term memory, that is, it can process sequences with thousands of lengths. And it works well with sequences like music that have a large number of repeating elements and repeating subsequences. 3. It is not necessary for this model to learn by repeating input many times, it grasps each sequence in just one injection, i.e., one-shot learning, and then recalls the sequences with high quality.
2
Model Description
In this section, we introduce the constituent of the model and the sequence learning and retrieval process. 2.1
Network Architecture
The model consists of a spatial-temporal subnetwork and a goal cluster. The former can store both sequence elements and time intervals, avoiding the creation of additional subnets dedicated to memory time intervals like the temporalsequential learning (TSL) [11]. The goal cluster contains a group of goal neurons, which makes sense just as the human brain stores title information.
Temporal-Sequential Learning with Columnar-Structured SNNs
155
Fig. 1. (a) Pyramidal cell (left) and the applied neuron model (right). The green lines represent synapses that are directly responsible for dendrite firing. (b) Dendritic potentials of the three distal dendrites in (a). (c) Neuronal Potential. The cell body of the neuron receives inputs from the proximal (red arrow) and distal dendrites (green arrow). (d) Schematic of network structure. (Color figure online)
Spatial-Temporal Subnetwork. Studies have shown that in different regions of the cerebral cortex neurons almost have the same arrangements - columnar structure, and neurons in a column share similar response properties [13]. Simply, a complex biological column is abstracted into a column containing M identical neurons (Fig. 1(a)) that are sensitive to the same elements. Such N columns form the spatial-temporal subnetwork, as shown in the bottom of Fig. 1(d). Then sequence memory is formed by establishing synaptic connections between these neurons. The concatenation of neurons represents the order of sequence elements, while the delay property of the connections represents the time intervals. Goal Cluster. The goal cluster contains many neurons, each representing the “title” or “label” of a specific sequence. During learning, bidirectional connections are established between the neurons in the spatial-temporal subnetwork and a goal neuron. A goal neuron can fire at the excitation of several (usually K) consecutively firing neurons of spatial-temporal subnetworks. Note that there is a mutual inhibition mechanism during a sequence retrieval to ensure that only one goal neuron fires. In turn, the firing goal neuron can activate the neurons in the spatial-temporal subnetwork to assist the sequence retrieval. 2.2
Spiking Neuron Model Inspired by Pyramidal Cell
The spiking neuron model employed in the spatial-temporal subnetwork is inspired by the HTM cells [5] and pyramidal cells that widely present in the cerebral cortex [14]. While the goal neuron is a simple non-leaky integrate-and-fire neuron, and its calculation process can refer to [18,20]. Here we only describe the neurons in the spatial-temporal subnetwork. As show in Fig. 1(a), input from the lower layer is received through the proximal dendrite, which is usually strong
156
X. Luo et al.
enough to make the neuron fire. Lateral input from the same layer is transferred to the cell body through the integration of distal dendrites. The cell body, which receives input from both proximal and distal dendrites, fires when its potential reaches the threshold, and the neuron is considered to be truly activated. Activated neuron will send a spike through its axon, with horizontal branches going to neurons in the same layer and vertical branches going to neurons in goal cluster. Distal dendrites, as semi-independent information processing units, are dynamically established during learning. They receive spikes from other neurons in the same layer, and transmit a current to the cell body when their potentials cross the dendritic threshold, as shown in Fig. 1(a), (b). The set of presynaptic neurons connected to the distal dendrite k is defined as Γk = {j|j presynaptic to k}. sj is the j-th presynaptic spike, and dj is the delay required for it to reach the dendrite (for ease of representation, the black arrow in Fig. 1(b) represents the delayed presynaptic spike). Then the potential of the ˜ k is: dendrite U ˜ k (t) = γ˜ · U ˜ k (t − Δt) + U sign(sj (t − dj )) − ϑ˜ · s˜k (t − Δt), k
s˜ (t) =
j∈Γk
(1)
˜ 1, if U (t) ≥ ϑ, 0, otherwise, ˜k
where s˜k indicates whether the dendrite is activated. γ˜ (= 0.3) is an attenuation coefficient of dendrite potentials, and ϑ˜ is the dendritic threshold. The cell body receives input currents from its proximal dendrite (I), distal ˜ and goal cluster (I). ˆ Suppose the set of distal dendrites is D, the dendrites (I) set of goal neuron is G, then the potential of the cell body U is: U (t) = γU (t−Δt) + Wp I(t) + Wd s(t) =
k∈D
U (t)/ϑ, if U (t) ≥ ϑ, 0, otherwise,
wk I˜k (t) + Wg
k∈G
Iˆk (t) − ϑs(t−Δt),
(2)
where Wp > Wg > Wd shows the contribution of the three inputs to neuron firing. wk is the weight of the k-th distal dendrite. ϑ is the firing threshold of the cell body. s indicates whether the neuron is activated, and is different from the general spiking neuron that either takes 0 or 1, it can represent the activation level of the neuron. γ (= 0.3) is the potential attenuation coefficient. The current from the k-th distal dendrite is positively related to its potential at the moment ˜ k (t) (˜ sk (t) = 1). For the convenience of calculation, of excitation, i.e., I˜k (t) ∝ U it is actually equal to the potential. The cell body emits at most one spike in the time duration T (maximum time interval between adjacent elements). 2.3
Learning
In the process of learning, the model learns the sequence elements one by one. Under the co-excitation of the proximal input and the distal input, within period
Temporal-Sequential Learning with Columnar-Structured SNNs
157
T , a certain neuron is fired by competing to become the winner. Then lateral connections are established between the winner and other historical winners, bidirectional connections with the corresponding goal neuron are also created. Choice of Winner. When an element is injected as the proximal input into all neurons of the column that favors it, the one with the strongest activation becomes the winner, thereby inhibiting other neurons from firing. The key to winning is that there is a distal dendrite that receives spikes from the historically winning neuron, then emit a dendritic spike just as the proximal input arrives. Therefore, according to Eq. 2, the firing potential of this neuron will be larger than other neurons without dendritic spike, thus fires a larger spike. If there is more than one neuron firing with the largest spike (including the case where all the neurons in the column have no distal dendrites, for example, when the network has just been initialized), then a random one among these neurons with the fewest distal dendrites wins. W is a dynamic set of historically winning neurons, containing up to K neurons that fired recently. The neurons with the earliest firing time will be removed from the set when a new winner join. Establishment of Connections. Lateral connections on distal dendrites link sequence elements, i.e., the keys to forming memories. Initially, the network is empty and no lateral connections exist, which means that each neuron in the spatial-temporal subnetwork has only one proximal dendrite. As learning progresses, the winning neurons form/update distal dendrites, on which they make lateral connections with other neurons. – When a distal dendrite of the winner contributes directly to its firing, that is, the timing of the dendritic spike coincides with the timing of the proximal dendrite input, the distal dendrite is “selected” (dendrite 1 in Fig. 1(a), (b)). If its number of valid synapses (that contribute directly to its firing) is less than K, new connections are created on it from neurons in W that have not been connected before. – When the neuron has no distal dendrites, or its distal dendrites do not transmit any valid spikes (an invalid spike is like the spike of dendrite 3 in Fig. 1), it fires only through the stimulation of its proximal dendrite. Then a new distal dendrite is established, on which K synaptic connections are created with the historical winner neurons in W . A synapse on a distal dendrite has two properties: permanence p and delay d. For newly established synapses, p = 0.5 and d = tpost − tpre , where tpost and tpre are the firing time of the pre- and post-synaptic neurons respectively. For selected dendrites, the valid synapses (green synapses in dendrite 1) will be enhanced: p ← p + 0.02, and other synapses (black synapse in dendrite 1) will be weakened: p ← p − 0.02. When p = 0, the synapse disappears. In addition, the dendrite’s weight wk is equal to 1/nk , where nk is the number of synapses on this dendrite, but when nk > K, wk = 1/K. Bidirectional connections with the goal neurons are divided into the feedforward connections to it and the feedback connections from it. All feedforward
158
X. Luo et al.
Fig. 2. The process of storing sequence G1: {B, 270 ms, A, 330 ms, D, 230 ms, C } begins with an empty network. Lines represent synaptic connections, and numbers on them represent synaptic delays.
connections are fixed with a weight of 1/K and no delay. All feedback connections have a weight of 1, and the delay is equal to the difference in firing time between the winner and the goal neuron. Learning Process of a Sequence. A simple sequence is used to describe the learning process from an empty network: G1: {B, 270 ms, A, 330 ms, D, 230 ms, C } where G1 is the goal, as shown in Fig. 2. During learning, the threshold ϑ is equal to Wp , which ensures that neurons can fire even with only proximal input stimuli. – Step 1: The neuron G1 in the goal cluster are first stimulated and fired immediately at time 0. At the same time, neurons in the column preferring B in the spatial-temporal subnetwork are excited by the proximal input, one of which wins by competing. Next, this neuron extends a feedforward connection to the goal neuron G1, which can be regarded as a synapse on the vertical branch of the neuron’s axon. G1 also extends a feedback connection to this neuron. – Step 2: After 270 ms, a neuron in the column preferring A fires upon stimulation of A, becoming the winner. This winner not only connects with G1, but also generates a distal dendrite, on which a synapse is established with the previous winner. As stated above, the newly created synapse on the distal dendrite has an initial permanence of 0.5, and the delay is equal to the difference between the firing time of the presynaptic neuron and the postsynaptic neuron, i.e., 270 ms. – Step 3: Likewise, after another 330 ms, one neuron in the D-preferring column fires at 600 ms and becomes the new winner. Connections with G1 are established, and a distal dendrite is generated, on which synaptic connections to the previous two winners are built. – Step 4: At 830 ms, one neuron in the C -preferring column fires. The same connections are generated, but there are 3 connections on its newly generated distal dendrites. In fact, connections will be established with all neurons in W , i.e., the most recent K winners.
Temporal-Sequential Learning with Columnar-Structured SNNs
159
The operation of storing a new sequence in a non-empty network is similar to the above procedure, the difference is that some synapses will be created on the existing dendrites instead of creating a new dendrite each time. 2.4
Retrieval
When we hear a short piece of music we’ve heard before, even if we don’t know the name, it’s still possible to recall what followed (without the guidance of the goal). And if we recall the music’s name, then we are likely to recall the full piece from beginning to end (with the guidance of the goal). In TSL, the latter is called context-based retrieval, while the goal-based retrieval is retrieval based on a given target information. But both actually depend on the goal. Here, we separately introduce how to retrieve with/without goal information, simulating the way the human brain recalls. It was mentioned earlier that when there are proximal input in the network, the threshold is equal to Wp . But when there is no proximal input, neurons can not reach this threshold. So an adaptive threshold is needed. That is, if there is a goal information, the threshold drops below Wg to ensure that the goal information can make the neurons fire. And if the network contains neither proximal input nor goal information, the threshold drops below Wd so that neurons can fire when they get enough distal input. Three sequences are stored in a network: G1: {A, 120 ms, B, 120 ms, E }, G2: {A, 120 ms, B, 240 ms, C, 120 ms, D, 60 ms, E }, G3: {B, 120 ms, C, 120 ms, E }, as shown in upper left corner of Fig. 3. Then two retrieval processes are described based on the network.
Fig. 3. Retrieval (without the goal) with input {B, 120 ms, C }, and retrieval (with the goal) with input {B, 240 ms, C, 120 ms} in a network that stores three sequences (upper left). Some connections are not drawn.
160
X. Luo et al.
Retrieval Without the Guidance of the Goal. In this case, except for the initial few hints (in this example, are {B, 120 ms, C }), there is no proximal input and feedback input in the subsequent retrieval process. – Step 1: Upon input of element B, all neurons in the column favoring B fire at 0 ms on the stimulus of the proximal input. Unlike when learning, these neurons are all winners, they transmit their spikes to other neurons (C3, D3, E3, E4). – Step 2: After 120 ms, C is input. The temporal coupling of proximal and distal inputs prompts C3 to fire a large spike at 120 ms. In contrast, D3, E3, and E4 have no proximal inputs, and inputs from distal dendrites are either not yet reached or insufficient to fire them. C3 becomes the winner. Its direct contributor B1 is the historical winner, and the other neurons in B are no longer winners because they have no firing postsynaptic neurons. – Step 3: With no proximal input, one of the distal dendrites of E3, receives spikes from B1 and C3 at the same time and fired at 240 ms, thereby activating E3 (winner). While neither D3 nor E4 received enough spikes to fire. Then the retrieval is terminated because E3 has no postsynaptic neuron. When there is no proximal input, the detection of the next element is completely dependent on the history element. If multiple neurons fire due to the distal input, the one with the largest spike wins. If there is more than one such neuron, they all transmit spikes down, resulting in more than one sequence being detected. Retrieval with the Guidance of the Goal. Retrieval without target guidance can only recall the following elements. Goal information is necessary if we want to recall elements prior to the given hints (here, are {B, 240 ms, C, 120 ms}). – Step 1: When these cues are input, context-based retrieval is conducted first, and neurons B4, C3, and D3 fire in sequence, and they transmit their own spikes to the goal neurons connected to them. – Step 2: G2 fires after receiving three spikes from the spatial-temporal subnetwork, while G3 fails to fire due to insufficient excitation. Then G2 inhibits all other goal neurons, and dominates the subsequent retrieval. – Step 3: After G2 fires, A3 fires immediately via the delay-free connection from G2. – Step 4: B4 fires under the combined excitation of the goal and A3. Similarly, C3, D3, and E4 are also excited successively.
3
Experiment
We validate the model performance on a MIDI dataset containing 331 classic piano pieces [9]. Each piece consists of two tracks, each containing hundreds to thousands of notes. We build one spatial-temporal subnetwork for each track,
Temporal-Sequential Learning with Columnar-Structured SNNs
161
Fig. 4. The retrieval accuracy of goal-based retrieval (a) and the correct number of notes in context-based retrieval (b). The results of TSL are derived from [11].
and each subnetwork contains 89 columns (50 neurons per column), corresponding to the 88 standard pitches of piano keys (MIDI numbers 21–108) and 1 terminator. The extra terminator is to make sure the duration of the last note of the piece is remembered, whereas TSL only has 88 columns because its duration is remembered with another subnet. Note durations are mapped to 0–64 ms (i.e., T = 64), corresponding to the tick number in MIDI from a demisemiquaver to two semibreves. And in the following two retrieval experiments, the number of historically winning neurons K is set to 15, the dendrite firing threshold ϑ˜ is equal to 1.0. 3.1
Goal-Based Retrieval
25 pieces randomly selected from the dataset and another 25 pieces identical to those in TSL, totaling 50 pieces, are used for goal-based retrieval. The task is to retrieve the first 100 notes given a melody name. A note is considered to be correctly detected when both the note and its duration are correctly detected. Figure 4(a) shows the retrieval accuracy. It can be seen that the proposed method can completely retrieve the first 100 notes of each melody, which outperforms
Fig. 5. The retrieval result of the track 1 of Mozart’s Sonata No. 16 in C Maior, KV 545 (a segment). Vertical bars represent spikes fired by neurons in the corresponding columns.
162
X. Luo et al.
TSL. In fact, instead of recalling the first 100 elements, the model executed freely until a terminator was encountered, and it turned out that the model can recall all elements of the test set with 100% accuracy. Figure 5 shows the retrieval results of a sequence. 3.2
Context-Based Retrieval
The same 20 pieces as in TSL are selected for context-based retrieval. The task is to retrieve the remaining 50 notes and the name of the piece from the given short episode (10 consecutive notes with duration). Both our method and TSL can retrieve the names of all the melodies. Figure 4(b) shows the number of correctly retrieved notes for each piece. The proposed model can detect all the 50 notes of each test piece, which is significantly better than TSL. 3.3
Parameter Influence
It is natural that the retrieval accuracy will decrease as the number of input elements increases. But how do different parameters affect this descent process need to be verified. We experimentally analyze the effect of the two critical parameters, the dendritic activation threshold ϑ˜ and the number of historical winning neurons K. The former determines the condition under which dendrites may be reused, and the latter determines how dependent an element is on its context. In this section, experiments are conducted under more difficult conditions, relying solely on context and not guided by the goal. Tests were performed every 1000 elements presented. 10 independent trials were repeated for each experimental condition, and the mean and standard deviation were reported. A larger dendritic threshold means that dendrites are less likely to be activated in another sequence and thus be reused. This makes almost all dendrites dedicated to only one particular sequence. Therefore, extremely high accuracy is achieved at the cost of a huge number of dendrites, as shown in Fig. 6. A smaller dendrite threshold leads to increased dendrite reuse, but it also leads to lower retrieval accuracy. However, the increase of K will weaken this
Fig. 6. The retrieval accuracy, the retrieval time (s) per 1000 elements and the number of dendrites with the increase of the input number when ϑ˜ = 2.0.
Temporal-Sequential Learning with Columnar-Structured SNNs
163
Fig. 7. The retrieval accuracy, the retrieval time (s) per 1000 elements and the number of dendrites with the increase of the input number when ϑ˜ = 1.0.
decline in retrieval accuracy, as shown in Fig. 7, because each element depends on more contextual information. And the increase of K does not necessarily mean longer computing time, sometimes a smaller K will consume more time due to failed retrieval (Fig. 7(b)). Then it was found that ϑ˜ = 1.0, K = 15 is a good choice for balancing retrieval accuracy and computation time.
4
Conclusion
In this article, a novel spiking neural network inspired by the columnar structure of the cerebral cortex is proposed. The successive firing of neurons in these columns represents the order of elements in the sequence, and the delay of synaptic connections represents the time interval between elements. In this way, the model can naturally memorize long sequences with variable time intervals. Moreover, it can memorize a sequence after a single presentation. Experimental results also demonstrate the model’s superior ability to memorize sequences with meaningful time intervals. Acknowledgement. This work is supported in part by the National Key Research and Development Program of China under Grant 2018AAA0100202, in part by the National Science Foundation of China under Grant 61976043.
References 1. Basawaraj, Starzyk, J.A., Horzyk, A.: Episodic memory in minicolumn associative knowledge graphs. IEEE Trans. Neural Netw. Learn. Syst. 30(11), 3505–3516 (2019) 2. Cho, K., et al.: Learning phrase representations using RNN encoder-decoder for statistical machine translation. In: Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, pp. 1724–1734. Association for Computational Linguistics (2014) 3. Cui, Y., Ahmad, S., Hawkins, J.: Continuous online sequence learning with an unsupervised neural network model. Neural Comput. 28(11), 2474–2504 (2016)
164
X. Luo et al.
4. Elman, J.L.: Finding structure in time. Cogn. Sci. 14(2), 179–211 (1990) 5. Hawkins, J., Ahmad, S., Dubinsky, D.: Cortical learning algorithm and hierarchical temporal memory. Numenta Whitepaper, pp. 1–68 (2011) 6. Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural Comput. 9(8), 1735–1780 (1997) 7. Horzyk, A., Starzyk, J.A., Graham, J.: Integration of semantic and episodic memories. IEEE Trans. Neural Netw. Learn. Syst. 28(12), 3084–3095 (2017) 8. Hu, J., Tang, H., Tan, K., Li, H.: How the brain formulates memory: a spatiotemporal model research frontier. IEEE Comput. Intell. Mag. 11(2), 56–68 (2016) 9. Krueger, B.: Classical piano midi page. http://www.piano-midi.de/. Accessed 10 Mar 2022 10. Li, N., Chen, Z.: Image cationing with visual-semantic LSTM. In: Proceedings of the Twenty-Seventh International Joint Conference on Artificial Intelligence (IJCAI-2018), pp. 793–799. International Joint Conferences on Artificial Intelligence Organization (2018) 11. Liang, Q., Zeng, Y., Xu, B.: Temporal-sequential learning with a brain-inspired spiking neural network and its application to musical memory. Front. Comput. Neurosci. 14, 51 (2020) 12. Schuster, M., Paliwal, K.: Bidirectional recurrent neural networks. IEEE Trans. Sig. Process. 45(11), 2673–2681 (1997) 13. Shipp, S.: Structure and function of the cerebral cortex. Curr. Biol. 17(12), R443– R449 (2007) 14. Soltesz, I., Losonczy, A.: CA1 pyramidal cell diversity enabling parallel information processing in the hippocampus. Nat. Neurosci. 21(4), 484–493 (2018) 15. Starzyk, J.A., Maciura, L., Horzyk, A.: Associative memories with synaptic delays. IEEE Trans. Neural Netw. Learn. Syst. 31(1), 331–344 (2020) 16. Sutskever, I., Vinyals, O., Le, Q.V.: Sequence to sequence learning with neural networks. In: Proceedings of the 27th International Conference on Neural Information Processing Systems, pp. 3104–3112. MIT Press, Cambridge (2014) 17. Tully, P.J., Lind´en, H., Hennig, M.H., Lansner, A.: Spike-based Bayesian-Hebbian learning of temporal sequences. PLoS Comput. Biol. 12(5), e1004954 (2016) 18. Wang, Y., Zhang, M., Chen, Y., Qu, H.: Signed neuron with memory: towards simple, accurate and high-efficient ANN-SNN conversion. In: Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence, pp. 2501– 2508 (2022) 19. Zhang, B., Xiong, D., Xie, J., Su, J.: Neural machine translation with GRU-gated attention model. IEEE Trans. Neural Netw. Learn. Syst. 31(11), 4688–4698 (2020) 20. Zhang, M., et al.: Rectified linear postsynaptic potential function for backpropagation in deep spiking neural networks. IEEE Trans. Neural Netw. Learn. Syst. 33(5), 1947–1958 (2022)
Graph Attention Transformer Network for Robust Visual Tracking Libo Wang1 , Si Chen1(B) , Zhen Wang2 , Da-Han Wang1 , and Shunzhi Zhu1 1
Fujian Key Laboratory of Pattern Recognition and Image Understanding, School of Computer and Information Engineering, Xiamen University of Technology, Xiamen 361024, China [email protected], [email protected], {wangdh,szzhu}@xmut.edu.cn 2 School of Computer Science, Faculty of Engineering, The University of Sydney, Darlington, NSW 2008, Australia [email protected]
Abstract. Visual tracking aims to estimate the state of an arbitrary object in a video frame only when the bounding box is given in the first frame. However, the existing trackers still struggle to adapt to complex environments due to the lack of adaptive appearance features. In this paper, we propose a graph attention transformer network, termed GATransT, to improve the robustness of visual tracking. Specifically, we design an adaptive graph attention module to enrich the embedding information extracted by the transformer backbone, which establishes the part-to-part correspondences between the template and search nodes. Extensive experimental results demonstrate that the proposed tracker outperforms the state-of-the-art methods on five challenging datasets, including OTB100, UAV123, LaSOT, GOT-10k, and TrackingNet.
Keywords: Visual tracking
1
· Graph attention · Transformer
Introduction
Visual tracking plays a pivotal role in computer vision, aiming to estimate the state of an arbitrary object in a video frame according to the given initial target box. In recent years, object tracking has broad applications in intelligent traffic, video monitoring, and other fields. However, the performance of the existing trackers are influenced by various challenging factors, including illumination variation, deformation, motion blur, and background clutter. Current mainstream trackers include Siamese-based trackers and transformerbased trackers, which have achieved good results in terms of efficiency and accuracy. Siamese-based trackers [1,13] utilize the cross-correlation for embedding information between the template and search branches. Transformer-based trackers [3,21,29] draw on the global and dynamic modeling capabilities to establish a long-distance correlation between the extracted template and search features. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 165–176, 2023. https://doi.org/10.1007/978-981-99-1639-9_14
166
L. Wang et al.
For example, STARK [29] proposes an encoder-decoder transformer architecture to model the global spatio-temporal feature dependencies between the target object and the search region. Despite their great success, there are still some indispensable drawbacks. The transformer-based trackers can calculate the global and rich contextual interdependence between the template and the search region. However, the extracted features lack the part-level embedding information, resulting in the difficulty of adaptation to complex tracking scenarios. In addition, the template features extracted by the traditional trackers may contain too much redundant information, which will accumulate the tracking errors. To solve the above two points, inspired by the graph attention network and the transformer, we propose a novel end-to-end graph attention transformer tracker GATransT that introduces the graph attention into the transformer-based tracker and establishes the local topological correspondences for the extracted features. We first utilize a transformer as a feature extraction network, which can obtain more semantic information through self-attention and cross-attention of the template and search region features. Next, we use the graph attention mechanism to propagate target information from the template to search region features. To reduce the interference of redundant template information and obtain more accurate tracking results, we employ an adaptive graph attention module to establish the correspondences between initial template nodes, dynamic template nodes and search nodes. In addition, we use the FocusedDropout operation to make the network focus on the target object, thus improving the tracking performance. As shown in Fig. 1, compared with the state-of-the-art trackers, our method can successfully track the target in cases of a similar object, background clutter, and partial occlusion. Finally, we evaluate the different trackers on public tracking benchmarks, including OTB100 [26], UAV123 [18], LaSOT [9], GOT-10k [12], and TrackingNet [19]. The experimental results show that the proposed tracker can outperform the competing trackers significantly. The main contributions of this work can be summarized as follows: – An end-to-end transformer-based graph attention tracking framework is proposed. To the best of our knowledge, this is the first work to introduce the graph attention into transformer for extracting the robust feature embedding information of the target. – We employ an adaptive graph attention module to establish part-to-part correspondences by aggregating initial template nodes, dynamic template nodes, and search nodes to obtain robust adaptive features. – Comprehensive experiments demonstrate the excellent performance of our method compared with the state-of-the-art trackers on the five challenging benchmarks.
2 2.1
Related Work Visual Tracking
The current popular tracking paradigm contains the three main stages, i.e., feature extraction, feature fusion, and prediction. Most researchers focus on the
Graph Attention Transformer Network for Robust Visual Tracking
167
Fig. 1. Visualization comparison of the tracking results with TransT [3], STARK [29], and SiamGAT [11], where GT represents the ground-truth of object tracking. From top to bottom rows indicate the Basketball, Bird1, and Bolt sequences on the OTB100 dataset, respectively.
previous feature learning phase. Generally, almost all previous methods use CNN as the feature extraction network and recent works [6,22,24] also use transformer as the backbone. Regarding the critical feature fusion stage, the previous Siamese trackers often use a cross-correlation operation to obtain the fused response map of the template and search branches. And some online trackers learn the target model to distinguish the target foreground and background. Recently, several works [2,3,6,11,23] use the attention operation as a feature fusion method and also achieve the good performance. This paper will main focus on how to effectively introduce graph attention to transformer tracking. 2.2
Attention for Tracking
Attention mechanisms are often used in the visual tracking methods for feature fusion. On the one hand, the self-attention and cross-attention of the transformer are introduced as a module to improve the learning of long-range dependencies between the template and search branches. For example, TransT [3] introduces the attention operations to the transformer which replace the previous correlation operation to obtain the valuable feature maps. Mixformer [6] uses a transformer-based mixed attention backbone to extract more discriminative features and generate extensive interactions between the templates and search branches. On the other hand, graph attention is also applied in object tracking. SiamGAT [11] establishes the part-to-part correspondences between the target and the search region with a complete bipartite graph and propagates
Search region
Input Tokens
vt
Mixed-aenon qs ks vs
FCNs Corner Head
Adapve Graph Aenon Module
qt kt Linear Projecon
Template patch
Adapve Graph Aenon Module
L. Wang et al.
Patch Embedding
168
Result
Output Tokens
Fig. 2. Overview of the proposed transformer tracking framework based on the adaptive graph attention module.
target information from the template to the search feature. As in GraphFormers [30], it can capture and integrate the textual graph representation by making GNNs nested alongside each transformer layer of the pre-trained language model. Inspired by [30], we take advantage of the graph attention and transformer to obtain more robust adaptive features for visual tracking.
3 3.1
Proposed Method Overview
In this section, we propose an effective graph attention transformer network GATransT for visual tracking, as shown in Fig. 2. The GATransT mainly contains the three components in the tracking framework, including a transformer-based backbone, a graph attention-based feature integration module, and a cornerbased prediction head. In this framework, the adaptive graph attention module we designed enriches the embedding information extracted by the transformer backbone. 3.2
Transformer-Based Feature Extraction
Most previous trackers have adopted deep convolutional neural networks as feature extraction networks, such as AlexNet, ResNet, GoogleNet, etc. Although the effective feature extraction performance has been achieved, the extracted feature and semantic information are still not compact and rich enough. Inspired by Mixformer [6], we refer to the mixed attention network as the backbone, which can establish long-distance associations between the target template and search region to obtain richer feature representations. Since the transformer lacks the processing of part-level feature information, we design an adaptive graph attention module (Sect. 3.3 for more details). In the process of feature extraction, we first convert the input template and search feature vector into tokens through patch embedding. Next, the convolution projection operation obtains the query, key, and value of the template and search features. This process can be formulated as
Graph Attention Transformer Network for Robust Visual Tracking
169
Template Update
Aggregate Template Nodes Dynamic Nodes
...
FocusedDropout Inital Nodes Adapve Features Aggregate Search Nodes Search Nodes
Tokens
Fig. 3. Architecture of the adaptive graph attention module.
templateq/k/v = F latten(Conv2d(Reshape2D(template), k s)), q/k/v
search
= F latten(Conv2d(Reshape2D(search), k s)),
(1) (2)
where templateq/k/v and searchq/k/v are the Q/K/V matrices obtained by the convolution projection of the template and search input token, respectively. Conv2d is a depth-wise separable convolution, and k s refers to the convolution kernel size. Then we perform a mixed attention operation on the obtained queries, keys, and values. We use qt , kt , and vt to represent the target, as well as qs , ks , and vs to denote the search region. In this framework, the mixed attention is defined as vm = Concat(vt , vs ), √ Attnt = Sof tmax(qt ktT / d)vt , √ T Attns = Sof tmax(qs km / d)vm ,
(3) (4) (5)
where vm is the value after concatenating the template and the search region; d represents the dimension of the key; Attnt and Attns are the attention maps of the target and the search region, respectively. Finally, the target token and the search token are concatenated by a linear projection to get the output token. 3.3
Adaptive Graph Attention Module
Most existing trackers utilize cross-correlation operations or self-attention to perform feature fusion, which might lose semantic and part-level embedding
170
L. Wang et al.
information. Inspired by the graph attention tracking, we establish part-to-part correspondences between the template and search region features extracted by the transformer backbone. These are achieved by aggregating initial template nodes, dynamic template nodes, and search nodes to obtain more robust adaptive features. As shown in Fig. 3, given two patches of the template and search region, we convert the respective tokens obtained through patch embedding into h∗w ∗c nodes to generate graphs, where h, w, and c represent the height, width, and channel of the feature, respectively. In order to adaptively learn the feature representation between nodes, we calculate the correlation score between nodes by inner product to express the similarity of two nodes. Among them, to eliminate the background redundant information of the template, initial template nodes and dynamic template nodes are performed for graph attention to obtain more accurate template information. If the set update threshold is reached, the obtained adaptive features whose prediction results exceed a particular confidence score are used to update the dynamic template nodes. In addition, we perform softmax normalization to calculate the correlation scores αij so as to balance the amount of information as follows: (6) αt = Sof tmax((Wt pt )T (Wt pt )), αij = Sof tmax((Ws ps )T αtj ),
(7)
Wt , Wt , Ws
where are the linear transformation matrix of initial template, dynamic template and search feature, respectively; pt , pt and ps refer to the node feature vectors of the template, the dynamic template and the search region, respectively. In Eq. 7, the αij obtained by the above formula can be viewed as the attention given to the search graph node i according to the information of the template node j. Then all nodes in the template are propagated to the i-th node in the search area to calculate the aggregate feature representation of this node, which is written as αij Wv pjt , (8) vi = j∈Vt
where Wv represents the linear transformation matrix of the original feature; Vt represents the node set of template features; pjt refers to the template feature vector of node j. Finally, the aggregated features and the original features are combined to obtain a more robust feature representation as follows. pˆis = Relu(Concat(vi , Wv pis )),
(9)
where pis refers to the search region feature vector of node i. In addition, we refer to the FocusedDrop [27] operation on the aggregation node features after graph attention to obtain adaptive features that can focus on more robust target appearance features in the following formula: Psi = F ocusedDrop(pˆis , rate), where the rate represents a participation rate.
(10)
Graph Attention Transformer Network for Robust Visual Tracking
4
171
Experiments
This section first describes the implementation details of our tracker. Then we analyze the influence of the main components in the proposed method. Finally, we compare the performance of our tracker and the state-of-the-art trackers on the OTB100 [26], UAV123 [18], LaSOT [9], GOT-10k [12], and TrackingNet [19] datasets. 4.1
Implementation Details
The proposed method is performed based on the deep learning framework PyTorch and implemented in an experimental environment of Intel-i7 CPU (32 GB RAM) and GeForce RTX TITAN (24 GB) with an average speed of about 11FPS. We compare our tracker with several state-of-the-art trackers on four public datasets and use one-pass evaluation (OPE) with precision and success plots on the challenging video sequences. Training. We use the train splits of LaSOT [9], GOT-10K [12], COCO2017 [16], and TrackingNet [19] for offline training. The training strategy refers to Mixformer [6]. The entire training process is single-stage without too much parameter tuning and post-processing. Based on the original model, we added the proposed adaptive graph attention module to continue training. After 200 epochs of training, each tracking dataset has a certain effect. We train our tracker by using the ADAM optimizer and the weight decay of 0.0001. The learning rate is initialized as 1e−4. The sizes of search images and templates are 320 × 320 pixels and 128 × 128 pixels, respectively. For data augmentation strategies, we use horizontal flip and rotation to increase the amount of training data. We use the GIoU loss and the L1 loss for training loss with the weights of 2.0 and 5.0, respectively. Inference. We take the initial template, multiple dynamic online templates, and the search area as the input of the tracker to generate the target bounding box and confidence scores. In this case, the dynamic template nodes of the adaptive graph attention module are updated only when the set update interval is reached, and the one with the highest confidence score is selected. 4.2
Ablation Study
To verify the effectiveness of each module of the proposed method, i.e., backbone, feature fusion, head, we conduct a detailed study of the different components on the LaSOT dataset. We use the STARK algorithm that removes temporal information as the baseline. The details of all the competing variants and the ablation results are listed in Table 1. We design five different combinations for the three main components of backbone, feature fusion, and head. As shown in Table 1, we have several vital observations on the five different experimental settings. Firstly, the experimental setting #1 uses resnet-50 as the backbone, encoder-decoder as the feature fusion
172
L. Wang et al.
Table 1. The ablation study of the main components of the proposed method on the LaSOT dataset. Setting
Backbone
Feature Fusion
Head
AUC Score (%)
#1
RestNet-50
Encoder-Decoder
Corner 66.8
#2
Transformer Graph
Corner 67.1
#3
Transformer Graph+DTN
Corner 67.4
#4 (ours) Transformer Graph+DTN+FD Corner 67.5 #5
Transformer Graph+DTN+FD Query
67.3
method, and corner as the feature prediction head. By comparing #1 and #2 in Table 1, we replace the backbone and feature fusion with transformer and graph attention, respectively, and the AUC score is improved by 0.3%. We introduce DTN (Dynamic Template Node) to the graph attention feature fusion method in #2, and the AUC score is improved by 0.3%. Then we add the FD (FocusedDropout) operation based on #3, and the AUC score is increased by 0.1%. Finally, we compare the feature prediction head and find that the corner head is better than the query head. Overall, the main components of our proposed tracker demonstrate the effectiveness and exhibit excellent performance on the LaSOT dataset. 4.3
Comparisons with State-of-the-Art Trackers
In this section, we compare the GATransT with other advanced trackers on five challenge datasets, i.e., OTB100, UAV123, LaSOT, GOT-10k, and TrackingNet datasets. OTB100. The OTB100 [26] dataset is composed of 100 video sequences, which include 11 challenge attributes. Several state-of-the-art trackers are compared in the experiments, including STARK-S [29], SiamRPN [14], GradNet [15], DeepSRDCF [7], SiamDW [33], and SiamFC [1]. Figure 4 reports precision plots and success plots according to the one-pass evaluation (OPE) on the OTB100 dataset. The representative precision score is reported when the threshold is 20 in the legend of Fig. 4 (left). In Fig. 4 (right), when the overlap between the tracking result and the ground truth is greater than 0.5, the tracking is considered successful. We can see from Fig. 4 that GATransT has achieved the highest performance on the OTB100 dataset with the precision score of 88.8% and the AUC score of 68.1%, respectively. It is worth mentioning that compared with the STARK-S based transformer, the accuracy and AUC score of the proposed tracker are higher 0.6% and 0.8% on OTB100, respectively. UAV123. The UAV123 [18] dataset contains 123 short-term video sequences and all sequences are fully annotated with upright bounding boxes. The UAV
Graph Attention Transformer Network for Robust Visual Tracking
173
Fig. 4. Precision and success plots on the OTB100 dataset using the one-pass evaluation (OPE). Table 2. Comparions with state-of-the-art trackers on the UAV123 dataset. CGACD [8] SiamGAT [11] SiamRCNN [20] FCOT [4] TREG [5] STARK-S [29] Ours AUC (%) 63.3
64.6
64.9
65.6
66.9
67.2
68.2
Prec. (%) 83.3
84.3
83.4
87.3
88.4
88.5
89.2
dataset has the more challenging attributes than the OTB dataset, such as aspect ratio change, full occlusion, partial occlusion, and similar object. Table 2 reports the area under curve (AUC) scores and precision score values [25] compared with SiamFC [1], SiamRPN++ [13], CGACD [8], SiamGAT [11], SiamRCNN [20], FCOT [4], TREG [5], and STARK-S [29] on the UAV123 datasets. Among the competing tracking algorithms, our tracker works better than STARK-S in both AUC score and precision score due to the effective adaptive graph attention module to be used. Specifically, the AUC and precision scores of the GATransT are 68.2% and 89.2% on UAV123 respectively. LaSOT/GOT-10k/TrackingNet. LaSOT [9] is a large-scale dataset for long-term tracking, which contains 280 videos with an average length of 2448 frames in the test set. GOT-10K [12] is a large-scale benchmark with over 10000 video segments and has 180 segments for the test set. TrackingNet [19] is a large-scale short-term dataset that contains 511 test sequences without publicly available ground truth. We evaluate the GATransT on the above three datasets, respectively. The compared state-of-the-art trackers include SiamRPN++ [13], SiamFC++ [28], D3S [17], Ocean [34], SiamGAT [11], DTT [31], STMTracker [10], SiamRCNN [20], AutoMatch [32], TrDiMP [21], and STARK-S [29]. From Table 3, our tracker shows excellent performance on three large-scale benchmarks, i.e., LaSOT, GOT-10k, and TrackingNet.
174
L. Wang et al.
Table 3. Comparisons with state-of-the-art trackers on LaSOT, GOT-10k, and TrackingNet. Tracker
5
LaSOT AUC PN orm P
GOT-10k TrackingNet AO SR0.5 SR0.75 AUC PN orm P
SiamRPN++ [13] 49.6
56.9
49.1 51.7 61.6
32.5
73.3
80.0
69.4
SiamFC++ [28]
54.4
62.3
54.7 59.5 69.5
47.9
75.4
80.0
70.5
D3S [17]
–
–
–
59.7 67.6
46.2
72.8
76.8
66.4
Ocean [34]
56.0
65.1
56.6 61.1 72.1
47.3
–
–
–
SiamGAT [11]
53.9
63.3
53.0 62.7 74.3
48.8
–
–
–
DTT [31]
60.1
–
–
63.4 74.9
51.4
79.6
85.0
78.9
STMTracker [10]
60.6
69.3
63.3 64.2 73.7
57.5
80.3
85.1
76.7
SiamRCNN [20]
64.8
72.2
–
64.9 72.8
59.7
81.2
85.4
80.0
AutoMatch [32]
58.2
–
59.9 65.2 76.6
54.3
76.0
–
72.6
TrDiMP [21]
63.9
–
61.4 67.1 77.7
58.3
78.4
83.3
73.1
STARK-S [29]
66.8
76.3
71.3 67.2 76.1
61.2
80.2
85.0
77.6
Ours
67.5 76.9
72.5 67.2 76.7
62.9
80.6 85.1
77.8
Conclusion
In this paper, we propose a novel graph attention transformer network for visual object tracking. This network leverages an adaptive graph attention to enrich long-distance correlation features extracted by the transformer backbone. The employed adaptive graph attention module can acquire robust target appearance features by establishing part-to-part correspondences between the initial template, dynamic template, and search nodes, thus adapting to complex tracking scenarios. The experimental results show that the proposed tracker can outperform the competing trackers significantly on five public tracking benchmarks, including OTB100, UAV123, LaSOT, GOT-10k, and TrackingNet. Acknowledgement. This work was supported in part by the Natural Science Foundation of Fujian Province of China (Nos. 2021J011185 and 2021H6035); the Youth Innovation Foundation of Xiamen City of Fujian Province (No. 3502Z20206068); the Joint Funds of 5th Round of Health and Education Research Program of Fujian Province (No. 2019-WJ-41); and the Science and Technology Planning Project of Fujian Province (No. 2020H0023).
References 1. Bertinetto, L., Valmadre, J., Henriques, J.F., Vedaldi, A., Torr, P.H.S.: Fullyconvolutional siamese networks for object tracking. In: Hua, G., J´egou, H. (eds.) ECCV 2016. LNCS, vol. 9914, pp. 850–865. Springer, Cham (2016). https://doi. org/10.1007/978-3-319-48881-3 56
Graph Attention Transformer Network for Robust Visual Tracking
175
2. Chen, S., Wang, L., Wang, Z., Yan, Y., Wang, D.H., Zhu, S.: Learning metaadversarial features via multi-stage adaptation network for robust visual object tracking. Neurocomputing 491, 365–381 (2022) 3. Chen, X., Yan, B., Zhu, J., Wang, D., Yang, X., Lu, H.: Transformer tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 8126–8135 (2021) 4. Cui, Y., Jiang, C., Wang, L., Wu, G.: Fully convolutional online tracking. arXiv preprint arXiv:2004.07109 (2020) 5. Cui, Y., Jiang, C., Wang, L., Wu, G.: Target transformed regression for accurate tracking. arXiv preprint arXiv:2104.00403 (2021) 6. Cui, Y., Jiang, C., Wang, L., Wu, G.: MixFormer: end-to-end tracking with iterative mixed attention. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2022) 7. Danelljan, M., Hager, G., Shahbaz Khan, F., Felsberg, M.: Convolutional features for correlation filter based visual tracking. In: Proceedings of the IEEE International Conference on Computer Vision Workshops (ICCVW), pp. 58–66 (2015) 8. Du, F., Liu, P., Zhao, W., Tang, X.: Correlation-guided attention for corner detection based visual tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 6835–6844 (2020) 9. Fan, H., et al.: LaSOT: a high-quality benchmark for large-scale single object tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 5374–5383 (2019) 10. Fu, Z., Liu, Q., Fu, Z., Wang, Y.: STMTrack: template-free visual tracking with space-time memory networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 13774–13783 (2021) 11. Guo, D., Shao, Y., Cui, Y., Wang, Z., Zhang, L., Shen, C.: Graph attention tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 9543–9552 (2021) 12. Huang, L., Zhao, X., Huang, K.: GOT-10k: a large high-diversity benchmark for generic object tracking in the wild. IEEE Trans. Pattern Anal. Mach. Intell. 43(5), 1562–1577 (2021) 13. Li, B., Wu, W., Wang, Q., Zhang, F., Xing, J., Yan, J.: SiamRPN++: evolution of siamese visual tracking with very deep networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4282–4291 (2019) 14. Li, B., Yan, J., Wu, W., Zhu, Z., Hu, X.: High performance visual tracking with siamese region proposal network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 8971–8980 (2018) 15. Li, P., Chen, B., Ouyang, W., Wang, D., Yang, X., Lu, H.: GradNet: gradientguided network for visual object tracking. In: Proceedings of the IEEE International Conference on Computer Vision (ICCV), pp. 6162–6171 (2019) 16. Lin, T.-Y., et al.: Microsoft COCO: common objects in context. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) ECCV 2014. LNCS, vol. 8693, pp. 740–755. Springer, Cham (2014). https://doi.org/10.1007/978-3-319-10602-1 48 17. Lukezic, A., Matas, J., Kristan, M.: D3S - a discriminative single shot segmentation tracker. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 7131–7140 (2020) 18. Mueller, M., Smith, N., Ghanem, B.: A benchmark and simulator for UAV tracking. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9905, pp. 445–461. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946448-0 27
176
L. Wang et al.
19. M¨ uller, M., Bibi, A., Giancola, S., Alsubaihi, S., Ghanem, B.: TrackingNet: a largescale dataset and benchmark for object tracking in the wild. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11205, pp. 310–327. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01246-5 19 20. Voigtlaender, P., Luiten, J., Torr, P.H.S., Leibe, B.: Siam R-CNN: visual tracking by re-detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 6577–6587 (2020) 21. Wang, N., Zhou, W., Wang, J., Li, H.: Transformer meets tracker: exploiting temporal context for robust visual tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2021) 22. Wang, Z., Liu, L., Duan, Y., Kong, Y., Tao, D.: Continual learning with lifelong vision transformer. In: CVPR, pp. 171–181 (2022) 23. Wang, Z., Liu, L., Duan, Y., Tao, D.: SIN: semantic inference network for few-shot streaming label learning. IEEE Trans. Neural Netw. Learn. Syst. 1–14 (2022) 24. Wang, Z., Liu, L., Kong, Y., Guo, J., Tao, D.: Online continual learning with contrastive vision transformer. In: Avidan, S., Brostow, G., Ciss´e, M., Farinella, G.M., Hassner, T. (eds.) ECCV 2022. LNCS, vol. 13680, pp. 631–650. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-20044-1 36 25. Wang, Z., Liu, L., Tao, D.: Deep streaming label learning. In: International Conference on Machine Learning (ICML), vol. 119, pp. 9963–9972 (2020) 26. Wu, Y., Lim, J., Yang, M.: Object tracking benchmark. IEEE Trans. Pattern Anal. Mach. Intell. 37(9), 1834–1848 (2015) 27. Xie, T., Liu, M., Deng, J., Cheng, X., Wang, X., Liu, M.: Focuseddropout for convolutional neural network. arXiv preprint arXiv:2103.15425 (2021) 28. Xu, Y., Wang, Z., Li, Z., Ye, Y., Yu, G.: SiamFC++: towards robust and accurate visual tracking with target estimation guidelines. In: Proceedings of the AAAI Conference on Artificial Intelligence (AAAI), pp. 12549–12556. AAAI Press (2020) 29. Yan, B., Peng, H., Fu, J., Wang, D., Lu, H.: Learning spatio-temporal transformer for visual tracking. In: Proceedings of the IEEE International Conference on Computer Vision (ICCV), pp. 10428–10437 (2021) 30. Yang, J., et al.: GraphFormers: GNN-nested transformers for representation learning on textual graph. In: Advances in Neural Information Processing Systems (NeurIPS), pp. 28798–28810 (2021) 31. Yu, B., et al.: High-performance discriminative tracking with transformers. In: Proceedings of the IEEE International Conference on Computer Vision (ICCV), pp. 9836–9845 (2021) 32. Zhang, Z., Liu, Y., Wang, X., Li, B., Hu, W.: Learn to match: automatic matching network design for visual tracking. In: Proceedings of the IEEE International Conference on Computer Vision (ICCV), pp. 13319–13328 (2021) 33. Zhang, Z., Peng, H.: Deeper and wider siamese networks for real-time visual tracking. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4591–4600 (2019) 34. Zhang, Z., Peng, H., Fu, J., Li, B., Hu, W.: Ocean: object-aware anchor-free tracking. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12366, pp. 771–787. Springer, Cham (2020). https://doi.org/10.1007/978-3030-58589-1 46
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding Qimeng Guo, Huajuan Duan, Chuanhao Dong, Peiyu Liu(B) , and Liancheng Xu(B) Shandong Normal University, Jinan 250358, China {lcxu,liupy}@sdnu.edu.cn Abstract. Knowledge graph embedding models characterize entities and relations in structured knowledge graphs as vectors, which is essential for many downstream tasks. Some studies show that knowledge graph embedding models based on graph neural networks can exploit higher-order neighborhood information and generate meaningful representations. However, most models suffer from interference from distant neighborhood noise information. To address the challenge, we propose a graph contrastive learning knowledge graph embedding (GCL-KGE)model to enhance the representation of entities. Specifically, we use the graph attention network to aggregate multi-order neighbor information optimizing the pretrained entity representation. To avoid the inclusion of redundant information in the graph attention network, we combine contrastive learning to provide auxiliary supervised signals. A new method of constructing positive instances in contrastive learning makes the entity representation in the hidden layer produce a marked effect in this paper. We use a triple scoring function to evaluate representation on link prediction. The experimental results on four datasets show that our model can alleviate the interactive noise and achieve better results than baseline models.
Keywords: Knowledge graph networks
1
· Contrastive learning · Graph attention
Introduction
The knowledge graph(KG) stores facts in the real world as graph structures, e.g., in the form of a triple: (The Hours, starred actors, Meryl Streep). The facts in the knowledge graph are always incomplete and manual completion is timeconsuming and laborious. One way to complete the knowledge graph is knowledge graph embedding(KGE), which is the process of embedding entities and relations of the knowledge graph into a continuous vector space while preserving the structural and semantic information. Knowledge graph embedding models apply a scoring function to measure the confidence of triples. Earlier knowledge graph embedding models are traditionally divided into distance-based models and tensor decomposition-based models [1]. They have high computational efficiency or a strong ability to c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 177–188, 2023. https://doi.org/10.1007/978-981-99-1639-9_15
178
Q. Guo et al.
express the model. With the widespread application of neural networks, some researchers apply graph attention networks(GAT) [2] to enrich entity representations because of the ability to exploit higher-order neighbor information. More recently, a number of studies [3,4] demonstrate that contrastive learning has the superiority to train effective graph representation learning when given unlabelled graph data. Some studies attempt to apply contrastive learning to knowledge graph embedding model to mine semantic similarities between triples. Inspired by the success of the above studies, we explore the technical application of graph contrastive learning to knowledge graph embedding. We identify two potential challenges in current knowledge graph embedding models. They are: (i) Most models usually deal with each entity independently and ignore the structural relations of neighborhood triples in the knowledge graph. Therefore they only marginally model the graph structure of the knowledge graph. (ii) The representation of entities is susceptible to interaction noise due to the extra information of the graph attention network extending to distant nodes with little relevance. Besides, how to avoid changing the triples semantics in the process of constructing positive instances is a challenging question for the success of contrastive learning in the knowledge graph embedding model. In this work, we propose a graph contrastive learning knowledge graph embedding model(GCL-KGE) to address these challenges. An encoder-decoder framework combined with contrastive learning is used in our model which obtains the structure information of the knowledge graph while utilizing the interactive noise to optimize the representation. Specifically, we first use the GAT which gives attention with different levels to neighbors to associate entities with neighbors optimizing the pre-trained entity embedding. Then we use a scoring function of a convolutional neural network-based knowledge graph embedding model for link prediction to evaluate the level of embedding. To decrease the influence of interactive noise, we perform contrastive learning based on the GAT without data augmentation which will change the semantics of the triple. The core idea is taking GAT’s hidden representations as positive instances which are semantically similar to the final entity representation. The knowledge graph embedding is learned by maximizing the consistency between different augmented views of the same data in the hidden space. Experimental studies on four datasets demonstrate the effectiveness of GCL-KGE, which significantly improves the accuracy. In summary, our contributions are as follows: 1. We propose a graph contrastive learning knowledge graph embedding(GCLKGE) model to improve accuracy and robustness of existing knowledge graph embedding models. 2. Our proposed contrastive learning architecture provides auxiliary supervision signals for knowledge graph embedding and we perform a theoretical derivation for the direction of entity representation in contrastive learning. 3. Experimental results show the effectiveness of our model on the knowledge graph link prediction.
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding
2
179
Related Works
Knowledge graph embedding has become a popular research topic attracting a wide range of researchers. These methods both determine the reality of a triple by constructing a scoring function. The traditional knowledge graph embedding methods are mainly divided into two types: the distance-based models and the tensor decomposition-based models. The distance-based model focuses on calculating the distance between entities to set the scoring function. TransE [6] is the most widely used of these and regards the relation vector as the translation between the head entity and the tail entity. Based on it, researchers propose more variants in complex relations, such as TransR [22], TransD [23]. The tensor decomposition-based models map head entities to tail entities by multiplying the relationship matrices. RESCAL [7] uses vectors to represent the latent semantics of the entities and matrices to represent relations to model the semantics between potential factors. To better model asymmetric matrices, ComplEx [14] extends the model to the complex space. However, these models separately optimize each triple with the scoring function, overlooking the relations between the triples. Recent studies use neural networks to learn representations of knowledge graphs. ConvKB [5] uses the convolutional neural networks(CNN) to extract triple features for link prediction. R-GCN [8] applies graph convolution networks(GCN) to link prediction and assigns the same weight to the neighboring entities of each entity. To reflect the different importance of different relations for entities, SCAN [9] sets the weight of aggregated neighbor information related to the class of relations. Referring to the idea of the GAT, work [1] proposes aggregating the overall neighbor triple information to train the representation. Nevertheless, as the graph attention network hierarchy deepens, information from more distant entities is aggregated into the entity representation, which leads to the introduction of more noisy information. Contrastive learning is treated as an instrumental part of self-supervised learning and it has ability to learn a good representation based on the data’s characteristics. The goal of contrastive learning is to pull the semantically close pairs together and push apart the negative pairs. Some models often use data augmentation to construct positive and negative instances, such as image flip, rotation, and cutout in computer vision [10,11]. In natural language processing, some studies use sentence crop, span deletion and reordering [12,13]. But the triples in the knowledge graph are different from the sentences in other tasks. If we add random noise to the embedding space, the semantic of the original triple will be changed and the incompleteness of the knowledge graph will be deepened. To address the above issues, we propose the GCL-KGE to learn the knowledge graph embedding. We apply the graph attention network in GCL-KGE to aggregate the neighbor triple information to cope with separate training. And we propose a new way to construct positive instances to solve noise interference without semantic deficits.
180
Q. Guo et al.
Fig. 1. Framework of the GCL-KGE model. We train the contrastive loss as a auxiliary task together with the link prediction loss.
3 3.1
Proposed Model Overview
In this section, we describe our model which utilizes contrastive learning to learn the KG embedding. We present an encoder-decoder model called GCL-KGE in Fig. 1. The encoder learns knowledge graph embedding through the graph attention network to aggregate neighbor’s information. And the decoder provides predictions for possible entities based on a triplet scoring function. We extend the existing model by introducing an auxiliary task to cope with interaction noise encountered in graph attention networks. First of all, we denote directed the KG as G = (ν, ξ) with nodes(entities) v ∈ ν and edges(relations) r ∈ ξ. Then we will introduce the details of the model. 3.2
Encoder
The neural network-based models encode entities and relations individually, ignoring the connections between the various triples in the knowledge graph. To capture the triple interaction information and graph structure information in the knowledge graph, we use a graph attention network to encode entities and relations based on work [1]. First, we obtain the initial embedding of entities and relations through a pre-trained model, using widely used embedding models. Then we place the embedding into the graph attention network to learn new representations. We learn new entity embedding hi in the form of the triple tij k = (hi , hk , hj ) where k is the relation link the entity i and entity j. A single GAT layer can be described as aijk = sof tmax(LeakyRELU (W2 bijk ))
(1)
bijk = W1 [hi : hk : hj ]
(2)
where aijk is the attention score of the neighbor j. W1 and W2 are the linear transformation matrix mapping the initial embedding to a higher dimensional
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding
181
space. bijk is the embedding of a triple tij k . Vector hi ,hj and hk denote embeddings of entities i, j and relation k respectively. Attention score is the importance of the neighbor j for entity i. Softmax is applied in Eq. (1) to compute the attention score. In order to make the network capture more abundant neighbor information about various aspects, we use a multi-head attention mechanism to learn the embedding of entities. The formula shows the output of a layer: hi = σ(
aijk bijk )
j∈νi k∈ξij
(3)
where νi denotes the neighbors of entity i and ξij denotes the set of relations between entities i and j. The process of concatenating N attention heads is shown as follows. N anijk bnijk ) hi = ||n=1 σ( (4) j∈νi k∈ξij
where || represents concatation. σ represents a non-liner function. anijk is the normalized attention coefficients of the neighbor calculated in the n-th attention head. In the final layer of the GAT, we employ averaging to get the final embedding of entities instead of the concatenation operation, as shown: hi = σ(
N 1 n n aijk bijk ) N n=1 j∈ν i
(5)
k∈ξij
We obtain the final entity embedding hi through the process described above. The graph attention network as the encoder of the whole model aggregates information about the surrounding neighbors into the entity’s representation. In brief, an m-layer graph attention network module is able to gather information about the m-hop neighborhood. 3.3
Decoder and Score
The link prediction task is used to evaluate the effectiveness of our embeddings. We use ConvKB as the decoder of the GCL-KGE. Multiple filters are used to generate different feature graphs to capture global relations and transition characteristics between entities. We determine whether each triple (h, r, t) is a true triple by the scoring function. f (tkij ) = (concat(g[hi , hk , hj ] ∗ Ω)) · W
(6)
where W is a linear transformation matrix to score the triple. g is the activate function. Ω is the number of layers of convolutional filter layers and the * is a convolution operator.
182
3.4
Q. Guo et al.
Contrastive Learning
The encoder-optimized entity representations are scored in the decoder. We observe that there are some entities with similar representations which lead to incorrect predictions by the decoder. To make the model more sensitive to entity semantics, we adopt a contrastive learning approach: we treat pre-trained entities’ embedding and the hidden state of the GAT as the positive instances for the entity i. We use h+ i to denote the representation of the positive instance of entity i in the set ν. Different entities in the same batch are used as negative instances in the set μ. Before calculating the contrastive loss, we map the entity representation to the same embedding space through the projection head layer. We adopt the contrastive loss, InfoNCE, about an instance i. Lc =
exp(sim(hi , h+ i ))/τ exp(sim(h i , hj ))/τ j∈μ
−log
i∈ν
(7)
where τ is a temperature hyperparameter. sim is the similarity calculation function and we use dot product operations in our models. As shown in Fig. 1, our model uses the hidden states in the previous m-1 layers as positive instance representations of entities. They are semantically similar and more pure to the final output of the graph attention networks. 3.5
Training Objective
For the given knowledge graph, we train its embedding using the proposed model, the loss of our framework is : L(h, r, t) = Ls + Lc
(8)
where Lc is the contrastive loss we introduced above. We train the GCl-KGE model using the Adam optimizer to minimize the loss function Ls . We use the L2 as the regularizer in our work. Ls =
log(1 + exp(l(hrt) · f (h, r, t))) +
(h,r,t)∈{G∪G }
3.6
l(hrt) = 1
f or(hrt) ∈ G
l(hrt) = −1
f or(hrt) ∈ G
λ 2 ||w||2 2 (9)
Theoretical Analyses
We discuss Eq. (7) to explain how contrastive learning in GCL-KGE make it work inspired. Contrastive learning performs meaningful gradient optimization to guide the embedding of entities. The gradient of the contrastive learning to the entity i is as follows [3]:
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding
∂Lc (hi ) ∂ ∂ =− (hi · h+ log exp(hi · hj /τ ) i /τ ) + ∂hi ∂hi ∂hi j∈N 1 j∈μ hj exp(hi · hj /τ ) + −hi + = τ j∈μ exp(hi · hj /τ )
183
(10)
where Lc (hi ) is the gradient of a single entity i. Then we can derive the trend of change when entity i is updated: ∂Lc (hi ) ∂hi + h j∈μ hj exp(hi · hj /τ ) n i −γ = hi + γ τ j∈μ exp(hi · hj /τ )
= hni − γ hn+1 i
(11)
in which hni is the representation of entity i at this time step and hn+1 is the i representation of entity i at the next time step. The Eq. (11) shows that hi tends γ to update in the direction of h+ i with the weighted value τ . Entities optimally retain their content while gaining information about their neighbors. This solves the problem of the negative weight given by excessive redundancy when noise is added. The otherside, update of hi in a direction that is far from hj with the exp(hi ·hj /τ )
weighted value γ j∈µ exp(hi ·hj /τ ) . In the same vector space, this way pulls away j∈µ entities from their semantically similar entity representations.
4
Experiments
To evaluate the validity of our model and the usefulness of contrastive learning, we conducted a series of experiments and explained them in this section. 4.1
Experimental Setup
Datsets. To evaluate our model, we use four knowledge graph datasets: WN18RR [15], FB15k-237 [16], NELL-995 [17], and Alyawarra Kinship [18]. We show the setup of each dataset in Table 1. We need to specifically note that WN18RR and FB15k-237 are expanded from WN18 [6] and FB15k [19], respectively. Previous work [16] has shown that there is an inverse relationship in WN18 and FB15k resulting in test sets missing and further causing overfitting of the model. Therefore the researchers created two subsets of WN18RR and FB15k-237 to solve the problem. Parameter Settings. We set the graph attention network with two layers in the model. The number of heads for multi-head attention to 2 and the last layer for both entity and relation embeddings to 200. The optimizer for the model uses the Adam optimizer with a learning rate of 0.001. We adjusted the optimal result of the hyperparameter τ to 1.0.
184
Q. Guo et al. Table 1. Statistics of the datasets. Dataset
Entities Relations Train
WN18RR
Valid 3,034
Total
40,943
11
FB15K-237 14,541
237
272,115 17,535 20,466 310,116
NELL-995
75,492
200
149,678
543
104
25
8,544
1,068
Kinship
86,835
Test 3,134
93,003
3,992 154,213 1,074
10,686
Comparable Methods. The baseline we choose are some of the most widely used knowledge graph embedding models, including DistMult [20], ComplEx [14], ConvE [16], TransE [6], ConvKB [5], R-GCN [8], ATTH [21] and KGE-CL [3]. 4.2
Main Results
The evaluation metrics we use are MR, MRR and Hit@n. MR(Mean Rank) is a data of averaging the ranking positions of all correct triples in the sort MRR (Mean Reciprocal Rank) is the inverse of the ranking of all the results given by the standard answers. Hit@n is the proportion of correct entities in the top n ranking and we use Hit@1,Hit@3 and Hit@10. Table 2. Link prediction results on WN18RR and NELL-995 datasets. We bold the best score in the table. Methods
WN18RR
NELL-995
MRR
Hit@1 Hit@3 Hit@10 MRR
Hit@1 Hit@3 Hit@10
DistMult
0.444
0.412
0.47
0.504
0.485
0.401
0.524
0.61
ComplEx
0.449
0.409
0.469
0.53
0.482
0.399
0.528
0.606
ConvE
0.456
0.419
0.47
0.531
0.491
0.403
0.531
0.613
TransE
0.243
0.427
0.441
0.532
0.401
0.344
0.472
0.501
ConvKB
0.265
0.582
0.445
0.558
0.43
0.37
0.47
0.545
R-GCN
0.123
0.207
0.137
0.08
0.12
0.082
0.126
0.188
ATTH
0.466
0.419
0.484
0.551
–
–
–
–
KGE-CL
0.512
0.468
0.531 0.597
–
–
–
–
Our method 0.522 0.477 0.49
0.581
0.541 0.456 0.596 0.698
Referring to previous work, we test our model in a filtered setting that we remove some corrupt triples in the datasets. The experimental results of link prediction in all datasets are presented in Table 2 and Table 3. Our method is effective in all four datasets. Specifically, we achieve optimal results on the FB15K237, NELL-995, and Kinship datasets. We also achieve comparable results to the optimal results on the WN18RR dataset. Compared to the tensor decompositionbased models and distance-based models, our model can preserve semantic information while preserving the structural information of the knowledge graph. It can expand the neighborhood information around the entity to encapsulate and generate a meaningful entity representation. At the same time, compared to the
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding
185
Table 3. Link prediction results on FB15K-237 and Kinship datasets. We bold the best score in the table. Methods
FB15K-237
Kinship
MRR
Hit@1 Hit@3 Hit@10 MRR
Hit@1 Hit@3 Hit@10
DistMult
0.281
0.199
0.301
0.446
0.516
0.367
0.581
0.867
ComplEx
0.278
0.194
0.297
0.45
0.823
0.733
0.899
0.971
ConvE
0.312
0.225
0.341
0.497
0.833
0.738
0.917
0.981
TransE
0.279
0.198
0.376
0.441
0.309
0.9
0.643
0.841
ConvKB
0.289
0.198
0.324
0.471
0.614
0.44
0.755
0.953
R-GCN
0.164
0.1
0.181
0.3
0.109
0.03
0.088
0.239
ATTH
0.324
0.236
0.354
0.501
–
–
–
–
KGE-CL
0.37
0.276
0.408
0.56
–
–
–
–
Our method 0.513 0.435 0.551 0.657
0.907 0.878 0.941 0.98
R-GCN, which also uses the graph neural network, our model incorporates contrastive learning as an auxiliary task to avoid the inclusion of noisy information while improving the effectiveness of entity embedding. 4.3
Ablation Experiments
Effect of Hyperparameter. We conduct the ablation experiment to evaluate the effect of parameter variations on the model. We confirm the importance of the hyperparameter τ in the contrastive loss for improving the efficiency. The model has the highest accuracy rates when the temperature parameter is 1.0 from the results of the ablation experiment in Table 4. The smaller the temperature, the more attention is paid to the difficult negative instances in the same batch. Part of the negative instances dominate the gradient optimization process, and the other negative samples do not work. Also, we evaluate the number of layers in projector head in contrastive learning. The model works best when the number of the neural network layers is 2. Table 4. MRR, Hit@1, Hit@3 and Hit@10 results of different value of τ and projection layers on the Kinship dataset. Methods
Kinship MRR
Hit@1 Hit@3 Hit@10
τ = 0.01
0.886
0.845
0.909
0.967
τ = 0.1
0.889
0.849
0.915
0.969
τ = 0.5
0.891
0.878 0.936
0.977
τ = 1.0
0.907 0.878 0.941 0.98
Projection-1 layer 0.863
0.814
0.892
0.958
Projection-2 layer 0.907 0.878 0.941 0.98 Projection-3 layer 0.875
0.818
0.91
0.96
186
Q. Guo et al.
Effect of Contrastive Loss. We apply ablation experiments to the effect of contrastive learning. We remove the contrastive loss on the encoder-decoder framework called G-KGE. Figure 2 shows the comparison of the five metrics vs epoch on the kinship dataset. From the five subgraphs, we observe that GCLKGE is significantly more effective than G-KGE without contrastive learning as an auxiliary task on the four metrics(Hit@10, Hit@3, Hit@1, MRR). (e) indicates the data on MR of GCL-KGE is lower than G-KGE. The figures also illustrate the improvement of contrastive learning on the GCL-KGE and the effectiveness of the choice of positive instances and negative instances.
Fig. 2. Hit@10, Hit@3, Hit@1, MRR and MR vs Epoch for GCL-KGE and the model without contrastive loss(G-KGE) on kinship dataset. GCL-KGE (black) represents the entire model.
5
Conclusion
In this work, we propose a knowledge graph embedding model combined with contrastive learning. We train the representation of entities and relations by graph attention networks, which aggregate graph structure information and multi-order neighbor semantic information. Then the triple scoring function in the ConvKB is used as a decoder for solving the link prediction task. In addition, we combine with contrastive learning as an auxiliary task to avoid the noise of graph attention networks. We propose a new method to construct positive instances which do not require data augmentation. The idea makes the entity embedding of the hidden layer function. Experimental results on four datasets demonstrate the effectiveness of our model. In future work, we expect that contrastive learning can be applied more to knowledge graph embedding because it has been demonstrated to be helpful in
GCL-KGE: Graph Contrastive Learning for Knowledge Graph Embedding
187
representation learning in many studies. We hope that the development of selfsupervised learning will be beneficial to solve the sparsity of knowledge graphs and improve the generality and transferability of knowledge graph embedding models. Acknowledgements. This work was supported in part in part Key R & D project of Shandong Province 2019JZZY010129, and in part by the Shandong Provincial Social Science Planning Project under Award 19BJCJ51, Award 18CXWJ01, and Award 18BJYJ04.
References 1. Nathani, D., Chauhan, J., Sharma, C., Kaul, M.: Learning attention-based embeddings for relation prediction in knowledge graphs. In: Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp. 4710–4723 (2019) 2. Veliˇckovi´c, P., Cucurull, G., Casanova, A., Romero, A., Li` o P., Bengio, Y.: Graph attention networks. In: International Conference on Learning Representations (2018) 3. Xu, W., Luo, Z., Liu, W., Bian, J., Yin, J., Liu, T.Y.: KGE-CL: contrastive learning of knowledge graph embeddings. arXiv e-prints, arXiv-2112 (2021) 4. Zhu, Y., Xu, Y., Yu, F., Liu, Q., Wu, S., Wang, L.: Graph contrastive learning with adaptive augmentation. In: Proceedings of the Web Conference 2021, pp. 2069–2080 (2021) 5. Dai Quoc Nguyen, T.D.N., Nguyen, D.Q., Phung, D.: A novel embedding model for knowledge base completion based on convolutional neural network. In: Proceedings of NAACL-HLT, pp. 327–333 (2018) 6. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., Yakhnenko, O.: Translating embeddings for modeling multi-relational data. In: Advances in Neural Information Processing Systems, vol. 26 (2013) 7. Nickel, M., Tresp, V., Kriegel, H.P.: A three-way model for collective learning on multi-relational data. In: Icml (2011) 8. Schlichtkrull, M.S., Kipf, T.N., Bloem, P., van den Berg, R., Titov, I., Welling, M.: Modeling relational data with graph convolutional networks. In: ESWC (2018) 9. Shang, C., Tang, Y., Huang, J., Bi, J., He, X., Zhou, B.: End-to-end structureaware convolutional networks for knowledge base completion. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 3060–3067 (2019) 10. Chen, M., Wei, F., Li, C., Cai, D.: Frame-wise action representations for long videos via sequence contrastive learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13801–13810 (2022) 11. Chen, X., He, K.: Exploring simple Siamese representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 15750–15758 (2021) 12. Meng, Y., Xiong, C., Bajaj, P., Bennett, P., Han, J., Song, X.: Coco-LM: correcting and contrasting text sequences for language model pretraining. In: Advances in Neural Information Processing Systems, vol. 34 (2021) 13. Wu, Z., Wang, S., Gu, J., Khabsa, M., Sun, F., Ma, H.: Clear: contrastive learning for sentence representation. arXiv preprint: arXiv:2012.15466 (2020)
188
Q. Guo et al.
´ Bouchard, G.: Complex embed14. Trouillon, T., Welbl, J., Riedel, S., Gaussier, E., dings for simple link prediction. In: International Conference on Machine Learning, pp. 2071–2080. PMLR (2016) 15. Toutanova, K.: Observed versus latent features for knowledge base and text inference. ACL-IJCNLP 2015, 57 (2015) 16. Dettmers, T., Minervini, P., Stenetorp, P., Riedel, S.: Convolutional 2D knowledge graph embeddings. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32, no. 1 (2018) 17. Xiong, W., Hoang, T., Wang, W.Y.: DeepPath: a reinforcement learning method for knowledge graph reasoning. In: Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pp. 564–573 (2017) 18. Lin, X. V., Socher, R., Xiong, C.: Multi-hop knowledge graph reasoning with reward shaping. In: EMNLP (2018) 19. Bordes, A., Weston, J., Collobert, R., Bengio, Y.: Learning structured embeddings of knowledge bases. In: Twenty-fifth AAAI Conference on Artificial Intelligence (2011) 20. Yang, B., Yih, S. W.T., He, X., Gao, J., Deng, L.: Embedding entities and relations for learning and inference in knowledge bases. In: ICLR (2015) 21. Chami, I., Wolf, A., Juan, D.C., Sala, F., Ravi, S., R´e, C.: Low-dimensional hyperbolic knowledge graph embeddings. In: Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 6901–6914 (2020) 22. Lin, Y., Liu, Z., Sun, M., Liu, Y., Zhu, X.: Learning entity and relation embeddings for knowledge graph completion. In: Twenty-ninth AAAI Conference on Artificial Intelligence (2017) 23. Ji, G., He, S., Xu, L., Liu, K., Zhao, J.: Knowledge graph embedding via dynamic mapping matrix. In: ACL (2015)
Towards a Unified Benchmark for Reinforcement Learning in Sparse Reward Environments Yongxin Kang1,2 , Enmin Zhao1,2 , Yifan Zang1,2 , Kai Li1,2 , and Junliang Xing1,3(B) 1
2
School of Artificial Intelligence, University of Chinese Academy of Sciences, Beijing, China [email protected] Institute of Automation, Chinese Academy of Sciences, Beijing, China 3 Department of Computer Science and Technology, Tsinghua University, Beijing, China Abstract. Reinforcement learning in sparse reward environments is challenging and has recently received increasing attention, with dozens of new algorithms proposed every year. Despite promising results demonstrated in various sparse reward environments, this domain lacks a unified definition of a sparse reward environment and an experimentally fair way to compare existing algorithms. These issues significantly affect the in-depth analysis of the underlying problem and hinder further studies. This paper proposes a benchmark to unify the selection of environments and the comparison of algorithms. We first define sparsity to describe the proportion of rewarded states in the entire state space and select environments by this sparsity. Inspired by the sparsity concept, we categorize the existing algorithms into two classes. To provide a fair comparison of different algorithms, we propose a new metric along with a standard protocol for performance evaluation. Primary experimental evaluations of seven algorithms in ten environments provide a startup user guide of the proposed benchmark. We hope the proposed benchmark will promote the research of reinforcement learning algorithms in sparse reward environments. The source code of this work is published on https://github. com/simayuhe/ICONIP Benchmark.git. Keywords: Reinforcement learning exploration · exploitation
· sparse-reward environments ·
Sparse reward environments in Reinforcement Learning (RL) are of great value because they are good representations of many real-life situations without immediate feedback. They are more challenging than ordinary RL environments, for This work was supported in part by the National Key Research and Development Program of China under Grant No. 2020AAA0103401, in part by the Natural Science Foundation of China under Grant No. 62076238 and 61902402, in part by the CCFTencent Open Fund, and in part by the Strategic Priority Research Program of Chinese Academy of Sciences under Grant No. XDA27000000. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 189–201, 2023. https://doi.org/10.1007/978-981-99-1639-9_16
190
Y. Kang et al.
connecting a long sequence of actions to a distant reward is difficult. For example, Montezuma’s Revenge is a notoriously sparse reward game. The agent must navigate many rooms filled with traps and few rewards [4]. Traditional RL algorithms, such as DQN [17] and A3C [16], fail to complete the tasks in these environments. Though many new algorithms are proposed, this domain lacks a unified benchmark to select environments and compare algorithms. Generally, algorithms are inconsistent in their description of validation environments. Although many algorithms claim to verify their ideas in a sparse reward environment, their definitions and descriptions of the environments’ reward sparsity vary greatly. ECR [25] uses “sparse”, “very sparse”, and “dense” to describe different environments in the experiment part. Researchers in UnfiyCBEIM [3] and SIL [19] use “easy” and “hard” to classify the Atari games. What exactly is a genuinely sparse reward environment? Which environment is more “sparse”? It is necessary to define reward sparsity for the environments to answer these questions and carry fair comparisons. Furthermore, the existing algorithms choose different evaluation metrics, thus making it challenging to develop a fair comparison between algorithms. ICM [21] measures the distance covered by the agent as a result of executing its policy, while EMI [14] and SIL use mean rewards to illustrate their performance. Meanwhile, UnifyCBEIM even uses the increase in the room number agents have visited to demonstrate the algorithm’s effectiveness. These criteria focus on only one aspect of the problem, which is neither comprehensive nor generic. Thus, there is a need for a reasonable evaluation metric of these algorithms. Another limitation to the study of sparse reward problems is the nonuniformity of algorithm implementation. For example, ICM uses 42 × 42 grey images as the input and an LSTM module as the final layer to construct a policy network. RND [5] downsamples the input into an 84 × 84 array and uses a GRU module to build the RNN policy. We can not determine whether the improvement in the final result is caused by these setting details or by the algorithm itself. A unified implementation will facilitate the comparison of new algorithms. To propose a benchmark to solve the above problems, We first define the sparsity of the reward to help select environments. The definition of Sparsity quantitatively describes the proportion of rewarded states in the entire state space. For a non-sparse environment, one reward can be obtained at each state, but for an environment with sparse rewards, the average number of rewards obtained at each state is close to zero. This definition is independent of the environment type and the reward magnitude, so it is possible to compare all environments in our benchmark no matter their game types, which further provides a guide to choosing the proper environment when testing algorithms. After quantitatively evaluating the reward sparsity, we re-classify the existing algorithms according to this sparsity. Some algorithms are designed for exceptionally sparse environments, in which the algorithms are usually in a no-reward exploration stage. Some algorithms are helpful in environments with moderate sparsity, in which the algorithms pay more attention to exploiting the occasionally obtained reward signals in a local-reward exploitation stage. We classify the
Towards a Unified Benchmark for Reinforcement Learning
191
existing algorithms according to the different stages they are concerned with and summarize the design principles within each category. We also propose a general metric to measure the performance of these algorithms. This metric measures the Average Reward Times per Step, which is named as ARTS. Just like the sparsity of the environment’s reward, ARTS is also not affected by the magnitude of each reward achieved by algorithms. It helps compare the algorithms’ performance in various environments and is a necessary part of our benchmark. Another essential part of forming the benchmark is a unified implementation of algorithms. We implement typical algorithms in different categories with the same environment wrapper and evaluation protocol to make a fair analysis. These unified implementations will be open-source as a toolkit, which will facilitate further research in sparse reward environments. To summarize, our benchmark makes three main original contributions: – We define a new sparsity to quantify the core attributes of different sparsereward environments used for RL evaluation. This new sparsity helps identify and measure the difficulties of environments and makes it easier to select the typical ones for better RL evaluation. – We provide a new perspective of RL in sparse-reward environments as an iteration of two stages: a no-reward exploration stage and a local-reward exploitation stage. This perspective generates a new criterion for classifying various existing algorithms into two separated groups. – We propose an evaluation metric, ARTS, to report the performance of algorithms more accurately and a unified evaluation protocol to make fair algorithm comparisons. Our benchmark includes ten environments covering various sparsity ranges, seven representative methods in different categories, a fairer evaluation criteria, and a unified evaluation protocol for these methods. Our experimental results verify the correlation between sparsity and different classes of algorithms with a unified implementation under the new evaluation metrics. Benchmark analysis can be used as a reference for subsequent studies.
1
Related Work
Before the introduction of DRL(2015) [17], researchers have noticed that only reward signals in RL are not enough to solve learning problems in sparse reward environments. They present different perspectives on reduction of uncertainty [29], behavior clone [24], Bayesian [13], intrinsic motivation [28], or inverse RL [18]. With the development of deep learning techniques, many algorithms have improved the original idea to adapt to the video input. For example, UnifyCBEIM [3], ICM [21], and ECR [25] use a trained model to calculate the uncertainty in the 3D video. DQfD [12] and GAIL [10] leverage the advantages of big data to expand the idea of behavior cloning and inverse RL. HashCount [31], PixelCNN [20], and RND [5] employ neural networks to construct intrinsic reward
192
Y. Kang et al.
or curiosity. There is a growing number of algorithms for sparse reward environments, such as EMI [14], SIL [19], and BeBold [33]. However, Classification methods and a comparison benchmark need to be improved urgently. There are many works on the unified implementation of RL algorithms and the building of computational platforms, such as OpenAI Baselines [8] and Rainbow [11]. However, these are not designed for sparse reward environments. Work [30] analyzes the exploration process of some algorithms in Montezuma’s Revenge. Its results suggest that recent gains may be better attributed to architectural change rather than better exploration schemes. We will give a unified implementation of different algorithms starting from the sparsity of the environment and analyze in-depth the similarities and differences of each class of algorithms. Meanwhile, our benchmark will provide recommendations for using algorithms in environments with various degrees of sparsity.
2 2.1
Building the Benchmark Quantify Sparse Reward Environments with Sparsity
To quantitatively describe whether an environment is sparse or not, we introduce a formulaic definition of the reward sparsity of the environment. Definition 1 (Sparsity). The sparsity of reward function r(s) on state space S is defined as follows: m[Sr(s)=0 ] fsp (r(s)) = , (1) m[S] where s ∈ S and the m[·] is a measurement of set. When S is a finite and known set, m[·] can be directly calculated by the size of the set. When the state space is infinite or unknown, m[·] can be evaluated by sampling. With the definition of Sparsity, we can make a quantitative and objective description of the reward distribution in the state space instead of using ambiguous words like “sparse” or “very sparse”. This definition is independent of the environment type and the reward magnitude, so it is possible to achieve a measure for different environments. In some discrete environments, the sparsity relates to how many steps, on average, a random agent is spaced to get a reward. For example, in MsPacman, we can get a small reward every four steps if we forward to a new area, so the upper bound of fsp (r(s)) in MsPacman is 0.25. However, in Montezuma’s Revenge, agents do not always get a reward even after four hundred steps, where the sparsity is almost zero. Thus, we measure each environment by sparsity and provide a quantitative view for comparing them. If an environment’s sparsity is close to one, it means that almost every step in this environment is rewarded. If the sparsity of an environment is close to zero, then the agent gets little reward for random exploration in that environment. Later in the experimental section, we will see the differences in the performance of various algorithms in environments with different sparsity.
Towards a Unified Benchmark for Reinforcement Learning
2.2
193
Sparsity Inspired RL Algorithm Categorization
Inspired by the definition of sparsity, we give a new classification method that takes the sparsity of the environment they concern into account. We find that many algorithms are designed for very sparse environments. In contrast, some algorithms are designed for environments where the sparsity ranking is not high. Generally, learning in sparse reward environments usually contains an iteration of two stages. One is a no-reward exploration stage, and the other is a localreward exploitation stage. As the proportion of reward states in the environment decreases, algorithms designed for this type of environment will focus more on the no-reward exploration stage. We categorize the existing algorithms according to the stages they focus on during the learning process. The first kind is No-reward Exploration Methods, such as Curiosity [3], Count-based [29], Intrinsic reward [31]. The second is Localreward Exploitation Methods, such as SIL [19], DTSIL [9]. We group existing algorithms into these categories to better understand them. No-Reward Exploration Methods. In the no-reward exploration stage, researchers have proposed various reward-unrelated bonuses to encourage the agent to explore a more extensive state space as quickly as possible. The rewardunrelated bonuses usually come from a human’s understanding of the environment, which may be information about a particular state or specific trajectories. We show some common methods in Table 1. Table 1. No-reward Exploration Methods Reward-unrelated bonus about state Reward-unrelated bonus about trajectory Curiosity [3], Count-based [29],
ECR [25], Plan2Explore [27],
ICM [21], RND [5], RenYi [32],
EMI [14], Empowerment [15],
PixelCNN [20], HashCount [31]
RVIC [2], BeBold [33]
Local-Reward Exploitation Methods. In the no-reward exploration stage, we have to make many attempts. In contrast, in the second stage, many methods discuss how we can effectively use these attempts to form skills and go farther because the reward information obtained through attempts is precious. We category the local-reward exploitation methods into two subclasses by the source of reward trajectory (Table 2). Table 2. Local-reward Exploitation Methods Reward-related methods with expert’s data
Reward-related methods with agent’s data
Imitation learning [22],
SIL [19], ESIL [7],
DQfD [12], Ape-X DQfD [23],
DTSIL [9], SILSP [6],
LearnFromYouTube [1]
GASIL [10]
194
Y. Kang et al.
Conclusion of Classification. No-reward exploration methods use some artificially designed rules to encourage agents to find some states that humans think are more conducive to space exploration. In the specific implementation form, they construct reward-unrelated bonuses, no matter about state or trajectory. These bonuses are used as a temporary reward signal. At each time step, the agent is trained with the reward rt = R + β · bru , where R is the reward provided by the environment, β a scaling parameter, and bru is the reward-unrelated bonus constructed by agents in different methods. In the no-reward exploration stage, though R is not available, the agent can complete the exploration task according to the wishes of humans under the guidance of bru . In reward-related exploitation methods, the agent’s task is to improve the efficiency of using rewarded experience and train a specific attempt into a stable policy as soon as possible. These methods will additionally hold a reward-related buffer Drr about the success trajectory and use the data in Drr to update a specific loss Lrr to guide the agent’s training process. The additional training data and loss function make the agent more likely to get rewarded policy when updating parameters of Q(s, a; θq ) or π(a|s; θπ ). 2.3
ARTS: A New Evaluation Metric via Sparsity
Existing algorithms use a variety of evaluation criteria when evaluating their performance, which affects the comparative analysis among them. For example, they usually use score, life remaining, exploration range, or time spent. Most of these evaluation metrics are environment-dependent. This makes it impossible to compare the algorithms’ results in different environments. Even in the same environment, if two algorithms use different metrics, there is no fairness in comparing them. Since we already have sparsity to measure the environment, we also design corresponding metrics to evaluate the algorithms. The score (Sc) is an evaluation criterion used in some papers, but this criterion is not entirely reasonable. Firstly, most methods use a normalized reward function in the training process, such as {1, −1, 0}. However, they evaluate the results with real rewards that are not normalized. The training and testing will disagree if we only use the score as a criterion. Secondly, in the same game,
Fig. 1. In Montezuma’s Revenge, after exploring the first room (b), no matter opening the left door (a) or the right door (c), the agent will experience two rooms with the same difficulty, get the same number of rewards, but the final reward value is different.
Towards a Unified Benchmark for Reinforcement Learning
195
under the same difficulty, the final results will be different due to different initial actions, but the algorithm can not pay attention to this when training. For example, in Montezuma’s Revenge (Fig. 1), after getting the key, the difficulty of opening the door to the left and right is the same, but the four rewards obtained after opening the left door are {100, 300, 1000, 1000}, and the four rewards obtained after opening the right door are {100, 300, 100, 3000}. It is the same during algorithm training (both {+1, +1, +1, +1}), but the evaluation results are different (2400 and 3500), so it is unfair only to evaluate by scores. To get a more comprehensive analysis, we will record the following criteria during the experiment process for analysis: – Episode Scores Sc. Score is the most commonly used criterion, but the scale of scoring varies significantly in different games, so this criterion is more suitable for comparison in the same game. – Episode Reward Times T im. Like the “sparsity”, T im is a more fair criterion among games, especially since we have clipped rewards when training. – Episode Length Len. It is an auxiliary indicator to measure how long an agent can survive. im ) as a new evaluation Here, we use the Average Reward Times per Step ( TLen metric, named ARTS. The ARTS provides a more comprehensive representation of the experimental results and a good correspondence with the sparsity. The ability of an algorithm will be described as the ability to obtain the maximum number of positive rewards with the minimum number of steps in an environment of certain sparsity. With these evaluation metrics, we can compare the improvement of algorithms in different environments instead of comparing algorithms in only some selected environments as researchers did before.
2.4
Benchmark Construction
We construct a benchmark to unify the sparsity, algorithm classification, and new evaluation metric. Our benchmark uses a unified wrapper approach to handle environments covering most of the sparsity and applies the same protocol to implement typical algorithms for each class. It provides a simple and practical toolkit for the comparison of algorithms. We will detail the principle for environment selection, the wrapping and preprocessing details, the baseline algorithms, and the construction of the evaluation protocol. Environment Selection. In this paper, we do experiments on ten games, Venture (Ven), Montezuma’s Revenge (MR), Amdiar (Amd), Gravitar (Grav), Bankheist (Ban), Breakout (Br), Hero (Hero), Alien (Alien), Pong (Pong), MsPacMan (MsP). There are two rules for the selections. One is that it often appears as a verification environment in various algorithms for sparse reward problems. The other is that the selected environment can cover most of the sparsity range. The former enables an accurate replication of the original paper, while the latter facilitates the analysis of the algorithm properties in different environments.
196
Y. Kang et al.
Environments Wrapper. As mentioned by many methods, working directly with raw games, such as frames in Atari games, which are 210 × 160 pixel images with a 128-color palette, can be demanding in computation and memory requirements. Many methods apply some basic preprocessing steps to reduce the input dimensionality and deal with some emulator artifacts. We take some useful preprocessing to ensure that all algorithms can get the same input in all environments (most of them come from OpenAI Baselines [8] and DQN [17]), which are as follows: (a) Warp frames to 84 × 84 images. (b) Stack 4 last frames as a single observation. (c) Clip reward to {+1, 0, −1} by its sign when training. Baseline Algorithms. We also unified eight existing open-source algorithms on our benchmark for comparison. They are DQN [17], PPO [26], A2C [16], RND [5], SIL [19], Empowerment [15], ICM [21], HashCount [31]. DQN, A2C, and PPO are chosen because many existing methods are based on them for improvement. Others are standard methods to solve the sparse reward problems. We analyze the relationship between these algorithms and the sparsity of the environment by the benchmark. Meanwhile, the analysis results also illustrate some limitations of their contribution to the environment with different sparsity. Of course, there are more excellent algorithms that we do not reproduce because of the limited time and sources, such as DTSIL [9], PixelCNN [20], Plan2Explore [27], and BeBold [33]. This part is not intended to cover reward-related methods with expert data, such as DQFD [12] and Imitation Learning [22], for we do not have expert data that can be used for a fair evaluation. This is just a comparison framework, and not every algorithm is refined. Our benchmark completes the comparative analysis of existing algorithms and provides a comparison protocol for the new algorithms. Evaluation Protocol. The source codes of different algorithms vary greatly. To enable a fair comparison in the same setting, we treat their common parts uniformly as follows: Firstly, all the agents are trained asynchronously with sixteen workers using stochastic gradient descent. We used the Adam optimizer with its parameters shared across the workers. Secondly, the policy network uses a series of three convolution layers, where the number of filters, kernel size, and stride in each layer are (32, 8, 4), (64, 4, 2), and (64, 3, 1) respectively. The ReLU nonlinearity is used after each convolution layer. Other algorithm-specific networks are implemented referring to the respective open-source code. Additionally, we implement these network structures based on TensorFlow (v1.14). The learning results are reported after training 50 million steps.
3
Experimental Evaluations
We first give the sparsity in various environments with the DQN algorithm to verify the rationality of the sparsity definition. Then the experimental results of several algorithms show the superiority of our benchmark and evaluation metric.
Towards a Unified Benchmark for Reinforcement Learning
197
Fig. 2. The orange histogram is the logarithm of sparsity. The blue histogram is the DQN’s performance on games (reward times in 0.1 million steps), which presents a downward trend (blue dotted line ) with the increase of sparsity. (Color figure online)
3.1
Measuring the Sparsity of Environments
The orange histogram in Fig. 2 visually demonstrates the reward sparsity of some common environments. We calculate the sparsity through the number of rewards divided by the episode’s length. Several games that are often mentioned as “sparse reward games” can indeed be considered as sparse reward environments, but there are differences among them. For example, the Pitfall is farther away from Venture, Montezuma’s Revenge, and Freeway in the overall ranking. From Fig. 2, by random sample, Venture, Montezuma’s Revenge, and Freeway get a sparsity of nearly 1.0(log(Sparsity) = 0.0), Alien and Qbert are in the middle of the ranking, and MsPacMan is at the bottom. Though they can not be regarded as dense reward environments, some games we used to think of as sparse reward environments are not so sparse. In this paper, we perform DQN to illustrate the weak point of traditional DRL. We report the DQN’s results on games with different sparsity by the blue histogram in Fig. 2. DQN obtains the result after 0.1 million steps of explorations. After a joint analysis with sparsity, we can see that the DQN’s performance is worsening with sparsity growth. We can see from the above trends that deep reinforcement learning faces a significant challenge in sparse reward environments. The reason behind this is related to sparsity. When learning in a sparse reward environment, the time the agent spends in the no-reward exploration stage varies with the sparsity. There is also little chance of reaching the local-reward exploitation stage. For the traditional DRL, they simply assume that the reward is the only criterion for the policy during learning and ignore the agent’s attention to other information. This simplification improves the learning speed in a dense reward environment while having little impact on the final policy. In a sparse reward environment, this drawback is amplified. It is difficult for the agent to understand many states correctly in the no-reward exploration stage.
198
Y. Kang et al. Table 3. Average Reward Times per Step (ARTS) in each game
ART S
MR
Amd
Grav
Ban
Br
Alien
Pong
Hero
MsP
Sparsity 0
Ven
0
0.0003
0.0005
0.0008
0.0014
0.0036
0.0062
0.0064
0.0086
A2C PPO
0 0
0 0
0.0769 0.0069 0.0726 0.0415 0.1961 0.0128 0.0177 0.2068 0.0801 0.0328 0.0725 0.0582 0.1893 0.0128 0.2154 0.2107
EMP ICM RND Hash
0.0099 0 0.0097 0.0021
0 0.0037 0.0081 0.0006
0.0596 0.0611 0.0639 0.0545
SIL
0
0.0078 0.0622
0.0218 0.0143 0.0169 0.0151
0.0592 0.0592 0.0601 0.0404
0.0339 0.0722
0.0429 0.0444 0.0452 0.0352
0.1989 0.0826 0.0781 0.0642
0.0712 0.1465
0.0098 0.0085 0.0073 0.0067
0.1811 0.1903 0.1859 0.1836
0.1815 0.1889 0.1864 0.1574
0.0131 0.2167 0.1751
Fig. 3. The average scores of the agent are reported during training every 10K steps in the environments. Training are truncated at 50 million steps in each game.
3.2
Algorithm Evaluations
After analyzing the environments, the benchmark selected several representative environments in each sparsity interval to complete the comparison of the algorithms. Environments, baseline algorithms, and evaluation protocols are detailed
Towards a Unified Benchmark for Reinforcement Learning
199
in Sect. 2.4. The benchmark performance of seven algorithms in ten environments is shown in Table 3. Furthermore, Fig. 3 shows the learning process. im ) instead of the traditional criteria score (Sc) in Table 3 We use ARTS ( Tlen to evaluate the performance of each algorithm. This criterion ignores the reward scale at different stages of the games, and it can be used as a complementary judgment for Sc. This criterion shows that different algorithms perform differently as the sparsity varies. The environments in Table 3 are sorted by their sparsity. The algorithms are selected from the basic RL algorithms and a few typical representatives of the two classes mentioned in Sect. 2.2. The blackened part is the two best performing of several algorithms in an environment. From the results, it is clear that for particularly sparse environments, such as Venture and Montezuma’s Revenge, the no-reward exploration methods (EMP or RND) work well. For a relatively sparse reward environment, such as Gravitar and Breakout, the local reward exploitation method (SIL) performs well. As the sparsity decreases, most algorithms for sparse rewards perform poorly. The basic RL algorithms (A2C, PPO) can do it well. The reason is that they spend part of their experience calculating and focusing on a prior other than the exploration reward. In contrast, the environmental reward itself can already guide the agent’s learning well, so the bonus becomes a kind of noise.
4
Conclusive Remarks
This work establishes a unified benchmark for RL algorithms in sparse reward environments. The constructed benchmark contains two parts: the environment selection and the algorithm comparison. Firstly, we define sparsity to quantitatively measure the percentage of reward states in the environment and choose environments that cover most of the sparsity range. Secondly, to compare the algorithms, we classify the existing algorithms according to the sparsity inspiration, propose a fair evaluation metric, and give a unified implementation setting. Analyzing the benchmarks reveals the advantages and disadvantages of the algorithms we classified. This unified benchmark will make it easier for researchers to replicate, refine, and identify new ideas.
References 1. Aytar, Y., Pfaff, T., Budden, D., et al.: Playing hard exploration games by watching youtube. In: NeurIPS, pp. 2935–2945 (2018) 2. Baumli, K., Warde-Farley, D., Hansen, S., et al.: Relative variational intrinsic control. In: AAAI, pp. 6732–6740 (2021) 3. Bellemare, M., Srinivasan, S., Ostrovski, G., et al.: Unifying count-based exploration and intrinsic motivation. In: NeurIPS, pp. 1471–1479 (2016) 4. Bellemare, M.G., Naddaf, Y., Veness, J., et al.: The arcade learning environment: an evaluation platform for general agents. JAIR 47(1), 253–279 (2013) 5. Burda, Y., Edwards, H., Storkey, A., et al.: Exploration by random network distillation. In: ICLR, pp. 1–17 (2018)
200
Y. Kang et al.
6. Chen, Z., Lin, M.: Self-imitation learning in sparse reward settings. arXiv preprint: arXiv:2010.06962 (2020) 7. Dai, T., Liu, H., Anthony Bharath, A.: Episodic self-imitation learning with hindsight. Electronics 9(10), 1742 (2020) 8. Dhariwal, P., Hesse, C., Klimov, O., et al.: OpenAI Baselines (2017). https:// github.com/openai/baselines 9. Guo, Y., Choi, J., Moczulski, M., et al.: Memory based trajectory-conditioned policies for learning from sparse rewards. In: NeurIPS, pp. 4333–4345 (2020) 10. Guo, Y., Oh, J., Singh, S., Lee, H.: Generative adversarial self-imitation learning. arXiv preprint: arXiv:1812.00950 (2018) 11. Hessel, M., Modayil, J., Van Hasselt, H., et al.: Rainbow: combining improvements in deep reinforcement learning. In: AAAI, pp. 3215–3222 (2018) 12. Hester, T., Vecerik, M., Pietquin, O., et al.: Deep q-learning from demonstrations. In: AAAI, pp. 3223–3230 (2018) 13. Itti, L., Baldi, P.: Bayesian surprise attracts human attention. In: NeurIPS, pp. 547–554 (2005) 14. Kim, H., Kim, J., Jeong, Y., et al.: EMI: exploration with mutual information. In: ICML, pp. 3360–3369 (2019) 15. Leibfried, F., Pascual-D´ıaz, S., Grau-Moya, J.: A unified bellman optimality principle combining reward maximization and empowerment. In: NeurIPS, pp. 7869–7880 (2019) 16. Mnih, V., Badia, A.P., Mirza, M., et al.: Asynchronous methods for deep reinforcement learning. In: ICML, pp. 1928–1937 (2016) 17. Mnih, V., Kavukcuoglu, K., Silver, D., et al.: Human-level control through deep reinforcement learning. Nature 518(7540), 529–533 (2015) 18. Ng, A.Y., Russell, S.J., et al.: Algorithms for inverse reinforcement learning. In: ICML, pp. 663–670 (2000) 19. Oh, J., Guo, Y., Singh, S., Lee, H.: Self-imitation learning. In: ICML, pp. 3875– 3884 (2018) 20. Ostrovski, G., Bellemare, M.G., Oord, A., et al.: Count-based exploration with neural density models. In: ICML, pp. 2721–2730 (2017) 21. Pathak, D., Agrawal, P., Efros, A.A., et al.: Curiosity-driven exploration by selfsupervised prediction. In: ICML, pp. 2778–2787 (2017) 22. Peng, X.B., Abbeel, P., Levine, S., et al.: DeepMimic: example-guided deep reinforcement learning of physics-based character skills. TOG 37(4), 1–14 (2018) 23. Pohlen, T., Piot, B., Hester, T., et al.: Observe and look further: achieving consistent performance on Atari. arXiv preprint: arXiv:1805.11593 (2018) 24. Ross, S., Gordon, G.J., Bagnell, J.A.: A reduction of imitation learning and structured prediction to no-regret online learning. AISTATS 1(2), 627–635 (2011) 25. Savinov, N., Raichuk, A., Vincent, D., et al.: Episodic curiosity through reachability. In: ICLR, pp. 1–20 (2019) 26. Schulman, J., Wolski, F., Dhariwal, P., et al.: Proximal policy optimization algorithms. arXiv preprint: arXiv:1707.06347 (2017) 27. Sekar, R., Rybkin, O., Daniilidis, K., et al.: Planning to explore via self-supervised world models. In: ICML, pp. 8583–8592 (2020) 28. Singh, S., Lewis, R.L., Barto, A.G., et al.: Intrinsically motivated reinforcement learning: an evolutionary perspective. TAMD 2(2), 70–82 (2010) 29. Strehl, A.L., Littman, M.L.: An analysis of model-based interval estimation for Markov decision processes. JCSS 74(8), 1309–1331 (2008) 30. Taiga, A.A., Fedus, W., Machado, M.C., et al.: On bonus based exploration methods in the arcade learning environment. In: ICLR, pp. 1–20 (2020)
Towards a Unified Benchmark for Reinforcement Learning
201
31. Tang, H., Houthooft, R., Foote, D., et al.: # exploration: a study of count-based exploration for deep reinforcement learning. In: NeurIPS, pp. 2753–2762 (2017) 32. Zhang, C., Cai, Y., Huang, L., Li, J.: Exploration by maximizing Renyi entropy for reward-free RL framework. In: AAAI, pp. 10859–10867 (2021) 33. Zhang, T., Xu, H., Wang, X., et al.: BeBold: exploration beyond the boundary of explored regions. arXiv preprint: arXiv:2012.08621 (2020)
Effect of Logistic Activation Function and Multiplicative Input Noise on DNN-kWTA Model Wenhao Lu1 , Chi-Sing Leung1(B) , and John Sum2 1
2
Department of Electronic Engineering, City University of Hong Kong, Kowloon, Hong Kong [email protected], [email protected] Institute of Technology Management, National Chung Hsing University, Taichung, Taiwan [email protected]
Abstract. The dual neural network-based (DNN) k-winner-take-all (kWTA) model is one of the simplest analog neural network models for the kWTA process. This paper analyzes the behaviors of the DNNkWTA model under these two imperfections. The two imperfections are, (1) the activation function of IO neurons is a logistic function rather than an ideal step function, and (2) there are multiplicative Gaussian noise in the inputs. With the two imperfections, the model may not be able to perform correctly. Hence it is important to estimate the probability of the imperfection model performing correctly. We first derive the equivalent activation function of IO neurons under the two imperfections. Next, we derive the sufficient conditions for the imperfect model to operate correctly. These results enable us to efficiently estimate the probability of the imperfect model generating correct outputs. Additionally, we derive a bound on the probability that the imperfect model generates the desired outcomes for inputs with a uniform distribution. Finally, we discuss how to generalize our findings to handle non-Gaussian multiplicative input noise. Keywords: DNN-kWTA · Logistic activation function logic units (tlus) · Multiplicative Input Noise
1
· Threshold
Introduction
The goal of the winner-take-all (WTA) process is to identify the largest number from a set of n numbers [1]. The WTA process has many applications, including sorting and statistical filtering [2,3]. An extension of the WTA process is the k-winner-take-all (kWTA) process [4,5], which aims to identify the k largest numbers from the set. From the dual neural network (DNN) concept, Hu and Wang [5] proposed a low complexity kWTA model, namely DNN-kWTA. This model contains n input-output (IO) neurons, one recurrent neuron, and only 2n + 1 connections. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 202–214, 2023. https://doi.org/10.1007/978-981-99-1639-9_17
Effect of Logistic Activation Function
203
For ideal realization, the activation function of IO neurons should behave like a step function and there are no noise in the inputs. However, for circuit realization, the activation function often behaves like a logistic function [6,7]. In addition, the operation of IO neurons may be affected by random drifts and thermal noise [8–10]. These two imperfections can affect the functional correctness. In [11,12], Sum et al. and Feng et al. presented the analytical results of a noisy DNNkWTA network, including the convergence and the performance degradation. However, Also, those results are based on the assumption that the noise are additive. In some situations, noise level are proportional to the signal levels. For instance, when we amplify the signal, the noise are amplified too. Hence, it is more suitable to use the multiplicative noise model to describe the behavior of the input noise [13,14]. This paper analyzes the imperfect DNN-kWTA model with non-ideal activation function and existence of multiplicative input noise. We first assume that the multiplicative input noise are zero mean Gaussian distributed and then perform the analysis. Afterwards, we generalize our result to the non-Gaussian input noise. We derive an equivalent model to describe the behaviour of the imperfect DNN-kWTA model. From the equivalent model, we derive sufficient conditions to check whether the imperfect model can generate the desired results or not. We can use this condition to study the probability of the model generating correct outputs without simulating the neural dynamics. For uniformly distributed inputs, we derive a lower bound formula to estimate the probability that the imperfect mode can generate the correct outputs. Finally, we generalize our results to the non-Gaussian multiplicative input noise case. This paper is organized as follows. Section 2 presents the background of the DNN-kWTA model. Section 3 studies the properties and performance of the DNN-kWTA models under the two imperfections. Section 4 extends the result to the non-Gaussian input noise case. Experimental results are shown in Sect. 5. Section 6 summarizes our results.
Fig. 1. Structure of a DNN-kWTA network.
2
Basic DNN-kWTA
Figure 1 illustrates the DNN-kWTA structure, which consists of a recurrent neuron and n input-output (IO) neurons. The state of the recurrent neuron
204
W. Lu et al.
is denoted as y(t). Each of the IO neurons has an external input, denoted as {u1 , , un }. It of them is associated an output, denoted as xi . All inputs ui are distinct and range from 0 to 1. In the context of the DNN-kWTA model, the recurrent state y(t) is governed by dy(t) = xi (t) − k, where xi (t) = h(ui −y (t)) and h(ϕ) = dt i=1 n
1 if ϕ ≥ 0, (1) 0 otherwise.
where is the characteristic time constant, which depends on the recurrent neuron’s capacitance and resistance. In (1), h(·) denotes the activation function of IO neurons. In the original DNN-kWTA model, h(·) is an ideal step function. A nice property of the DNN-kWTA model is that its state converges to an equilibrium state in finite time. At the equilibrium state, only the IO neurons with the k largest inputs produce outputs of 1. All other neurons produce outputs of 0.
3 3.1
Logistic DNN-kWTA with Input Noise DNN-kKWTA Under Imperfection
In realization, the activation function of IO neurons often resembles a logistic function [6], and noise is unavoidable in analog circuits. This paper considers the coexistence of these two imperfections in the DNN-kWTA model. The first imperfection is that the activation function is a logistic function, given by hα (ϕ) =
1 , 1 + e−αϕ
(2)
where α is the gain factor. Also, there are multiplicative noise at the inputs of IO neurons. That is, the noisy inputs are given by ui + εi (t) ui ,
(3)
where εi (t) ui is the input noise for the i-th input. In this model, the noise level depends the normalized noise εi (t), as well as input ui . In this paper, we assume that εi (t)’s are Gaussian distributed with zero mean and variance of σ 2 . With the two imperfections, the behaviour of the DNN-kWTA can be described as dy(t) = x i (t) − k, dt i=1 n
x i (t) = hα (ui + ui εi (t) − y (t)), 1 hα (ϕi ) = , where ϕi = ui + ui εi (t) − y (t) . 1 + e−αϕi
(4) (5) (6)
In the presence of noise in the inputs, the outputs x i (t) may change with time. Therefore, for the DNN-kWTA model with input noise, it is necessary to take
Effect of Logistic Activation Function
205
multiple measurements of the outputs of IO neurons to obtain the average output values as the neurons’ outputs. Figure 2 illustrates the effect of non-activation function and multiplicative Gaussian noise. In the first case, as shown in Fig. 2(a), When the gain factor is large enough and the noise variance is small, the recurrent state converges to around 0.5299 and thus the outputs of the network are correct. When the gain parameter α is reduced to 15, the recurrent state converges to 0.5145 and thus the outputs are incorrect, as shown in Fig. 2(b). If we increase the noise level to 0.2, the recurrent state converges to 0.5164 and thus the outputs are incorrect too, as shown in Fig. 2(c). Clearly, the gain parameter value α and the noise level σ 2 can affect the operational correctness. =100
X1 Y 0.529924
=100
0.2
0
0.2
0 0
0.5
1
1.5
2
t characteristic time
(a)
=15
X 1.0006 Y 0.516403
0.4
y(t) 0.2
=0.02 and
0.6
X 1.02 Y 0.514564
0.4
y(t)
0.4
=0.2 and
0.6
y(t)
=0.02 and
0.6
0 0
0.5
1
1.5
t characteristic time
(b)
2
0
0.5
1
1.5
2
t characteristic time
(c)
Fig. 2. The dynamics of the recurrent state in a DNN-kWTA with n = 5 and k = 3. The inputs are {u1 , · · · , u5 } = {0.54, 0.61, 0.52, 0.55, 0.51}. (a) Gain α = 100 and noise level σ = 0.02. At the equilibrium, the recurrent state converges to 0.5299 and thus the outputs are {x1 , · · · , x5 } = {0.744, 0.999, 0.273, 0, 0.867, 0.137}. Clearly, only x1 , x2 and x4 are greater than 0.5, and the outputs are correct. (b) Gain α = 100 and noise level σ = 0.2. At the equilibrium, the recurrent state is 0.5145, the outputs are {x1 , · · · , x5 } = {0.592, 0.818, 0.528, 0, 0.587, 0.491} and they are incorrect. (c) Gain α = 15 and noise level σ = 0.02. At the equilibrium, the recurrent state is 0.5164. Thus, final outputs are {x1 , · · · , x5 } = {0.591, 0.806, 0.503, 0, 0.632, 0.485} and they are incorrect.
3.2
Equivalent Model
This subsection derive a model to simulate the dynamic behaviour of the model under the two imperfections. We use the Haley’s approximation for the Gaussian distribution [15]. Lemma 1. Haley’s approximation: A logistic function 1+e1−ρz can be model by the distribution function of Gaussian random variables, given by z v2 1 1 √ e− 2 dv. ≈ (7) −ρz 1+e 2π −∞ where ρ = 1.702. From Lemma 1, the equivalent dynamics can be described by Theorem 1.
206
W. Lu et al.
Theorem 1. For the imperfect DNN-kWTA model with the two mentioned imperfections, we can use the following equations to describe its dynamic behaviour, given by dy(t) = x ¯i (t) − k, dt i=1 n
(8)
x ¯i (t) = hα˜ i (ui − y(t)), 1 , where α ˜= hα˜ i (ϕi ) = ˜ i 1 + e−αϕ
(9) 1 1 α2
+
σ 2 u2i ρ2
and ϕi = ui − y (t) . (10)
Proof: From (4), the update of the recurrent state can be written as n t+δ t+δ dy(τ ) dy(t) = y(t)+ dτ = y(t)+ x ˜i (τ )dτ −kδ (11) y(t + δ) = y(t)+δ dt dτ t i=1 t t+δ ˜i (τ )dτ can be expressed where δ is a small positive real number. The term t x t+δ
M x˜i (t+jζ) as t x ˜i (τ )dτ = lim ζM j=1 M , where ζ × M = δ. We can further M →∞ t+δ ˜i (τ )dτ as rewrite t x t
t+δ
x ˜i (τ ) dτ = δ × (mean of x ˜i (t)) = δ × E [˜ xi (t)] = δ × x ¯i (t) = δ × E [hα (ϕi )] ,
where ϕi = ui + εi (t) ui − y (t) is the input of the i-th IO neuron. It contains the noise component εi (t) ui . As εi (t) is zero mean Gaussian distributed, we have ∞ ε2 1 1 − 2σi2 √ E [hα (ϕi )] = × e dεi , −α(ϕi +εi ui ) 2πσ 2 −∞ 1 + e where ϕi = ui − y (t). Furthermore, from Lemma 1, ∞ ϕi ε2 (v+ui εi )2 i 1 1 √ e− 2σ2 dvdεi e− 2η × √ E [hα (ϕi )] = 2 2πη 2πσ −∞ −∞ ∞ ϕ 1 v2 1 √ = √ exp − 2 (η + σ 2 u2i ) 2πη 2πσ 2 −∞ −∞ ⎞ ⎛ 2 2 iσ εi + σvu 2 u2 +η ⎟ ⎜ i × exp ⎝− ⎠ dvdεi , 2ησ 2 / (σ 2 ui + η)
(12)
where η = ρ2 /α2 . Taking the integration with respect to εi and applying Lemma 1 again, we obtain E[hα (ϕi )] =
1 x ¯i (t) hα˜ i (ϕi ), 1 + e−α˜ i s
(13)
Effect of Logistic Activation Function
where ϕi= ui −y(t), η=ρ2 /α2 and α ˜ i =
1
σ 2 u2 1 + ρ2 i α2
as
207
. Equation (11) can be rewritten
n dy(t) . y(t + δ) = y(t) + δ( x ¯i (t) − k) = y(t) + δ dt i=1
(14)
With (13) and (14), (4)–(6) can be written as dy(t) = x ¯i (t) − k, dt i=1 n
(15)
x ¯i (t) = hα˜ (ui − y(t)), 1 , hα˜ i (ϕi ) = 1 + e−α˜ i ϕi where α ˜i =
1
1 α2
+
σ 2 u2 i ρ2
(16) (17)
and ϕi = ui − y(t). The proof is completed.
It is important to note that we are not proposing a new model. Introducing the equivalent model, as stated in equations (15)–(17), helps us to analyze the properties of the imperfect model. From Theorem 1, a convergence result of the DNN-kWTA network with the multiplicative Gaussian noise and non-ideal activation function is obtained. The result is presented in Theorem 2. Theorem 2. For the imperfect DNN-kWTA network, the recurrent state y(t) converges to a unique equilibrium point.
n 1 − k. For very large y, dy Proof: Recall that dy −α ˜ i (ui −y ) i=1 dt = dt = −k < 0. 1+e
dy For very small y, dy dt = n − k > 0. Furthermore, it is worth noting that dt is a strictly monotonically decreasing function of y. Therefore, there exists a unique ∗ equilibrium point y ∗ such that dy dt |y=y = 0. Additionally, we can obtain the following properties of the model:
if y(t) > y ∗ , then
dy dt
< 0, and if y(t) < y ∗ , then
dy dt
> 0.
Suppose that at time to , the recurrent state y(to ) is greater than y ∗ . In this ∗ case, we have dy dt < 0, and y(t) decreases with time until it reaches y . On the ∗ other hand, if at time to the recurrent state y(to ) is less than y , then y(t) also decreases with time until it reaches y ∗ . This completes the proof. Since the output x ¯i (t) of the imperfect model is not strictly binary, we need to introduce new definitions for “winner” neurons and “loser” neurons. Definition 1. At the equilibrium, if x ¯i ≥ 0.5, then we call the i-th IO neuron is a winner. Otherwise, we call the i-th IO neuron is a loser. There are some relationship among equilibrium point y ∗ , winners and losers. The results are summarized in Theorem 3.
208
W. Lu et al.
Theorem 3. Consider that the inputs are {u1 , · · · , un }. Denote the sorted inputs in the ascending order are {uπ1 , · · · , uπn }, where {uπ1 , uπ2 , · · · , uπn } is the sorted index list. If uπn−k < y ∗ ≤ uπn−k+1 , then the imperfect model can generate correct outputs. Proof: From (16) and (17), we know for a given y, if ui < ui , then hα˜ i (ui −y) < hα˜ i (ui − y). Also, hα˜ i (0) = 0.5 and hα˜ i (ui − y) is an increasing function ¯ π1 < · · · < x ¯πn−k < 0.5. Hence IO of ui . Thus, if uπn−k < y ∗ , then x neuron π1 to IO neuron πn−k are losers. Similarly, if uπn−k+1 ≥ y ∗ , then ¯πn . Hence IO neuron πn−k+1 to IO neuron πn are 0.5 ≥ x ¯πn−k+1 > · · · > x winners. The proof is completed. There is a common misconception that we can use Theorem 3 to study the probability of the imperfect model operating correctly. This approach involves simulating the neural dynamics for many sets of inputs, obtaining the equilibrium point y ∗ for each set, and then checking whether the imperfect model produces correct outputs or not. However, simulating the dynamics is quite timeconsuming. Therefore, it is of interest to find a more efficient way to estimate the probability value of correct operation. The following theorem, based on the equivalent model, provides us with a convenient way to estimate the probability value without the need to simulate the neural dynamics. Theorem 4. Denote the sorted inputs in the ascending order are {uπ1 , · · · , uπn }, where {uπ1 , uπ2 , · · · , uπn } is the sorted index list. For the imperfect model, n
hα˜ i (ui − y)y=u
πn−k
i=1
> k and
n
hα˜ i (ui − y)y=u
πn−k+1
i=1
≤ k,
(18)
if and only if uπn−k < y ∗ ≤ uπn−k+1 . In addition, if n
hα˜ i (ui − y)y=uπ
i=1
n−k
> k and
n
hα˜ i (ui − y)y=uπ
i=1
n−k+1
≤ k,
then the model generates the desired outputs.
n Proof: Denote H(y) = i=1 hα˜ i (ui − y) − k. As y → ∞, H(y) = −k. Also, as y → −∞, H(y) = n − k. Since H(y) is a strictly monotonically decreasing function of y, we have H(y) > 0, ∀y < y ∗ , and H(y) ≤ 0, ∀y ≥ y ∗ . Hence H(uπn−k ) > 0, if and only if, uπn−k < y ∗ . In addition, we have: H(uπn−k+1 ) ≤ 0, if and only if, uπn−k+1 ≥ y ∗ . Furthermore, if
n i=1
hα˜ i (ui − y)y=u
πn−k
> k and
then uπn−k < y ∗ ≤ uπn−k+1 .
n i=1
hα˜ i (ui − y)y=u
πn−k+1
≤ k, (19)
Effect of Logistic Activation Function
209
According to the result of Theorem 3, the condition “uπn−k < y ∗ ≤ uπn−k+1 ” implies that the model correctly identifies the winner and loser neurons for the given n numbers. Theorem 4 provides us with an efficient way to estimate the probability value of the imperfect model correctly identifying the winner and loser neurons without the need to simulate the neural dynamics. To do so, we can consider many sets of inputs and sort the data for each input set. Then, we can compute the following expression: n
hα˜ i (ui − y)y=u
πn−k
i=1
− k and
n
hα˜ i (ui − y)y=u
i=1
πn−k+1
−k
to determine whether the model can generate the correct output or not. It should be noticed that the aforementioned procedures are suitable for the data with any distribution. When the inputs are iid uniform random variables between zero and one, Theorem 5 provides us with a lower bound on the probability value. Theorem 5. If the inputs are iid uniform random variables with a range from zero to 1, the probability Prob(correct) that the imperfect model correctly identifies the winner and loser neurons can be expressed as: n−1 2 2 Prob(correct) ≥ 1 − 2 1 − 1 − 1 + (n − 1) , α ˜ α ˜ 2 where α ˜ = 1/ α12 + σρ2 . Proof. Since the complete proof is lengthy, we only outline the flow of the proof here. Since the effect of zero-mean Gaussian noise is equivalent to decreasing the gain factor of the logistic function, we can use the flow of the proof in Theorem 4 of [7] to obtain our result. Probability theory tells us that any non-uniformly distributed data can be mapped into a uniform distribution through histogram equalization. This mapping does not affect the ordering of the original non-uniform inputs. Therefore, we can apply Theorem 5 to handle non-uniformly distributed data.
4
Non-gaussian Multiplicative Input Noise
Although we focus on the multiplicative input noise with the Gaussian distribution, our analysis can be extended to cases where the multiplicative input noise has a non-Gaussian distribution. This technique is based on the idea of the Gaussian mixture model (GMM) [16]. We can use the GMM concept to approximate the density function of the normalized noise component εi (t). In the context of GMM, the density function of εi (t) can be represented as follows: f (εi (t)) =
L l=1
(ε (t) − μ )2 Ξl i l , exp − 2 2 2ς 2πςl l
(20)
210
W. Lu et al.
where μl is the mean of the l-th component, ςl2 is the variance of the l-th component, and Ξl is the weighting of the l-th component. Note that the sum of Ξl ’s is equal to 1. By following the steps presented in Sect. 3, we can derive the equivalent dynamics for the case of non-Gaussian multiplicative input noise and non-ideal activation function. The equivalent dynamics are sated in the following theorem. Theorem 6. The equivalent dynamics for the non-Gaussian multiplicative input noise and non-ideal activation function are given by dy = x ¯i (t) − k, dt i=1
(21)
o¯i (t) = hi (ui − y(t)),
(22)
n
hi (ui − y(t)) =
L l=1
α ˜ i,l =
Ξl , 1 + e−α˜ i,l (ui −y(t)+ui μi ) 1
1 α2
+
u2i ςl2 ρ2
.
(23) (24)
Similar to the approach presented in Sect. 3, we can also develop an efficient method to estimate the probability of the model generating correct winner and loser neurons for the case of non-Gaussian multiplicative input noise and nonideal activation function. Theorem 7. For the non-Gaussian multiplicative input noise and
non n n h (u − y(t))| > k and ideal activation function, if y(t)=uπ i=1 i i i=1 hi (ui − y(t))|y(t)=uπ
5
n−k
n−k+1
≤ k, then the model has the correct operation.
Simulation Results
In Theorem 1 and Theorem 6, we introduce equivalent models to describe the behaviour of the imperfect DNN-kWTA model. Afterwards, based on the equivalent model, we propose the ways (Theorem 4, Theorem 5 and Theorem 7) to predict the performance of the imperfect model. The aim of this section is verified our results. 5.1
Effectiveness of Theorem 4
Three settings: {n = 6, k = 2, α = 500}, {n = 11, k = 2} and {n = 21, k = 5} are considered. In this subsection, we consider that the inputs follow Beta distrid−1 Γ(c+d) c−1 x (1 − x) , where c = d = 2, and Γ(.) bution with Betac,d (x) = Γ(c)Γ(d) denotes the well known Gamma function. To study the probability value of the imperfect model performing correctly, we generate 10,000 sets of inputs. Time-varying multiplicative input noise “εi (t) ui ” ’s are added into inputs, where ε (t)’s are zero-mean Gaussian distributed with variance of σ 2 . We consider
Effect of Logistic Activation Function
211
three gain values:{n = 6, k = 2, α = 500}, {n = 11, k = 2, α = 1000} and {n = 21, k = 5, α = 2500} . When dealing with the non-uniform input case, we have two methods to measure the probability values of the imperfect model correctly identifying the winner and loser neurons. One way is to use the original neural dynamics stated in (4)–(6). Another method is to use Theorem 4 to check the performance of the imperfect model. In this method, we only need to use (18) to determine whether the imperfect model can correctly identify the winner and loser neurons for each set of inputs. The results are shown in Fig. 3. It can be seen that the results obtained from Theorem 4 are quite close to the results obtained from the original neural dynamics over a wide range of noise levels and various settings. For example, from Fig 3, for the case of {n = 6, k = 2, α = 500, σ = 0.06309}, the two probability values from the two methods are 0.9416 and 0.09442, respectively.
1
X 0.0630957 Y 0.9442 X 0.0630957 Y 0.9416
0.9
0.8 Original dynamics Measurement: Theorem 4 0.7 10-3
10
-2
10
-1
10
0
Beta distribution Input: n=11, k=2, =1000
1
X 0.0315479 Y 0.9491
Probability value
Beta distribution Input: n=6, k=2, =500
Probability value
Probability value
1
X 0.0315479 Y 0.9464
0.9
0.8 Original dynamics Measurement: Theorem 4 10
-3
10
-2
10
-1
X 0.0157739 Y 0.9375
0.9
X 0.0157739 Y 0.943
0.8 Original dynamics Measurement: Theorem 4 0.7
0.7
Beta distribution Input: n=21, k=5, =2500
10-3
10-2
: noise level
: noise level
: noise level
(a)
(b)
(c)
10-1
Fig. 3. The inputs are with the Beta distribution in (0, 1), and the multiplicative input noise components are εi (t)ui , where ε (t)’s are zero-mean Gaussian distributed with variance of σ 2 .
5.2
Effectiveness of Theorem 5
In this subsection, we study the effectiveness of Theorem 5. For uniform inputs, there is an additional method to estimate the performance. The simulation settings are similar to those of Sect. 5.2, except that the inputs are uniformly distributed. When the inputs are uniformly distributed, we can use the lower bound from Theorem 5 to estimate the chance of identifying the correct winners and losers. The results are shown in Fig. 4. There are three methods to estimate the probability values. The first method is based on the original neural dynamics, which is quite time consuming. The second method is from Theorem 4, in which we should have input data sets. The last method involves using the lower bound from Theorem 5. The advantage of this method is that there is no need to perform the time-consuming simulation of the neural dynamics or require input data sets.
212
W. Lu et al.
The results are in Fig 4. First, it can be seen that the probability values obtained from Theorem 4 are very close to those obtained from simulating the original neural dynamics. The probability values obtained from Theorem 5 are lower than the values obtained the original neural dynamics and Theorem 4. It is because Theorem 5 gives lower bounds on the probability values. But, the advantage of Theorem 5 is that there is no need to have input data sets. We can use our result to know the noise tolerant level of the model. For example, for {n = 21, k = 5, β = 2500} with the target probability value equal to 0.95, the input noise level σ should be less than 0.0157 from Theorem 4, while the result of the low bound tells us that the input noise level σ should be less than 0.00995.
Probability value
X 0.0398107 Y 0.94172
0.9
X 0.0630957 Y 0.9535
0.85 0.8 0.75 0.7 10-3
Original dynamics Measurement: Theorem 4 Low bound: Theorem 5 10-2
1
X 0.0199054 Y 0.94734
0.85 0.8
0.7
1
X 0.0315479 Y 0.9468
0.9
0.75 10-1
Uniform distribution Inputs: n=11, k=2, =1000
0.95
Original dynamics Measurement: Theorem 4 Low bound: Theorem 5 10-3
: noise level
(a)
10-2
: noise level
(b)
Uniform distribution Inputs: n=21, k=5, =2500
0.95
Probability value
Uniform distribution Inputs: n=6, k=2, =500
Probability value
1 0.95
X 0.0157739 Y 0.9485
0.85 0.8 0.75
10-1
X 0.00995268 Y 0.95021
0.9
0.7 10-4
Original dynamics Measurement: Theorem 4 Low bound: Theorem 5 10-3
10-2
10-1
: noise level
(c)
Fig. 4. The inputs are uniformly distributed in (0, 1), and the multiplicative input noise components are εi (t)ui , where ε (t)’s are zero-mean Gaussian distributed with variance of σ 2 .
5.3
Effectiveness of Theorem 7
We can Theorem 7 to predict the performance of the model for non-Gaussian distributed multiplicative input noise. To study the performance in this case, we consider that the inputs follow a uniform distribution with a range of (0, 1) in this subsection. We generated 10,000 sets of inputs for this purpose. We consider that multiplicative input noise are “εi (t) ui ” ’s, where εi (t)’s are uniformly distributed in the range [−Δ/2, Δ/2]. We chose to use a uniform distribution to demonstrate that the GMM concept is capable of handling nonbell-shaped distributions. This is because the uniform distribution has a rectangular shape, which is significantly different from the Gaussian distribution. In the simulation, for each noise level, we build a GMM with 11 components. We consider three settings: {n = 6, k = 2, α = 500}, {n = 11, k = 2, α = 1000} and {n = 21, k = 5, α = 2000}. To validate the effectiveness of Theorem 7, we also use the original neural dynamics to estimate the probability of the model having the correct operation. It should be noticed that this simulation method is quite time consuming. The results are shown in Fig. 5. From the figure, the result of Theorem 7 is very close to that of simulating the original neural dynamics.
Effect of Logistic Activation Function
213
Again, we can use Theorem 7 to predict the tolerant level for input noise. For example, for {n = 6, k = 2, α = 500} with the target probability value equal to 0.95, Theorem 7 tells us that the input noise range δ should be less than 0.1999. Non-Gaussian multiplicative input noise n=6, k=2, =500
X 0.199054 Y 0.9416
0.9 0.85 0.8
Original dynamics Measurement: Theorem 7 10-2
10-1
: noise level
(a)
Non-Gaussian multiplicative input noise n=11, k=2, =1000
1
Non-Gaussian multiplicative input noise n=21, k=5, =2000
X 0.0790569 Y 0.9562
0.95
X 0.0790569 Y 0.9542
0.9 0.85 0.8
100
Probability value
X 0.199054 Y 0.9449
0.95
1
Probability value
Probability value
1
Original dynamics Measurement: Theorem 7 10-2
10-1
: noise level
(b)
X 0.0497634 Y 0.9431
0.95
X 0.0497634 Y 0.9394
0.9 0.85 0.8 10-3
Original dynamics Measurement: Theorem 7 10-2
10-1
: noise level
(c)
Fig. 5. The inputs are uniformly distributed in (0, 1), and the multiplicative input noise components are εi (t)ui , where ε (t)’s are zero-mean uniformly distributed noise in the range of [−Δ, −Δ].
6
Conclusion
This paper presented an analysis of the DNN-kWTA model with two imperfections, namely, multiplicative input noise and non-ideal activation in IO neurons. We first developed an equivalent model to describe the dynamics of the imperfect DNN-kWTA model. It should be aware that the equivalent model is introduced for studying behaviour of the imperfect DNN-kWTA model and that it is not a new model. Using the equivalent model, we derive sufficient conditions for checking whether the imperfect model can correctly identify the winner and loser neurons. For uniform-distributed inputs, we provide a formula to estimate the lower bound on the probability value of the model with correct operation. Lastly, we extend our results to handle the non-Gaussian multiplicative input noise case. We validate our theoretical results through various simulations.
References 1. Touretzky, S.: Winner-take-all networks of O(n) complexity. In: Advances in Neural Information Processing Systems, vol. 1, pp. 703–711. Morgan Kaufmann (1989) 2. Kwon, T.M., Zervakis, M.: KWTA networks and their applications. Multidimension. Syst. Signal Process. 6(4), 333–346 (1995) 3. Narkiewicz, J.D., Burleson, W.P.: Rank-order filtering algorithms: a comparison of VLSI implementations. In: the 1993 IEEE International Symposium on Circuits and Systems, pp. 1941–1944. IEEE (1993) 4. Sum, J.P., Leung, C.S., Tam, P.K., Young, G.H., Kan, W.K., Chan, L.W.: Analysis for a class of winner-take-all model. IEEE Trans. Neural Netw. 10(1), 64–71 (1999)
214
W. Lu et al.
5. Hu, X., Wang, J.: An improved dual neural network for solving a class of quadratic programming problems and its k-winners-take-all application. IEEE Trans. Neural Netw. 19(12), 2022–2031 (2008) 6. Moscovici, A.: High Speed A/D Converters: Understanding Data Converters Through SPICE, vol. 601. Springer Science & Business Media, Berlin (2001) 7. Feng, R., Leung, C.S., Sum, J., Xiao, Y.: Properties and performance of imperfect dual neural network-based kWTA networks. IEEE Trans. Neural Netw. Learn. Syst. 26(9), 2188–2193 (2014) 8. Redout´e, J.M., Steyaert, M.: Measurement of EMI induced input offset voltage of an operational amplifier. Electron. Lett. 43(20), 1088–1090 (2007) 9. Kuang, X., Wang, T., Fan, F.: The design of low noise chopper operational amplifier with inverter. In: 2015 IEEE 16th International Conference on Communication Technology (ICCT), pp. 568–571. IEEE (2015) 10. Lee, P.: Low noise amplifier selection guide for optimal noise performance. Analog Devices Application Note, AN-940 (2009) 11. Feng, R., Leung, C.S., Sum, J.: Robustness analysis on dual neural network-based kWTA with input noise. IEEE Trans. Neural Netw. Learn. Syst. 29(4), 1082–1094 (2017) 12. Sum, J., Leung, C.S., Ho, K.I.J.: On Wang kWTA with input noise, output node stochastic, and recurrent state noise. IEEE Trans. Neural Netw. Learn. Syst. 29(9), 4212–4222 (2017) 13. Semenova, N., et al.: Fundamental aspects of noise in analog-hardware neural networks. Chaos: Interdisc. J. Nonlinear Sci. 29(10), 103128 (2019) 14. Kariyappa, S., et al.: Noise-resilient DNN: tolerating noise in PCM-based AI accelerators via noise-aware training. IEEE Trans. Electron Devices 68(9), 4356–4362 (2021) 15. Haley, D.C.: Estimation of the dosage mortality relationship when the dose is subject to error. STANFORD UNIV CA APPLIED MATHEMATICS AND STATISTICS LABS, Technical report (1952) 16. Radev, S.T., Mertens, U.K., Voss, A., Ardizzone, L., Kothe, U.: BayesFlow: learning complex stochastic models with invertible neural network. IEEE Trans. Neural Netw. Learn. Syst. 33(4), 1452–1466 (2020)
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method Bang Xiong1 , Jiayang Huang1 , Bo Wan1(B) , Changhua Jiang2 , Kejia Su1 , and Fei Wang2 1
2
Xidian University, Xi’an 710071, Shaanxi, China {bxiong,jyhuang1,kjsu}@stu.xidian.edu.cn, [email protected] National Key Laboratory of Human Factors Engineering, China Astronaut Research and Training Center, Beijing 100094, China
Abstract. Information transfer rate (ITR) of steady-state visual evoked potential (SSVEP)-based brain-computer interface (BCI) spellers were often calculated with fixed gaze shifting time. However, the required gaze shifting time changes with the distance between the two characters. In this study, we propose a continuous spelling method to enhance the ITR of the speller by adaptively adjusting the gaze shifting time. In the continuous spelling procedure, SSVEP signals corresponding to a character sequence were evoked by focusing on different target stimuli continuously. To recognize these target characters, SSVEP segments and their onset time are first obtained by a threshold-based method using continuous wavelet transform (CWT) analysis. Then, we proposed a template reconstruction canonical correlation analysis (trCCA) to extract the feature of the SSVEP segments. Both offline and online experiments were conducted with 11 participants by a 12-target speller. The offline experiments were used to learn the parameters for template reconstruction. In online experiments, the proposed spelling method reached the highest ITRs of 196.41 ± 30.25 bits/min. These results demonstrate the feasibility and efficiency of the proposed method in SSVEP spelling systems. Keywords: brain-computer interface (BCI) · steady-state visual evoked potential (SSVEP) · continuous spelling · continuous wavelet transform (CWT) · template reconstruction
1
Introduction
Brain-computer interfaces (BCIs) build a new connection between humans and computers. EEG-based BCIs have attracted more attention due to their convenience to use and low costs [1]. Among the EEG-based BCIs, SSVEP-based BCIs provide a higher signal-to-noise ratio (SNR) and information transfer rate This work is supported by the Foundation of National Key Laboratory of Human Factors Engineering (Grant NO. 6142222210101), the Key Industry Innovation Chain Projects of Shaanxi, China (Grant NO. 2021ZDLGY07-04). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 215–226, 2023. https://doi.org/10.1007/978-981-99-1639-9_18
216
B. Xiong et al.
(ITR) [2]. In recent years, SSVEP-based speller has become the most popular application, which enables aphasias to communicate with the outside environment [3]. In a common SSVEP-based speller [4], characters are modulated with distinct frequencies and phases [5]. Users focus on the visual flickers and the SSVEPs were evoked [6], the target characters can be identified by detecting the corresponding SSVEP features. Information transfer rate is an important metric to evaluate the performance of SSVEP spellers, which is determined by the target number, target classification accuracy, and target selection time. The classification accuracy mainly depends on the SSVEP detection algorithms. And the target selection time is comprised of visual stimulation time and gaze shifting time. In a multi-target speller system, the highest ITR is obtained by a trade-off between classification accuracy and target selection time. Various approaches have been proposed to obtain high classification accuracy with short stimulation time. For example, Chen et al. proposed an extended canonical correlation analysis (eCCA) [7] by applying the filter bank analysis in standard canonical correlation analysis (CCA) [8,9] and incorporating the individual calibration data into target identification [10]. It achieved the highest ITR at 1 s (0.5 s for visual stimulation and 0.5 s for gaze shifting). To further enhance the SNR of the multi-channel signals, Nakanishi et al. [11] proposed an ensemble task-related component analysis (eTRCA) for SSVEP recognition, and the highest ITR was obtained in online cue-spelling at 0.8 s (0.3 s for visual stimulation and 0.5 s for gaze shifting). However, due to the individual differences and unstationary brain states, the required EEG data for target identification varies from time to time [12,13]. To address this problem, Jiang et al. [14] applied the dynamic stopping method to reduce the time for target selection, which adaptively set SSVEP data length for feature extraction with high classification accuracy and obtained a satisfactory result, the highest ITR was achieved at around 0.7 s (averaged 0.7 s for a target selection). Apart from reducing the visual stimulation time, the gaze shifting time can be optimized to reduce the time for selecting a target. Recently, Wan et al. [15] developed a continuous word speller by utilizing the filter bank canonical correlation analysis (FBCCA) with a sliding multi-window strategy to minimize the gaze shifting time. However, the reference-based feature extraction method reached the highest ITR with at least 1.2 s visual stimulation time. To further improve the ITR of the spelling system, the visual stimulation time and the gaze shifting time could be simultaneously reduced by applying the template-based feature extraction method in continuous spelling systems. The main contributions of this paper can be summarized as follows: 1. A continuous spelling method is proposed to further enhance the ITR of the SSVEP-based speller. 2. EEG segmentation is conducted by a threshold-based method using continuous wavelet transform (CWT) analysis [16]. 3. The templates are reconstructed by the convolution of the impulse responses and the periodic impulses [17–19].
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method
217
4. The target frequency is identified by a combined correlation coefficient between the SSVEP segment and the constructed templates. The proposed method was evaluated by a 12-target BCI system in the offline and online experiments. The online ITRs of the continuous spelling method significantly outperform the traditional spelling method. These results demonstrated the feasibility and the efficiency of the proposed spelling method. The remainder of this paper is organized as follows. Section 2 introduces the materials for participants and the design of experiments. Section 3 introduces the methods for SSVEP data processing and performance evaluation. Section 4 shows the results of the experiment. Finally, the discussion and the conclusion are presented in Sects. 5 and 6 respectively.
2 2.1
Materials Participants
Offline and online experiments were conducted in this study separately. Eleven healthy subjects (6 males and 5 females, aged 23–27, mean age 25 years) with normal or corrected-to-normal vision participated in the experiments. Five subjects had the experience of using an SSVEP-based BCI speller, while the others were naive to the EEG-based experiments. Each subject was informed of the experimental procedure before the experiment. 2.2
Visual Stimulus Presentation
This study used the sampled sinusoidal code method to represent the visual stimulus modulated by the JFPM method. All the stimuli are displayed on a 24 in. liquid-crystal display (LCD) monitor with a resolution of 1,920- × 1,080pixels and a refresh rate 60 Hz. The stimulus sequence s(fi , φi , n) corresponding to stimulus frequency fi and phase φi is modulated by the formula [4]: s(fi , φi , n) =
n 1 1 + sin 2πfi + φi , 2 RefreshRate
(1)
where sin() generates a sine wave, n indicates the frame index of a sequence, and RefreshRate is the refresh rate of the LCD monitor. In this way, the luminance of the screen can be modulated from 0 to 1 (0 represent dark and 1 represent white). The designed BCI system containing 12-target stimuli, the stimulus frequencies range from 9.25 to 14.75 Hz with an interval of 0.5 Hz, and the phase resolution is 0.5 π. As shown in Fig. 1, each stimulus was designed as 160- × 160-pixel square. Each character with 50- × 50-pixel was located at the center of a stimulus square. The vertical and horizontal distances between two neighboring stimuli were 90 pixels and 500 pixels respectively. The stimulus program was developed under MATLAB using the Psychophysics Toolbox Version 3 [20].
218
B. Xiong et al.
Fig. 1. The user interface of 12-target BCI speller (a), the frequency and phase values for all targets modulated by JFPM (b).
2.3
Experiments
The acquisition equipment used in the experiments is g.USBamp-Research. EEG data were recorded at a sampling rate 256 Hz. 9 channels (PO3, PO4, PO7, PO8, POZ, Pz, O1, O2, Oz) were used to record SSVEP signals. During the experiment, subjects were seated in a comfortable chair in a dimly lit and quiet room at a viewing distance of approximately 60 cm in front of the screen. The offline experiment was designed to derive the spatial filters and impulse responses of each individual. It consists of 10 blocks. In each block, subjects need to complete 12 trials corresponding to all 12 stimuli indicated in random order. A visual cue presented as a red dot appeared below the stimulus to facilitate visual fixation, the gaze duration of each target lasted 5 s. There are 0.5 s of a cue before each stimulation onset, subjects need to shift their gaze to the indicated target as soon as possible within the cue duration. After the stimulation offset, subjects rest for 0.5 s to prepare for the next trial. To avoid visual fatigue, subjects were asked to rest for several minutes between two consecutive blocks. The online free-spelling task consisted of 12 blocks. In each block, subjects were asked to input a 12-character sequence (‘#4256378190∗’) without visual cues, but an auditory cue (beep for every 0.8 s) was provided to indicate the next trial. Subjects need to shift their gaze to the next target and focus on the target within 0.8 s. They were asked to rest for several minutes between two repetitions.
3 3.1
Methods Data Processing
All the EEG data from offline and online experiments are band-pass filtered from 7 to 50 Hz with an IIR filter. Zero-phase forward and reverse filtering is implemented using the filtfilt() function in MATLAB. In the offline experiments,
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method
219
a visual latency of 0.14 s is applied to the data epochs [21], which means that the data epochs for offline experiments were extracted in [0.14 s 5.14 s] (where 0 indicated the stimulus onset). However, the SSVEP segments in the online experiments were obtained by a threshold-based method using CWT analysis. As shown in the Fig. 2, the multi-channel signals are first weighted by a spatial filter to enhance the SNR of the single trial EEG data. Then the threshold-based segmentation method is applied to identify the SSVEPs and the non-SSVEPs. CWT is implemented using the cwt() function in MATLAB using the bump analytic wavelet. The threshold is determined by a grid search method in offline experiments. Finally, the SSVEP segments and their onset time are obtained for feature extraction.
Fig. 2. The EEG segmentation diagram implemented by a threshold-based method using CWT analysis. The bump wavelet is utilized to get high time resolution.
3.2
Template Reconstruction CCA
Based on the superposition hypothesis, the SSVEP templates can be reconstructed by convolution of the impulse response with the periodic impulse [17,19]. For a visual stimulus frequency fi , the corresponding periodic impulse donated as hi = [h(1), h(2), · · · , h(L)] (i = 1, 2, · · · , Nf , Nf is the number of stimulus frequency, L is the length of the stimulation signal), the response of the subjects corresponding to frequency fi can be reconstructed by ⎡ ⎤ ⎤⎡ ri (1) h(1) · · · h(L) 0 0 0 +∞ ⎢ri (2) ⎥ ⎢ 0 h(1) · · · h(L) 0 0 ⎥ ⎢ ⎥ ⎥⎢ = ri (τ ) · hi (t − τ ) = ri Hi = ⎢ . ⎥ ⎢ . . x i . . . . . . . . . . . ... ⎥ .. ⎦ ⎣ .. ⎣ ⎦ −∞ 0 0 0 h(1) · · · h(L) ri (N ) (2)
220
B. Xiong et al.
where ri is the impulse response, Hi is the periodic impulse matrix, N is the length of ri . The impulse response ri can be obtained by solving the following least-square problem: 2 , (3) {r i , wi } = argmin w i X i − ri H i r i ,w i
where X i is the averaged training data obtained in the offline experiments, and wi is the spatial filter corresponding to the i-th frequency. The solution of (3) can be found by using the alternating least-square approach. Finally, the spatial filter wi and impulse response ri of frequency fi were utilized to reconstruct the templates. In the continuous spelling system, the subject shifts their gaze to the next target adaptively. As a result, the focus time is different from the onset time of the stimulus, which makes it difficult for template training. Inspired by the template reconstruction approaches, we proposed a template reconstruction CCA (trCCA) algorithm to identify the SSVEP features in the continuous spelling system. Suppose an SSVEP segment X j evoked by the i-th stimulus (frequency fi and the initial phase φi ), the onset time of the SSVEP segment tj is obtained by the CWT analysis in Sect. 3.1. The phase of the SSVEP segment φji is calculated by: (4) φji = 2πfi tj + φi , where tj represents the focus time when subject gaze at the stimulus fi (time 0 represent the stimulation onset). The H ji corresponding to φji is donated as: ⎡ ⎤ h(tj ) · · · h(L) 0 0 0 h(1) · · · h(tj − 1) ⎢h(tj − 1) h(tj ) · · · h(L) 0 0 0 h(1) ··· ⎥ ⎢ ⎥ H ji = H i φji = ⎢ ⎥ , (5) .. .. .. .. .. .. .. .. .. ⎣ ⎦ . . . . . . . . . 0
0
0
h(1) · · · h(L) 0
0
0
The new template corresponding to fi at focus time tj is reconstructed by the convolution of r i with H ji :
xji = riH ji .
(6)
Then, the SSVEP features are efficiently extracted by using these reconstructed templates. Further, the standard CCA is combined to enhance the performance of the feature extraction. The combined correlation coefficient ρi is calculated as: ⎤ ⎡ j X CCA , Y i ρ ⎦, (7) ρi = i,1 = ⎣ ρi,2 cor X j wi , H j i ri where Yi donates the sin-cosine reference signal of stimulus frequency fi , cor (a, b) indicates the correlation coefficient between a and b. CCA() calculates the canonical correlation between multi-channel SSVEPs and the reference
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method
221
signals. The number of harmonics is set as 3 in standard CCA to include the fundamental and harmonic components of SSVEPs. The flowchart of the feature extraction is shown in Fig. 3.
Fig. 3. Flowchart of the trCCA-based feature extraction method for SSVEP detection in continuous spelling system.
The weighted correlation coefficients are calculated as the feature for target identification: 2 sign (ρi,m ) · ρ2i,m , (8) ρi = m=1
where sign() is to retain discriminate information from negative correlation coefficients between the test SSVEPs and templates. The frequency with the maximal ρi is considered to be the frequency of SSVEPs. 3.3
Performance Evaluation
We implemented a 12-target continuous spelling system to evaluate the performance of the proposed spelling method. During the spelling period, the subject spells characters by continuously focusing on multiple stimuli and the EEG signals in visual cortex are collected synchronously. To recognize the SSVEPs in the EEG signals, this study first obtained the onset time of the SSVEP segments by a threshold-based method using CWT coefficients. Secondly, the phases of the
222
B. Xiong et al.
segments obtained from the CWT analysis are calculated based on the corresponding stimuli phases and their onset time. Thirdly, the templates are reconstructed by the convolution of the impulse responses and the periodic impulses, where the impulse responses represented the individual responses to flickers and the periodic impulses are determined by the stimulation frequencies and the onset time of the SSVEP segments. Finally, the target frequency is identified by a combined correlation coefficient between the SSVEP segments and the reconstructed templates. The SSVEP signals obtained from the offline experiments were classified using three algorithms FBCCA, eTRCA and trCCA. The classification accuracy was estimated by leave-one-out cross-validation. In each calculation, 9 blocks were used for training and 1 block for testing. Meanwhile, ITR is utilized to evaluate the system performance: 1−P 60 × , (9) ITR = log2 Nf + P log2 P + (1 − P ) log2 Nf − 1 T where Nf is the number of frequencies, P is the classification accuracy, and T is the average time for a selection. In the offline experiments, gaze shifting time is fixed as 0.5 s. In the online experiments, the selection time T is calculated as T = (Tv + Ts )/Nt , where Tv is the visual stimulation time, Ts is the gaze shifting time, Nt is the number of targets. Thus, the summation of the Tv and Ts represents the time for spelling the Nt targets, and the T is the averaged selection time for each target.
4 4.1
Results Offline BCI Performance
Offline experiments were conducted to obtain the parameters for template reconstruction. Figure 4 shows the averaged accuracies and ITRs of three different feature extraction methods across eleven subjects under different data lengths. For each subject, the accuracies and ITRs were averaged across 10 trials from the leave-one-out cross-validation. One-way repeated-measures analysis of variance (ANOVA) was utilized to investigate the difference between these three methods. As the Fig. 4(a) shows, there were significant differences in the classification accuracy between the three methods under data lengths from 0.2 s to 0.8 s. In Table 1, trCCA outperform the FBCCA in short data lengths (0.2 s to 0.8 s). There was no significant difference between trCCA and eTRCA under all data lengths. These results show the feasibility of the template reconstruction in the SSVEP-based continuous spelling system. Figure 4(b) shows the mean ITRs from 0.2 s to 1.5 s in the offline experiment. eTRCA and trCCA obtained the highest ITR of 123.7 bits/min and 129.0 bits/min at 0.7 and 0.8 s, respectively, significant higher than the ITR of the FBCCA (0.7 s: eTRCA vs. FBCCA, 123.7 bits/min vs. 75.54 bits/min, 0.8 s:
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method
100
** *
** *
** *
** ** * *
** *
170
**
150
90
** *
* * *
* * *
** *
* * *
* * *
223
* *
130
ITR (bits/min)
Accuracy (%)
80 70 60 50 40
110 90 70 50
30 30
20
trCCA eTRCA FBCCA
10
trCCA eTRCA FBCCA
10
0
-10 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 1.1 1.2 1.3 1.4 1.5
0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1 1.1 1.2 1.3 1.4 1.5
Data length (s)
Data length (s)
(a)
(b)
Fig. 4. Classification accuracies and ITRs using different data lengths (from 0.2 s to 1.5 s with a step of 0.1 s). The error bars represent standard errors. The asterisks indicate significant difference between different methods (∗p < 0.05, ∗∗p < 0.01, ∗∗∗p < 0.001). Table 1. The significant difference on classification accuracy obtained by one-way repeated measures ANOVA Method
Time windows 0.2 s 0.3 s 0.4 s
trCCA vs. eTRCA 0.1000
0.0573
0.1364
0.5 s
0.6 s
0.7 s
0.8 s
0.3968
0.8376
0.7446 0.5750
trCCA vs. FBCCA < 0.001 < 0.001 < 0.001 < 0.001 < 0.001 < 0.01 0.0518
trCCA vs. FBCCA, 129.0 bits/min vs. 95.87 bits/min). It should be mentioned that the selection time in the offline experiments was the summation of the visual stimulation time and the fixed 0.5 s for gaze shifting time. While in the online experiments, subjects adjust the gaze shifting time adaptively, which has the potential to boost the ITR of the spelling systems. 4.2
Online BCI Performance
All the subjects participated in the online experiment. Table 2 lists the accuracies and ITRs of the continuous spelling system in free-spelling tasks. The averaged accuracies and ITRs were 87.25 ± 6.31 % and 196.41 ± 30.25 bits/min. These results show that the accuracy is similar to the offline experiment results (offline accuracies vs. online accuracies: 89.92 ± 5.86 % vs. 87.25 ± 6.31 %). But the visual stimulation time and the gaze shifting time were simultaneously reduced by the template-based feature extraction approach and the proposed continuous spelling method. The ITRs were significantly higher than the offline experiment (offline ITRs vs. online ITRs: 129.77 ± 19.40 bits/min vs. 196.41 ± 30.25 bits/min). The results indicated the proposed method outperforms the fixed gaze shifting time spelling method in practical usage.
224
B. Xiong et al.
Table 2. Results of online free-spelling experiment. The bold font indicates subjects with experience of using SSVEP-based BCI speller. Subject
No. of trials (Correct/Incorrect) Accuracy [%] ITR [bits/min]
s1
144(136/ 8)
94.44
231.22
s2
144(120/24)
83.33
176.86
s3
144(119/25)
82.64
173.89
s4
144(121/23)
84.03
179.91
s5
144(123/21)
85.42
186.10
s6
144(140/ 4)
97.22
247.92
s7
144(130/16)
90.28
209.14
s8
144(117/27)
81.25
168.00
s9
144(111/33)
77.08
151.17
s10
144(131/13)
90.97
212.64
s11
144(134/10)
93.06
223.57
Mean ± STD –
5
87.25 ± 6.31 196.41 ± 30.25
Discussions
Gaze shifting time is an essential part when calculating the ITRs of the spelling systems. In general, gaze shifting time was usually set as a fixed time in previous studies. However, it changes with the distance between characters, which influences the ITRs of the spelling systems. To solve this problem, this study proposed a continuous spelling method. That is, the user spells a word or a phrase by focusing on multiple target stimuli in a spelling block. In this way, the gaze shifting time was adaptively adjusted during spelling. In the online spelling task, the continuous speller achieved the averaged ITRs of 196.41 ± 30.25 bits/min with a 12 targets speller. These results demonstrated the feasibility and efficiency of the proposed spelling method. The proposed system has a good performance in continuous spelling, some problems should be further investigated in future work. Firstly, This study designed a 12 targets speller to evaluate the performance of the spelling method. However, with more than 40 targets, the required gaze shifting time decreases with the distance between characters, which poses a challenge to the robustness of the CWT coefficients-based SSVEP segmentation. The voice active detection technology is widely used in speech recognition to identify speech onset. So, our future work will explore the feasibility of the active detection technology for SSVEP segmentation. Secondly, the dynamic stopping method has a good performance in SSVEP-based BCI systems [14,22,23]. The combination of the dynamic stop method and the adaptive gaze shifting has the potential to minimize the target selection time and enhance the ITR of the SSVEP-based BCI systems.
A High-Speed SSVEP-Based Speller Using Continuous Spelling Method
6
225
Conclusion
A novel continuous spelling method based on template reconstruction is proposed in this study. The system achieved the highest ITR of 196 bits/min in the online experiment. These results show the feasibility of the proposed method in the SSVEP-based spelling systems. The performance of SSVEP segmentation demonstrates the applied CWT coefficients-based method can extract the SSVEP segments from the continuous EEG signals. The classification results show the effectiveness of the trCCA approach for SSVEP segments feature extraction. The continuous spelling method adaptively adjusts the time for gaze shifting, thus boosting the ITR of the system and providing an efficient interaction method in SSVEP BCIs. Acknowledgments. All authors would like to thank all the participants for data acquisition in this study.
References 1. Wolpaw, J.R., Birbaumer, N., McFarland, D.J., Pfurtscheller, G., Vaughan, T.M.: Brain-computer interfaces for communication and control. Clin. Neurophysiol. 113(6), 767–791 (2002). https://doi.org/10.1016/S1388-2457(02)00057-3 2. Allison, B., Dunne, S., Leeb, R., Millan, J.D.R., Nijholt, A.: Towards Practical Brain-computer Interfaces: Bridging the Gap From Research to Real-world Applications. Biological and Medical Physics, Biomedical Engineering. Springer Science & Business Media, Berlin (2012) 3. Chen, X., Chen, Z., Gao, S., Gao, X.: A high-ITR SSVEP-based BCI speller. BrainComput. Interfaces 1(3–4), 181–191 (2014). https://doi.org/10.1080/2326263X. 2014.944469 4. Chen, X., Wang, Y., Masaki, N., Tzyy-Ping, J., Gao, X.: Hybrid frequency and phase coding for a high-speed SSVEP-based BCI speller. In: 2014 36th Annual International Conference of the IEEE Engineering in Medicine and Biology Society, pp. 3993–3996 (2014). https://doi.org/10.1109/EMBC.2014.6944499 5. Regan, D.: Evoked potentials and evoked magnetic fields in science and medicine. Hum. Brain Electrophysiol. 59–61 (1989) 6. Pan, J., Gao, X., Duan, F., Yan, Z., Gao, S.: Enhancing the classification accuracy of steady-state visual evoked potential-based brain-computer interfaces using phase constrained canonical correlation analysis. J. Neural Eng. 8(3), 036027 (2011). https://doi.org/10.1088/1741-2560/8/3/036027 7. Chen, X., Wang, Y., Nakanishi, M., Gao, X., Jung, T.P., Gao, S.: High-speed spelling with a noninvasive brain-computer interface. Proc. Natl. Acad. Sci. 112(44), E6058–E6067 (2015). https://doi.org/10.1073/pnas.1508080112 8. Chen, X., Wang, Y., Gao, S., Jung, T.P., Gao, X.: Filter bank canonical correlation analysis for implementing a high-speed SSVEP-based brain-computer interface. J. Neural Eng. 12(4), 046008 (2015). https://doi.org/10.1073/pnas.1508080112 9. Lin, Z., Zhang, C., Wu, W., Gao, X.: Frequency recognition based on canonical correlation analysis for SSVEP-based BCIs. IEEE Trans. Biomed. Eng. 53(12), 2610–2614 (2006). https://doi.org/10.1109/TBME.2006.886577
226
B. Xiong et al.
10. Wang, Y., Nakanishi, M., Wang, Y.T., Jung, T.P.: Enhancing detection of steadystate visual evoked potentials using individual training data. In: 2014 36th Annual International Conference of the IEEE Engineering in Medicine and Biology Society, pp. 3037–3040 (2014). https://doi.org/10.1109/EMBC.2014.6944263 11. Nakanishi, M., Wang, Y., Chen, X., Wang, Y.T., Gao, X., Jung, T.P.: Enhancing detection of SSVEPs for a high-speed brain speller using task-related component analysis. IEEE Trans. Biomed. Eng. 65(1), 104–112 (2017). https://doi.org/10. 1109/TBME.2017.2694818 12. Ma, B.-Q., Li, H., Zheng, W.-L., Lu, B.-L.: Reducing the subject variability of EEG signals with adversarial domain generalization. In: Gedeon, T., Wong, K.W., Lee, M. (eds.) ICONIP 2019. LNCS, vol. 11953, pp. 30–42. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-36708-4 3 13. Guo, T., et al.: Constrained generative model for EEG signals generation. In: Mantoro, T., Lee, M., Ayu, M.A., Wong, K.W., Hidayanto, A.N. (eds.) ICONIP 2021. LNCS, vol. 13110, pp. 596–607. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-92238-2 49 14. Jiang, J., Yin, E., Wang, C., Xu, M., Ming, D.: Incorporation of dynamic stopping strategy into the high-speed SSVEP-based BCIs. J. Neural Eng. 15(4), 046025 (2018). https://doi.org/10.1088/1741-2552/aac605 15. Wan, B., Xiong, B., Huang, J., Yang, P.: A novel SSVEP-based word speller based on sliding multi-window strategy. In: 2021 9th International Winter Conference on Brain-Computer Interface (BCI), pp. 1–5 (2021). https://doi.org/10. 1109/BCI51272.2021.9385343 16. Stockwell, R.G., Mansinha, L., Lowe, R.: Localization of the complex spectrum: the S transform. IEEE Trans. Signal Process. 44(4), 998–1001 (1996). https://doi. org/10.1109/78.492555 17. Thielen, J., van den Broek, P., Farquhar, J., Desain, P.: Broad-band visually evoked potentials: re(con)volution in brain-computer interfacing. PLoS ONE 10(7), e0133797 (2015). https://doi.org/10.1371/journal.pone.0133797 18. Capilla, A., Pazo-Alvarez, P., Darriba, A., Campo, P., Gross, J.: Steady-state visual evoked potentials can be explained by temporal superposition of transient eventrelated responses. PLoS ONE 6(1), 1–15 (2011). https://doi.org/10.1371/journal. pone.0014543 19. Wong, C.M., et al.: Transferring subject-specific knowledge across stimulus frequencies in SSVEP based BCIs. IEEE Trans. Autom. Sci. Eng. 18(2), 552–563 (2021). https://doi.org/10.1109/TASE.2021.3054741 20. Brainard, D.H., Vision, S.: The psychophysics toolbox. Spat. Vis. 10(4), 433–436 (1997). https://doi.org/10.1163/156856897X00357 21. Di Russo, F., Spinelli, D.: Electrophysiological evidence for an early attentional mechanism in visual processing in humans. Vision Res. 39(18), 2975–2985 (1999). https://doi.org/10.1016/S0042-6989(99)00031-0 22. Cecotti, H.: Adaptive time segment analysis for steady-state visual evoked potential based brain-computer interfaces. IEEE Trans. Neural Syst. Rehabil. Eng. 28(3), 552–560 (2020). https://doi.org/10.1109/TNSRE.2020.2968307 23. Chen, Y., Yang, C., Chen, X., Wang, Y., Gao, X.: A novel training-free recognition method for SSVEP-based BCIs using dynamic window strategy. J. Neural Eng. 18(3), 036007 (2021). https://doi.org/10.1088/1741-2552/ab914e
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer Mengzhu Wang1 , Shanshan Wang2,3,4(B) , Tianwei Yan1 , and Zhigang Luo1 1
2
National University of Defense Technology, Changsha, China [email protected] Institutes of Physical Science and Information Technology, Anhui University, Hefei, China [email protected] 3 Institute of Artificial Intelligence, Hefei Comprehensive National Science Center, Hefei, China 4 Chongqing Key Laboratory of Bio-perception and Intelligent Information Processing, Chongqing University, Chongqing, China
Abstract. In sim-to-real task, domain adaptation is one of the basic challenge topic as it can reduce the huge performance variation caused by domain shift. Domain adaptation can effectively transfer knowledge from a labeled source domain to an unlabeled target domain. Existing DA methods always consider to match cross-domain local features, however, only consider local features may lead to negative transfer. To alleviate this problem, in this paper, we propose a novel non-local networks for sim-to-real adversarial augmentation transfer (AAT). Our method leverages attention mechanism and semantic data augmentation to focus on global features and augmented features. Specifically, to focus on the global features, we leverage the non-local attention mechanism to weight the extracted features which can effectively eliminate the influence of untransferable features. Additionally, in order to enhance the ability of classifier adaptation, semantic data augmentation is leveraged to augment source features toward target features. We also give an upper bound of the divergence between the augmented features and the source features. Although our method is simple, it consistently improves the generalization performance of the popular domain adaptation and sim-to-real benchmarks, i.e., Office-31, Office-Home, ImageCLEF-DA and VisDA-2017. Keywords: Domain Adaptation · Non-Local Networks · Semantic Data Augmentation
1 Introduction In the real world, deep networks [26–28] have greatly improved the performance of various machine learning problems and applications application scenarios such as robotic applications [12, 18], however, the training process relies heavily on a large number of labeled training samples based on supervised learning. In fact, it is often prohibitively expensive to manually label large amounts of such data, while domain adaptation [23, 24, 30, 32] is exactly a well-researched strategy to solve this problem c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 227–238, 2023. https://doi.org/10.1007/978-981-99-1639-9_19
228
M. Wang et al.
by transferring knowledge from labeled source domain to improve the unlabeled target domain. However, this dramatic learning paradigm faces the problem of domain shift, which constitutes a major obstacle for making the adaptive source domain beneficial to the target domain. Previous deep domain adaptation methods can be roughly classified into two paradigm: 1) methods based on minimizing statistical discrepancy [10, 32, 34], which use statistical regularization to clearly reduce cross-domain distribution discrepancy; 2) methods based on adversarial learning [2, 9, 21, 22], which try to learn domain-invariant representation across two domains through adversarial methods. Among these methods, domain adversarial methods [2, 9, 20] have recently attracted a lot of attention which introduces adversarial learning into deep networks, and the core idea of it comes from generative adversarial networks (GAN) [4, 17]. Specifically, domain adversarial adaptation usually uses a min-max game strategy to optimize the domain discriminator. The domain discriminator is designed to distinguish the feature representation between training data and testing data. Once the minmax optimization reaches a balance, the hypothesis optimization of feature representations in domain adversarial adaptation is fully achieved. For instance, PSR [20] introduces a pairwise similarity regularization approach to exploit the clustering structure of the target domain. It minimizes the difference between the pairwise similarity of clustering partitions and the pseudopredicted pairwise similarity. Multi-adversarial domain adaptation [13] achieves multiadversarial domain adaptation by extracting multi-modal information, which enables fine-grained cross-domain alignment with the help of multiple domain discriminators. MDD [6] tries to minimize the inter-domain divergence and maximize the inter-class density to address the balance challenge issue in adversarial domain adaptation. Although domain adversarial adaptation methods have achieved remarkable results, they still face some major bottlenecks. On the one hand, not all the features generated by the well trained feature extractor can be used for domain adaptation. If these untransferable features are forcedly to be matched, it may lead to negative transfer. On the other hand, most of existing adversarial domain adaptation methods are adapted to feature representation across two domains under the guidance of a shared source supervised classifier. However, this source trained classifier limits the generalization performance of unlabeled target samples. To handle these problems, we propose non-local networks for sim-to-real adversarial augmentation transfer (AAT) based on the adversarial learning. We introduce the non-local attention mechanism to weight extracted features which can effectively eliminate the influence of untransferable features. In addition, to further enhance the adaptability of classifiers, we leverage the semantic data augmentation to effectively equip the target semantics to train a more transferable classifier. Then the source feature can be enhanced into the target domain by the overall semantic difference between domains and the semantic changes in the target class. Similar to the conditional domain adversarial networks (CDAN), the proposed method involves a conditional adversarial adaptation. Different from CDAN, a nonlocal attention mechanism and semantic augmentation are designed to handle the disadvantage of the CDAN as seen in Fig. 1. In addition, our AAT is very simple and it can be easily plugged into various UDA methods to significantly boost their performances.
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer
229
Fig. 1. Non-local Network for Sim-to-Real Adversarial Augmentation Transfer. Our core module consist of three parts: (a) denotes that we use semantic data augmentation for source classifier to augment source domain. (b) denotes that we use non-local attention module to focus on the global feature. (c) denotes that we use conditional adversarial domain adaptation loss as our backbone.
Overall, our contributions are threefold in the following: (1) This paper proposes a new non-local attention method for unsupervised adversarial domain adaptation. Compared with the existing DA methods, the proposed non-local attention mechanism can capture complex global structural information while avoiding the neglect transfer caused by violence matching and untransferable features. (2) To enhance the adaptation ability of classifier, we also leverage an expected transferable cross-entropy loss on augmented source distribution and we get its upper bound. (3) Extensive experiments on several domain adaptation benchmarks, including Office-31, ImageCLEF-DA, Office-Home and VisDA-2017 demonstrate that our method can outperform most state-of-the-arts.
2 Method 2.1 Conditional Domain Adversarial Networks [9] The idea of adversarial domain adaptation is inspired by Generative Adversarial Networks (GAN) [4]. In this paper, we denotes F as feature learner and D denotes the domain discriminator. Then the overall networks can be represented as follows: C 1[ys =c] log σ (F (xs )) min max Ladv = −E F
D
c=1
+ α (E [log D (hs )] + E [log (1 − D (ht ))])
(1)
where α is a hyper-parameter to balance the classifier loss and adversarial loss. To make a fair comparison, we fix α=1 followed by CDAN [9]. In Eq. (1), the first term is a supervised cross-entropy loss on the source domain, C is the number of the categories,
230
M. Wang et al.
σ is the softmax function, 1 is the indicator function. The second term is the adversarial loss from the CDAN [9]. It is worth nothing that h = Π(f, p) is the joint variable of the domain-specific feature f and its corresponding classification prediction p. Specifically, the classification prediction is calculated by softmax classifier. Π(·) is the conditioning strategy in CDAN [9]. However, CDAN [9] only considers to share source-supervised classifier in the training process, which may limit the generalization ability towards unlabeled target domain. We further leverage semantic data augmentation to handle this problem. 2.2
Semantic Data Augmentation Loss
With the help of conditional adversarial training strategies, we can capture multimodal data structures by optimizing Eq. 1 and achieve significant performance. However, CDAN algorithm adapts the feature representation across two domains under the guidance of the shared source supervised classifier. However, as mentioned above, this classifier limits the generalization ability of the unlabeled target domain. In order to promote meaningful cross-domain semantic augmentation, it is necessary to discover all possible target-oriented style feature transfer directions. In our method, we leverage transferable cross-entropy loss to achieve this goal. For each class, we sample a random vector [7, 25] from the multivariate normal distribution, with the mean difference of the inter-domain characteristics as the mean, and the class conditional target covariance as the covariance matrix. Then the C sampling distributions are constructed. Each source deep feature fsi can undergo various semantic transformations along random directions sampled from N (λΔμysi , λΣyt si ) to generate the augmented feature f˜si , and the f˜si ∼ N (fsi + λΔμysi , λΣyt si ). Considering a simple condition, we can repeat each fsi for M 2timesand retain Mits label, nswhich will form an augmented feature dataset 1 , ysi , fsi , ysi , . . . , fsi , ysi i=1 . Based on this, the source network can be fsi trained using traditional cross-entropy loss on the augmented feature dataset: m ns M 1 1 ew ysi fsi +bysi (2) − log C LM (W, b, Θ) = m w ns i=1 M m=1 c fsi +bc c=1 e
where W = [w1 , w2 , . . . , wC ] ∈ RC×K and b = [b1 , b2 , . . . , bC ] ∈ RC are the weight matrix and bias vector of the last fully connected layer, respectively. In order to achieve the expected performance, M usually is set to a large value, which leads to unexpected computational costs. In semantic data augmentation, we intend to implicitly generate unlimited augmentation source features rather than the explicit value of M . When M is close to infinity, we can deduce the upper limit loss on the augmented source distribution according to the Law of Large Numbers [1]. Specifically, in the limiting case that M tends to infinity, the expected transferable crossentropy loss on the augmented source data is defined as follows: T N 1 ew ysi f si +byi E − log C L∞ (W , b, Θ | Σ) = (3) T N i=1 f si ew j f si +bj j=1
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer
231
Noteworthily, if L∞ can be calculated efficiently, we can minimize the loss directly without explicitly sampling augmentation features. Considering that Eq. (3) is difficult to calculate precisely, we derive the easy-to-calculate upper bound of L∞ to alleviate this problem, which is shown in the following proposition. Proposition 1. Suppose that f˜si ∼ N (fsi , λΣysi ), and we have the upper bound of L∞ in the following: ⎞ ⎛ N wT ysi fsi +bysi 1 e ⎠ L∞ − log ⎝
(4) L∞ ≤ λ T C w N i=1 j fsi +bj + 2 v jysi Σysi v jysi j=1 e where v jysi = wj − wysi . In the following, we will give the theoretical proof. Proof. According to Eq. (3), we have the following: ⎡ ⎛ ⎞⎤ N C T 1 E ¯ ⎣log ⎝ ev jysi f si +(bj −bysi ) ⎠⎦ L∞ = N i=1 fsi j=1 ⎞ ⎛ N C T 1 ≤ log ⎝ Ef si ev jysi fsi +(bj −bysi ) ⎠ N i=1 j=1 ⎛ N C T λ T 1 = log ⎝ e v jysi fsi +(bj −bysi )+ 2 v jysi Σysi v jysi ) N i=1 j=1
(5)
= L∞ According to the Jensen’s inequality E[log X] ≤ log E[X] [15], where the logarith 1 2 2 mic function log(·) is concave. Due to X ∼ N μ, σ 2 , E etX = etμ+ 2 σ t . Then we can get the following: T T ˜ (6) vT jyi fsi + (bj − bysi ) ∼ N v jysi fsi + (bj − bysi ) , λv jysi Σysi v jysi Essentially, Proposition 1 provides a replacement loss for our implicit data augmentation algorithm. We can optimize its upper bound L1 in a more efficient way instead of minimizing the exact loss function L1. Then we can replace the cross-entropy loss in Eq. (1) to semantic data augmentation loss in Eq. (1), which can effectively adapt the source classifier to target with the augmented source features. 2.3 Non-local Attention With the help of adversarial loss and semantic data augmentation loss, we can achieve significantly performance based on the hypothesis that all extracted features are transferable for domain adaptation. Unfortunately, this assumption is not always true. If the untransferable features are forcibly matched, it may lead to negative transfer. The errors brought in by negative transfer will propagate to the feature extractor. To make matters worse, these errors will be amplified in the training process. Our ultimate goal aims to
232
M. Wang et al.
Fig. 2. Detailed structure of non-local attention module.
learn transferable representations across domains. However, extracting defective features may deceive the domain discriminator, but it cannot really learn transferable features. One way to solve this problem is to introduce the non-local attention module into conditional adversarial learning. The framework of non-local attention module is shown in Fig. 2. Given an extracted feature from the feature extractor, we first fed it into the convolutional layer and distill two abstract feature maps X1 and X2 , respectively, C C where {X1 , X2 } ∈ R 2 ×H×W . Then we reshape them into R 2 ×HM and the transpose between them to perform matrix multiplication. Finally, by using softmax operation, the non-local attention map Fatt ∈ RHM ×HM can be calculated as follows: exp (X1i · X2j ) fji = H×W exp (X1i · X2j ) i=1
(7)
where mji measures the position ith impact on jth position. The more similar the characteristics of two locations, the greater the correlation between them. At the same time, X is also inputted into a convolutional layer to distill other C C abstract feature map X3 ∈ R 2 ×H×W , then is reshaped into R 2 ×HW . The feature map X3 and the transpose of Fatt are performed by a matrix multiplication operaC tion and the result is reshaped to R 2 ×H×W . Finally, it can be multiplied by a scale parameter β and the element-wise sum operation to reshape the feature representations X ∈ RC×H×W . Xj
=β
N i=1
(mji X3i ) + Xj
(8)
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer
233
where β is a hyper-parameter according to [29]. In the proposed AAT, the feature generated by the non-local attention module are transformed to the domain adaptation network to achieve domain alignment. 2.4 Overall Networks Considering all of the above discussion about conditional adversarial loss, semantic data augmentation and non-local attention module, we have the following loss for our method: L = Ladv + θL∞ (9) where θ is a hyper-parameter, in this paper, we set θ = 1. In our method, the feature extractor F is deployed to generate cross-domain transferability feature representations. Then, the classifier trained on the non-local attention source features can be directly applied into the non-local attention target features.
3 Experiments 3.1 Data Preparation In this paper, we prepared four benchmarks to verify the effectiveness of our method. Office-31 [16] which included in 31 common categories consists of three datasets, Amazon (A), Webcam (W) and DSLR (D). Specifically, the data from Amazon come from www.amazon.com. Webcam and DSLR contain the images taken by the webcam and SLR camera, respectively. ImageCLEF-DA1 contains 12 common classes shared by Caltech-256 (C), ImageNet ILSVRC 2012 (I) and PASCAL VOC 2012 (P). OfficeHome [19] consists of four different domains: Artistic images (Ar), Clip Art (Cl), Product (Pr) and Real-World (Rw). Each dataset contains 65 categories including office and home environment. VisDA-2017 [14] is a large-scale dataset for sim-to-real transfer which contains more than 280,000 images in 12 categories. We choose the synthetic domain (Syn) with 152,397 images as the source domain and 55,388 images as the target domain (Re). 3.2 Implementation Details For fair comparison, we use ResNet-50 [5] pre-trained on ImageNet as our backbone network of the dataset: Office-31. ImageCELF-DA, and Office-Home. For VisDA-2017, we use ResNet-101. In this paper, all experiments are implemented through PyTorch. We use mini-batch SGD optimizer with momentum 0.9 for the optimization.
1
http://imageclef.org/2014/adaptation.
234
M. Wang et al.
Table 1. Accuracy (%) on Office-31 and ImageCLEF-DA for unsupervised domain adaptation (ResNet-50). Method
A→W D→W W→D A→D D→A W→A Avg I→P P→I I→C C→I C→P P→C Avg
ResNet-50 [5] 68.4
96.7
99.3
68.9 62.5 60.7
76.1 74.8 83.9 91.5 78.0 65.5 91.2 80.7
DANN [3]
82.0
96.9
99.1
79.7 68.2 67.4
82.2 75.0 86.0 96.2 87.0 74.3 91.5 85.0
JAN [10]
85.4
97.4
99.8
84.7 68.6 70.0
84.3 76.8 88.0 94.7 89.5 74.2 91.7 85.8
MADA [13] 90.0
97.4
99.6
87.8 70.3 66.4
85.2 75.0 87.9 96.0 88.8 75.2 92.2 85.8
CAN [31]
81.5
98.2
99.7
85.5 65.9 63.4
82.4 78.2 87.5 94.2 89.5 75.8 89.2 85.7
iCAN [31]
92.5
98.8
100.0 90.1 72.1 69.9
87.2 79.5 89.7 94.7 89.9 78.5 92.0 87.4
CDAN [9]
93.1
98.2
100.0 89.8 70.1 68.0
86.6 76.7 90.6 97.0 90.5 74.5 93.5 87.1
CDAN+E
94.1
98.6
100.0 92.9 71.0 69.3
87.7 77.7 90.7 97.7 91.3 74.2 94.3 87.7
AAT
95.9
98.9
100.0 93.4 75.0 75.0
89.7 78.9 93.6 99.9 94.7 75.6 96.6 89.9
Table 2. Accuracy (%) on Office-Home (ResNet-50) and VisDA-2017 (ResNet-101) for unsupervised domain adaptation. Method
A→C A→P A→R C→A C→P C→R P→A P→C P→R R→A R→C R→P Avg Syn→Re
ResNet [5]
34.9 50.0 58.0 37.4 41.9 46.2 38.5 31.2 60.4 53.9 41.2 59.9 46.1 49.4
DAN [8]
43.6 57.0 67.9 45.8 56.5 60.4 44.0 43.6 67.7 63.1 51.5 74.3 56.3 62.8
DANN [3]
45.6 59.3 70.1 47.0 58.5 60.9 46.1 43.7 68.5 63.2 51.8 76.8 57.6 57.4
JAN [10]
45.9 61.2 68.9 50.4 59.7 61.0 45.8 43.4 70.3 63.9 52.4 76.8 58.3 65.7
SymNets [33] 47.7 72.9 78.5 64.2 71.3 74.2 64.2 48.8 79.5 74.5 52.6 82.7 67.6 72.9 AAT
3.3
55.5 76.3 79.5 64.6 72.9 74.6 62.6 52.0 78.3 72.7 59.0 82.2 69.2 79.7
Results
The classification results are showing in Table 1 and Table 2. In particular, AAT has greatly improved the average accuracy performance on four benchmarks. The encouraging results show the importance of non-local attention and semantic data augmentation which prove that our AAT can learn more transferable representations. In standard domain adaptation, adversarial domain adaptation (DAN [8], CAN [31], CDAN [9] and our AAT) are superior to previous domain adaptation methods. The improvement from the previous domain adaptation methods to the adversarial domain adaptation methods are very important for domain adaptation. The previous methods only consider the local features and ignore the global features, while AAT leverages the non-local attention mechanism to consider the global features and can capture global information. 3.4
Analysis
Feature Visualization. We visualize the network activation of task A→W learned by CDAN and AAT using t-SNE embedding in Fig. 3. The red points are the source sample and the blue points are the target sample. Figure 3 (a) shows the result of CDAN, which
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer
235
Fig. 3. (a) and (b): Visualization of the learned representations by t-SNE [11]. (c): A-distance. (c) Convergence analysis on task W→A.
Fig. 4. Attention Visualization of the last convolutional layer of different models on the task A→R on Office-Home.
uses a adversarial network-based method. We can find that the source domain and target domain are not well aligned, and some points are difficult to classify. In contrast, Fig. 3 (b) shows that our AAT using non-local attention mechanism and semantic data augmentation. It can be seen that the source domain and target domain are aligned very well. We can not only see that the subdomains in different domains of the same category are very close, but the subdomains of different categories are also scatted. This result shows that our model AAT can capture more fine-grained information for each category than CDAN. Distribution Discrepancy. A-distance is a measure of distribution discrepancy which defined as dA = 2(1 − 2), where is the error rate of the domain classifier trained to distinguish the source domain from the target domain. In order to limit the influence of AAT on the domain discrepancy, we report the A-distance of the bottleneck feature in Fig. 3 (c). The statistical results show that benefiting from the guidance of non-local attention mechanism and semantic data augmentation, the proposed AAT method effectively helps the adversarial model to further reduce domain discrepancy. Convergence Analysis. We first compare the convergence speed of adversarial adaptation models and without AAT. In order to simulate a case close to a real-life scenario, we chose a difficult to transfer task W→A, with a large domain shift, and reported the change in its test error during the training phase. As shown in Fig. 3 (d), we observe
236
M. Wang et al.
that after the 1500th iteration, the test error of the CDAN method remains unchanged. By contrast, it is easy to see that the loss curve of our AAT is lower than CDAN, which implies that the representations learned by our method are more effective. Attention Feature Map. To verify the effectiveness of the AAT, we visualize the attention map of the last convolutional layer of the different models in Fig. 4. The second line shows the result of CDAN, and the third line shows our AAT method. We can see that our method can focus on the object in the training process. Meanwhile, this result intuitively shows that our method is essential to capture the most important areas to better solve unsupervised domain adaptation problems.
4 Conclusion In this paper, a new type of conditional adversarial learning method with non-local attention module is proposed which named as non-local network for sim-to-real adversarial augmentation transfer. The proposed method uses a non-local attention mechanism to weight the extracted features, which can effectively eliminate the influence of untransferable features. In addition, to enhance the transferability of the domain adaptation, we leverage semantic data augmentation to achieve this goal. Comprehensive experiments on four benchmark datasets prove the feasibility and effectiveness of the proposed method. In the future work, our goal is to extend our model to other deep domain to adapt to research interests, such as robotic garbage classification and robot object recognition. Acknowledgment. This work is supported by National Natural Science Fund of China (No. 62106003), the University Synergy Innovation Program of Anhui Province (No.GXXT-2021005) and Open Fund of Chongqing Key Laboratory of Bio-perception and Intelligent Information Processing (No.2020CKL-BPIIP001).
References 1. Etemadi, N.: An elementary proof of the strong law of large numbers. Zeitschrift f¨ur Wahrscheinlichkeitstheorie und Verwandte Gebiete 55(1), 119–122 (1981). https://doi.org/ 10.1007/BF01013465 2. Ganin, Y., Lempitsky, V.: Unsupervised domain adaptation by backpropagation. In: International Conference on Machine Learning, pp. 1180–1189. PMLR (2015) 3. Ganin, Y., et al.: Domain-adversarial training of neural networks. J. Mach. Learn. Res. 17(1), 2096 (2016) 4. Goodfellow, I.J., et al.: Generative adversarial networks. arXiv preprint: arXiv:1406.2661 (2014) 5. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 6. Li, J., Chen, E., Ding, Z., Zhu, L., Lu, K., Shen, H.T.: Maximum density divergence for domain adaptation. IEEE Trans. Pattern Anal. Mach. Intell. 43, 3918–3930 (2020) 7. Li, S., Xie, M., Gong, K., Liu, C.H., Wang, Y., Li, W.: Transferable semantic augmentation for domain adaptation. arXiv preprint: arXiv:2103.12562 (2021)
AAT: Non-local Networks for Sim-to-Real Adversarial Augmentation Transfer
237
8. Long, M., Cao, Y., Wang, J., Jordan, M.: Learning transferable features with deep adaptation networks. In: International Conference on Machine Learning, pp. 97–105. PMLR (2015) 9. Long, M., Cao, Z., Wang, J., Jordan, M.I.: Conditional adversarial domain adaptation. arXiv preprint: arXiv:1705.10667 (2017) 10. Long, M., Zhu, H., Wang, J., Jordan, M.I.: Deep transfer learning with joint adaptation networks. In: International Conference on Machine Learning, pp. 2208–2217. PMLR (2017) 11. Van der Maaten, L., Hinton, G.: Visualizing data using t-SNE. J. Mach. Learn. Res. 9(11) (2008) 12. Nguyen, A.T., Tran, T., Gal, Y., Baydin, A.G.: Domain invariant representation learning with domain density transformations. arXiv preprint: arXiv:2102.05082 (2021) 13. Pei, Z., Cao, Z., Long, M., Wang, J.: Multi-adversarial domain adaptation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 14. Peng, X., Usman, B., Kaushik, N., Hoffman, J., Wang, D., Saenko, K.: VisDA: the visual domain adaptation challenge. arXiv preprint: arXiv:1710.06924 (2017) 15. Rom´an-Flores, H., Flores-Franuliˇc, A., Chalco-Cano, Y.: A Jensen type inequality for fuzzy integrals. Inf. Sci. 177(15), 3192–3201 (2007) 16. Saenko, K., Kulis, B., Fritz, M., Darrell, T.: Adapting visual category models to new domains. In: Daniilidis, K., Maragos, P., Paragios, N. (eds.) ECCV 2010. LNCS, vol. 6314, pp. 213–226. Springer, Heidelberg (2010). https://doi.org/10.1007/978-3-642-15561-1 16 17. Sun, D., et al.: A focally discriminative loss for unsupervised domain adaptation. In: Mantoro, T., Lee, M., Ayu, M.A., Wong, K.W., Hidayanto, A.N. (eds.) ICONIP 2021. LNCS, vol. 13108, pp. 54–64. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-92185-9 5 18. Tanwani, A.K.: Domain-invariant representation learning for sim-to-real transfer. arXiv preprint: arXiv:2011.07589 (2020) 19. Venkateswara, H., Eusebio, J., Chakraborty, S., Panchanathan, S.: Deep hashing network for unsupervised domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5018–5027 (2017) 20. Wang, H., Yang, W., Wang, J., Wang, R., Lan, L., Geng, M.: Pairwise similarity regularization for adversarial domain adaptation. In: Proceedings of the 28th ACM International Conference on Multimedia, pp. 2409–2418 (2020) 21. Wang, M., et al.: Informative pairs mining based adaptive metric learning for adversarial domain adaptation. Neural Netw. 151, 238–249 (2022) 22. Wang, M., et al.: InterBN: channel fusion for adversarial unsupervised domain adaptation. In: MM 2021: ACM Multimedia Conference, Virtual Event, 20–24 October 2021, pp. 3691– 3700, China (2021) 23. Wang, S., Zhang, L.: Self-adaptive re-weighted adversarial domain adaptation. IJCAI (2020) 24. Wang, S., Zhang, L., Wang, P., Wang, M., Zhang, X.: BP-triplet net for unsupervised domain adaptation: a Bayesian perspective. Pattern Recogn. 133, 108993 (2022) 25. Wang, Y., Huang, G., Song, S., Pan, X., Xia, Y., Wu, C.: Regularizing deep networks with semantic data augmentation. IEEE Trans. Pattern Anal. Mach. Intell. 44, 3733–3748 (2021) 26. Yang, X., Wang, M., Tao, D.: Person re-identification with metric learning using privileged information. IEEE Trans. Image Process. 27(2), 791–805 (2017) 27. Yang, X., Wang, S., Dong, J., Dong, J., Wang, M., Chua, T.S.: Video moment retrieval with cross-modal neural architecture search. TIP 31, 1204–1216 (2022) 28. Yang, X., Zhou, P., Wang, M.: Person reidentification via structural deep metric learning. IEEE Trans. Neural Netw. Learn. Syst. 30(10), 2987–2998 (2018) 29. Zhang, H., Goodfellow, I., Metaxas, D., Odena, A.: Self-attention generative adversarial networks. In: International Conference on Machine Learning, pp. 7354–7363. PMLR (2019) 30. Zhang, J., et al.: Supplement file of VR-goggles for robots: real-to-sim domain adaptation for visual control. Training 853(840), 715 (2018)
238
M. Wang et al.
31. Zhang, W., Ouyang, W., Li, W., Xu, D.: Collaborative and adversarial network for unsupervised domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3801–3809 (2018) 32. Zhang, W., Zhang, X., Liao, Q., Yang, W., Lan, L., Luo, Z.: Robust normalized squares maximization for unsupervised domain adaptation. In: Proceedings of the 29th ACM International Conference on Information & Knowledge Management, pp. 2317–2320 (2020) 33. Zhang, Y., Tang, H., Jia, K., Tan, M.: Domain-symmetric networks for adversarial domain adaptation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5031–5040 (2019) 34. Zhang, Y., Liu, T., Long, M., Jordan, M.: Bridging theory and algorithm for domain adaptation. In: International Conference on Machine Learning, pp. 7404–7413. PMLR (2019)
Aggregating Intra-class and Inter-class Information for Multi-label Text Classification Xianze Wu1(B) , Dongyu Ru1 , Weinan Zhang1(B) , Yong Yu1 , and Ziming Feng2 1
2
Shanghai Jiao Tong University, Shanghai, China {xzwu,maxru,wnzhang,yyu}@apex.sjtu.edu.cn China Merchants Bank Credit Card Center, Shanghai, China [email protected]
Abstract. This paper is concerned with the multi-label text classification (MLTC) task, whose goal is to assign one or more categorical labels to a document. The two critical characteristics of this task are the intra-class and inter-class information. The former means the distribution of samples belonging to the same category, and the latter models the relationships between labels such as label co-occurrence and label hierarchy. However, previous methods focus on either of them instead of combining both. This paper proposes a novel two-branch architecture to capture both intraclass and inter-class information. Experimental results show that considering both information improves the performance of the model. Besides, our model achieves competitive results on two widely used datasets.
1
Introduction
Multi-label text classification (MLTC) focuses on assigning one or multiple class labels to a document given the candidate label set. It has been applied to many fields such as tag recommendation [7], sentiment analysis [8], text tagging on social medias [18]. It differs from multi-class text classification, which aims to predict one of a few exclusive labels for a document [6]. Two types of information should be captured for the MLTC task. One is intraclass information, which cares the data distribution of samples belonging to the same category. The other is inter-class information, which models relationships between classes, such as label co-occurrence and hierarchy. Prior efforts for multi-label text classification mainly focus on learning enhanced text representation [1,13,20,22]. These models feed the text representation into a set of linear classifiers. Each linear classifier predicts whether the given document belongs to a certain class. During training, the linear classifiers capture the intra-class information by learning the decision boundaries of corresponding classes. However, these methods neglect the inter-class information since the linear classifiers are trained independently and never interact with each other. Recently, extracting the inter-class information has raised researchers’ attention [15,16,19,23]. Some studies construct a label graph according to the interclass information, and convert the graph into node features via random walkbased node embedding methods [23] or graph neural network (GNN) [15,16,19]. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 239–250, 2023. https://doi.org/10.1007/978-981-99-1639-9_20
240
X. Wu et al.
The probability that a document belongs to a class is calculated by the dot product of document features and corresponding node features. These methods capture the inter-class information while depreciating the expressiveness of intra-class information. For node embedding-based methods, node embeddings are optimized in advance with the objective function of reconstructing neighbors. The optimized node embeddings, which take information for reconstruction rather than text classification, occupy the limited capacity originally used for modeling intra-class information. For GNN-based methods, the message passage process harms the expressiveness of intra-class information because the decision boundaries of classes receive noises from other nodes. In this paper, we propose Aggregating Intra-class and Inter-class information Framework (AIIF) for MLTC. AIIF consists of a text encoder and a two-branch classification layer. On the classification layer, the linear branch applies multiple linear classifiers to capture intra-class information. The graph-assisted branch employs graph neural networks to a label-graph, where the message passing process captures the inter-class information. Each branch takes the text feature as input and makes predictions independently. Two branches’ predictions are aggregated by a followed fusion module, which is optimized during the training process. With a divide-and-conquer architecture, AIIF captures both intraand inter-class information and prevents the modeling of intra- and inter-class information from interfering with each other. Besides, AIIF supports plug-andplay usage, i.e., existing studies focusing on enhanced text representation or extracting inter-class information can be coupled with AIIF by serving as the text encoder or graph-assisted branch. To evaluate the effectiveness of AIIF, we implement an instance of AIIF with BERT [5] and GCN [9], then evaluate the instance on widely used RCV1 and AAPD datasets. Experimental results show that the instance outperforms its variants without the two-branch classifier by a large margin. Besides, the instance achieves state-of-the-art results on the widely used RCV1 dataset and achieves competitive scores on the AAPD dataset. The main contributions of this paper are listed as follows: – We propose AIIF, a novel MLTC framework which can capture both intraclass and inter-class information. – We implement an instance of AIIF. Experimental results show that the instance outperforms the baselines and gets competitive results on two public MLTC datasets. – To our best knowledge, We firstly analyze MLTC from the view of intra- and inter-class information. We hope this work can provide a new perspective to the community.
2
Related Work
As mentioned in Sect. 1, existing MLTC work focuses on two directions: improving text representation and extracting the inter-class information.
Aggregating Intra-class and Inter-class Information
241
To obtain a good text representation, many neural models have been applied, such as CNN [13], RNN [14,20,22], the combination of CNN and RNN [10], and BERT [1,3]. Some models consider improving text representation with the interaction between the input document and labels [6,20,22]. In addition, some methods construct a graph of words or documents to capture non-consecutive and long-distance semantics within a document or the whole corpus [17,21]. We argue that these methods neglect the inter-class information. Capturing the inter-class information has attracted much attention in recent years. The main idea is modeling relationships between labels as graphs and guiding the multi-label prediction with graph representation. For example, Zhang et al. [23] construct a category graph according to the label correlations and used a random-walk-based method to encode the graph. Most subsequent work applies neural networks to encode label graphs, such as Tree-LSTM [24] and the variants of GCN [15,16,24]. Lu et al. [15] propose aggregating knowledge from multiple label graphs. Ma et al. [16] propose predicting the label graph according to each document. However, these methods sometimes perform worse than their variants without label graphs [2,15], which may be attributed to the reason that these methods ignore intra-class information. Above methods only focus on either intra-class or inter-class information. Compared with them, AIIF applies a two-branch architecture to capture both information.
3
Methods
As shown in Fig. 1, AIIF separates the modeling of intra- and inter-class information with a two-branch classification layer. The classification layer takes the representation of the input document, which is obtained by the text encoder, as input. The linear branch captures intra-class information with a set of linear binary classifiers. The graph-assisted classifier branch models inter-class information by first encoding the label graph as node features, then calculating the dot product between node features and text features. The fusion module combines the predictions of the linear and graph-assisted branches as the probability that the input document belongs to each category. Problem Formulation. Given the label space L = {l1 , l2 , . . . , lT } and an input document x, MLTC aims at predicting a label set Lpr ed ⊆ L for x. In the remaining part of this section, we first describe the working flow of AIIF when BERT serves as the text encoder and GCN serves as the graph encoder. Then we introduce the training of AIIF . Finally, we introduce the method of constructing the label graph. 3.1
AIIF Models
Text Encoder. Basically, given an input document x = {x1 , x2 , . . . , xn }, where xi is the i-th token in the document, BERT converts x to the text feature as Ht = BERT (x) ,
(1)
242
X. Wu et al.
Fig. 1. The overall architecture of AIIF.
where Ht ∈ Rdt are the text feature. As recommended in [5], we add a special “[CLS]” token in front of x before feeding x into the BERT model, and take the token feature of the “[CLS]” token produced by the BERT model as the text feature of the input document x. Graph-assisted Classifier. Given the category graph G, we obtain its node features via a graph convolutional network (GCN). We choose GCN for two reasons: First, GCN can capture the complex topology structure between labels through message passing. Second, GCN can be jointly trained with the other parts of the model in an end-to-end manner to achieve the optimal solution. We denote every convolutional network layer as a non-linear function f ( , ), which l takes the adjacency matrix A and node features Hgl ∈ RT ×dg as input. Here, T is the number of nodes (labels) and dlg represents the dimension of node features. Hgl+1 = f (Hgl , A ) = σ(A Hgl W l )
(2)
Here, W l ∈ Rdg ×dg is a weight matrix to be learned, σ(·) indicates a non-linear function, which is implemented with ReLU in this paper. We treat the output of the last convolutional layer as node features, denoted by LG . The graph-assisted classifier produces the prediction for each label
Aggregating Intra-class and Inter-class Information
243
according to the dot product between text features and node features. f g (Ht ) = Ht W g zig = LG i · f g (Ht )i
(3)
where f g () is a linear transformation with W g ∈ Rdt ×dg . The transformation is required to match the dimension to that of LG . Linear Classifiers and Fusion Module. The linear classifiers make prediction for each label as, f p (Ht ) = Ht W p (4) ziw = fp (Ht )LW where f p () is a linear transformation with W p ∈ Rdt ×dp ., LW is the weight of linear classifiers. The fusion module aggregates predictions of the linear classifiers z w and that of the graph-assisted classifier z g by weighted summation as follows, z = z g · μ + z w · (1 − μ)
(5)
T
where μ ∈ R is trainable parameters representing the ratio of inter-class relationship for each category. 3.2
AIIF Training
We adopt a 2-stage training for AIIF : (1) Training the text encoder (2) Training the classification layer. Training the Text Encoder. In the first stage, we train the text encoder according to supervised signals from the dataset to obtain well-learned text features. Specifically, we add a linear classifier after the text encoder and compare predictions of the linear classifier with ground-truth labels with a hinge loss. ˆ = SH(y, y)
T
(max(0, 1 − yi · yˆi )2 )
(6)
i=1
where yˆ ∈ RT is the prediction of the linear classifier. After training, only the parameters in the text encoder are saved for later use. Training the Classification Layer. In the second stage, we freeze the parameters of the text encoder. The remaining parts of AIIF are randomly initialized and trained. Here, we use the binary cross-entropy (BCE) loss to train the model. Given the ground-truth label vector y and a vector of predicted probability p, the BCE loss is calculated as BCE(p, y) =
T i=1
(yi log pi + (1 − yi ) log (1 − pi ))
(7)
244
X. Wu et al.
We apply BCE loss to the linear classifiers, the graph-assisted classifier and the whole model. The final loss L is calculated as Lf = BCE(sigmoid(z), y) Lw = BCE(sigmoid(zw ), y) Lg = BCE(sigmoid(zg ), y)
(8)
L = Lf + αLw + βLg , where α and β are hyper-parameters used for balancing these losses. For the initialization of the vertices embedding matrix, one method is using the mean-pooling of word embeddings of the tokens in the text of the label. However, some MLTC datasets does not provide label text or provide it in the form of abbreviation, which prevents us from obtaining a good initial vertices embedding matrix. Thus, we initialize the embedding of a category label according to the documents belong to the category. More specifically, if a set of documents {x1 , x2 , . . . , xk } have the ground-truth label l, then the initial vertex embedding of l is k 1 0 Hg = BERT (xj ) , (9) k j=1 3.3
Label Graph Construction
Following [4], we create the label graphs according to the co-occurrence patterns between labels within the dataset. Details are as follows. First, we count the co-occurrence of label pairs (li , lj ) in the training set to obtain the label correlation matrix M ∈ RT ×T . Then, we calculate the conditional probability P (lj |li ) that lj appears when li appears as P (lj |li ) = Mij /Ni ,
(10)
where Ni denotes the appearance frequency of li in the training set. Then, we apply a threshold τ to filter out the noisy rare co-occurrence via 0, if P (lj |li ) < τ , (11) Aij = 1, otherwise where A is the binary adjacency matrix. However, directly applying A to GCN may cause the over-smoothing problem [12]. To alleviate the issue, we re-weight A as follows to obtain the final adjacency matrix A . C p/ j=1,i=j Aij , if i = j, Aij = (12) 1 − p, if i = j
Aggregating Intra-class and Inter-class Information
245
Table 1. Summary of the datasets. N is the number of samples in the training and validation set, M is the size of the testing set, W denotes the average length of documents, ˜ means the average number of labels per sample. and C Datasets N
4 4.1
M
W
˜ C
RCV1
23,149 781,265 223.2 3.2
AAPD
54,840 1,000
155.9 2.4
Experiments Datasets and Evaluations
We perform experiments on two widely-used MLTC datasets: RCV1 [11] and AAPD. For a fair comparison, we follow the dataset split used in previous work [20]. The statistics of datasets are shown in Table 1. Following the setting of previous work [16,20], we apply two metrics for performance evaluation: precision at top k (P@k), the normalized discounted cumulated gain at top k (nDCG@k). Given the ground-truth binary vector y ∈ {0, 1}T , P@k is defined as follows: P @k =
k 1 yrank(l) k
(13)
l=1
where rank(l) is the index of the l-th highest predicted label. nDCG@k is defined as follows: DCG@k =
k yrank(l) log(l + 1) l=1
min(k,||y ||0 )
iDCG@k =
l=1
1 log(l + 1)
(14)
DCG@k N @k = iDCG@k 4.2
Baselines
We select the following methods as baselines. (1) DXML Zhang et al. [23] construct a graph of labels considering the co-occurrence between labels and applied a random walk-based method to obtain node features. (2) AttentionXML You et al. [22] adopt a multi-label attention mechanism to perform hierarchical classification. (3) LSAN Xiao et al. [20] use the attention mechanism to consider the relations between document words and labels. (4) LGAN Ma et al. [16] use GCN to encode a static and dynamic text-specific label graph for predictions of each text. It achieved current state-of-the-arts results on RCV1 and AAPD datasets. We report the results of baselines from their original paper if no extra description.
246
4.3
X. Wu et al.
Implementation Details
We apply BERT [5] as the text encoder because BERT has the ability to obtain strong contextualized text representation and has achieved great success in many NLP tasks. The pre-trained BERT utilized in this paper are provided by transformers1 library. We use the BERTbase checkpoint. We set dt = 768 and dg = dp = 256. When constructing the label graph, we use τ = 0.5 and p = 0.2. For all training stages, we use the Adam optimizer. The initial learning rate lr is 5 × 10−4 expect for fine-tuning BERT in the first training stage, where lr is 5 × 10−5 . We use learning rate warm-up during the first 0.1 proportion of the whole training process, and a linear learning rate decay is applied for the remaining process. We use early stopping and the max training epoch for each stage is 20. We set α = 1 and β = 1. Each document is truncated at the length of 350 and 250 for the RCV1 and the AAPD dataset, respectively. Table 2. Compare AIIF with previous methods. † indicates that the scores are collected from [20]. Best results are shown in bold. Dataset
RCV1
Method
P@1
P@3
P@5
nDCG@3 nDCG@5 P@1
P@3
P@5
nDCG@3 nDCG@5
DXML AttentionXML LSAN LDGN
94.04† 96.41† 96.81 97.12
78.65† 80.91† 81.89 82.26
54.38† 56.38† 56.92 57.29
89.83† 91.88† 92.83 93.80
90.21† 92.70† 93.43 95.03
80.54† 83.02† 85.28 86.24
56.30† 58.72† 61.12 61.95
39.16† 40.56† 41.84 42.29
77.23 † 78.01† 80.84 83.32
80.99† 82.31† 84.78 86.85
BERT (ours) AIIF
97.18 83.36 57.61 94.19 97.50 84.07 58.22 94.85
94.48 95.20
86.60 86.9
62.40 41.58 62.40 42.06
81.75 82.34
85.20 85.81
4.4
AAPD
Main Results
We compare AIIF with previous methods. Results are shown in Table 2. From the results, we can observe that 1. AIIF significantly outperforms BERT. AIIF outperforms BERT in all metrics in both datasets, with the improvement between 0.70% and 1.15% in RCV1 dataset and between 0.23% and 1.44% in AAPD dataset. The superiority of AIIF over BERT is that AIIF capture both intra- and inter-class information, not just intra-label information. 2. Competitive results on two datasets. Compared with the previous stateof-the-art method LDGN, AIIF outperforms LDGN in all metrics on the RCV1-v2 dataset, with an improvement between 0.27% and 2.35%; on the AAPD dataset, AIIF achieved close results with LDGN, with an improvement between −1.20% and 0.73% in each metric. 1
https://github.com/huggingface/transformers.
Aggregating Intra-class and Inter-class Information
4.5
247
Ablation Study
Table 3. AIIF compares with its variants on MLTC datasets. AIIF-L represents the model consists of the text encoder and linear classifiers. AIIF-G represents the model consists of the text encoder and graph-assisted classifier. The best results are shown in bold. Dataset RCV1 Methods P@1 P@3
P@5
AAPD nDCG@3 nDCG@5 P@1 P@3
P@5
nDCG@3 nDCG@5
AIIF
97.50 84.07 58.22 94.85
95.20
86.80 62.40
42.06 82.34
AIIF-L
97.34
83.65
57.67
94.48
94.63
86.10
62.50 41.84
82.13
85.52
AIIF-G
96.40
83.99
58.17
94.53
94.89
85.80
62.37
82.09
85.51
41.90
85.81
To further analyze the effectiveness of the two-branch architecture, we compare AIIF to its two variants: (1) AIIF-L: We remove the graph-assisted classifier and the fusion module from the AIIF. The linear classifiers’ predictions are treated as the final prediction. (2) AIIF-G. We remove the linear classifiers and the fusion module from the AIIF. The graph-assisted classifier’s predictions are treated as the final prediction. AIIF, AIIF-L, and AIIF-G follow the same training methods. The results are shown in Table 3. We can observe that in most cases, removing any branch of AIIF will cause the performance drop on two datasets. Take the P @1 score on the AAPD dataset as an example, AIIF-L is inferior to AIIF by 0.81%, and AIIF-G is inferior to AIIF by 1.27%. The performance drop demonstrates the effectiveness of the proposed two-branch classifier. 4.6
Performance on the Tail Labels
As mentioned in Sect. 1, previous work shows that encoding the inter-class information achieves promising results on tail labels [19]. We are interested in whether introducing the intra-class information further improves results on tail labels. Thus, we evaluate AIIF and its variants with propensity scored precision at k (PSP@k), which is calculated as P SP @k =
k 1 yrank (l) k Prank (l)
(15)
l=1
Results are shown in Fig. 2. We can observe that AIIF outperforms AIIF-G on all datasets, demonstrating that even if a model has captured the inter-class information, introducing the intra-class information still improves performance on tail labels.
248
X. Wu et al.
Fig. 2. Performance on tail labels Table 4. A case study on the RCV1 dataset. Here, we show the top 5 predictions of BERT and AIIF, and the right predictions are colored by red. Input Document: French pension at 55 ups cost 117 bln Reducing the french retirement age to 55 from 60 could cost as much as 117 billion francs over 15 to 20 years, an unpublished study by the pensions branch of the social security system says, according to Daily Les Echos. The newspaper said a preliminary study by the CNAV branch of the social security system believed retirement across the board at 55 would add 28 million people to the pensioner population, raising existing claimers 31 percent from 91 million. The CNAV estimated that the extra burden on the basic social security retirement benefit system in extra payouts would be in the region of 100 billion francs., ... Ground-truth labels: economics; government/social; expenditure/revenue; government finance; welfare, social services Top 5 predictions of BERT: government/social; health; domestic politics; expenditure/revenue; welfare, social services Top 5 predictions of AIIF: government/social; welfare, social services; expenditure/revenue; economics; government finance
4.7
Case Study
Further, we compare the prediction results of AIIF and BERT to analyze why AIIF is superior to BERT in Table 4. For the example shown in Table 4, BERT correctly predicts the categories “government/social,” “expenditure/revenue,” and “welfare, social services,” but not the categories “economics” and “government finance”. AIIF predicts all the correct categories. We believe the phenomenon can be attributed to AIIF extracting the relationship between the categories after introducing the category map. In the training set, the categories “economics” and “government finance” have
Aggregating Intra-class and Inter-class Information
249
a strong co-occurrence with the three categories correctly predicted by BERT. Correspondingly, they are connected by edges in the category graph. Extracting such inter-class information increases the probability of predicting “economics” and “government finance”. Conversely, the two categories that BERT incorrectly predicted, “health” and “domestic politics”, have weaker co-occurrence with the three categories that BERT correctly predicted. Hence, AIIF excludes these two incorrect predictions.
5
Conclusion
This paper studies the multi-label text classification task. Previous methods focus either intra-class or inter-class information. We propose a novel two-branch architecture to combine both information. Experimental results show that the model capture both intra-class and inter-class outperforms those modeling either of them.
References 1. Adhikari, A., Ram, A., Tang, R., Lin, J.: Docbert: bert for document classification. arXiv preprint arXiv:1904.08398 (2019) 2. Chalkidis, I., Fergadiotis, M., Kotitsas, S., Malakasiotis, P., Aletras, N., Androutsopoulos, I.: An empirical study on large-scale multi-label text classification including few and zero-shot labels. In: Webber, B., Cohn, T., He, Y., Liu, Y. (eds.) Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, 16–20 November 2020, pp. 7503–7515. Association for Computational Linguistics (2020). https://doi.org/10.18653/v1/2020.emnlp-main. 607 3. Chang, W.C., Yu, H.F., Zhong, K., Yang, Y., Dhillon, I.: X-bert: extreme multilabel text classification using bidirectional encoder representations from transformers. In: Proceedings of NeurIPS Science Meets Engineering of Deep Learning Workshop (2019) 4. Chen, Z.M., Wei, X.S., Wang, P., Guo, Y.: Multi-label image recognition with graph convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5177–5186 (2019) 5. Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018) 6. Du, C., Chen, Z., Feng, F., Zhu, L., Gan, T., Nie, L.: Explicit interaction model towards text classification. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 6359–6366 (2019) 7. F¨ urnkranz, J., H¨ ullermeier, E., Menc´ıa, E.L., Brinker, K.: Multilabel classification via calibrated label ranking. Mach. Learn. 73(2), 133–153 (2008) 8. Gopal, S., Yang, Y.: Multilabel classification with meta-level features. In: Proceedings of the 33rd International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 315–322 (2010) 9. Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016)
250
X. Wu et al.
10. Lai, S., Xu, L., Liu, K., Zhao, J.: Recurrent convolutional neural networks for text classification. In: Twenty-ninth AAAI Conference on Artificial Intelligence (2015) 11. Lewis, D.D., Yang, Y., Rose, T.G., Li, F.: Rcv1: a new benchmark collection for text categorization research. J. Mach. Learn. Res. 5, 361–397 (2004) 12. Li, Q., Han, Z., Wu, X.M.: Deeper insights into graph convolutional networks for semi-supervised learning. In: Thirty-Second AAAI Conference on Artificial Intelligence (2018) 13. Liu, J., Chang, W.C., Wu, Y., Yang, Y.: Deep learning for extreme multi-label text classification. In: Proceedings of the 40th International ACM SIGIR Conference on Research and Development in Information Retrieval, pp. 115–124 (2017) 14. Liu, P., Qiu, X., Huang, X.: Recurrent neural network for text classification with multi-task learning. arXiv preprint arXiv:1605.05101 (2016) 15. Lu, J., Du, L., Liu, M., Dipnall, J.: Multi-label few/zero-shot learning with knowledge aggregated from multiple label graphs. In: Webber, B., Cohn, T., He, Y., Liu, Y. (eds.) Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing, EMNLP 2020, 16–20 November 2020, pp. 2935–2943. Association for Computational Linguistics (2020). https://doi.org/10.18653/v1/2020. emnlp-main.235 16. Ma, Q., Yuan, C., Zhou, W., Hu, S.: Label-specific dual graph neural network for multi-label text classification. In: Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers), pp. 3855– 3864 (2021) 17. Peng, H., Li, J., He, Y., Liu, Y., Bao, M., Wang, L., Song, Y., Yang, Q.: Largescale hierarchical text classification with recursively regularized deep graph-cnn. In: Proceedings of the 2018 World Wide Web Conference, pp. 1063–1072 (2018) 18. Ren, Z., Peetz, M.H., Liang, S., Van Dolen, W., De Rijke, M.: Hierarchical multilabel classification of social text streams. In: Proceedings of the 37th International ACM SIGIR Conference on Research & Development in Information Retrieval, pp. 213–222 (2014) 19. Rios, A., Kavuluru, R.: Few-shot and zero-shot multi-label learning for structured label spaces. In: Riloff, E., Chiang, D., Hockenmaier, J., Tsujii, J. (eds.) Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, Brussels, Belgium, October 31 - November 4, 2018, pp. 3132–3142. Association for Computational Linguistics (2018). https://doi.org/10.18653/v1/d18-1352 20. Xiao, L., Huang, X., Chen, B., Jing, L.: Label-specific document representation for multi-label text classification. In: Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pp. 466–475 (2019) 21. Yao, L., Mao, C., Luo, Y.: Graph convolutional networks for text classification. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 7370–7377 (2019) 22. You, R., Zhang, Z., Wang, Z., Dai, S., Mamitsuka, H., Zhu, S.: Attentionxml: label tree-based attention-aware deep model for high-performance extreme multi-label text classification. arXiv preprint arXiv:1811.01727 (2018) 23. Zhang, W., Yan, J., Wang, X., Zha, H.: Deep extreme multi-label learning. In: Proceedings of the 2018 ACM on International Conference on Multimedia Retrieval, pp. 100–107 (2018) 24. Zhou, J., et al.: Hierarchy-aware global model for hierarchical text classification. In: Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 1106–1117 (2020)
Fast Estimation of Multidimensional Regression Functions by the Parzen Kernel-Based Method Tomasz Galkowski1(B)
and Adam Krzy˙zak2,3
1
Institute of Computational Intelligence, Cz¸estochowa University of Technology, Armii Krajowej 36, Cz¸estochowa, Poland [email protected] 2 Department of Computer Science and Software Engineering, Concordia University, Quebec, Montreal H3G 1M8, Canada [email protected] 3 Department of Electrical Engineering, Westpomeranian University of Technology, Sikorskiego 37, 70-313 Szczecin, Poland Abstract. Various methods for estimation of unknown functions from the set of noisy measurements are applicable to a wide variety of problems. Among them the non–parametric algorithms based on the Parzen kernel are commonly used. Our method is basically developed for multidimensional case. The two-dimensional version of the method is thoroughly explained and analysed. The proposed algorithm is an effective and efficient solution significantly improving computational speed. Computational complexity and speed of convergence of the algorithm are also studied. Some applications for solving real problems with our algorithms are presented. Our approach is applicable to multidimensional regression function estimation as well as to estimation of derivatives of functions. It is worth noticing that the presented algorithms have already been used successfully in various image processing applications, achieving significant accelerations of calculations. Keywords: Parzen kernel algorithms · nonparametric regression estimation · multidimensional functions
1
Introduction
The purpose of modelling is to trace the rules of functioning and organization of the system. In wide spectra of problems the fundamental task is to obtain the reliable mathematical description or simply a mathematical model of the process or object. This model of a technical object is based on the knowledge of the physical laws and Supported by the program of the Polish Minister of Science and Higher Education under the name “Regional Initiative of Excellence” in the years 2019–2023 project number 020/RID/2018/19 the amount of financing 12,000,000 PLN. The work of the second Author was performed at Westpomeranian University of Technology, while on sabbatical leave from Concordia University. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 251–262, 2023. https://doi.org/10.1007/978-981-99-1639-9_21
252
T. Galkowski and A. Krzy˙zak
usually consists of the system of nonlinear differential or integral equations. Its structural complexity is sometimes very high, depending of course on the system itself. The aim of identification is to design the model of the system which is able to reconstruct or recover the state of the object at any point. A commonly used methodology is the analysis of phenomena by modeling problems with probability density functions (in continuous d-dimensional spaces) and distributions (in the case of discrete series of random numbers) representing the multidimensional processes. Regression analysis is a powerful statistical tool for building a functional relationship between the input and output data in a model. Generally, the inputs are the multidimensional vectors of random variables and output is the scalar function dependent on the random noise (see model (1)). This approach was at first used by Galton and Pearson more than one hundred years ago, cf., Stanton 2001 [33]. They assumed linear regression function. Linear regression models were later generalized to various parametric regression models, such as the Box-Cox regression model [3]. Regression analysis builds a functional relationship between a response variable (input) and an explanatory variable (output, or, in multivariate case: an explanatory variable vector). The problem of object or system identification is more difficult when we are not able to formulate its mathematical description. When a parameter of the process is estimated the so-called parametric approach is applied. Nonparametric methodology is used when no assumptions on the functional form of regression functions and data distributions have been made. Then R(·) in (1) is completely unknown function. Sometimes the lack of an apriori information makes identification task impossible. Moreover, the existence of probabilistic noise in observations makes this task more complex. In 1950s and 1960s the parametric regression models were further extended to newly developed nonparametric models, see Nadaraya (1964) [25], Watson (1964) [37], Parzen (1962) [28], Rosenblatt (1969) [30]. Our paper uses the ParzenRosenblatt approach.
2
Efficient Nonparametric Density Estimation Algorithms
Computational speed is an important concern when designing algorithms applied in data analysis. Next, a short overview of this aspect in nonparametric approach will be presented. When the regressor variable as well as the dependent variable are random observations (so-called random design case) commonly used algorithm is the Nadaraya-Watson algorithm, which has been extensively studied in the literature [25,37]. For instance, WARP technique (see, e.g., [21]) is an efficient technique for computing kernel regression estimates. In popular tools in data smoothing to compute a kernel estimate at a single point one has to evaluate and sum multivariate functions associated with all the individual data points (see, e.g., Scott 1992 [31], Wand and Jones (1995) [35]). In the year 2000 Holmstr¨ om in
Fast Estimation of Multidimensional Regression Functions
253
[22] proposed the binning-of-the-data method to improve the accuracy and computational complexity of a multivariate density estimator. The main idea of Holmstr¨om’s is to discretize the data to a grid, to bin the data first, and then to compute the weighted kernel estimator. Author noticed that the error of estimation depends on bin side length (denoted by δn ), and proved that sufficient consistency conditions of the binned estimator are: δn , hn , δn /hn → 0 as n → ∞. But, next he stated that δn = α · hn , α > 0, which contradicts the previous consistency conditions. Fast evaluation of the kernel for density estimation, based on the Fast Fourier Transform (FFT), was proposed and investigated, see, e.g., works of Fan and Marron (1994) [6], H¨ ardle and Scott (1992) [21], Silverman (1982) [32], Wand (1994) [36]. Raykar, et al. [29] (2010) proposed a novel, computationally efficient approximation algorithm for estimating derivative of a density function by means of the univariate Gaussian kernel-based density estimate algorithm that reduces the computational complexity from O(n · m) to linear O(n + m). The algorithm is based on the Taylor series expansion of the Gaussian kernel, in which only the first few terms of the expansion are used so that the estimation error, due to truncation, is less than the chosen accuracy. The technique was inspired by the fast multipole methods used in computational physics, see, Greengard, (1994) [17]. Another technique is based on the Fast Gauss Transform (FGT). This method was introduced by Greengard and Strain (1991) [18], and can be thought of as an extension of the Improved Fast Gauss Transform (IFGT), see, Yang et al., (2003) [38]. Antoniadis et al. [2] proposed wavelet estimator based on wavelets on the interval since the problem is confined to an interval. Wavelets and multiresolution analyses of L2([0, 1]) have been introduced and explicitly constructed by Cohen et al. (1993) [4]. There are well known implementations based on GPU hardware and software applications which in fact improve the computational effectiveness, but it is rather difficult to say that a new microelectronics technology is improving the mathematical algorithm itself. We may quote for instance works of Andrzejewski et al. [1], and Gramacki [16] where the NVIDIA GPUs and CUDA API ([26,27]) with highly parallel multi-threads implementation were used. In the next sections we do not use the binning method but instead we use the grid which is involved in our algorithm. We propose a more suitable solution of choosing the size of elementary sub-region in the hyper-grid in multidimensional case.
3
A Synopsis of the Kernel Methods in Regression Estimation
In this work we investigate the model in the form: yi = R(xi ) + εi , i = 1, ..., n
(1)
254
T. Galkowski and A. Krzy˙zak
where xi is assumed to be the d-dimensional vectors of deterministic input, xi ∈ Rd , yi is the scalar random output, and εi is a measurement noise with zero mean and bounded variance. R (.) is completely unknown function in the sense of ignoring its mathematical equation. Note that, this is so-called fixeddesign regression problem, see, e.g., [5]. ˆ n (x) of function R(.) at point x based on the set We start with estimator R of measurements yi , i = 1, ..., n. We use the Parzen kernel–based algorithm of the integral type: ˆ R(x) = an −d
n
K
yi
i=1
Di
x − u an
du
(2)
where x − u denotes a norm or the distance between points x and u in ddimensional space and Di ’s are defined below. Factor an depending on the number of observations n is called the smoothing factor. Let us mention that in nonparametric approach we impose no constraints on either the shape of unknown function or on any mathematical formula dependent on a certain set of parameters to be found. The domain area D is partitioned into n disjunctive nonempty sub-spaces Di and the measurements xi are chosen from Di , i.e.: xi ∈ Di . The set of input values xi in the model (1) are chosen in the process of collecting data, for instance, stock exchange information, or equidistant samples of ECG signal in time domain, or internet activity on several TCP/IP ports at server logs recorded in time. These data points should provide a balanced representation of function R in the domain D. The standard assumption in theorems on convergence of (3) is that the maximum diameter of set Di tends to zero if n tends to infinity (see, e.g., [13–15]). The kernel function K in one dimensional case K(.) satisfies the following conditions: K(t) = 0 t∈ / (−τ, τ ), τ > 0 τ K(t)dt = 1 (3) −τ
supt |K(t)| < ∞ Note that the adopted kernel is a compact support defined function. For multidimensional case we use the product kernel, given by: K (x, u, an ) =
d p=1
K
|xp − up | an
=K
x − u an
(4)
As an illustration let us to introduce two-dimensional (d = 2) model. Assume that the design space D = [0, 1] × [0, 1] and, √ for the sake of clarity, let integer n be a perfect square, thus integer m = 2 n. Then, a typical two dimensional nonparametric regression model can be written as follows: yij = R (x1i , x2j ) + εij , i = 1, ..., m,
j = 1, ..., m
(5)
Fast Estimation of Multidimensional Regression Functions
255
Parzen kernel estimator in this case is defined as follows: = an −2
m m i=1 j=1
yij
ˆ 1i , x2j ) = R(x
x2 −u2 1 K x1a−u K du1 du2 a Dij n n
(6)
where K(·) satisfies (3), by applying d = 2. Let us stress that we do not investigate the density estimation problem but the regression estimation problem in so-called fixed-design case, where variable xi is controlled by the experimenter. The computational complexity of this algorithm depends linearly on the number of observations n. So estimation at a single point x requires O(n) kernel evaluations. Our main goal is to apply our algorithm in multidimensional space. Unfortunately high dimensional problems often suffer from so-called “curse of multidimensionality”. Consequently, dimension reduction techniques are often used, e.g., projections. It will be shown in our article that the number of kernel evaluations in newly proposed algorithm is significantly reduced, also in multidimensional case, when compared with the original Parzen-Rosenblatt method using integral type of kernel calculation (2). An important feature in our algorithm is use of the kernel with compact support. Kernel satisfying (3) is compactly supported kernel. From this it follows that the large number of summed elements in (2) are equal to zero. So, they are absent from the algorithm.
4
The Novel Algorithm Basics
√ Let d n = m , and m be an integer. For the sake of clarity we will use the uniform grid in multi-space. Let the space be D = [0, 1]d . It is partitioned into hyper–cubes Dik by dividing the unit interval [0, 1] on k–th axis into m subsets Δxik , in such a way: Δxi1 × Δxi2 × · · ·Δxik · · · ×Δxid = Di1 ···id = Di ,
k = 1, ..., d
(7)
Let us introduce an integer r = hn · m. We start with localizing of the sub– cube Di∗ containing x, i.e., x ∈ Di∗ where index i∗ = [i∗1 , i∗2 , ..., i∗d ]. The estimator can be now calculated by the following thrifty formula: x−u ˆ R(x) = an −d yi K du (8) hn Di∈A∗ ∗ i∈A
where reduced region of summation A∗ is defined by: =
[i∗k
−
r, i∗k
−r+
A∗ = + 1, ..., i∗k + r − 1, i∗k + r] , k = 1 , ..., d
1, ..., i∗k , i∗k
(9)
256
T. Galkowski and A. Krzy˙zak
Fig. 1. Construction of region of calculation
Fig. 1 shows an example of partitioning the two-dimensional domain area D (the space where function R is defined) into n disjunctive nonempty hyper squares Di and the construction of reduced region around the point x applied in the estimation algorithm (8). Our algorithm should not to be applied in the boundary region of the domain D because of the well-known phenomenon called a boundary problem which is not studied in this work. Some suggestions on how to improve convergence of the kernel based algorithms in boundary region may be found in, e.g., [7,12,24,39].
5
The Scaling Factor an Selection and Its Impact on the Computational Speed
The choice of parameter an (called also smoothing parameter and/or scaling factor or bandwidth) is very important, see, e.g. [20,23,34]. When we estimate functions in presence of noise generally it can be said that too small an causes the estimated function be insufficiently smooth, i.e., it may have sharp jumps or irregularities. On the other hand, too big value of an may cause the loss of possible important features of functions, giving too smooth shape of its chart. So, parameter an should be carefully selected, and ideally the choice should be data dependent. The second important concern is computational speed. It is clearly seen in Fig. 1 and it follows from definitions (8) and (9) that the number of√cells Di taken into calculation procedure strictly depends on an . Recalling d n = m and r = hn · m the number λ(n) of sub-regions Di taken into calculations in
Fast Estimation of Multidimensional Regression Functions
257
d-dimensional case is equal to:
d d λ(n) = (2r + 1) = (2 hn · m + 1) ∼ = 2d nadn = O 2d nadn
(10)
In fact, we can conclude that (10) describes the rate of convergence of newly proposed algorithm. The number of evaluations in (8) is of order O 2d nadn . Next, let us choose the smoothing parameter an of exponential type an = m−α = n−α/d ,
0 0. We choose to model the covariance structure of this GP using an SQE kernel with a fixed length scale hyperparameter > 0. We compute the empirical mean w ¯ of the output values f (x1 ) = w1 , . . . , f (xM ) = wM that live in the training data Dtrain , as each component of μ. So as the joint of f (x1 ), . . . , f (xM ) is Multivariate Normal, with mean μ, and covariance matrix Σ = [K(xi , xj )], it implies that the joint of w1 , . . . , wM is this Multivariate Normal. But the joint of these M values of the output W that live in training data, is the probability of the data on W . In fact, it is a conditional probability - conditional on the parameters of the mean and covariance, i.e. on . But the probability of the data conditional on model parameters is the likelihood. Thus, the likelihood is Multivariate Normal with mean μ and the covariance
Efficient Uncertainty Quantification for Under-Constraint Prediction
281
Fig. 2. Results of learning using training data Dmean . Here the 2-stepped prior is imposed on . From left, (a,b) traces of learned values of and logarithm of posterior of respectively; (c,d) histogram representations of the marginal of and posterior values respectively; (e) 3-window convergence test, (see caption of Fig. 1) to confirm convergence of the MCMC chain.
matrix Σ that is kernel parameterised as above, with hyperparameter ,i.e. the likelihood is L(|Dtrain ) = (1/ (2π)M |Σ|) exp −(w − μ)T Σ −1 (w − μ)/2 , where w = (w1 , . . . , wM )T , and Σ is kernel parametrised using the SQE kernel with hyperparameter . We impose Gaussian priors on . Then the unscaled posterior on given the training data is this likelihood times the Gaussian prior. We perform learning with Metropolis Hastings in which is proposed from a truncated Normal density, and training data that is generated as the data Dmean , using the mean of the sampled α and β that are learnt using D. We update in each iteration of the chain, and subsequently, predict the mean and variance of the output W at test values of the input. However, these predictions do not necessarily offer output values to remain in [0,1]. To remedy this, we impose the constraint that we learn only those that predict outputs at any test input, to lie in [0,1]. We accomplish this by inputting a prior on that offers a density of ρhi > 0 if the predictions on an arbitrarily chosen set of test temperatures are within [0,1], and the prior is set at the value ρlo . In the t-th iteration of the MCMC chain, after proposing as ,t ∼ Truncated Normal(0, (t−1) , 0.032 ), a 2-stepped prior is invoked, as discussed below. Here (t−1) is the current value of in the t − 1-th iteration of the MCMC chain. ρhi π0,N ormal , if W predicted using = prop is ∈ [0, 1] fstepP rior (prop ) = ρlo π0,N ormal , otherwise, (2) where π0,N ormal = N (0.9, 0.082 ) prior placed on . We choose ρhi = 1, and 3 ρlo ≤ e−10 . The MCMC chain that we run comprises 100000 iterations; and the 95% HPD learnt on the using Dmean is [0.79509924, 1.10960904], with mean 0.9523 approximately. Figure 2 shows results obtained by running this chain. Figure 5a shows the GP prediction for this set up using ≈ 0.9523.
282
3.2
G. Roy and D. Chakrabarty
DNN Results
We have also implemented Deep Neural Network systems with varying architectures to make prediction of output W ≡ Pr(Y = 1|X = xtest ) at test temperatures. Here the DNN architecture varies in number of hidden layers (one to four), with 64 neurons in each layer. We make these predictions using Tensorflow, following the code that is available in the official documentation/tutorial cite of Tensorflow [14]. To allow for compatible comparison, here we perform predictions at the same test data, as used when output predictions were made following the learning of using MCMC, as in Subsect. 3.1. We perform these DNN-based predictions using Adam optimiser that has a learning rate=0.001, loss function=‘mean absolute error’, epoch=500. All the four predictive plots from DNN are shown in Fig. 3. The comparative results of DNN and GP are presented in the Table 2 and it is to be noted that DNN results are sensitive to the choice of architecture, even for this simplistic data set. Such sensitivity is cause for worry when reliable and robust ML implementation is sought. For more complex real-world data sets with higher noise, such sensitivity is likely to be more pronounced. Also, DNN does not produce the (probability) predictions within the strict range of 0 and 1, while predictions following GP-based learning could be successfully constrained to lie in this interval.
Fig. 3. DNN predictions with varying hidden layers, from left, DNN with 1, 2, 3 and 4 hidden layers with 64 neurons in each layer.
Fig. 4. Results as in Fig. 2, except obtained using Dright , from left: (a) to (e).
3.3
Learning Under Constraint, with MCMC, Using Dr ig ht
Instead of using the training data Dmean one could potentially use any value of α and β from within 95% HPDs learnt on these parameters, in the chain run with test flight data D, to compute Pr(Y = 1|X = xi ) where xi is the i-th design temperature; i = 1, . . . , M . The training data that will then result, will not be Dmean since the output at any design temperature in Dmean is computed at
Efficient Uncertainty Quantification for Under-Constraint Prediction
283
the mean of the α and β samples generated in the MCMC chain run with test flight data D. As stated above, when we populate a training data by computing Pr(Y = 1|X = xi ) using the value of α=15.19 and of β=-0.216 from the rightmost edge of their respective 95% HPD that is learnt in this MCMC chain, for i = 1, . . . , M , we get the training set that we refer to as Dright . Results of GPbased learning of undertaken with training data Dright are presented in Fig. 4. From this chain, we learn the 95% HPD on to be [0.7999134, 1.11657139], with mean of 0.9582 approximately. Figure 5b shows the GP prediction for this set up using = 0.9582. Prediction. Once is Table 2. Prediction of output W at test inputs updated in any iteration with DNNs of varying architectures, and following GPof the MCMC chain run based learning of hyperparameter () of the covariance with training data Dtrain kernel that parameterises the GP covariance structure. - which could be Dmean HL: number of Hidden Layers in the DNN, N: number or Dright - we predict the of Neurons. output W at a test input. 1 HL, 64N 2 HL, 64N each 3 HL, 64N each 4 HL, 64N each GP mean 1.0310167 0.9359667 0.938025 0.9486213 0.9323123 These predicted mean out- 0.9939897 0.91862655 0.91935253 0.9286917 0.934071 0.9569628 0.90128636 0.90068007 0.9093453 0.926627 puts at each of the con- 0.9199359 0.8839463 0.8820075 0.8906931 0.911141 0.86660624 0.863335 0.87204087 0.888929 sidered test temperatures 0.88290906 0.8458822 0.84823 0.8443388 0.8533887 0.861319 0.8088551 0.8216592 0.818619 0.82862586 0.82952 are plotted against tem- 0.7718283 0.79057 0.78710705 0.80185705 0.794533 0.7537405 0.75483435 0.77277607 0.757104 perature, in green stars, 0.7348015 0.6977744 0.7143501 0.7184024 0.73662823 0.717731 0.6607475 0.6749598 0.6793759 0.69309384 0.676709 in Fig. 5a, and Fig. 5b; W 0.62372065 0.6355695 0.63738376 0.6467888 0.634208 0.5862205 0.5937305 0.59471357 0.599609 0.590366 from the training data at 0.5407498 0.5457597 0.54709804 0.55195147 0.545373 0.4992079 0.5042205 0.49954 a design temperature is 0.4952766 0.49778882 0.44980314 0.44990715 0.45125672 0.45653468 0.453327 depicted in these figures 0.40433016 0.4031423 0.40382358 0.40979505 0.407338 0.3588567 0.35748953 0.35848558 0.36351866 0.362288 in red triangles, while the 0.31549928 0.31303966 0.31533292 0.32044703 0.318933 0.2705538 0.27200228 0.2804457 0.277996 uncertainty predicted at a 0.27524063 0.23979539 0.23742433 0.24069557 0.2418064 0.240096 0.19988464 0.20456474 0.2145113 0.205685 test input, on the output, 0.20099688 0.16196822 0.16679865 0.17303263 0.18275131 0.175016 0.14389752 0.14620048 0.14776252 0.15101965 0.148136 is depicted as the salmon- 0.12673585 0.12363774 0.12680508 0.12768927 0.124907 0.10957434 0.10223782 0.1055288 0.10772515 0.105044 pink shaded region. This 0.09241267 0.08653921 0.09068373 0.08748056 0.088173 0.07638919 0.0738593 0.073881 depicted “uncertainty” is 0.07525112 0.07084078 0.06405788 0.05796587 0.06551494 0.06084543 0.06177 0.0565593 0.0475015 0.05149 2.5 times the standard devi- 0.05543705 0.04907831 0.04681628 0.04019085 0.04897907 0.04064081 0.042759 ation that is predicted - in a 0.03819548 0.03227264 0.04272956 0.03435563 0.035364 0.02957465 0.02511932 0.03683984 0.02806261 0.029157 closed form way, along with 0.02095388 0.0179661 0.03095011 0.0229567 0.024029 0.0123331 0.01081278 0.02506035 0.01791653 0.0198959 the mean - at any test temperature. This is an advantage of performing learning with GPs; the output mean and variance predictions are closed-form. One concern that we have about this learning is that the choice of the training set that was employed in undertaking this learning appears ambiguous. In fact, percolation of the uncertainties in the learning of α and β is not correct in the formulation of these training data. In particular, when we examine the values of the predicted outputs and the uncertainties predicted on the outputs at any test temperature, we can firstly see that the result varies depending on which training data we have considered in the learning, and secondly, the result shows that some probabilities within the uncertainty bands,
284
G. Roy and D. Chakrabarty
are predicted as being in excess of 1. This is in spite our 2-stepped prior that we imposed to ensure that all output values stay restricted to [0,1]. Figure 5a depicts the result of predictions using Dmean while Fig. 5b shows the same with Dright .
4
Pipe-Lined Architecture for Efficient Uncertainty Percolation
As we have seen in the previous section, predictions following GP-based learning can result in unacceptable uncertainties, and ambiguous results stemming from generation of training sets based on ambigu- Fig. 5. (a,b) show GP predictions with 2.5 stanity in summarising uncertain- dard deviation of uncertainties with different s ties in the earlier stages. In with non pipe-lined architecture and pipe-lined this section we will illustrate architecture (c). a pipeline architecture for efficient handling of uncertainties which can also be extended to higher dimensional data. Figure 6 represents the two approaches of uncertainty modelling and percolation as flow charts. The pipeline architecture consists of three blocks within the MCMC chain that we run. These blocks are executed sequentially within each iteration of the MCMC chain, so that uncertainty gets propagated through the stages without incorporating further errors that can creep in via arbitrary summarisation of learning outcomes in previous stages. Broadly, inside an iteration of the MCMC chain, in the first block, α, β are updated - as within an iteration of Metropolis Hastings - using the test flight data D. In the second block of this iteration, the training data Dtrain is then computed at the current α and β values, and subsequently the GP kernel hyperparameter is updated using this current training data Dtrain , under the 2-stepped prior to ensure that output predictions remains within 0 and 1 at all considered temperatures. Then in the third block of this iteration, with the current value of , closed form mean and variance predictions of output values are undertaken, at each test temperatures. Hence, after the full iteration is over, we get traces of the predictive means at each test temperature. We use the range of values of the output W ≡ Pr(Y = 1|X = xtest ) sampled across the iterations, at each test input xtest , to compute the 95% HPD credible region. We compute the mean and the standard deviation of this sample, as the central prediction at test temperature xtest , with uncertainties of 2.5 times this sample standard deviation on either side of the mean, as the uncertainty in the prediction. The predicted variance at each test temperature is also recorded, and this is compared to the uncertainty of predictions obtained using the variation across MCMC samples. First block: In an iteration of the MCMC chain, in the first block α, β are learnt using data D, with Metropolis Hastings, with the same configuration that
Efficient Uncertainty Quantification for Under-Constraint Prediction
285
is used in Subsect. 2.1. The log of the posterior defined in this subsection is used. Gaussian priors are used for both α, β. Once, the burnin phase is over, then next stages are switched on. Second block: In this current iteration of the MCMC chain, the O-ring failure probability is computed at design temperatures x1 , . . . , xM using the current values of α, β as in Sect. 2.1. This newly computed failure probability now becomes the output computed at a design input. When performed over the whole set of design inputs, the training data Dtrain is generated. this is employed in learning the GP kernel hyperparameter , as depicted in the earlier Subsect. 3.1. Third block: In the third block of this iteration, predictive mean and variance of the output at each test temperature is computed, using the current . Figure 7 represents the results of learning the parameters α and β, using the test flight data, within the MCMC chain - to update the training set {(xi , wi )}M i=1 - that simultaneously learns the hyperparameter () of the covariance structure of the GP that models the functional relation between output W and input X. Here, the output W here is the probability of the O-ring failure and the Fig. 6. Approaches of dealing uncertainty input X is the ambient temperature to which such an O-ring is exposed. At each iteration of the MCMC chain, the uncertainty-included outputs are predicted at each test temperature. Thus, at the end of the MCMC chain, we predict values of mean and standard deviation of the output at each xtest ; such predictions are shown in Fig. 5c. Traces of predicted values of the output W at few of the considered test temperatures are depicted in Fig. 8, along with uncertainties.
Fig. 7. Traces and plots for learning α, β, in pipe-lined architecture; in c) C.:convergence test for α(top); β
Algorithm 1 depicts the algorithm of the pipeline with the following inputs. α0 , β0 , 0 are seeds; σα , σβ , σ are jump scales for α, β, respectively; μα , μβ prior mean and σαp , σβp are prior standard deviation for α and β respectively, number of maximum iterations N ; N is starting iteration for learning , n number of test data points, noise in GP.
286
G. Roy and D. Chakrabarty
Fig. 8. Variation of predicted mean of the output W with iteration index, at few values of xtest , from left 0th, 8th, 16th, 24th and 28th test point.
Algorithm 1: Pipeline for Forward Prediction Set α[0] ← α0 , β[0] ← β0 , [0] ← 0 , x ← Non-standardised input vector from D 2 xstd ← Standardise(x), xtest ← generate n test data within range of xstd /* Block-1: MH for learning α, β 3 for i ← 1 to N increment by 1 do 4 αcurrent ← α[i − 1], βcurrent ← β[i − 1], αproposed ∼ N (α[i − 1], σα ), u ∼ U (0, 1) α , π β )) − log(π I (α[i − 1], β[i − 1]|D, π α , π β ) 5 a ← log(π I (αproposed , β[i − 1]|D, π0 0 0 0 6 if a > log(u) then α[i] ← αproposed 7 else α[i] ← α[i − 1] 8 βproposed ∼ N (β[i − 1], σβ ) α , π β )) − log(π 1 (α[i], β[i − 1]|D, π α , π β ) 9 b ← log(π 1 (α[i], βproposed |D, π0 0 0 0 10 if b > log(u) then β[i] ← βproposed 11 else β[i] ← β[i − 1] /* Block-2: Vanilla MCMC for learning 12 yp ← ComputeF ailureP robability(αcurrent , βcurrent , D) as in Subsection 2.1 std ← Standardise(y ) 13 yp p 1
14 15 16 17 18 19 20 21 22 23 24 25
5
*/
*/
if i >= N then II j ← (i − N ), current ← [j − 1], πcurrent ← π II [j − 1], a ← −∞, b ← +∞ prop ∼ NT (current , σ , a , b ) p ← fstepP rior (xstd , yp , xtest , prop , ) if p == 1 then Q ← (log(NT (current |prop , σ , a , b )) − log(NT (prop |current , σ , a , b ))) II std , x, π )) πprop ← log(π II (prop |yp 0 II II II if (πprop − πcurrent + Q ) > log(u) then [j] ← prop , π II [j] ← πprop II else [j] ← current , π II [j] ← πcurrent else II [j] ← current , π II [j] ← πcurrent {yo , σgp } ← GP prediction([j], xstd , yp , xtest , ) ; // Block-3: prediction using GP
Discussion and Conclusions
In this work we present a robust method of uncertainty percolation in the context of a simple data situation; however the methodology is generic, and applicable to real-world complex datasets. This kind of integrated uncertainty percolation is useful for domains such as healthcare, where correct uncertainty acknowledgement is crucial. Results from our undertaken Bayesian approach have been compared to methods that consider uncertainty differently, and this comparison is the basic premise of the work. Importantly, we have also made comparison of our results with those obtained using against DNNs. In fact, the prediction performance in our work is benchmarked against DNN performance. Results obtained from DNNs are sensitive to the architectural parameters of the DNN. Indeed, the prediction made with DNNs could improve with further hyperparameter tuning; however, this very need for such tuning is our exact worry in regard to the employment of DNNs, when making prediction. We are not arguing against the possibility of enhancing prediction performance with DNNs; our quibble is that DNN prediction performance is sensitive to the choice of its archi-
Efficient Uncertainty Quantification for Under-Constraint Prediction
287
tectural parameters, and when the correct answer is not known - as is always the case in real-world problems - we do not know which architectural details are optimal for the considered prediction task, given the data at hand. So the quality of the prediction made with DNNs is rendered questionable in general and tuning the DNN for a test case in a chosen data context, does not directly inform on how well such a tuned DNN will perform in a different data context. Such difficulties with DNN usage are corroborated by the works of [6] and references therein. As stated above, our small example can be generalised to broader real-world data situations. Indeed, depending on the architectural parameters, DNNs can produce diversely inaccurate predictions. Via our simple illustration, we also show how the Bayesian setting allows for priors on the unknowns to facilitate the undertaken GP-based learning, such that the outputs - predicting which is the objective of the undertaken learning - can abide by relevant constraints.
References 1. Begoli, E., Bhattacharya, T., Kusnezov, D.: The need for uncertainty quantification in machine-assisted medical decision making. Nat. Mach. Intell. 1(1), 20–23 (2019) 2. Dalal, S.R., Fowlkes, E.B., Hoadley, B.: Risk analysis of the space shuttle: prechallenger prediction of failure. J. Am. Stat. Assoc. 84(408), 945–957 (1989) 3. Der Kiureghian, A., Ditlevsen, O.: Aleatory or epistemic? does it matter? Struct. Saf. 31(2), 105–112 (2009) 4. Ghahramani, Z.: Probabilistic machine learning and artificial intelligence. Nature 521(7553), 452–459 (2015) 5. Kruschke, J.K.: Doing Bayesian Data Analysis (2nd edn). Academic Press, Cambridge (2015) 6. Liao, L., Li, H., Shang, W., Ma, L.: An empirical study of the impact of hyperparameter tuning and model optimization on the performance properties of deep neural networks. ACM Trans. Soft. Eng. Methodol. (TOSEM) 31(3), 1–40 (2022) 7. Mallick, T., Balaprakash, P., Macfarlane, J.: Deep-ensemble-based uncertainty quantification in spatiotemporal graph neural networks for traffic forecasting (2022). arXiv:2204.01618 8. Neal, R.M.: Regression and classification using gaussian process priors (with discussion). In: Bernardo, J.M., et al. (eds.), Bayesian Statistics 6, pp. 475–501. Oxford University Press (1998) 9. O’Hagan, A.: Curve fitting and optimal design for prediction. J. R. Stat. Soc. Ser. B (Methodological) 40(1), 1–24 (1978) 10. Rasmussen, C.E.: Evaluation of Gaussian processes and other methods for nonlinear regression. PhD thesis, University of Toronto Toronto, Canada (1997) 11. Rasmussen, C.E., Williams, C.K I.: Gaussian Processes for Machine Learning. Adaptive Computation and Machine Learning. MIT Press (2006) 12. Robert, C.P., Casella, G.: Monte carlo statistical methods. In: Springer Texts in Statistics, Springer, New York (2004). https://doi.org/10.1007/978-1-4757-4145-2 13. Smith, H.J., Dinev, T., Xu, H.: Information privacy research: an interdisciplinary review. MIS quarterly, pp. 989–1015 (2011) 14. Tensorflow. Basic regression: Predict fuel efficiency (2022). https://www. tensorflow.org/tutorials/keras/regression 15. Young, D.S.: Chapman and Hall/CRC, Handbook of regression methods (2018)
SMART: A Robustness Evaluation Framework for Neural Networks Yuanchun Xiong1 and Baowen Zhang1,2(B) 1
Institute of Cyber Science and Technology, Shanghai Jiao Tong University, Shanghai 200240, China [email protected] 2 Shanghai Key Laboratory of Integrated Administration Technologies for Information Security, Shanghai 200240, China [email protected]
Abstract. Robustness is urgently needed when neural network models are deployed under adversarial environments. Typically, a model learns to separate data points into different classes while training. A more robust model is more resistant to small perturbations within the local microsphere space of a given data point. In this paper, we try to measure the model’s robustness from the perspective of data separability. We propose a modified data separability index Mahalanobis Distance-based Separability Index (MDSI), and present a new robustness evaluation framework Separability in Matrix-form for Adversarial Robustness of neTwork (SMART). Specifically, we use multiple attacks to find adversarial inputs, and incorporate them with clean data points. We use MDSI to evaluate the separability of the new dataset with correct labels and the model’s prediction, and then compute a SMART score to show the model’s robustness. Compared with existing robustness measurement, our framework builds up a connection between data separability and the model’s robustness, showing openness, scalability, and pluggability in architecture. The effectiveness of our method is verified in experiments.
Keywords: Neural network robustness Adversarial inputs · MDSI · SMART
1
· Data separability ·
Introduction
Recent work has demonstrated that neural networks (NNs) are vulnerable to adversarial examples: visually imperceptible perturbations that can mislead a well-trained model [1]. Safety is always a relative concept under adversarial environments. [2] suggests that the existence of adversarial examples is an inevitable part of the network architecture and an inherent weakness of network models. It is necessary to measure and improve the robustness of different models to This work is supported by the National Key Research and Development Program of China under Grant No. 2020YFB1807504 and No. 2020YFB1807500. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 288–299, 2023. https://doi.org/10.1007/978-981-99-1639-9_24
SMART: A Robustness Evaluation Framework for Neural Networks
289
adversarial examples and mitigate the risks caused by the adversary. Currently, many robustness measurements focus on the success rate of attacks and the minimal distortion to successfully generate adversarial examples, few consider data separability when measuring robustness. Separability is an intrinsic characteristic of a dataset. It describes how data points of different labeled classes are mixed together. Typically, a model learns to classify data points into different categories while training. The separability difference between the original dataset and the adversarial dataset can reflect the performance of models trained on them. [3] created a model-agnostic separability index called Distance-based Separability Index (DSI). It uses Euclid distance as its distance metric to measure how far data points are from each other. In this work, we modify DSI and apply data separability to robustness evaluation. First, we propose Mahalanobis Distance-based Separability Index (MDSI), a modification of DSI that uses Mahalanobis distance as its metric and considers the correlation between different dimensions of a dataset when measuring separability. For a given dataset, we use different attack techniques to general some adversarial examples and mix them with the original clean examples to form a new dataset. The new dataset with correct labels and model’s predicted labels show difference in separability. We use MDSI to reflect the difference and construct our robustness evaluation framework termed Separability in Matrixform for Adversarial Robustness of neTwork (SMART). We highlight our main contributions in this paper as follows: – We propose a data separability metric called MDSI. It shows the overall separability of a given dataset. In practice, we us e partitioned matrix operations to optimize the efficiency of computing MDSI. – We introduce SMART, a robustness evaluation framework for neural network models. SMART measures model’s robustness by comparing MDSI results on a new dataset consisting of clean data points and adversarial examples against the model. Our framework is scalable and shows flexibility in the choice of attacks and the proportion of adversarial examples generated. – We use SMART and some mainstream metrics to evaluate the robustness of several state-of-the-art NN models. The results verify the effectiveness of our SMART framework.
2 2.1
Related Work Adversarial Examples
A counter-intuitive property of neural networks found by [1] is the existence of adversarial examples, a hardly perceptible perturbation to a clean image can cause misclassification. [4] observes that the direction of perturbation matters most and proposes the Fast Gradient Sign Method (FGSM) to generate adversarial examples. Basic Iteration Method (BIM) [5] is an extension of FGSM by applying a smaller step size. Jacobian Saliency Map-based Attack (JSMA) [6]
290
Y. Xiong and B. Zhang
only modifies a limited amount of pixels of input to search for adversarial examples. DeepFool [7] uses geometry concepts as its guide for search. C&W attack [8] formulates finding adversarial examples as a distance minimization problem and can find adversary with a significantly smaller perturbed distance. On the opposite side of adversarial attacks, many defense techniques have been proposed to identify or reduce adversarial examples. There is an arms race between attacks and defenses. Adversarial training can improve robustness by retraining the model on adversarial examples [4]. It is by far the strongest empirical defense. There is no defense technique that is effective to all attacks. 2.2
Robustness Evaluation
Adversarial robustness is defined as the performance of a neural network model facing adversarial examples [9]. Some research formalizes their notion of robustness by giving their own definitions, including point-wise robustness [10], local robustness [11] and categorial robustness [12]. The core is when input changes within a small range, the output of a robust model shouldn’t show large fluctuation. The evaluation of robustness can be achieved from different perspectives. Accuracy. Model’s accuracy on adversarial examples is the direct indicator of robustness. Model’s accuracy on clean examples also reflects its performance and generalization. For convenience, we refer to the latter one as Nature Accuracy (NA). Minimal Distortion Radius. The minimal distortion radius represents the range of adversarial perturbation for generating successful adversarial examples. In general, a model with a larger radius suggests higher adversarial robustness. An upper bound of the radius is usually computed via some attacks and its tightness depends on the effectiveness of the attack [8]. We name the upper bound found by PGD attack as Empirical Radius (ER). A lower bound is usually provided by certified methods, it guagrantees that the model is robust to any perturbations smaller than it. [13] proposed an attackindependent robustness metric CLEVER to gain a lower bound. However, [14] pointed out that gradient masking can misguide CLEVER to overestimation. Later discussion demonstrated that CLEVER can handle gradient masking problems [15]. 2.3
Separability Index
Separability is an inherent characteristic of a dataset which measures the relationship between classes. [3] created Distance-based Separability Index (DSI) as a novel separability measure. It represents the universal relations between the data points in a dataset. A higher DSI score indicates that the dataset is easier to separate into different classes. When the DSI score is close to zero, it means that different classes of data have nearly identical distribution and are the most difficult to separate.
SMART: A Robustness Evaluation Framework for Neural Networks
3
291
Method
In 3.1, we discuss about the relationship between model’s robustness and data separability. On the basis of previous work on DSI mentioned in 2.3, we introduce a modified separability measure named MDSI in 3.2. In 3.3, we apply data separability to model’s robustness evaluation and present our robustness evaluation framework SMART. 3.1
Model’s Robustness and Data Separability
Given a NN model and its input space X , let δ be the minimal distortion required to craft an adversarial example x from a clean one x. Larger δ indicates that model is more robust around x. Consider the following two scenarios. – Scenario I: In X , only a few data points have relatively small δs. – Scenario II: In X , many data points have relatively moderate δs. Robustness evaluation metrics that seek the bound of the minimal distortion radius focus on a single data point in the dataset at a time. These metrics often find models in Scenario I less robust because they have smaller minimal distortion radius. It is questionable because models in Scenario I generally work well when these few unrobust points are filtered out, while models in scenario II need to be strenghened at many unrobust locations before deployment. Our method explores a novel approach based on data separability that simultaneously considers all data points in the dataset when evaluating robustness, and can reflect the overall robustness of neural network models. 3.2
The Separability Index MDSI
We propose Mahalanobis Distance-based Separability Index (MDSI) as a modification of DSI mentioned in 2.3. Intuitively, MDSI uses Mahalanobis distance as its distance metric, which has wide applications in image processing [19] and neurocomputing [18] areas. Mahalanobis distance is unitless, scale-invariant, and takes the correlations of the dataset into account [16], and can better reflect the overall data separability when applied in MDSI. It requires to pass through all variables in the dataset to compute the underlying inter-correlation structure, so it is usually computationally more expensive than Euclidean distance [17]. Following are the steps to compute MDSI. Given a dataset X and two points p, q ∈ X , let S be the covariance matrix of the dataset, dM be the Mahalanobis distance between p and q: (1) dM = dM (p,q) = (p − q)T S −1 (p − q) First, consider two classes Xm , Xn ⊂ X that satisfy Xm Xn = ∅, they have the same distribution and sufficient data points. The Intra-class Mahalanobis Distance (IMDm ) set contains dM between any two points in the same class Xm . IMDm = {dM (xi ,xj ) | xi , xj ∈ Xm ; xi = xj }
(2)
292
Y. Xiong and B. Zhang
The Between-class Mahalanobis Distance (BMDm,n ) set contains dM between any two points that are from different classes Xm , Xn . BMDm,n = {dM (xi ,xj ) | xi ∈ Xm ; xj ∈ Xn }
(3)
The Kolmogorov-Smirnov (KS) test quantifies a distance between the empirical distribution functions of two samples. Compared with other data distribution measures like Kullback-Leibler divergence, Jensen-Shannon divergence, and Wasserstein distance, KS test works when the samples have different number of points and is more sensitive when measuring separability [3]. MDSI uses KS test to examine the similarity of the distributions of IMD and BMDsets. Consider a n-class dataset X , its subset Xi , X¯i satisfies Xi X¯i = ∅, Xi X¯i = X . The MDSI score of X is defined as: n
MDSI(X ) =
1 KS(IMDi , BMDi¯i ) n i=1
(4)
X has the lowest separability when the distributions of the IMD and BMD sets are nearly the same, and it shows the lowest MDSI score. Advantages of MDSI. We using the sklearn.datasets.make blobs function in Python to create eight two-class and five-class datasets and compare their DSI [3] and MDSI. Each dataset has 1000 data points and one cluster center per class, the Standard Variation (SD) of the clusters is set between 1 and 8. The results are shown in Fig. 1.
Fig. 1. DSI and MDSI scores on n-class datasets with different SDs.
As SD increases, the distributions of different classes overlap more and more, DSI and MDSI show the same downward trend, which is in line with our perception. The curves in (a) almost overlap, because the dimension is low and the correlation between dimensions is not obvious. In (b), the curves of MDSI is lower, indicating that when the class increases, the effect of dimensional correlation begins to appear, and the separability exhibited by MDSI is more realistic. The comparison of the two experimental results shows that when the feature dimension of the dataset increases, the correlation between them has a greater impact on MDSI, and the difference between DSI and MDSI is more obvious. We consider MDSI to be a better separability metric because it takes into account the
SMART: A Robustness Evaluation Framework for Neural Networks
293
influence of the dimensional correlation of the dataset and can more realistically reflect the overall data separability. Here are some optimizations we made to the calculation of MDSI. The time cost for computing IMD and BMD sets increases quadratically with the number of data points. [3] encountered a similar problem and suggested that random sampling can reduce time cost without significantly affecting the results. However, we think their approach has its inherent defect. We optimize the computation of IMD, BMD, and MDSI by introducing partitioned matrix operation. Small Batch. First, we apply the idea of training a neural network in small batches to the generation of BMD sets. Choosing an appropriate batch size can avoid out-of-memory issues no matter how large the dataset grows. Partitioned Matrix Operation. By converting raw data points into matrices, we can use CUDA to speed up matrix operations and reduce the time cost. X¯i = is a class of dataset X where X The operation is shown in Fig. 2. X i i ∅, Xi X¯i = X . The number of data points in Xi , X¯i are M and N . The feature dimension of each point is F . When the covariance matrix S of the dataset is not full rank, it will be replaced by its pseudo-inverse matrix. P = PM ×F = T T (P1 , . . . , PM ) and Q = QN ×F = (Q1 , . . . , QN ) are two input matrices, each row represents a data point.
Fig. 2. The matrix operations in computation of IMD and BMD sets.
Take any row Pi and Qj for example, their Mahalanobis distance is dM (Pi ,Qj ) = Pi S −1 P Ti − Qj S −1 PiT − Pi S −1 QTj + Qj S −1 QTj
(5)
There are four matrix multiplication operations in the above formula. Extend the above formula to all data points, we get the distance matrix D2 = BMDi¯i of size M × N . In a distance matrix, each element dij represents a distance. D2 can still be regarded as a combination of four matrix multiplication operations. (6) D2 = M1 − (QS −1 P T )T − P S −1 QT + M2
294
Y. Xiong and B. Zhang
In Formula 6, M1 and M2 are two M ×N matrices. First take out the diagonal elements of the M × M matrix P S −1 P T and get a M × 1 vector, then replicate and extend it to a M × N matrix M1 . Similarly, replicate the N × 1 vector consisting of the diagonal elements of matrix QS −1 QT , extend it to a N × M elements in D2 form the BMDi¯i set. matrix which is the transpose of M2 . All The distance matrix D1 = IMDi = 2 · M3 − 2 · P S −1 P T is a symmetric matrix, only elements in its strictly lower triangular matrix is needed to form the IMDi set. M3 can be obtained in the similar way as M1 and M2 . When the size of the input matrices P or Q is too large, we use the combination of small batch and matrix operation (i.e. partitioned matrix operation) for optimization. The above operations still apply to partitioned matrices. We verified our optimization on Google Colab and the results show significant improvement. Computing MDSI on MNIST is almost 200 times faster, the calculation time reduced from 3387.79 s to 17.76 s. When we set partition size to 5000, the calculation time is 17.88 s, almost the same. The results indicate that matrix operation is far more efficient and small batch can solve the out-ofmemory problem without significantly affect performance. For convenience, the default partition size is set to 10000, and partition matrix operations are applied when more samples are added. 3.3
The Robustness Evaluation Framework SMART
In this section, we combine MDSI and neural network models. We evaluate the model’s robustness by measuring the separability difference between the datasets with correct labels and with model predicted labels. Figure 3 shows the evaluation process for our framework SMART. We combine the standard and adversarial test sets into a new dataset. The score MDSI0 of the new dataset with correct labels is considered as the separability reference result. The score MDSI1 of the new dataset with model predicted labels is considered as the separability measurement result. We can use these two MDSIs to calculate the final SMART score that represents the model’s robustness.
Fig. 3. The robustness evaluation framework SMART.
Attack DB. A flexible and critical component of our framework is the Attack Database (Attack DB). The idea is to put some typical attacks in the DB and
SMART: A Robustness Evaluation Framework for Neural Networks
295
mix the generated adversarial examples with clean examples in appropriate proportions. In practice, FGSM [4], BIM [5], PGD [2], DeepFool [7] and C&W [8] are selected to join the Attack DB. The proportion K of adversarial examples generated by different attacks in attack database A is basically an empirical parameter, which can be tuned by researchers using SMART. We think attacks with high time complexity should generate fewer samples. In practice, we use the time T for the attacks in A to generate the same number of adversarial examples on the same dataset as a reference for time complexity, and use it to determine the proportion K. Different ai , aj ∈ A may vary widely in ti , tj ∈ T , so the relationship between ki , kj ∈ K is determined by ki /kj = log tj / log ti . SMART Formula. We expect to create a formula that utilizes the difference of the separability reference result MDSI0 and measurement result MDSI1 to reflect robustness. Through observation, we preset the following three formulas: y1 = 2 −
MDSI1 1 , y2 = tanh(y1 ), y3 = Sigmoid(y1 ) = MDSI0 1 + e−y1
(7)
Intuitively, higher SMART score represents a more robust model. We experimented with the above formulas using the configuration in 4.1. Their curves are in line with the expected trend, suggesting their validity in representing robustness. Among them, y3 is more sensitive to changes and its results are normalized. We determine the final SMART score as y = y3 = Sigmoid(2 − MDSI1 /MDSI0 ) and present the results in 4.1. Algorithm 1 summarizes the process of calculating SMART scores.
Algorithm 1: SMART score
1 2 3 4 5 6 7 8 9 10 11 12
Input: Dataset X and corresponding label C, model f , attack database A and proportion K, total number of attacks n, SMART formula y. Output: SMART score ρ. ρ = 0; calculate the model’s predicted labels Y = f (X) on the clean dataset X; for i ← 1 to n do use ai ∈ A to generate corresponding adversarial examples Xi = ai (f, X); /* | | represents the total number of elements in set */ use ki ∈ K to randomly choose Xi ⊆ Xi that satisfies |Xi | = ki · |Xi |; calculate the predictions Yi = f (Xi ) and correct labels Ci on Xi ; combine X, Xi and compute their separability MDSI0 under labels C, Ci ; combine X, Xi and compute their separability MDSI1 under labels Y, Yi ; ρ = ρ + y(M DSI0 , M DSI1 ); i=i+1; end return ρ = ρ/n
296
4
Y. Xiong and B. Zhang
Experiments
In this section, we make some experiments to demonstrate the sensitivity and validity of SMART in 4.1, and compare SMART with existing robustness evaluation metrics in 4.2. For evaluation purposes, we implemented Algorithm 1 as a proof-of-concept tool, which is written in Python 3.8 and uses the PyTorch frameworks. All experiments mentioned in this section were run on the Google Colab environment. 4.1
The Validity of SMART
The upper limit of perturbation is set between 0 and 1 in increments of Δ = 0.1. We compute the SMART scores of an AlexNet pre-trained on MNIST (nature accuracy 99.19%) under different and present the results in Table 1. Table 1. The SMART scores of a pre-trained AlexNet under different
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
1.0
SMART 0.721 0.626 0.550 0.527 0.494 0.469 0.464 0.459 0.457 0.433
A larger indicates that larger perturbations may appear, and the probability of misclassification will increase accordingly. For the same model, when the perturbation gradually increases, it will appear to be less robust. As shown in Table 1, the SMART score of the AlexNet decreases as increases, which verifies the validity of SMART. When is fixed, a more robust model will have more similar MDSI0 and MDSI1 with its SMART score closer to 1, a less robust model will have a SMART score farther from 1. SMART is more sensitive when ≤ 0.5. 4.2
SMART and Mainstream Robustness Metrics
We further experiment on the MNIST and CIFAR-10 (CIFAR for short) datasets, comparing SMART and mainstream robustness metrics, including Natural Accuracy (NA), Empirical Radius (ER), and CLEVER score mentioned in 2.2. Table 2. Robustness evaluation results on MNIST. NA(%) Std
ER(×10−4 ) CLEVER
Adv Std
Adv Std
SMART
Adv Std
Adv
LeNet-5 99.22 99.00 1.202 1.467 0.185 0.237 0.583 0.728 AlexNet 99.29 98.85 1.192 1.232 0.319 0.362 0.689 0.726
SMART: A Robustness Evaluation Framework for Neural Networks
297
ER represents the upper bound of the minimal perturbation computed by attacks under the l2 norm. CLEVER is set according to the original paper [8] where the sampling parameters batch size Nb = 500, the number of samples per batch Ns = 1024, the maximum perturbation R = 2 under l2 norm and 100 test-set images for CIFAR and MNIST. On MNIST, we evaluate these metrics on relatively small models Lenet-5 and AlexNet. Each model has a standard trained version (Std) and an adversarial trained version (Adv). The Adv models are enhenced via PGD [2] adversarial training with = 0.3, α = 0.1, iteration = 40 and random initiation. The evaluation results are shown in Table 2. The change in natural accuracy shows that adversarial training slightly reduces the generalization ability of the model. ER and CLEVER show that adversarial training indeed makes the model more robust, showing higher scores. Comparing the SMART scores of the Std and Adv columns, the results show that the robustness of both models is improved after PGD adversarial training. Table 3. Robustness evaluation results on CIFAR-10. NA(%) LeNet-5 ResNet-18 SqueezeNet VGG-16 AlexNet DenseNet-121
CLEVER
SMART
Std
Adv Std
ER Adv
Std
Adv
Std
Adv
63.82 78.49 79.8 81.93 82.94 89.42
62.58 72.14 77.25 74.51 77.99 70.87
0.07279 0.07283 0.07294 0.07298 0.07281 0.07311
0.0726 0.0181 0.013 0.0118 0.0718 0.0409
0.0875 0.0466 0.0603 0.2271 0.1682 0.1837
0.2560 0.2558 0.2429 0.2533 0.2552 0.2507
0.2610 0.2830 0.2653 0.2746 0.2602 0.2855
0.07255 0.07278 0.07273 0.07266 0.07277 0.07285
On CIFAR-10, we additionally evaluate four other models VGG-16, DenseNet-121, ResNet-18 and SqueezeNet, the results are shown in Table 3. Both ER and CLEVER show that adversarial training improves the robustness of the models, although the changes in ER are very slight. The SMART scores of the Adv models are higher than those of Std models, which can reflect the changes in robustness of a single model unfer different training methods. Comparing SMART scores between different models under the same training method, the results in the last two columns show that after adversarial training, DenseNet and ResNet are more robust. The above experiments verify that SMART is a reliable robustness evaluation framework, which matches well with mainstream robustness metrics such as ER and CLEVER on various models. Now we discuss the advantages of SMART over these attack-based or certification metrics. SMART and Attacks. In theory, adversarial attacks developed to search for anti-robust perturbations of models around data points can only optimize to some local minima. In a sense, the attack-based method ER can only achieve
298
Y. Xiong and B. Zhang
partial guarantees. Thus, a holistic robustness evaluation method is expected to be developed to reflect a more comprehensive robustness distribution in the input space. Compared to adversarial attacks that seek the upper bound of local minimal perturbations, SMART exploits all the anti-robust perturbations found by tools in the Attack DB and reflects the overall robustness of neural networks. SMART and Certifications. CLEVER is an attack-agnostic robustness metric to estimate a lower bound of the minimal perturbation, which transforms the robustness evaluation process into a local Lipschitz constant estimation problem and applies the extreme value theory to solve it. While certification methods such as CLEVER and randomized smoothing can provide lower bound guarantees, SMART can be used to measure the overall robustness explored by Attack DB. Its evaluation results for a single model are as effective as the mainstream robustness evaluation metrics, and can also well reflect the robustness differences between different models. Moreover, many current and future adversarial methods can be plugged into our attack library, Attack DB, according to the evaluation process in Fig. 3. Our proposed data separability index MDSI enables reasonable integration of all generated adversarial data. Therefore, SMART can be used as an open and pluggable framework to evaluate robustness.
5
Conclusion
In this paper, we propose SMART, a novel robustness evaluation framework for NN models. The main advantages of SMART over mainstream robustness evaluation methods are: (i) we develope a data separability index MDSI, which allows SMART to evaluate robustness more stably and suitably from the perspective of the overall dataset separability; (ii) we use partitioned matrix operations to significantly reduce the computation time of SMART and fix the out-of-memory issue; (iii) the Attack DB in SMART is open to accommodate a wide variety of adversarial methods, which makes our framework expandable. Currently, the applicability of SMART has been verified with extensive experiments on datasets including MNIST and CIFAR-10 and on models including LeNet-5, AlexNet, ResNet-18, SqueezeNet, VGG-16, and DenseNet-121. The results show that SMART scores match and outperform mainstream robustness metrics when evaluating both natural and defended models. We plan to extend our work to ImageNet in future work.
References 1. Szegedy, C., Zaremba, W., Sutskever, I., et al.: Intriguing properties of neural networks. arXiv preprint arXiv:1312.6199 (2013) 2. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. arXiv preprint arXiv:1706.06083 (2017) 3. Guan, S., Loew, M., Ko, H.: Data separability for neural network classifiers and the development of a separability index. arXiv preprint arXiv:2005.13120 (2020)
SMART: A Robustness Evaluation Framework for Neural Networks
299
4. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572 (2014) 5. Kurakin, A., Goodfellow, I., Bengio, S.: Adversarial examples in the physical world. arXiv preprint arXiv:1607.02533 (2016) 6. Papernot, N., McDaniel, P., Jha, S., Fredrikson, M., Celik, Z.B., Swami, A.: The limitations of deep learning in adversarial settings. In: 2016 IEEE European Symposium on Security and Privacy (EuroS&P), pp. 372–387. IEEE (2016) 7. Moosavi-Dezfooli, S.M., Fawzi, A., Frossard, P.: Deepfool: a simple and accurate method to fool deep neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2574–2582 (2016) 8. Carlini, N., Wagner, D.: Towards evaluating the robustness of neural networks. In: 2017 IEEE symposium on security and privacy (SP), pp. 39–57. IEEE (2017) 9. Bai, T., Luo, J., Zhao, J.: Recent advances in understanding adversarial robustness of deep neural networks. arXiv preprint arXiv:2011.01539 (2020) 10. Bastani, O., Ioannou, Y., Lampropoulos, L., et al.: Measuring neural net robustness with constraints. arXiv preprint arXiv:1605.07262 (2016) 11. Huang, X., Kroening, D., Ruan, W., et al.: A survey of safety and trustworthiness of deep neural networks: verification, testing, adversarial attack and defence, and interpretability. Comput. Sci. Rev. 37, 100270 (2020) 12. Levy, N., Katz, G.: Roma: a method for neural network robustness measurement and assessment. arXiv preprint arXiv:2110.11088 (2021) 13. Weng, T.W., Zhang, H., et al.: Evaluating the robustness of neural networks: an extreme value theory approach. arXiv preprint arXiv:1801.10578 (2018) 14. Goodfellow, I.: Gradient masking causes clever to overestimate adversarial perturbation size. arXiv preprint arXiv:1804.07870 (2018) 15. Weng, T.W., Zhang, H., Chen, P.Y., et al.: On extensions of clever: a neural network robustness evaluation algorithm. In: 2018 IEEE Global Conference on Signal and Information Processing (GlobalSIP), pp. 1159–1163. IEEE (2018) 16. McLachlan, G.J.: Mahalanobis distance. Resonance 4(6), 20–26 (1999) 17. Ghorbani, H.: Mahalanobis distance and its application for detecting multivariate outliers. Facta Univ. Ser. Math. Inf. 34(3), 583–95 (2019) 18. Haldar, N.A.H., Khan, F.A., Ali, A., Abbas, H.: Arrhythmia classification using mahalanobis distance based improved fuzzy c-means clustering for mobile health monitoring systems. Neurocomputing 220, 221–235 (2017) 19. Zhang, Y., Du, B., Zhang, L., et al.: A low-rank and sparse matrix decompositionbased mahalanobis distance method for hyperspectral anomaly detection. IEEE Trans. Geosci. Remote Sens. 54(3), 1376–1389 (2015)
Time-aware Quaternion Convolutional Network for Temporal Knowledge Graph Reasoning Chong Mo1 , Ye Wang1,2(B) , Yan Jia1 , and Cui Luo2 1
School of Computer Science and Technology, Harbin Institute of Technology (Shenzhen), Shenzhen, China {mochong,wangye2020,jiaya2020}@hit.edu.cn 2 Peng Cheng Laboratory, Shenzhen, China [email protected] Abstract. Temporal knowledge graphs (TKGs) have been applied in many fields, reasoning over TKG which predicts future facts is an important task. Recent methods based on Graph Convolution Network (GCN) represent entities and relations in Euclidean space. However, Euclidean vectors cannot accurately distinguish entities in similar facts, it is necessary to further represent entities and relations in complex space. We propose Time-aware Quaternion Graph Convolution Network (T-QGCN) based on Quaternion vectors, which can more efficiently represent entities and relations in quaternion space to distinguish entities in similar facts. T-QGCN also adds a time-aware part to show the influence of the occurrence frequency of historical facts when reasoning. Specifically, T-QGCN uses QGCN with each historical fact frequency to aggregate graph structural information for each timestamp in TKGs and uses RNN to dynamically update entity representation and relation representation. To decode in quaternion space and better use historical representations, we design a new decoding module based on Convolution Neural Network (CNN) to help T-QGCN perform better. Extensive experiments show that T-QGCN has better performance than baselines for the entity prediction task on four datasets. Keywords: Temporal Knowledge Graph · Knowledge Reasoning Graph Convolutional Network · Quaternion
1
·
Introduction
Knowledge Graphs (KGs) have been widely used in many fields, such as Recommendation System [1], Question Answering System [2], Crisis Warning [3], etc. Traditional knowledge graphs can be regarded as multi-relational graphs without time information. However, each fact in knowledge graphs does not always happen and may change over time, it is necessary to add a temporal constraint to each fact to construct temporal knowledge graphs (TKGs). Each fact in TKGs is represented in the form of a quadruple (subject entity, relation, object entity, c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 300–312, 2023. https://doi.org/10.1007/978-981-99-1639-9_25
Time-aware Quaternion Convolutional Network
301
timestamp), which can be simply written as (s, r, o, t), and the timestamp indicates when the fact occurs. Since there are an infinite number of facts in the real world, completing the KGs is a very important task. Static KGs completion is predicting the missing subject entity or object entity in a triple. Completion in TKGs can be considered as reasoning, which is divided into two tasks: interpolation and extrapolation [3]. Similar to static KGs completion, TKG reasoning is also predicting the missing entities in quadruple (s, r, ?, t) or (?, r, o, t) for a temporal knowledge graph from t0 to tT . The difference between interpolation and extrapolation is that the former task is to predict the facts from t0 to tT , while the latter task is to predict future facts after tT . In the extrapolation task, putting all facts in chronological order to predict unknown facts is more challenging and can be applied to many fields such as crisis warning and COVID-19 infectious source forecast. This paper focuses on the extrapolation task in TKG reasoning. Some researchers have studied the extrapolation setting in TKG reasoning. CyGNet [4] uses linear models to learn temporal features in TKGs, others including RE-Net [3] and RE-GCN [5] use Graph Convolution Networks (GCN) to get the interaction of concurrent facts at the same timestamp. These works studied TKG reasoning on Euclidean space and achieved some progress, but they suffer the following problems: Firstly, they weakly distinguish similar facts with the same entities and relations at the same timestamp. For instance, there are two facts (Saeb Erekat, make statement, Israel) and (Saeb Erekat, make statement, Knesset) occur at the same timestamp, Euclidean vectors cannot accurately distinguish object entities Israel and Knesset, the same head entities and relations cause their Euclidean vectors to be very close. Secondly, existing methods ignore the influence of the frequency of historical facts when reasoning. The fact (Saeb Erekat, make statement, Israel) has occurred many times and the fact (Saeb Erekat, make statement, Knesset) occurs only once in history, the former can be regarded as a continuous fact and the latter is an accidental fact, the continuous facts are more helpful to reason future facts. Therefore, it is necessary to design a method for TKG reasoning which can distinguish similar facts and the frequency of historical facts. Quaternion vectors have been shown to have highly expressive representations than Euclidean vectors [6], inspired by it, we propose Time-aware Quaternion Graph Convolutional Network (T-QGCN), which can represent entities and relations in quaternion space and recognize the frequency of historical facts. In addition, we observe that existing reasoning models only use the entity representation at timestamp tT to predict future facts for a temporal knowledge graph from t0 to tT , while updating the entity representation from t0 to tT may result in partial loss of historical features. To avoid this problem, we design a new decoding method that makes better use of historical entity representations in reasoning. This paper has the following contributions: (1) We propose a new model T-QGCN with time attention for temporal reasoning in TKGs, which represents entities and relations as quaternion vectors and recognizes the frequency of historical facts.
302
C. Mo et al.
(2) We design a new decoding module to use more historical representations to avoid feature loss when reasoning. (3) Extensive experiments demonstrate that T-QGCN has better performance than existing methods in entity prediction task on four datasets.
2
Related Works
Static KG Methods. TransE [7] is a classical translating model, the basic idea of which is to make the sum of the subject embedding and relation embedding as close as possible to the tail embedding in a low-dimension vector space. TransH [8] and TransR [9] are extended models of TransE, they introduce a hyperplane and a separate space respectively to recognize one-to-many, many-to-one, and many-to-many relationships. RESCAL [10] is a bilinear model based on 3dimensional tensor decomposition. DistMult [11] restricts the relation matrix to a diagonal matrix, which proves the general applicability of the bilinear model. RotatE [12] regards the object entity as a rotation of the subject entity in the complex space. QuatE [13] introduces quaternion space to rotate entities and relations. ConvE [14] uses 2-dimension convolution networks to get the interaction between entities and relations. With the development of Graph Neural Networks(GNN), many static KG completion methods based on GNN have been proposed. GCN [15] uses graph convolution to better obtain node features. RGCN [16] adds a relation-specific matrix to handle the effect of different relations on entities in the graph structure. SCAN [17] proposes an efficient convolutional decoder ConvTransE capture features implicit in entities and relations. QGNN [6] combines quaternion space and GNN, which can be applied to the tasks such as node classification. TKG Reasoning Methods. For the interpolation task, TA-DistMult [18] embeds time information in relation to form a predicate sequence to obtain time features and uses the scoring function in DistMult to evaluate. TTransE [19] integrates time information into the embedding of relation for each fact in TKGs. HyTE [20] introduces a hyperplane to project entities and relations with timestamps. TeRo [21] uses temporal rotation to embed entities and relations. ATiSE [22] decomposes time series to obtain time information in TKGs. For the extrapolation task, Know-Evolve [23] and DyRep [1] introduce time point process into TKG reasoning and combine it with MLP decoder to predict future facts. CyGNet [4] proposes a copy-generation network to get the repeated facts in history, but it ignores the interaction information among entities in TKG. RE-Net [3] uses RGCN to aggregate concurrent facts at the same timestamp and predict future facts through the joint distribution of temporal events. REGCN [5] uses RGCN and RNN to learn the evolutional representations of entities and rela-=tions at each timestamp and proposes a static module to learn the static properties of entities for ICEWS datasets. Compared with these works, T-QGCN models entities and relations in quaternion space which can better distinguish similar facts when obtaining graph structure information of each timestamp. Meanwhile, the time perception part of T-QGCN can obtain the
Time-aware Quaternion Convolutional Network
303
frequency of each fact in history, and the new decoding module can better use features updated every timestamp in history.
3
Preliminary
A quaternion q ∈ H can be defined as: q = (qv , qw ) = iqx + jqy + kqz + qw = qv + qw
(1)
where qv = iqx +jqy +kqz is the imaginary part and qw is the real part. i, j, k are imaginary units and i2 = j 2 = k 2 = −1, jk = −kj = i, ki = −ik = j, ij = −ji = k. For the imaginary part qv , all ordinary vector operations can be imposed, such as follows: Addition. The addition of two quaternions q = iqx + jqy + kqz + qw and r = irx + jry + krz + rw is q + r = (qv , qw ) + (rv , rw ) = (qv + rv , qw + rw ). Conjugation. The conjugation of a quaternion q is q ∗ = (qv , qw )∗ = (−qv , qw ). Scalar Multiplication. The multiplication of a scalar λ and a quaternion q is λq = λ(qv , qw ) = (λqv , λqw ). Hamilton Product. The Hamilton product for of two quaternions q and r can be defined as: q ⊗ r = (iqx + jqy + kqz + qw ) ⊗ (irx + jry + krz + rw ) = i(qx rz − qz ry + rw qx + qw rx ) + j(qz rx − qx rz + rw qy + qw ry )
(2) +k(qx ry − qy rx + rw qz + qw rz ) + qw rw − qx rx − qy ry − qz rz 2 ). Norm. The norm of a quaternion q is |q| = (qq ∗ ) = (qx2 + qy2 + qz2 + qw Inner Product. The inner product of two quaternions q and r can be defined as q · r = qw rw − qx rx − qy ry − qz rz
4
Overview of T-QGCN
Consider a TKG from t0 to tT , it can also be regarded as knowledge graph sequences {Gt0 , Gt1 ,. . . , GtT }, for each sequence Gti = {E, R, Sti }, 0 ≤ i ≤ T , E is the set of entities, contains |E| different entities, R is the set of relations, contains |R| different relations, Sti is the set of facts, every fact in Sti is quadruple (s, r, o, ti ), where s, o ∈ E, r ∈ R. We add inverse quadruple (o, r−1 , s, ti ) to predict missing subject entity in quadruple (?, r, o, tT +1 ) or missing object entity in quadruple (s, r, ?, tT +1 ) simultaneously. Due to the sequences closer to tT +1 are more helpful to reason, T-QGCN uses the history sequences {GtT −m+1 . . . , GtT } to reason future facts. T-QGCN consists of an encoding module and a decoding module. The encoding module updates the dynamic entity embedding and the dynamic relation embedding at every timestamp. The decoding module uses the historical entity embeddings {H tT −m+1 . . . , H tT } and the historical relation embedding {RtT −m+1 . . . , RtT } to get entity probability vectors. The overview of T-QGCN has shown in Fig. 1.
304
C. Mo et al.
Fig. 1. The process of T-QGCN reasoning future facts at timestamp tT +1 .
4.1
Encoding Module
T-QGCN first randomly initializes the dynamic entity embedding and dynamic relation embedding as H 0 and R0 for all entities and all relations. To update the relation embedding from Rt−1 to Rt (tT −m+1 ≤ t ≤ tT ), T-QGCN considers the facts in graph sequence Gt , and takes the average of the dynamic entity embedding H t−1 of the entities for each relation in Gt as the structural relation embedding R∗t , then uses GRU to update dynamic relation embedding, as shown in the below: (3) Rt = GRUr (R∗t , Rt−1 ) To update the entity embedding from H t−1 to H t , T-QGCN transforms the embedding ht−1 of each entity at timestamp t − 1 and the relation embedding r t of each relation at timestamp t into quaternion embeddings as the first layer inputs to QGCN, as below: g(ht−1 ) = ht−1,w + iht−1,x + jht−1,y + kht−1,z
(4)
g(r t ) = r t,w + ir t,x + jr t,y + kr t,z
(5)
Each KG sequence is a multi-relational graph and QGCN can powerfully obtain graph structure information in quaternion space, a p−layer QGCN is used to get the structural entity embedding, an object entity o at layer l + 1 can be updated from its corresponding relations and subject entities at layer l ∈ [0, p − 1], as below: W lt ⊗ (Qt (hls,t )) + (Qt (r t ))) (6) hl+1 o,t = σ(at (s,r),∃(s,r,o,t)∈St
Time-aware Quaternion Convolutional Network
305
where hls,t , hl+1 o,t , r t are the embeddings of subject entity at layer l, object entity at layer l+1 and relation. Qt (·) represents concatenation of quaternion embedding, as Qt (hls,t ) = [hls,t,w ; hls,t,x ; hls,t,y ; hls,t,z ] and Qt (r t ) = [r t,w ; r t,x ; r t,y ; r t,y ; ], where ; represents concatenate. W lt = W lt,w + W lt,x + W lt,y + W lt,z is the weight matrix in quaternion space, ⊗ is Hamilton product, the specific arithmetic process is as follows: ⎤ l l l ⎤⎡ l ⎡ l l l W t ⊗ (Qt (h s,t )) + (Qt (r t ))
1 i j k
=
W
t,w l
⎣ W t,x l W t,y W lt,z
−W t,x W lt,w W lt,z −W lt,y
−W t,y −W lt,z W lt,w W lt,x
−W t,z h s,t,w W lt,y ⎦ ⎣ h ls,t,x −W lt,x h ls,t,y W lt,w h ls,t,z
+ r t,w + r t,x ⎦ + r t,y + r t,z
(7)
at is the time aware re-normalized adjacency matrix for all related entities and relations in Gt , it can be calculated by D−1/2 (At + I)D−1/2 , At is the adjacency matrix with time norm, it can record the frequency for all facts, and gives the entities related to these facts a norm 1/t − tl , where tl is the timestamp the fact last occurred, I is the identity matrix, D is the diagonal matrix, σ is activation function tanh. The output of QGCN can be represented as H pt , and then TQGCN uses GRU network to update dynamic entity embedding, as follows:
H t = GRUh (H pt , H t−1 )
(8)
Meanwhile, T-QGCN applies a time gate to prevent the loss of entity features during training, as shown in the below:
U t = f (W ∗t H t + b∗t )
H t = U t ⊗ H t + (1 − U t ) ⊗ H t−1
(9) (10)
where W ∗t , b∗t are the weight matrix and bias value, f is the activation function sigmoid, U t is the time gate to control the weight between the dynamic entity embedding H t−1 at timestamp t − 1 and the structural entity embedding H t at timestamp t. 4.2
Decoding Module
Previous works have shown that convolutional decoders like ConvTransE [17] perform better in KG reasoning. We design a new decoding module based on CNN in quaternion space [26]. For dynamic entity embedding H t , tT −m+1 ≤ t ≤ tT , T-QGCN first uses activation function tanh to update the entity embedding, and then concatenates it with the relation embedding Rt to construct a union embedding Z t , as H tanh = tanh(H t ), Z t = [H tanh ; Rt ]. T-QGCN uses two t t convolution layers and a linear layer to update the union embedding Z t , as below: (11) Z 1t = N orm2d (Z t ) Z 2t = ReLU (N orm2d (Conv2d (ωt1 , Z 1t )))
(12)
Z 3t = ReLU (N orm2d (Conv2d (ωt2 , Z 2t )))
(13)
306
C. Mo et al.
Z 4t = ReLU (Linear(ωt3 , Z 3t ))
(14)
where N orm2d (·) is batch normalization function, ReLU (·) is the activation function ReLU , Conv2d (·) is a 2-dimension convolution layer, ωt1 , ωt2 are the trainable parameters in the convolution kernel, Linear(·) is the linear layer, and ωt3 is the trainable parameter. Then, the probability vector at timestamp tT +1 for every sequence Gt can be computed as: p(otT +1 |st , rt , Gt ) = sigmoid(Z 4t ⊗ H tanh ) t
(15)
where sigmoid(·) is the activation function sigmoid, p(otT +1 |st , rt , Gt ) s the probability vector for every fact at timestamp tT +1 from the entity embedding and the relation embedding at timestamp t. T-QGCN combines all probability vectors from timestamp tT −m+1 to tT to prevent the loss of historical features in the training process, the final score probability vector can be defined as: p(otT +1 |st , rt , [GtT −m+1 :tT ]) = sof tmax(ptT −m+1 :tT (otT +1 |st , rt , Gt ))
(16)
where sof tmax(·) is the activation function sof tmax. 4.3
Parameter Learning
The task of entity prediction can be regarded as multi-label learning problems. Based on the probability vector p(otT +1 |st , rt , [GtT −m+1 :tT ]), the loss function for entity prediction can be defined as: L=−
tT
|E|−1
t=tT −m+1 i=0
χt+1,i log(pi (otT +1 |st , rt , [GtT −m+1 :tT ]))
(17)
where χt+1,i is a task label value,i represents the number of each entity in entity set, χt+1,i = 1 when the fact (s, r, i, t + 1) is true at timestamp t + 1 in training, else the value is false and χt+1,i = 0. pi (otT +1 |st , rt , [GtT −m+1 :tT ]) is the probability vector for entity i.
5 5.1
Experiment Experimental Setup
Datasets. This paper uses four datasets for the experiments, they are ICEWS18 [23], ICEWS14 [24], YAGO [25], WIKI [19]. The two ICEWS datasets are from Integrated Crisis Early Warning System, YAGO and WIKI are from publicly available datasets. The form of facts in ICEWS18 and ICEWS14 is (s, r, o, t), where t is the timestamp of the fact. The form of facts in YAGO and WIKI is (s, r, o, [ts , te ]), where ts is the starting timestamp and te is the ending timestamp. All datasets except ICEWS14 are divided into training, validation, and testing sets with a proportion of 80%, 10%, and 10% by timestamps following. ICEWS14
Time-aware Quaternion Convolutional Network
307
Table 1. . Datasets
Entities Relations Training Validation Testing Interval
ICEWS18 23,033
256
373,018
45,995
49,545
24 h
ICEWS14 12,498
260
373,895
–
341,409 24 h
WIKI
12,554
24
539,286
67,538
63,110
1 year
YAGO
10,623
10
161,540
19,523
20,026
1 year
is divided into training and testing sets with a proportion of 50% and 50% by timestamps following. Table 1 summarizes the statistics of these datasets. Evaluation Metrics. We use two metrics to evaluate T-QGCN, Mean Reciprocal Rank(MRR) and the proportion of correct test cases that ranked within the top 1/3/10(Hits@1/3/10). Due to some filter settings [5] may get incorrect higher ranking scores, the experimental results of T-QGCN are all in the raw setting. Baselines. Twelve baselines for static KG reasoning and TKG reasoning are used for comparison with T-QGCN. DistMult [11], RGCN [16], ConvE [14], RotatE [12] are static models. HyTE [20], TTransE [19], TA-DistMult [18] are temporal models for the interpolation task. CyGNet [4], Re-Net [3], Re-GCN [5] are temporal models for the extrapolation task. Know-Evolve [23] and DyRep [1] are extended to the temporal reasoning task. Experimental Settings. All experiments are evaluated on NVIDIA A100 80 GB PCle GPU. Adam is adopted for parameter learning with the learning rate of 0.001. The dimension of entity embedding and relation embedding is 200, and the dimension of every part in quaternion vectors is 50. The history lengths for YAGO, WIKI, ICEWS14, and ICEWS18 are 1, 3, 5, and 6. The number of QGCN layers is 2 and the dropout rate for each layer is 0.2. For the decoding module, the number of convolutional layers is 2, the kernel size for each layer is 3 × 3, and the dropout rate for each layer is 0.2. All experiments are carried out under the same conditions, and T-QGCN only uses historical facts before timestamp tT to reason the future multi-step facts. 5.2
Experimental Results
All experimental results on the entity prediction task are shown in Table 2 and Table 3. Table 2 shows that T-QGCN has better performance than baselines in the two ICEWS datasets. Compared with static KG completion methods, T-QGCN can get temporal information in TKGs, so the results are better. Compared with the TKG reasoning methods applied in the interpolation setting, T-QGCN performs better because it can get sequence structure in TKGs. Compared with the TKG reasoning methods applied in the extrapolation setting, T-QGCN also has advantages. T-QGCN represents entities and relations
308
C. Mo et al.
Table 2. The result of entity prediction task on ICEWS datasets with raw metrics. Model
ICEWS14 ICEWS18 MRR Hits@3 Hits@10 MRR Hits@3 Hits@10
DistMult RGCN ConvE RotatE
13.86 15.05 22.56 11.63
15.22 16.49 25.41 12.31
31.26 29.00 41.67 28.03
9.72 15.03 21.64 9.79
10.09 16.12 23.16 9.37
22.53 31.47 38.37 22.24
HyTE TTransE TA-DistMult
7.41 8.44 15.62
7.33 8.95 17.09
16.01 22.38 32.21
7.72 10.86 11.29
7.94 12.72 11.60
20.16 23.65 23.71
Know-Evolve DyRep CyGNet RE-Net RE-GCN T-QGCN
7.41 7.82 24.98 26.62 27.51 28.12
7.87 7.73 28.58 30.27 31,17 32.24
14.76 16.33 43.54 45.57 46.55 47.33
16.81 17.54 22.77 23.85 23.17 25.84
18.63 19.87 23.54 14.63 25.74 29.12
29.20 30.34 41.62 42.58 40.99 44.33
T-QGCN w.GT 28.98 33.17
48.72
26.24 29.60
45.20
in quaternion space to learn hidden features to distinguish similar facts and gets more history information when reasoning. Table 3 shows that T-QGCN has better performance than baselines in WIKI and YAGO datasets. The time interval in the two datasets is one year, static KG completion methods and TKG reasoning methods applied in the interpolation setting perform poorly, as the time interval is too long to allow these methods to obtain an accurate representation of entities. For the TKG reasoning methods applied in the extrapolation setting, long time interval also makes reasoning difficult, but these methods can obtain the sequence structure in TKGs, especially T-QGCN applies an effective decoding module to use historical features in reasoning. This paper also reports the experimental results with ground truth(GT), which means the true facts will be added to the historical facts after each step of reasoning. That illustrates the portability of T-QGCN, which can combine the new facts to predict future facts. In summary, compared with the results of baselines, T-QGCN achieves the improvements of 25.4% in MRR, 24.6% in Hits@3, 18.0% in Hits@10 over the best baseline on WIKI, and achieves the improvements of 7.7% in MRR, 11.6% in Hits@3, 3.9% in Hit@10 over the best baseline on ICEWS14. For other datasets, T-QGCN also has better performance than the best baselines. The results show that T-QGCN is more capable than baselines for TKG reasoning and illustrate the effectiveness of the representation of entities and relations in TKGs in quaternion space.
Time-aware Quaternion Convolutional Network
309
Table 3. The result of entity prediction task on WIKI and YAGO with raw metrics.
5.3
Model
WIKI YAGO MRR Hits@3 Hits@10 MRR Hits@3 Hits@10
DistMult RGCN ConvE RotatE
27.96 13.96 26.41 26.08
32.45 15.75 30.36 31.63
39.51 22.05 39.41 38.51
44.05 20.25 41.31 42.08
49.70 24.01 47.10 46.77
59.94 37.30 59.67 59.39
HyTE TTransE TA-DistMult
25.40 20.66 26.44
29.16 23.88 31.36
37.54 33.04 38.97
14.42 26.10 44.98
39.73 36.28 50.64
46.98 47.73 61.11
Know-Evolve DyRep CyGNet RE-Net RE-GCN T-QGCN
10.54 10.41 30.77 30.87 39.84 53.44
13.08 12.06 33.83 33.55 44.43 58.93
20.21 20.93 41.19 41.27 53.88 65.67
5.23 4.98 46.72 46.81 58.27 59.17
5.63 5.54 52.48 52.71 65.62 69.08
10.23 10.19 61.52 61.93 75.94 77.26
T-QGCN w.GT 62.15 69.38
78.04
61.83 72.17
81.89
Ablation Studies
Euclidean Vectors. To show whether quaternion vectors have better performance in TKG reasoning, we represent entities and relations in Euclidean vectors under the same condition. The experimental results are denoted as -Euclid, as shown in Table 4 and Table 5. It can be seen from the two tables that the experimental results on the four datasets have decreased, especially on the ICEWS datasets, because there are many similar historical facts as illustrated in the introduction. That also explains the quaternion vectors have highly expressive representations than Euclidean vectors. Frequency of Historical Facts. To show whether T-QGCN can distinguish accidental facts and continuous facts, we no longer consider the number of times each fact appears in history. Specifically, we use the adjacency matrix without time information at to replace the time-aware re-normalized adjacency matrix at for each sequence in TKGs. The experimental results are denoted as -TimeFre, as shown in Table 4 and Table 5. The experimental results show that the performance of T-QGCN on all datasets will decrease after removing the time information. Through the observation of each dataset, we find that there are a large number of facts on YAGO and WIKI appear at every timestamp in history, while most historical facts on ICEWS18 appear in several timestamps. That’s why the performance of T-QGCN on ICEWS18 declines more after removing time information. Decoding with one History Embedding. To show whether history embeddings play a role in the process of reasoning, we set the experiments by only
310
C. Mo et al.
using the embeddings at timestamp tT in a TKG from t0 to tT to reason future facts. The experimental results are denoted as -DecOne, as shown in Table 4 and Table 5. It can be observed that experimental results will get worse except in YAGO datasets. That’s because the history length of YAGO is 1, and the historical features have little impact on it. It also illustrates the importance of historical information in reasoning. Table 4. The result for ablation studies on ICEWS datasets with raw metrics. T-QGCN ICEWS18 ICEWS14 MRR Hits@3 Hits@10 MRR Hits@3 Hits@10 T-QGCN 28.12 32.24
47.33
25.84 29.12
44.33
-Euclid 27.48 -TimeFre 27.43 -DecOne 26.48
45.76 45.46 44.06
25.20 25.46 24.60
42.94 43.48 41.94
31.30 31.23 29.83
28.20 28.44 27.44
Table 5. The result for ablation studies on WIKI and YAGO datasets with raw metrics.
T-QGCN WIKI YAGO MRR Hits@3 Hits@10 MRR Hits@3 Hits@10
6
T-QGCN 53.44 58.93
65.67
59.17 69.08
77.26
-Euclid 52.64 -TimeFre 52.62 -DecOne 50.41
64.96 64.91 64.11
55.64 58.43 59.17
76.62 76.90 77.26
58.29 58.18 56.67
66.81 68.41 69.08
Conclusion
This paper proposes T-QGCN for the extrapolation setting in TKG reasoning, which represents entities and relations in quaternion space to distinguish similar historical facts. T-QGCN also adds time information into each sequence in TKGs to obtain the frequency of historical facts. In addition, T-QGCN uses a new decoding module to use more historical embeddings to avoid feature loss during training. Experimental results on four datasets demonstrate that T-QGCN has better performance than baselines in the task of entity prediction. That also illustrates that T-QGCN is more capable than baseline in TKG reasoning. Acknowledgement. This work is supported in part by the Major Key Project of PCL (Grant No. PCL2022A03).
Time-aware Quaternion Convolutional Network
311
References 1. Trivedi, R., Farajtabar, M., Biswai, P., Zha, H.: Dyrep: Learning representations over dynamic graphs. In: ICLR (2019) 2. Wang, X., Li, X., Zhu, J., et al.: A local similarity-preserving framework for nonlinear dimensionality reduction with neural networks. In: DASFAA (2021) 3. Jin, W., Qu, M., Jin, X., Ren, X.: Recurrent event network: autoregressive structure inference over temporal knowledge graphs. In: EMNLP (2020) 4. Zhu, C., Chen, M., Fan, C., Cheng, G., Zhan, Y.: Learning from history: modeling temporal knowledge graphs with sequential copy-generation networks. In: ICLR (2021) 5. Li. Z., et al.: Temporal knowledge graph reasoning based on evolutional representation learning. In: SIGIR (2021) 6. Dai, Q.N., Tu, Q.N., Phuang D.: Quaternion graph neural network. In: ACML (2021) 7. Bordes, A., Usunier, N., Garcia, A., Weston, J., Yakhnenko, Q.: Translating embeddings for modeling multi-relational data. In: Advances in Neural Information Processing Systems (2013) 8. Wang, Z., Zhang, J., Feng, J., Chen, Z.: Knowledge graph embedding by translating on hyperplanes. In: AAAI (2014) 9. Lin, Y., Liu, Z., Sun, M., Lin, Y., Zhu, X.: Learning entity and relation embeddings for knowledge graph completion. In: AAAI (2015) 10. Nickel, M., Tresp, V., Kriegel, H.P.: A three-way model for collective learning on multi-relational data. In: ICML (2011) 11. Yang, B., Yih, W., He, X., Gao, J., Deng, L.: Embedding entities and relations for learning and inference in knowledge bases. In: ICLR (2015) 12. Sun, Z., Deng, Z.H., Nie, J.Y., Tang, J.: Rotate: Knowledge graph embedding by relational rotation in complex space. In: ICLR (2019) 13. Zhang, S., Tay, Y., Yao, L., Liu, Q.: Quaternion knowledge graph embedding. In: Neural Information Processing Systems (2019) 14. Dettmers, T., Minervini, P., Stenetorp, P., Riedel, S.: Convolutional 2d knowledge graph embeddings. In: AAAI (2018) 15. Kipf, T., Welling, M.: Semi-supervised classification with graph convolutional networks. In: Advances in Neural Information Processing Systems (2016) 16. Schlichtkrull, M., Kipf, T., Bloem, P., Berg, R., Titov, I., Welling, M.: Modeling relational data with graph convolutional networks. In: European Semantic Web Conference (2018) 17. Shang, C., Tang, Y., Huang, J., Bi, J., He, X., Zhou, B.: End-to-end structureaware convolutional networks for knowledge base completion. In: AAAI (2019) 18. Garcia-Duran, A., Dumancic, S., Niepert, M.: Learning sequence encoders for temporal knowledge graph completion. In: EMNLP (2018) 19. Leblay, J., Chekol, M.W.: Deriving validity time in knowledge graph. In: International World Wide Web Conferences Steering Committee (2018) 20. Dasgupta, S., Ray, S.N., Talukdar, P.: Hyte: hyperplane based temporally aware knowledge graph embedding. In: EMNLP (2018) 21. Xu, C., Nayyeri, M., Alkhoury, F., Yazdi, H., Lehmann, J.: TeRo: a time-aware knowledge graph embedding via temporal rotation. In: International Conference on Computational Linguistics (2020) 22. Xu. C., Nayyeri, M., Alkhoury, F., Yazdi, H., Lehmann, J.: Temporal knowledge graph embedding on additive time series decomposition. In: International Semantic Web Conference (2020)
312
C. Mo et al.
23. Trivedi, R., Dai, H., Wang, Y., Song, L.: Know-evolve: deep temporal reasoning for dynamic knowledge graphs. In: ICML (2017) 24. Boschee, E., et al.: Icews coded event data. In: Harvard dataverse (2015) 25. Mahdisoltani, F., Biega, J. A., Suchanek, F.: Yago3: a knowledge base from multilingual Wikipedia. In: CIDR (2014) 26. Wang, S., Cao, J., Yu, P.: Deep learning for spatio-temporal data mining: a survey. In: TKDE (2022)
SumBART - An Improved BART Model for Abstractive Text Summarization A. Vivek(B) and V. Susheela Devi Indian Institute of Science, Bengaluru, Karnataka 560012, India [email protected] https://iisc.ac.in
Abstract. In this project we introduce SumBART - an improved version of BART with better performance in abstractive text summarization task. BART is a denoising autoencoder model used for language modelling tasks. The existing BART model produces summaries with good grammatical accuracy but it does have certain amount of factual inconsistency. This issue of factual inconsistency is what makes text summarization models unfit to use in many real world applications. We are introducing 3 modifications on the existing model that improves rouge scores as well as factual consistency. Keywords: Abstractive text summarization version of BART
1
· SumBART · Improved
Introduction
Text summarization has been one among the vastly explored areas in machine learning in past few years with many models producing summaries which closely resemble human generated summaries. Text summarization is mainly of two types - extractive and abstractive. In exractive summarization, most important sentences are identified from the given paragraph and reproduced as it is to create the summary. In abstractive summarization, the model tries to understand the gist of the article and then produces a summary in its own words. In this project, we will deal with abstractive summarization. Abstractive summarization produces summaries which resemble those written by human beings. Generating abstractive summaries is more difficult than generating extractive summaries as the model has to first understand the content in the article and then generate new sentences on its own. This opens up chances of sentence construction errors, grammatical errors and factual inconsistency errors. In this project we will make improvements on BART model [19] to produce abstractive summaries that are grammatically correct, have good sentence structure and maintain good factual consistency. BART is a transformer based model that has produced state of the art results in multiple language modeling tasks. In this work, we have modified the BART model to enhance its capability to carry out abstractive summarization. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 313–323, 2023. https://doi.org/10.1007/978-981-99-1639-9_26
314
2
A. Vivek and V. S. Devi
Literature Survey
Before the 21st century summarization approaches were mostly extractive approaches where most important sentences were identified from the parent article and reproduced as it is in the summary. Banko et al. [1] introduced a statistical translation based approach that can create new sentences to form the summary instead of using the sentences in the source document. Zajic et al. [2] came up with linguistic and statistical methods for summarization which did not depend fully on extractive approaches. Vu et al. [3] introduced an approach to simplify children stories with limited data. But these approaches were not comparable with summarization done by human beings. The turning point in the field of abstractive summarization came with Sutskever et al. [4] introducing recurrent neural networks that can be used in natural language processing tasks. Recurrent neural networks were used by Rush et al. [5] to create abstractive summary of text with a neural attention model. This was a fully data driven approach and made significant performance gains over the baseline models of that time. Parikh et al. [6] showed that self attention based model can be used effectively for language inference by decomposing the problem into sub problems that can be dealt with separately. Nallapati et al. [7] showed that attentional encoder decoder models can produce state of the art results in abstractive summarization. See et al. [8] introduced pointer generator networks in which a hybrid pointer network copied important parts from source text by pointing and then a generator network generated abstractive summary. Paulus et al. [9] came up with a model that combined standard supervised word prediction and reinforcement learning to produce good summaries on long paragraphs. Chen and Bansal [10] used a hybrid extractive-abstractive architecture with policy based reinforcement learning. The next big breakthrough in the field of natural language processing was the introduction of transformers by Vaswani et al. [11]. Recurrent neural networks that were used for natural language processing tasks till that point processed input sequentially and hence parallel processing was not possible. The hidden state output of one stage had to be used as input for the next stage. Transformers made parallel processing of input sentences possible as it used matrix operations to compute self attention at any point. Moreover, results produced by transformers were better than those by recurrent neural networks. Narayan et al. [12] explored the ability of convolutional neural networks to identify the key ideas of a paragraph and perform extreme summarization in which short and ’to the point’ one sentence summaries were generated on BBC news dataset. Training language models for various tasks from scratch was a humongous task and this paved the way for the idea of having pre-trained models that are trained on huge unsupervised datasets. These models can be fine tuned
SumBART - An Improved BART Model for Abstractive Text Summarization
315
to suit various language modeling tasks. Devlin et al. [13] introduced BERT which is a transformer based pre-trained model that can be fine tuned with one extra output layer to suit various tasks. This was an encoder only model. BERT made significant improvement in abstractive summarization as it was pre-trained on very large datasets. Transfer learning boosted the efficiency of abstractive summarization by a good margin but all these pre-trained models were general purpose language models fine tuned for abstractive summarization. They were not tailor made for this task. This gave way to PEGASUS which was introduced by Zhang et al. [14]. This model was prepared exclusively for the task of abstractive summarization and it was able to come up with good results even in low resource summarization. Raffel et al. [15] introduced T5 which is a unified text to text transformer. The idea behind T5 was to convert all text based language problems into text to text format. Radford et al. [16] introduced the idea of generative pre-training (GPT) on diverse unlabelled data for building language models that can perform well on a wide variety of tasks. They proposed task specific descriminative fine tuning in order to make GPT models perform any particular task. These were decoder only models and produced state of the art results. As a successor of GPT model came GPT 2 model by Radford et al. [17]. They showed that language models can learn many language modeling tasks without explicit supervision when trained on a dataset of millions of web pages. GPT 3 - an autoregressive model - was introduced as the successor of GPT 2 by Brown et al. [18]. GPT 3 can perform various language modeling tasks without the need of any fine tuning. Lewis et al. [19] came up with a denoising autoencoder called BART which can be used for pre-training language models. It was trained by corrupting the text using some noise and then making the model predict the original text. BART proved more effective in creating abstractive summaries than BERT. BART stands for Bidirectional Auto-regressive Transformer. BERT is an encoder only model and GPT is a decoder only model. The beauty of BART is that it has a bidirectional encoder similar to BERT and an auto-regressive decoder similar to GPT. This helped BART perform very well on language modeling tasks including text summarization. There are many evaluation metrics that can be used to allot a score to summaries generated by a model and compare it with summaries generated by other models. Papineni et al. [20] came up with a precision based evaluation metric called BLEU while Lin [21] came up with a recall based evaluation metric called ROUGE. Zhang et al. [22] introduced BERT score which uses sum of cosine similarities between the embeddings of tokens in generated sentence and reference sentence. Abstractive summaries produced with the help of pre-trained models showed good grammatical accuracy and sentence structure but still lacked in factual
316
A. Vivek and V. S. Devi
consistency. Huang et al. [23] found that existing abstractive summarization models suffers from two major types of factual inconsistency errors namely intrinsic errors and extrinsic errors. Intrinsic errors are when the generated summary contradicts facts in the source document while extrinsic errors are when facts in generated summary neither supports nor contradicts the source document. Zhou et al. [24] studied the hallucination problem in neural sequence generation and concluded that neural sequence models hallucinate additional content that is not supported by the source document.
3
Methodology
Our model is a modified version of BART model introduced in Lewis et al. [19]. Our aim was to make modifications that can improve the performance of the existing model and fine tune the modified model for text summarization on different datasets. Three modifications that we made on BART are as follows. 3.1
Additional Embedding Using Information from Knowledge Graph
One major area where we intend to bring improvement is in tackling factual inconsistency issue and hallucination problem as mentioned in Huang et al. [23] and Zhou et al. [24] respectively. In order to tackle this issue, we wanted the model to understand the real meaning of different words when people use them in conversations. Our word embeddings are based on Conceptnet knowledge graph. Conceptnet connects words and phrases with labelled edges. In this way it holds the general information related to language understanding. We came up with the idea of extracting knowledge available in this so as to provide the model with extra information about each word that it encounters. This is something BART can not deduce from embeddings created from input sentences. Our use case in this scenario does not demand further training of the graph but just extraction of information from existing knowledge graph. Speer et al. [25] showed that applying Conceptnet to word embeddings can enhance the ability of models to perform analogy based tasks as it embeds information about usage of words. We used numberbatch embeddings introduced by them in our model. 3.2
Keyword Extractor
The next thing we had to do was to make sure that the model won’t miss out on any key areas while generating the summary. For this, we came up with the idea of using additional keyword extractor model which can identify keywords in the given article. This set of keywords can be separately input to the model so as to help it stay aware of major points in the article and include that while creating the summary. The keyword extractor model we used is a BERT based model called keyBERT introduced in Grootendorst [26]. This model can identify top few keywords or key phrases from an article. The model allots scores to keywords and then outputs the ones with highest scores.
SumBART - An Improved BART Model for Abstractive Text Summarization
3.3
317
CNN Embedding
The power of CNN in text based tasks including summarization was highlighted in multiple papers including Narayan et al. [12] which showed its power in identifying key areas of the article. Following these footsteps, we made use of CNN in creating embeddings that can provide the model with information on key areas of the paragraph. Our first modification helped the model in identifying correct usage of words and language rules while the other 2 modifications helped the model gain the ability to identify key areas of the paragraph. Figure 1 shows the complete architecture of our model. As visible in the architecture, we have separated out the token embedding and position embedding layers from the BART encoder. The knowledge graph based embedding layer and keyBERT model are placed before this layer. CNN embedding layer is placed after this layer as it takes the output of BART embedding layer as input. The architecture shows two copies of BART embedding layer which includes token embedding and position embedding. Two copies are shown for the ease of representation even though there is only one in the model. The first copy shown in the architecture is used to show how keyword based embedding is passed through BART embedding to create final embedding. Keywords identified from keyBERT model are converted to single string, tokenized and passed to BART embedding. The output of BART embedding from here is one among the four embeddings which are added up to form final SumBART embedding of our model. The second copy is used as normal BART embedding where it gets position embedded and token embedded. The output from this goes to two destinations - CNN embedding layer and our SumBART model’s final embedding. The embedding created by knowledge graph based embedding layer does not pass through the BART embedding layer. Instead it is directly added to the final SumBART embedding after passing through a fully connected layer. So, our SumBART embedding has four components - BART embedding, keyword based embedding, knowledge graph based embedding and CNN based embedding. All these are added up in the sumBART embedding layer. This is then passed on to encoder layers. The reference summary is tokenized and directly passed to the decoder in training phase. This helps the decoder to see all previous outputs while generating any word in training phase. In testing phase, decoder runs in loop and uses summary created by it till that time step to predict the word in next time step.
318
A. Vivek and V. S. Devi
Fig. 1. SumBART architecture. The tokenized reference summary is used only in the training phase. In testing phase, the decoder uses words generated by itself till the previous time step instead of using tokenized reference summary.
SumBART - An Improved BART Model for Abstractive Text Summarization
4
319
Implementation and Results
Our model has three special features added over BART model. The model input data consist of pairs of source articles and human generated summaries. We used CNN/Daily-mail and XSum text summarization datasets. CNN/Daily-mail dataset has 286817 training pairs, 13386 validation pairs and 11487 test pairs. XSum dataset has 203577 training pairs, 11305 validation pairs and 11301 test pairs. Articles were tokenized using BART tokenizer and then fed to our model. We used the base model and tokenizer of BART provided by Huggingface ([27]). The first layer in the BART model is the embedding layer. It has a token embedding layer which creates embedded tokens and a positional embedding layer which creates positional embedding in order to embed the position of different words in the sentence. Over and above the BART embedding, we have added numberbatch embedding which makes use of Conceptnet data to produce embeddings which help the model is indentifying correct usage of words in the article. This adds language knowledge to the data which can help in reducing inconsistencies in summaries. Numberbatch embedding output corresponding to each token is of length 300. BART embeddings are of length 1024. In order to match this, we added a fully connected layer which will convert embeddings of length 300 into 1024. This linear layer was also trained during the model training. Next addition we made is a keyword identifier model called keyBERT introduced in [26] which identifies keywords in the given article. This model takes each article as input and generates top five keywords from the article. We converted these keywords into a comma separated list and passed it as a string to BART tokenizer. The tokenized string was passed to the embedding layer to create an embedding composed of top 5 keywords of the parent article. This helped the model in not missing out on any of the major ideas of the article. We also added an extra embedding generated from a CNN layer where the CNN layer takes embedded article as input and produces CNN embeddings of same dimensions. We used one dimensional convolutional layer with kernel size of 7, stride length of 1 and padding of 3. All the above generated embeddings were added together to form a final SumBART embedding where each token is represented by vector of length 1024. Each sentence is comprised of 512 such tokens. All sentences were padded to match 512 length. We used a batch size of 8. We used Adam optimizer with learning rate of 0.005 and betas 0.9 and 0.09. This passes through 12 layers of encoder where each layer passes on hidden states to the next. The last hidden state obtained at the final layer of encoder is passed to the decoder. During training, decoder is given right shifted tokenized summary as input so that the decoder can see all previous words while generating any word. This also helps the decoder in generating the whole sentence in one go. During testing, decoder can not be provided with tokenized summary and hence it will generate output words one by one in loop where it will generate
320
A. Vivek and V. S. Devi
one word by looking at all preceding words (generated by itself) and then move on to generate the next word and so on. The output of the decoder is passed through language modelling head which converts each token in decoder into a vector of length 50265. Beam search with 4 beams is used to identify the final word from this vector. We observed that training the last few layers of the BART decoder along with all the newly added layers was more beneficial than training all layers of encoder and decoder. We got the best results when last 2 layers were trained keeping all other layers frozen. This could be because the data on which BART model is pre-trained is extensive compared to the data set we are using. So, training all the layers using a dataset which is much smaller than the original pre-training dataset might be inappropriate. For evaluating the model we used Rouge scores introduced in [21]. It compares the summary generated by model to human generated summaries and provides a score considering number of overlapping n-grams, word sequences and word pairs. We used Rouge-1 and Rouge-L scores to evaluate our model. Figure 2, 3 and 4 shows an article, its summary created by BART and that by SumBART respectively. It can be seen that our model was able to include Forrest’s age and birth name along with all other information that the BART managed to include in summary. Moreover, our model was able to come up with different sentence structure in which it included the information contained in first two sentences of the BART model in the opening sentence itself. There is still room for improvement in the model and that is depicted in Figs. 5 and 6. Figure 5 shows an article about Bruce Jenner, a father of six who had a transition to a woman. Figure 6 shows the summary generated by our model in which it got confused whether to use pronoun he or she for Jenner and ended up using he in the first sentence and she in the last sentence. The human generated reference summary expected the model to use she.
Fig. 2. Article 1
SumBART - An Improved BART Model for Abstractive Text Summarization
321
Fig. 3. Summary of Article 1 by BART
Fig. 4. Summary of Article 1 by SumBART. Our model was able to include the information about Forrest’s age and birth name in the summary in addition to all the information included in the summary produced by the BART. Table 1. Performance of models CNN/Daily-mail dataset Model
Rouge-1 Rouge-L
BART 34.72 SumBART 37.31
32.15 34.98
XSum dataset Model
Rouge-1 Rouge-L
BART 31.16 SumBART 33.83
27.77 30.31
Fig. 5. Article 2
322
A. Vivek and V. S. Devi
Fig. 6. Summary of article 2 by SumBART. Here, the model got confused whether to use pronoun he or she for Jenner and ended up using he in first sentence and she in the last sentence.
Table 1 shows the ROUGE-1 and Rouge-L scores obtained when we fine tuned BART model before and after our modifications on CNN/Daily-mail and XSum datasets. We can see that Rouge scores did improve significantly. The reference summaries were human generated and hence the summaries generated by our model is closer to human generated summary.
5
Conclusions
We conclude that the improvements we made in the BART model made the summaries generated by it resemble human generated summaries more closely. This trend was seen in both CNN/Daily-mail and XSum datasets. We found that our SumBART model produced summaries with better rouge scores than the BART model we fine tuned without modifications. Further improvements on the model can be done in making it more human-like by improving its capabilities like differentiating between a fact and an imagination while creating summaries. Once this is done, the range of applications on which abstractive text summmarization models apply would increase by a huge amount.
References 1. Banko, M., Mittal, V.O., Witbrock, M.J.: Headline generation based on statistical translation. In: Proceedings of the 38th Annual Meeting of the Association for Computational Linguistics (2000) 2. Zajic, D.M., Dorr, B.J., Lin, J.: Single-document and multi-document summarization techniques for email threads using sentence compression. Inf. Process. Manag. 44(4), 1600–1610 (2008) 3. Vu, T.T., Tran, G.B., Pham, S.B.: Learning to simplify children stories with limited data. In: Nguyen, N.T., Attachoo, B., Trawi´ nski, B., Somboonviwat, K. (eds.) ACIIDS 2014. LNCS (LNAI), vol. 8397, pp. 31–41. Springer, Cham (2014). https:// doi.org/10.1007/978-3-319-05476-6 4 4. Sutskever, I., Vinyals, O., Le, Q.V.: Sequence to sequence learning with neural networks. In: Advances in Neural Information Processing Systems, vol. 27 (2014) 5. Rush, A.M., Chopra, S., Weston, J.: A neural attention model for abstractive sentence summarization. arXiv preprint arXiv:1509.00685 (2015) 6. Parikh, A.P., et al.: A decomposable attention model for natural language inference. arXiv preprint arXiv:1606.01933 (2016) 7. Nallapati, R., et al.: Abstractive text summarization using sequence-to-sequence RNNs and beyond. arXiv preprint arXiv:1602.06023 (2016)
SumBART - An Improved BART Model for Abstractive Text Summarization
323
8. See, A., Liu, P.J., Manning, C.D.: Get to the point: summarization with pointergenerator networks. arXiv preprint arXiv:1704.04368 (2017) 9. Paulus, R., Xiong, C., Socher, R.: A deep reinforced model for abstractive summarization. arXiv preprint arXiv:1705.04304 (2017) 10. Zhang, Y., Chen, E., Xiao, W.: Extractive-abstractive summarization with pointer and coverage mechanism. In: Proceedings of 2018 International Conference on Big Data Technologies (2018) 11. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 12. Narayan, S., Cohen, S.B., Lapata, M.: Don’t give me the details, just the summary! topic-aware convolutional neural networks for extreme summarization. arXiv preprint arXiv:1808.08745 (2018) 13. Devlin, J., et al.: Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805 (2018) 14. Zhang, J., et al.: Pegasus: pre-training with extracted gap-sentences for abstractive summarization. In: International Conference on Machine Learning, PMLR (2020) 15. Raffel, C., et al.: Exploring the limits of transfer learning with a unified text-to-text transformer. arXiv preprint arXiv:1910.10683 (2019) 16. Radford, A., et al.: Improving language understanding by generative pre-training (2018) 17. Radford, A., et al.: Language models are unsupervised multitask learners. OpenAI Blog 1(8), 9 (2019) 18. Brown, T., et al.: Language models are few-shot learners. Adv. Neural Inf. Process. Syst. 33, 1877–1901 (2020) 19. Lewis, M., et al.: Bart: denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. arXiv preprint arXiv:1910.13461 (2019) 20. Papineni, K., et al.: Bleu: a method for automatic evaluation of machine translation. In: Proceedings of the 40th Annual meeting of the Association for Computational Linguistics (2002) 21. Lin, C.-Y.: Rouge: a package for automatic evaluation of summaries. In: Text Summarization Branches Out (2004) 22. Zhang, T., et al.: Bertscore: evaluating text generation with bert. arXiv preprint arXiv:1904.09675 (2019) 23. Huang, Y., t al.: he factual inconsistency problem in abstractive text summarization: a survey. arXiv preprint arXiv:2104.14839 (2021) 24. Zhou, C., et al.: Detecting hallucinated content in conditional neural sequence generation. arXiv preprint arXiv:2011.02593 (2020) 25. Speer, R., Chin, J., Havasi, C.: Conceptnet 5.5: an open multilingual graph of general knowledge. In: Thirty-first AAAI Conference on Artificial Intelligence (2017) 26. Grootendorst, M.: Keybert: minimal keyword extraction with bert. https:// maartengr.github.io/KeyBERT/index.html (2020) 27. Wolf, T., et al.: Transformers: state-of-the-art natural language processing. In: Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations (2020)
Saliency-Guided Learned Image Compression for Object Detection Haoxuan Xiong and Yuanyuan Xu(B) Key Laboratory of Water Big Data Technology of Ministry of Water Resources, Hohai University College of Computer and Information, Hohai University, Nanjing, China {haoxuan x,yuanyuan xu}@hhu.edu.cn
Abstract. With recent advance of computer vision techniques, an increasing amount of image and video content is consumed by machines. However, existing image and video compression schemes are mainly designed for human vision, which are not optimized concerning machine vision. In this paper, we propose a saliency guided learned image compression scheme for machines, where object detection is considered as an example task. To obtain salient regions for machine vision, a saliency map is obtained for each detected object using an existing black-box explanation of neural networks, and maps for multiple objects are merged sophistically into one. Based on a neural network-based image codec, a bitrate allocation scheme has been designed which prunes the latent representation of the image according to the saliency map. During the training of end-to-end image codec, both pixel fidelity and machine vision fidelity are used for performance evaluation, where the degradation in detection accuracy is measured without ground-truth annotation. Experimental results demonstrate that the proposed scheme can achieve up to 14.1% reduction in bitrate with the same detection accuracy compared with the baseline learned image codec.
Keywords: Learned image compression analysis · Object detection
1
· Machine vision · Saliency
Introduction
Nowadays, more and more images and videos are analyzed by machines, and analysis results are verified occasionally by humans. However, the majority of image and video coding schemes are optimized for human vision, which may not preserve vital information for machines during compression. Image and video coding schemes targeted at machine vision are nessacery to obtain a compact representation of visual data without degrading the performance of machine analysis. The exploration of video coding for machine (VCM) [5] has been launched by Moving Picture Experts Group (MPEG), which aims to standardize a bitstream format for the compact representation of both video stream and extracted features for various machine vision tasks. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 324–335, 2023. https://doi.org/10.1007/978-981-99-1639-9_27
Saliency-Guided Learned Image Compression for Object Detection
325
To conduct video coding for machine, we need to understand the decision process of machine analysis first. The output of machine analysis tasks can be mapped back to input space to see what parts of images influences the results most [1]. With class-discriminative saliency maps, machine vision-oriented compression methods can be designed. Cai et al. [3] designed a back-propagation based method to estimate the influence of pixels for object detection, and proposed an object detection-oriented bit-allocation scheme for High Efficiency Video Coding (HEVC) [15]. Huang et al. [7] utilized the generated bounding boxes before non-maximum suppression to evaluate the importance of each CTU, and designed a CTU-level bitrate allocation scheme for VVC. Besides the works based on traditional video codec, learned image compression for machine vision has been explored in [9,18]. Observing channel redundancies in high-dimensional feature maps in the learned coding framework, Wang et al. [18] designed an inverted bottleneck structure for the learned encoder, where analytics accuracy is incorporated into the optimization process. Le et al. [9] proposed an inference-time content-adaptive finetuning scheme for learnedimage compression targeting machine consumption, where the training process is guided by task loss or feature-based perceptual loss, mean square error and rate. However, in the existing learned image compression schemes, the saliency information of machine vision is implicitly utilized, where only the loss function is adapted to machine vision compared with learned image compression for human. Considerable redundancy remains as each pixel has the same dimension of latent representation. In this paper, we propose a novel learned image compression scheme for machines, where the latent representation of the image is pruned according to saliency for machines. Using object detection as an example task, a saliency map for each detected object is obtained using a black-box explanation method of neural network, and maps for all the detected objects are merged into a single saliency map by a proposed map merging scheme. Besides, a saliency-guided bitrate allocation scheme has been designed that tailors the number of a pixel’s feature channels according to the importance of this pixel for the machine task. Moreover, the whole end-to-end framework is trained with a loss function that incorporates the distortion of the detection results without ground-truth annotations. Experimental results show that compared with relevant schemes, the proposed scheme can achieve higher detection accuracy with the same bitrate.
2 2.1
Proposed Method The Overall Framework
The overall framework of the proposed saliency-guided learned image compression for object detection is shown in Fig. 1, where a state-of-the-art object detector YOLOv3 [13] is adopted. The saliency analysis for object detection is conducted first. With the original image x, YOLOv3 can produce detection result Detectori . Since YOLO is a convolutional neural network (CNN) based object detector, D-RISE [12],
326
H. Xiong and Y. Xu
Fig. 1. Illustration of the proposed framework. YOLO is the object detector as a representative task for machine vision. Detectori is the detection result of YOLO using the original image x. D-RISE is a black-box explanation method to generate a saliency map, Si , for each detected object. M-Sal is the proposed saliency map merging scheme that merges multiple maps to one denoted as S. The bitrate allocation in the encoder is guided by S. The distortion of signal reconstruction Dsignal , the loss in analysis task Dtask and the rate expense R of y can be jointly optimized.
an existing black-box explanation method of neural networks, is used to generate a saliency map Si for each detected object in Detectori . A proposed scheme, M-Sal, merges multiple saliency maps into one map, S, representing saliency for all the detected objects. Based on a state-of-art learned image codec in [6], the generated saliency map S is then used to guide the bitrate allocation in the learned image encoder, producing latent representation y. Feeding YOLO with the decoded image x ˆ, detection results Detect could be obtained. With signal ˆ and task loss Dtask calculated from Detectori distortion Dsignal between x and x and Detect, the learned image codec is trained with a loss function consisting of Dsignal , Dtask and the bitrate expense R of y. 2.2
Saliency Analysis for Machine Vision
In the following, we will present the details of saliency analysis for object detection, saliency-guided bitrate allocation and coding optimization for machine vision. Saliency analysis for machine vision can be conducted in a white-box [14,19] or black-box manner [11,12]. The Class Activation Mapping (CAM) [19] method modifies image classification CNN to obtain a class-specific feature map which
Saliency-Guided Learned Image Compression for Object Detection
327
replaces fully-connected layers with convolutional layers and global average pooling, while the Gradient-weighted Class Activation Mapping (Grad-CAM) [14] approach uses the gradients of any target concept flowing into the final convolutional layer and guided backpropagation to produce a class-discriminative visualization. These white-box methods need to either modify the network or obtain internal states of the network, while the black-box methods are more general and can be applied to any CNN network, where the network is treated as a black box and access to its parameters, features or gradients is not assumed. RISE [11] estimates pixel importance by probing the model empirically with randomly masked versions of the input image and obtaining the corresponding outputs for classification, while D-RISE [12] produces saliency map similarly for any object detector. Therefore, D-RISE is used to obtain the salient regions for CNN-based YOLO. For each original image, detection result of YOLO can be represented by Detectori = {(cls0 , bbox0 , conf0 ), ..., (clsn , bboxn , confn )}, where clsi , bboxi , and confi represent the category label, the bounding box and the confidence score for a detected object i, respectively, as shown in Fig. 2(a). For a detected object i, DRISE can interpret the detection result to produce a saliency map Si measuring the importance of pixels in detecting this object. Multiple saliency maps are generated for multiple detected objects as demonstrated in Fig. 2(b). To guide the bitrate allocation, we need a single saliency map describing the influence of pixels for all the detected objects. A straightforward approach to produce a single map is to average all the normalized maps. However, by averaging all the maps, some of previously salient regions become less obvious, while part of irrelevant areas become salient, as shown in Fig. 2(c). The averaged saliency map fails to represent what is salient to machine vision in the case of the multiple objects. To this end, we propose a scheme called M-Sal to merge multiple saliency maps into one. Pixel values of each saliency map S is normalized first to a range between 0 and 1 using the Min-Max Normalization. Then each saliency map is pre-processed as follows, which keeps the values of pixels that overlapped with the detected objects and setting the values of all the remaining pixels to zero removing the impact of irrelevant areas. 0 Sk (i, j) < σk , (1) Sk (i, j) = Sk (i, j) Sk (i, j) ≥ σk where Sk (i, j), Sk (i, j) and σk represent the values of a pixel at row i and column j of the pre-processed and original saliency maps for object k, and the salient threshold for object k, respectively. The value of σk can be obtained by segmenting the salient object, and retrieving the averaged pixel value on the boundary of the salient object. Then the preprocessed maps are merged into a single saliency map as follows, N S(i, j) = min( Sk (i, j), 1). (2) k=1
328
H. Xiong and Y. Xu
The value of S(i, j) is within the range of 0 and 1. An example of merged saliency map through M-Sal is shown in Fig. 2(d). The final saliency map S will be used to guide the bitrate allocation in the learned image codec. 2.3
Saliency-Guided Bitrate Allocation
The proposed image codec is established upon a state-of-art end-to-end image compression framework in [6]. For image compression in [6], the encoder transforms the input image x into latent representation and reduces redundancy by introducing the coarse-to-fine hyper-prior model for entropy estimation and signal reconstruction. The decoder performs the reverse transformation, using the latent representation and its coarse-to-fine hyper-priors together to reconstruct the image. The transmitted bitstream involves the quantized latent representation and its coarse and fine-grained hyper-priors. In an end-to-end codec, the amount of data to be compressed and transmitted is related to the size of the feature map, particularly the number of channels. While the work in [6] uses the same number of channels for each pixel, the proposed scheme tunes the number of channels in the latent representation under the guidance of the saliency map, so that more coding resources are allocated to information vital to machines. The proposed saliency-guided bitrate allocation scheme is displayed in Fig. 3. x is an RGB image with a height of h and a width of w. Then x is analysis transformed into a latent representation z3 of 192 dimensions with a height of h and a width of w . We adopt a spatially varying pruning method by representing salient regions with more channels of quantized representations and fewer channels for less important areas. The dimension of representation in the original codec is denoted as a Cmax = 192, which serves as the maximum number of channels. To ensure the fidelity of the reconstructed image, the minimum number of channels Cmin is set to 128. The number of channels in latent representation for a position at row i and column j, C(i, j), can be obtained as follows C(i, j) = Cmin + S(i, j)(Cmax − Cmin )
(3)
Each position in spatially tuned latent representation y has a varying number of channels depending on its importance to machine vision, and thus y is more compact than z3 while maintaining information important to machine vision. The coarse-to-fine hyper-prior modelling is performed on y, and the final bitstream obtained consists of y as well as its coarse and fine-grained hyper-priors z2 and z1 . The hyper-priors analysis transform and decoding process are the same with work in [6]. 2.4
Coding Optimization for Machine Vision
Image compression for human vision involves rate-distortion optimization (RDO), while compression for machine vision has to consider both rate-accuracy optimization (RAO) and RDO for occasional signal reconstruction. Using a
Saliency-Guided Learned Image Compression for Object Detection
329
(a) The original image and the corresponding detection results of YOLO
(b) Saliency maps produced by D-RISE
(c) An averaging merged saliency map
(d) A M-Sal merged saliency map
Fig. 2. An example of salient analysis for YOLO, with different approaches of generating a single saliency map for all the detected objects.
330
H. Xiong and Y. Xu
Fig. 3. The pipeline of saliency guided bitrate allocation in coarse-to-fine hyper-prior modelling learned image codec.
Lagrangian approach, coding optimization for machine vision can be expressed as minimizing (4) L = R + λsignal Dsignal + λtask Dtask where R, λsignal , λtask , Dsignal and Dtask are the bitrate, the Lagrangian multipliers for the signal distortion and task loss, signal distortion, and task loss, respectively. Dsignal is measured in terms of mean square error (MSE) between the original image x and the reconstructed image x ˆ. The distortion term Dtask needs to reflect the performance degradation in object detection. In the existing works of feature compression and joint texture-feature compression [10,17], models are trained with the help of ground-truth labels. In some cases, these labels are not available. To address this issue, we design an evaluation metric for task loss Dtask comparing only the detection results of YOLO before and after compression, without ground-truth annotations. Using the decoded image x ˆ, the detection results of YOLO can be represented by 0 , bbox 0 , conf m , bbox m , conf j , bbox j , and 0 ), ..., (cls m )}, where cls Detect = {(cls j are the category label, the bounding box and the confidence score for a conf detected object j in the decoded image, respectively. Since multiple objects are detected in the original and decoded images, the detected objects from two sets are matched first. For a detected object i in Detectori , we try to find the corresponding detected object k in Detect as follows j ), k = argmax[1 − Δi,j ] · IoU (bboxi , bbox j
(5)
where IoU is the Intersection over Union that is the overlapped area of the two j |) having values of 0 and 1 for bounnding boxes, and Δi,j = min(1, |clsi − cls two objects with the same and different class labels, respectively. For a pair of matched objects, the task loss mainly consists of Dloc and Dconf which are differences in the bounding boxes and confidence scores. For unpaired objects, the task loss is due to the difference in class labels Dcls . Therefore, the task loss is defined as follows {Δi,k · αDcls + (1 − Δi,k ) · (βDloc + γDconf )} (6) Dtask = i
Saliency-Guided Learned Image Compression for Object Detection
331
where α, β and γ denote weights for Dcls , Dloc and Dconf . Dcls is empirically set to 1, and k ); (7) Dloc = 1 − IoU (bboxi , bbox k · log(conf k )). Dconf = −confi · log(confi ) − (−conf
(8)
The proposed learned image codec with saliency-guided feature tuning is trained with loss function in Eq. (4) considering both signal fidelity and task loss without ground-truth annotation.
3 3.1
Experiment and Results Implementation Details
The proposed learned image compression codec is implemented using released code of work [6]. YOLOv3 [16] is used to as the object detector, which is trained with the training datasets of VOC 2007 and VOC 2012 [4]. Part of VOC 2007 validation set is employed to train and evaluate the performance of the proposed scheme, with a training set of 1184 images and a validation set of 1180 images. During training, the learned image codec is first trained with a batch size of 8 and a patch size of 256 with other configurations as in [6]. In the second stage of training, the proposed codec is trained with the proposed loss function in Eq. (4). We trained different models with different λsignal in the range of 0.0005 to 0.016, and the λtask is empirically set to 10 and 1 for 5 lower bitrates and the highest one, respectively. The values of α, β and γ are all set to 1. The proposed scheme is compared with VVC reference software model VTM 16.0 [8] and the baseline learned image codec in [6]. Performance of different codec is evaluated in terms of mean Average Precision (mAP)@0.5, [email protected]:0.95 and peak signal-to-noise ratio (PSNR). Specifically, mAP for a set of queries is the mean of the average precision for each query. 3.2
Effectiveness of Saliency-Guided Bitrate Allocation
To verify the effectiveness of the proposed saliency-guided bitrate allocation scheme, we implement the proposed learned image codec with the traditional loss function that only considers the rate and signal fidelity, denoted as the “Saliency” scheme. As shown in the Fig. 4, the detection performance of all the schemes improves as the bitrate increases. The “Saliency” scheme outperforms VVC and the baseline when the bitrate is moderate, and has similar detection performance with these schemes at relatively low and high bitrates. When the bitrate is low, it almost overlaps with VVC in Fig. 4(a), and is slightly worse than VVC in Fig. 4(b). With a low bitrate, both signal distortion and task distortion are large, where the bitrate allocation is insufficient. Compared with the baseline scheme, the “Saliency” scheme consistently achieves higher detection accuracy with the same bitrate, since it prunes the feature channels according to saliency for machine vision.
332
H. Xiong and Y. Xu
(a) [email protected]
(b) [email protected]:0.95
Fig. 4. The performance comparison of the proposed algorithm with VVC and Baseline in terms of (a)[email protected]; (b)[email protected]:0.95
BPP/PSNR(dB) [email protected]/[email protected]:0.95
0.82 / 25.12 99.5 / 69.7
0.78 / 24.12 99.5 / 69.7
0.65/23.28 99.5 / 69.7
BPP/PSNR(dB) [email protected]/[email protected]:0.95
0.59 /23.20 99.4 / 64.7
0.57 / 22.09 99.5 / 79.6
0.47 / 21.41 99.5 / 84.3
(a) Original
(b) VVC
(c) Baseline
(d) Proposed
Fig. 5. Visual comparison of reconstructed images of VVC, the baseline scheme in [6] and the proposed scheme, with a cropped detected object in red and cropped background in green. (Color figure online)
Saliency-Guided Learned Image Compression for Object Detection
333
Performance of the “Saliency” scheme is evaluated in terms of Bjontegaard rate (BD-rate) [2], where the commonly used PSNR is replaced by a corresponding [email protected] or [email protected]:0.95. The BD-rate saving and BD-accuracy improvement of the “Saliency” scheme compared with VVC and the baseline scheme are listed in Table 1. As shown in the Table, the proposed bitrate allocation scheme can save 5.62% and 7.24% of the BD-Rate compared to VVC and baseline on [email protected], while improving [email protected] by 0.39% and 0.51%, respectively. In terms of [email protected]:0.95, the bitrate allocation increases the bitrate by 1.33% compared with VVC, while [email protected]:0.95 drops by 0.06%. Our approach saves 5.58% BD-Rate and improves [email protected]:0.95 by 0.42% over the baseline scheme. Table 1. BD-Rate saving in %, [email protected] and BD-mAP0.5:0.95 gain of the codec with the proposed saliency guided bitrate allocation, compared with the VVC and the baseline anchors Method BD-Rate [email protected] BD-Rate [email protected]:0.95 VVC
3.3
–5.62
0.39
1.33
–0.06
baseline –7.24
0.51
–5.58
0.42
The Overall Performance Evaluation
In this subsection, both the proposed loss function and saliency-guided bitrate allocation are implemented for the proposed scheme to evaluate the overall effectiveness, denoted as the “Saliency+RAO” scheme in Fig. 4. The [email protected] and [email protected]:0.95 are 0.83 and 0.62 on the uncompressed validation dataset, respectively. With the proposed loss function considering task loss, the “Saliency+RAO” scheme achieves better performance than the “Saliency” scheme on object detection. The signal fidelity of different schemes is displayed in Fig. 5, where the improvement in machine consumption comes at a cost of signal fidelity. The BD-rate saving and BD-accuracy improvement of the “Saliency+RAO” scheme compared with VVC and baseline are listed in Table 2. As shown in the table, the proposed scheme can achieve up to 13.39% and 14.12% reduction in bitrate with the same detection accuracy compared with VVC and the baseline learned image codec, respectively. Table 2. BD-Rate saving in %, [email protected] and BD-mAP0.5:0.95 gain of the proposed overall scheme compared with the VVC and the baseline anchors Method BD-Rate [email protected] BD-Rate [email protected]:0.95 −13.39
0.84
−6.11
0.41
baseline −14.12
0.91
−12.77
0.87
VVC
334
H. Xiong and Y. Xu
Visual comparison of reconstructed images of VVC, the baseline scheme and the proposed scheme is provided in Fig. 5. It can be observed that compared with the baseline and proposed scheme, although VVC preserves more details of the image with higher signal fidelity at a larger bitrate, these details are not informative to machine vision and do not improve the detection performance. Our method represents areas of non-informative information with fewer channels during compression, thereby reducing the bitrate while maintaining the performance of object detection.
4
Conclusion
In this paper, we propose a saliency-guided end-to-end image compression scheme for machines. The proposed scheme provides a compact representation of image for object detection, and maintains the capability of signal reconstruction. Saliency analysis is first conducted using a black-box explanation method of neural networks and a proposed map merging approach. The bitrate allocation scheme is then designed that adjusts the number of feature channels for different regions according to their saliency for machine vision. The network for the proposed codec is trained with a proposed loss function of rate, signal distortion and task loss, where the task loss metric is designed without using the ground-truth annotation. Experimental results demonstrate the effectiveness of the proposed scheme in maintaining the detection accuracy with less bitrate compared with other state-of-the-art methods.
References 1. Arrieta, A.B., et al.: Explainable artificial intelligence (XAI): Concepts, taxonomies, opportunities and challenges toward responsible AI. Inf. Fusion 58, 82– 115 (2020) 2. Bjontegaard, G.: Calculation of average PSNR differences between RD-curves. In: VCEG-M33 (2001) 3. Cai, Q., Chen, Z., Wu, D., Liu, S., Li, X.: A novel video coding strategy in HEVC for object detection. IEEE Trans. Circuits Syst. Video Technol. 31(12), 4924–4937 (2021) 4. Everingham, M., Van Gool, L., Williams, C.K.I., Winn, J., Zisserman, A.: The PASCAL visual object classes challenge. http://host.robots.ox.ac.uk/pascal/VOC/ 5. Gao, W., Liu, S., Xu, X., Rafie, M., Zhang, Y., Curcio, I.: Recent standard development activities on video coding for machines (2021). https://arxiv.org/abs/2105. 12653 6. Hu, Y., Yang, W., Liu, J.: Coarse-to-fine hyper-prior modeling for learned image compression. In: Thirty-Fourth AAAI Conference on Artificial Intelligence, AAAI 2020, New York, NY, USA, 7–12 February 2020, pp. 11013–11020 (2020) 7. Huang, Z., Jia, C., Wang, S., Ma, S.: Visual analysis motivated rate-distortion model for image coding. In: 2021 IEEE International Conference on Multimedia and Expo (ICME), pp. 1–6. IEEE (2021) 8. Vtm reference software for vvc (2021). https://vcgit.hhi.fraunhofer.de/jvet/ VVCSoftware VTM/-/tree/VTM-16.0
Saliency-Guided Learned Image Compression for Object Detection
335
9. Le, N., Zhang, H., Cricri, F., Ghaznavi-Youvalari, R., Tavakoli, H.R., Rahtu, E.: Learned image coding for machines: a content-adaptive approach. In: IEEE International Conference on Multimedia and Expo (ICME), pp. 1–6 (2021) 10. Li, Y., et al.: Joint rate-distortion optimization for simultaneous texture and deep feature compression of facial images. In: Fourth IEEE International Conference on Multimedia Big Data, BigMM 2018, Xi’an, China, pp. 1–5 (2018) 11. Petsiuk, V., Das, A., Saenko, K.: Rise: randomized input sampling for explanation of black-box models. In: British Machine Vision Conference 2018, BMVC 2018, Newcastle, UK, 3–6 September (2018) 12. Petsiuk, V., et al.: Black-box explanation of object detectors via saliency maps. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 11438–11447 (2021) 13. Redmon, J., Farhadi, A.: Yolov3: an incremental improvement (2018). 10.48550/ARXIV.1804.02767 14. Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Gradcam: visual explanations from deep networks via gradient-based localization. In: IEEE International Conference on Computer Vision (ICCV), pp. 618–626 (2017) 15. Sullivan, G.J., Ohm, J.R., Han, W.J., Wiegand, T.: Overview of the high efficiency video coding (HEVC) standard. IEEE Trans. Circuits Syst. Video Technol. 22(12), 1649–1668 (2012) 16. Ultralytics: Yolov3 implementation (2021). https://doi.org/10.5281/zenodo. 6222936, https://github.com/ultralytics/yolov3 17. Wang, S., et al.: Teacher-student learning with multi-granularity constraint towards compact facial feature representation. In: IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). pp. 8503–8507 (2021) 18. Wang, S., Wang, Z., Wang, S., Ye, Y.: End-to-end compression towards machine vision: network architecture design and optimization. IEEE Open J. Circuits Syst. 2, 675–685 (2021) 19. Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 2921–2929 (2016)
Multi-label Learning with Data Self-augmentation Yuhang Ge1 , Xuegang Hu1 , Peipei Li1(B) , Haobo Wang2 , Junbo Zhao2 , and Junlong Li1 1
Hefei University of Technology, Hefei, China {yuhangge,lijl}@mail.hfut.edu.cn, {jsjxhuxg,peipeili}@hfut.edu.cn 2 Zhejiang University, Hangzhou, China {wanghaobo,j.zhao}@zju.edu.cn Abstract. Multi-label learning (MLL) deals with the problem where each training example is associated with multiple labels. Existing MLL approaches focus on manipulating feature space and modeling label dependencies among labels. However, both of them require additional burden assumps and cannot easily be embedded into existing algorithms, while data augmentation is the more intuitive way to facilitate MLL. Therefore, in this paper, we propose a novel data augmentation method named MLAUG, i.e. Multi-Label learning with data self-Augmentation for MLL. Specifically, to achieve data augmentation, we learn feature correlation and label correlation matrices in feature space and label space simultaneously in an adaptive manner. Guided by the learned correlation matrices, MLAUG is able to refine the original feature and label spaces by linear combinations of other data vectors. In this way, we could obtain semantic-richer feature distribution and smoother label distributions, thus facilitating multi-label predictive performance. Besides, to further improve the model’s performance, we introduce feature and label graph Laplacian regularization for guaranteeing the discriminability and capturing more adequate label correlation, respectively. Extensive experiment results demonstrate the effectiveness of our proposed MLAUG.
Keywords: Multi-Label Learning Correlation
1
· Data Augmentation · Label
Introduction
Multi-label learning (MLL) learns from the training data, where each instance is associated with a set of labels simultaneously [1,2]. Recently, MLL has been widely applied in various tasks, such as text categorization [3] and video annotation [4]. The key challenges of MLL have two folds: 1) complex semantic structure in the feature space, which makes it hard to exploit intrinsic structure information in feature space, and 2) an exponential-sized label space (2m , where m is the number of possible labels), which may overwhelm MLL algorithms. Recent studies show that manipulating feature space is an effective manner to capture the structure information hidden in the feature space, and numerous c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 336–347, 2023. https://doi.org/10.1007/978-981-99-1639-9_28
Multi-label Learning with Data Self-augmentation
337
algorithms have been proposed. For instance, learning label-specific features is a popular strategy for manipulating features, which assumptions that the distinct characteristics of each class label can be captured via investigating the underlying properties of the training instance [5]. Besides, feature selection [6] and dimensionality reduction [7] learning more compact representations for the original features. However, these methods all require designing complicated data transformation strategies and cannot easily be embedded into existing algorithms, while enriching feature information through feature-level augmentation is a more intuitive way to manipulate the feature space. Meanwhile, to tackle exponential-size output label space, modeling the dependencies between label variables effectively reduces the complexity of the label space. For instance, GLOCAL [8] exploits latent label representation and optimizes label by exploiting global and local label correlations simultaneously. C2AE [9] and TSLE [10] adopt label embedding techniques to project features and labels into a shared space. Indeed, multi-label dependency modeling effectively simplifies the multi-label complex label semantic space. However, modeling label dependencies often requires developing additional burden assumptions (i.e. low-rank), which may not always hold in the real-world applications. Therefore, in parallel to the feature-level augmentation technique, in this work, we further investigate the label-level augmentation technique, which directly aggregates information from similar label vectors and is able to effectively capture higher-order relevant information. To the end, in this paper, we propose a simple but efficient data augmentation framework for MLL named MLAUG, i.e. Multi-Label learning with data self-Augmentation for multi-label learning. The basic idea of MLAUG is to refine original feature and label spaces by leveraging intrinsic structure information. More specifically, we learn feature and label correlation matrices simultaneously in an adaptive manner. Guided by the feature correlation matrix, MLAUG is able to refine the original feature space by linear combinations of other data vectors. Similarly, the label correlation matrix guides the propagation of label vectors. The label augmentation could be viewed as the label smoothing, which alleviates the model’s overconfidence in the classification results. Besides, we introduce a feature graph Laplacian regularization to capture feature correlation for guaranteeing the discriminability and impose a label graph Laplacian regularization to constrain the output label space for capturing more adequate relationships between labels. Extensive experiments on eight benchmark multi-label datasets show the effectiveness of MLAUG.
2
Related Works
Multi-Label Learning. The general MLL methods involve two types of strategies including manipulating the feature space and modeling label dependencies in label space [2,11]. For manipulating the feature space strategy, label-specific feature learning is the most representative work. Generally, label-specific feature learning can be generated in two different manners, prototype-based feature transformation [12] and feature selection based methods [6]. In addition,
338
Y. Ge et al.
some embedding-based methods [13] align the feature and label to obtain better prediction performance. Another line of work in MLL aims to model label dependencies to handle the exponentially sized output space of MLL. Leveraging label dependencies between label variables can significantly reduce the complexity of the label space [14]. It could be roughly divided into three groups, namely first-order correlations [15], second-order correlations [16] and high-order correlations [17]. Besides, considering the complexity or difficulty of processing the original label space directly, some methods are used to pre-process the original label space, such as utilizing manifold deduction [18], sparse reconstruction [19] and canonical correlation analysis [20]. Based on the optimal latent label space, label semantics can be captured more accurately, thereby deriving the desired multi-label model to improve the performance. Data Augmentation. Although no previous work explicitly defines data selfaugmentation, some approaches similar to the MLAUG have been proposed in past few years. For example, MLFE [19] utilizes structure information in feature space to enrich label space. The Label Enhancement (LE) technique focused on converting logical labels to numerical labels, allowing labels to reflect the importance degree of each class. In addition, in multi-dimensional learning, KARM [21] and LEFA [22] adopt k-Nearest Neighbor (kNN) and Label Embedding to enrich the original feature space. But these methods perform unilateral augmentation w.r.t feature or label space. Different from these methods, MLAUG serves as the first attempt which considers both feature and label space augmentation by exploiting the similar information of the space itself.
3 3.1
Methods Problem Formulation
In MLL, let X ∈ Rd represent the input space with d-dimension features and Y = {y1 , y2 , · · · , yl } denote a set of l possible labels. we denote the input data as the matrix X = [x1 , x2 , · · · , xn ]T ∈ Rn×d , where xi = [xi1 , xi2 , · · · , xid ] is a feature vector and represents the i-th instance. Besides, the output of labels is denoted as Y = [y1 , y2 , · · · , yn ]T ∈ Rn×l , where yi = [yi1 , yi2 , · · · , yil ] is a set of ground-truth labels corresponding to xi , where yij = 1 means the i-th instance is associated with j-th label, and vice versa. The task of MLL aims to learn a coefficient matrix W = [w1 , w2 , . . . , wl ] ∈ Rd×l from training data. In this paper, we propose a novel MLL approach named MLAUG, which combines data self-augmentation and predictive model learning into a unified framework. As aforementioned, we investigate both feature and label augmentation by mining augmentation signals from the data itself. Specifically, we learned correlation matrices in feature and label spaces, respectively, and then applied the correlation matrices to refine the origin space. In addition, we utilize graph Laplacian regularization to constrain the classifier coefficient matrix and output label space to further guarantee the classification performance.
Multi-label Learning with Data Self-augmentation
3.2
339
Data Augmentation for Multi-label Learning
Feature Augmentation. The augmented properties of each instance should be reconstructed using information from the others in the feature space. Based on this, we try to learn a feature correlation matrix V ∈ Rd×d from original features, and apply feature correlation on original feature space to achieve feature-level data augmentation. Obviously, the quality of V affects the augmentation performance significantly. Several previous methods have been proposed to obtain V such as [23] with precomputing similarity matrix over the feature space. However, if V is obtained directly in feature space, it is easy to introduce unnecessary noise data to augmented data. To this end, instead of precomputing a fixed matrix, we learn the feature correlation matrix in an adaptive manner. Our method obtains the correlation matrix and simultaneously enhances the feature representation to achieve the best results. The optimization problem can be defined as min V
λ1 X − XV 2F . 2
s.t. vii = 0, vij ≥ 0
(1)
where vii is the diagonal element of V , and vij represents the similarity value between i-th feature and of j-th feature. Label Augmentation. Similar to feature augmentation, we also learn a label correlation matrix U ∈ Rl×l in an adaptive manner and use it to achieve label augmentation. As mentioned earlier, the label-level augmentation could be regarded as a kind of label smoothing, which calibrates the classifier’s overconfidence in the output results by fusing the label information from others. Thus, we can rewritten the framework as λ1 λ2 X − XV 2F + Y − Y U 2F . 2 2 s.t. vii = 0, vij ≥ 0, uii = 0, uij ≥ 0
min U,V
(2)
Train Classifier with Augmented Spaces. With the augmented feature and label spaces, we conduct an effective classifier with coefficient parameter matrix W . In addition, we constrain W by l1 -norm regularizer to learn sparse label-specific features where irrelative features of each label will be removed. It could further improve the performance of the multi-label classifier and reduce the model size. The object function mentioned above can be further rewritten as min
W,U,V
1 λ1 λ2 XV W − Y U 2F + X − XV 2F + Y − Y U 2F + λ5 W 1 . 2 2 2 s.t. vii = 0, vij ≥ 0, uii = 0, uij ≥ 0,
(3)
where uii is the diagonal element of U , and uij represents the similarity value between i-th label vector and of j-th label vector.
340
3.3
Y. Ge et al.
Graph Laplacian Regularization
To further improve the multi-label classification performance, we introduce a feature graph Laplacian regularization to guarantee discriminability of classifier and impose a label graph Laplacian regularization to constrain on output label space to capture more adequate relationships between labels. For feature regularization, we assume that if features fi and fj are strongly correlated, their corresponding classifier parameter vectors wi and wj will be also similar. Thus, we further rewrite the objective function as 1 λ1 λ2 λ3 T tr W L1 W min XV W − Y U 2F + X − XV 2F + Y − Y U 2F + W,U,V 2 2 2 2 s.t. vii = 0, vij ≥ 0, uii = 0, uij ≥ 0, (4) where L1 = D − V ∈ Rd×d is the graph Laplacian matrix of V . For label regularization, we adopt the assumption that if two labels yi and yj are strongly correlated, their corresponding functions XV wi and XV wj should have similar outputs. In this case, the complex relationship between features and labels can be implicitly explored with min
l i,j
2 Uij XV wi − XV wj 2
(5)
where L2 is the Laplacian matrix of U . After adding this regularizer, the final optimization hence can be written as 1 λ1 λ2 min XV W − Y U 2F + X − XV 2F + Y − Y U 2F W,U,V 2 2 2 λ3 λ4 (6) tr W T L1 W + tr (XV W )L2 (XV W )T + λ5 W 1 + 2 2 s.t. vii = 0, vij ≥ 0, uii = 0, uij ≥ 0, 3.4
Optimization
An alternating optimization strategy is adopted to solve the above objective function, the detail of optimization is as follows. Updating Feature Correlation V. when fixing model parameters W and U , the problem can be rewritten as 1 λ1 λ3 tr XV W L2 (XV W )T . min XV W − Y U 2F + X − XV 2F + V 2 2 2 (7) s.t. vii = 0, vij ≥ 0 we use projected gradient descent [24] to optimize the above objective function to be positive. The gradient of the objective with respect to V is ∇f (V ) = X T XV W T W +λ1 I + λ4 W L2 W T − λ1 X T X − X T Y U W T . (8)
Multi-label Learning with Data Self-augmentation
341
Updating Label Correlation U. when fixing model parameters W and V , the problem can be rewritten as 1 λ2 min XV W − Y U 2F + Y − Y U 2F . U 2 2
s.t.
uii = 0, uij ≥ 0,
(9)
Same as V , we use projected gradient descent to optimize the above objective function to be positive. The gradient of the objective with respect to U is ∇f (U ) = Y T Y ((1 − λ2 )U − λ2 I) − Y T XV W.
(10)
Updating Classifier Parameter W. when the feature correlation matrices V and U are fixed, the objective function is turned into λ4 1 λ3 min XV W −Y U 2F + tr W T L1 W + tr XV W L2 (XV W )T +λ5 W 1 . W 2 2 2 Then, the gradient w.r.t W can be calculated as ∇f (W ) = V T X T (XV W − Y U + λ4 XV W L2 ) + λ3 L1 W.
(11)
The 1 -norm regularization w.r.t W can be solved by the element-wise softthreshold operator. According to the proximal gradient descend algorithm [25], W can be updated by Wt+1 = prox λ5 (W(t) − Lf
1 ∇f (W (t) )) Lf
(12)
−1 where W (t) = Wt + bt−1 (Wt − Wt−1 ). Lf is the Lipschitz constant, and an bt upper bound of it is shown in Theorem 1. For a sequence bt , it should satisfy b2t+1 − bt+1 ≤ b2t , and proxt (a) is the element-wise operator which is defined as
prox (a) = sign(a) max(|a| − , 0).
(13)
Theorem 1. (Lipschitz Continuous Gradient). Given two arbitrary distinct parameters W1 and W2 , we have 2
(14) ∇W L (W1 ) − ∇W L (W2 )F ≤ γΔW2F 2 2 2 2 where γ = 3 V T X T XV 2 + λ3 L1 2 + λ4 V T X T XV 2 · L2 2 and ΔW = W1 − W2 , and an approximate Lipschitz constant can be calculated by, 2 2 2 2 (15) Lf = 3 V T X T XV 2 + λ3 L1 2 + λ4 V T X T XV 2 · L2 2 Proof. Given W1 and W2 , according to Eq. 11, we have 2
∇f (W1 ) − ∇f (W2 )F 2 = V T X T XV ΔW + λ3 L1 ΔW + λ4 V T X T XV ΔW L2 F 2 2 2 2 ≤ 3 V T X T XV 2 + λ3 L1 2 + λ4 V T X T XV 2 · L2 2 ΔW 2F . The pseudo code of MLAUG is summarized as shown in Algorithm 1.
(16)
342
Y. Ge et al.
Algorithm 1: Optimization of MLAUG Require: Training data matrix X ∈ Rn×d , label matrix Y ∈ Rn×l , and weighting parameters λ1 , λ2 , λ3 , λ4 , λ5 , γ Ensure: Feature correlation matrix V ∈ Rd×d , Label correlation matrix U ∈ Rq×q , Coefficient matrix W ∈ Rd×l −1 T Initialization: b0 , b1 ← 1; W0 , W1 ← X T X + γI X Y; V0 ← (X T X + γI)−1 (X T Y W T + λ1 X T X)(W W T + λ1 I + λ3 W L2 W T )−1 ; U0 ← ((1 + λ2 )Y T Y + λI)−1 (λ2 Y T Y + Y T XV W ); 1: while stop criterion not reached do 2: Update Vt with gradient Eq.8; 3: Update Ut with gradient Eq.10; 4: Calculate Lipschitz constant Lf by Eq.15 ; b −1 5: W (t) ← Wt + t−1 (Wt − Wt−1 ); √ 2 bt 1+ 4bt +1 ; 6: bt+1 ← 2 7: end while 8: return W , U , V
4
Experiments
4.1
Datasets and Experiment Settings
We conduct experiments on eight benchmark MLL datasets including Arts, Science, Enron, Health, Corel5k, Bibtex, Rcv1 and Rcv2, which could be found in Mulan1 and LAMDA2 repositories. Five widely used multi-label metrics Ranking Loss, One-Error, Coverage, Average Precision and Micro-F1 are employed. To highlight MLAUG, we compare it with six well-established MLL algorithms, including MLKNN [1], LLSF [26], GLOCAL [8], LSF-CI [5], LSML [29] and CLML [23]. For all these methods, we set the hyperparameters according to the suggestions in original literature. Parameters include λ1 = 22 ,
−1 1of 2MLAUG 2 3 λ2 = 2 , and λ3 , λ4 , λ5 are choose among 2 , 2 , 2 , 2 . Meanwhile, we performed five cross-validations and averaged the results. 4.2
Experimental Results
In this section, we report the statistical results of all comparing methods on five performance measures mentioned above. The results on Ranking Loss One Error and Coverage, which is the lower the better, are reported in Table 1. While for Average precision and Micro-f1, which is the higher the better, can be observed in Table 3. According to the experimental results, we have the following observations.
1 2
http://mulan.sourceforge.net/. http://www.lamda.nju.edu.cn/CH.Data.ashx.
Multi-label Learning with Data Self-augmentation
343
Table 1. Experimental results on seven multi-label datasets in terms of Ranking Loss, One Error and Coverage, where •/◦ indicates whether MLAUG is superior/inferior to the other methods on each data set. Datasets Ranking Loss (the lower the better) MLKNN LLSF Arts
GLOCAL LSF-CI LSML CLML MLAUG
0.148 •
0.172 • 0.129 •
0.163 • 0.132 • 0.133 • 0.127
Science 0.114 •
0.163 • 0.121 •
0.149 • 0.123 • 0.117 • 0.111
Enron
0.093 •
0.120 • 0.085 •
0.087 • 0.086 • 0.102 • 0.080
Health
0.060 •
0.076 • 0.068 •
0.081 • 0.058 • 0.069 • 0.051
Corel5k 0.135 •
0.204 • 0.163 •
0.193 • 0.166 • 0.180 • 0.129
Bibtex
0.212 •
0.083 • 0.151 •
0.071 ◦ 0.063 ◦ 0.068 ◦ 0.074
Rcv1
0.097 •
0.063 • 0.057 •
0.053 • 0.050 • 0.052 • 0.045
Rcv2
0.095 •
0.064 • 0.062 •
0.054 • 0.049 • 0.048 • 0.046
Datasets One Error (the lower the better) MLKNN LLSF Arts
GLOCAL LSF-CI LSML CLML MLAUG
0.613 •
0.460 • 0.451 •
0.461 • 0.452 • 0.455 • 0.444
Science 0.557 •
0.495 • 0.497 •
0.489 • 0.483 • 0.478 ◦ 0.478
Enron
0.304 •
0.280 • 0.222 •
0.243 • 0.225 • 0.256 • 0.213
Health
0.384 •
0.261 • 0.258 •
0.257 • 0.253 • 0.257 • 0.250
Corel5k 0.737 •
0.648 • 0.671 •
0.647 • 0.640 • 0.650 • 0.635
Bibtex
0.585 •
0.353 • 0.584 •
0.369 • 0.367 • 0.347 ◦ 0.351
Rcv1
0.543 •
0.429 • 0.440 •
0.417 ◦ 0.413 ◦ 0.419 • 0.419
Rcv2
0.566 •
0.411 • 0.434 •
0.424 • 0.406 • 0.423 • 0.402
Datasets Coverage (the lower the better) MLKNN LLSF Arts
GLOCAL LSF-CI LSML CLML MLAUG
0.206 •
0.248 • 0.199 •
0.236 • 0.203 • 0.204 • 0.197
Science 0.149 ◦
0.213 • 0.167 •
0.196 • 0.170 • 0.162 • 0.155
Enron
0.249 •
0.312 • 0.241 •
0.230 ◦ 0.250 • 0.280 • 0.233
Health
0.101 ◦
0.141 • 0.128 •
0.141 • 0.114 • 0.130 • 0.104
Corel5k 0.307 •
0.462 • 0.384 •
0.401 • 0.388 • 0.411 • 0.287
Bibtex
0.347 •
0.163 • 0.216 •
0.126 ◦ 0.122 ◦ 0.132 ◦ 0.140
Rcv1
0.205 •
0.147 • 0.137 •
0.127 • 0.127 • 0.130 • 0.116
Rcv2
0.194 •
0.144 • 0.141 •
0.123 • 0.122 • 0.115 ◦ 0.115
– Across all datasets, MLAUG achieves the best performance in 94.6%, 92.8%, 87.5%, 96.4% and 100% cases in terms of ranking loss, one-error, coverage and average precision and micro-f1 respectively. Although CLML achieves the great performance in terms of one-error and coverage, its performance on the other evaluation metrics are less competitive. The outstanding advantages of MLAUG show that the feature augmentation of MLAUG can indeed bring useful discriminative information to the feature space. Specifically, this discriminative information brought into the feature space could be considered as modeling feature dependency among different features. On the other side, the label augmentation of MLAUG could effectively improve the distribution of label space, thus yielding a smoother label space that further facilitates the mapping from feature space to label space. – Except on bibtex dataset, MLAUG significantly outperforms LLSF, LSFCI, LSML and CLML in all evaluation metrics. The reason is that all these algorithms including MLAUG learn label-specific features in model induction.
344
Y. Ge et al.
However, other methods perform label-specific association in raw feature and label spaces, while MLAUG processes it in a more semantically rich space yield by data self-augmentation, which could provide more accurate labelspecific features. This phenomenon demonstrates that our proposed data selfaugmentation effectively improves the data distribution and could take with other parts of the model for collaborative learning. Furthermore, Friedman test [28] is Table 2. Friedman Statistics FF employed as the statistic test to evaluate (K=7, N=8) and the critical value the relative performance among the com- of each evaluation metric paring methods. Assume that there are K Metrics FF Critical Value algorithms and N data sets. Let rij denote Ranking Loss 8.602 2.324 the rank of j-th algorithm on the i-th data One Error 14.553 set. The average ranks of algorithms Rj = Coverage 4.424 j 1 r Friedman test comparison. Under Average Precision 27.273 i i N 33.205 the null-hypothesis, which indicates that all Micro F1 the algorithms have equivalent performance, the Friedman statistic FF with respective to the F -distribution, with (K − (N −1)χ2F 2 1)(N − 1) degree of freedom can be defined: FF = N (K−1)−χ 2 , where χF = F
2 k(k+1) 12N 2 . The Friedman statistics FF of each evaluation metj Rj − k(k−1) 4 ric and critical values are summarized in Table 2. For each evaluation metric, the null hypothesis among the comparing algorithm is rejected at significance level α = 0.05, which indicates that there are significant differences between comparative algorithms. Table 3. Experimental results on seven multi-label datasets in terms of Average Precision and Micro-F1, where •/◦ indicates whether MLAUG is superior/inferior to the other methods on each data set. Datasets Average Precision (the higher the better) MLKNN LLSF Arts
GLOCAL LSF-CI LSML CLML MLAUG
0.525 •
0.606 • 0.629 •
0.612 • 0.628 • 0.626 • 0.634
Science 0.550 •
0.579 • 0.597 •
0.592 • 0.607 • 0.612 • 0.614
Enron
0.633 •
0.650 • 0.707 •
0.683 • 0.709 ◦ 0.677 • 0.708
Health
0.696 •
0.773 • 0.779 •
0.762 • 0.788 • 0.778 • 0.793
Corel5k 0.244 •
0.295 • 0.261 •
0.266 • 0.309 • 0.301 • 0.311
Bibtex
0.349 •
0.601 • 0.367 •
0.582 • 0.602 • 0.616 ◦ 0.607
Rcv1
0.485 •
0.598 • 0.576 •
0.608 • 0.618 • 0.613 • 0.619
Rcv2
0.495 •
0.623 • 0.584 •
0.623 • 0.639 • 0.639 • 0.642
Datasets Micro-F1 (the higher the better) MLKNN LLSF Arts
GLOCAL LSF-CI LSML CLML MLAUG
0.145 •
0.427 • 0.415 •
0.441 • 0.432 • 0.417 • 0.465
Science 0.277 •
0.430 • 0.412 •
0.443 • 0.445 • 0.432 • 0.464
Enron
0.466 •
0.556 • 0.609 •
0.546 • 0.616 • 0.586 • 0.617
Health
0.361 •
0.594 • 0.567 •
0.603 • 0.642 • 0.628 • 0.655
Corel5k 0.028 •
0.095 • 0.006 •
0.201 • 0.286 • 0.173 • 0.289
Bibtex
0.218 •
0.389 • 0.165 •
0.402 • 0.490 • 0.500 • 0.500
Rcv1
0.181 •
0.497 • 0.388 •
0.505 • 0.494 • 0.481 • 0.531
Rcv2
0.198 •
0.480 • 0.378 •
0.490 • 0.491 • 0.481 • 0.513
Multi-label Learning with Data Self-augmentation
345
In the end, Nemenyi test [28] is employed to conduct whether our approach MLAUG can get a superior performance compared to the other competing algorithms. Here, MLAUG is regarded as the control method whose average rank difference against the comparing algorithm is calibrated with the critical dif
ference (CD): CD = qa k(k+1) 6N , where critical value qa = 2.948 at significance level α = 0.05. Accordingly, MLAUG is deemed to have significantly different performance to a comparing algorithm if their average ranks differ by at least one CD (CD = 3.1842, K = 7 and N = 8 in our experiments). The CD diagrams on each evaluation metric are shown in Fig. 1, where the average rank of each comparing algorithm is marked along the axis (lower ranks to the right). In each sub-figure, if two algorithms are not connected, it means their average rankings differ by one CD and their performances have significant differences. It can be observed that MLAUG achieves the best (lowest) average rank among comparing methods and outperforms all other comparing methods at least one CD in terms of all evaluation metrics. The experimental results demonstrate the significance of the superiority of our MLAUG approach.
Fig. 1. Comparing MLAUG against seven competing algorithms with Nemenyi test. Algorithms not connected with MLAUG in the CD diagram have a significantly different performance from the control algorithm (CD = 3.1842 at 0.05 significance level).
4.3
Parameter Sensitivity Analysis
In this section, we study the influences of three trade-off parameters, λ1 , λ2 and λ3 for the proposed approach on the bibtex data sets. We conducted experiments by varying one parameter while keeping the other four parameters fixed. Due to the page limit, we only show the experimental results which are measured by the five evaluation metrics. According to Fig. 2, we can observe that MLAUG is not sensitive to the parameters λ1 , thus we can safely set them in a wide range in practice. For other parameters, we observe that the performance of MLAUG improves as λ2 increases, and λ3 vice versa.
0.07 0.06 -8 -6 -4 -2 0 2 4 6
0.14 0.12 0.1
-8 -6 -4 -2 0 2 4 6
(a) λ1
0.5
0.4 0.35 0.3
-8 -6 -4 -2 0 2 4 6
-8 -6 -4 -2 0 2 4 6
0.2
0.1 0.08 0.06 -8 -6 -4 -2 0 2 4 6
0.15 0.1
-8 -6 -4 -2 0 2 4 6
(b) λ2
0.7
OneError
OneError
0.55
0.6 0.5
0.4 0.3 0.2 0.1
-8 -6 -4 -2 0 2 4 6
0.15
-8 -6 -4 -2 0 2 4 6
0.3
Coverage
0.08
0.45
0.6
AveragePrecision
0.34 -8 -6 -4 -2 0 2 4 6
0.65
RankingLoss
0.58 -8 -6 -4 -2 0 2 4 6
0.36
Coverage
0.6
RankingLoss
0.38
OneError
0.62
AveragePrecision
Y. Ge et al.
Coverage
RankingLoss
AveragePrecision
346
0.1 0.05 -8 -6 -4 -2 0 2 4 6
0.25 0.2
-8 -6 -4 -2 0 2 4 6
(c) λ3
Fig. 2. .
5
Conclusion
In this paper, we propose a novel MLL algorithm MLAUG with data selfaugmentation. The major contributions are two-fold: 1) A simple yet effective augmentation technique aiming at manipulating feature and label spaces for MLL, which provides an alternative solution to learn from multi-label datasets; 2) Two efficient graph Laplacian regularizations are applied simultaneously and integrated with data augmentation and classifier training into the same framework in a collaborative way, whose effectiveness is thoroughly validated based on extensive comparative studies. We hope our work will draw more attention from the community toward a broader view of using data augmentation for MLL.
References 1. Zhang, M., Zhou, Z.: ML-KNN: a lazy learning approach to multi-label learning. Pattern Recognit. 40, 2038–2048 (2007) 2. Liu, W., Wang, H., Shen, X., Tsang, I.: The emerging trends of multi-label learning. In: IEEE TPAMI (2021) 3. Yang, B., Sun, J., Wang, T., Chen, Z.: Effective multi-label active learning for text classification. In: KDD, pp. 917–926 (2009) 4. Ji, W., Wang, R.: A multi-instance multi-label dual learning approach for video captioning. ACM Trans. Multimedia Comput. Commun. Appl. 17, 1–18 (2021) 5. Han, H., Huang, M., Zhang, Y., Yang, X., Feng, W.: Multi-label learning with label specific features using correlation information. IEEE Access 7, 11474–11484 (2019) 6. Fan, Y., Liu, J., Liu, P., Du, Y., Lan, W., Wu, S.: Manifold learning with structured subspace for multi-label feature selection. Pattern Recognit. 120, 108169 (2021) 7. Yu, T., Zhang, W.: Semisupervised multilabel learning with joint dimensionality reduction. IEEE Signal Process. Lett. 23, 795–799 (2016) 8. Zhu, Y., Kwok, J., Zhou, Z.: Multi-label learning with global and local label correlation. IEEE TKDE. 30, 1081–1094 (2017) 9. Yeh, C., Wu, W., Ko, W., Wang, Y.: Learning deep latent space for multi-label classification. In: AAAI, pp. 2838–2844 (2017) 10. Chen, C., Wang, H., Liu, W., Zhao, X., Hu, T., Chen, G.: Two-stage label embedding via neural factorization machine for multi-label classification. In: AAAI, pp. 3304–3311 (2019)
Multi-label Learning with Data Self-augmentation
347
11. Hang, J., Zhang, M.: Collaborative learning of label semantics and deep labelspecific features for multi-label classification. In: IEEE TPAMI (2021) 12. Zhang, M., Wu, L.: Lift: multi-label learning with label-specific features. IEEE TPAMI 37, 107–120 (2015) 13. Huang, K., Lin, H.: Cost-sensitive label embedding for multi-label classification. Mach. Learn. 106, 1725–1746 (2017) 14. Zhang, M., Zhang, K.: Multi-label learning by exploiting label dependency. In: KDD, pp. 999–1008 (2010) 15. Zhang, M.-L., Li, Y.-K., Liu, X.-Y., Geng, X.: Binary relevance for multi-label learning: an overview. Front. Comput. Sci. 12(2), 191–202 (2018). https://doi. org/10.1007/s11704-017-7031-7 16. Brinker, C., Mencı´ea, E., F¨ urnkranz, J.: Graded multilabel classification by pairwise comparisons. In: ICDM, pp. 731–736 (2014) 17. Tsoumakas, G., Katakis, I., Vlahavas, I.: Random k-labelsets for multilabel classification. IEEE TKDE 23, 1079–1089 (2011) 18. Hou, P., Geng, X., Zhang, M.: Multi-label manifold learning. In: AAAI, pp. 1680– 1686 (2016) 19. Zhang, Q., Zhong, Y., Zhang, M.: Feature-induced labeling information enrichment for multi-label learning. In: AAAI, vol. 32 (2018) 20. Chen, Y., Lin, H.: Feature-aware label space dimension reduction for multi-label classification. In: NeurIPS, vol. 25 (2012) 21. Jia, B., Zhang, M.: Multi-dimensional classification via kNN feature augmentation. Pattern Recognit. 106, 107423 (2020) 22. Wang, H., Chen, C., Liu, W., Chen, K., Hu, T., Chen, G.: Incorporating label embedding and feature augmentation for multi-dimensional classification. In: AAAI, pp. 6178–6185 (2020) 23. Li, J., Li, P., Hu, X., Yu, K.: Learning common and label-specific features for multiLabel classification with correlation information. Pattern Recognit. 121, 108259 (2022) 24. Chen, Y., Wainwright, M.: Fast low-rank estimation by projected gradient descent: general statistical and algorithmic guarantees. ArXiv Preprint ArXiv:1509.03025 (2015) 25. Beck, A., Teboulle, M.: A fast iterative shrinkage-thresholding algorithm for linear inverse problems. SIAM J. Imaging Sci. 2, 183–202 (2009) 26. Huang, J., Li, G., Huang, Q., Wu, X.: Learning label specific features for multilabel classification. In: ICDM, pp. 181–190 (2015) 27. Qiu, Y., Zhang, J., Zhou, J.: Improving gradient-based adversarial training for text classification by contrastive learning and auto-encoder. In: ACL/IJCNLP (Findings), pp. 1698–1707 (2021) 28. Demˇsar, J.: Statistical comparisons of classifiers over multiple data sets. J. Mach. Learn. Res. 7, 1–30 (2006) 29. Huang, J., Qin, F., Zheng, X., Cheng, Z., Yuan, Z., Zhang, W., Huang, Q.: Improving multi-label classification with missing labels by learning label-specific features. Inf. Sci. 492, 124–146 (2019)
MnRec: A News Recommendation Fusion Model Combining Multi-granularity Information Laiping Cui1 , Zhenyu Yang1(B) , Guojing Liu2 , Yu Wang1 , and Kaiyang Ma1 1
Qilu University of Technology (Shandong Academy of Sciences), Jinan, China qlu [email protected] 2 Ocean University of China, Jinan, China https://github.com/cuixiaopi/MnRec
Abstract. Personalized news recommendations can alleviate information overload. Most current representation matching-based news recommendation methods learn user interest representations from users’ behavior to match candidate news to perform recommendations. These methods do not consider candidate news during user modeling. The learned user interests are matched with candidate news in the last step, weakening the fine-grained matching signals (word-level relationships) between users and candidate news. Recent research has attempted to address this issue by modeling fine-grained interaction matching between candidate news and each news article viewed by the user. Although interactionbased news recommendation methods can better grasp the semantic focus in the news and focus on word-level behavioral interactions, they may not be able to abstract high-level user interest representations from the news users browse. Therefore, it is a worthwhile problem to make full use of the above two architectures effectively so that the model can discover richly detailed cues of user interests from fine-grained behavioral interactions and the abstraction of high-level user interests representations from the news that users browse. To address this issue, we propose MnRec, a framework for fusing multigranularity information for news recommendation. The model integrates the two matching methods via interactive attention and representation attention. In addition, we design a granularity network module to extract news multigranularity information. We also design an RTCN module to implement multilevel interest modeling of users. Extensive experiments on the real news dataset MIND verify the method’s validity.
Keywords: fine-grained
· integrates · MIND
This work was supported in part by Shandong Province Key R&D Program (Major Science and Technology Innovation Project) Project under Grants 2020CXGC010102 and the National Key Research and Development Plan under Grant No. 2019YFB1404701. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 348–360, 2023. https://doi.org/10.1007/978-981-99-1639-9_29
MnRec: A News Recommendation Fusion Model
1
349
Introduction
In today’s information explosion, it is difficult for users to select the news they are interested in from a large amount of news daily [3]. The application of personalized news recommendation technology can reduce information overload and improve users’ news reading experience and enhance user stickiness.
Historical Browsed News 1
Texans safety Justin Reid out with shoulder injury
2
The best and worst times to travel for the Christmas holiday
3
Black Cat That Hilariously Interrupted NFL Game Announced as Starter for the DallasCowboys Current Click Behavior of
Opinion: ColinKaepernick is about to get what he deserves: a chance Historical Browsed News 1
Early symptoms of dementia: Be aware of subtle signs
2
The scary anxietysymptom you might be overlooking
3
Sorry, But Beyond Burgers Are No Better Nutritionally Than A Beef Burger Current Click Behavior of
'Biggest Loser' Is Back And New Trainer Erica Lugo Has Her Own Weight-Loss Story To Share
Fig. 1. Example of a user behavior log in MSN. c is the news she actually clicked on later. Yellow bars represent important words in news headlines. (Color figure online)
In fact, the news is multi-granularity and user interests are diverse. Many existing approaches are based on representation-based matching strategies. For example, NAML and NRMS apply attention networks to extract meaningful information about different aspects of the text to learn user representations [13,15]. LSTUR utilizes GRU models and user ID embeddings to enable the modeling of users’ long-term and short-term interests [1]. UniRec inferred the user embedding ranking from the user’s historical news clicking behavior. This ranked user embedding is used as an attention query to improve recall and ranking in recommendations [16]. In fact, the word-level interaction between the news clicked by the same user and the candidate news helps to discover clues related to the user’s interests from the detailed matching information. However, the user interests modeled by most of the existing methods are only matched with the candidate news in the last step, and there is no interaction with the candidate news before this, ignoring the word-level behavior interaction, which may lead to the fact that the user’s interest modeling is not optimal. For example, Fig. 1 shows an example of the behavior of two users on MSN. Based on user 1’s reading history, we can see that the candidate news item “Colin Kaepernick” matches the word “Justin Reid” in the news item that user 1 first clicked on (a fine-grained interaction occurs), which drives the user’s current click. In addition, we can also infer from “NFL Game” and “Dallas Cowboys” that it is important for the user to be an NFL fan, as they refer to the game and the team.
350
L. Cui et al.
Unfortunately, the user aggregation vector obtained by taking the represented matching strategy mixes all terms in d1, d2, and d3, and noises such as Christmas and travel are included, which are not relevant to the current click. Thus, these fine-grained interests are blurred and the ability to model user interests is reduced. To capture the fine-grained matching signals between candidate news and users, researchers have proposed interactive matching strategies. For example, FIM extracts a multi-level representation for each news item and performs fine-grained matching by convolution [11]. To obtain a multi-field matching representation, AMM utilizes complementary information from different fields (e.g., title, summary, and body) [19]. Interaction-based news recommendation methods greatly excavate the semantic cues related to users’ interests. However, these approaches ignore the relationship between different news clicks from the same user. These relationships between different news clicks contain rich, detailed cues to infer user interest preferences. We summarized from user 2’s reading history d1, d2, and d3 that “dementia,” “anxiety symptom,” and “Nutritionally” inferred the user might be interested in health-related news, but the interactive matching strategy did not urge the user 2 to click on the candidate news. On the one hand, the keyword “Weight-Loss” in candidate news is far away in semantic space from “dementia,” “anxiety symptom,” and “Nutritionally” in historical news. On the other hand, there is no interaction between different news clicked by the same user, and the user’s interest information cannot be abstracted at a high level. The above two different matching strategies are not purely inclusive relationships but complementary relationships. Thus, making the model in an effective way that enables both word-level fine-grained behavioral interactions between candidate news and the news viewed by the user and high-level abstract representations of the user’s interests is the problem we address. In this paper, we propose a news recommendation fusion method incorporating multi-granularity information (MnRec), a new recommendation architecture. The method effectively combines two matching strategies through interactive and representational attention to solve the problem of single strategy exposure. The advantage of this method lies in two cores: multi-granularity news representation and multi-level user interest representation. Since the user interests learned by the representation-based strategy method are matched with the candidate news only in the last step, ignoring the word-level interaction information. Therefore, our method first applies interactive attention between the candidate news and the historical news browsed by the user to capture the key semantic information in the news and grasp the semantic focus. Weighted news is more likely to capture the key parts of a piece of text in the process of information representation. We then build a granular network for feature extraction to achieve multi-granularity representation of news. Then we use Bi-LSTM [21] with representational attention to learn a representation for each news. Finally, we propose a hierarchical user interest modeling framework to abstract user interest information at a high level, and different levels correspond to different granularity of user interests. To ensure the richness of semantic information, we adopt the idea of residual network to connect the multi-level user interest information after
MnRec: A News Recommendation Fusion Model
351
feature expression with the original information to obtain the final user interest expression vector to match the candidate news and calculate the click probability. The main contributions of this work are outlined below: – We propose a new approach that merges a representation-based matching model with an interaction-based matching model through interaction attention and representation attention to solve the problem of single model exposure. – We design a granularity network to implement multigranularity information feature extraction for news texts. – We design the RTCN module to achieve multi-level interest modeling for users to enhance user interest representation.
2
Our Approach
In this section, we will introduce the news recommendation fusion method MnRec combining multi-granularity information in detail. Our model consists of the following main modules: a cross-attention module, a granular network module, a representational attention module, a multi-level user interest building module, and a prediction module that calculates the probability of a user clicking on candidate news. The architecture of our method is shown in the following Fig. 2. An efficient introduction to each part of the model is provided below.
Fig. 2. The framework of our MnRec news recommendation method.
352
2.1
L. Cui et al.
Word Embedding
The input layer of the MnRec model proposed in this paper is an embedding lookup function. This function uses the pretrained Glove word vectors to initialize the embedding matrix. We use T words to represent candidate news headlines Dc = [w1 , w2 , . . . , wT ], and J words to represent the k-th news headline Dk = [w1 , w2 , . . . , wJ ] browsed by the user. Through the word vector search matrix We ∈ RV ×D , the candidate news mapped to the low-dimensional vector space is denoted as c = [e1 , e2 , . . . , eT ], and the history news viewed by the user is denoted as dk = [e1 , e2 , . . . , eJ ], where V and D are the vocabulary size and word embedding dimension, respectively. 2.2
Cross-Attention
We first apply interactive attention to weight the news, and the weighted news is easier to grasp the key part of a piece of text in the process of information representation. We embed the word vector with the historical news viewed by users and the candidate news to calculate the similarity between each pair of words in the news headlines and obtain the similarity matrix M. The similarity matrix is computed with Mc,k = Linear (c · dk + b), where b is the bias vector. ∈ RT ×J , and Mc,k denotes The similarity matrix is constructed to obtain Mc,k t,j t,j the similarity between the t-th word in the candidate news and the j-th word in the k-th historically viewed news. Let at ∈ RJ denote the attention weight of the t-th word of the candidate news to the words of the news previously = 1, For all t. The weight of attention is viewed by the user, where atj
∈ RJ . The weighted candidate news vector calculated as at c,k = softmax Mc,k t: is expressed as ˜ cc,k atj d:j . Hence ˜ c is a candidate news vector that contains :t = j
the participation of historical news information. Since there is more than one historical news item viewed by the user, We use the attention mechanism [10] to aggregate the k weighted news items to obtain ˆ c. To enrich the news semantic relations, we combine the aggregated vector with the original vector to obtain a weighted candidate news representation containing information about the user’s news browsing history to more easily capture. 2.3
Granular Network
The ResNeXt [18] network for image feature extraction not only solves the degradation problem of deep neural networks, but also promotes convergence. In this paper, this residual network is introduced into the recommendation task to achieve multi-granularity feature extraction of news text information. The triple information in each layer of the network represents the input feature dimension, convolution kernel size, and output feature dimension. The MnRec model uses one-dimensional convolution with different convolution windows instead of two-dimensional convolution in the residual network to realize multi-granularity feature extraction of news text. In this paper, this feature extraction network is
MnRec: A News Recommendation Fusion Model
353
called granular network. The network uses three different convolution windows to group convolutions on the original text, and each group of convolution windows extracts feature representations with different granularities to capture more text feature information. The specific formula of the multi-granularity network is as follows: i (1) ˆ c + bi ciori = σ Wori i (2) cimul = σ Wmul cori + bi c, c1mul , c2mul , c3mul ) cfin = concat(ˆ
(3)
where i ∈ {1, 2, 3}; cireduct represent the tensor representation of the original text of the first layer of the Granet network after information extraction and dimension reduction; cimul represent the results of granular information extraction and dimension expansion in the second layer. Wmul represents the granularity sliding window, and as the window slides, the granularity information in the text sequence is gradually extracted; σ represents the ReLU activation function. b is the trainable parameter. Then, the model combines the text information of all granularities to obtain a multigranular news representation tensor. The final step of the model uses the residual concatenation operation in ResNeXt to concatenate the feature-extracted multigranularity information with the original information to obtain the final news representation vector. Similarly, we can get the embedded representation dfin of historical browsing news. 2.4
Representation Layer and Representation Attention
Important information about the news can appear anywhere in the news headline. We employ an attention-based bidirectional long short-term memory network to capture the most important semantic information in sentences. Bi-LSTM can realize representation learning of the information after feature extraction at a high level more abstractly. The calculation formula for the text representation is as follows: (4) cbi =Bi − LST M (cfin ) dbi =Bi − LST M (dfin )
(5)
The importance of the phrase information extracted by the granularity network varies for news matching in the prediction stage, and phrase information with low importance may be noise; thus, the MnRec model includes a representation attention mechanism after the Bi-LSTM. The fully connected layer is used to calculate the importance of the representation information to the global information at each time step, and this importance is then used as the weight value of the representation information. A weight constraint is added to the representation information so that the news matching can take advantage of phrase information with high importance and suppress phrase information with low importance. The representation attention is calculated as follows: cibi att = σ(Wi cibi + bi )
(6)
354
L. Cui et al.
djbi att = σ(Wj cjbi + bj )
(7)
where i ∈ {1, 2, · · · , T }, j ∈ {1, 2, · · · , J}; T and J denote the length of the news headlines of candidate news and historically browsed news, respectively. cibi and djbi are the tensor representations of the i-th and j-th time steps of the two input texts. Wi , Wj , bi , bj represent learnable parameters. σ represents the sigmoid activation function; cibi att and djbi att denote the weight values of the i-th and j-th time steps of the candidate news and the user’s historically viewed news, respectively. Then, the weight value is used to weight the original text to obtain the final outputs crep and drep , as follows: crep = cbi att · cbi
(8)
drep = drep att · dbi
(9)
where crep is the final candidate news representation, and drep is the representation of each piece of historical news browsed by the user. 2.5
RTCN Hierarchical Interest
The interests of users are multi-level and multi-granularity. Hierarchical modeling of user interests helps improve the final recommendation prediction. We design a multi-level interest modeling framework for users to capture users’ interests at multiple granularities in a hierarchical manner. RTCN uses a three-layer structure to apply dilated convolutions with different dilation rates to the news sequences browsed by users. Due to the different dilation rates, the granularity of news features extracted by each group of convolutional windows is different. This approach enables the model to capture different levels of user interest. The RTCN hierarchical interest is calculated as follows: w (10) F (dt ) = ReLU Fw ⊕ drepj±kδ + b k=0
0
where we set r = [drep1 , . . . , drepN ] (N is the length of the news sequence viewed by the user) and δ is the dilation rate. ⊕ is the connection of vectors. Fw denotes the convolution kernel and 2w + 1 is its window size. b is the trainable parameter. The input of each convolutional layer is the output of the previous layer. In experiment. We spliced the news representations obtained with differ
ent expansion rates to get r = r0 ,r1 ,r2 ,r3 . In order to enrich user interest information, the residual network idea is used to splicing with the original user browsing news vector. We then apply the attention mechanism to the concatenated vectors to obtain the final multi-granularity user interest representation u. The attention formula is as follows: ai = sof tmax(rTi tanh(Wu × q + bu )) u=
N
αi ri
(11) (12)
i=1
where Wu and bu are the projection parameters, q denotes the attention query vector. ri is the ith news representation of the user’s historical clicks.
MnRec: A News Recommendation Fusion Model
2.6
355
Click Predictor
The click predictor is used to predict the probability y that a user will click on a candidate news item [2,15,20]. The click probability is calculated by the inner product of the user representation vector u and the news representation vector hc . The calculation is as follows yˆ = uT hc . Following closely the work of Wu et al. [1,4,14,15,17]. we continue to train the model using information from a combination of k negative samples. The negative sample is composed of randomly selected news items that are displayed in the same impression but not clicked by the user. predict the scores of positive news yˆi+
−We−then jointly − and k negative news yˆ1 , yˆ2 , . . . , yˆK . We use softmax to normalize these click probability scores to calculate the click probabilities of the positive samples and we use the cross-entropy loss function for model training.
3
Experiments
3.1
Experiment Setup
Our experiments were conducted on a large public dataset called MIND [17]. The news items in this dataset are rich in textual information such as headlines, summaries, body text, categories, and entities, etc. Since no MIND-large test set label is provided, the test performance is submitted through the MIND news recommendation competition to obtain1 . The detailed statistics of the MINDlarge train dataset are shown in Table 1. Table 1. Statistical of the MIND-large training dataset. # Users # Impressions Avg. # words per title
711222 # News items 2232748 # Positive samples
101527 3383656
11.52 # Negative samples 80123718
Next, we introduce the experimental setup and hyperparameters of MnRec. In our experiments, the dimension of word embeddings is set to 300. We use pretrained Glove embeddings, which are trained on a corpus of 840 billion tokens. Used to initialize the embedding matrix. The number of CNN filters is set to 256 and the window size is 3. In our method, we apply dropout to mitigate overfitting. Set to 0.2. The number of LSTM neurons is 128, and Adam [8] is used as the optimization algorithm for gradient descent. The batch size is set to 200. Due to the limitation of GPU memory, in the neural network-based method, the maximum number of news clicks to learn user representations is set to 50, and the maximum length of news headlines is set to 20. The dilation rate of each dilated convolutional layer is [1, 2, 3] and the number of dilated convolutional neurons is 128. We use negative sampling with a ratio of 4 for model training. Metrics in our experiments include the mean AUC, MRR, nDCG@5 and nDCG@10 scores across all impressions [17]. We independently replicate each experiment 10 times and report the average performance. 1
https://competitions.codalab.org/competitions/24122.
356
3.2
L. Cui et al.
Performance Evaluation
First, we will evaluate the performance of our method by comparing with several SOTA baseline methods. (1) EBNR [9], which uses autoencoders to learn news representations and GRU networks to learn user representations; (2) DeepFM [5], a commonly used neural recommendation method that combines factorization machines with deep neural networks; (3) DKN [12], a knowledge-based news recommendation method; (4) NPA [14], a neural news recommendation method with a personalized attention mechanism that selects important words and news articles according to user preferences to obtain more informative news and user representations; (5) NAML [13], a neural news recommendation method with attentive multiview learning that incorporates different types of news information into the representation of news articles; (6) NRMS [15], a neural news recommendation method that utilizes multihead self-attention to learn news representations from words in news texts and user representations from previously clicked news articles; (7) FIM [11], which models user interest in candidate news from the semantic correlation of user clicked news and candidate news and then passes through a 3-D CNN network; (8) GnewsRec [6], modeling user interests from a user news graph. (9) ATRN [7], it proposes a new adaptive transformer network to improve recommendation. We implement all baselines on the full MIND2 dataset and report their results on the test set. Table 2 lists the average results. For DKN, NPA, NAML, NRMS and FIM we use official codes and settings3 . For other baselines, we re-implement them and set their parameters according to the experimental setup strategy reported in their paper. Codes are in https://github.com/cuixiaopi/MnRec. Table 2. The performance of different methods. MnRec significantly outperforms all baselines (p 1 per episode but consider the contribution of each task equally. In this paper we develop the idea of adaptive use of information from multiple tasks simultaneously via the multi-task approach for meta-learning proposed in [2]. Multi-task learning methods for deep neural networks can be divided into two main areas: soft and hard parameter sharing of hidden layers of neural network [21]. In accordance with [2], we use hard parameter sharing for all hidden layers of a convolutional network. Thus, one neural network is used for all tasks, and the presence of several tasks is reflected only in the loss function. For this purpose [2] proposed an adapted approach discussed in [12] which used task-depended (homoscedastic) uncertainty as a basis for weighting losses in a multi-task learning problem. Corresponding multi-task meta-learning loss function takes the following form: M
T LM ξt (ω t , {Qti }i=1 ) = (1)
M
(i) 2 i=1 (ωt ) (M )
where weights ω t = (ωt , . . . , ωt 3.3
1
Lθ,ti (Qti ) +
M
(i)
log(ωt )2 ,
(3)
i=1
) are hyperparameters.
Multi-task Weights Optimization
As shown in [12], model performance is extremely sensitive to multi-task weights selection and tuning of ω t is critical for the success of a multi-task learning. On
378
A. Boiarov et al.
the other hand, searching for these optimal weights is expensive and increasingly difficult for a large model with numerous tasks. In this work we focus on developing and investigating performance of optimization methods for hyperparameters ω t in the loss function (3). Nowadays gradient optimization approach is a natural choice in many deep learning algorithms [10,12]. Therefore, we consider an approach with embedding weights optimization in the backpropagation procedure similar to the one described in [12]. t and θt denote estimates at This procedure is described in Algorithm 1. Let ω iteration t of multi-task weights and model parameters, respectively. Algorithm 1. Training for episode ξt : (t1 , . . . , tM ) t−1 Input: NS , NQ , NC , θt−1 , ω t Output: Updated parameters estimates θt , ω 1: for i in {1, . . . , M } do 2: Sample NC random classes 3: Sample random elements in Sti and Qti 4: Compute task loss function Lθ,ti (Qti ) 5: end for T t−1 , {Qti }M 6: Compute LM ξt (ω i=1 ) via (3) t 7: Use multi-task weights optimizer to update ω T 8: Update parameters θt via SGD by LM ξt
3.4
SPSA for Tracking Method
When solving a meta-learning problem, algorithms face model uncertainties associated with a critical issue of limited amount of training samples per task, which may result in overfitting. This phenomenon is especially pronounced in the extreme case of one-shot learning. Therefore, meta-learning algorithms must be able to effectively adapt to these uncertainties. We approach this issue from optimization point and propose to use methods from the family of SPSA algorithms as a multi-task weights optimizer due to their robustness and successful application in various machine learning and control problems with uncertainties [9]. In this subsection, we propose the SPSA for Tracking algorithm firstly adapted for multi-task weights optimization. In order to cast the SPSA-based approach as a multi-task weights optimizer, we reformulate the problem of finding estimates of hyperparameters ω t in the multi-task loss function (3) as a non-stationary optimization problem suitable for a stochastic approximation [2,8]. In order to do this, we introduce an observation model for the training episode ξt that takes into account uncertainties arising in meta-learning: M T (4) Lt (ω t ) = LM ξt (ω t , {Qti }i=1 ) + νt , where νt is an additive external noise caused by uncertainties arising from the limited amount of training samples per task. Therefore, it is necessary to find an
Simultaneous Perturbation Optimization in One-Shot Meta-learning
379
t of an unknown vector ω t that minalgorithm that produces an estimate ω imizes mean-risk functional of objective functions (4) based on observations L1 , L2 , . . . , Lt from training episodes ξ1 , ξ2 , . . . , ξt . To the best of our knowledge, we propose to use SPSA for Tracking approach [8] as a multi-task weights optimizer for the first time. This algorithm simultaneously uses observation from the current and previous iterations which allows it to be more stable under parameters drift conditions. Let Δt ∈ Rd be a vector consisting of independent random variables with Bernoulli distribution, 0 a vector with the initial values, and {αt } and {βt } sequences of positive numω bers. Then SPSA for tracking multi-task weights optimizer (MTM SPSA-Track) constructs the following estimates of multi-task weights: ⎧ ω 2t−2 + βn Δt ), ⎪ ⎪L2t = L2t ( ⎪ ⎨L ω 2t−2 − βt Δt ) 2t−1 = L2t−1 ( ⎪ω 2t−1 = ω 2t−2 ⎪ ⎪ ⎩ 2t−1 2t−1 − αt Δt L2t −L 2t = ω . ω 2βt
(5)
Algorithm (5) has theoretical convergence guarantees expressed in terms of an upper bound of residuals between the estimates and the theoretical optimal solution [8].
4
Experiments
We conducted experiments on four datasets – CIFAR-FS [1], FC100 [18], miniImageNet [29] and tieredImageNet [20]. These datasets were selected as they became standard benchmarks over the last couple of years [5,28]. Experiments were performed on Rolos platform1 with a A100 vGPU, 48 vCPUs at 2.30GHz, 128GiB RAM. Tables 1, 2 and 3 summarize results of these experiments. The CIFAR-FS and FC100 datasets are derived from CIFAR-100 dataset having 100 classes with each class consisting of 600 images of size 32 × 32. In the CIFAR-FS dataset classes are randomly divided into groups of 64, 16 and 20 for training, validation, and testing, respectively, while in the FC100 datasets 20 superclasses of CIFAR-100 are split into groups of 12, 4 and 4. The miniImageNet and tieredImageNet datasets are derived from ILSVRC-2012 dataset having 1000 categories. The miniImageNet dataset consists of 100 classes randomly chosen from ILSVRC-2012 split into groups of 64, 16 and 20 as proposed in [19] which is a standard convention. The tieredImageNet dataset consists of 608 categories from ILSVRC-2012 grouped into 34 supercategories that are then split into groups of 20, 6 and 8. In both cases 600 images of 84 × 84 pixels in size are sampled for each class. As mentioned in Sect. 2, the experiments with Multi-Task Modification were performed using MAML and Prototypical Networks. For MAML algorithm we 1
https://rolos.com/.
380
A. Boiarov et al.
Table 1. Multi-Task Modification results on CIFAR-FS, FC100, miniImageNet and tieredImageNet (1-shot setting). Improvements are shown in bold. CIFAR-FS
FC100
2-way
5-way
2-way
5-way
Reproduced
74.8 %
54.5 %
66.0 %
36.4 %
MTM SPSA
76.4 % 54.7 % 67.0 % 36.4 %
Configuration
miniImageNet
tieredImageNet
2-way
5-way
2-way
5-way
73.2 %
45.9 %
73.3 %
47.9 %
MAML 74.9 % 48.0 % 74.9 % 47.8 %
MTM SPSA-Track 76.8 % 54.6 % 66.8 % 38.3 % 75.8 % 46.5 % 73.8 % 48.3 % ProtoNet Reproduced
77.8 %
MTM SPSA
79.1 % 59.7 % 65.1 % 36.0 % 74.7 % 50.2 % 73.6 % 49.5 %
58.9 %
65.0 %
35.7 %
74.2 %
50.0 %
72.9 %
49.4 %
MTM SPSA-Track 78.2 % 59.8 % 65.3 % 36.1 % 74.8 % 50.8 % 73.1 % 50.0 %
used the neural network φθ defined in the original paper [6] with 32-32-32-32 configuration where a-b-c-d denotes a 4-layer convolutional neural network with a, b, c, d filters in convolutional layers. Adam was used as the meta-optimizer with learning rate β = 10−3 and adaptation step size α = 0.01 as in (1). We used 5 adaptation steps during training and 10 during testing for all datasets. We selected meta-batch size of M = 4 and trained the model for 300 epochs, each containing 100 tasks, unless specified otherwise. Following [19] and the original paper, 15 samples per class were taken for evaluation. For the MTM the baseline model was trained for extra 40 epochs. In the case of ProtoNet we used 64-64-64-64 feature extraction backbone φθ as suggested in the original paper [24]. We followed the standard practice [7] for the meta-learning setup using SGD with Nesterov momentum of 0.9 and weight decay of 0.0005 as an optimizer. The learning rate was initially set to 0.1 and then decreased according to the strategy from [7]. During meta-training, we used the first 20 epochs for pre-training the model by using the original ProtoNet method and implementing the Multi-Task Modifications only for the last 40 epochs. We designed several experiment settings to research the relative advantage of using the multi-task loss function (3) and SPSA-based optimization over original methods (Table 1) and over gradient-based method (Table 2) where multi-task weights in the loss function are optimized jointly with the network parameters θ. In experiments we mainly used M = 4 tasks per training episode ξt (other values of M are indicated explicitly) and the best model for testing was selected on the 1 1 validation set. For SPSA and SPSA-Track we set αn = 0.25/n 6 , βn = 15/n 24 in (5) as per the theoretical result from [9]. During the experiments we found that L2 -normalization of multi-task weights in MTM SPSA and MTM SPSA-Track improves the stability of training. Results of all experiments are formulated in terms of average few-shot classification accuracy after 1000 testing iterations. We used the miniImageNet dataset to compare our method with the prior work. Since the majority of recent approaches use more advanced convolutional neural networks with higher embedding dimension such as residual networks (ResNet) [10] as feature extraction backbones, we implemented the original
Simultaneous Perturbation Optimization in One-Shot Meta-learning
381
Table 2. Comparison of gradient-based (Backprop), SPSA and SPSA for Tracking multi-task weights optimizers on 1-shot, 5-way experiments. Algorithm
CIFAR-FS FC100
miniImageNet tieredImageNet
Backprop
53.1 %
37.6 %
47.4 %
46.7 %
SPSA
54.7 %
36.4 %
48.0 %
47.8 %
MAML
SPSA-Track 54.6 %
38.3 % 46.5 %
48.3 %
ProtoNet Backprop
59.4 %
35.5 %
50.4 %
49.2 %
SPSA
59.7 %
36.0 %
50.2 %
49.5 %
SPSA-Track 59.8 %
36.1 % 50.8 %
50.0 %
ProtoNet with ResNet-12 backbone provided in [15] to compare against the results of other methods with backbones from the ResNet family. We did not include approaches that were developed for semi-supervised and transductive learning settings in this comparison since such approaches use the statistics of query examples or statistics across the one-shot tasks. We also excluded methods that use non-episodic pre-training as mentioned in Sect. 2. Table 3 shows that, when used with a comparable backbone, our multi-task meta-learning modification of ProtoNet with stochastic approximation increases one-shot classification accuracy to the level of significantly more advanced methods and is competitive against state-of-the-art meta-learning approaches. These results are presented with 95 % confidence intervals.
5
Ablation Study
Improvements of MTM with SPSA and SPSA-Track. We have conducted experiments for four datasets most widely used in the field of one-shot learning as shown in Table 1. On CIFAR-FS, we have improved on original methods up to 2.0%, with the largest improvement by 2-way MAML MTM SPSA-Track. On FC100 benchmark, the largest improvement of 1.9% has been achieved with the novel MTM SPSA-Track method on MAML in 5-way scenario. On tieredImageNet, we got improvements up to 1.6% for MAML in 2-way scenario with SPSA. The top performing methods include MTM SPSA and MTM SPSA-Track for both MAML and ProtoNet. MTM SPSA-Track gives the greatest boost in 5-way settings. The last dataset we considered was miniImageNet, which is the most widely used benchmark for few-shot learning. Here we have achieved significant improvements up to 2.6% with MAML MTM SPSA-Track leading and MAML MTM SPSA following. Similar results are observed for ProtoNet. The results presented in Table 1 show that proposed MTM with SPSA and SPSA-Track methods outperform the original approaches on all four benchmarks. Novel SPSATrack demonstrates the largest improvement over the baseline in most cases. The experiment results shown in Table 3 suggest that applying our method to ProtoNet with a more modern ResNet-12 backbone gives performance improvement of 3.72% for MTM SPSA, 4.81% for MTM SPSA-Track, making MTM
382
A. Boiarov et al.
Table 3. Comparison to prior work on miniImageNet meta-test split. Bold values are the accuracy no less than 1 % compared with the highest one. Algorithm
Backbone 1-shot 5-way
MAML [4, 6] Chen et al. [4] Relation Networks [4, 27] Matching Networks [4, 29] RAP-ProtoNet [11] ProtoNet (reproduced) [24] Gidaris et al. [7] SNAIL [16] TADAM [18] Wang et al. [30] MTL [26] vFSL [31] MetaOptNet [15] DSN [23]
ResNet-18 49.61 ± ResNet-18 51.87 ± ResNet-18 52.48 ± ResNet-18 52.91 ± ResNet-10 53.64 ± ResNet-12 56.52 ± ResNet-15 55.45 ± ResNet-15 55.71 ± ResNet-12 58.50 ± ResNet-12 59.84 ± ResNet-12 61.20 ± ResNet-12 61.23 ± ResNet-12 62.64 ± ResNet-12 62.64 ±
0.92 % 0.77 % 0.86 % 0.88 % 0.60 % 0.45 % 0.89 % 0.99 % 0.30 % 0.22 % 1.80 % 0.26 % 0.61 % 0.66 %
ProtoNet MTM SPSA (Ours) ResNet-12 60.24 ± 0.71 % ResNet-12 61.33 ± 0.74 % ProtoNet MTM SPSA-Track (Ours) ProtoNet MTM SPSA-Track (M = 2) (Ours) ResNet-12 61.94 ± 0.73 %
SPSA-Track better performing in this scenario as well. Such a performance improvement puts our result among the best in the field. Applying MTM SPSATrack to ProtoNet with number of tasks M = 2 gives a performance improvement of 5.42% that is competitive against state-of-the-art methods. The fact that this improvement has been achieved by modifying the loss function only and the fact that MTM SPSA and MTM SPSA-Track can be applied to almost any of the meta-learning methods in the table makes our result even more significant. It is worth noting that in the case of meta-learning methods like MAML that originally use meta-batch with several tasks, our Multi-Task Modification requires practically no additional computational costs. Gradient-Based vs SPSA-Based Optimization. We explored the comparison between gradient-based and SPSA-based approaches of multi-task weight optimization. As can be seen from Table 2, SPSA-based approaches give superior results in all experiments, so we can conclude that zero-order optimizers are better suited for the proposed multi-task modification in one-shot setting.
6
Conclusion
In this paper we have developed a new multi-task meta-learning modification with stochastic approximation for one-shot learning. The application of this
Simultaneous Perturbation Optimization in One-Shot Meta-learning
383
approach to optimization-based method MAML and metric-based method ProtoNet was investigated on four most widely used one-shot learning benchmarks. In all experiments our algorithm showed significant improvements over baseline methods. In addition, in most cases SPSA-based approach was better than gradient method. We presented novel SPSA for Tracking algorithm as a multi-task weights optimizer which has demonstrated the largest performance boost on average. For future work, we aim to apply the described approach to state-ofthe-art few-shot learning algorithms and to reinforcement learning.
References 1. Bertinetto, L., Henriques, J.F., Torr, P., Vedaldi, A.: Meta-learning with differentiable closed-form solvers. In: International Conference on Learning Representations (2019) 2. Boiarov, A., Granichin, O., Granichina, O.: Simultaneous perturbation stochastic approximation for few-shot learning. In: 2020 European Control Conference (ECC), pp. 350–355 (2020) 3. Boiarov, A., Tyantov, E.: Large scale landmark recognition via deep metric learning. In: CIKM 2019: Proceedings of the 28th ACM International Conference on Information and Knowledge Management, pp. 169–178 (2019) 4. Chen, W.Y., Liu, Y.C., Kira, Z., Wang, Y.C.F., Huang, J.B.: A closer look at few-shot classification. In: International Conference on Learning Representations (2019) 5. Dhillon, G.S., Chaudhari, P., Ravichandran, A., Soatto, S.: A baseline for fewshot image classification. In: International Conference on Learning Representations (2020) 6. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: Proceedings of the 34th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 70, pp. 1126– 1135. PMLR (2017) 7. Gidaris, S., Komodakis, N.: Dynamic few-shot visual learning without forgetting. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2018) 8. Granichin, O., Amelina, N.: Simultaneous perturbation stochastic approximation for tracking under unknown but bounded disturbances. IEEE Trans. Autom. Control 60(6), 1653–1658 (2015) 9. Granichin, O., Volkovich, Z., Toledano-Kitai, D.: Randomized Algorithms in Automatic Control and Data Mining, Intelligent Systems Reference Library, vol. 67. Springer, Heidelberg (2015). https://doi.org/10.1007/978-3-642-54786-7 10. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016) 11. Hong, J., et al.: Reinforced attention for few-shot learning and beyond. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 913–923 (2021) 12. Kendall, A., Gal, Y., Cipolla, R.: Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2018)
384
A. Boiarov et al.
13. Kiefer, J., Wolfowitz, J.: Stochastic estimation of the maximum of a regression function. Ann. Math. Stat. 23(3), 462–466 (1952) 14. Koch, G., Zemel, R., Salakhutdinov, R.: Siamese neural networks for one-shot image recognition. In: ICML Deep Learning Workshop (2015) 15. Lee, K., Maji, S., Ravichandran, A., Soatto, S.: Meta-learning with differentiable convex optimization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2019) 16. Mishra, N., Rohaninejad, M., Chen, X., Abbeel, P.: A simple neural attentive meta-learner. In: International Conference on Learning Representations (2018) 17. Musgrave, K., Belongie, S., Lim, S.-N.: A metric learning reality check. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12370, pp. 681–699. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58595-2 41 18. Oreshkin, B., Rodr´ıguez L´ opez, P., Lacoste, A.: TADAM: task dependent adaptive metric for improved few-shot learning. In: Advances in Neural Information Processing Systems, vol. 31. Curran Associates, Inc. (2018) 19. Ravi, S., Larochelle, H.: Optimization as a model for few-shot learning. In: International Conference on Learning Representations (2017) 20. Ren, M., et al.: Meta-learning for semi-supervised few-shot classification. In: International Conference on Learning Representations (2018) 21. Ruder, S.: An overview of multi-task learning in deep neural networks. arXiv preprint arXiv:1706.05098 (2017) 22. Rusu, A.A., et al.: Meta-learning with latent embedding optimization. In: International Conference on Learning Representations (2019) 23. Simon, C., Koniusz, P., Nock, R., Harandi, M.: Adaptive subspaces for few-shot learning. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2020) 24. Snell, J., Swersky, K., Zemel, R.: Prototypical networks for few-shot learning. In: Advances in Neural Information Processing Systems, vol. 30. Curran Associates, Inc. (2017) 25. Spall, J.C.: Multivariate stochastic approximation using a simultaneous perturbation gradient approximation. IEEE Trans. Autom. Control 37(3), 332–341 (1992) 26. Sun, Q., Liu, Y., Chua, T.S., Schiele, B.: Meta-transfer learning for few-shot learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (2019) 27. Sung, F., Yang, Y., Zhang, L., Xiang, T., Torr, P.H., Hospedales, T.M.: Learning to compare: Relation network for few-shot learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2018) 28. Tian, Y., Wang, Y., Krishnan, D., Tenenbaum, J.B., Isola, P.: Rethinking few-shot image classification: a good embedding is all you need? In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12359, pp. 266–282. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58568-6 16 29. Vinyals, O., Blundell, C., Lillicrap, T., Kavukcuoglu, K., Wierstra, D.: Matching networks for one shot learning. In: Advances in Neural Information Processing Systems, vol. 29. Curran Associates, Inc. (2016) 30. Wang, H., Zhao, H., Li, B.: Bridging multi-task learning and meta-learning: towards efficient training and effective adaptation. In: Proceedings of the 38th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 139, pp. 10991–11002. PMLR (2021) 31. Zhang, J., Zhao, C., Ni, B., Xu, M., Yang, X.: Variational few-shot learning. In: Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) (2019)
Searching for Textual Adversarial Examples with Learned Strategy Xiangzhe Guo, Ruidan Su, Shikui Tu(B) , and Lei Xu Department of Computer Science and Engineering, Shanghai Jiao Tong University, Shanghai, China {gggggxz,suruidan,tushikui,leixu}@sjtu.edu.cn
Abstract. Adversarial attacks can help to reveal the vulnerability of neural networks. In the text classification domain, synonym replacement is an effective way to generate adversarial examples. However, the number of replacement combinations grows exponentially with the text length, making the search difficult. In this work, we propose an attack method which combines a synonym selection network and search strategies of beam search and Monte Carlo tree search (MCTS). The synonym selection network learns the patterns of synonyms which have high attack effect. We combine the network with beam search to gain a broader view by multiple search paths, and with MCTS to gain a deeper view by the exploration feedback, so as to effectively avoid local optimum. We evaluate our method with four datasets in a challenging black box setting which requires no access to the victim model’s parameters. Experimental results show that our method can generate high-quality adversarial examples with higher attack success rate and fewer number of victim model queries, and further experiments show that our method has higher transferability on the victim models. The code and data can be obtained via https://github.com/CMACH508/SearchTextualAdversarialExamples. Keywords: Textual adversarial attack Beam search · Monte Carlo tree search
1
· Learnable adversarial attack ·
Introduction
Deep neural networks have been widely applied in real-world natural language processing (NLP) applications, such as spam detection, abuse language detection, fake news detection, and so on. However, these systems may be cheated by adversarial examples, which are crafted by maliciously perturbing the input [1,6,12]. Adversarial examples are imperceptible to humans, but they can confuse the neural network based NLP models, which pose a threat to the real-world systems. For example, changing a few words of a spam message, even if the meaning remains the same, can deceive the spam detection system. On the other hand, the adversarial examples can help to reveal the vulnerability of the neural networks [10]. Therefore, it is essential to explore generating textual adversarial examples, so as to construct robust NLP systems. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 385–396, 2023. https://doi.org/10.1007/978-981-99-1639-9_32
386
X. Guo et al. Original Input
I
Synonyms
like
this
film
love
cinema
enjoy
movie
I like this film
movie
I like this movie
show like
Word Importance I
like
this
film
I like this movie
……
Fig. 1. Greedy search for synonym replacement based attack. Words are processed in descending order of importance, and at each step, the synonym with the best attack effect is selected for replacement, and other search branches are discarded. The underlined part of the text indicates that it has been fixed and will not be changed.
Adversarial attacks have been extensively studied in computer vision tasks [10], mainly by perturbing pixels. In the NLP field, the attack is still challenging due to the discrete nature of text. Some previous works proposed to generate adversarial examples by modifying characters (e.g., like→l!ke) [7], but such character level modification can be easily detected by spelling checks [22]. Recently, synonym replacement has become a common attack approach: it replaces some words with synonyms to generate adversarial examples [1]. In synonym replacement, a word may have multiple synonyms, and the number of replacement combinations is exponentially large. Greedy search is a practical solution to this problem [9,13,15,16,26]. As shown in Fig. 1, it first evaluates the importance of each word, and then replaces words in descending order of importance. For each word, it chooses the synonym with the best attack effect for replacement, until a valid adversarial example is found. The greedy search based synonym replacement can be further improved in terms of efficiency and performance. On the one hand, a word may have dozens of synonyms (e.g., TextFooler, a popular greedy search method, has 50 synonyms per word [13]), resulting in a relatively large combination space, which should be effectively reduced. On the other hand, the greedy search makes the local optimal choice at each step and discards all other search branches, which discards a large number of possible solutions and may lower the attack performance. In this work, we propose an effective and efficient attack method which incorporates a synonym selection network and new search strategies to solve the two problems. For a specific word, some synonyms are not appropriate in the context and some have poor attack effect [9,15,16]. Take an example: “I plan to have a rest”. Both ‘break’ and ‘remainder’ are synonyms of ‘rest’, but ‘remainder’ is inappropriate in the context. Moreover, for a specific victim model, some synonyms are more likely to cheat it due to the model’s bias or overfitting [10]. Thus, we use a synonym selection network [28] to select the synonyms which can both maintain the semantics and effectively cheat the victim model. Only the selected “high-quality” synonyms are kept, and the search space can be effectively reduced. On the one hand, we use this network to efficiently determine the word
Searching for Textual Adversarial Examples with Learned Strategy
387
replacement order, which can achieve high attack performance. On the other hand, we combine this network with search strategies of beam search [24] and Monte Carlo tree search (MCTS) [3] for higher search performance. For beam search, it maintains multiple search paths to gain a broader search view, and the synonym selection network excludes inappropriate synonyms at each step to reduce victim model calls. For MCTS, the possible synonyms correspond to actions, and the synonym selection network determines the prior probability of each action and guides the random rollout. The exploration feedback of MCTS provides a deeper search view for higher search performance. We conduct experiments with four datasets in the black box setting which requires no internal information of the victim model. Experimental results show that our method not only achieves higher attack success as compared to the baseline methods, but also generates more high-quality adversarial examples and requires fewer victim model calls with the help of the synonym selection network. Further experiments also show that our method has higher transferability.
2
Related Work
Synonym replacement based attack can be formalized as a combinatorial optimization problem [29,30]. Previous works proposed population based algorithms for this problem, such as genetic algorithm [1,18] and discrete particle swarm optimization [30], but such algorithms are very time-consuming [29]. Recent studies have focused more on the greedy search methods, such as TextFooler [13], PWWS [26], BERT-Attack [16], etc. Greedy search determines the word replacement order by word ranking, then at each step, it selects the synonym with the best attack effect for replacement. The word ranking step mainly uses a scoring function to evaluate word importance and replaces the words in descending order of importance, so that the more important words are replaced first. Scoring functions are usually designed as the decrease in the predicted probability of the true label when a word is deleted [9,13] or replaced with special tokens ([MASK], [UNK], etc.) [8,16,26]. Another focus of previous works is on the search space, which mainly depends on the construction of synonym sets. These methods include selecting synonyms from the word embedding space [1,13,18], using WordNet [20,26] or HowNet [5,30], and using pretrained masked language models (BERT-Attack [16], BAE [9] and CLARE [15]). CLARE is similar to our proposed method in synonym selection network, and the difference lies in that our network is trained to select synonyms which are not only natural in the context, but also have high attack effect against the victim model.
3 3.1
Methodology Problem Definition
Given a trained text classification model f as the victim model, for a text x = (w1 , w2 , . . . , wn ), the predicted probabilities are f (x) = (p1 , p2 , . . . , pC ),
388
X. Guo et al.
where C is the number of classes. Assuming that for each word wi , the synonym set is Si , and y is the true label which can be correctly predicted by f , then a valid adversarial example is x = (w1 , w2 , . . . , wn ) such that wi ∈ Si + {wi }, arg maxy f (x )y = y, sim(x, x ) ≥ , where sim(·, ·) is a semantic similarity function and is a similarity threshold. 3.2
Synonym Selection Network
A word may have dozens of synonyms. In order to effectively reduce the search space, we use a synonym selection network [28] to select the synonyms that fit the context and are more likely to deceive the victim model, so as to reduce replacement attempts. Given a text x = (w1 , w2 , . . . , wn ), the label y, and the substitute position i (1 ≤ i ≤ n), we first encode wi with BERT [4]. Then a projection matrix M is set for each label ∈ {1, 2, . . . , C}, and the encoding of wi is projected to a label specific space: ei = BERT(x, i) ∈ Rh , hi = M y ei ∈ Rd . Then we set up a vocabulary V which contains all possible synonyms, that is, for any text x and any word wi , Si ⊆ V. We transform hi into a |V|-dimensional vector using a feedforward network Fffn and the softmax function, denoting the probability of each synonym being selected: pi = softmax(Fffn (hi )) ∈ R|V| .
(1)
Since Si ⊆ V, the probability of each synonym in Si being selected can be obtained from pi . 3.3
Beam Search
Given a text, we first use the synonym selection network (Sect. 3.2) to rank the words in a way similar to PWWS [26]. Formally, given a text x = (w1 , w2 , . . . , wn ), we score the synonyms for each word wi using the synonym selection network and select the synonyms with the top Kwr highest score, denoted by Si,wr ⊆ Si . We compute the attack effect as: E i = max (f (x)y − f (˜ xi,s )y ), i = 1, 2, . . . , n , s∈Si,wr
then we also compute the saliency of each word as: S i = f (x)y − f (˜ xi,[UNK] )y , i = 1, 2, . . . , n , finally, the words are scored by: softmax(S)i · E i , i = 1, 2, . . . , n .
(2)
Searching for Textual Adversarial Examples with Learned Strategy
389
Algorithm 1: Beam Search Based Attack
1 2 3 4 5 6 7 8 9 10 11 12 13
Input: text x = (w1 , w2 , . . . , wn ), true label y, victim model f Input: synonym selection network fss Input: semantic similarity function sim(·, ·) and similarity threshold Input: synonym selection number Kwr , Kss and beam size b wi1 , wi2 , . . . , win ← rank the words based on fss and Kwr ; X ← {x}; for i ← i1 , i2 , . . . , in do Xnew ← ∅; for x ∈ X do Si,ss ← synonyms in Si with the top Kss highest score in fss (x , y, i); Xcand ← {x˜ i,s | s ∈ Si,ss , sim(x, x˜ i,s ) ≥ }; if find a valid adversarial example x∗ in Xcand then return x∗ ; Xnew ← Xnew + Xcand ; Xnew ← Xnew + X ; // Adding X corresponds to the case where wi is unchanged. X ← the b texts with the lowest predicted probability in Xnew ; return None;
The word ranking method is similar to PWWS [26], but requires much fewer victim model queries because the synonym selection network reduces synonyms when computing the attack effect E i in (2). During the beam search, we maintain a text set X . Initially, X only contains the original text, then we process the words in the descending order of score. For substitute position i, we enumerate all the texts in X and query the synonym selection network for word wi . The synonyms with the top Kss highest score in Si are selected for substitution and the new texts are collected. Then we select the b texts with the lowest predicted probability from the collected texts as the new text set X , where b is the predefined beam size, and then turn to the next substitute position. The procedure is repeated until a valid adversarial example is found, or all positions have been processed and the attack fails. The complete procedure is summarized in Algorithm 1. 3.4
MCTS
The process of searching for textual adversarial examples can be viewed as a sequential decision making process, where the state is the pair of text and substitute position, and the actions correspond to the optional synonyms. MCTS can determine actions with a deeper field of view by the exploration feedback. Given a text, we determine the word substitution order in the same way as the previous beam search. Then MCTS is applied to select the synonym for each substitute position. As shown in Fig. 2, in MCTS, each tree node represents a state s which is composed of a text and a substitute position, and each edge
390
+
X. Guo et al. Selection
Expansion
Evaluation
Backpropagation
A rather bland film
A rather bland film
A rather bland film
A rather bland film
+
max
ss
A pretty bland film +
A very bland film
A pretty bland film
A very bland film
A pretty bland film
A very bland film
A pretty bland film
+
A pretty bland movie A pretty bland film
A pretty bland movie A pretty bland film
ss
A pretty bland movie A pretty bland film
A very plain movie
A pretty bland movie A pretty bland film
A very plain movie
Fig. 2. MCTS in textual adversarial attack. Each node contains a text and a substitute position which is highlighted. Selection: Starting at the root node, each time the action with the largest Q + u is selected to traverse the tree. Expansion: If selection ends at some unfinished state, a new node is expanded, and the new state’s prior probabilities are initialized by the synonym selection network. Evaluation: Starting at the leaf node, a random rollout is conducted guided by the synonym selection network. After the rollout ends, the current leaf node is evaluated by the state’s value and the rollout reward. Backpropagation: Q and N of the edges from the root to the leaf are updated.
represents an action a which is a synonym for the position. An edge (s, a) stores an action value Q(s, a), a visit count N (s, a) and a prior probability P (s, a). When a new node is added to the tree, all the downward edges are determined, i.e., the possible synonyms. For each new edge, Q(·, ·) and N (·, ·) are initialized as 0, and P (·, ·) is initialized by the synonym selection network. At the beginning of the search, a root node is initialized by the current text and substitute position. Then the search process involves multiple search iterations, each of which contains four steps: – Selection. This step traverses starting from the root and down to a leaf node. At state s, an edge a is selected by a = arg maxa (Q(s, a ) + u(s, a )), where u(s, a ) = α · P (s, a )/(1 + N (s, a )), and α is a scaling factor. The formula of u(·, ·) indicates that the greater the prior probability of the edge, the fewer number of visits, and the more preferential the edge is. – Expansion. If the selection step ends at some state that is unfinished (that is, a text which is not an adversarial example and has not exhausted all substitute positions), a node is expanded. For each downward edge of the new node, the prior probabilities P (·, ·) should be initialized using the synonym selection network, and Q(·, ·) and N (·, ·) are initialized as 0. – Evaluation. At the leaf node, the state s is evaluated by e(s) = λ · v(s) + (1 − λ) · r(s), where λ is a mixing parameter, v(s) is the state value, and r(s) is the random rollout reward. v(s) is defined as v(s) = (1 − f (s.text)y )2 · sim(s.text, x)2 , the lower the predicted probability, the higher the semantic similarity with the original text, and the more valuable the state is. The random rollout starts from the current leaf node, at each step, an action is randomly selected from the actions with the top Kss highest prior probabilities by the synonym selection network. The procedure is repeated until a terminal state is encountered. Assuming that the terminal state is st , the reward is
Searching for Textual Adversarial Examples with Learned Strategy
391
r(s) = −(f (st .text)y − 1/C)2 , where C is the number of classes of the victim task. r(s) reflects the expected benefits in the future. – Backpropagation. For all edges on the path from the root to the leaf, N (·, ·) and Q(·, ·) are updated. Assuming that there are m iterations, for edge (s, a), m m N (s, a) = i=1 1(s, a, i) and Q(s, a) = i=1 1(s, a, i) · e(sileaf )/N (s, a), where 1(s, a, i) indicates whether the i-th iteration traverse through (s, a), and sileaf is the leaf node visited in the i-th iteration.
4
Experiments
4.1
Datasets and Victim Models
The experiment is conducted with the following datasets: – – – –
MR [21]. A sentence level movie review sentiment analysis dataset. IMDB [17]. A document level movie review sentiment analysis dataset. Yelp [31]. A restaurant and hotel review sentiment analysis dataset. SNLI [2]. The Stanford Natural Language Inference dataset which is to determine whether the relationship between a premise and a hypothesis is entailment, neutral or contradiction.
In order to comply with the black box setting, we do not use any training data of the victim model. For evaluation, we randomly select 1, 000 test samples for which the victim model can make a correct prediction for each dataset (MR has only 898 correctly predicted test samples, and we use all of them for evaluation). For the training of the synonym selection network, we construct training data using 20, 000 samples from datasets of similar tasks: for MR and Yelp, the training data is constructed from IMDB (which are all sentiment analysis tasks); for IMDB, the training data is from Yelp (both sentiment analysis tasks); for SNLI, the training data is from MNLI [27] (both natural language inference tasks). We use trained BERT [4] and BiLSTM [11] as the victim models. 4.2
Experiment Setup
Previous works have different ways to construct synonym sets, resulting in different search spaces. To make a fair comparison, we use the same search space for all compared methods: the synonym set is constructed based on HowNet [5], and only content words are substituted [30]. Based on the same search space, we compare the following methods: – TextFooler (2020) [13]. Each word’s importance for word ranking is defined as the decrease in the predicted probability when it is deleted from the text. – PWWS (2019) [26]. Each word’s importance is determined by the decrease in the predicted probability when it is replaced by [UNK] as well as the maximum attack effect when considering all synonyms.
392
X. Guo et al.
– LSH (2021) [19]. The scoring function is similar to PWWS, but uses attention mechanism and locality sensitive hashing to reduce victim model queries. For semantic similarity constraint, the texts are encoded using SentenceTransformers [25], and the semantic similarity is defined as the cosine similarity of the two encoded representation vectors. The similarity threshold is = 0.9. For the synonym selection network, d = 128, θ = 0.95, pppl = 90. The network is trained for 5 epochs with the learning rate of 10−5 . For beam search, Kwr = 5, Kss = 15. For MCTS, Kwr = 5, Kss = 3, α = 1.0, λ = 0.4, and the maximum number of search iterations is 200. 4.3
Evaluation Metrics
In our experiment, we evaluate the methods with the following metrics: – Attack Success Rate. The percentage of successful attacks. – Semantic Similarity. The average semantic similarity between the original texts and the corresponding adversarial examples. – Perplexity. The average perplexity of the adversarial examples as measured by GPT-2 [23]. – Number of Queries. The average number of victim model queries. 4.4
Main Results
Table 1 shows the attack success rate results. We can observe that both beam search (with beam size 4) and MCTS can achieve higher attack success rate than the baseline methods, which demonstrates that the two search strategies can effectively avoid local optimization and achieve higher search performance. Table 2 shows the attack efficiency and adversarial example quality evaluation results. We also show the number of victim model queries when attacking BERT model on MR dataset in Fig 3. Taking the attack efficiency and adversarial example quality together, beam search performs almost the best of all methods: it requires few victim model queries and can generate high-quality adversarial examples. This may benefit from the synonym selection network which excludes the low-quality synonyms. Although PWWS has high performance in attack success rate and adversarial example quality, it requires a large number of model queries. Beam search requires similar number of queries to TextFooler and LSH, but its attack success rate and text quality both outperform the other two. MCTS has the highest attack success rate, but it requires more queries to the victim model. If one pursues the best performance regardless of efficiency, MCTS is a better choice. Otherwise, beam search which better trades off between the performance and the efficiency is enough.
Searching for Textual Adversarial Examples with Learned Strategy
393
Table 1. The attack success rate (%) on the four datasets. Method
BERT BiLSTM MR IMDB Yelp SNLI MR IMDB Yelp
TextFooler (2020) PWWS (2019) LSH (2021)
53.8 92.2 65.1 97.2 62.7 95.6
Beam Search (b = 1) 63.4 96.9 Beam Search (b = 4) 67.8 98.3 67.8 97.5 MCTS
86.8 81.3 92.0 88.8 90.1 84.9
79.6 99.5 81.6 99.6 80.9 99.6
97.0 98.1 97.8
91.5 88.3 81.3 99.6 92.6 90.9 83.3 99.7 95.2 91.1 83.7 99.8
98.4 98.8 98.7
Table 2. The attack efficiency and adversarial example quality evaluation results against BERT model. “Sim” (↑) means semantic similarity, “PPL” (↓) means perplexity. Method TextFooler (2020) PWWS (2019) LSH (2021)
MR #Queries Sim 80 334 52
Beam Search (b=1) 72 Beam Search (b=4) 101 202 MCTS
4.5
PPL
0.935 348.3 0.946 324.0 0.940 337.5 0.945 0.943 0.938
SNLI #Queries Sim
PPL
44 301 37
140.7 123.5 129.1
312.2 84 317.2 96 343.1 163
0.934 0.940 0.941
0.944 120.7 0.944 125.0 0.943 178.4
Decomposition Analysis
In this part, we evaluate the attack success rate when the synonym selection network is removed, in which case the synonyms are assigned uniform probabilities. We show the results on MR dataset in Fig. 4: “WR+Search” means the full setting, that is, the network is not removed, “WR” means it is used only in word ranking, “Search” means it is only used in the beam search or MCTS, “None” means that it is removed from both procedures. The absence of the network from any procedure will lead to performance decline. In addition, MCTS is less affected than beam search: it has much higher attack success rate than beam search, especially when the network is only used in the search, which may benefit from a deeper search view obtained by the exploration feedback information. 4.6
Transferability
In this part, we evaluate the transferability of our attack method, which is the attack model’s ability to cheat a machine learning system without any access to the underlying model [14]. We evaluate the classification accuracy of BiLSTM model using the adversarial examples generated by attacking BERT model,
394
X. Guo et al.
Fig. 3. The average number of queries when attacking BERT model on MR.
Fig. 4. Decomposition analysis for the synonym selection network.
Table 3. The classification accuracy (↓) when the adversarial examples are transferred to other victim models. Method
BERT → BiLSTM BiLSTM → BERT MR IMDB Yelp MR IMDB Yelp
TextFooler (2020)
72.4 74.2
79.9 81.9 80.4
89.0
PWWS (2019)
72.3 75.3
78.7 82.5 81.8
90.5
LSH (2021)
72.6 74.9
78.1 82.1 81.2
91.0
Beam Search (b = 4) 71.1 73.9
78.3 81.6 81.7
89.2
MCTS
74.2 82.4 79.3
87.7
70.1 71.6
and vice versa. The victim’s lower classification accuracy means the higher transferability of the attack methods. The classification results are shown in Table 3. Compared to the baseline methods, our method achieves lower classification accuracy, indicating higher transferability, which may benefit from the synonym selection network excluding low-quality synonyms to generate more natural texts.
5
Conclusion
In this paper, we focus on the synonym replacement based textual adversarial attack, and propose an effective and efficient attack method. Specifically, a synonym selection network is proposed to learn the patterns of high attack effect synonyms, resulting that only top effective synonyms will be selected for further substitution. It not only reduces the search space but also helps to generate high-quality adversarial examples. Further, we also propose a beam search method following the synonym selection network for broader view of the search space, and an MCTS based method for deeper view by the exploration feedback, which achieve higher attack performance. Experimental results demonstrate our method outperforms the state-of-the-art methods in attack success rate and number of victim model queries. Moreover, transferability evaluation verifies that the
Searching for Textual Adversarial Examples with Learned Strategy
395
synonym selection network assists in generating more natural texts, making the proposed method more robust in application. Acknowledgement. This work was supported by the National Key R&D Program of China (2018AAA0100700), and Shanghai Municipal Science and Technology Major Project (2021SHZDZX0102).
References 1. Alzantot, M., Sharma, Y., Elgohary, A., Ho, B.J., Srivastava, M., Chang, K.W.: Generating natural language adversarial examples. In: Proceedings of EMNLP (2018) 2. Bowman, S.R., Angeli, G., Potts, C., Manning, C.D.: A large annotated corpus for learning natural language inference. In: Proceedings of EMNLP (2015) 3. Browne, C.B., et al.: A survey of Monte Carlo tree search methods. IEEE Trans. Comput. Intell. AI Games (2012) 4. Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: BERT: pre-training of deep bidirectional transformers for language understanding. In: Proceedings of NAACL (2019) 5. Dong, Z., Dong, Q.: Hownet and the Computation of Meaning. World Scientific (2006) 6. Ebrahimi, J., Rao, A., Lowd, D., Dou, D.: HotFlip: white-box adversarial examples for text classification. In: Proceedings of ACL (2018) 7. Eger, S., et al.: Text processing like humans do: visually attacking and shielding NLP systems. In: Proceedings of NAACL (2019) 8. Gao, J., Lanchantin, J., Soffa, M.L., Qi, Y.: Black-box generation of adversarial text sequences to evade deep learning classifiers (2018) 9. Garg, S., Ramakrishnan, G.: BAE: BERT-based adversarial examples for text classification. In: Proceedings of EMNLP (2020) 10. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. In: Proceedings of ICLR (2015) 11. Hochreiter, S., Schmidhuber, J.: Long short-term memory. Neural Comput. (1997) 12. Jia, R., Liang, P.: Adversarial examples for evaluating reading comprehension systems. In: Proceedings of EMNLP (2017) 13. Jin, D., Jin, Z., Zhou, J.T., Szolovits, P.: Is BERT really robust? A strong baseline for natural language attack on text classification and entailment. In: Proceedings of AAAI (2020) 14. Kurakin, A., Goodfellow, I., Bengio, S.: Adversarial examples in the physical world. In: ICLR Workshop (2017) 15. Li, D., et al.: Contextualized perturbation for textual adversarial attack. In: Proceedings of NAACL (2021) 16. Li, L., Ma, R., Guo, Q., Xue, X., Qiu, X.: BERT-ATTACK: adversarial attack against BERT using BERT. In: Proceedings of EMNLP (2020) 17. Maas, A.L., Daly, R.E., Pham, P.T., Huang, D., Ng, A.Y., Potts, C.: Learning word vectors for sentiment analysis. In: Proceedings of ACL (2011) 18. Maheshwary, R., Maheshwary, S., Pudi, V.: Generating natural language attacks in a hard label black box setting (2021) 19. Maheshwary, R., Maheshwary, S., Pudi, V.: A strong baseline for query efficient attacks in a black box setting (2021)
396
X. Guo et al.
20. Miller, G.A.: WordNet: a lexical database for English. In: Speech and Natural Language: Proceedings of a Workshop Held at Harriman, New York, 23–26 February 1992 (1992) 21. Pang, B., Lee, L.: Seeing stars: exploiting class relationships for sentiment categorization with respect to rating scales. In: Proceedings of ACL (2005) 22. Pruthi, D., Dhingra, B., Lipton, Z.C.: Combating adversarial misspellings with robust word recognition. In: Proceedings of ACL (2019) 23. Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., Sutskever, I.: Language models are unsupervised multitask learners (2019) 24. Reddy, R.: Speech understanding systems: summary of results of the five-year research effort. Carnegie Mellon University (1976) 25. Reimers, N., Gurevych, I.: Sentence-BERT: sentence embeddings using Siamese BERT-networks. In: Proceedings of EMNLP (2019) 26. Ren, S., Deng, Y., He, K., Che, W.: Generating natural language adversarial examples through probability weighted word saliency. In: Proceedings of ACL (2019) 27. Williams, A., Nangia, N., Bowman, S.: A broad-coverage challenge corpus for sentence understanding through inference. In: Proceedings of NAACL (2018) 28. Xiangzhe, G., Shikui, T., Lei, X.: Learning to generate textual adversarial examples. In: The 31st International Conference on Artificial Neural Networks (2022) 29. Yoo, J.Y., Morris, J., Lifland, E., Qi, Y.: Searching for a search method: benchmarking search algorithms for generating NLP adversarial examples. In: Proceedings of the Third BlackboxNLP Workshop on Analyzing and Interpreting Neural Networks for NLP (2020) 30. Zang, Y., et al.: Word-level textual adversarial attacking as combinatorial optimization. In: Proceedings of ACL (2020) 31. Zhang, X., Zhao, J.J., LeCun, Y.: Character-level convolutional networks for text classification. In: Proceedings of NeurIPS (2015)
Multivariate Time Series Retrieval with Binary Coding from Transformer Zehan Tan1 , Mingyu Zhao1 , Yun Wang1 , and Weidong Yang1,2(B) 1
2
School of Computer Science, Fudan University, Shanghai, China {18110240062,19110240028,20110240066,wdyang}@fudan.edu.cn Zhuhai Fudan Innovation Institute, Hengqin New Area, Zhuhai, Guangdong, China Abstract. Deep learning to binary coding improves multivariate time series retrieval performance by end-to-end representation learning and binary codes from training data. However, it is fair to say that exist deep learning retrieval methods, e.g., Encoder-Decoder based on recurrent or Convolutional neural network, failed to capture the latent dependencies between pairs of variables in multivariate time series, which results in substantial loss of retrieval quality. Furthermore, supervised deep learning to binary coding failed to meet the requirements in practice, due to the lack of labeled multivariate time series datasets. To address the above issues, this paper presents Unsupervised Transformer-based Binary Coding Networks (UTBCNs), a novel architecture for deep learning to binary coding, which consists of four key components, Transformer-based encoder (T-encoder), Transformer-based decoder (T-decoder), an adversarial loss, and a hashing loss. We employ the Transformer encoderdecoder to encode temporal dependencies, and inter-sensor correlations within multivariate time series by attention mechanisms solely. Meanwhile, to enhance the generalization capability of deep network, we add an adversarial loss based upon improved Wasserstein GAN (WGAN-GP) for real multivariate time series segments. To further improve of quality of binary code, a hashing loss based upon Convolutional encoder (C-encoder) is designed for the output of T-encoder. Extensive empirical experiments demonstrate that the proposed UTBCNs can generate high-quality binary codes and outperform state-of-the-art binary coding retrieval models on Twitter dataset, Air quality dataset, and Quick Access Recorder (QAR) dataset.
Keywords: Multivariate Time Series Retrieval Transformer · Adversarial Loss · Hashing Loss
1
· Binary Coding ·
Introduction
In recent years, multivariate time-series data are being collected increasingly by various appliances [1,2]. This flood of multivariate time-series data requires management to realize fast storage and access, where retrieval and similarity Supported by the National Natural Science Foundation of China, U2033209. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 397–408, 2023. https://doi.org/10.1007/978-981-99-1639-9_33
398
Z. Tan et al.
search are typical task [3,4]. Due to the dimensionality curse problem, transformations are applied to time-series data to reduce the number of dimensions in the representation of time-series segments. Over the past several decades, a diverse range of segment representation approaches have been developed, e.g., Discrete Wavelet Transform (DWT) [5], Discrete Fourier Transform (DFT) [6], Karhunen Loeve Transform (KLT) [7], Singular Value Decomposition (SVD) [8], Piecewise Aggregate Approximation (PAA) [9], etc. Although representations obtained from the above methods are effective, retrieval via pairwise distance based on Euclidean distance (EU) or Dynamic Time Warping (DTW) still suffers high computational complexity [10,11]. Binary coding techniques e.g., Locality Sensitive Hashing (LSH) [12], have been proposed to enable efficient approximate nearest neighbor search via Hamming ranking of binary codes. Deep learning [13], is an advanced technologies technology in data mining. For example, HashGAN [14] has been employed to learn binary hash codes from images. However, this model was designed for image retrieval and does not encode the temporal dependencies in sequence segments. To address this issue, Deep Unsupervised Binary Coding Networks (DUBCNs) [15] was developed to perform multivariate time-series retrieval. However, DUBCNs cannot capture inter-sensor correlations in multivariate time-series. Inspired by the success of Transformer in sequence tasks, we propose to apply Unsupervised Transformer-based Binary Coding Networks (UTBCNs) to multivariate time series retrieval, in which Transformer-based encoder-decoder can encode both temporal dynamics and potential dependencies among variables in multivariate time-series segments. For homogeneous, relatively close, or consecutive multivariate time-series segments, Transformer-based encoder-decoder attempts to generate similar binary codes. To further ensure the consistency of binary codes for similar segments, we employ a convolutional encoder [16], trained using a hashing loss on the output of Transformer-based encoder. In addition, to enhance the generalizability of the proposed UTBCNs, we utilize an adversarial loss to help perform the end-to-end learning. While training of UTBCNs, we add an adversarial loss based upon improved Wasserstein GAN (WGAN-GP) [17], which enables stable training. The benefits of the UTBCNs are experimentally verified on three multivariate time series datasets, on which it significantly outperforms the current state of the art. Our primary contributions are summarized as follows: – To the best of our knowledge, this is the first study to introduce the deep model Transformer into retrieval systems. – To enhance generalizability of the proposed network, adversarial loss based on WGAN-GP is employed to assist model training. – The hashing loss function is introduced to enhance the quality of binary codes.
Multivariate Time Series Retrieval with Binary Coding from Transformer
2
399
Related Work
Representation Methods for Time Series: Over the past two decades, various approaches have been developed for time series representation with reduced dimensionality, including Discrete Fourier Transform (DFT) [6], Discrete Wavelet Transform (DWT) [5], Piecewise Aggregate Approximation (PAA) [9], Adaptive Piecewise Constant Approximation (APCA) [18], Symbolic Aggregate Approximation (SAX) [19], Principle Component Analysis (PCA) [20], etc. The above representation methods can be divided into two categories: Data Adaptive representations and Non-Data Adaptive representations. Deep Learning to Binary Coding: Recently, many deep learning to binary coding methods have been proposed to efficiently approximate nearest neighbors by Hamming ranking of the binary codes [21,22]. Generally, binary coding functions can be induced into both supervised and unsupervised methods. The supervised semantics-preserving deep hashing (SSDH) [23] constructs hash functions as a latent layer in a deep convolutional neural network and achieve effective image retrieval performance. Lin et al. proposed DeepBit to learn a compact binary descriptor for efficient visual object matching by optimizing the objective function based on quantization loss, evenly distributed codes, and uncorrelated bits [24]. HashGAN [14] jointly learns a hash function with GAN, which has demonstrated promising effectiveness in unsupervised image retrieval task. Song et al. proposed a Deep r-th root of Rank Supervised Joint Binary Embedding (Deep r-RSJBE) to perform multivariate time-series retrieval and extensive experiments demonstrated the effectiveness and efficiency of Deep r-RSJBE [25], but their method is a supervised learning technique. In addition, Deep unsupervised binary coding networks (DUBCNs) [15] employ an LSTM encoder-decoder to encode temporal dynamics within input samples. However, This approach cannot encode the potential dependencies between pairs of variables in multivariate time-series. To address this issue, we propose the Transformer-based Binary Coding Networks (UTBCNs) to jointly consider the temporal information and latent relevance among parameters in multivariate time-series using self-attention mechanism. Attention Mechanism in Time Series: Recently, several studies have attempted to employ the attention mechanism in time series anomaly detection and prediction in order to improve the performance of time series data mining. For instance, Pereira et al. incorporate a variational self-attention into recurrent autoencoder for anomaly detection in time-series data [26]. Qin et al. propose a dual-stage attention-based recurrent neural network (DA-RNN) to perform time series prediction [27]. In this model, an input attention mechanism is designed to adaptively select relevant input features at each time step. Meanwhile, a temporal attention layer is used to extract relevant encoder hidden states across all time steps. These two attention mechanisms are well integrated within LSTM encoder-decoder and improve the effectiveness of time series prediction. The works mentioned above basically treat attention mechanism as an additional component to the original network, e.g., encoder-decoder based upon
400
Z. Tan et al.
LSTM or GRU units. In contrast, Transformer [28] are built solely on multi-head self-attention mechanism and achieve a great success on text sequence modeling task. Similarly, time-series can be considered a form of sequence data. We employ the deep network Transformer on time series modeling, to realize the multivariate time-series retrieval.
3 3.1
Multivariate Time-Series Retrieval Network Problem Statement
Before going to the details, we first state the framework of UTBCNs which contains four key components, i.e., Transformer-based encoder (T-encoder), Transformer-based decoder (T-decoder), adversarial loss and hashing loss. In multivariate time-series retrieval, Xt,ω = (xt , . . . , xt+ω−1 ) ∈ Rω×n denotes the query segment, in which xt = (x1t , x2t , . . . , xnt ) ∈ Rn represents the vector in time step t, where t is the time index, n denotes number of time-series(variables) in the multivariate time-series, and ω represents the segment length. Given the query multivariate time-series segment Xt,ω , it is necessary to find its most similar time-series segments in the given dataset. Here, we introduce a new multivariate time series retrieval model called UTBCNs, which applies the binary coding representations from Transformer to multivariate time-series retrieval. It is built upon the popular Transformer layer and Wasserstein GAN. As shown in Fig. 1, at first, multivariate time-series segmentations serve as the input of the T-encoder, which can capture latent temporal features in multivariate timeseries data. Then, we reconstruct the input segments through T-decoder. Significantly, the binary coding, which serves as measurement criteria for multivariate time series similarity search, is produced by T-encoder. Finally, we combine mean square error (MSE) upon Transformer encoder-decoder, adversarial loss on Wasserstein GAN, and hashing loss upon convolutional encoder to complete the end to end learning process. 3.2
Transformer-Based Encoder
As shown in Fig. 1, the Transformer layer encoder contains two sublayers, including the multi-head self-attention sub-layer and feed-forward network. We employ the multi-head self-attention layer. Specifically, the input multivariate time-series segments are linearly projected into corresponding subspace h times using the multi-head self-attention mechanism, with different, learned linear projections. Then each projected version is preformed the self-attention function in parallel, to produce output values which are concatenated and once again projected: M ultiHead(Xt,ω ) = Concat(head1 , . . . , headh )W O headi = Self ∼ Attention(Xt,ω WiQ , Xt,ω WiK , Xt,ω WiV )
(1)
where the projections matrices W O ∈ Rd×d , WiQ , WiK , and WiV ∈ Rd×d/h are parameters for the model to learn. For each head, the “Scaled Dot-Product
Multivariate Time Series Retrieval with Binary Coding from Transformer
401
Fig. 1. Illustration of overall structure of proposed UTBCNs model.
Attention” is employed, where the input contain queries Q, keys K, and values V vectors. The matrix of outputs can be computed as follows: QK T Attention(Q, K, V ) = sof tmax( √ )V dk
(2)
where Q, K, and V are projected from the same time-series segment Xt,ω with different projection matrices. In addition, to avoid extremely small gradients, the dot products are scaled by d1k . Position-Wise Feed-Forward Network. Because stacking the self-attention layers helps to capture more complex item transition dependencies in most cases. Here, as the network goes deeper, it becomes increasingly difficult to train it. To address this problem, a residual connection [29] is employed around each of the two layers, followed by the layer normalization [30]. In addition, prior to layer normalization, we add the dropout layer for the output of each sublayer. In summary, the output of each sublayer can be obtained as follows: LayerN orm(x + Dropout(Sublayer(x))
(3)
where Sublayer(·) represents the function employed by the sublayer itself, Dropout(·) is the dropout operation, and LayerN orm(·) denotes the layer normalization function defined above. Here, LayerNorm is used to normalize the inputs over the hidden units in the same layer to stabilize and accelerate the model training.
402
3.3
Z. Tan et al.
Transformer-Based Decoder
For the output of T-decoder, we employ a fully-connected layer to generate the reconstructed time-series segments as follows:
X = σ(W d + b)
(4)
where W ∈ Rd×n and b ∈ Rn are the parameters for the fully-connected layer to learn. d denotes the output of decoder layers, σ is the activation function (Tanh in this case). Finally, the objective is defined as the reconstruction error over the time-series segments. Equation 6 shows the computing process of the MSE loss: LM SE =
1
m n
Nbatch
i=1 j=1
(Xi,j − Xi,j )2
(5)
where Nbatch is the number of sample segments in one batch (batch size), m, n represents the length of the input series segments and the number of the features in multivariate time-series data, respectively. In addition, X and X represents the input and output data of the entire UTBCNs model. 3.4
Adversarial Loss
To address the overfitting, we add adversarial loss to improve the training process and enhance the generalization ability of UTBCNs. Here, we employ an improved Wasserstein GAN (WGAN) model followed T-encoder. To overcome the training difficulty associated with the traditional GAN model, we employ the WGAN with an improved training strategy [17], which adopts the Wasserstein distance that is a measure of the distance between two probability distributions to train the discriminator and adds a gradient penalty to enforce a differentiable Lipschitz constraint as follows: x)] − E [D(x)] min LD = E [D(˜ ˜ ∼Pg D x∼Pr x 2 + γxˆ ∼Pxˆ (∇xˆ D(ˆ x)2 − 1)
(6)
˜ = G(z), Pr is the real data diswhere Pg is the model distribution defined by x tribution, Pxˆ is implicitly defined as the sampling uniformly along straight lines between pairs of points sampled from Pg and Pr , and γ represents the penalty coefficient. The improved WGAN exhibits strong modeling performance and training stability across various architectures without almost no hyper-parameter tuning. We concatenate the real feature matrix g and z, which serve as the inputs to generator G. 3.5
Hashing Loss
As shown in Fig. 1, the GAN model considered in this study comprises three components: a generator, a discriminator, and a convolutional encoder (C-encoder).
Multivariate Time Series Retrieval with Binary Coding from Transformer
403
Here, the C-encoder shares same model weights with discriminator except the parameter in last layer. The C-encoder is designed to capture the spatial information in query segments, by generating the corresponding hashing code representation. The parameters in the C-encoder are trained via hashing loss in output of T-encoder. The hashing objective used to train C-encoder takes the following form: min −
m
2 ti log(ti ) + (1 − ti )log(1 − ti ) + Wε T Wε − I 2
(7)
i=1
where ti represents the i-th hashing code over T-encoder output, and Wε is the weights of last layer in C-encoder. Note that the hashing codes are also obtained via the last layer in C-encoder network. 3.6
Binary Codes Embedding
For the output of T-encoder and subsequent linear layer, we employ another fully-connected layer to obtain the feature vector V [v1 , v2 , . . . vm ], which is compressed into binary code B[b1 , b2 , . . . bm ] via binary embedding mapping formulated by (8) bc = sgn(vj ), j = 1, 2, . . . , m where sgn(·) represents the sign function which returns 1 if the input is greater than 0 and −1 otherwise. Here, we use the generated binary code based upon Hamming distance to measure the similarity among the multivariate time-series data. 3.7
Objective and Training Procedure
The final objective of the proposed UTBCNs is defined as follows: L = LM SE + λ1 Ladv + λ2 Lhash
(9)
where λ1 , λ2 denotes the hyper-parameter to control the importance of adversarial loss and hashing loss, respectively. To optimize this objective, it is essential to deal with the following minimax game: G∗ , D∗ = arg min max L(G, D)
(10)
Thus, we optimize the generator G and discriminator D separately. Specifically, during the process of optimizing generator, we only focus on the layers of G to synthesize fake samples that fool the discriminator. When optimizing discriminator, the goal is to train D to distinguish the real segments from synthesized samples. Note that the parameters of C-encoder are trained via hashing loss. Meanwhile, we also update the parameters of Transformer encoder-decoder via MSE loss, adversarial loss, and hashing loss jointly. We choose Adam optimizer to train the whole UTBCNs model [31].
404
4
Z. Tan et al.
Experiments
4.1
Datasets
The following datasets were used in our experiments. Beijing Multi-site Air Quality Dataset1 : This is typical multivariate time series data including multiple variables. Twitter Dataset2 : This dataset contains examples of buzz events from Twitter. Quick Access Recorder Dataset: This dataset was collected from the real flight recorders. It contains 28 variables measuring altitude information, pitch attitude and speed information. 4.2
Parameters Setting and Evaluation Metrics
The proposed UTBCNs model contains seven hyper-parameters. Here, in all experiments, the size of mini-batch is 128, the dimension of input and output for Transformer encoder-decoder is 512, and the dimension of inner-layer in Transformer encoder-decoder is 2048. The learning rate is set to 2 × 10−4 . In addition, for the hyper-parameter λ1 , λ2 in the objective, it is optimized based on grid search over {10−4 , 10−3 , 10−2 , 10−1 }, when the number of identical layers in Transformer varies N = {2, 4, 6, 8}. In addition, experiments conducted on the above datasets using different multivariate time-series retrieval methods were repeated for five query segments and we calculated the average as the final results. To evaluate the performance of multivariate time series retrieval models, given a query segment, we calculate its K Nearest Neighbors (KNN) based on the Euclidean distance in the raw data to generate the query result set, and subsequently use this KNN result as the ground truth. Significantly, we retrieval the similar time series segments using the binary codes based on Hamming distance, which are generated from corresponding retrieval methods. The effectiveness of various time-series retrieval algorithms is measured by three evaluation metrics, i.e., Mean Average Precision (MAP), precision at top-K position (Precision@K) and recall at top-K position (Recall@K). 4.3
Results
To measure the retrieval efficiency of the proposed UTBCNs and baseline algorithms, the MAP values obtained by each model on three datasets are illustrated in Table 1, where the number of binary bits v varies from 64, 128, and 256. As result can be seen, the deep learning to binary coding retrieval approaches outperforms shallow method LSH and deep learning to hash approach HashGAN, which ascribes that LSH cannot obtain good vector representation to time-series retrieval and HashGAN is tailored for image retrieval, which 1 2
https://archive.ics.uci.edu/ml/datasets/Beijing+Multi-Site+Air-Quality+Data. https://archive.ics.uci.edu/ml/datasets/Buzz+in+social+media+.
Multivariate Time Series Retrieval with Binary Coding from Transformer
405
Fig. 2. Precision@K with 64 binary bits on (a) Air Quality (l=10), (b) Twitter (l=20), and (c) QAR dataset (l=20).
Fig. 3. Recall@K with 64 binary bits on (a) Air Quality (l=10), (b) Twitter (l=20), and (c) QAR dataset (l=20).
cannot encode the essential temporal dependencies in multivariate time-series. In addition, Transformer-based methods demonstrated better retrieval effectiveness than the RNN-based Encoder-Decoder. Significantly, UTBCNs consistently achieves the best MAP on these three datasets, which is attributed to the ability of UTBCNs to capture the spatial information, temporal features, and latent dependencies among variables in multivariate time-series. To further demonstrate the validity of hashing loss relative to improving binary coding retrieval, we compared the proposed UTBCNs with its variant, i.e., UTBCNs without hashing loss (UTBCNs-noHash). As shown in Table 1, UTBCNs outperforms UTBCNsnoHash, which indicates that hashing loss during model training improves the quality of binary codes. In addition, to investigate the effect of adversarial loss in terms of improving the unsupervised multivariate time-series retrieval task, we compared UTBCNs with Transformer variant (UTBCNs without hashing loss and adversarial loss). We also compared LSTM-ED with its variant, i.e., LSTM-ED-GAN. As shown in Table 1, we found that the proposed UTBCNs and LSTM-ED-GAN display better multivariate time series effectiveness than Transformer and LSTM-ED, respectively. This indicates that the adversarial loss can indeed enhance the generalization ability of retrieval model. To further investigate the effectiveness of the proposed model, we compared UTBCNs to the baseline methods in terms of Precision@K and Recall@K, as shown in Fig. 2 and Fig. 3, respectively. In Fig. 2, we found that the proposed UTBCNs obtained
406
Z. Tan et al.
Table 1. MAP of UTBCNs and the baseline methods on Twitter(l=20), Air Quality(l=10), and QAR dataset(l=20) when m = 64, 128, and 256, N=6, KNN = 100. Algorithms
Twitter 64bits
128bits
256bits
Air Quality 64bits 128bits
256bits
QAR 64bits
128bits
256bits
LSH
8.79±1.3
11.72±2.4
15.13±2.8
6.79±1.5
10.32±2.1
13.46±2.3
41.46±4.6
43.91±3.9
46.18±3.1
HashGAN
6.81±1.1
10.16±1.4
15.76±1.8
5.13±1.2
9.34±1.4
12.74±2.1
31.95±2.1
34.76±1.8
38.62±1.7
LSTM-ED
13.44±2.7
15.97±2.3
19.16±3.1
9.83±1.8
12.41±2.4
16.33±2.7
59.39±3.5
64.03±2.8
68.97±2.9
LSTM-ED-GAN
14.51±3.8
16.48±2.7 20.49 ±3.6 10.15±2.1
12.73±2.3 16.57 ±2.7 64.09 ±3.6 67.72±4.1
70.43±3.9
DUBCNs
14.49±2.4
17.56±2.5
21.38±2.6
15.25±2.1
19.53±2.4
22.03±2.4
60.14±2.1
64.31±2.5
69.03±2.7
Transformer
13.67±2.8 15.54± 2.3 19.89±2.7
13.72±2.3
17.88±2.1
20.87±3.1
63.16±2.8
66.36±3.7
69.79±4.1
16.84±2.8
20.92±3.1
25.01±2.6
66.47±3.6
70.09±4.1
73.89±3.6
UTBCNs-noHash 14.22±1.8 UTBCNs
17.98±2.3
20.76±2.8
17.43±1.3 22.16±2.1 24.62±1.9 21.26±1.5 25.51±1.8 30.87±2.1 68.86±2.2 72.54±2.4 77.32±2.8
higher Precision@K than the baseline approaches on all three datasets. Similarly, in terms of Recall@K (Fig. 3),we also found that the proposed UTBCNs outperformed the compared baseline methods. These results demonstrate that our UTBCNs model can maintain high precision and recall at the top of Hamming distance ranking list, which suggests two key findings, i.e., 1)UTBCNs can capture temporal information in multivariate time-series; 2)the attention mechanism in UTBCNs can encode the potential dependencies between pairs of variables within multivariate time-series data.
5
Conclusion
In this paper, we have proposed Unsupervised Transformer-based Binary Coding Networks (UTBCNs), an unsupervised deep learning method to perform multivariate time-series retrieval. Given input multivariate time-series segments, the self-attention mechanisms in UTBCNs can jointly capture the temporal information and latent dependencies between pairs of variables within the time series samples, subsequently generating the binary codes for retrieval based upon Hamming distance. Meanwhile, to improve the quality of generated binary codes, a hashing loss based on convolutional network is implemented in the proposed network. Furthermore, to enhance the generalization capability of the proposed network, we also implement adversarial loss in training process, to facilitate the end-to-end learning. Extensive experiments were conducted to investigate the proposed UTBCNs, and the results demonstrate that it outperformed state-ofthe-art methods for multivariate time-series retrieval. The code we used to train and evaluate our models is available at https://github.com/haha1206/UTBCNs.
Multivariate Time Series Retrieval with Binary Coding from Transformer
407
References 1. Fu, T.: A review on time series data mining. Eng. Appl. Artif. Intell. 24, 164–181 (2011) 2. Hallac, D., Vare, S., Boyd, S., Leskovec, J.: Toeplitz inverse covariance-based clustering of multivariate time series data. In: Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 215–223 (2017) 3. Lin, J., Keogh, E., Lonardi, S., Chiu, B.: A symbolic representation of time series, with implications for streaming algorithms. In: Proceedings of the 8th ACM SIGMOD Workshop on Research Issues in Data Mining and Knowledge Discovery, pp. 2–11 (2003) 4. Yeh, C., et al.: Matrix profile I: all pairs similarity joins for time series: a unifying view that includes motifs, discords and shapelets. In: 2016 IEEE 16th International Conference on Data Mining (ICDM), pp. 1317–1322 (2016) 5. Chan, K., Fu, A.: Efficient time series matching by wavelets. In: Proceedings of the 15th International Conference on Data Engineering (Cat. No. 99CB36337), pp. 126–133 (1999) 6. Faloutsos, C., Ranganathan, M., Manolopoulos, Y.: Fast subsequence matching in time-series databases. ACM SIGMOD Rec. 23, 419–429 (1994) 7. Effros, M., Feng, H., Zeger, K.: Suboptimality of the Karhunen-Loeve transform for transform coding. IEEE Trans. Inf. Theory 50, 1605–1619 (2004) 8. De Lathauwer, L., De Moor, B., Vandewalle, J.: A multilinear singular value decomposition. SIAM J. Matrix Anal. Appl. 21, 1253–1278 (2000) 9. Keogh, E., Chakrabarti, K., Pazzani, M., Mehrotra, S.: Dimensionality reduction for fast similarity search in large time series databases. Knowl. Inf. Syst. 3, 263–286 (2001) 10. Berndt, D., Clifford, J.: Using dynamic time warping to find patterns in time series. In: KDD Workshop, vol. 10, pp. 359–370 (1994) 11. Rakthanmanon, T., et al.: Searching and mining trillions of time series subsequences under dynamic time warping. In: Proceedings of the 18th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 262–270 (2012) 12. Gan, J., Feng, J., Fang, Q., Ng, W.: Locality-sensitive hashing scheme based on dynamic collision counting. In: Proceedings of the 2012 ACM SIGMOD International Conference on Management of Data, pp. 541–552 (2012) 13. LeCun, Y., Bengio, Y., Hinton, G.: Deep learning. Nature 521, 436–444 (2015) 14. Dizaji, K., Zheng, F., Sadoughi, N., Yang, Y., Deng, C., Huang, H.: Unsupervised deep generative adversarial hashing network. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3664–3673 (2018) 15. Zhu, D., et al.: Deep unsupervised binary coding networks for multivariate time series retrieval. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 1403–1411 (2020) 16. Cao, Y., Liu, B., Long, M., Wang, J.: Hashgan: deep learning to hash with pair conditional Wasserstein GAN. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1287–1296 (2018) 17. Gulrajani, I., Ahmed, F., Arjovsky, M., Dumoulin, V., Courville, A.: Improved training of Wasserstein GANs. In: Advances in Neural Information Processing Systems, vol. 30 (2017)
408
Z. Tan et al.
18. Keogh, E., Chakrabarti, K., Pazzani, M., Mehrotra, S.: Locally adaptive dimensionality reduction for indexing large time series databases. In: Proceedings of the 2001 ACM SIGMOD International Conference on Management of Data, pp. 151– 162 (2001) 19. Lin, J., Keogh, E., Wei, L., Lonardi, S.: Experiencing SAX: a novel symbolic representation of time series. Data Mining Knowl. Discov. 15, 107–144 (2007) 20. Korn, F., Jagadish, H., Faloutsos, C.: Efficiently supporting ad hoc queries in large datasets of time sequences. ACM SIGMOD Rec. 26, 289–300 (1997) 21. Zhu, Y., Shasha, D.: Warping indexes with envelope transforms for query by humming. In: Proceedings of the 2003 ACM SIGMOD International Conference on Management of Data, pp. 181–192 (2003) 22. Norouzi, M., Punjani, A., Fleet, D.: Fast search in hamming space with multi-index hashing. In: 2012 IEEE Conference on Computer Vision and Pattern Recognition, pp. 3108–3115 (2012) 23. Yang, H., Lin, K., Chen, C.: Supervised learning of semantics-preserving hashing via deep neural networks for large-scale image search. arXiv Preprint arXiv:1507.00101, vol. 1, p. 3 (2015) 24. Lin, K., Lu, J., Chen, C., Zhou, J.: Learning compact binary descriptors with unsupervised deep neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1183–1192 (2016) 25. Song, D., Xia, N., Cheng, W., Chen, H., Tao, D.: Deep R-th root of rank supervised joint binary embedding for multivariate time series retrieval. In: Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 2229–2238 (2018) 26. Pereira, J., Silveira, M.: Unsupervised anomaly detection in energy time series data using variational recurrent autoencoders with attention. In: Proceedings of 2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA), pp. 1275–1282 (2018) 27. Qin, Y., Song, D., Chen, H., Cheng, W., Jiang, G., Cottrell, G.: A dual-stage attention-based recurrent neural network for time series prediction. arXiv Preprint arXiv:1704.02971 (2017) 28. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 29. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 30. Ba, J., Kiros, J., Hinton, G.: Layer normalization. arXiv Preprint arXiv:1607.06450 (2016) 31. Kingma, D., Ba, J.: Adam: a method for stochastic optimization. arXiv Preprint arXiv:1412.6980 (2014)
Learning TSP Combinatorial Search and Optimization with Heuristic Search Hua Yang1(B) and Ming Gu2 1
School of Software, Tsinghua University, Beijing 100084, China [email protected] 2 Tsinghua University, Beijing 100084, China
Abstract. Traveling Salesman Problem (TSP) and similar combinatorial search and optimization problems have many real-world applications in logistics, transportation, manufacturing, IC design, and other industries. Large-scale TSP tasks have always been challenging to solve fast. During the training phase of the model, when the number of city nodes exceeds 200, the training process will be terminated due to insufficient memory. This paper achieves reducing memory usage by simplifying the network model. However, the prediction accuracy is lowered after the network model is simplified. In this paper, heuristic search methods such as greedy search, beam search and 2-opt search are used to improve the prediction accuracy. Our main contributions are: increase the number of city nodes that can be solved from 100 to 1000; compensate for the loss of accuracy with various search techniques; use various search techniques in combinatorial search and optimization domain. The novelty of our paper is: the model structure of the Transformer is simplified, and various heuristic search techniques are used to compensate for the accuracy of the solution. In the inference stage, although the search time required by greedy search, beam search, and 2-opt search is quite different, all of them can improve the model’s prediction accuracy to varying degrees. Extensive experiments demonstrate that using various heuristic search techniques can greatly improve the prediction accuracy of the model. Keywords: Combinatorial Optimization · heuristic search search · beam search · 2-opt search · Deep Learning · TSP
1
· greedy
Introduction
Combinatorial search and optimization [6,12,19,24] has essential applications across many fields, such as logistics, transportation, IC design, production planning, scheduling, operations research [1,11,23,24,27]. The Traveling Salesman Problem (TSP) [4,10,16] is a traditional combinatorial heuristic search and optimization problem, and many other combinatorial search and optimization problems can be reduced to the TSPs. Solving combinatorial search and optimization problems with the traditional methods can be categorized into three kinds. First, exact methods [32]. Exact algorithms ensure optimal solutions are found, but they become increasingly intractable as the problem grows. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 409–419, 2023. https://doi.org/10.1007/978-981-99-1639-9_34
410
H. Yang and M. Gu
Therefore, they are only limited to small-scale problems. Second, approximation algorithms [1]. Approximation algorithms are not guaranteed to find optimal solutions. Third, heuristic algorithms [1,11,20]. Heuristic algorithms can find reasonable satisfactory solutions in acceptable time, which require well-designed heuristics to guide search. It can be seen that the exact algorithm is only suitable for small-scale problems, and the approximate algorithm is difficult to guarantee determinism. Although the heuristic algorithm can solve large-scale problems, it needs to manually design rules without flexibility, adaptability, and the convenience of fast solutions. However, machine learning approaches for solving combinatorial search and optimization issues have sparked much interest and attention in recent years. On the other hand, machine learning can provide quick solution times, improved generalization, and flexibility. Through heuristic ways of autonomously finding itself based on training data, machine learning methods are likely suited for a wide range of optimization problems, necessitating fewer human procedures than many solvers [9,23,27] that optimize for a single task. Although it has many advantages that the machine learning method solves combinatorial search and optimization problems, it also faces enormous challenges. First, out of memory. The training procedures will terminate running due to out of memory when the number of city nodes is large (more than 200). With the increase in the number of TSP city nodes, the training time is getting longer and longer, and the memory required is getting larger and larger, which eventually leads to memory overflow. Second, accuracy decrease. If a simple network structure is designed to reduce memory usage, the accuracy of inference results will decrease. Third, training time is long. When the number of nodes is large, the training time of the model will be very long. Faced with these challenges, we carefully improved the Transformer [28] network structure to reduce memory usage, and used several different search techniques to improve the accuracy of inference results. To reduce memory usage, we deleted the first layer of the encoder in Transformer architecture [28], the MultiHead Attention module (MHA), and the first layer of the decoder in Transformer architecture, the Masked Multi-Head Attention module (masked MHA). However, the performance of the training model will be cut down because of the simplified Transformer architecture. Therefore, in the inference phase, we utilize greedy search, beam search, and 2-opt [14] search to make up for the loss of accuracy. The contributions of this paper are as follows: 1). Increase the number of city nodes that can be solved from 100 to 1000: existing works usually can only solve 100 nodes with an actual training model rather than a generalization paradigm. We boost the maximum number of city nodes which may be solved from 100 to 1000 only using the training model of the actual node number. 2). Compensate for the loss of accuracy with various search techniques: During the inference phase, we use various search techniques to
Learning TSP Combinatorial Search
411
compensate for the loss of accuracy caused by simplifying the network structure during the training phase to reduce memory usage. 3). Use various search techniques in combinatorial search and optimization domain: In the context of using deep reinforcement learning to solve combinatorial search and optimization, we use various search techniques to improve accuracy. 4). Provide a reference for people in the effectiveness of various search technologies: The different effects of various search technologies in combinatorial search and optimization can provide a reference for everyone when choosing to use.
2
Related Works
Neural networks are increasingly being employed to solve the TSPs, in addition to outstanding solution solvers like Concorde [27], Gurobi [23], LKH3 [11], and Google OR Tools [9]. Neural networks may learn improved heuristic features from data to replace handcrafted heuristic features in TSP combinatorial search and optimization. HopfieldNets [13], which first uses neural networks to solve small-scale TSPs, Pointer Networks [29], which primarily uses the attention mechanism [2] to solve the variable length output problem, Neural Combinatorial Optimization [3], which uses Reinforcement Learning and RNN to solve TSPs while uses active search to refine the prediction results of the inference phase, and the Vehicle Routing Problem (VRP) [22], which solves the VRP using Reinforcement Learning, are examples of typical work using neural networks. The work of Kool et al. [18] has a significant influence. The Transformer model was utilized in this study to address path optimization issues such as the TSP, the Vehicle Routing Problem (VRP), the Orienteering Problem (OP), and others. The encoder is generally an Encoder of a typical Transformer without position encoding. It utilized the first city node visited and the last observed city node as an input to the Decoder. The Transformer [28] was initially used to tackle the TSPs by Deudon et al. [7]. The Decoder’s usual input was the recently visited three city nodes, and this work used the Transformer’s standard Encoder without employing positional encoding. Simultaneously, it used the 2-opt heuristic search [14] approach to increase the Transformer model’s prediction accuracy. Xing et al. [33] present a self-learning approach to solve the TSP, which combines deep reinforcement learning and Monte Carlo tree search. it uses Monte Carlo tree search to select the best policy increasing its generalization ability. Other learning-based approaches to solving combinatorial optimization problems include [3,5,8,15,17,19,21,22,25,29,30].
412
H. Yang and M. Gu Output Probabilities
Softmax Linear Add & Norm
Feed Forward
Nx
Add & Norm
Add & Norm
Feed Forward
Multi-Head Attention
Nx
Positional Encoding Input Embedding
Output Embedding
Outputs
Inputs
Fig. 1. The Proposed Architecture. The Multi-Head Attention module of the Encoder and the Masked Multi-Head Attention module of the Decoder is deleted, and the Positional Encoding module of the Encoder is not used due to no order of the city node coordinate.
3 3.1
The Model Architecture Model Designing
We exclusively solve the two-dimensional Euclidean plane-symmetric TSP, which is a full and undirected graph, in this paper. The node feature is represented as the node position coordinate xi , i ∈ {1, 2, 3, ..., n} for a problem instance of the TSP graph s. We’re looking for the best route permutation of the nodes π so that each city is only visited once and the overall distance L(π|s) is kept to a minimum. The traveling salesman problem (TSP) can be formalized as the following constrained optimization problem: min π
s.t.
L(π|s) =
n−1
xπi − xπi+1 2 + xπ1 − xπn 2 ,
i=1
f (π, s) = 0, g(π, s) ≤ 0,
(1)
Learning TSP Combinatorial Search
413
where f (π, s) and g(π, s) is constraint functions. We use reinforcement learning [26] to learn the solutions to the problem instances. As a result, we give the following definitions for states, actions, transition, rewards, and policy, which are all significant variables in the Markov Decision Process of reinforcement learning [26]: – States: In TSP cases, the state st reflects an ordered list of cities that have been visited. The first state s1 is the index order of the first city visited, the initial state s0 is NULL, and the final state sn is the collection of all the city nodes we visited. – Actions: The action at is one of all unvisited city nodes, in other words, the next node which will be selected. – Rewards: The reward value is defined as the negative distance cost, i.e. r(ai , si ) = −xπi − xπi+1 2 . – Transition: A city node from the collection of unvisited city nodes is chosen and added to the collection of visited city nodes as part of the TSP’s state transition. – Policy: The parameter θ represents the network’s trainable weights, and the policy pθ (π|s) is represented as a neural network. This policy function pθ (π|s) will provide the probability distribution of the next unselected candidate city given a collection of visited city nodes. Using a probability chain rule, this policy function pθ (π|s) is defined as Eq. 2, pθ (π|s) =
n
pθ (πi |π1...i−1 , s).
(2)
i=1
We still follow the overall structure of the vanilla transformer [28] but made the following essential changes, as shown in Fig. 1: 1). We delete the first module of the Encoder in the Transformer architecture [28], i.e., the Multi-Head Attention module 2). We delete the first module of the Decoder in the Transformer architecture [28], i.e., the Masked Multi-Head Attention module. 3). We remove the positional encoder of the Encoder in the Transformer architecture [28], i.e., the Positional Encoding module.
3.2
Model Training with Reinforcement Learning
Reinforcement Learning may provide an adequate technique for training neural networks in terms of combinatorial optimization. To train the network model, we use model-free policy-based Reinforcement Learning to enhance the parameters. The loss function is the ATL (Average Tour Length) which is defined as Eq. 3, given an input graph s (3) C(θ|s) = Epθ (π|s) L(π|s).
414
H. Yang and M. Gu
Table 1. The Accuracy Improvement with Greedy search. All results are averaged based on 100 random instances. The number of city nodes in the TSP graph are 20, 100, 200, 500, and 1000. ATL, i.e., Average Travel Length, and OGR (Optimal Gap Ratio). The ‘-’ indicates out-of-memory. method
TSP20 ATL OGR
Concorde Gurobi OR-Tools Vinyals et al Bello et al. Dai et al. Nazari et al. Joshi et al. Ma et al Deudon et al. Kool et al. Bresson et al. Gasse et al
TSP500 ATL OGR
TSP1000 ATL OGR
3.83 0.00% 7.77 0.00% 10.53 0.00% 3.83 0.00% 7.77 0.00% 10.53 0.00% 3.84 0.02% 7.97 2.72% 11.78 3.79%
16.56 0.00% 16.57 0.00% 17.46 5.3%
23.14 0.00% – – 26.48 14.5%
4.02 3.97 3.95 4.08 3.99 4.12 4.24 3.99 3.98 3.97
– – – – – – – – – –
– – – – – – – – – –
1.71% 1.65% 1.63% 1.78% 1.68% 1.83% 2.12% 1.68% 1.67% 1.65%
TSP100 ATL OGR
8.59 8.56 8.75 8.65 8.32 8.44 8.92 8.55 8.58 8.43
2.67% 2.58% 4.38% 3.56% 2.21% 2.37% 4.78% 2.57% 2.59% 2.35%
TSP200 ATL OGR
– 13.63 13.93 13.81 12.94 13.25 12.82 13.64 12.42 12.73
– 25.82% 26.65% 25.97% 17.69% 24.76% 17.63% 25.86% 17.51% 17.58%
– – – – – – – – – –
– – – – – – – – – –
Greedy search. 3.98 1.67% 8.71 4.32% 13.52 25.81% 18.83 10.51% 29.83 29.67%
To improve the network parameters, we employ stochastic gradient descent and policy gradient approaches. Equation 4’s gradient is calculated with the wellknown REINFORCE technique [31], ∇θ C(θ|s) = Epθ (π|s) [(L(π|s) − b(s))∇θ log pθ (π|s)],
(4)
where b(s) is a baseline that is independent of pi and predicts the average trip length to decrease gradient variance. A good b(s) baseline decreases gradient variance and speeds up learning. We sample graphs s1 , s2 , ..., sB ∼ S in order to draw mini-batch B and sample a single tour per graph, the gradient in Eq. 4 is approximated with Monte Carlo sampling as follows: ∇θ C(θ) =
B 1 (L(πi |si ) − b(si ))∇θ log pθ (πi |si ). B i=1
(5)
In order to learn the average tour length obtained by our present policy pθ given an input sequence s, we create an auxiliary network of the critic. Between its predictions bθ2 (s) and the actual tour durations sampled by the most recent policy, the critic is trained using stochastic gradient descent with a mean squared error goal. The additional objective is formulated as Cri(θ2 ) =
B 1 L(πi |si ) − bθ2 (si )22 . B i=1
(6)
Learning TSP Combinatorial Search
415
Table 2. The Accuracy Improvement with Beam Search (BS). All results are averaged based on 100 random instances. The number of city nodes in the TSP graph are 20, 100, 200, 500, and 1000. ATL, i.e., Average Travel Length, and OGR (Optimal Gap Ratio). The ‘-’ indicates out-of-memory.
method
TSP20 ATL OGR
Concorde Gurobi OR-Tools
3.83 0.00% 7.77 0.00% 10.53 0.00% 16.56 0.00% 23.14 0.00% 3.83 0.00% 7.77 0.00% 10.53 0.00% 16.57 0.00% – – 3.84 0.02% 7.97 2.72% 11.78 3.79% 17.46 5.3% 26.48 14.5%
Vinyals et al Bello et al. Dai et al. Nazari et al. Joshi et al. Ma et al Deudon et al. Kool et al. Bresson et al. Gasse et al
3.96 3.92 3.94 3.96 3.98 3.99 4.14 3.96 3.89 3.94
1.15% 1.09% 1.12% 1.15% 1.18% 1.19% 1.42% 1.15% 1.07% 1.12%
8.23 8.38 8.35 8.48 8.14 8.32 8.95 7.98 7.89 8.11
2.56% 3.24% 3.08% 3.29% 2.21% 3.02% 3.69% 1.29% 0.39% 2.18%
– 13.39 13.89 13.21 12.88 12.76 12.68 13.34 12.42 12.22
– 8.76% 8.95% 8.43% 7.97% 7.88% 7.53% 8.66% 7.21% 6.35%
– – – – – – – – – –
– – – – – – – – – –
– – – – – – – – – –
– – – – – – – – – –
BS(b=20). BS(b=200). BS(b=1000). BS(b=2000).
3.92 3.89 3.86 3.84
1.09% 1.07% 0.39% 0.35%
8.42 8.38 8.14 7.79
3.24% 3.24% 2.21% 0.14%
12.16 12.03 11.27 10.72
6.27% 6.01% 2.26% 2.05%
18.25 17.85 17.41 16.85
7.86% 6.79% 5.65% 3.26%
26.55 25.56 24.77 23.68
9.85% 8.87% 8.28% 3.89%
4 4.1
TSP100 ATL OGR
TSP200 ATL OGR
TSP500 ATL OGR
TSP1000 ATL OGR
Experiments Datasets and Experimental Details
We use datasets generated artificially, which are randomly generated in the unit square [0,1]*[0,1] from a uniform distribution. All our experiments are trained with a single piece of Geforce RTX 2080Ti GPU, 11G GPU memory, CPU E5 2678v3, 32G CPU memory [34]. We use a batch size 32 for TSP20, TSP50, TSP100, and TSP200, 16 for TSP500, and 8 for TSP1000 due to GPU memory limit. For this reason, we generate 100 random mini-batches for TSP20, TSP50, TSP100, and TSP200, and 200 mini-batches for TSP500, and TSP1000 in each epoch. TSP20 and TSP50 train for 200 epochs as convergence is faster for smaller problems, whereas TSP100, TSP200, TSP500, and TSP1000 train for 300 epochs. Our experiments use 256 hidden units, and embed the two coordinates of each point in a 512-dimensional space. We train our models with the Adam optimizer and use an initial learning rate of 10−3 that we decay every 5000 steps by a factor of 0.96. We initialize our parameters uniformly at random −1 √1 , d ] and clip the L2 norm of our gradients to 2.0. within [ √ d
416
4.2
H. Yang and M. Gu
Results and Analysis
In order to get closer to optimality, searching at inference time is necessary, but it comes at the cost of longer running periods. Table 1 shows the accuracy improvement using greedy search, Table 2 shows the accuracy improvement using beam search, and Table 3 shows the accuracy improvement using 2-opt search. To evaluate the effectiveness of our model and other similar baseline works, we present two types of metrics: Average Tour Length (ATL), and Optimality Gap Ratio (OGR), which is the average percentage ratio of the anticipated tour length compared to optimal solutions. The optimal solution obtained by the three solvers Concorde, Gurobi, and OR Tools is a comparison baseline for all other tasks. Table 3. The Accuracy Improvement with 2-opt search. All results are averaged based on 100 random instances. The number of city nodes in the TSP graph are 20, 100, 200, 500, and 1000. ATL, i.e., Average Travel Length, and OGR (Optimal Gap Ratio). The ‘-’ indicates out-of-memory. method
TSP20 ATL OGR
TSP100 ATL OGR
TSP200 ATL OGR
TSP500 ATL OGR
TSP1000 ATL OGR
Concorde Gurobi OR-Tools
3.83 0.00% 3.83 0.00% 3.84 0.02%
7.77 0.00% 7.77 0.00% 7.97 2.72%
10.53 10.53 11.78
0.00% 0.00% 3.79%
16.56 16.57 17.46
0.00% 0.00% 5.3%
23.14 – 26.48
0.00% – 14.5%
Vinyals et al Bello et al. Dai et al. Nazari et al. Joshi et al. Ma et al Deudon et al. Kool et al. Bresson et al. Gasse et al
3.89 3.89 3.88 3.98 3.97 3.98 4.02 3.89 3.88 3.89
8.29 8.32 8.29 8.45 8.04 8.32 8.83 7.95 7.78 8.45
– 13.33 13.77 13.01 12.84 13.33 13.45 13.24 12.32 13.36
– 5.42% 5.93% 5.13% 4.29% 5.42% 5.76% 5.26% 4.01% 5.47%
– – – – – – – – – –
– – – – – – – – – –
– – – – – – – – – –
– – – – – – – – – –
2-opt search.
3.84 0.02% 7.78 0.36% 11.24 3.46% 17.87 4.14% 24.32 6.28%
1.05% 1.05% 1.03% 1.12% 1.11% 1.12% 1.32% 1.05% 1.03% 1.05%
2.98% 3.04% 2.98% 3.26% 2.01% 3.04% 3.58% 1.27% 0.36% 3.26%
In Table 1, we can see that our greedy search is not as good as other authors’ greedy search because we simplify the network structure, which leads to a decrease in the predictive ability of the model. However, we can expand the number of nodes that can be solved to 1000, while other authors can only solve 200. This just means that we have reduced the memory usage and expanded the size of the nodes that can be solved. In Table 2, we set the search widths of similar work in the third row to also be 20, 200, 1000, and 2000, and then take the average of the best results. For our beam search algorithm, when the search widths are 20, 200, 1000, and 2000, the results of inference are completely different, and the time spent is also different. The larger the search width, the more time it takes.
Learning TSP Combinatorial Search
417
In Table 3, we can find that our 2opt search algorithm has better OGR than other similar works, the reason is that although the network structure is simplified, in the inference stage, the 2opt search algorithm effectively compensates for the loss of accuracy of the model prediction. This also shows that, among similar search algorithms, the 2opt algorithm can well compensate for the loss caused by the reduction of parameters during model training. From Table 1, Table 2, and Table 3, we observe that: (1) our designed model can find the Average Tour Length (ATL) of 1000 city nodes, while other works can only find that of no more than 200 nodes because of out-of-memory. (2) from the data of 20, 100, and 200 nodes, it can be seen that among the three search methods, i.e., the greedy search, beam search, and 2-opt search, the greedy algorithm has the worst accuracy, the 2-opt algorithm has the best performance, and the beam algorithm is in between them, where the size of the beam search is set to 20, 200, 1000. (3) we use 500, and 1000 nodes to train a model directly, and then use these trained models to test the data while we do not use the model trained with 200 nodes, and then generalize to 500, and 1000 nodes to get test data. The important significance of this is that a model trained with 1000 nodes can be used to infer the test data of up to 10000 nodes. Since the training time of the model is very long, we use one trained model to uniformly test the inference test time of each heuristic search method, as shown in Table 4. We can see that a greedy search takes the least time and its inference is fast. Depending on the search width, the time in which the beam search is used varies greatly. The 2-opt search method is faster than the beam search methods with widths of 20 and 200, but slower than the beam search methods with widths of 1000 and 2000. Therefore, on the whole, beam search and 2-opt use time in exchange for high precision of the solution, while the greedy method is the opposite. Table 4. Time comparison of various heuristic search methods in the inference test phase. We take 100 randomly generated test data and take the sum of their total test time, the unit of time is seconds. Method Greedy Search
TSP20 0.038 s
TSP100 TSP200 TSP500 1.42 s
4.81 s 89.02 s
TSP1000
13.89 s
129.47 s
120.76 s
821.53 s
Beam Search(b=20)
1.31 s
14.54 s
Beam Search(b=200)
5.14 s
62.43 s 241.34 s
461.13 s 1802.29 s
Beam Search(b=1000) 12.56 s
294.52 s 472.33 s
932.21 s 2503.24 s
Beam Search(b=2000) 27.11 s
375.27 s 745.43 s 1421.76 s 3957.62 s
2-opt Search
13.55 s
53.87 s 257.38 s
732.42 s 2103.27 s
418
5
H. Yang and M. Gu
Conclusion
This paper designed a new original Transformer-based network architecture in order to solve better the TSP, and refined the prediction results of the inference phase with heuristic searchs such as greedy search, beam search, and 2-opt search. Through comparison with similar work by other authors, it is found that our designed network architecture can achieve state-of-the-art results and the optimal gap ratio which is the closest to the optimal solution. We will expand the number of city nodes from 1000 to 10000 in the future, and strive to control the accuracy of the solution within a certain range. Our most unique contribution is that the accuracy of the solution remains relatively high on the basis of the expansion of the urban node scale. The biggest limitation of future research is insufficient memory of CPU or GPU.
References 1. Arora, S.: The approximability of NP-hard problems. In: Proceedings of the Thirtieth Annual ACM Symposium on Theory of Computing, pp. 337–348 (1998) 2. Bahdanau, D., Cho, K.H., Bengio, Y.: Neural machine translation by jointly learning to align and translate. In: 3rd International Conference on Learning Representations, ICLR 2015 (2015) 3. Bello, I., Pham, H., Le, Q.V., Norouzi, M., Bengio, S.: Neural combinatorial optimization with reinforcement learning. arXiv preprint arXiv:1611.09940 (2016) 4. Boese, K.D.: Cost versus distance in the traveling salesman problem. Citeseer (1995) 5. Bresson, X., Laurent, T.: The transformer network for the traveling salesman problem. arXiv preprint arXiv:2103.03012 (2021) 6. Cook, W., Lov´ asz, L., Seymour, P.D., et al.: Combinatorial optimization: papers from the DIMACS Special Year, vol. 20. American Mathematical Soc. (1995) 7. Deudon, M., Cournut, P., Lacoste, A., Adulyasak, Y., Rousseau, L.-M.: Learning heuristics for the TSP by policy gradient. In: van Hoeve, W.-J. (ed.) CPAIOR 2018. LNCS, vol. 10848, pp. 170–181. Springer, Cham (2018). https://doi.org/10. 1007/978-3-319-93031-2 12 8. Gasse, M., Ch´etelat, D., Ferroni, N., Charlin, L., Lodi, A.: Exact combinatorial optimization with graph convolutional neural networks. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 9. Google, I.: Google optimization tools (or-tools) (2018). https://github.com/google/ or-tools 10. Gutin, G., Punnen, A.P.: The Traveling Salesman Problem and Its Variations, vol. 12. Springer, New York (2006). https://doi.org/10.1007/b101971 11. Helsgaun, K.: An extension of the Lin-Kernighan-Helsgaun TSP solver for constrained traveling salesman and vehicle routing problems. Technical report (2017) 12. Hochba, D.S.: Approximation algorithms for np-hard problems. ACM SIGACT News 28(2), 40–52 (1997) 13. Hopfield, J.J., Tank, D.W.: Neural computation of decisions in optimization problems. Biol. Cybern. 52(3), 141–152 (1985) 14. Johnson, D.: Local search and the traveling salesman problem. In: Automata Languages and Programming. LNCS, pp. 443–460. Springer, Berlin (1990)
Learning TSP Combinatorial Search
419
15. Joshi, C.K., Cappart, Q., Rousseau, L.M., Laurent, T., Bresson, X.: Learning TSP requires rethinking generalization. arXiv preprint arXiv:2006.07054 (2020) 16. J¨ unger, M., Reinelt, G., Rinaldi, G.: The traveling salesman problem. In: Handbooks in Operations Research and Management Science, vol. 7, pp. 225–330 (1995) 17. Khalil, E., Dai, H., Zhang, Y., Dilkina, B., Song, L.: Learning combinatorial optimization algorithms over graphs. In: Advances in Neural Information Processing Systems, pp. 6348–6358 (2017) 18. Kool, W., van Hoof, H., Welling, M.: Attention, learn to solve routing problems! In: International Conference on Learning Representations (2018) 19. Li, W., Ding, Y., Yang, Y., Sherratt, R.S., Park, J.H., Wang, J.: Parameterized algorithms of fundamental np-hard problems: a survey. HCIS 10(1), 1–24 (2020) 20. Lin, S., Kernighan, B.W.: An effective heuristic algorithm for the travelingsalesman problem. Oper. Res. 21(2), 498–516 (1973) 21. Ma, Q., Ge, S., He, D., Thaker, D., Drori, I.: Combinatorial optimization by graph pointer networks and hierarchical reinforcement learning. arXiv preprint arXiv:1911.04936 (2019) 22. Nazari, M., Oroojlooy, A., Snyder, L., Tak´ ac, M.: Reinforcement learning for solving the vehicle routing problem. In: Advances in Neural Information Processing Systems, pp. 9839–9849 (2018) 23. Gurobi Optimization: Gurobi optimizer reference manual (2018). http://www. gurobi.com 24. Papadimitriou, C.H., Steiglitz, K.: Combinatorial optimization: algorithms and complexity. Courier Corporation (1998) 25. Peng, B., Wang, J., Zhang, Z.: A deep reinforcement learning algorithm using dynamic attention model for vehicle routing problems. In: Li, K., Li, W., Wang, H., Liu, Y. (eds.) ISICA 2019. CCIS, vol. 1205, pp. 636–650. Springer, Singapore (2020). https://doi.org/10.1007/978-981-15-5577-0 51 26. Sutton, R.S., Barto, A.G.: Reinforcement Learning: An Introduction. MIT Press, Cambridge (2018) 27. Chvatal, V., Applegate, D.L., Bixby, R.E., Cook, W.J.: Concorde TSP solver (2006). www.math.uwaterloo.ca/tsp/concorde 28. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, pp. 5998–6008 (2017) 29. Vinyals, O., Fortunato, M., Jaitly, N.: Pointer networks. Comput. Sci. 28 (2015) 30. Welling, M., Kipf, T.N.: Semi-supervised classification with graph convolutional networks. In: International Conference on Learning Representations (ICLR 2017) (2016) 31. Williams, R.J.: Simple statistical gradient-following algorithms for connectionist reinforcement learning. Mach. Learn. 8(3), 229–256 (1992) 32. Woeginger, G.J.: Exact algorithms for NP-hard problems: a survey. In: J¨ unger, M., Reinelt, G., Rinaldi, G. (eds.) Combinatorial Optimization — Eureka, You Shrink! LNCS, vol. 2570, pp. 185–207. Springer, Heidelberg (2003). https://doi. org/10.1007/3-540-36478-1 17 33. Xing, Z., Tu, S., Xu, L.: Solve traveling salesman problem by Monte Carlo tree search and deep neural network. arXiv preprint arXiv:2005.06879 (2020) 34. Yang, H.: Extended attention mechanism for tsp problem. In: 2021 International Joint Conference on Neural Networks (IJCNN), pp. 1–8. IEEE (2021)
A Joint Learning Model for Open Set Recognition with Post-processing Qinglin Li1 , Guanyu Xing1 , and Yanli Liu2(B) 1
National Key Laboratory of Fundamental Science on Synthetic Vision, Sichuan University, Chengdu, China 2 College of Computer Science, Sichuan University, Chengdu, China [email protected] Abstract. In open set recognition (OSR), the model not only needs to correctly recognize known class samples, but also needs to be able to effectively reject unknown samples. To address this problem, we propose a joint learning model with post-processing based on the concept of Reciprocal Points. Specifically, to guarantee the accuracy of known class recognition, we design a two-branch network containing a self-supervised branch and a classification branch. The self-supervised branch helps the model classify known classes more accurately. Then, to avoid misjudging unknown samples as known ones with high confidence, we carefully redesign the open loss to better separate the known and unknown spaces, and design a post-processing mechanism to penalize the predictions of potential unknown samples. We perform several experiments and ablations on our model, obtaining the state-of-the-art results on most datasets for open set recognition and unknown detection tasks.
Keywords: Open set recognition Post-processing
1
· Self-supervised learning ·
Introduction
In the past few years, deep learning has achieved the state-of-the-art performance in computer vision such as detection and recognition, but many challenges remain in solving real-world problems. A typical problem is that the information presented and learned during the training phase is incomplete. Most current deep learning methods are based on the assumption of closed sets, which means their training data and test data come from the same label space. However, we often face the appearance of unknown classes in real-world scenarios, and traditional classifiers often wrongly recognize the unknown class samples as one of the known classes, which greatly affects their usefulness in practical scenarios. To overcome this limitation, open set recognition is proposed in [12], which assumes that training samples use known class data, while test samples can be This research is supported by National Natural Science Foundation of China (Grant No. 61972271) and Sichuan Science and Technology Program (No. 2022YFS0557). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 420–432, 2023. https://doi.org/10.1007/978-981-99-1639-9_35
A Joint Learning Model for Open Set Recognition with Post-processing
421
from any class. In the open set scenario, samples included/excluded in label space are referred to as knowns/unknowns. The main task of open set recognition is not only to correctly identify known class samples, but also to have the ability to detect unknown class samples. With the development of deep learning, many OSR methods have emerged. [1] uses Openmax to replace Softmax in CNNs, redistributes the probability distribution of Softmax, and get the classification probability of unknown samples, but the training stage is not optimized. [4,5,9,10,15] use generative models such as GAN or VAE to generate potential unknown samples or reconstruct input to optimize training phase to solve the problem of open set recognition. However, [8] found the uncertainty of mainstream generative models for the problem of open set recognition. Based on the above problems, RPL [3] proposed an open set recognition method based on Reciprocal Points. RPL proposed the concept of Reciprocal Points to represent out-of-class space, which can also be called open space. Through the learning of reciprocal points, a bounded space containing reciprocal points is formed, and the known class samples are outside this bounded space. Instead, this bounded space represents potential unknown samples, enabling the separation of known and unknown spaces, as shown in Fig. 1. RPL is not complicated but performs well on most datasets. However, we found in our experiments that RPL still has two problems. One is that when an unknown sample is input, the model still classifies the sample as known ones with high confidence, another is that there are also some incorrect classifications when inputting known samples. Both are also problems with other OSR methods. To overcome such shortcomings, we propose a novel method based on Reciprocal Points, a joint learning model for open set recognition with postprocessing. First, on the basis of the traditional classification branch, we introduce a self-supervised branch to better learn the semantic information of known classes, thereby achieving more accurate classification of known classes. The self-supervised branch and the classification branch share the network weights. Second, we redesign the open loss and carefully design a penalty mechanism. The redesigned open loss helps to better separate known and unknown classes and learn the dynamic boundary of the bounded space. Then we take penalty measures based on the prediction results of the self-supervised branch and the dynamic boundary, which avoids classifying unknown class samples as known classes with high confidence. Our contributions are as follows: • A joint learning model is designed by introducing a self-supervised branch, which is used to effectively learn semantic information for classification tasks, and the self-supervised branch acts in the training and test phases; • We redesign an open loss to separate the known and unknown spaces and learn the dynamic boundary of the known and unknown spaces; • We propose a post-processing penalty mechanism to effectively avoid predicting unknown class samples as known ones with high confidence; • The extensive experiments demonstrated our method achieves the highest performance in most cases.
422
Q. Li et al.
Fig. 1. An overview of classification based on the Reciprocal Points. The out-of-class embedding space (open space) for each known class is restricted to a bounded range. Thus, the known class and the unknown class are separated.
2
Preliminaries
The key to OSR is to reduce the empirical classification risk on labeled known data and the open space risk on potential unknown data simultaneously [12]. Following this, the total loss of OSR problem can be expressed as: arg min{
N
k=1
Re (f, DL ) + α
N
Ro (f, DU )},
(1)
k=1
where f : Rd → N is a measurable multiclass recognition function, α is the regularization parameter, Re is the empirical classification loss of known data, and Ro is the open space loss, which is used to measure the uncertainty of classifying unknown samples as known or unknown classes. DL represents the known classes samples used in the training phase, and DU represents the unknown classes samples. The concept of Reciprocal Points comes from [3]. The reciprocal point P k of class k can be regarded as the representation of known classes which do not =k ∪ DU ). In the embedding feature belong to class k and unknown classes (DL space, the probability of sample x belonging to class k is proportional to the distance between embedding fθ (x) and reciprocal point P k [3]. Use the softmax function normalize the final probability to get: k ed(fθ (x),P ) p (y = k |x, fθ , P ) = N , d(fθ (x),P i ) i=1 e
(2)
wherethe P means the all reciprocal points corresponding to all known classes, the d fθ (x) , P k is calculated by combining Euclidean distance and dot product. The learning of θ is achieved by minimizing the classification loss of reciprocal points based on the negative log probability of the true label k: Lcls (x; θ, P ) = − log p (y = k |x, fθ , P ) .
(3)
Minimizing Eq. 3 is equivalent to minimizing Re (f, DL ) in Eq. 1. The classifier based on reciprocal points reduces the empirical classification risk by updating the reciprocal points and model weights, thereby improving the classification accuracy.
A Joint Learning Model for Open Set Recognition with Post-processing
3 3.1
423
Proposed Method Joint Learning with Self-supervision
Traditional closed-set classifiers only learn the features necessary to distinguish known classes. However, these features may not completely describe the class, and even some class-specific features may be missing. This drawback will be magnified in the open set recognition task. Benefiting from the good ability of self-supervised learning to learn semantic information, we introduce a selfsupervised branch to design a joint learning model to better learn the semantic information of known classes so that known class samples can be better classify. We adopt the self-supervised framework proposed in [6]. A geometric transformation is randomly applied to an input from a finite set of geometric transformations, and the self-supervised branch of the model is used to predict which transformation is applied. To determine the applied transformation, the model needs to learn more semantic information about the image, structural properties of the content such as shape and orientation. In order for the self-supervised branch to correctly predict rotation transformations, we use Lss (xi , θ) to help the self-supervised branch converge, and Lss (xi , θ) is expressed as follows: Lss (xi , θ) = −
K 1 log(F y (g(xi |y)|θ)), K y=1
(4)
where K indicates that we can do K kinds of transformations, g ( ·| y) is to apply the geometric transformation to the sample x with label y, F y (.|θ) is the predicted probability of geometrically transformed input with label y, and θ is the learnable parameter of model F (.), which shares weights with classification branch. In this way, the semantic information of the known class images is better learned, so that the classification of known classes can be more accurate according to the semantic information. The predecessors [11] also used self-supervision, but unlike them, our self-supervised branch not only helps us learn semantic information better in the training phase, but the prediction results of the selfsupervised branch are also used as a prior for judging whether the input belongs to a known or unknown class in the test phase. 3.2
Learning Dynamic Boundary
The open space risk is directly related to the separation of known and unknown spaces. As shown in Fig. 1, we expect to confine unknown samples to a bounded space through the learning of reciprocal points. In order to achieve this goal, considering that the known space and the unknown space are complementary, and the unknown sample is closer to the reciprocal point than the known sample.
424
Q. Li et al. R
R
known sample unknown sample reciprocal point
(a) Open loss without margin
(b) Open loss with margin
Fig. 2. (a): the open loss without margin, it can be seen that there are known classes falling into the unknown space (in the circle); (b): the open loss with margin, which ensures that the known samples are outside the boundary as much as possible
RPL [3] learns a dynamic boundary R to make the distance from the known sample to the reciprocal point is close to the boundary R, which indirectly limits the unknown space within this boundary R. As follows: 2 Lo x; θ, P k , Rk = d fθ (x), P k −Rk 2 ,
(5)
where d fθ (x), P k denotes the distance between the sample of known class k k k and the corresponding reciprocal P , R denotes a learnablek value, which points k represents the minimum of d fθ (x), P . Each class has its own R . With Eq. 5, the distance from the center of each known class to the reciprocal points is close enough to the dynamic boundary to achieve the effect shown in Fig. 2(a). However, as Fig. 2(a) shows, there are some known class samples that fall into the unknown space, resulting in insufficient separation of the known and unknown spaces. Therefore, we redesign the open loss as follows: 2 Lo x; θ, P k , Rk = d fθ (x), P k −(Rk + m arg in)2 ,
(6)
where margin represents the expected distance between Rk and d fθ (x), P k . We add margin to make the each known class center further away from the reciprocal points on the basis of Fig. 2(a), so as to achieve a more thorough separation of known and unknown spaces, as shown in Fig. 2(b). We define the margin as follows: N N 1 1 (d(x, i) − d(x, i))2 , (7) m arg in = N − 1 i=1 N i=1 where d(x, i) = d(fθ (x), P i ), margin represents the standard deviation of the distance between the sample x and all reciprocal points. Reducing Eq. 6 is equivalent Ro in Eq. 1, which can control our open loss. Re controls the to reduce d fθ (x), P k to increase, and increase the unknown class space, Ro controls the k value d fθ (x), P k and the dynamic boundary R are close to each other, which k weakens the increasing trend of d fθ (x), P to some extent. At last it tends to balance, and the known and unknown space with clear boundaries is obtained.
A Joint Learning Model for Open Set Recognition with Post-processing
3.3
425
Post-processing Penalty Mechanism
To avoid the classifier predicting unknown classes as known classes with high confidence, we design two penalty mechanisms to penalize potential unknown samples with the information learned by the network as a prior. Penalize Based on the Dynamic Boundary R Learned in Sect. 3.2. In Sect. 3.2, we have learned the minimum value of each known class to the corresponding reciprocal point, which we call the dynamic boundary Rk . Let k d x, P be the distance between the current sample x and the reciprocal point corresponding to its predicted class k, Rk is the dynamic boundary of the k class, p (y = k |x ) is the probability that the classifier predicts x as k class, we can make the following penalty:
p(y = k|x) ∗ ∗2, d(x, P k ) < Rk . (8) p(y = k|x) = p(y = k|x), others Assuming a sample x, the classifier predicts that x belongs to class k. When the distance between the sample x and the corresponding reciprocal point P k in the embedding feature space is less than its corresponding dynamic boundary Rk , it has a high probability of belonging to the unknown class theoretically, then we should penalize the prediction score to avoid the classifier misjudges x as a known class with a high confidence. And when the distance between the sample x and the corresponding reciprocal point P k is greater than or equal to the corresponding dynamic boundary Rk , as we expected, we consider it to be a specific class in the known classes, so no penalty measures are taken. Penalize Based on the Self-supervised Learning Branch in Sect. 3.1. Let xr denote the result of the geometric transformation of x, and as the input of the self-supervised branch, r be its geometric transformation label, y be its corresponding self-supervised branch prediction label, p (y = k |x ) is the probability that the classifier predicts x as k class, then the following measures are taken:
p(y = k|x) ∗ ∗2, y = r . (9) p(y = k|x) = p(y = k|x), others For the known samples that participate in the self-supervised branch learning, the self-supervised classifier can make correct predictions about the geometric transformation of the known class samples, while it is difficult to make correct predictions about the geometric transformations made by unknown class samples that only appear in the test phase. Therefore, when the self-supervised classifier cannot correctly predict the geometric transformation of the input sample x, we can consider it to belong to a potential unknown class. Then, we should penalize the prediction score to avoid the classifier predicting x as a known class with high confidence. Conversely, if the geometric transformation of x can be predicted correctly, it is considered to belong to a known class and no penalty measures will be taken.
426
3.4
Q. Li et al.
Learning the Open Set Network
Our network architecture is shown in Fig. 3. We have designed an OSR model with two branches. One branch is similar to traditional classifiers, which mainly learns classification based on reciprocal points. For the sake of fairness, our CNN extract feature part is same to the architecture used in [9], which is a ten-layer convolution structure. The other branch performs self-supervised learning as an auxiliary branch for our open set recognition. The input of the self-supervised branch is a geometrically transformed sample xr , r is the self-supervised label. We do four-angle rotation transformations {0◦ , 90◦ , 180◦ , 270◦ }, and the corresponding labels are {0, 1, 2, 3}. The self-supervised branch performs a fourcategory classification based on embedding to further predict which geometric transformation is performed. Our self-supervised branch and classification branch CNN part share weights. Reciprocal Points
logits < model.R punish
Input
global average pooling
Image
Transformation
Rotate Image
R
final probs
softmax embedding
Random
classfication branch
logits
shared classification network
1x1 Conv probs != gt
softmax embedding
probs
punish
self-supervision branch
Fig. 3. The network architecture of our model. The input image and rotate image are used as the input of classification prediction branch and self-supervised branch respectively.
Training. We train the model on a closed data set (known class data), and finally, the total loss function is as follows: L (x, θ, P, R) = λcls Lcls (x; θ, P ) + λss Lss (xi , θ) + αLo (x; θ, P, R) ,
(10)
where λcls , λss , α are the hyperparameters that control the classification loss, self-supervised loss, and open loss weight, respectively; θ, P, R represents the learning parameters of the classifier, the reciprocal points corresponding to each class, and the dynamic boundary value of each class, respectively. We use x and xr as the input of the classification branch and the self-supervised branch respectively, then we get the output of two branches named outcls , outss respectively. We calculate the distance from x to all reciprocal points, then we can calculate the classification loss Lcls and the open loss Lo . At the same time, the self-supervised branch uses r as the true label to calculate the self-supervised
A Joint Learning Model for Open Set Recognition with Post-processing
427
loss Lss , and finally, we get the total loss L. The composite loss L is backpropagated to find gradients associated with each network weight. Finally, network C is updated according to the network updating algorithm. Inference. In order to better avoid misjudgment of unknown classes and avoid affecting the recognition effect of known classes, we add the post-processing method in Sect. 3.3 to penalize potential unknown classes. The test process is shown in Algorithm 1. We also use x and xr as the input of the classification branch and the self-supervised branch respectively, and then we can get
Algorithm 1: Inference Algorithm Input: Test sample x, Transformation Set T , Transformation label r, Model C. Output: label and score. Rotate the test sample: xr = T (x), label is r. z = (x, xr ). outcls , outss = C (z), Output. log its = d (outcls , P ), distance to Reciprocal Points. probs = sof t max(log its), pred, cls = max(log its).Classification branch. probsrot , clsrot = max (sof tmax (outss )).Self-supervised branch. if (log its[cls] ≤ R[cls] and clsrot = r): probs penalize; (probs ∗ ∗4) elseif (logits [cls] ≤ R [cls]): probs penalize; (probs ∗ ∗2) elseif (clsrot ! = r) : probs penalize; (probs ∗ ∗2) output predict label cls and probability score probs
the respective embedding representations outcls , outss . The classification branch predicts the class cls and probability score pred of x by calculating the distance logits of x from reciprocal points, and the self-supervised branch predicts the geometric transformation label clsrot . Then follow the post-processing mechanism introduced in Sect. 3.3 to determine whether to penalize the predictions. We have three levels of penalties, as shown in Algorithm 1. With the addition of a post-processing mechanism, it is possible to obviously avoid predicting potential unknown classes as known classes with high confidence, and increase the reliability of open set recognition tasks.
4 4.1
Experiments Implementation Details
The network architecture is shown in Fig. 3. The learnable parameter reciprocal points P of the network is initialized randomly and satisfies the normal distribution. The dynamic boundary R is initialized to 0. The classifier use Adam optimizer. The learning rate of the classifier starts from 0.01, drops 0.1 times at {100, 150, 180, 210} epoch, and trains for a total of 250 epochs. Followed [2], except for the MNIST dataset, random center cropping and random horizontal flipping are used as data argument. The size of the TinyImageNet dataset is resized to 64 × 64, and the size of other datasets is 32 × 32.
428
Q. Li et al.
Considering the setting of hyperparameters, α is the weights of open loss, and its value is 0.1; λcls , λss represent the weights of classification loss and selfsupervised loss, which depend on the number of classes of the two branches. Let ncls be the number of known classes of the classification branch, nss be the number of classes of the geometric transformation of the self-supervised branch, nss then λcls = nclsncls +nss , λss = ncls +nss . Because our self-supervised rotation transformation has only 4 angles in total, the final prediction class number of our self-supervised branch is 4 categories, which means nss = 4. CIFAR+10, CIFAR+50. For the CIFAR+N experiments, we randomly select samples of 4 classes from CIFAR10 as known classes, and then randomly select samples of N classes from CIFAR100 as unknown classes. TinyImageNet. TinyImageNet is a sub-set of 200 classes taken from the ImageNet dataset. For experiments with TinyImageNet, we consider 20 classes as known and 180 classes as unknown for experiment. Metrics. In the open set recognition, one of the most commonly used evaluation metric is the Area Under the Receiver Operating Characteristic (AUROC) curve. We use AUROC to evaluate the performance of our method. Besides, Like [2], we also introduced Open Set Classification Rate (OSCR) as a new evaluation metric to simultaneously evaluate the difference between known and unknown classes and the accuracy of known class recognition. The higher the two evaluation metrics, the better. Table 1. The AUROC results of on the detection of known and unknown samples. Results are averaged over 5 different splits of known and unknown classes partitions. Method
MNIST
SVHN
Softmax Openmax [1] G-OpenMax [4] OSRCI [9] CROSR [15] C2AE [10] RPL [3] GFROSR [11] PROSER [16] CVAECapOSR [5] CGDL [13] ARPL+CS [2]
97.8 ± 0.2 98.1 ± 0.2 98.4 ± 0.1 98.8 ± 0.1 99.1 ± 98.9 ± 0.2 98.9 ± 0.1 – – 99.2 ± 0.4 99.4 ± 0.2 99.7 ± 0.1
88.6 89.4 89.6 91.0 89.9 92.2 93.4 95.5 94.3 95.6 93.5 96.7
ours
99.4 ± 0.2
97.7 ± 0.5 92.7 ± 0.4 97.3 ± 0.6 96.2 ± 0.4 79.1 ± 2.1
± ± ± ± ± ± ± ± ± ± ± ±
CIFAR10 0.6 0.8 0.6 0.6 0.9 0.5 1.8 1.2 0.3 0.2
67.7 69.5 67.5 69.9 88.3 89.5 82.7 83.1 89.1 83.5 90.3 91.0
± ± ± ± ± ± ± ± ± ± ± ±
3.2 3.2 3.5 2.9 0.8 1.4 3.9 2.3 0.9 0.7
CIFAR+10 CIFAR+50 TinyImageNet 81.6 81.7 82.7 83.8 91.2 95.5 84.2 91.5 96.0 88.8 95.9 97.1
± ± ± ± ± ± ± ± ± ± ± ±
0.6 1.0 0.2 1.9 0.6 0.3
80.5 79.6 81.9 82.7 90.5 93.7 83.2 91.3 95.3 88.9 95.0 95.1
± ± ± ± ± ± ± ± ± ± ± ±
0.4 0.7 0.2 1.7 0.6 0.2
57.7 57.6 58.0 58.6 58.9 74.8 68.8 64.7 69.3 71.5 76.2 78.2
± ± ± ± ± ± ± ± ± ± ± ±
0.5 1.4 1.2 1.8 0.5 1.3
Experiment Results. Table 1 shows the results of different methods evaluated by AUROC on six datasets, and the best results are indicated in bold. It can be seen that our method achieves the state-of-the-art results on five datasets and achieves
A Joint Learning Model for Open Set Recognition with Post-processing
429
Table 2. The open set classification rate (OSCR) curve results of open set recognition. The results are the average results of 5 random trials. Method
MNIST
SVHN
Softmax GCPL [14] RPL [3] ARPL [2] ARPL+CS [2]
99.2 ± 0.1 99.1 ± 0.2 99.4 ± 0.1 99.4 ± 0.1 99.5 ± 0.1
92.8 93.4 93.6 94.0 94.3
ours
99.2 ± 0.1
94.9 ± 0.9 89.0 ± 1.0 94.4 ± 0.8
± ± ± ± ±
CIFAR10 0.4 0.6 0.5 0.6 0.3
83.8 84.3 85.2 86.6 87.9
± ± ± ± ±
1.5 1.7 1.4 1.4 1.5
CIFAR+10 CIFAR+50 TinyImageNet 90.9 ± 1.3 91.0 ± 1.7 91.8 ± 1.2 93.5 ± 0.8 94.7 ± 0.7
88.5 88.3 89.6 91.6 92.9
± ± ± ± ±
0.7 1.1 0.9 0.4 0.3
60.8 59.3 53.2 62.3 65.9
± ± ± ± ±
5.1 5.3 4.6 3.3 3.8
93.5 ± 0.7 66.8 ± 4.3
an average improvement of 0.8% over the previous state-of-the-art method [2]. In particular, except for the MNIST and CIFAR+10 datasets whose performance is already close to saturation, our method improves by at least 1.0% on the remaining four challenging datasets, fully demonstrating that our method has advantages on solving the challenging datasets. For the MNIST dataset, where the performance is close to saturation, the performance drops slightly due to the complexity of the model, but our method still outperforms most methods on this dataset. In addition, we also report the results of the evaluation metric OSCR. As shown in Table 2, our method achieves the state-of-the-art results on four datasets. Similarly, except for the MNIST and CIFAR+10 datasets, the performance of our method is significantly improved in the other four challenging datasets, with an average improvement of 0.8%, which further proves the effectiveness of semantic information and penalty mechanism for open set recognition. Table 3. Open set recognition results on CIFAR10. The performance is evaluated by macro-averaged F1-scores in 11 classes (10 known and unknown).
4.2
Methods
ImageNet-crop ImageNet-resize LSUN-crop LSUN-resize
Softmax Openmax [1] CROSR [15] GFROSR [11] C2AE [10] RPL [3] CGDL [13] PROSER [16] CVAECapOSR [5]
63.9 66.0 72.1 75.7 83.7 81.1 84.0 84.9 85.7
65.3 68.4 73.5 79.2 82.6 81.0 83.2 82.4 83.4
64.2 65.7 72.0 75.1 78.3 84.6 80.6 86.7 86.8
64.7 66.8 74.9 80.5 80.1 82.0 81.2 85.6 88.2
ours
88.0
87.4
88.7
87.0
Open Set Recognition
In this section, we validate the performance of our proposed method with open set recognition tasks. Following the guidelines of [11], we use all samples of
430
Q. Li et al.
10 classes in CIFAR10 as known data, and samples in ImageNet and LSUN are selected as unknown samples. According to [7], to obtain the same image size as the known samples, we resize or crop the unknown samples, obtaining the following datasets: ImageNet-crop, ImageNet-resize, LSUN-crop, and LSUNresize. We evaluate the performance using the macro-averaged F1-score between 10 known classes and 1 unknown class. The experimental results are shown in Table 3. We can infer from Table 3 that our method can handle open set classes from different inputs and achieves better performance than the SOTA method on three datasets. We improved the F1-score by 1.75% on average on the four datasets. Particularly, ImageNet-resize is the most challenging among the four datasets because previous methods perform poorly on this dataset. However, we improve the F1-score by 4.0% on ImageNet-resize. It proved that for challenging datasets, the accuracy of the separation of known and unknown spaces is important, and our open loss can do this well; at the same time, our penalty mechanism can effectively penalize unknown samples and make up for the uncertainty of separating known and unknown spaces. Table 4. Ablation study on the model architecture, the performance is evaluated by AUROC. Method
MNIST
RPL
99.3 ± 0.1 94.9 ± 0.8 87.6 ± 1.2 94.0 ± 0.8
91.3 ± 0.4
75.7 ± 2.0
RPL+L
99.3 ± 0.1 95.0 ± 0.7 87.9 ± 1.1 94.2 ± 0.6
91.5 ± 0.5
75.8 ± 2.7
RPL+L+CP
99.3 ± 0.1 95.1 ± 0.7 88.1 ± 1.0 94.4 ± 0.6
91.6 ± 0.4
76.5 ± 2.4
RPL+L+S
99.1 ± 0.2 95.0 ± 0.9 88.9 ± 1.2 94.8 ± 0.7
93.1 ± 0.7
76.7 ± 2.5
RPL+L+S+CP
99.1 ± 0.2 95.1 ± 0.9 89.3 ± 1.1 95.1 ± 0.7
93.4 ± 0.7
77.5 ± 2.3
RPL+L+S+SP
99.2 ± 0.2 96.6 ± 0.7 90.6 ± 0.9 96.1 ± 0.8
94.7 ± 0.5
77.7 ± 2.7
ours(RPL+L+S+SP+CP) 99.4 ± 0.2 97.7 ± 0.5 92.7 ± 0.4 97.3 ± 0.6
96.2 ± 0.4
79.1 ± 2.1
4.3
SVHN
CIFAR10
CIFAR+10 CIFAR+50 TinyImageNet
Ablation Study
To verify the contribution of each part of our proposed method, we do corresponding ablation experiments on the main contributions of this method. Our method uses RPL [3] as the baseline. The difference is that, as described in Sect. 2, like [2], we combined with Euclidean distance and dot product to calculate the distance. Throughout the ablation experiments, our classifier was the same as [9]. Our ablation experiments are mainly divided into seven parts: RPL [3] method as our baseline; L stands for the use of redesigned open loss; S stands for the addition of self-supervised branch; CP and SP respectively represent penalty based on dynamic boundary of classification branch and penalty for self-supervised branch. The contribution of the main work in this paper is verified by combining different parts. Finally there is the effect of our final method, the baseline plus the redesigned open loss, self-supervised branch and penalty mechanism, the experimental results are shown in Table 4. The experimental results are evaluated by AUROC. It can be seen that the redesigned
A Joint Learning Model for Open Set Recognition with Post-processing
431
open loss, the proposed penalty mechanism and the introduced self-supervised branch have a direct impact on the improvement of the model effect, especially the introduction of the self-supervised branch can improve our baseline method with L by an average of 0.7%, which can fully improve the performance of the model. It turns out that semantic information itself is very helpful for classification tasks, especially for open set recognition tasks. The redesigned open loss has an average improvement of 0.15%, which is not very noticeable, but proves that our redesigned open loss does achieve better separation of known and unknown spaces than RPL [3]. In addition, our penalty mechanism also improves the baseline method to a certain extent, especially for the TinyImageNet dataset with high difficulty and poor performance of the baseline method, there is an improvement of nearly 1%. Furthermore, our method performs best when the redesigned open loss, self-supervised branching and penalty mechanism are combined, with an average improvement of 3.2% per dataset, which is also a large improvement for the open set recognition problem and shows the effectiveness of the redesigned open loss, self-supervision and penalty mechanism of this method.
5
Conclusion
This paper proposes an open set recognition network based on Reciprocal Points to deal with the open set recognition problem from three aspects. Firstly, a selfsupervised branch is introduced to better learn the semantic information of the samples, helping to better learn the unique features of known classes. Then we redesign the open loss to better separate known and unknown spaces and learn the dynamic boundary of known and unknown spaces. Finally, we believe that the information learned by the network is useful. Based on the knowledge learned by the network, we propose a post-processing penalty mechanism to take certain punishment measures for potential unknown class samples. We conduct extensive comparative experiments and ablation experiments, and it turns out that our proposed method is effective and achieves the effect of SOTA in most datasets. In the future, how to use the information learned by the network to help us make correct decisions in the inference stage is a direction that can help open set recognition.
References 1. Bendale, A., Boult, T.E.: Towards open set deep networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1563–1572 (2016) 2. Chen, G., Peng, P., Wang, X., Tian, Y.: Adversarial reciprocal points learning for open set recognition. IEEE Trans. Pattern Anal. Mach. Intell. (2021). https://doi. org/10.1109/TPAMI.2021.3106743 3. Chen, G., et al.: Learning open set network with discriminative reciprocal points. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12348, pp. 507–522. Springer, Cham (2020). https://doi.org/10.1007/978-3-03058580-8 30
432
Q. Li et al.
4. Ge, Z., Demyanov, S., Chen, Z., Garnavi, R.: Generative openmax for multi-class open set classification. In: British Machine Vision Conference 2017. British Machine Vision Association and Society for Pattern Recognition (2017) 5. Guo, Y., Camporese, G., Yang, W., Sperduti, A., Ballan, L.: Conditional variational capsule network for open set recognition. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 103–111 (2021) 6. Komodakis, N., Gidaris, S.: Unsupervised representation learning by predicting image rotations. In: International Conference on Learning Representations (ICLR) (2018) 7. Liang, S., Li, Y., Srikant, R.: Enhancing the reliability of out-of-distribution image detection in neural networks. arXiv preprint arXiv:1706.02690 (2017) 8. Nalisnick, E., Matsukawa, A., Teh, Y.W., Gorur, D., Lakshminarayanan, B.: Do deep generative models know what they don’t know? arXiv preprint arXiv:1810.09136 (2018) 9. Neal, L., Olson, M., Fern, X., Wong, W.K., Li, F.: Open set learning with counterfactual images. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 613–628 (2018) 10. Oza, P., Patel, V.M.: C2ae: class conditioned auto-encoder for open-set recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2307–2316 (2019) 11. Perera, P., et al.: Generative-discriminative feature representations for open-set recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11814–11823 (2020) 12. Scheirer, W.J., de Rezende Rocha, A., Sapkota, A., Boult, T.E.: Toward open set recognition. IEEE Trans. Pattern Anal. Mach. Intell. 35(7), 1757–1772 (2012) 13. Sun, X., Yang, Z., Zhang, C., Ling, K.V., Peng, G.: Conditional gaussian distribution learning for open set recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13480–13489 (2020) 14. Yang, H.M., Zhang, X.Y., Yin, F., Liu, C.L.: Robust classification with convolutional prototype learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3474–3482 (2018) 15. Yoshihashi, R., Shao, W., Kawakami, R., You, S., Iida, M., Naemura, T.: Classification-reconstruction learning for open-set recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4016–4025 (2019) 16. Zhou, D.W., Ye, H.J., Zhan, D.C.: Learning placeholders for open-set recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4401–4410 (2021)
Cross-Layer Fusion for Feature Distillation Honglin Zhu1 , Ning Jiang1(B) , Jialiang Tang1 , Xinlei Huang1 , Haifeng Qing1 , Wenqing Wu2 , and Peng Zhang2 1
School of Computer Science and Technology, Southwest University of Science and Technology, Mianyang, Sichuan 621000, China [email protected] 2 School of Mathematics and Physics, Southwest University of Science and Technology, Mianyang, Sichuan 621000, China Abstract. Knowledge distillation is a model compression technology, which can effectively improve the performance of a small student network by learning knowledge from a large pre-trained teacher network. In most previous works of feature distillation, the performance of the student is still lower than the teacher network due to it only being supervised by the teacher’s features and labels. In this paper, we novelly propose Cross-layer Fusion for Knowledge Distillation named CFKD. Specifically, instead of only using the features of the teacher network, we aggregate the features of the teacher network and student network together by a dynamic feature fusion strategy (DFFS) and a fusion module. The fused features are informative, which not only contain expressive knowledge of teacher network but also have the useful knowledge learned by previous student network. Therefore, the student network learning from the fused features can achieve comparable performance with the teacher network. Our experiments demonstrate that the performance of the student network can be trained by our method, which can be closer to the teacher network or even better. Keywords: Model Compression Cross-Layer Fusion
1
· Knowledge Distillation ·
Introduction
With the development of deep learning, neural networks have obtained satisfactory performance in various computer vision tasks [4,10,27]. However, the required parameters and calculations for a superior network are huge on the whole. In practice, these cumbersome networks with high performance are difficult to deploy on resource-limited devices, such as mobile phones, which hinders the practice application of neural networks. To solve this problem, a variety of research methods have been proposed, including network pruning [16,17], quantization [12], lightweight network design [9,18,22] and knowledge distillation [6,21,26,29]. Among these methods, we are concerned with knowledge distillation (KD), which is mainly to transfer knowledge from the large teacher network to the compact student network. KD aims to train a small student network with similar performance to the large teacher network, thus the student network can be applied to c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 433–445, 2023. https://doi.org/10.1007/978-981-99-1639-9_36
434
H. Zhu et al.
resource-limited devices for practice application. The original distillation approach is introduced by Hinton et al. [6], which makes a student learn the logit soft output after the linear layer from a teacher network. To improve the efficiency of knowledge transfer, Romero et al. [21] proposed a feature distillation framework to utilize the teacher’s intermediate features and ground truth labels to supervise the student training. Some works [3,20,25,26] have further improved the feature distillation method. In general, features from a certain layer of the teacher are transferred to the student. After filtering, a part of useful information extracted by a certain layer of the teacher may be discarded. Therefore, the improvement brought by transferring the knowledge of the teacher alone is limited. Some studies [7,14,15] have found that the aggregated feature maps from two parallel networks are expected to generate better prediction results than those of the feature maps of a single network. It is believed that the prediction results can be improved because the fused features are rich and discriminative [7]. In this study, we propose cross-layer fusion for knowledge distillation (CFKD). In our method, instead of only using teacher features, the student is supervised by expressive knowledge combined with the features of the teacher network and student network, as shown in Fig. 1. We introduce the fused features for two reasons. The first is that the thin and small student network could provide shallow texture information for fused features. The second is that the student could effectively improve their performance through self-distillation of their deeper stage features [8,13,28]. We believe that the fused features that aggregate the deeper stage features of the student are beneficial to the student network training. Furthermore, we explore a dynamic feature fusion strategy, which can get valuable fused features by controlling the role of the student’s features in the fused features. In some teacher/student networks such as ‘ResNet50/ResNet20’ and ‘WRN-40-2/WRN-16-2’, the student network outperforms the teacher network in our experiments, which validates the effectiveness of our method. The details of our proposed distillation method are presented in Sect. 3.3. The specific content of the fusion module and the dynamic feature fusion strategy are shown in Sect. 3.4. Extensive experiments validate the effectiveness of our method in Sect. 4.
2
Related Work
In this section, we introduce the related work in detail. Related works on knowledge distillation and feature distillation are discussed in Sect. 2.1 and Sect. 2.2, respectively. Related works on the feature fusion method are discussed in Sect. 2.3. 2.1
Knowledge Distillation
Reducing model parameters and speeding up network inference are the main purposes of model compression. Knowledge distillation is a simple and convenient approach among model compression methods, which improves the performance
Cross-Layer Fusion for Feature Distillation
435
of the student network by learning the output of the well-trained teacher network. This idea is first introduced by Hinton et al. [6], the student network is not only supervised by ground-truth labels but also mimics the teacher’s predicted probabilities called soft targets. In [2,29], they explore an online distillation method to improve network performance by training multiple student networks and encouraging each network to learn soft targets from other networks. 2.2
Feature-Map Distillation
The goal of feature distillation is to promote the student to learn the teacher’s features. Romero et al. [21] first introduced intermediate layer feature distillation. They proposed that the teacher network transfers the intermediate layer features as knowledge to the student network, which can further improve the prediction ability of the student network. The attention mechanism is introduced by Zagoruyko et al. [26] to extract expressive information from the teacher’s middle layers for knowledge distillation. Heo et al. [5] proposed a margin ReLU function, which advances the position of feature distillation before ReLU. Furthermore, they used a partial L2 loss to reduce the transfer of useless information. To the best of our knowledge, previous work has not considered the role of the student’s deeper features in the teacher-student learning paradigm. In this study, we propose the combination of the student’s deeper features and teacher features as knowledge for student learning. 2.3
Feature Fusion Method
The feature fusion method can combine different features through fusion operations. Lin et al. [15] employed matrix product to aggregate two feature maps produced by two parallel convolutional networks for image classification. They think that fused features can get higher local features. The method of Hou et al. [7] fused features by introducing the ‘SUM’ operation. They argue that richer and more accurate images produced by feature fusion methods are beneficial for recognition. Kim et al. [14] apply fused features for online distillation, to boost each untrained sub-network. Shen et al. [23] aggregated features from multiple teachers to guide the learning of the student network by amalgamation module.
3
Method
This section introduces the cross-layer fusion knowledge distillation (CFKD). The notations are in Sect. 3.1. Section 3.2 briefly introduces logit-based distillation. Figure 1 shows an overview of our distillation method. The details of the proposed method are described in Sect. 3.3. Section 3.4 discusses the fusion method and dynamic feature fusion strategy in detail.
436
H. Zhu et al. f 3t
f 2t
f1t
f 4t
Predicon Somax
Teacher
Lce
FC
Fusion Module Fusion Module
LKL
Fusion Module Fusion Module
LCFKD
LCFKD f1s
Fusion Module Fusion Module
LCFKD f 2s
Label
FC
f 3s
f 4s
Predicon
Somax
Student
Lce Top layer
Boom layer
Fig. 1. The overall framework of cross-layer fusion for feature distillation. The process of distillation consists of two stages. In the first stage, the fusion module aggregates the features from the teacher and the student to generate fused features. In the second stage, rich fused features are learned by the student to improve performance.
3.1
Notations
Given a set of input data X = {x1 , x2 .., xk } including k examples, the label corresponding to each example is denoted as Y = {y1 , y2 .., yk }. We define a pretrained and fixed teacher network as NT and a student network as NS . Let gs , gt denote the logits of the student network and teacher network, respectively, where the variables with superscripts s and t represent them are outputted by the student network and teacher network, respectively, throughout this paper. fjs represents the j-th layer’s features of the student network, j ∈ {1, ..., n}. fit represents the i-th layer’s features of the teacher network, i ∈ {1, ..., n}, where n represents the maximum number of blocks in the network. 3.2
Logit-Based Knowledge Distillation
The purpose of knowledge distillation is to train a student network by imitating the output of a teacher network. In original knowledge distillation [6], the student network is expected to imitate the teacher’s soft targets. For image classification tasks, the traditional distillation loss uses the cross-entropy loss and KullbackLeibler (KL) divergence. The loss function Llogit can be expressed as: Llogit = (1 − λ)Lce (y, σ(gs )) + λτ 2 LKL (σ(gt /τ ), σ(gs /τ )),
(1)
where Lce (·, ·) denotes the cross-entropy loss, LKL (·, ·) denotes KL divergence. τ > 0 is a temperature parameter that controls the level of smoothness between categories. We use τ 2 times LKL because we need to scale the gradient of the soft targets by 1/τ 2 . λ is a balancing hyperparameter.
Cross-Layer Fusion for Feature Distillation
3.3
437
Teacher-Student Feature Fusion
In our method, the key point is to enhance the performance of the student model by learning fused features of the teacher-student model. As shown in Fig. 1, the overall distillation process is divided into two stages. In the first stage, the same data is inputted into a student network and teacher network to obtain features of each layer of the teacher network and features of the bottom layer of the student network. The features from the bottom layer of the student network and the features from the different layers of the teacher network are aggregated into fused features through a fusion module, respectively. In the second stage, the fused features are seen as rich knowledge, which is transferred to different stages of the student network. First, we obtain the features fi aggregated from the features fit of the i-th layer of the teacher network and the features fns of the student network. The fused features fi of a single layer can be defined as: fi = F(W · fit + b, W · fns + b)
(2)
where F(·, ·) function is the fusion operation described in detail in Sect. 3.4. W (·) is a linear function for matrix transformation, b is the bias matrix. After getting the fused features, we can get the distillation loss of CFKD through mean squared error (MSE). The single-layer distillation loss LCF KD can be written as: s ) , (3) LCF KD = M SE fi , R(fn−1 where R(·) is a regression function matching the dimensions of the fused features which is accomplished by [21]. In multi-layer knowledge distillation, the knowledge of the teacher is transferred to multiple layers of the student network. It is worth noting that multilayer knowledge distillation can get higher efficiency for transferring knowledge. We further generalize our method to multi-layer knowledge distillation. More specifically, each layer of the student network learns the corresponding fused features. The multi-layer distillation loss LCF KD is calculated by the following: LCF KD =
n−1
M SE fi , R(fjs ) .
(4)
j=1,i=j+1
In our method, the original cross-entropy loss and KL divergence are added to the total loss function. We introduce hyper-parameters α, β, and γ to tune the relationship of several loss functions. Here is our overall loss function Ltotal : Ltotal = αLce + βτ 2 LKL + γLCF KD .
3.4
(5)
Fusion Module and Dynamic Feature Fusion Strategy
The details of the fusion module are shown in Fig. 2b. First, two feature maps fns and fit should be input into the fusion module. The channel dimensions and sizes
438
H. Zhu et al. f ns
f ns
f it
f it
f
fˆns
fˆit
s
~ f
(a) Dynamic Fusion Function S
(b) Fusion Module
Fig. 2. An overview of the proposed dynamic fusion function and fusion module. (a)The dynamic fusion function is a piecewise function. When the training time increases, the dynamic fusion function obtains different function values. (b)Both features are compressed. The compressed features are aggregated into fused features through 1x1 convolution, multiplication, and addition.
of the feature maps fns and fit are compressed to a relatively small size. Second, The two compressed feature maps fns and fit are aggregated into the feature maps f ∈ Rb×2×h×w . Then, the concatenated features f are multiplied by the compressed features to produce two feature maps fˆns and fˆit . Finally, the fused features fi are obtained by adding the two features fˆns and fˆit . The structure is inspired by Chen et al. [3]. Different from them, we add the operation of compressing the feature maps in the fusion module. The purpose of the dynamic feature fusion strategy (DFFS) is to obtain high-quality fused features by controlling the role of student features in fused features. In the early stage of training, the student has poor performance, whose bottom layer features may bring negative effects to the fused features. Therefore, only the transferred knowledge from the features of the teacher is used in the beginning. With the improvement of the student’s performance, features from the student are more beneficial to the generated fused features. The features of the two models are actively combined into fused features. At this point, we introduce a dynamic fusion function S to control the effect of student features on fused features. The dynamic fusion function is shown in Fig. 2a. In the fusion process, the features of the student are multiplied by S and combined with the features of the teacher. When the student’s performance is similar to the teacher’s performance, the features of the two models are fused in the same proportion. The equation of dynamic fused function S can be written as: ⎧ ⎨ EP OCH + 0.83334 300 S(EP OCH) = ⎩ 1
20 0, with probability at least 1 − δ/2, the following holds. log 2δ . ψ(S) ≤ ES [ψ(S)] + 2m
472
M. Kimura
We next bound the expectation of the right-hand side as follows. m 1 ES [ψ(S)] = ES sup E[g] − g(zi ) m i=1 g∈G m 1 ≤ ES,S sup (g(zi ) − g(zi )) g∈G m i=1 m 1 = Eσ,S,S sup σi (g(zi ) − g(zi )) g∈G m i=1 m m 1 1 ≤ Eσ,S sup σi g(zi ) + Eσ,S sup −σi g(zi ) g∈G m i=1 g∈G m i=1 m 1 = 2Eσ,S sup σi g(zi ) = 2Rm (G). g∈G m i=1 Here, using again McDiarmid’s inequality, with probability at least 1 − δ/2, the following holds. 2 ˆ S (G) + log δ . Rm (G) ≤ R 2m Finally, we use the union bound which yields with probability at least 1 − δ: 2 ˆ S (G) + 3 log δ . φ(S) ≤ 2R (15) 2m Theorem 1 (Margin Bound for Set-to-Set Matching). Let F be a set of matching score functions. Fix ρ > 0. Then, for any δ > 0, with probability at least 1 − δ over the choice of a sample S of size m, each of the following holds for all f ∈ F: 1 log 1δ 2 ˆ ρ (f ) + Rm (F) + R2m (F) + , (16) R(f ) ≤ R ρ 2m 2 2 ˆ ρ (f ) + ˆ S 1 (F) + R ˆ S 2 (F) + 3 log δ . R(f ) ≤ R R (17) ρ 2m Proof. Let F˜ be the family of functions mapping (X×X)×{−1, +1} to R defined by F˜ = {z = (Z , Z), a) → a[f (Z ) − f (Z)] | f ∈ F}, where a ∈ {0, 1}. Consider ˜ derived from F˜ which are taking the family of functions F˜ = {φρ ◦ g | f ∈ F} values in [0, 1]. By Lemma 1, for any δ > 0 with probability at least 1 − δ, for all f ∈ F, 1 ˆ ρ (f ) + 2Rm (φρ ◦ F) ˜ + log δ . (18) E φρ (a[f (Z + ) − f (Z − )]) ≤ R 2m
Generalization Bounds for Set-to-Set Matching with Negative Sampling
473
Since 1u≤0 ≤ φρ (u) for all u ∈ R, the generalization error R(f ) is a lower bound on left-hand side, R(f ) = E[1a[f (Z )−f (Z)]≤0 ] ≤ E[φρ (a[f (Z ) − f (Z)])], and we can write 1 ˜ + log δ . ˆ ρ (f ) + 2Rm (φρ ◦ F) (19) R(f ) ≤ R 2m ˜ ≤ 1 Rm (F) ˜ using the (1/ρ)-Lipschitzness of Here, we can show that Rm (φρ ◦ F) ρ ˜ φρ . Then, Rm (F) can be upper bounded as follows: m 1 ˜ = ES,σ sup Rm (F) σi ai (f (Zi ) − f (Zi )) m f ∈F i=1 m 1 = ES,σ sup σi (f (Zi ) − f (Zi )) m f ∈F i=1 m m 1 ≤ ES,σ sup σi f (Zi ) + sup σi f (Zi ) m f ∈F i=1 f ∈F i=1
= ES [RS 2 (F) + RS 1 (F)] = Rpm2 (F) + Rpm1 (F).
4
RKHS Bound for Set-to-Set Matching
In this section, we consider more precise bounds that depend on the size of the negative sample produced by negative sampling. Let S = ((X1 , Y1 ), . . . , (Xm , Ym )) ∈ (X × X)m be a finite sample sequence, and m+ be the + positive sample size. If the positive proportion mm = α, then sample sequence S also can be denoted by Sα . Let RK be the reproducing kernel Hilbert space (RKHS) associated with the kernel K, and Fr is defined as Fr = {f ∈ RK | f K ≤ r}
(20)
for r > 0. Theorem 2 (RKHS Bound for Set-to-Set Matching). Suppose Sα to be any sample sequence of size m. Then, for any > 0 and f ∈ Fr , 2
2 2 ˆ ; Sα ) − R(f )| ≥ ≤ 2 exp α (1 − α) m , PSα |R(f (21) 2L2 κ2 r2 where κ := supx K(x, x). Proof. Denote S = (S + , S − ) = {Z1 , . . . , Zm } and +1 (Zi ∈ S + ), zi := z(Zi ) :== −1 (Zi ∈ S − ).
(22)
474
M. Kimura
Fig. 1. RKHS bound w.r.t. sample size m and positive ratio α.
First, for each 1 ≤ k ≤ m+ such that zi = +1, let (Zk , +1) be replaced by (Zk , +1) ∈ (X × X) × {−1, +1}, and we denote by S k as this sample. Then, ˆ ; S) − R(f ˆ ; S k )| ≤ |R(f
1 + m m−
1 ≤ + − m m =
+m− m+
|ϕ(f (Zk ) − f (Zj )) − ϕ(f (Zk ) − f (Zj ))|
j=m+ +1 +m− m+
L · |f (Zk ) − f (Zj ) − f (Zk ) + f (Zj )|
j=m+ +1
1 2L · m− · L · |f (Zk ) − f (Zk )| ≤ + f ∞ . m+ m− m
Next, for each m+ + 1 ≤ k ≤ m such that zi = −1, let (Zl , −1) be replaced by (Zk , −1) ∈ (X × X) × {−1, +1} and we denote by S¯k as this sample. Similarly, we have ˆ ; S) − R(f ˆ ; S¯k )| ≤ 2L f ∞ . |R(f m+ Finally, for each 1 ≤ k ≤ m+ such that zi = +1, let (Zk , +1) be replaced by (Zk , −1) ∈ (X × X) × {−1, +1}, and we denote by S˜k = S¯k ∪ {(Zm+1 , −1)} as this sample. Then, we have ˆ ; S) − R(f ˆ ; S˜k )| ≤ Γ1 + Γ2 , |R(f ˆ ; S) − R(f ˆ ; S ∪ {Zm+1 , −1})| and Γ2 = |R(f ˆ ; S ∪ {Zm+1 , −1}) − where Γ1 = |R(f 2L 2L k ˆ ˜ R(f ; S )|. Since Γ1 ≤ m− +1 f ∞ and Γ2 ≤ m+ f ∞ , we have 1 1 ˆ ; S) − R(f ˆ ; S˜k )| ≤ 2L |R(f + (23)
f ∞ . m+ m− + 1 Combining them and applying McDiarmid’s inequality, we have the proof.
Generalization Bounds for Set-to-Set Matching with Negative Sampling
475
Remark 1. Given m, , L, we can find that the tight bound can be achieved when α = 12 . This means that it is desirable the number of positive samples be equal to the number of negative samples (See Fig. 1). Remark 2. For any δ > 0, with probability at least 1 − δ, we have 2 log 2δ Lκr ˆ . R(f ; Sα ) − R(f ) ≤ α(1 − α) m
(24)
Remark 3. For Remark 2, Let m = m+ + m− and fix m+ ∈ N. Then, we have the optimal negative sample size as (1 − α) = 2/3.
5
Conclusion and Discussion
In this paper, we performed a generalization error analysis in set-to-set matching to reveal the behavior of the model in that task. Our analysis reveals what the convergence rate of algorithms in set matching depends on. Future studies may include the following: – Derivation of tighter bounds. There are many types of mathematical tools for generalization error analysis of machine learning algorithms, and it is known that the tightness of the bounds depends on which one is used. For tighter bounds, it is useful to use mathematical tools not addressed in this paper [1,2,8–10]. – Induction of novel set matching algorithms. It is expected to derive a novel algorithm based on the discussion of generalized error analysis. – The effect of data augmentation for generalization error of set-to-set matching. Many data augmentation methods have been proposed to stabilize neural network learning, and theoretical analysis when these are used would be useful [4,5,12,13,17].
References 1. Bartlett, P.L., Bousquet, O., Mendelson, S.: Local Rademacher complexities. Ann. Stat. 33(4), 1497–1537 (2005) 2. Duchi, J.C., Jordan, M.I., Wainwright, M.J.: Local privacy and statistical minimax rates. In: 2013 IEEE 54th Annual Symposium on Foundations of Computer Science, pp. 429–438. IEEE (2013) 3. Iwata, T., Lloyd, J.R., Ghahramani, Z.: Unsupervised many-to-many object matching for relational data. IEEE Trans. Pattern Anal. Mach. Intell. 38(3), 607–617 (2015) 4. Kimura, M.: Understanding test-time augmentation. In: Mantoro, T., Lee, M., Ayu, M.A., Wong, K.W., Hidayanto, A.N. (eds.) ICONIP 2021. LNCS, vol. 13108, pp. 558–569. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-92185-9 46 5. Kimura, M.: Why Mixup improves the model performance. In: Farkaˇs, I., Masulli, P., Otte, S., Wermter, S. (eds.) ICANN 2021. LNCS, vol. 12892, pp. 275–286. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-86340-1 22
476
M. Kimura
6. Kimura, M., Nakamura, T., Saito, Y.: Shift15m: multiobjective large-scale fashion dataset with distributional shifts. arXiv preprint: arXiv:2108.12992 (2021) 7. Lisanti, G., Martinel, N., Del Bimbo, A., Luca Foresti, G.: Group re-identification via unsupervised transfer of sparse features encoding. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2449–2458 (2017) 8. McAllester, D.A.: PAC-Bayesian model averaging. In: Proceedings of the Twelfth Annual Conference on Computational Learning Theory, pp. 164–170 (1999) 9. McAllester, D.A.: Some PAC-Bayesian theorems. Mach. Learn. 37(3), 355–363 (1999) 10. Millar, P.W.: The minimax principle in asymptotic statistical theory. In: Ecole d’Et´e de Probabilit´es de Saint-Flour XI — 1981. LNM, vol. 976, pp. 75–265. Springer, Heidelberg (1983). https://doi.org/10.1007/BFb0067986 11. Saito, Y., Nakamura, T., Hachiya, H., Fukumizu, K.: Exchangeable deep neural networks for set-to-set matching and learning. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12362, pp. 626–646. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58520-4 37 12. Shorten, C., Khoshgoftaar, T.M.: A survey on image data augmentation for deep learning. J. Big Data 6(1), 1–48 (2019) 13. Van Dyk, D.A., Meng, X.L.: The art of data augmentation. J. Comput. Graph. Stat. 10(1), 1–50 (2001) 14. Vapnik, V.: The Nature of Statistical Learning Theory. Springer Science & Business Media, Cham (1999) 15. Vapnik, V.N.: An overview of statistical learning theory. IEEE Trans. Neural Networks 10(5), 988–999 (1999) 16. Xiao, H., et al.: Group re-identification: leveraging and integrating multi-grain information. In: Proceedings of the 26th ACM International Conference on Multimedia, pp. 192–200 (2018) 17. Zhong, Z., Zheng, L., Kang, G., Li, S., Yang, Y.: Random erasing data augmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 13001–13008 (2020) 18. Zhu, P., Zhang, L., Zuo, W., Zhang, D.: From point to set: extend the learning of distance metrics. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2664–2671 (2013)
ADA: An Attention-Based Data Augmentation Approach to Handle Imbalanced Textual Datasets Amit Kumar Sah
and Muhammad Abulaish(B)
Department of Computer Science, South Asian University, New Delhi, India [email protected], [email protected]
Abstract. This paper presents an Attention-based Data Augmentation (ADA) approach that extracts keywords from minority class data points using a vector similarity-based mechanism, uses the extracted keywords to extract significant contextual words from minority class documents using an attention mechanism, and uses the significant contextual words to enrich the minority class dataset. By creating new documents based on significant contextual words and adding them to the minority class dataset, we oversample the dataset for the minority class. On the classification job, we compare the original and oversampled versions of the datasets. We also compare ADA over the augmented datasets with two popular state-of-the-art text data augmentation methods. According to the experimental findings, classification algorithms perform better when used to augmented datasets produced by any data augmentation technique than when applied to the datasets’ original versions. Additionally, the classifiers trained over the augmented datasets generated by ADA are more effective than those generated by state-of-the-art data augmentation techniques. Keywords: Data Augmentation · Machine Learning · Deep Learning Class Imbalance · Attention Mechanism · Information Extraction
1
·
Introduction
Textual data typically experiences problems with class imbalance. For instance, the proportion of fake, hateful, and spam tweets to actual tweets is low. It takes a lot of work to gather textual training data because the distribution of the gathered data must match that of the original data’s syntax, semantics, and pragmatics. One of the most common methods for gathering data is oversampling, which involves producing more documents or samples from the minority class or repeating some documents. The textual dataset is oversampled by the text data augmentation mechanism using a variety of techniques. These strategies include copying documents, changing words with synonyms, or creating new data points using deep learning models. Data augmentation is one of the most popular methods for enhancing model generalization in deep learning models c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 477–488, 2023. https://doi.org/10.1007/978-981-99-1639-9_40
478
A. K. Sah and M. Abulaish
that successfully lowers overfitting while training a neural network. In the field of image processing, data augmentation techniques are successfully used [6,12]. Since they have syntactic, semantic, and pragmatic properties, data augmentation techniques that are effective for image data cannot be applied to textual data. The use of a thesaurus, synonyms, and similarities based on certain algorithms are typically involved in textual data augmentation. Although data augmentation can aid in the training of more reliable models, it is difficult to create universal rules for language transformation due to the complexity of Natural Language Processing (NLP). As a result, the main difficulty in proposing a generalized text data augmentation approach is NLP’s complexity. Keywords and keyphrases are crucial for text data augmentation, according to [1]. The keywords and keyphrases in a document serve to summarize its main points. One of the NLP-related issues with the most research is keyword extraction. Several methods that are frequently used for keyword extraction are – (i) Statistical methods, which primarily use term frequency and word distribution-based methods; (ii) Machine learning and deep learning-based methods, which employ a variety of supervised, semi-supervised, or unsupervised learning algorithms for keywords extraction; and (iii) Graph-based methods, which typically model the document’s vocabulary as nodes and connect them based on the relationships between the words. In this paper, we present an Attention-based Data Augmentation (ADA) approach to oversample the minority class instances of imbalanced textual datasets to improve the detection efficacy of the classification algorithms. The proposed approach utilizes a vector similarity-based keywords extraction mechanism to identify keywords from the minority class data points. Using an attention mechanism, it exploits the identified keywords to extract its corresponding significant contextual words from minority class documents. Finally, it utilizes those significant contextual words to enrich the minority class dataset. The proposed approach oversamples the minority class dataset by generating new documents based on keywords and their significant contextual words and augmenting them to the minority class dataset. The proposed approach seems interpretable and improves the performance of the deep learning classifier over the augmented datasets. In order to increase the detection accuracy of the classification algorithms, we describe in this research an Attention-based Data Augmentation (ADA) method to oversample the minority class instances in imbalanced textual datasets. The suggested method extracts keywords from the minority class data points using a process based on vector similarity. It uses an attention mechanism to extract significant contextual words from minority class documents that correlate to the discovered keywords. Finally, it makes use of those important contextual words to enhance the dataset for the minority class. By creating additional documents based on keywords and their significant contextual terms and adding them to the minority class dataset, the suggested approach oversamples the minority class dataset. The deep learning classifier performs better on the augmented datasets generated by the proposed technique, which is reportedly interpretable.
ADA: An Attention-Based Data Augmentation Approach
479
The remainder of the paper is structured as follows. An overview of the available text data augmentation literature is provided in Sect. 2. The proposed attention-based text data augmentation approach is fully described in Sect. 3. The experimental setup and evaluation results are presented in Sect. 4. Finally, the work is concluded with suggestions for future research in Sect. 5.
2
Related Works
In the case of short text documents like reviews and tweets, where multiple words appear exceedingly seldom, data augmentation becomes crucial. In these circumstances, data augmentation becomes essential for deep learning models to increase their capacity for generalization. Researchers have made a contribution in this area by suggesting several text data augmentation strategies. In [11], authors performed data augmentation using English thesaurus and evaluated using deep learning models. In [8], the authors proposed to append original training sentences with their corresponding predicate-arguments triplets generated by a semantic role labeling tagger. In [10], authors introduced Easy Data Augmentation (EDA), showing that data augmentation using simple operations like synonym replacement, random insertion, random swap, and random deletion over a textual dataset can boost the performance of a classifier on text classification tasks. In [5], authors proposed contextual augmentation for labeled sentences by offering a wide range of substitute words, which a label-conditional bidirectional language model predicts according to the context. In [7], the authors explained that many data augmentation methods could not achieve gains when using large pre-trained language models because they are already invariant to various transformations. Instead, creating new linguistic patterns could be helpful. In [1], authors showed that augmenting n−grams from a minority class document that contains keywords extracted from a minority class dataset using Latent Dirichlet Allocation (LDA) to the same document can improve the performance of the CNN on textual datasets.
3
Proposed Data Augmentation Approach
This section discusses the proposed attention-based text data augmentation mechanism to handle imbalanced textual data. Table 1 gives the statistics of the Amazon reviews datasets used in our experiment. It can be observed from Table 1 that the ratio of the number of positive reviews to negative reviews, i.e., imbalance ratio (IR), is significantly high for all the datasets. So, we consider the positive reviews dataset as the majority class dataset (Xmaj ) and the negative reviews dataset as the minority class dataset (Xmin ). The main goal is to balance the dataset by augmenting the minority class with non-duplicate documents that incorporate additional knowledge to the minority class. In this process, we first extract keywords from the minority class based on a naive similar semantic space concept as discussed in Sect. 3.1. After that, we create a
480
A. K. Sah and M. Abulaish
keyword-based labeled dataset as discussed in Sect. 3.2. We then deploy 2 parallel attention-based BiLSTM on the keyword-based labeled dataset to learn significant words belonging to each document of the minority class that contains the keyword(s) as discussed in Sect. 3.3. We then select the documents from the keywords-based dataset generated and labeled corresponding to keywords and transform them using a language model as discussed in Sect. 3.4. Finally, we oversample the minority class dataset by augmenting the transformed version of the documents, as discussed in Sect. 3.5. Figure 1 illustrates the workflow of the proposed approach.
Fig. 1. Workflow of the proposed data augmentation approach
3.1
Vector Similarity-Based Keywords Extraction
In this section, we discuss how we extract keywords from the minority class (Xmin ) using “Bidirectional Encoder Representations from Transformers” (BERT). BERT is a bi-directional transformer model that helps to capture the meaning of words, phrases, and documents by encoding them to vectors. There is a general notion that word embedding of semantically similar words is close in vector space. With this notion, we propose identifying the keywords from the minority class dataset as those words whose word embedding representation is closer to that of the entire minority class dataset. For this, we first generate embedding corresponding to the entire minority class dataset, then generate the embeddings corresponding to each word in the vocabulary of the minority class
ADA: An Attention-Based Data Augmentation Approach
481
dataset. Towards this direction, for document-level embedding, we prefer to use SBERT, a modification of the pre-trained BERT network originally presented in [9]. SBERT uses siamese and triplet network structures and has proven to be a successful bi-encoder model for generating semantically meaningful sentence embeddings, which we can utilize for textual similarity comparisons using cosine similarity. SBERT generates semantically more acceptable and expressive sentence embeddings by fine-tuning the pre-trained BERT network. At first, we encode the entire minority class documents to a single vector (Vmin ) using SBERT. We encode each document in Xmin using SBERT to extract individual sentence-level embedding. We then average the sentence embeddings corresponding to all the documents in Xmin to get a single minority class embedding vector Vmin . After that, we encode ith word from the minority class vocabulary (V ocabmin ) to its corresponding embedding vector, wi , using BERT. Finally, we calculate cosine similarity (CoSimV alue) between embedding vector of each ith word ∈ V ocabmin , wi with Vmin , to give CoSimV alue(wi , Vmin ) as given by Eq. 1, where CoSimV alue(wi , Vmin ) ∈ [−1, 1]. CoSimV alue(wi , Vmin ) =
wi · Vmin wi Vmin
(1)
We sort words in order of descending CoSimV alue, and to balance the total number of review documents in both the classes in the dataset; we select only the top k words, as discussed in Sect. 3.2. We refer to the top k words from V ocabmin as the minority class keywords (Kmin ). 3.2
Keywords-Based Labeled Dataset Creation
In this section, we discuss the creation of binary labeled dataset Dswe for significant words extraction from minority class dataset Xmin . Towards this direction, for each keyword kw ∈ Kmin , in order of decreasing CoSimV alue, we oversample each review document r ∈ Xmin with respect to every word w ∈ r. We assign class label 0 to the oversampled review if w = kw, and class label 1 otherwise. The main aim of creating this dataset is to generate additional minority class documents required to balance the dataset and extract the significant words of each review document labeled 0 as discussed in the upcoming Sect. 3.3. So, we represent this dataset as significant words extraction dataset, denoted by Dswe , k nk dataset and class 1 documents by Dswe dataset. We class 0 documents by Dswe continue this process until the total number of documents in the minority class dataset, and significant words extraction dataset combined is equal to the numk | = |Xmaj |. ber of documents in the majority class dataset, i.e., |Xmin | + |Dswe 3.3
Attention-Based Significant Words Extraction
In this section, we discuss the process of significant words identification from each minority class review document that contains the keyword(s). We identify the
482
A. K. Sah and M. Abulaish
word corresponding to which a review document r ∈ Dswe has been generated, as discussed in Sect. 3.2, by target word wt where wt ∈ r. We aim to identify the words w ∈ r that contributes the most when predicting the target word wt using the attention mechanism, which is well known for its ability to rank features. Here, we apply the attention mechanism to capture the informative parts of associated contexts. In order to achieve this, we pass each review document r ∈ Dswe dataset through a pair of parallel attention-based 2−layers stacked BiLSTM, followed by a dense layer, and finally through a softmax layer. Let us suppose ri is the ith review document, and ri ∈ Dswe such that ri = {w1 , w2 , . . . , wt , . . . , wn−1 , wn } where wt is the target word and n is the number of words in the review document. Our model aims to learn the importance of each word w ∈ ri while training the model on ri with emphasis on wt , where wt is a target word corresponding to which ri ∈ Dswe has been generated and labeled, as discussed in Sect. 3.3. To this end, we have two parallel attentionbased 2 layers stacked BiLSTM, one encoding the document from the beginning to the target word (BiLST Mb ), and the other from the target word to the end of the document (BiLST Me ) given by Eqs. 2 and 3 respectively. hbwt = BiLST Mb (wt , hbwt−1 )
(2)
hewt
(3)
=
BiLST Me (wt , hewt−1 )
where BiLST Mb and BiLST Me are two employed BiLSTM that model the preceding and following context of the target word independently. Not every word encoded by BiLST Mb and BiLST Me are equally significant. In order to identify the more significant words, we have an attention layer at the top of BiLST Mb and BiLST Me , which helps decode the more significant/informative words by assigning them attention scores. We use the attention mechanism to assign a variable weight to all words (i) from the beginning of the review document to the target word (encoded by BiLST Mb ) and (ii) from the target words towards the end of the review (encoded by BiLST Me ), depending on their contextual importance. For example, for encoded vector Vri corresponding to review document ri ∈ Dswe ; if hidden state representation of a target word wt ∈ Vri given by BiLST M is hwt , then it is passed to a dense-layer to learn its hidden representation hwt , as given by Eq. 4, where W and B represent the weight and bias, respectively. Thereafter, similarity is calculated between hwt and a vertex vector vwt which represents the importance of wt ∈ Vri . We also compute the normalized importance score of wt using Eq. 5. The feature-level context vector vwt is randomly initialized and jointly learned during the training process. Finally, the attention-aware representation of the review document ri is learned and represented as Ari . It is computed as a weighted sum of the hidden representation of each word, as given by Eq. 6.
ADA: An Attention-Based Data Augmentation Approach
483
hwt = tanh(W hwt + B)
(4)
αwt =
(5)
exp(hwt vwt ) w exp(hwt vwt )
A ri =
αwt hwt
(6)
w
Both BiLST Mb and BiLST Me goes through processes in Eqs. 4, 5, and 6 simultaneously. As a result, the attention-based representation corresponding to BiLST Mb and BiLST Me for review document ri are obtained, represented as Abri and Aeri . Afterward, we concatenate these two vectors to generate the final representation vector of the review document ri , pass it through a dense layer with 1024 neurons, and finally through a softmax layer with 2 neurons. We do this to make the model learn and identify the target word given the attentionbased weight distribution of the contextual words. We train the parallel attention-based BiLST M model on Dswe dataset. Once we have trained the model, we extract the attention-based vectors Ab and Ae . These vectors are the attention scores corresponding to words on both sides of the target word wt . We rank the top words on both sides of wt based on their attention scores. In this work, we’ve selected top 15% words corresponding to both the BiLST Mb and BiLST Me . 3.4
Language Model-Based Transformation of Review Documents
In this section, we discuss the process of language model-based transformation k k . We aim to transform a review document r ∈ Dswe of review documents in Dswe to rt such that the transformed review document rt is a semantically similar but non-duplicate version of r. Towards this, we ensure that the words replaced from r to give rt are contextually similar and have the semantically similar meaning as r. To this end, we deploy Fill-Mask task supported by BERT , where some of the words in a sentence are masked, and the BERT model predicts which words best replaces the current word, also known as mask language modeling. These models are helpful when we want to get a statistical understanding of the language in which the model is trained. As BERT is one of the best language models to date for this task, we prefer to use it for our work. We have extracted top k significant words Sw = {Sw1 , Sw2 , . . . , Swk } from k , based on attention score as discussed in each review document r ∈ Dswe Sect. 3.3. Now, for each ith significant word Swi ∈ Sw , we replace it with its most similar word learned by masking and passing it through the BERT model. We follow the hold and predict strategy in which we mask one word and predict the words based on the rest words in the document. In this case, we mask words in order of importance, i.e., their attention score; when we mask a word, the rest of the words remain unchanged. The BERT model then gives the best word r r . We then replace Swi by Sw and repeat replacement for Swi in the form of Sw i i
484
A. K. Sah and M. Abulaish
this process for all the significant words in the review document r, in decreasing order of importance or attention score. Finally, we have the transformed review document rt where all the words w ∈ Sw ∩r are replaced by their best contextual and semantically similar words given by the BERT model. 3.5
Oversampling Minority Class Dataset
In this section, we discuss the oversampling process of the minority class dataset Xmin . We first transform each review document r from the keywords-based k to give the transformed review document rt as discussed in dataset Dswe Sect. 3.4. As we know, ADA aims to balance the number of review documents k has been created such in both classes of the review dataset. In Sect. 3.2, Dswe that augmenting it to Xmin gives the balanced dataset. Therefore, we augment k with Xmin to give oversampled minority class dataset AXmin , such that Dswe |Xmaj | = |AXmin |. So, AXmin is the final augmented minority class dataset. We replace Xmin by AXmin to give the oversampled balanced dataset.
4
Experimental Setup and Results
In this section, we present our experimental setup and discuss the evaluation of the proposed approach. We mention that experiments were performed on a machine with a 2.10 GHz Intel(R) Silver(R) processor and 192G RAM. Our attention-based text data augmentation model was implemented in Keras1 . For BERT pre-trained models, we used Transformers2 library. 4.1
Datasets
We evaluate ADA over 3 publicly available Amazon reviews datasets [4], namely musical instruments (DS1 ), patio lawn and garden (DS2 ), and automotive (DS3 ). We labeled all reviews with star ratings of 1 or 2 as negative reviews, whereas reviews with star ratings of 3, 4, or 5 as positive reviews. Table 1 presents the statistics of the modified datasets, listed in increasing order of the total number of reviews in the dataset. The IR value in Table 4.1 refers to the datasets’ imbalance ratio. Table 1. Statistics of the Amazon review datasets Dataset #Reviews #Xmaj #Xmin IR
1 2
DS1
10,261
9,794
467
20.97
DS2
13,272
12,080
1,192
10.13
DS3
20,473
19,325
1,148
16.83
https://keras.io/. https://huggingface.co/docs/transformers/index.
ADA: An Attention-Based Data Augmentation Approach
4.2
485
Data Preprocessing
The main issue with short text documents, especially review or tweet documents, is that they generally vary significantly from standard grammatical structures and possess predominantly creative spellings developed by the users due to character limitations and the habit of informal writing. Such data needs more special pre-processing than the standard pre-processing techniques, as we might face semantic loss. In order to avoid such a scenario, we performed the following pre-processing tasks: stop-words, URLs, and hashtag symbols removal, resolving elongated words, emoticons handling, resolving contractions, stemming, and lemmatization. 4.3
Classifier Architecture and Training Details
In this section, we present the classification technique used to validate the effectiveness of ADA. We used a 2−layer stacked BiLSTM architecture with 256 cells each, followed by the final softmax layer with 2 neurons, as we have formulated it as a binary classification problem. We have used Xavier Glorot initialization to assign initial weights, adam as an optimizer in our model. Our model used dropout at the BiLSTM and fully connected layers to minimize the overfitting effect, with probability values of 0.2 and 0.5, respectively. Further, our model used a L2 regularizer with a value of λ as 0.01 over the binary cross-entropy loss function. We used Rectified Linear Unit (ReLU) as an activation function throughout the model, except in the output layer, where we used the softmax function. We have used the softmax probability function in the last layer. For classification tasks throughout this work, we have used 300−dimensional GloVe embeddings trained on the Common Crawl dataset with 840B tokens. For BERT-related tasks, we have used the BERT base uncased pre-trained model proposed in [2]. Table 2 gives the statistics of the total number of keywords extracted corresponding to different Amazon reviews datasets to generate the keyword-based labeled dataset, based on the discussion in Sect. 3.2. Table 2. Number of keywords extracted corresponding to different Amazon review datasets Dataset #Keywords
4.4
DS1
2, 076
DS2
1, 160
DS3
2, 494
Evaluation Metrics
There are very few metrics to consider when we require to evaluate the classifier on imbalanced data [3]. When the dataset is skewed, we should consider choosing
486
A. K. Sah and M. Abulaish
evaluation metrics such that the classifier’s performance on the majority class does not overshadow its performance on the minority class. Hence, we evaluated the performance of the classification model throughout our experiments only for the minority class, and the macro averaged ones. We use precision (P R), recall (DR), F1 measure (F1 ), macro precision (M acP R), macro recall (M acDR), and macro F1 (M acF1 ) measure as evaluation metrics during the experimentations in this work. We chose these evaluation metrics to study the classifier’s performance on the minority class and observe whether there is any highly adverse impact on the majority class of the dataset. Table 3. Comparative performance evaluation results of ADA on minority class DS1 PR
DR
F1
DS2 PR
DR
F1
DS3 PR
DR
F1
Original Dataset 45.45
10.42
16.95
37.21
13.48
19.80
35.86
15.03
21.18
EDA [10]
95.37
98.92
97.11
90.86
97.90
94.25
90.35
98.69
94.33
CDA [5]
95.65
97.15
96.39
95.98 94.03
95.00
94.95
97.31
96.12
ADA
96.13 99.31 97.70 92.70
Approach
98.00 95.28 96.76 98.83 97.78
Table 4. Macro comparative performance evaluation results of ADA Approach
DS1 DS2 DS3 M acP R M acDR M acF1 M acP R M acDR M acF1 M acP R M acDR M acF1
Original Dataset 70.61
54.90
57.25
64.61
55.62
57.30
65.48
56.71
EDA [10]
97.12
97.03
97.04
94.17
93.56
93.71
94.72
95.87
95.16
CDA [5]
96.44
96.46
96.45
95.13
95.09
95.10
96.15
96.15
96.16
ADA
97.64
97.28
97.43
95.30
95.17
95.16
97.50
96.99
97.23
4.5
58.95
Comparison Approaches
In order to establish the efficacy of the proposed model on imbalanced data, this section presents the comparative performance evaluation of ADA with the following two standard text data augmentation techniques, namely – (i) EDA – Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks [10], and Contextual Augmentation – Data Augmentation by Words with Paradigmatic Relations [5]. 4.6
Evaluation Results and Comparative Analysis
We oversampled the minority class of the original or non-augmented dataset by augmenting new review documents generated from the proposed attention-based text data augmentation technique to give a balanced dataset. We evaluated both the original and balanced versions of datasets on the BiLSTM model in order to study the effectiveness of our text data augmentation mechanism. Similarly,
ADA: An Attention-Based Data Augmentation Approach
487
for the comparative study, works done in [5,10] were evaluated following a similar approach, i.e., we oversampled the minority class dataset with new review documents generated from the respective text data augmentation mechanisms such that it resulted to a balanced dataset. We consider the evaluation metrics discussed in Sect. 4.4 for evaluation purpose. We trained the BiLSTM model on 56% of the dataset, validated it on 14%, and finally tested the model on 30% unseen data. We trained the BiLSTM model for 100 epochs with early stopping as a regularization mechanism to combat overfitting and have recorded the results obtained on test data. Table 3 lists the classifier’s performance on the minority class dataset, and Table 4 lists the evaluation results of the classifier macro averaged over both the majority and the minority class dataset. Performance on Minority Class: Table 3 shows that the DR value in particular, on the original datasets, was extremely poor and ranged between a minimum of 10.42% for DS1 and 15.03% for DS3 . However, we observed a radical significant margin improvement on the oversampled datasets, with a minimum of 94.03% for DS2 using CDA and a maximum of 99.31 for DS1 using our proposed approach. ADA outperformed both the EDA and CDA in terms of DR and F1 value over all the datasets. It was fascinating to observe that the DR value on the oversampled version of the datasets obtained using the proposed approach never fell below 98%, which was for the DS2 dataset. We also observed that the F1 value on the oversampled version of the datasets obtained using ADA never fell below 95.28% for DS2 and reached the maximum 97.78% in the case of DS3 . Further, in terms of P R, the performance of ADA was reported to be better than EDA and CDA on DS1 and DS3 , while CDA surpassed ADA on DS2 . We observed that EDA surpassed CDA over all the datasets in terms of DR, whereas in terms of P R and F1 , EDA surpassed CDA over all the datasets except DS2 . Macro Performance: Table 4 shows that ADA surpasses EDA and CDA in terms of all M acP R, M acDR, and M acF1 . Even in terms of M acF1 , ADA beat EDA and CDA by a wider margin on DS3 , the largest dataset, than DS1 and DS2 , the other two smaller datasets. The reported macro-averaged performance result suggests that ADA generates a qualitative augmented and oversampled dataset, which remarkably improves the classifier performance on the minority class and does not hamper its performance on the majority class. We observed that CDA performed comparatively better than EDA over DS2 and DS3 datasets on all evaluation criteria except DR. However, in terms of recall performance, EDA is better than CDA and comparable to ADA. Also, the performance analysis in Sect. 4.6 and 4.6 suggests that ADA gives the best text data augmentation model compared to EDA and CDA.
5
Conclusion and Future Work
An Attention-Based Data Augmentation (ADA) method is presented in this paper as a solution to the class imbalance issue in processing textual datasets.
488
A. K. Sah and M. Abulaish
Compared to the state-of-the-art methods (EDA and CDA), ADA offers observable advantages. For deep learning models that extract patterns from the data, the oversampled augmented dataset may be useful. It appears to be extremely helpful for fields with tiny and unbalanced datasets because it attempts to solve the issue of information scarcity. It appears like a potential topic of research to investigate various keyword extraction processes and provide a unique model to learn the best phrases that can replace the significant words discovered in Sect. 3.3. To process imbalanced textual datasets, we are attempting to improve our suggested approach to produce more qualitative documents.
References 1. Abulaish, M., Sah, A.K.: A text data augmentation approach for improving the performance of CNN. In: 11th International Conference on COMSNET, Bangalore, India, pp. 625–630 (2019) 2. Devlin, J., Chang, M.W., Lee, K., Toutanova, K.: BERT: pre-training of deep bidirectional transformers for language understanding. In: Proceedings of the Conference of the North American Chapter of the ACL: Human Language Technologies, Minnesota, pp. 4171–4186 (2019) 3. Ferri, C., Hernández-Orallo, J., Modroiu, R.: An experimental comparison of performance measures for classification. PR Lett. 30(1), 27–38 (2009) 4. He, R., McAuley, J.: Ups and downs: modeling the visual evolution of fashion trends with one-class collaborative filtering. In: Proceedings of the 25th Int’l Conference on WWW, p. 507–517 (2016) 5. Kobayashi, S.: Contextual augmentation: data augmentation by words with paradigmatic relations. In: Proceedings of the Conference of the North American Chapter of the ACL-HLT, Louisiana, pp. 452–457 (2018) 6. Krizhevsky, A., Sutskever, I., Hinton, G.E.. In: F. Pereira et al. (ed.) Advances in Neural Information Processing Systems, Nevada, USA 7. McCoy, T., Pavlick, E., Linzen, T.: Right for the wrong reasons: diagnosing syntactic heuristics in natural language inference. In: Proceedings of the 57th Annual Meeting of the ACL, Florence, Italy, pp. 3428–3448 (2019) 8. Min, J., McCoy, R.T., Das, D., Pitler, E., Linzen, T.: Syntactic data augmentation increases robustness to inference heuristics. In: Proceedings of the 58th Annual Meeting of the ACL, pp. 2339–2352 (2020) 9. Reimers, N., Gurevych, I.: Sentence-BERT: sentence embeddings using Siamese BERT-networks. In: Proceedings of the Conference on EMNLP and IJCNLP, Hong Kong, China, pp. 3982–3992 (2019) 10. Wei, J., Zou, K.: EDA: easy data augmentation techniques for boosting performance on text classification tasks. In: Proceedings of the Conference on EMNLP and IJCNLP, Hong Kong, China, pp. 6382–6388 (2019) 11. Zhang, X., Zhao, J.J., LeCun, Y.: Character-level convolutional networks for text classification. In: Advances in Neural Information Processing Systems, Quebec, Canada, pp. 649–657 (2015) 12. Zhong, Z., Zheng, L., Kang, G., Li, S., Yang, Y.: Random erasing data augmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, New York, vol. 34, pp. 13001–13008 (2020)
Countering the Anti-detection Adversarial Attacks Anjie Peng1,2 , Chenggang Li1 , Ping Zhu3(B) , Xiaofang Huang1 , Hui Zeng1(B) , and Wenxin Yu1 1
3
Southwest University of Science and Technology, Mianyang 621010, Sichuan, China [email protected] 2 Science and Technology on Communication Security Laboratory, Chengudu 621010, Sichuan, China Chengdu University of Information Technology, Chengudu 621010, Sichuan, China
Abstract. The anti-detection adversarial attack is an evolutionary attack. It can both fool a CNN model to give error classification outputs and evade some detection-based defenses. In this paper, we aim at detecting the adversarial images generated by a typical anti-detection attack which can evade some existing noising-based detectors. We find that this anti-detection attack makes shifting effects in the residual domain. Specially, zero residual elements shift to non-zero elements, and shifting also occurs among non-zero residual elements. Those shifting effects inevitably change the co-occurrence relationships among neighbor residual elements. Furthermore, the attacker considers the R, G, and B channels are isolated when adding the adversarial perturbation, which further disturbs their co-occurrence relationships. So, we propose the 3rd -order co-occurrence probabilities of R, G, and B residuals as features and construct a binary ensemble classifier to detect the anti-detection adversarial images. Experimental results show that the proposed method achieves detection accuracy >99% and >96.9% on ImageNet and Cifar-10 respectively, outperforming state-of-the-arts. In addition, the proposed method has good generalization ability and is difficult to be attacked again.
Keywords: Convolution Neural Network Detection · Co-occurrence Probability
1
· Adversarial Images
Introduction
Today, digital image classification based on convolution neural networks (CNN) has become the infrastructure for many computer-vision tasks. However, the adversarial attacks aiming at fooling CNN models greatly hinder in-depth applications of CNNs in the security-sensitive industries, such as self-driving cars [1], face detection [2]. Attackers generate adversarial examples via carefully adding This work was partially supported by NSFC (No. 41865006), Sichuan Science and Technology Program (No. 2022YFG0321, 2022NSFSC0916). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 489–500, 2023. https://doi.org/10.1007/978-981-99-1639-9_41
490
A. Peng et al.
imperceptive perturbations onto clean images, force the CNN to give error outputs, i.e., either to a targeted class (targeted attack) or to any class rather than the original one (untargeted attack). Since the first attack L-BFGS [3] was proposed in 2013, many attacks have been proposed. According to the perturbation budget, attacks can be categorized into two types [4]: constrained attack (FGSM [5], BIM [6], PGD [7], BP [8], SI-TI-DI-NID-FGSM [9]), and minimum attack (C&W [10], DeepFool [11]). These attacks obtained nearly 100% attack success rates (ASR) under white-box settings. Besides of adversarial training [36], detecting the adversarial image is an important defense against the adversarial attack. The detection-based defense detects the adversarial images firstly and rejects to input them into the CNN classifiers. Assuming that the adversarial perturbation is sensitive and easily disturbed, transformation-based methods exploited the changes of CNN before and after transformations to detect adversarial examples (e.g., noising [12,34, 35], FGSM attack [13], denoising [14], quilting [15], compression [16], feature squeezing [17]). Some methods utilized the distribution differences between the adversarial images and clean images to construct the detector, e.g., principal component analysis [18], kernel density estimation [19], Magnet [20]. The above-mentioned detection defenses can be viewed as static methods, because they do not consider or fully consider that the cunning attackers may design new attacks to evade them. We call this evolution attack as the antidetection attack1 , which not only evades the given detection-based defense, but also can continue to fool CNN models. The early ten detection methods are bypassed by the anti-detection attacks proposed by Carlini et al. [21]. Recently, defenders [12,13] begin to evaluate their methods against both static and antidetection attacks. However, these defenses are still defeated by strong antidetection attacks [22]. In this work, we try to detect the typical anti-detection attack [22] aiming at evading the detection proposed by Roth et al. [12]. This attack [22] probably also can defeat some other noising-based detections [34,35], because it makes the adversarial samples behave similarly with the clean samples when adding the Gaussian noise. The co-occurrence probability is extracted as the feature due to that (1) It can differentiate the adversarial example from the clean example and (2) It may defend against the potential unknown anti-detection attacks or at least sets up some obstacles when the attackers design new anti-detection adversarial attack. So far, we have found that the detections using multi-dimension cooccurrence features for steganalysis [23,24] are not susceptible to the new antidetection attack. Our contributions are listed as follows: (1) We find shifting effects in the residual domain caused by the anti-detection attack and propose the co-occurrence probability features of R, G, and B residual channels to detect the attack. To our best knowledge, there are very 1
Some works call this attack as the adaptive attack which is referred to the existing or unknown attacks against the specific defense. We focus on detecting the existing adaptive attack and call it as the anti-detection attack to make a clear understanding.
Countering the Anti-detection Adversarial Attacks
491
few works countering the anti-detection adversarial attack [22]. Our work adds a tool to the defender under the attack-defend confrontation. (2) We achieve detection accuracies of >99%, >96.9% on ImageNet and Cifar10, outperforming existing methods. Besides, our method has good generalization ability and is difficult to be attacked again by the new anti-detection attacks.
2
Background
To better understand our method, we briefly introduce the initial detection method [12] and the adaptive attack [22]. Initial Detection Method: The initial detection [12] aims at detecting the initial attack PGD [7] and C&W [10] which fool the CNN classifiers. Roth et al.. [12] observed that the adversarial image x is less robust to the Gaussian noise than a legitimate image x in the logit output layer z(·) of a CNN (before the softmax layer). Specially, given a scalar z(x)i as the logit output of the ith class, if using a Gaussian noise σ, it holds that z(x)i ≈ z(x + σ)i but z(x )i and z(x + σ)i will differ significantly. Hence, Roth et al. [12] proposed f = E{[(z(x + σ)t − z(x + σ)y ) − (z(x)t − z(x)y )]} as the feature, where E(·) as the feature, where k independent noise σ, y is the predicted class of the CNN, t(t = y) is any other class. The test image is rejected as adversarial if f > T , where T is a threshold controlled by a fixed false positive rate (FPR). It achieved detection rates of 99.1%/71.4% for defense-unaware/-aware PGD attacks against WideResNet [7] on Cifar-10. Assuming that the adversarial image is not robust against the Gaussian noise, some methods [34,35] also employ noising transformations to detect the PGD attack. Anti-detection Adversarial Attack: Tram`er et al. [22] proposed an antidetection attack to evade the detection defense [12], and continue to fool the CNN classifier. They made z(x + σ)t − z(x + σ)y and z(x )y − z(x )t be small, and finally caused the feature f of x be similar with that of x. To this end, they generated x that targeted the logit outputs z(xl ) of an arbitrary benign example xl with label l via minimizing L2 norm logit-matching loss function ||z(x ) − z(xl )||22 using PGD [7]. The logit-matching can reduce z(x )y − z(x )t effectively. To reduce z(x + σ)t − z(x + σ)y expectation over transformation (EOT) [25] is used on the logit loss to increase the robustness of adversarial perturbations against noise. This attack successfully fooled the WideResNet [7], and reduced the detection accuracy of the defense [12] from 99.1% to nearly 0%. Since this anti-detection attack makes the adversarial images behave similarly with the clean images when adding the gaussian noise, it also may evade the noising transformation-based detections [34,35].
3
The Proposed Method
In this section, we propose a method to detect the anti-detection adversarial attack proposed by Tram`er et al. [22]. This attack can be formulated as an
492
A. Peng et al.
iterative attack shown in (1), where EOT(·) computes the expectation of the gradient on logit-matching loss ∇x ||z(x(i) + σ) − z(xl )||22 at the ith iteration, (i)
sign(·) is sign function, α is the step-size, clipx, (·) makes x satisfy L∞ norm constraint |x − x | ≤ . Finally, the perturbation η ∈ {−, − + 1, ..., 0, ..., − 1, } is added onto x to generate the adversarial image x . When adding the perturbation, the attackers consider R, G, and B channels are isolated, thus it may add different perturbations for the neighboring pixels of R, G, and B channels, which inevitably changes the neighbor relationships among channels. To measure the changes of those neighbor relationships, we will extract the 3rd -order co-occurrence probabilities across from R, G, B residual channels as features.
x(0) = x
x(i+1) = clipx, {x(i) − αsign(EOT (∇x ||z(x(i) + σ) − z(xl )||22 ))}
(1)
(i)
3.1
The Residuals Used for the Proposed Feature
Inspired by blind forensics and steganalysis [26–28], we analyze the statistical properties of adversarial images from the residual domain, which benefits to suppress image contents, thus stressing the added perturbations, and highlighting the statistical differences between clean images and adversarial images. (r) (g) (b) Let x = {R, G, B} = {xij , xij , xij }, 1 ≤ i ≤ m, 1 ≤ j ≤ n be an 8-bit RGB image. Considering that the statistical information of image may be affected 1st , by image size and post-processing operations (such as down-sampling), Sij EDGE3×3 EDGE5×5 Sij and Sij are used to reveal statistical information as full as possible. Figure 1 shows the filtering kernels of each residual. The residual for R, G, and B is obtained by convoluting the kernel with the corresponding channel. EDGE3×3 1st is preferable for the image in ImageNet, while Sij and We find that Sij EDGE5×5 Sij are preferable for Cifar-10. The reason may be that the image of size 32 × 32 × 3 in Cifar-10 is generated by a large down-sampling factor, and EDGE3×3 EDGE5×5 and Sij with its statistical information can be enhanced by Sij large filter kernels. To curb the dynamic range for reducing the feature dimension EDGE3×3 EDGE5×5 1st , Sij , and Sij are divided with quantization step 1, 4, [26], Sij 12 respectively, then truncated quantized values into [−T, T ].
EDGE3×3 EDGE5×5 1st Fig. 1. Filtering kernels: left: Sij ; middle: Sij ; right: Sij .
Countering the Anti-detection Adversarial Attacks
493
We find that the anti-detection attack [22] causes shifting effects in the resid 1st 1st 1st +δij , where Sij /Sij1st is the 1st -order difference of ual domain. Let Sij1st = Sij 1st 1st is the variation between Sij and Sij1st . Obvithe clean/adversarial image, δij 1st 1st will make zero Sij shift to non-zero Sij1st . Since there are ously, non-zero δij 1st 1st in the natural clean image, if many δij are non-zero, it will cause many zero Sij many shifts (zero residual element shifts to non-zero element). We empirically 1st are non-zero when attacking using find that more than 75% elements of δij = 5 on ImageNet described in Sect. 4. In the following, we will explain why 1st are non-zero. most elements of δij 1st It can be easily inferred that δij = ηij − ηij+1 , where ηij , ηij+1 is the perturbation added on the clean pixel xij , xij+1 . We will explain that most of ηij 1st and ηij+1 are different, thus causing many non-zero elements of δij . From (1), we can infer that ηij is determined by the gradient ∇x ||z(x + σ) − z(xl )||22 . We introduce how to calculate the gradient for a n-layers CNN with the logit output z(x) = fn (fn−1 (...f1 (x))), where fj is a function of the j th layer representing a combination of convolution, pooling, activations, etc.. According to the derivative chain rule, it holds that ∇x ||z(x + σ) − z(xl )||22 = 2{z(x + σ) − z(xl )} fn fn−1 ... f1 (x ), where is the Hadamard product, fj is the derivative of fj . When calculating the logit loss ||z(x + σ) − z(xl )||22 , xij and xij+1 are convoluted with different kernels of fj in the forward propagation, so the derivatives fj (1 ≤ j ≤ n) calculated in the backward propagation are different for xij and xij+1 . Because of different gradients for xij and xij+1 , even the gradient is processed by a sign function, xij and xij+1 will probably be different after finishing all iterations. Due to the shifting effects, the histograms in Fig. 2 show that the distributions of Sij1st for the adversarial images (black dotted line) become flat near zero bin but steep at some non-zero bins in the green channel. The similar plots can be obtained for the red and blue channel. Analogically, the shift also occurs EDGE3×3 EDGE5×5 / Sij and SijEDGE3×3 / SijEDGE5×5 . Those shifts will between Sij inevitably change the co-occurrence probabilities of adjacent residual elements among R, G, B channels. 3.2
The Proposed Feature
Now, we analyze how the anti-detection attack [22] change the 3rd -order cooccurrence probabilities among R, G, B channels in the residual domain. The 3rd -order co-occurrence probability can be viewed as the histogram of the triplet (r) (g) (b) (Sij , Sij , Sij ). For the clean image, because of the de-mosaicking effect of (r)
(g)
(b)
Bayer color filter array when generating RGB color image, Sij , Sij and Sij tend to be nearly equal. We denote this triplet as the “ALL” type in Table 1. It is shown that there are 45.95% and 54.27% “ALL” triplets on Cifar-10 and ImageNet, respectively. However, such relationship is disturbed by the attack. The shift on the single element from Sij to Sij at each channel will transfer (r)
(g)
(b)
to the shift from the triplet (Sij , Sij , Sij ) of the clean image to the triplet
494
A. Peng et al.
1st Fig. 2. The normalized distribution (y axis) of Sij /Sij1st (x axis) evaluated from 2000 clean (cyan line)/adversarial (black dotted) images using green channel on ImageNet. We also show their averaged values by the red line and green dotted line respectively. 1st at the Due to the shifting effects, it is shown that Sij1st has fewer elements than Sij zero bin. (Color figure online)
(r)
(g)
(b)
(Sij , Sij , Sij ) of the adversarial image. Table 1 shows that the attack using the attack budget ( = 1) makes more “ALL” triplets transfer to “PART” and “NONE”, while less “PART” and “NONE” triplets transfer back to “ALL”. For example, the attack on ImageNet makes the average ratio of “ALL”, “PART” and “NONE” change from 0.5427 to 0.4264, 0.2182 to 0.3232, and 0.2391 to 0.2504, respectively. For some “PART” and “NONE” triplets of the clean image which has relatively small co-occurrence probabilities, the shifting on the triplet makes the co-occurrence probabilities of these “PART” and “NONE” triplets in the adversarial image increase sharply, enabling these triplets can distinguish the adversarial image from the clean image as shown in Fig. 3. Table 1. Average distributions of three types (ALL, PART, NONE) of triplets for the clean and adversarial image generating using the attack budget = 1. ALL: Three (r) (g) (b) elements are equal, e.g., Sij = Sij = Sij for the clean image; PART: Only two of three are equal; NONE: None of three are equal. Image type
EDGE3×3 1st Cifar-10, 5000 images, Sij ImageNet, 2000 images, Sij
ALL Clean
PART NONE
ALL
PART NONE
0.4595 0.3884 0.1521
0.5427 0.2182 0.2391
Adversarial ( = 1) 0.3709 0.3928 0.2363
0.4264 0.3232 0.2504
The 3rd -order co-occurrence probability denoted by Cd1st of Sij 0 d1 d2
1st(r)
(r) xij+1
(r) xij
=
− is calculated as (2), where [·] is the Iverson bracket which is 1 if satisfying the conditions, and 0 otherwise. CdEDGE3×3 and CdEDGE5×5 can be 0 d1 d2 0 d1 d2
Countering the Anti-detection Adversarial Attacks
495
obtained analogically. = Cd1st 0 d1 d2
m,n−1
1st(r)
[Sij
1st(g)
= d0 , Sij
1st(b)
= d1 , sij
= d2 ]/{m × (n − 1)}
(2)
i,j=1
F = [h(C 1st ), h(C EDGE3×3 ), h(C EDGE5×5 )] 1st
EDGE3×3
(3)
EDGE5×5
The co-occurrence matrix C , C and C all have (2T + 1)3 elements. Based on the symmetric properties of co-occurrence [26], we use a dimension reduction function h(·) to reduce feature dimension and concatenate these three co-occurrences to form the final feature F as in (3), which 3 has k=0 3(2T + 1)k /4 elements. We set T =5, because the attacks cause statistical changes mainly within [-5, 5] shown in Fig. 2. Under T =5, the feature F has 1098 dimensions totally. A larger T will not significantly increase the detection accuracy, but dramatically increases the feature dimension. We use the ensemble classifier with default settings [26] to construct the detector which detecting adversarial images from clean images. To visually demonstrate the classification ability of the proposed feature, we select the feature elements of top three Fisher Criterion Scores (FCS ) [33] to draw scatter plots. The higher FCS element means the stronger classification ability. Figure 3 shows that top-3 FCS feature elements separate most clean images from adversarial images. Due to shifting effects, all top-3 FCS feature elements in Fig. 3 originate from “PART” and “NONE” triplets. We also find that the top-1 elements for Cifar-10 and ImageNet are from C EDGE3×3 and C 1st respectively, verifying that the proposed residuals benefit for detecting images from different sources.
Fig. 3. Scatter plots of top-3 FCS elements of 5000/2000 clean (red) images and adversarial (green, = 5) images estimated from the proposed feature on Cifar-10 (left)/ImageNet (right). Please refer to Sect. 4 for details of database. (Color figure online)
496
4
A. Peng et al.
Experimental Results
4.1
Experiment Settings
For the Cifar-10 test set and validation dataset of ImageNet-1000(ILSVRC2012), a WideResNet [7] and a pretrained Resnet-1012 are attacked, respectively. Default settings of the anti-detection attack3 [22] except the attack budget are set as follows: step number N=100, step size α = (2.5)/N , 40 samples of N (0, 1) Gaussian noise are used in the EOT at each step. As 100×40=4000 times of gradient calculations are needed for each batch, to save time, 5000 and 2000 images are randomly selected from the Cifar-10 and ImageNet, respectively. We find that the attack with large ≥ 6 will not increase the revenue significantly (the attack with = 6 obtained nearly 100% ASR, reducing detection rate of the defense [12] to 0), but will reduce image quality, so we set = 1, 2, 3, 4, 5, 6. The state-of-the-arts: ESRM [24], SRNet [30], detection filter [14] (DF), feature squeezing [17] (FS), and WISERNet [29] are considered, where DF and FS are famous transformation-based methods, the other three are steganalysis-based methods. The threshold of FS is fixed under FPR=5%. Because SRNet is used for gray images, we only employ it on the green channel. For all experiments, half of the adversarial samples and clean samples are training. The rest half is for testing. Under these equal priors, we report the detection accuracy (Acc) i.e., # (correctly predicted samples) / # (all test samples) in the test. Table 2. Detection accuracies (Acc, %) of clean images and their anti-detection adversarial images on Cifar-10. Best results are displayed in bold. L∞ attack budget Cifar-10, Attacked model: WideResNet [7] Training set: 2500 clean images + 2500 adversarial images Testing set: 2500 clean images + 2500 adversarial images ESRM [24] DF [14] FS [17] WISERNet [29] SRNet [30] Our method =1
87.96
56.34
55.94
57.2
51.3
96.98
=2
94.08
66.00
64.88
56.38
50.78
98.78
=3
96.96
71.98
72.28
61.96
51.12
99.38
=4
98.10
73.46
76.26
70.2
58.24
99.68
=5
98.86
69.56
77.12
77.54
77.92
99.84
=6
99.24
62.54
75.82
83.74
76.84
99.86
4.2
Comparative Results
Tables 2 and 3 show that the proposed method achieves Acc>98.8% for all tests except for Cifar-10 test using the attack budget = 1(Acc=96.98%). 2 3
https://hub.tensorflow.google.cn/google/supcon/resnet v1 101/imagenet/ classification/1. https://github.com/wielandbrendel/adaptive attacks paper/tree/master/02 odds.
Countering the Anti-detection Adversarial Attacks
497
These results verify that the proposed method can effectively detect the adversarial images generated by the anti-detection attack [22]. Besides, the proposed feature extracted from multiple residuals of R, G, B channel is stable for the different attack budgets and image sources, outperforming all compared methods, especially for the attack with a low attack budget. For example, on Cifar-10 test using = 1, the proposed method achieves about 9% higher in terms of Acc than the second-best method ESRM. Notice that ESRM features are 34671 dimensions, whereas the proposed features are only 1098. Transformation-based methods DF and FS perform well on ImageNet but deteriorate on Cifar-10. SRNet and WISERnet are used for steganalysis, not tailored for detecting adversarial images, thus performing worst in the test. Table 3. Detection accuracies (Acc, %) of clean images and their anti-detection adversarial images on ImageNet. Best results are displayed in bold. L∞ attack budget ImageNet-1000, Attacked model: Resnet-101 Training set: 1000 clean images + 1000 adversarial images Testing set: 1000 clean images + 1000 adversarial images ESRM [24] DF [14] FS [17] WISERNet [29] SRNet [30] Our method =1
81.60
79.70
83.05
57.55
58.05
99.75
=2
83.90
89.55
90.60
64.45
67.55
99.85
=3
86.05
85.00
83.90
69.95
74.90
99.70
=4
86.75
90.55
91.25
72.9
80.20
99.40
=5
89.20
90.55
91.05
77.5
83.95
99.80
=6
90.50
90.65
90.80
81.4
85.95
99.85
4.3
Discussions of Generalization Ability and Security
We employ the unseen clean images and adversarial images to test the generalization ability. Specially, the proposed detector is trained from the first 2500 clean images in Cifar-10 test set and 900( = 1)+900( = 3)+700( = 5) adversarial images attacking WideResNet [7] under default settings described in Sect. 4.1, but testing on the 5001th − 7500th legitimate and adversarial images on VGG4 3.5 using different settings: = 2, 4, 6, N = 150, α = 1.5 N ( = 4, 6) or α = N ( = 2), sampled noise. Due to larger N, , it can be seen that the test adversarial images probably suffer stronger attacks than the training adversarial images. Under the generalization ability test, our method achieves Acc=97.94%, nearly being the average Acc of = 2, 4, 6 under the baseline test shown in Table 2. To demonstrate effectiveness of the proposed feature for a different machine learning tool, we also train an RBF-SVM detector [31] with setting hyperparameters c=128, g=0.5, and obtain Acc=98.40%, nearly being the same with the used ensemble classifier tool. These results show that the proposed detector is not strongly correlated with CNN model and attack parameters, indicating that it has good generalization ability and is potentially useful in the application. 4
https://hub.tensorflow.google.cn/deepmind/ganeval-cifar10-convnet/1.
498
A. Peng et al.
We think the proposed detector can resist some unknown adaptive attacks to a certain extent, or at least set some obstacles for the attacker to design a new adaptive attack. Firstly, causing the proposed composite feature to fail is difficult. To verify this, we discard the feature elements of top-100 FCS (note that these features have high differentiate ability) to assume that these elements are completely breached, and execute the above generalization ability test using the remaining 998-D feature only. Surprisingly, we achieve Acc= 96.80%, just lowering 1.14% of Acc. So, it is probably unfeasible to disable our method by attacking some feature elements separately. Secondly, even if the attacker can design some anti-features to attack the proposed detector, how to convert the transformations in the feature domain into perturbations in the pixel domain is difficult as indicated in the anti-forensics [32]. We have found two related attacks proposed by Chen et al. [37] and Bryniarski et al. [39]. However, Chen et al. [37] only attacked SVM-based forensic techniques using SPAM [38] features without considering fool the CNN classifier. Bryniarski et al. [39] evaded a simple steganalysis-based detector using a 3-layer fully connected neural network on the 1st -order SPAM. Modifying them to fool the CNN classifier and evade our detector using the ensemble classifier on the 3st -order co-occurrence is difficult. Thirdly, our proposed detector is random, because the feature subspaces and the bootstrap samples used in the ensemble classifier tool are controlled by random seeds. The random property of our detector further increases the difficulty for the attacker to design the new adaptive attack.
5
Conclusion
In this work, we propose an effective method to detect anti-detection adversarial images. Experimental results show that the proposed method performs nearly perfectly on Cifar-10 and ImageNet. Besides, the proposed method has generalization ability to deal with unknown image sources to some extent. Our work indicates that the anti-detection adversarial attack as an evolution version of the early defense-unaware attack can also be defeated. The confrontation between attackers and defenders will continue. We will extend the proposed method for detecting other types of anti-detection adversarial images and enhance its security against unknown or strong adaptive attacks.
References 1. Zhang, J., Lou, Y., Wang, J., Wu, K., Lu, K., Jia, X.: Evaluating adversarial attacks on driving safety in vision-based autonomous vehicles. IEEE Internet Things J. 9(5), 3443–3456 (2021) 2. Zhang, K., Zhang, Z., Li, Z., Qiao, Y.: Joint face detection and alignment using multitask cascaded convolutional networks. IEEE Signal Process. Lett. 23(10), 1499–1503 (2016) 3. Szegedy, C., et al.: Intriguing properties of neural networks. arXiv:1312.6199 (2013)
Countering the Anti-detection Adversarial Attacks
499
4. Dong, Y., et al.: Benchmarking adversarial robustness on image classification. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 321–331 (2020) 5. Goodfellow, I.J., Shlens, J. and Szegedy, C.: Explaining and harnessing adversarial examples. arXiv:1412.6572 (2014) 6. Kurakin, A., Goodfellow, I.J., Bengio, S.: Adversarial examples in the physical world. In: Artificial Intelligence Safety and Security, pp. 99–112. Chapman and Hall/CRC (2018) 7. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. arXiv:1706.06083 (2017) 8. Zhang, H., Avrithis, Y., Furon, T., Amsaleg, L.: Walking on the edge: fast, lowdistortion adversarial examples. IEEE Trans. Inf. Forensics Secur. 16, 701–713 (2020) 9. Wan, C., Ye, B., Huang, F.: PID-based approach to adversarial attacks. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 10033–10040 (2021) 10. Carlini, N., Wagner, D.: Towards evaluating the robustness of neural networks. In: 2017 IEEE Symposium on Security and Privacy (SP), pp. 39–57 (2017) 11. Moosavi-Dezfooli, S.M., Fawzi, A., Frossard, P.: DeepFool: a simple and accurate method to fool deep neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2574–2582 (2016) 12. Roth, K., Kilcher, Y., Hofmann, T.: The odds are odd: a statistical test for detecting adversarial examples. In: International Conference on Machine Learning, pp. 5498–5507 (2019) 13. Wu, Y., Arora, S.S., Wu, Y., Yang, H.: Beating attackers at their own games: adversarial example detection using adversarial gradient directions. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 2969–2977 (2021) 14. Liang, B., Li, H., Su, M., Li, X., Shi, W., Wang, X.: Detecting adversarial image examples in deep neural networks with adaptive noise reduction. IEEE Trans. Dependable Sec. Comput. 18(1), 72–85 (2018) 15. Guo, C., Rana, M., Cisse, M., Van Der Maaten, L.: Countering adversarial images using input transformations. arXiv:1711.00117 (2017) 16. Yin, Z., Wang, H., Wang, J., Tang, J., Wang, W.: Defense against adversarial attacks by low-level image transformations. Int. J. Intell. Syst. 35(10), 1453–1466 (2020) 17. Xu, W., Evans, D., Qi, Y.: Feature squeezing: Detecting adversarial examples in deep neural networks. arXiv:1704.01155 (2017) 18. Li, X., Li, F.: Adversarial examples detection in deep networks with convolutional filter statistics. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 5764–5772 (2017) 19. Feinman, R., Curtin, R.R., Shintre, S. and Gardner, A.B.: Detecting adversarial samples from artifacts. arXiv:1703.00410 (2017) 20. Meng, D., Chen, H.: Magnet: a two-pronged defense against adversarial examples. In: Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security, pp. 135–147 (2017) 21. Carlini, N. and Wagner, D.: Adversarial examples are not easily detected: bypassing ten detection methods. In: Proceedings of the 10th ACM Workshop On Artificial Intelligence and Security, pp. 3–14 (2017) 22. Tramer, F., Carlini, N., Brendel, W., Madry, A.: On adaptive attacks to adversarial example defenses. Adv. Neural. Inf. Process. Syst. 33, 1633–1645 (2020)
500
A. Peng et al.
23. Fan, W., Sun, G., Su, Y., Liu, Z., Lu, X.: Integration of statistical detector and Gaussian noise injection detector for adversarial example detection in deep neural networks. Multimed. Tools Appl. 78(14), 20409–20429 (2019). https://doi.org/10. 1007/s11042-019-7353-6 24. Liu, J., et al.: Detection based defense against adversarial examples from the steganalysis point of view. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4825–4834 (2019) 25. Athalye, A., Engstrom, L., Ilyas, A. and Kwok, K.: Synthesizing robust adversarial examples. In: International Conference on Machine Learning, pp. 284–293 (2018) 26. Fridrich, J., Kodovsky, J.: Rich models for steganalysis of digital images. IEEE Trans. Inf. Forensics Secur. 7(3), 868–882 (2012) 27. Chen, J., Kang, X., Liu, Y., Wang, Z.J.: Median filtering forensics based on convolutional neural networks. IEEE Signal Process. Lett. 22(11), 1849–1853 (2015) 28. Goljan, M., Fridrich, J., Cogranne, R.: Rich model for steganalysis of color images. In: 2014 IEEE International Workshop on Information Forensics and Security (WIFS), pp. 185–190 (2014) 29. Goljan, M., Fridrich, J. and Cogranne, R.: Rich model for steganalysis of color images. In: 2014 IEEE International Workshop on Information Forensics and Security (WIFS), pp. 185–190 (2014) 30. Boroumand, M., Chen, M., Fridrich, J.: Deep residual network for steganalysis of digital images. IEEE Trans. Inf. Forensics Secur. 14(5), 1181–1193 (2018) 31. Chang, C.C., Lin, C.J.: LIBSVM: a library for support vector machines. ACM Trans. Intell. Syst. Technol. (TIST) 2(3), 1–27 (2011) 32. Stamm, M.C., Wu, M., Liu, K.R.: Information forensics: an overview of the first decade. IEEE Access 1, 167–200 (2013) 33. Bas, P., Filler, T. and Pevn´ y, T.: “Break our steganographic system”: the ins and outs of organizing BOSS. In: International Workshop on Information Hiding, pp.59–70 (2011) 34. Hu, S., Yu, T., Guo, C., Chao, W.-L., Weinberger, K.Q.: A new defense against adversarial images: turning a weakness into a strength. In: Advances in Neural Information Processing Systems, pp. 1633–1644 (2019) 35. Hosseini, H., Kannan, S., and Poovendran, R.: Are odds really odd? Bypassing statistical detection of adversarial examples. arXiv: 1907.12138 (2019) 36. Zhang, H., et al.: Theoretically principled trade-off between robustness and accuracy, In: International Conference on Machine Learning, pp. 7472–7482 (2019) 37. Chen, Z., Tondi, B., Li, X., Ni, R., Zhao, Y., Barni, M.: A gradient-based pixeldomain attack against SVM detection of global image manipulations, In: 2017 IEEE Workshop on Information Forensics and Security (WIFS), pp. 1–6 (2017) 38. Pevny, T., Bas, P., Fridrich, J.: Steganalysis by subtractive pixel adjacency matrix. IEEE Trans. Inf. Forensics Secur. 5(2), 215–224 (2010) 39. Bryniarski O., Hingun N., Pachuca P., et al.: Evading adversarial example detection defenses with orthogonal projected gradient descent. arXiv: 2106.15023 (2021)
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks Hao Tang, Donghong Liu, Xinhai Xu(B) , and Feng Zhang(B) Academy of Military Science, Beijing, China [email protected] Abstract. Predicting facts that occur in the future is a challenging task in temporal knowledge graphs (TKGs). TKGs represent temporal facts about entities and their relations, where each fact is associated with a timestamp. Inspired from the human inference process that predictions are usually made by analyzing relevant historical clues, in this paper, we propose a model based on temporal evolution and temporal graph attention mechanism to infer future facts. Specifically, we construct a node pool to keep the importance of all nodes encountered in the historical search. We learn temporal evolution features and sub-graph structures based on temporal random walks and graph attention networks. Moreover, these sub-graphs are sets of objects with the same subjects and relations as the query. Experiments on five temporal datasets demonstrate the effectiveness of the model compared with the state-of-the-art methods. Codes are available at https://github.com/lendie/SWGAT. Keywords: Temporal knowledge graph Future facts
1
· Spatio-Temporal walks ·
Introduction
Knowledge graphs (KGs) are directed graphs which excel at organizing relational facts. They represent factual entities as nodes and semantic relations as edges. Each fact is presented as a triple of the form (subject, relation, object), e.g., (Donald Trump, PresidentOf, USA). Large-scale knowledge graphs have been used in various artificial intelligence applications including recommender system [11] and information retrieval [12]. Knowledge graph reasoning [1] refers to inferring missing facts from existing facts, but they treat a knowledge graph as a static graph, meaning that the entities and relationships do not change over time. However, in reality, many facts are only true at a certain point in time or period of time. For example, (Donald Trump, PresidentOf, USA) is valid only from January 2017 to January 2021. To this end, temporal knowledge graphs (TKGs) were introduced. A TKG represents a temporal fact as a quadruple (subject, relation, object, timestamp), describing that this fact is valid at this timestamp. Recently, research on temporal knowledge graph reasoning has received increasingly more attentions, which can be categorized into interpolation and extrapolation tasks [27]. The former studies reasoning facts within H. Tang and D. Liu—Equal contribution. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 501–512, 2023. https://doi.org/10.1007/978-981-99-1639-9_42
502
H. Tang et al.
a known time range, while the latter studies predicting facts in the future and is more challenging. In this paper, we focus on the extrapolation tasks. To solve extrapolation problems, different TKG embedding approaches have been developed. These approaches maps entities, relations and time information into a continuous vector space to capture the semantic meanings of a temporal knowledge graph. However, there exist two main challenges for current TKG embedding approaches. Firstly, how to model the evolutionary patterns of historical facts to accurately predict future facts. Secondly, how to model the topological structure of a TKG. Latest top-tier works focus on either side but not take both of them into consideration. Inspired from the process of human reasoning, in this paper, we propose to tackle extrapolation problems by iterative spatio-temporal random walks followed by a temporal relation graph attention layer. In spatio-temporal random module, we select the one-hop neighbors that are close to the subject, and then calculate their importance scores by the relationship between these neighbors and the subject. Then, after assigning an importance score to each one-hop neighbor, we iteratively walk from the neighbors with the top-n importance scores, and select the two-hop neighbors of the subject. After that, TRGAT is used to capture the topological structure which select sub-graphs that are sets of objects with the same subject and relation as the query. Similar to CyGNet [28], this layer is mainly used to capture repetitive facts related to the query, except that we use graph attention network to capture such repetitive patterns. The key contributions of this paper can be summarized as follows: – The reasoning idea of temporal knowledge graph is derived from the human cognitive process, consisting of iterative spatio-temporal walks and temporal graph attention mechanism. – We resort to graph attention networks to capture repetitive patterns. – Our model achieves state-of-the-art performance in five temporal datasets.
2 2.1
Related Work Static KG Representation Learning
There is a growing interest in knowledge graph embedding methods. This type of method is broadly classified into three paradigms: (1) Translational distancebased models [1,25]. (2) Tensor factorization-based models [14,15]. (3) Neural network-based models [4,13]. Translation-based models consider the translation operation between entity and relation embedding, such as TransE [1] and TransH [25]. Factorization-based models assume KG as a third-order tensor matrix, and the triple score can be carried out through matrix decomposition, including RESCAL [15], HOLE [14]. Other models use convolutional neural networks or graph neural networks to model scoring functions, like ConvE [4], KBGAT [13]. However, all the above methods cannot predict future events, as all entities and relations are treated as static.
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks
2.2
503
Temporal KG Representation Learning
In order to better capture the dynamic changes of information, the temporal knowledge graph embedding(TKGE) model encodes temporal information into entities or relationships. A number of recent works have attempted to model the changing facts in TKGs. Temporal knowledge graph inferring can be divided into interpolation [6] and extrapolation problems [28]. The former attempts to reason about facts in known time, while the latter, which is the focus of this paper, attempts to predict future facts. On the interpolation task, DE-SimplE [6] defines a function that takes entities and relations as input and then produces timespecific embeddings. But this approach ignores the dynamic changes of entities and relationships. On the extrapolation task, some models estimate the conditional probability to predict future facts via temporal point process taking all historical events into consideration. RE-NET [9] is used to capture the evolutionary patterns of fixed-length subgraphs specific to a query. CyGNet [28] models repeated facts with sequential copy-generation networks. xERTE [8] learns to find the query-related subgraphs with a fixed number of hops. Glean [3] enriches entity information by introducing time-dependent descriptions. EvoKG [17] performs temporal graph inference by jointly modeling event time and network structure.
3
Our Model
We describe our model in a top-down fashion where we provide an overview in Sect. 3.1, and then in Sect. 3.2 and 3.3 we explain each module separately. 3.1
Model Overview
Our model performs the inference process from walking sequences obtained by dynamic sampling on temporal knowledge graph and temporal relation graph attention layer (TRGAT). Our model contains two parts, spatio-temporal walks and TRGAT. Specifically, part 1 focuses on sampling dynamic node sequences whose semantic and temporal information is closely related to the given query. Afterwards, the sampled node sequences are provided to the temporal inference cell, which focuses on modeling the node sequence information and then assigning importance scores to each node related to the query. In the TRGAT module, we sample objects with the same subject and relationship as the query from the history information. And we note that different objects and subjects should also have different importance under the same relationship (Fig. 1).
504
H. Tang et al.
Node Pool iterative walk
QUERY:(es , rq, ?, ts )
Update
es es
+ h es + h rq
he Position Encode
GRU
...
MLP
es
es
Spatiotemporal walk
Update
es es
MLP
he Position Encode
he
he
GRU
U p d a t e
es
ts-1
rq
rq
rq
ts-n
...
+ h es + h rq
es
rq
TRGAT
T R G A T
TKG
ts-8
...
ts-7
ts-4
ts-2
ts-1
Time
Fig. 1. Overview of model architecture.
Iteration i+1 backtrack over time
Iteration i backtrack over time
E B A
C
B Continue walk
D
A C D
Candidates
F x
New nodes
G F B
Top-n Pruning Candidates
Fig. 2. Three 2-step walks (x is default node id, which we set to -1 when no historical links can be found).
3.2
Spatio-temporal Walking
We assume that the older the event is, the less impact on the inference. So instead of using all the historical event information, we choose the set consisting of nodes V that are closer to the time of the query, where V denotes the subset of all entities in the historical event that are directly linked to the query subject entity. We sample the connected links by backtracking over time to extract the potential time-evolving relations of temporal knowledge graph. According to our hypothesis, more recent events may contain more information and thus we use time-aware weighted sampling P(qv = (si , rk , oj , t )) = exp (t − t), where t is previous sampled timestamp before t . We show a toy example in Fig. 2. Given a query (es , rp , ?, ts ), we use A to denote node es . We first find the most recent moment of node A from historical events, such as t − 1. Since we use time
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks
505
backtracking to search for historical information, in the next step, we will search for nodes that have direct link with node A from facts with less than or equal to t − 1. As shown in Fig. 2, we obtain three walks, { (A → B), (A → C), (A → D) }, and here we omit the relationship and timestamp for simplicity. Then we put these walks into the time unit to calculate the relevance score of nodes B, C and D to the query. After, the walk continues, which we call iterative walk. To reduce the path selection time, we use the Top-n pruning method to continue the walking only from the n neighboring nodes with highest relevance scores. Sampled Walks. We define sampled edges Sv,t = {(e, t ) | e ∈ G, t ≤ t, v ∈ e} to include the links contained before node v. The walking sequence of the temporal knowledge graph can be expressed as: E = ((w0 , t0 ) , (w1 , t1 ) , . . . , (wm , tm )) , t0 ≥ · · · ≥ tm , (wi , ti ) ∈ S
(1)
where (wi , ti ) denotes quadruples (ei , ri , oi , ti ). Position Encoder. Inspired by CAW [23], in order to make the model inductive, we use the method mentioned in CAW [23] to remove the node identifiers to encode the relative position information. Let the set of walks sampled from node es be Se . Each node from Se is replaced by a vector that encode a positional frequency of the node in each corresponding positions of walks in Se . For node es , the vector of position frequencies relative to all walks in Se is defined as: P E(es , Se ) = {|{W | W ∈ Se , es = Wm , m ∈ {1, . . . , m}}|}
(2)
This equation simply expresses that the position of node es is encoded as a vector, so that the mth component of this vector is equal to the number of occurrences of node es at the mth position of all walks in Se . Finally, we will encode the relative positions of the nodes in each walk: = (P E (w0 ) , P E (w1 ) , . . . , P E (wm )) E
(3)
is passed through a Representation of each position of each walk, i.e., E, multi-layered perceptron (MLP) to obtain the corresponding position embedding: = MLP(E) (4) f0 (E) Iterative Update. Just like the human learning process [24], humans update their existing knowledge base when they gain new observations. In our condition, the existing knowledge base is the node scores to be discovered, which we call the node pool. When new nodes are reached, our spatio-temporal random walk module updates the importance scores in the node pool, including known nodes and new nodes. As shown in the Fig. 2, the query node is found in the historical
506
H. Tang et al.
information, and then the nearest spatio-temporal neighbors are selected starting from the query node, called one-hop spatio-temporal neighbors. The node sequences are then fed into the GRU model to calculate their node importance. Subsequently, the spatio-temporal walking is performed again, starting from its one-hop neighbors. As the walking continues, the model’s knowledge of the query subject node is constantly updated, and finally we make predictions using the node pool. We obtain the encoding of E as follows: i=0,1,...,m ) Encode(E) = GRU ({F1 (hi , f1 (ti )) ⊕ hr ⊕ f0 (E)}
(5)
where F1 is the time-aware encoding function, hi and hr are the hidden representation of node and relation, respectively, f1 is the time embedding function, is the relative position embeddings, ⊕ is the concatenation operation. The F1 E conducts nonlinear transformation as: F1 (hi , f1 (ti )) = W1 (hi ⊕ f1 (ti ))
(6)
where W1 is the time-aware trainable parameter. For f1 , we adopt Bochner Time Embedding [23]. f1 (t) = [cos(w1 t), sin(w2 t), . . . , cos(wd t), sin(wd t)]
(7)
where wi ’s are learnable parameters. To get the relevance of the discovered nodes to our query, we consider the node and relation information of the query and then update the seen entity scores in the node pool by computing the query-related attention scores using: Att(E, q) = f (Wλ (hs ⊕ Encode(E) ⊕ hr ))
(8)
where Att is the attention score of the seen nodes regarding the query q = (es , rp , ?, ts ), Wλ is the weight matric for aggregating features from evolving node sequences and query, hs and hp denotes the embeddings of entity es and relation ep related to the query, respectively, and f (·) is an activation function. 3.3
TRGAT
As for the TRGAT module, we only consider those objects in history that have a connection to a given subject under same relation rq . We define Ot (esi , epi ) to represent the set of objects that have a relation rp with the subject entity at a history timestamp t(0 ≤ t ≤ ts ). The TRGAT module can be considered as a special kind of neighborhood aggregation. We assume that there are differences between different entities under the same relationship. We assign different weights to each edge by computing the attention. aes ,Ot = f (W2 (F1 (hes , f1 (t)) ⊕ hpi ⊕ hOt ))
(9)
where aes ,Ot is the attention of O and the subject entity, W2 is the relationaware transformation matric, hpi and hOt is the embeddings of relation epi and Ot , respectively.
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks
507
To get the relative attention values, a softmax function is applied over aes ,Ot : αes ,Oi = softmax(aes ,Oi ) =
exp(aes ,Oi )
(10)
n∈Ot exp(aes ,On )
We aggregate the representations of prior neighbors and weight them using the normalized attention scores, which is written as h αes ,n hn (11) es = n∈Ot
After, we use the updated subject entity and the object entity in set Ot to update the scores in the node pool. Att(es , Ot ) = f (Wμ (h es , hOt ))
4
(12)
Experiments
In this section, we demonstrate the effectiveness of our model using five public datasets. We first explain the experimental setup, including datasets, implementation details, benchmark methods and evaluation metrics. After that, we discuss the experimental results. In particular, we also conduct several ablation studies to analyze the impact of entity/relationship embedding and various components of SWGAT. 4.1
Experimental Setup
Datasets. In previous studies, there are five typical TKGs commonly used, namely, ICEWS14, ICEWS18, WIKI, YAGO and GDELT. Integrated Crisis Early Warning System(ICEWS) dataset contains political events annotated with specific time, e.g. (Barack Obama, visit, Malaysia, 2014-02-19). ICEWS14 and ICEWS18 are subsets of ICEWS, corresponding to facts from 2014 and facts from 2018, respectively. It is worth noting that all time annotations in the ICEWS dataset are time points. WIKI and YAGO are knowledge bases with temporally associated facts. Global Database of Events, Language, and Tone(GDELT) dataset is an initiative to construct a global dataset of all events, connecting people, organizations, and news sources. Baseline Methods. Our model is compared with two categories of models: static KG reasoning models and TKG reasoning models. DistMult [26], RGCN [18], ConvE [4] and RotateE [20] are selected as static models. Temporal methods include TA-DistMult [5], R-GCRN [19], HyTE [2], dyngraph2vecAE [7], EvolveGCN [16], know-Evolve [21], know-Evolve+MLP [21], DyRep [22], CyGNet [28], RE-Net [9], xERTE [8] and EvoKG [17]. We note that both T-GAP [10] and xERTE [8] use subgraph extraction and attention flow walks, but the former is used for the interpolation problem.
508
H. Tang et al.
Table 1. Experiments results for the extrapolation task on five temporal datasets. Hits@N values are in percentage. The best score is in Bold and the second is underlined. The results of all the baseline methods are taken from [17]. Method
ICEWS14 ICEWS18 WIKI YAGO GDELT MRR Hits@3 Hits@10 MRR Hits@3 Hits@10 MRR Hits@3 Hits@10 MRR Hits@3 Hits@10 MRR Hits@3 Hits@10
DistMult [26] R-GCN [18] ConvE [4] RotateE [20]
9.72 15.03 21.64 9.79
10.09 16.12 23.16 9.37
22.53 31.47 38.37 22.24
13.86 15.05 22.56 11.63
15.22 16.49 25.41 12.31
31.26 29.00 41.67 28.03
27.96 13.96 26.41 26.08
32.45 15.75 30.36 31.63
39.51 22.05 39.41 38.51
44.05 27.43 41.31 42.08
49.70 31.24 47.10 46.77
59.94 44.75 59.67 59.39
8.61 12.17 18.43 3.62
8.27 12.37 19.57 2.26
17.04 20.63 32.25 8.37
Temporal TA-DistMult [5] HyTE [2] dyngraph2vecAE [7] EvolveGCN [16] Know-Evovle [21] Know-Evolve+MLP [21] DyRep+MLP [22] R-GCRN+MLP [19] CyGNet [28] RE-Net [9] xERTE [8] EvoKG [17]
11.29 7.72 6.95 8.32 0.05 16.81 17.54 21.39 22.83 23.91 23.92 27.18 27.81
11.60 7.94 8.17 7.64 0.00 18.63 19.87 23.60 25.36 26.63 27.30 30.84 31.75
23.71 20.16 12.18 18.81 0.10 29.20 30.34 38.96 39.97 42.70 39.54 47.67
15.62 7.41 1.36 10.31 0.11 7.41 7.82 23.46 25.43 26.81 27.14 29.28 30.15
17.09 7.33 1.54 10.52 0.00 7.87 7.73 26.62 28.95 30.58 31.08 33.94 35.12
32.21 16.01 1.61 23.65 0.47 14.76 16.33 41.96 43.86 45.92 43.31 50.09
26.44 25.40 2.67 27.19 0.03 10.54 10.41 28.68 33.89 31.55 70.52 68.03 82.76
31.36 29.16 2.75 31.35 0.00 13.08 12.06 31.44 36.10 34.45 75.01 79.60 84.44
38.97 37.54 3.00 38.13 0.04 20.21 20.93 38.58 41.86 42.26 76.46 85.91 86.83
44.98 14.42 0.81 40.50 0.02 5.23 4.98 43.71 52.07 46.37 84.12 68.59 87.63
50.64 39.73 0.74 45.78 0.00 5.63 5.54 48.53 56.12 51.95 88.62 81.13 90.52
61.11 46.98 0.76 55.29 0.01 10.23 10.19 56.98 63.77 61.59 89.90 92.73
10.34 6.69 4.53 6.54 0.11 15.88 16.25 18.63 18.05 19.44 21.21 19.28 22.10
10.44 7.57 1.87 5.64 0.02 15.69 16.45 19.80 19.11 20.73 22.93 20.55 24.54
21.63 19.06 1.87 15.22 0.10 22.28 23.86 32.42 31.50 33.81 26.60 34.44 36.49
Static
SWGAT (our)
4.2
43.84
47.54
91.25
Results on TKG Reasoning
Table 1 summarizes the time-aware filtered results of link prediction task on the ICEWS14, ICEWS18, WIKI, YAGO,and GDELT datasets. The benchmark results are obtained from [17]. Our model outperforms basically all baseline methods on different datasets. Compared with the best benchmark model EvoKG [17], our model achieves 2.3% and 2.8% improvement in MRR and Hits@3 on the ICEWS14 dataset, 3% and 3.5% improvement in MRR and Hits@1 on the ICEWS18 dataset, 21.6%, 6% and 1.1% improvement in MRR, Hits3 and Hits10 on the WIKI dataset, 27.7% and 11.6% improvement in MRR and Hits3 on the YAGO dataset, 14.6%, 19.5% and 5.9% improvement in MRR, Hits3 and Hits10 on the GDELT dataset, respectively. And our model significantly outperforms all other benchmark methods on all metrics, indicating that learning time series and directly related spatio-temporal neighbors can help the model find correct target nodes. In particular, on the YAGO dataset, it is 4.2% and 27.7% higher MRR than xERTE [8] and EvoKG [17], respectively, probably due to the fact that the YAGO dataset contains events that occur with relative regularity and have a small number of neighbouring entities, which allow SWGAT to find target entities quickly and accurately. Among the benchmark methods, the static methods have relatively poor performance because they do not consider temporal information.
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks
4.3
509
Ablation Study
To evaluate the effectiveness of SWGAT, we conduct ablation studies on dataset ICEWS18. Impact of Two Components. To verify the importance of each component of SWGAT, we mask the saptio-temporal random walk and the TRGAT module, respectively. The experimental results are shown in Fig. 3. We find that the spatio-temporal walk component has a considerable impact on the performance of the model. From the results, we can obtain that SWGAT performs better than when the two components act alone, which suggests that SWGAT can integrate the properties of the two components, i.e., exploration of temporal evolution and structural features. Embedding Size. We set the embedding size (both temporal and structural embedding) to 100, 200, 300, and 400. As shown in the Fig. 4, the best results on ICEWS18 dataset are achieved with embedding size 200. However, a larger embedding size, such as 400, will hurt the performance, probably because too large dimensions can lead to overfitting. Number of Walks. Figure 5 shows the performance of three evaluation metrics on the ICEWS18 dataset with different number of walks. We observe that the performance increases as the number of walks increases. However, the performance is close to the saturation state when the number of walks reaches 20, i.e., only a small improvement in performance can be obtained regardless of the increase in the number of walks. Inductive Link Prediction. As time evolves, new nodes may appear, such as new users or new posts. Therefore, a good model should have good inductive representation capability to cope with unseen entities. For example, in the test set of ICEWS14, we have a quadruple (Mehdi Hasan, Make an appeal or request, Citizen(India), 2014-11-12). The entity Mehdi Hasan does not appear in the training set, which means that the quadruple contains an entity that the model does not observe in the training phase. Specifically, we divide the test dataset Table 2. Experiments results for inductive link prediction on ICEWS18 datasets. Hits@N values are in percentage. The best score is in Bold and the second is underlined. Methods
ICEWS18 test set (Mixed entities) ICEWS18 test set (seen entities) ICEWS18 test set (unseen entities)
Model
MRR Hits@1 Hits@3 Hits@10
MRR Hits@1 Hits@3 Hits@10 MRR Hits@1 Hits@3 Hits@10
CyGNet [28] RE-Net [9] xERTE [8] EvoKG [17] SWGAT
25.43 26.81 27.14 29.28 30.15
26.37 27.22 28.01 30.39 31.24
16.09 17.6 19.64 19.02 21.52
28.95 30.58 31.08 33.94 35.12
43.86 45.92 43.31 50.09 47.54
16.68 17.29 19.93 19.58 22.14
30.05 31.04 32.12 35.16 36.33
45.51 46.61 44.92 51.83 49.67
2.08 2.19 6.96 4.13 7.52
1.18 1.15 4.82 2.71 6.01
1.99 2.28 6.45 4.45 8.52
3.45 3.84 8.57 6.34 10.68
510
H. Tang et al.
Fig. 3. Time-aware filtered metrics of SWGAT with or without the TRGAT module on ICEWS18.
Fig. 4. Embedding Size.
Fig. 5. Number Walks.
Fig. 6. Training Time Cost.
into three categories, seen entities, unseen entities and mixed entities (containing both seen and unseen entities), and the results are shown in Table 2. We find that our proposed method SWGAT achieves optimal performance on all evaluation metrics, showing the good performance of our model SWGAT on inductive link prediction. Time Cost. It is important to get strong performance while keeping the training time short on the model. To investigate the balance between accuracy and efficiency of SWGAT, we report the total training time for convergence of our model and four benchmarks on Fig. 6. We find that although Re-Net [9] is one of the strongest performance baselines, it takes almost 13 times longer to train compared with CyGNet [28] and xERTE [8]. Whereas our model ensures shorter training time while maintaining state-of-the-art performance for the extrapolation problem, which shows the superiority of our model.
Evolving Temporal Knowledge Graphs by Iterative Spatio-Temporal Walks
5
511
Conclusions
Representing and reasoning about temporal knowledge is a challenging problem. In this paper, we propose a model for temporal graph prediction that learns the evolution patterns of entities and relations over time and spatio-temporal subgraph specific to the query entities and relations, respectively. Compared with state-of-the-art methods, extensive experiments on five benchmark datasets demonstrate the effectiveness of the model on the extrapolation problem. It is necessary to study more efficient node/edge sampling strategies, because the efficiency of the model is limited by the choice of nodes when the model is extended to real-life temporal graph with tens of billions of nodes. Interesting future work includes developing fast and efficient versions and applications in streaming scenarios. Acknowledgment. This work is supported by grants from Shengze Li’s National Natural Science Foundation of China (No. 11901578).
References 1. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., Yakhnenko, O.: Translating embeddings for modeling multi-relational data. In: Advances in Neural Information Processing Systems, vol. 26 (2013) 2. Dasgupta, S.S., Ray, S.N., Talukdar, P.: Hyte: hyperplane-based temporally aware knowledge graph embedding. In: Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing, pp. 2001–2011 (2018) 3. Deng, S., Rangwala, H., Ning, Y.: Dynamic knowledge graph based multi-event forecasting. In: Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 1585–1595 (2020) 4. Dettmers, T., Minervini, P., Stenetorp, P., Riedel, S.: Convolutional 2d knowledge graph embeddings. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 5. Garc´ıa-Dur´ an, A., Dumanˇci´c, S., Niepert, M.: Learning sequence encoders for temporal knowledge graph completion. arXiv preprint arXiv:1809.03202 (2018) 6. Goel, R., Kazemi, S.M., Brubaker, M., Poupart, P.: Diachronic embedding for temporal knowledge graph completion. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 3988–3995 (2020) 7. Goyal, P., Chhetri, S.R., Canedo, A.: dyngraph2vec: capturing network dynamics using dynamic graph representation learning. Knowl. Based Syst. 187, 104816 (2020) 8. Han, Z., Chen, P., Ma, Y., Tresp, V.: Explainable subgraph reasoning for forecasting on temporal knowledge graphs. In: International Conference on Learning Representations (2020) 9. Jin, W., Qu, M., Jin, X., Ren, X.: Recurrent event network: autoregressive structure inference over temporal knowledge graphs. arXiv preprint arXiv:1904.05530 (2019) 10. Jung, J., Jung, J., Kang, U.: T-gap: Learning to walk across time for temporal knowledge graph completion. arXiv preprint arXiv:2012.10595 (2020) 11. Koren, Y., Bell, R., Volinsky, C.: Matrix factorization techniques for recommender systems. Computer 42(8), 30–37 (2009)
512
H. Tang et al.
12. Liu, Z., Xiong, C., Sun, M., Liu, Z.: Entity-duet neural ranking: Understanding the role of knowledge graph semantics in neural information retrieval. arXiv preprint arXiv:1805.07591 (2018) 13. Nathani, D., Chauhan, J., Sharma, C., Kaul, M.: Learning attention-based embeddings for relation prediction in knowledge graphs. arXiv preprint arXiv:1906.01195 (2019) 14. Nickel, M., Rosasco, L., Poggio, T.: Holographic embeddings of knowledge graphs. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 30 (2016) 15. Nickel, M., Tresp, V., Kriegel, H.P.: A three-way model for collective learning on multi-relational data. In: ICML (2011) 16. Pareja, A., et al.: EvolveGCN: evolving graph convolutional networks for dynamic graphs. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 5363–5370 (2020) 17. Park, N., Liu, F., Mehta, P., Cristofor, D., Faloutsos, C., Dong, Y.: EVOKG: jointly modeling event time and network structure for reasoning over temporal knowledge graphs. In: Proceedings of the Fifteenth ACM International Conference on Web Search and Data Mining, pp. 794–803 (2022) 18. Schlichtkrull, M., Kipf, T.N., Bloem, P., van den Berg, R., Titov, I., Welling, M.: Modeling relational data with graph convolutional networks. In: Gangemi, A., Navigli, R., Vidal, M.-E., Hitzler, P., Troncy, R., Hollink, L., Tordai, A., Alam, M. (eds.) ESWC 2018. LNCS, vol. 10843, pp. 593–607. Springer, Cham (2018). https://doi.org/10.1007/978-3-319-93417-4 38 19. Seo, Y., Defferrard, M., Vandergheynst, P., Bresson, X.: Structured sequence modeling with graph convolutional recurrent networks. In: Cheng, L., Leung, A.C.S., Ozawa, S. (eds.) ICONIP 2018. LNCS, vol. 11301, pp. 362–373. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-04167-0 33 20. Sun, Z., Deng, Z.H., Nie, J.Y., Tang, J.: Rotate: knowledge graph embedding by relational rotation in complex space. arXiv preprint arXiv:1902.10197 (2019) 21. Trivedi, R., Dai, H., Wang, Y., Song, L.: Know-evolve: deep temporal reasoning for dynamic knowledge graphs. In: international Conference on Machine Learning, pp. 3462–3471. PMLR (2017) 22. Trivedi, R., Farajtabar, M., Biswal, P., Zha, H.: DYREP: learning representations over dynamic graphs. In: International Conference on Learning Representations (2019) 23. Wang, Y., Chang, Y.Y., Liu, Y., Leskovec, J., Li, P.: Inductive representation learning in temporal networks via causal anonymous walks. arXiv preprint arXiv:2101.05974 (2021) 24. Wang, Y., Chiew, V.: On the cognitive process of human problem solving. Cogn. Syst. Res. 11(1), 81–92 (2010) 25. Wang, Z., Zhang, J., Feng, J., Chen, Z.: Knowledge graph embedding by translating on hyperplanes. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 28 (2014) 26. Yang, B., Yih, W.t., He, X., Gao, J., Deng, L.: Embedding entities and relations for learning and inference in knowledge bases. arXiv preprint arXiv:1412.6575 (2014) 27. Zhao, M., Zhang, L., Kong, Y., Yin, B.: Temporal knowledge graph reasoning triggered by memories. arXiv preprint arXiv:2110.08765 (2021) 28. Zhu, C., Chen, M., Fan, C., Cheng, G., Zhan, Y.: Learning from history: modeling temporal knowledge graphs with sequential copy-generation networks. arXiv preprint arXiv:2012.08492 (2020)
Improving Knowledge Graph Embedding Using Dynamic Aggregation of Neighbor Information Guangbin Wang1 , Yuxin Ding1,2(B) , Yiqi Su1 , Zihan Zhou1 , Yubin Ma1 , and Wen Qian1 1
2
Harbin Institute of Technology, ShenZhen, China [email protected] Guangdong Provincial Key Laboratory of Novel Security Intelligence Technologies, Guangzhou, China Abstract. Knowledge graph embedding represents the embedding of entities and relations in the knowledge graph into a low-dimensional vector space to accomplish the knowledge graph complementation task. Most existing knowledge graph embedding models such as TransE and RotatE based on translational distance models only consider triplelevel information. Considering the rich contextual information of entities in the graph, we propose a new approach named DAN, which is able to dynamically aggregate neighbor information. Firstly, the relation is regarded as a transformation operation on heterogeneous neighbors. Neighborhood information is transformed into an homogeneous space with the central node. This solves the problem of transformation between heterogeneous nodes in the knowledge graph. Then, by dynamically aggregating neighbor information, the same entity owns different information embedding in different triples. This method enriches the representation of entity information and can be combined with other embedding models as a general method. We combine our approach with TransE and RotatE models. The experimental results show that DAN can improve the accuracy of the original model. Our best result outperforms the existing state-of-the-art models in link prediction. While our method has a faster convergence rate and higher accuracy with a slightly increased number of parameters.
Keywords: Knowledge Graph
1
· Link Prediction · Graph Learning
Introduction
Knowledge graph generally consist of a number of factual triples, and now knowledge graph is widely used in the fields of question answering [7], intelligent search [9], personalized recommendation [22], etc. There are a lot of researches on knowledge graph in the academic and industrial fields. Representative knowledge graphs include WordNet [13], Freebase [1], Yago [18], DBpedia [11], etc. Knowledge graph consists of triples (h,r,t), with r representing the relation between the head entity h and the tail entity t. Knowledge graph contains rich information, c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 513–525, 2023. https://doi.org/10.1007/978-981-99-1639-9_43
514
G. Wang et al.
Fig. 1. Illustrative example of the KGs
but there are also missing nodes and relations. To solve this problem, there are many researches on how to embed entity and relation in low-dimensional space to predict the missing information of knowledge graph as a way to automatically complete the knowledge graph. In this paper, we aim to improve the link prediction accuracy by investigating the knowledge graph embedding method. Knowledge graph generally contain a large number of nodes and relations. Some classical knowledge graph embedding models, such as TransE [2], RotatE [19], DistMult [24], ComplEx [21], etc. construct scoring functions based on knowledge graph triples. These methods ignore the complete structural information of each node. However, the knowledge graph, as a kind of heterogeneous graph, has rich contextual and structural information for each entity. Some graph convolutional networks such as GCN [10]consider the neighbor information of nodes when dealing with graph structure information, but knowledge graphs are heterogeneous graphs with different relationships between nodes, so it is difficult to apply GCN directly to knowledge graphs. RGCN [16] considers the heterogeneous property of the knowledge graph, transforms the neighbor information by the relation matrix, and also reduces the number of parameters of the relation matrix by using the basic decomposition and block-diagonal decomposition. HRAN [26] classifies the neighbor nodes by relations and divides the heterogeneous graph into multiple homogeneous graphs. Then GCN can be applied to aggregate the neighbor information under each relation, while considering the different weights of the neighbor information. However, Zhang et al. [27] proves that the modeling of graph structure is not important for the GCN-based KGC model, and applying GCN methods to knowledge graphs has the same effect as simply performing linear transformations on entities. Meanwhile the above GCN method does not take into account that the neighbor information should have different weights when predicting different triples. We assume that each neighbor should have different weights under different relations. As shown in Fig. 1. When predicting the Classmate relation between PersonA and PersonB, the school from which PersonA graduated should have a higher weight. When predicting the relative relation between PersonA and PersonD, the Father of PersonA should have a higher weight. In this paper, we propose a method to dynamically aggregate entity neighbor information. When predicting different triples, the neighbor information has
Improving Knowledge Graph Embedding Using Dynamic Aggregation
515
different weights depending on the relation. Before aggregation, we transform heterogeneous neighbors to homogeneous neighbors by relations based on scoring functions. Our approach can be combined with existing knowledge graph embedding models, such as TransE and RotatE. Experimental results demonstrate the effectiveness of dynamic aggregation of neighbor information. Compared with the original model, the dynamic aggregation neighbor approach achieves better results by preprocessing the entity information with only a slight increase in parameters. And our best results are better than the state-of-the-art models. Notations. In this paper, each fact is represented by a triple (h, r, t) in lower case. h/r represents the head/tail entity and r represents the relation between them. Bold h,r,t is used to represent their vector representations. v represents the vector representation of the entity, vi represents the vector representation for entity node i, and vi represents the updated vector representation. rj represents the embedding of relation j for scoring. rj represents the representation of relation j for transformation. T denotes the set of triples.
2
Related Work
Knowledge graph embedding models can be broadly classified into three categories according to the scoring functions on the triples, including translational distance models, semantic matching models and neural network models. The translational distance model calculates the similarity between different entities and relations by a distance scoring function. TransE [2] is the most representative translational distance model. Representing entities and relations in real number space, as shown in Fig. 2(a), for a fact (h, r, t), TransE expects h+r≈ t. TransE has the advantage of simplicity and efficiency, but cannot handle complex 1-to-n, n-to-1 etc. relations. TransH [23], TransR [12], TransD [8], based on TransE, map entities to different vector spaces in order to solve complex relations. RotatE [19] represents the entities and relations embedded in the complex space. As shown in Fig. 2(b),through Euler’s formula, the relations are regarded as a rotation operation, which solves the symmetric relations that TransE cannot deal with. TorusE [5] represents entities and relations on Lie groups, solving the problems caused by TransE due to regularization. Rotat3D [6] represents entities and relations embedded in 3D space, enhancing the representation of non-exchangeable compositional patterns. Rot-Pro [17] solves the representation of transitive relations by mapping entities on the basis of RotatE. The semantic matching model determines the truth of a triple by a similarity scoring function. RESCAL [15] represents each relation as a matrix with fr (h,t) = hT Mr t. RESCAL has good fitting ability, yet requires a large number of parameters due to the need to represent the relations in matrix form. DistMult [24] restricts the relation matrix to a diagonal matrix, which reduces the complexity of the model, nevertheless it can only handle symmetric relations. ComplEx [21] represents entity and relation embeddings in complex space based on DistMult, using conjugates of tail entities to solve antisymmetric relations.
516
G. Wang et al.
Fig. 2. Illustrations of TransE, and RotatE
QuatE [25] models symmetry, antisymmetry, and invertibility relations through quaternion operations, while having higher degrees of freedom and fitting power. Neural network models apply some well-established neural networks to knowledge graphs. ConvE [4] splices head entities and relations into matrices, which are operated by 2-dimensional convolution. ConvKB [14] splices a fact (h,r,t) using 2-dimensional convolution as a score. M-DCN [28] implements multi-dimensional dynamic convolution by setting different splicing methods and convolution kernels. The above models only consider triple-level information. But each node in the knowledge graph has rich contextual information. Similar approaches to this paper are some models based on graph convolutional networks. R-GCN [16] is the first to apply the GCN [10] framework to relational data and reduce the complexity of the relation matrix through parameter sharing and sparse constraints. TransGCN [3] combines the GCN with the translational distance series model. HRAN [26] classifies neighbors based on relations and introduces a relational attention mechanism when aggregating neighbor information. Zhang [27] proposes that GCN-based modeling is not important for KGC models, and also proposes a way to achieve comparable results to GCN-based methods by a linear transformation of entities. The difference between this paper and other GCN-based models is that the aggregated neighbor information is transformed based on the scoring function and the relation between the neighbor and the central node. A dynamic way of aggregating neighbor information is used to enrich the information representation of entities by giving different weights to each neighbor under different relations when predicting each triple.
3 3.1
Proposed Model Neighborhood Information Transformation
The graph structure is generally divided into homogeneous graphs and heterogeneous graphs. Homogeneous graphs have only one relationship between nodes, while heterogeneous graphs have different relationships among nodes, as shown in Fig. 1. In the homogeneous graph, the neighbor information can be aggregated directly to the central node. In contrast, this approach cannot be directly applied
Improving Knowledge Graph Embedding Using Dynamic Aggregation
517
to heterogeneous graphs, where aggregating neighbor information in a heterogeneous graph also requires considering the relation between neighbors and central nodes. Therefore, it is necessary to transform the heterogeneous neighbors to a space that is homeomorphic to the central node. In this paper, we assume that the relation between nodes can transform heterogeneous nodes to homogeneous spaces. We transform the neighbor information to the homogeneous space with the central node based on the selected scoring function. In the following, we select TransE and RotatE as scoring functions and show how to perform neighbor information transformation under different scoring functions. When the scoring function is based on the translational distance model, the relation has directionality. We classify the neighbors as incoming and outgoing neighbors. For in-degree neighbors, the central node is the tail entity of the indegree neighbor. Use the forward transformation of the relation. For out-degree neighbors, the inverse transformation is taken. For example neighbor vj has relation rj with central node vi , rj represents the representation of relation j for transformation and vj represents the transformed neighbor information. Then, under TransE, the transformation formula is expressed as vj + rj (vj , rj , vi )T vj = (1) vj − rj (vi , rj , vj )T where vj , rj , rj Rd When the scoring function is RotatE, relation is regard as a rotation operation. The transformation formula can be expressed as vj ◦ rj (vj , rj , vi )T vj = (2) vj ◦ rj (vi , rj , vj )T where vj , rj , rj C d , ◦ represents the Hadamard product of the complex space,rj represents the conjugacy relation of rj . 3.2
Dynamic Aggregation of Neighbor Information
After transforming the neighbor information to homogeneous space, we need to aggregate the updated neighbor information to the central node. The general method of aggregating neighbors is to average the neighbor information or to calculate the weights of neighbors based on the central node. These approaches cause the impact of each neighbor information on the central node to be constant under different triples. However, due to the different triple relations, we believe that neighbor information has various weights under different relationships. As shown in Fig. 1, when predicting the Classmate relation between PersonA and PersonB, the school from which PersonA graduated should have a higher weight. When predicting the Relative relation between PersonA and PersonD, the Father of PersonA should have a higher weight. Therefore, when predicting different triples, the same neighbors may have different weights for the same entity PersonA due to the different relations.
518
G. Wang et al.
Fig. 3. Illustration of proposed DAN framework
In this paper, we propose a method for dynamically aggregating neighbor information, named DAN. Depending on the relation in the triple, the neighbor information has variable influence on the central node, thus realizing dynamic aggregation of neighbor information. As shown in Eq. 3. βi,j vj ) (3) vi = αvi + (1 − α)( jNi
where α represents the weight of the central node information, which is a learnable parameter equivalent to self-attention. vj is the updated information of neighbor nodes. βi,j is the importance of neighbor nodes to the central node, Ni is the set of neighbors of central node vi . βi,j represents the importance of the neighboring nodes to the central node. The dynamic aggregation neighbor method dynamically aggregates the information of neighbors based on the relation between the neighbor and the central node. As shown in Eq. 4. βi,j =
exp(LeakyReLU (pT [ri ||rj ])) T kNi exp(LeakyReLU (p [ri ||rk ]))
(4)
where pT represents a single-layer neural network, the activation function uses LeakyReLU, Ni represents the set of neighbors of node vi , || represents the splicing of vectors. Finally the weights of each neighbor are obtained by Softmax. The overall flow is shown in Figure 3, and the left figure represents the structure of the partial knowledge graph. The vector representation of each node in the right figure corresponds to the color of the node in the left figure. For a triple (h, r, t), We need to update the entities h and t according to the respective structural information. Firstly, the neighbor information is transformed to the homogeneous space with the central node based on the relation between the neighbor and the central node. After that, we calculate the weights based on the relation r of the triple and the relation between the neighbors and the central entity. Then the neighbor information and its own information are aggregated and the head and tail entities information representation are updated. Finally,
Improving Knowledge Graph Embedding Using Dynamic Aggregation
519
the scores are based on the scoring function for the triples. To reduce the complexity, the calculation process can be fixed to sample k neighbors for each node.
4 4.1
Experiments Experimental Settings
Benchmark Datasets. We evaluate our approach on two commonly used publicly available datasets, WN18RR [4] and FB15k-237 [20], as shown in Table 1. The WN18RR dataset is a knowledge graph of lexical relations, a subset of WN18 [2], with reversible relations removed. FB15k-237 is a knowledge graph based on Freebase [1], a large-scale knowledge graph containing generic knowledge. FB15k-237 removes the reversible relations. Table 1. Basic statistics of WN18RR and FB15k-237 Dataset
#Entity #Relation #Train #Valid #Test
WN18RR 40943 FB15k-237 14541
11 237
86835 3034 272115 17535
3134 20466
Evaluation Metrics. We mainly use three main evaluation metrics to evaluate the link prediction results, including Mean Rank (MR), Mean Reciprocal Rank (MRR) and HITs at N (H@N), where N is taken as 1, 3 and 10. We evaluate the performance of link prediction in the filtering setting. We ranked the test triples against all other candidate triples that did not appear in the training, validation, or test sets, where the candidate triples were generated by replacing the head entity or the tail entity. Training Details. In this paper loss function we use the self-adversarial learning method of RotatE [19]. The parameters are optimized using the Adam optimizer. To validate the effectiveness of our proposed method and to eliminate the effect of training skill enhancement on the experimental results, we adopt the current advanced RotatE self-adversarial training method for TransE model as well. Table 2. The number of parameters on WN18RR dataset and FB15k-237 data sets. WN18RR FB15k-237 TransE RotatE
78.13M 156.22M
28.20M 55.94M
TransE+DAN 78.15M RotatE+DAN 156.25M
28.65M 56.39M
Number of Parameters. Table 2 shows the size of the number of parameters on the WN18RR and FB15k-237 datasets for the combination of the two models and the dynamic aggregated neighbor information approach. The embedding dimension is 500 dimensions. The number of parameters of the dynamically aggregated neighbor information method increases only slightly compared to the original model.
520
4.2
G. Wang et al.
Experimental Results Table 3. Link prediction results on the WN18RR and FB15k-237 datasets.
Model
WN18RR MRR MR
FB15k-237 H@1 H@3 H@10 MRR MR H@1 H@3 H@10
TransE RotatE TorusE ConvE DistMult ComplEx M-DCN R-GCN HRAN
.218 .476 .452 .430 .430 .440 .475 .479
3408 3340 4187 5110 5261 2113
.012 .428 .422 .400 .390 .410 .440 .450
.523 .571 .512 .520 .490 .510 .540 .542
.322 .338 .316 .325 .241 .247 .345 .249 .355
177 177 244 254 339 156
3098 2903
.022 .405 .532 .438 .501 .579
.326 .361
176 .226 .364 .524 151 .256 .385 .542
TransE+DAN .229 RotatE+DAN .481
.391 .492 .464 .440 .440 .460 .485 .494
.224 .241 .217 .237 .155 .158 .255 .151 .263
.359 .375 .335 .356 .263 .275 .380 .264 .390
.522 .533 .484 .501 .419 .428 .528 .417 .541
Link Prediction. We combine the DAN method with TransE and RotatE named TransE+DAN, RotatE+DAN respectively. The methods compared with our model are TransE, RotatE, TorusE, DistMult, ComPlEx, ConvE, M-DCN, R-GCN and HRAN. The experimental results are shown in Table 3. The best score is in bold and the second-best score is underlined. RotatE, TorusE, ConvE, M-DCN, R-GCN, HRAN, and RotatE experimental results were taken from the original paper for optimal results. The experimental results show that the proposed method improves the performance of the original model. Compared with the original model, our method improves the accuracy of link prediction on both datasets. It indicates that the proposed method of dynamic aggregation of neighbor information can improve the performance of link prediction. Neighborhood information transformation and dynamic aggregation of neighborhood information are important for enriching the representation of central nodes. Our results are still better than most of the current advanced methods with almost constant parameters. The HRAN method is comparable to our method in terms of effectiveness, and our method is better in both MRR and H@10 metrics. Meanwhile HRAN is a GCN-based method with higher complexity. To further investigate the reasons for the significant improvement of the effect of this model, this paper sets up experiments on H@10 evaluation metrics under muti-fold relations on the WN18RR dataset. There are many complex relations in the knowledge graph, which can be classified as 1-to-1, 1-to-N, Nto-1 and N-to-N relations. The WN18RR dataset contains two 1-to-1 relations, six 1-to-N relations, one N-to-1 relation, and two N-to-N relations. The results of the comparison experiments between the RotatE model and Rotate+DAN under different relations of WN18RR are shown in Table 4. The results shows
Improving Knowledge Graph Embedding Using Dynamic Aggregation
521
Table 4. Comparison of H@10 under different relations on the WN18RR dataset Relation Categories Realion Names
Methods RotatE +DAN
1-to-1
verb group similar to
0.974 1
0.974 1
1-to-N
hypernym member meronym has part member of domain region member of domain usage instance hypernym
0.257 0.355 0.322 0.211 0.292 0.529
0.280 0.445 0.386 0.385 0.396 0.495
N-to-1
synset domain topic of
0.403
0.486
N-to-N
derivationally related form 0.967 also see 0.661
0.971 0.732
that our model performs much better than the Rotate model in dealing with muti-fold relations (1-to-N, N-to-1 and N-to-N). It illustrates that by dynamically aggregating the node neighbor information, the head and tail entities can dynamically aggregate the neighbor information based on the semantic information of the relation. Thus each entity presents different semantic information under different relations, which enriches the expression of entity information and improves the performance of the model. Table 5. Ablation experiments on the WN18RR and FB15k-237 dataset WN18RR MRR MR
FB15k-237 H@1 H@3 H@10 MRR MR H@1 H@3 H@10
TransE TransE+N TransE+RN TransE+DAN
.218 .221 .227 .229
3408 3137 3107 3098
.012 .015 .021 .022
.390 .393 .402 .405
.523 .520 .528 .532
.322 .316 .321 .326
177 178 174 176
.224 .216 .222 .226
.359 .355 .361 .364
.522 .515 .518 .524
RotatE RotatE+N RotatE+RN RotatE+DAN
.463 .214 .452 .472
3255 3298 2969 2903
.410 .015 .396 .438
.484 .379 .479 .488
.567 .514 .573 .575
.319 .315 .330 .332
183 182 173 173
.222 .214 .234 .234
.356 .354 .366 .369
.515 .518 .524 .527
Ablation Experiments. To verify the effectiveness of the model proposed in this paper, the above two methods are combined with dynamic aggregation of neighbors on the WN18RR and FB15k-237 dataset respectively, and ablation experiments are done. +N indicates that the neighbor information is aggregated directly to the central node without considering the relation information. +RN stands for transforming the neighbor information to the homogeneous space
522
G. Wang et al.
with the central node by relational information, and each neighbor information has the same weight by averaging. +DAN is the dynamic aggregated neighbor information model proposed in this paper. The experimental results are shown in Table 5. The experimental results show that the +N method is less effective compared to the original model, which is more obvious on RotatE+N. It indicates that the direct aggregation of neighbor information will blur the information representation of the central node. Therefore it is important to transform the neighbor information to the homogeneous space with the central node through the relation. The experimental comparison of +DAN and +RN shows that it is important to give different weights to the neighbor information under different triadic relations. This illustrates the effective of the dynamic aggregation of neighbor information approach. Meanwhile, the +RN method shows a slight improvement in the results compared to the original model, which illustrates the importance of neighbor information transformation through relations. Model Analysis. The method proposed in this paper has the advantage of fast convergence at the same time. Compared with the original model, the convergence speed is accelerated because the present model aggregates the neighbor information of each entity. As shown in Fig. 4, we combined the TransE and RotatE models with the DAN method, respectively. Compared with the original model, the DAN method has a faster convergence rate. The loss decreases quickly in the pre-training period, which is especially evident on the RotataE model. This shows the efficiency of the dynamic aggregation of neighbor information method. The dynamic aggregation of neighbor information approach accelerates the convergence of the model by making the neighbor nodes converge to be similar due to the aggregation of node neighbor information.
Fig. 4. Comparison of loss curves on the WN18RR and FB15k-237 data sets
Improving Knowledge Graph Embedding Using Dynamic Aggregation
5
523
Conclusion
We propose a method to improve the knowledge graph embedding by dynamically aggregating the neighbor information. For neighbor information in heterogeneous graphs, it is important to transform the neighbor information to the homogeneous space with the central node through the relation. For the updated neighbor information, we use a dynamical aggregation method to make the same entities aggregated to different neighbor information under different triadic relations. Combining this approach with the TransE and RotatE models achieves better results on the two public data sets. And the number of parameters of our model increases only slightly compared to the original model, but the speed of model convergence increases rapidly. We believe that this approach can also be combined with other embedding models to improve the performance of the original model. Acknowledgement. This work was supported by the National Natural Science Foundation of China (Grant No.61872107) and Guangdong Provincial Key Laboratory of Novel Security Intelligence Technologies (2022B1212010005).
References 1. Bollacker, K., Evans, C., Paritosh, P., Sturge, T., Taylor, J.: Freebase: a collaboratively created graph database for structuring human knowledge. In: Proceedings of the 2008 ACM SIGMOD International Conference on Management of Data, pp. 1247–1250 (2008) 2. Bordes, A., Usunier, N., Garcia-Duran, A., Weston, J., Yakhnenko, O.: Translating embeddings for modeling multi-relational data. In: Advances in Neural Information Processing Systems, vol. 26 (2013) 3. Cai, L., Yan, B., Mai, G., Janowicz, K., Zhu, R.: TransGCN: Coupling transformation assumptions with graph convolutional networks for link prediction. In: Proceedings of the 10th International Conference on Knowledge Capture, pp. 131–138 (2019) 4. Dettmers, T., Minervini, P., Stenetorp, P., Riedel, S.: Convolutional 2d knowledge graph embeddings. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 5. Ebisu, T., Ichise, R.: Generalized translation-based embedding of knowledge graph. IEEE Trans. Knowl. Data Eng. 32(5), 941–951 (2019) 6. Gao, C., Sun, C., Shan, L., Lin, L., Wang, M.: Rotate3d: representing relations as rotations in three-dimensional space for knowledge graph embedding. In: Proceedings of the 29th ACM International Conference on Information & Knowledge Management, pp. 385–394 (2020) 7. Huang, X., Zhang, J., Li, D., Li, P.: Knowledge graph embedding based question answering. In: Proceedings of the twelfth ACM International Conference on Web Search and Data Mining, pp. 105–113 (2019) 8. Ji, G., He, S., Xu, L., Liu, K., Zhao, J.: Knowledge graph embedding via dynamic mapping matrix. In: Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (volume 1: Long Papers), pp. 687–696 (2015)
524
G. Wang et al.
9. Ji, S., Pan, S., Cambria, E., Marttinen, P., Yu, P.S.: A survey on knowledge graphs: representation, acquisition, and applications. IEEE Trans. Neural Netw. Learn. Syst. 33(2), 494–514 (2022). https://doi.org/10.1109/TNNLS.2021.3070843 10. Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016) 11. Lehmann, J., et al.: DBpedia-a large-scale, multilingual knowledge base extracted from Wikipedia. Semant. Web 6(2), 167–195 (2015) 12. Lin, Y., Liu, Z., Sun, M., Liu, Y., Zhu, X.: Learning entity and relation embeddings for knowledge graph completion. In: Twenty-Ninth AAAI Conference on Artificial Intelligence (2015) 13. Miller, G.A.: Wordnet: a lexical database for English. Commun. ACM 38(11), 39–41 (1995) 14. Nguyen, D.Q., Nguyen, T.D., Nguyen, D.Q., Phung, D.: A novel embedding model for knowledge base completion based on convolutional neural network. arXiv preprint arXiv:1712.02121 (2017) 15. Nickel, M., Tresp, V., Kriegel, H.P.: A three-way model for collective learning on multi-relational data. In: ICML (2011) 16. Schlichtkrull, M., Kipf, T.N., Bloem, P., van den Berg, R., Titov, I., Welling, M.: Modeling relational data with graph convolutional networks. In: Gangemi, A., Navigli, R., Vidal, M.-E., Hitzler, P., Troncy, R., Hollink, L., Tordai, A., Alam, M. (eds.) ESWC 2018. LNCS, vol. 10843, pp. 593–607. Springer, Cham (2018). https://doi.org/10.1007/978-3-319-93417-4 38 17. Song, T., Luo, J., Huang, L.: Rot-pro: modeling transitivity by projection in knowledge graph embedding. In: Advances in Neural Information Processing Systems, vol. 34 (2021) 18. Suchanek, F.M., Kasneci, G., Weikum, G.: YAGO: a core of semantic knowledge. In: Proceedings of the 16th International Conference on World Wide Web, pp. 697–706 (2007) 19. Sun, Z., Deng, Z.H., Nie, J.Y., Tang, J.: Rotate: knowledge graph embedding by relational rotation in complex space. arXiv preprint arXiv:1902.10197 (2019) 20. Toutanova, K., Chen, D.: Observed versus latent features for knowledge base and text inference. In: Proceedings of the 3rd Workshop on Continuous Vector Space Models and their Compositionality, pp. 57–66 (2015) ´ Bouchard, G.: Complex embed21. Trouillon, T., Welbl, J., Riedel, S., Gaussier, E., dings for simple link prediction. In: International Conference on Machine Learning, pp. 2071–2080. PMLR (2016) 22. Wang, X., Wang, D., Xu, C., He, X., Cao, Y., Chua, T.S.: Explainable reasoning over knowledge graphs for recommendation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 5329–5336 (2019) 23. Wang, Z., Zhang, J., Feng, J., Chen, Z.: Knowledge graph embedding by translating on hyperplanes. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 28 (2014) 24. Yang, B., Yih, W.t., He, X., Gao, J., Deng, L.: Embedding entities and relations for learning and inference in knowledge bases. arXiv preprint arXiv:1412.6575 (2014) 25. Zhang, S., Tay, Y., Yao, L., Liu, Q.: Quaternion knowledge graph embeddings. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 26. Zhang, Z., Cai, J., Zhang, Y., Wang, J.: Learning hierarchy-aware knowledge graph embeddings for link prediction. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 3065–3072 (2020)
Improving Knowledge Graph Embedding Using Dynamic Aggregation
525
27. Zhang, Z., Wang, J., Ye, J., Wu, F.: Rethinking graph convolutional networks in knowledge graph completion. In: Proceedings of the ACM Web Conference 2022, pp. 798–807 (2022) 28. Zhang, Z., Li, Z., Liu, H., Xiong, N.N.: Multi-scale dynamic convolutional network for knowledge graph embedding. IEEE Trans. Knowl. Data Eng. (2020)
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features Weimin Sun2 and Gang Yang1,2(B) 1
Key Lab of Data Engineering and Knowledge Engineering, Beijing, China 2 School of Information, Renmin University of China, Beijing, China [email protected]
Abstract. Although generalized zero-shot learning (GZSL) has achieved success in recognizing images of unseen classes, most previous studies focused on feature projection from one domain to another, neglecting the importance of semantic descriptions. In this paper, we propose auxiliary-features via GAN(Af-GAN) to deal with the semantic loss problem. Auxiliary-features contain both real features of seen classes and instructive-features mapped by attributes. For the seen classes, we deploy the auxiliary-features to train the generator and regularize the synthesized samples to be close to auxiliary-features. For the unseen classes, we take the instructive-feature mapped by attributes to synthesize unseen class samples for training the final classifier. We construct the constraint between real visual features and instructive-features, and reduce the dependence on the class attributes. Considering that features synthesized from a set of similar attributes overlap each other in visual space, we combine Cosine similarity and Euclidean distance to constrain the distribution of synthesized features. Our method outperforms stateof-the-art methods on four benchmark datasets and also surpasses prior work by a large margin in generalized zero-shot learning. Keywords: Zero-shot learning · Auxiliary-feature Instructive-feature · Semantic loss
1
·
Introduction
In fine-grained recognition, a large number of images are needed for training due to nuanced differences between classes. However, the collection and annotation of data require substantial professional guidance, but the availability of images is limited, so this process demands a lot of time and resources. To solve the problem of identifying new classes without training samples, the zero-shot learning method is proposed to recognize images of unseen classes. Concretely, in This research was supported by the Fundamental Research Funds for the Central Universities, and the Research Funds of Renmin University of China (20XNA031).The computer resources were provided by Public Computing Cloud Platform of Renmin University of China. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 526–537, 2023. https://doi.org/10.1007/978-981-99-1639-9_44
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
527
zero-shot learning(ZSL) [1,4–6], a model is trained on a set of images from seen classes, and then images of unseen classes are used to test the trained model, where seen and unseen classes are disjoint. In addition, the ZSL method uses semantic information to bridge the gap between seen and unseen classes [3], such as semantic attributes or word vector space. In real-world scenarios, the images that need to be recognized are coming from both seen and unseen classes, which are not applicable in conventional zero-shot learning, where all test images are coming from unseen classes. The generalized zero-shot learning(GZSL) [7–9,24] method is proposed to address such a problem, where the label space contains both seen and unseen classes during testing. However, GZSL causes a serious domain shift problem where the prediction is more likely to be seen classes. In GZSL, Generative Adversarial Network(GAN) is one of the most important approaches generating unseen class samples from random noises guided by semantic descriptions [19–21,27]. As the only guidance for generating samples, semantic descriptions play an important role. However, there are two problems with only using attributes as guidance, namely, image differences of attributes and semantic loss. The first is that for some classes, semantic descriptions are similar, but other visual features are completely different. For instance, “black and white stripes”, “four legs” and “hairy” can describe zebras, pandas, or Bengal white tigers, but these animals are quite different in visual space. The second is the problem of semantic loss. Specially, it happens because during training the model learns from the seen classes, and some attributes can not help them distinguish the seen classes and thus are ignored. However, during testing, those attributes that were ignored can help distinguish unseen classes. Therefore, using only semantic information to synthesize high-quality instances is not enough. In this paper, we propose to supplement semantic attributes with auxiliary-features which consider the defects and generate high-quality target samples that trigger a class label. Auxiliary-features contain real features of seen classes and instructive-features, which are all useful to support the sample generation of unseen classes. In this paper, we map the class semantic attributes to instructive-features to avoid semantic loss. Unlike the class exemplars [10], in addition to constructing the interconnection between real visual features and instructive-features, we also adopt the idea of mutual information to assess the degree of correlation, to reduce the dependence on the class attributes. Furthermore, the generator of the GAN model in our method takes random noises conditioned by semantic attributes and auxiliary-features as input. The real features and instructive-features are to guide the generator in turn rather than simultaneously, so that the synthesized features can achieve a balance between real features and instructive-features. For the seen classes, we deploy auxiliary-features to train the generator and regularize the synthesized samples to resemble auxiliary-features. For the unseen classes, we take the instructive-features to synthesize unseen class samples. Then we obtain the synthesized features of unseen classes, and the problem is transformed into a general supervised machine learning problem. In this paper, our method is firstly using auxiliary-features to assist semantic attributes to synthesize features
528
W. Sun and G. Yang
of unseen classes. Compared with semantic attributes used to generate features of unseen classes, auxiliary-features are complementary to semantic information in visual space, so that the synthesized features are more reasonable. Moreover, we combine a Cosine Similarity and Euclidean Distances loss to constrain the distribution of generated samples in visual space, which forms effective guidance to generate reasonable features of unseen classes. Experimental results on four benchmark datasets show that our method outperforms state-of-the-art methods on both ZSL and GZSL.
Fig. 1. An overview of our Af-GAN method(Auxiliary-features via GAN) for GZSL. We train a mapping function M to get instructive-features from attributes and construct a correlation constraint between instructive-features and real features. Then, the generator alternates the auxiliary-features and real features of seen classes as inputs and outputs target features. The arrow line with number represents one input channel respectively, and switch label represents only one of them is selected at a time. Finally, the discriminator D is learned to distinguish whether the input is true or false and a classifier F is trained on the synthesized unseen classes features and tested on real features.
2
Related Work
Zero-shot learning aims to recognize images of unseen classes with the help of semantic information, such as semantic attributes. As seen classes and unseen classes are disjoint, semantic attributes are the main bridge between them [3]. Lampert et al. [2] tackle the problem by introducing attribute-based classification. They propose a Direct Attribute Prediction(DAP) model which maps features into attribute embedding through multiple kernel-based regressors and then gets corresponding labels. Elyor et al. [11] introduce the encoder-decoder paradigm where an encoder projects a visual feature vector into the semantic space and the decoder reconstructs the original visual feature from semantic space. Pengkai et al. [12] propose a novel low-dimensional embedding of visual instances that is “visually semantic” to bridge the semantic gap. Devraj et al. [13] introduce an out-of-distribution detector that determines whether the video features belong to a seen or unseen action category.
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
529
GZSL is a more challenging task, where seen and unseen class samples are both included during testing to solve this problem [15]. Ye et al. [7] propose a novel progressive ensemble network with multiple projected label embeddings, naturally alleviating the domain shift problem in visual appearances. Kampffmeyer et al. [8] propose a Dense Graph Propagation (DGP) module with carefully designed direct links among distant nodes to exploit the hierarchical graph structure of the knowledge graph through additional connections. Schonfeld et al. [9] take feature generation one step further and propose a model where a shared latent space of image features and class embeddings is learned by modality-specific aligned variational autoencoders. The key is that the distributions learned from images and side-information construct latent features that contain the essential multi-modal information associated with unseen classes. Huynh et al. [14] propose a dense attribute-based attention mechanism that focuses on the most relevant image regions of each attribute, obtaining attribute-based features. However, these methods do not take into account the image differences of attributes and semantic loss, which will challenge the GZSL classification further.
3
Our Proposed Method
The notations and definitions of ZSL and GZSL problems are presented firstly. Then our method is introduced overall and in detail. Suppose {S = (x, y, a)|x ⊆ X s , y ⊆ Y s , a ⊆ A} represents the sets of seen class samples, where x is the visual features, y is corresponding to class label of x, and a denotes a set of attributes of class y. In ZSL setting, we have unseen class samples represented by {U = (xu , y u , au )|xu ⊆ X u , y u ⊆ Y u , au ⊆ A}. Different from traditional image classification, seen and unseen classes are disjoint, which means Y s ∩ Y u = φ. The goal of zero-shot learning is to predict the label of sample xu in Y u . And the goal of GZSL is to predict in seen and unseen classes. 3.1
Overall Procedure of Our Method
Unlike previous methods, we propose auxiliary-features via GAN(Af-GAN) to address the problem of image differences of attributes and semantic loss in ZSL and GZSL. We take advantage of Conditional GAN to generate fake features for unseen classes. The specific process of our method is shown in Fig. 1. Firstly, we train a mapping function M to get instructive-features from attributes. Unlike the class prototype features, we construct a correlation constraint between instructive-features and real features. Instructive-features and real features of seen classes constitute auxiliary-features. Then the generator G takes random noises as inputs conditioned by auxiliary-feature and attributes to generate target features. Notably, the instructive-features and real features are taken in turn as inputs to the generator G . Finally, the discriminator D learns to distinguish whether its input is true or false, and the classifier F is trained on the synthesized unseen class features and tested on real features.
530
3.2
W. Sun and G. Yang
Learning the Instructive-Features
In order to build the mapping function to produce instructive-features from attributes, we adopt the idea of the class exemplar [10], meaning that semantic representations are predictive of the locations of their corresponding visual exemplars. More specially, given the samples (x, y, a), we would like to train a function M such that zk = M(ak ), where ak is the attribute and zk is the synthesized instructive-feature for class k. We hope that the instructive-feature is close to the central feature of the class, which is defined as: LCL =
N 1 2 ck − M(ak )2 N
(1)
k=1
Nk xi , is the center features where N is the number of seen classes and ck = N1k i=1 of each class. Nk is the number of samples in class k. According to Eq. 1, the optimal function M would be the one that always produces exactly the center-features for each class, which is inconsistent with the definition of instructive-features. To make sure that the instructed-features are highly-related to real features rather than center features, we adopt the idea of mutual information, which refers to the degree of correlation between two random variables, meaning that some information can be conveyed from real features to instructive-features. The constraint between real features and instructive-features can be defined as: LM I =
x∈X z∈Z
p(x, z)log
p(x, z) p(x)p(z)
(2)
where p(·) is the probability distribution function, and Z be the instructivefeatures. Since calculating the mutual information with high dimension is intractable, we adopt the strategy of [15] to use a variational upper bound as a surrogate, and the variational upper bound is estimated using the reparameterization trick. With the constraints of Eqs. 1 and 2, the instructive-features of each class retain the information of attributes and real features and provide more informative guidance for subsequent sample generation. The instructive-features can be used to perform zero-shot classification, where we use them to train a nearest neighbor classifier that outputs the label of the closest exemplar for each novel data point. In the next section, we will introduce our Af-GAN to generate high-quality instances with the help of instructivefeatures. 3.3
Auxiliary-features via GAN(Af-GAN)
Previous methods [3,19,20] focus on the constraints of the generated features and ignore the lost information of attributes. In this section, we exploit the auxiliary-features to synthesize discriminative features. Auxiliary-features contain real features of seen classes and instructive-features mapped by M, such
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
531
that X = Z ∪X s . In addition, our method is based on Conditional GAN [16]. The generator G in our method uses the random noises n˙ as inputs conditioned by ˜k semantic attributes ak and instructive-features zk to synthesize fake features x of the class k, which are defined as x ˜k = G(ak , zk , n). ˙ Generally, x ˜ = G(a, z, n). ˙ At the same time, the discriminator D distinguishes features of real images from generated features x ˜. The loss of GAN could be formulated as follows: LGD = E[logD(x)] + E[1 − logD(˜ x))]
(3)
where the first item evaluates the true samples and the second item evaluates the synthesized samples. In order to make the synthesized features more distinguished, we cluster the features of the same class and regulate that the generated features should be close to the central features and instructive-features, which is defined as: LCZ =
N 1 2 2 zk − x ˜k 2 + ck − x ˜k 2 2N
(4)
k=1
Nk where ck = N1k i=1 xi . According to Eq. 3 and Eq. 4, the synthesized features will revolve around central features and instructive-features. To further constrain the synthesized features, we add real features of seen classes as input to GAN. In this way, we take the adversarial objective of feature generation defined in Eq. 3 as follows. ˙ (5) LGD = E[1 − logD(G(a, x, n))] Notably, we utilize instructive-features to synthesize the features of unseen samples instead of real unseen samples. Instructive-features and real sample features can provide neglected semantic attribute information as guidance and help to generate high-quality samples. In visual space, due to subtle differences of attributes between two classes, a part of the generated features will be interlaced with each other. To eliminate the influence of these features, our goal is to keep the generated features away from the conflict area. In other words, keep the generated features far away from the features of the most similar class. In information retrieval, the text similarity is usually calculated by cosine similarity. Because cosine similarity has nothing to do with the length of the vector [25] and is only related to the direction of the vector, we can use it to measure the difference between two features in the direction. Euclidean Distance [26] refers to the real distance between two points in n-dimensional space, which can reflect the absolute difference of individual features. We hope that the features of different classes can be distinguished in direction and distance. Formally, we introduce the following regularization: LCE =
N 1 C(˜ xk − ck , x ˆk − ck )E(˜ xk , x ˆk ) N
(6)
k=1
where C represents the cosine similarity and E represents Euclidean Distance. In Eq. 6, we take the synthesized features as x ˜k and the center-features of its
532
W. Sun and G. Yang
Table 1. Average accuracy comparison of different methods on four datasets(AWA, CUB, APY and SUN) under ZSL. Methods
AWA CUB APY SUN
SSE [17] 60.1 44.1 DAP [18] 70.6 LisGAN [19] 68.4 SAE [11] f-CLSWGAN [20] 68.2 OCD [21] 58.2 ESZSL [22] 45.6 CONSE [23] Af-GAN[Ours]
43.9 40.0 58.8 33.3 57.3 60.3 53.9 34.3
34.0 33.8 43.1 8.3 40.5 38.3 26.9
51.5 39.9 61.7 40.3 60.8 63.5 54.5 38.8
71.8 62.4 47.5 61.9
most similar class as x ˆk . Here, the nearest neighbor of attributes is adopted as the criterion to select the most similar class. With the regularization LCE , we provide the direction for the generated features, and the features of unseen classes guided by instructive-features cluster together and separate from the features of its most similar class in visual space. The final objective function of Af-GAN can be represented by Eq. 7, using α and β as hyper-parameters for weighting the respective losses: Lloss = LGD + αLCZ + LGD + βLCE
4
(7)
Experiments
To evaluate our method, we conduct experiments on 4 benchmark datasets [3]: Caltech-UCSD Birds-200-2011 (CUB), Animals with Attributes (AWA), Attribute Pascal and Yahoo(APY) and Scene UNderstanding(SUN). The hyperparameters are set experientially with α = 0.001 and β = 40. 4.1
On ZSL
To demonstrate the superiority of our method, we compare our results with those of the other 8 methods studied in recent years. Table 1 shows the ZSL results on the 4 datasets. From Table 1, we can see that our method could achieves SOTA classification accuracies with 71.8%, 62.4%, 47.5% and 61.9% on AWA,CUB, APY and SUN separately. Besides, the accuracy on AWA, CUB and APY improves by 1.2%, 2.1% and 4.4% respectively. Therefore, the results show our method using auxiliary-features as guidance could achieve better efficiency on ZSL. 4.2
On GZSL
Table 2 summarizes the results on 4 datasets in GZSL setting. GZSL is difficult, and normally the methods with good results in ZSL settings perform poor in
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
533
Fig. 2. Visualization of the synthetic feature distributions of LisGAN [19] and our Af-GAN on APY.
GZSL. As we can see in Table 2, our method achieves clear improvements compared with other methods in the GZSL setting. The harmonic mean (H) on AWA, CUB, SUN, APY are 63.4%, 54.3%, 40.8%, and 48.0%, respectively. In terms of the harmonic mean, we achieve up to 1.1%, 2.7%, and 2.3% improvements on AWA, CUB, and APY respectively, which is better than the existing SOTA methods. Some methods, such as CONSE, ESZSL, SAE, DAP and SSE, have high accuracy in seen classes, but not in unseen classes. In contrast, our method achieves a good balance between seen classes and unseen classes, to obtain the best classification results. 4.3
Ablation Study and Experiment Analysis
To provide a qualitative evaluation, we visualize the synthesized features generated by our Af-GAN method on APY. The visualization comparison results of our method and LisGAN which gets the optimal result on APY are illustrated in Fig. 2(a) and Fig. 2(b). Specifically, for each unseen class, we synthesize 300 images features. As shown in Fig. 2(a), some classes overlap with each other by a large degree, such as “cow” and “donkey”. The overlapping is reasonable since “cow” and “donkey” have a similar set of attribute descriptions, and thus the synthesized samples are close to each other. But each class can not be centralized distribution, so it is difficult to classify. However, as shown in the Fig. 2(b), the synthesized features guided by auxiliary-features are concentrated and can be easily separated. From the visualization of “cow” and “donkey”, we can see that our method generates good visual features that match the real ones quite well, eliminating the overlapping areas. Even if the features of each class are closely clustered, they can be more easily classified than the features generated by LisGAN. In our Af-GAN method, the number of synthesized samples determines the subsequent classification effect. We select the best classification results of AWA, CUB and APY as the sample selection criteria to show the effect. The details are
534
W. Sun and G. Yang
Table 2. The results are evaluated by the average per-class top-1 accuracy on four datasets SUN, APY, AWA, CUB under GZSL. U and S are the Top-1 accuracies tested on unseen classes and seen classes, respectively, in GZSL. H is the harmonic mean of U and S. Dataset
AWA
Methods
U
S
H
U
CUB S
H
U
SUN S
H
U
APY S
H
SSE [17] DAP [18] LisGAN [19] OCD [21] SAE [11] f-CLSWGAN [20] ESZSL [22] SJE [28] ALE [29] JGM-ZSL [30] SE-ZSL [31] RAS+CGAN [32] CONSE [23]
7.0 0.0 52.6 1.1 57.9 5.9 11.3 16.8 62.7 56.3 0.4
80.5 88.7 76.3 82.2 61.4 77.8 74.6 76.1 60.6 67.8 88.6
12.9 0.0 62.3 2.2 59.6 11.0 19.6 27.5 61.6 61.5 0.88
8.5 1.7 46.5 59.9 0.4 43.7 2.4 23.5 23.7 42.7 41.5 31.5 1.6
46.9 67.9 57.9 44.8 80.9 57.7 70.1 59.2 62.8 45.6 53.3 40.2 72.2
14.4 3.3 51.6 51.3 0.9 49.7 4.6 33.6 34.4 44.1 46.7 35.3 3.1
2.1 4.2 42.9 42.9 8.8 42.6 11.0 3.7 4.6 31.1 6.8
36.4 25.2 37.8 44.8 18.0 36.6 27.9 55.7 73.7 43.3 39.9
4.0 7.2 40.2 43.8 11.8 39.4 15.8 6.9 8.7 36.2 11.6
0.2 4.8 34.3 0.4 32.9 2.4 14.7 21.8 44.4 40.9 41.2 0.0
78.9 78.3 68.2 80.9 61.7 70.1 30.5 33.1 30.9 30.5 26.7 91.2
0.4 9.0 45.7 0.9 42.9 4.6 19.8 26.8 36.5 34.9 32.4 0.0
Af-GAN[Ours]
53.5 77.6 63.4 49.3 60.4 54.3 45.2 37.1 40.8 37.1 37.1 48.0
shown in Fig. 3 (left). In Fig. 3(left), it can be found that for the same dataset, the number of generated samples affects the classification results. The least and the most samples are not the best choice. For AWA, generating 150 samples is the best choice. For CUB, the best choice is 400. At the same time, it also reflects that the selection of samples among different categories is different. CUB is a fine-grained data set, and the difference between each kind of samples is small, so it needs more generated samples to distinguish. For AWA, each class is from animal species, and the visual features are different from each other, so the number needed is relatively small. In Af-GAN, the mapped instructive-features Z and the synthesized features x ˜ are two important factors affecting the method efficiency. Ablation experiments about two factors effect are studied on the 4 benchmark datasets, and the experiment results are shown in Fig. 3 (right) . In Fig. 3 (right), the baseline method without instructive-features Z and the synthesized features x ˜ obtains 66.4%, 52.4%, 40.8%, and 56.9% accuracy on the AWA, CUB, APY, and SUN datasets. With the effect of instructive-features, the accuracies on ZSL are improved, and they are 68.3%, 56.7%, 43.4%, and 61.6%, respectively. The results reveal that compared with attributes, instructive-features contain more information of class and can provide more guidance information for generating samples. Finally, our Af-GAN method with the synthesized features x ˜ obtains further improved accuracy, about 71.8%, 62.2%, 47.5% and 61.9%, which verifies the effectiveness of synthesized features on ZSL further.
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
535
Fig. 3. Sampling effect (left) and ablation study (right).
5
Conclusion
In this paper, we propose auxiliary-features via GAN for generalized zero-shot recognition. Specially, we deploy Conditional GAN to synthesize fake unseen samples from random noises conditioned by semantic attributes and auxiliaryfeatures. Here auxiliary-features contain both real features of seen classes and instructive-features mapped by attributes. We construct the constraint between real visual features and instructive-features, and reduce the dependence on the class attributes. Besides, we propose to combine cosine similarity and Euclidean distance to constrain the distribution of synthesized features to eliminate the overlapping area in visual space. Experiments on four standard datasets show that our method outperforms state-of-the-art methods in ZSL and has made a great improvement on other methods in GZSL.
References 1. Akata, Z., Perronnin, F., Harchaoui, Z., et al.: Label-embedding for image classification. IEEE Trans. Pattern Anal. Mach. Intell. 38, 1428–1435 (2016) 2. Lampert, C.H., Nickisch, H., Harmeling, S.: Learning to detect unseen object classes by between class attribute transfer. In: IEEE Conference on Computer Vision and Pattern Recognition (2009) 3. Liu, J., Zhang, Z., Yang, G.: Cross-class generative network for zero-shot learning. Inf. Sci. 555, 147–163 (2021) 4. Liu, L., Zhou, T., Long, G., Jiang, J., Dong, X., Zhang, C.: Isometric propagation network for generalized zero-shot learning. In: International Conference on Learning Representations (2021) 5. Bo, L., Dong, Q., Hu, Z.: Hardness sampling for self-training based transductive zero-shot learning. In: Computer Vision and Pattern Recognition, pp. 16499–16508 (2021) 6. Xie, C., Xiang, H., Zeng, T., Yang, Y., Yu, B., Liu, Q.: Cross knowledge-based generative zero-shot learning approach with taxonomy regularization. Neural Netw. 139, 168–178 (2021)
536
W. Sun and G. Yang
7. Ye, M., Guo, Y.: Progressive ensemble networks for zero-shot recognition. In: IEEE Conference on Computer Vision and Pattern (2019) 8. Kampffmeyer, M., Chen, Y., Liang, X., et al.: Rethinking knowledge graph propagation for zero-shot learning. In: IEEE Conference on Computer Vision and Pattern (2019) 9. Schonfeld, E., Ebrahimi, S., Sinha, S., et al.: Generalized zero- and few-shot learning via aligned variational autoencoders. In: IEEE Conference on Computer Vision and Pattern (2019) 10. Changpinyo, S., Chao, W.L., Sha, F.: Predicting visual exemplars of unseen classes for zero-shot learning. In: IEEE Conference on Computer Vision and Pattern (2016) 11. Kodirov, E., Xiang, T., Gong, S.: Semantic autoencoder for zero-shot learning. In: IEEE Conference on Computer Vision and Pattern Recognition (2017) 12. Zhu, P., Wang, H., Saligrama, V.: Generalized zero-shot recognition based on visually semantic embedding. In: IEEE Conference on Computer Vision and Pattern (2020) 13. Mandal, D., Narayan, S., Dwivedi, S., et al.: Out-of-distribution detection for generalized zero-shot action recognition. In: IEEE Conference on Computer Vision and Pattern (2019) 14. Huynh, D., Elhamifar, E.: Fine-grained generalized zero-shot learning via dense attribute-based attention. In: IEEE Conference on Computer Vision and Pattern (2019) 15. Han, Z., Fu, Z., Yang, J.: Learning the redundancy-free features for generalized zero-shot object recognition. In: IEEE Conference on Computer Vision and Pattern (2020) 16. Mirza, M., Osindero, S.: Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784 (2014) 17. Zhang, Z., Saligrama, V.: Zero-shot learning via semantic similarity embedding. In: IEEE Conference on Computer Vision and Pattern (2015) 18. Lampert, C.H., Nickisch, H., Harmeling, S.: Attribute-based classification for zeroshot visual object categorization. IEEE Transactions on Pattern Analysis Machine Intelligence. 36, 435–453 (2014) 19. Li, J., Jin, M., Lu, K., et al.: Leveraging the invariant side of generative zero-shot learning. In: IEEE Conference on Computer Vision and Pattern (2019) 20. Xian, Y., Lorenz, T., Schiele, B., Akata, Z.: Feature generating networks for zeroshot learning. In: IEEE Conference on Computer Vision and Pattern (2018) 21. Keshari, R., Singh, R., Vatsa, M.: Generalized zero-shot learning via over-complete distribution. In: IEEE Conference on Computer Vision and Pattern (2020) 22. Romera-Paredes, B., Torr, P.H.S.: An embarrassingly simple approach to zeroshot learning. In: Proceedings of the 32nd International Conference on Machine Learning (2015) 23. Norouzi, M., et al.:. Zero-shot learning by convex combination of semantic embeddings. In: International Conference on Learning Representations (2014) 24. Paul, A., Krishnan, N.C., Munjal, P.: Semantically aligned bias reducing zero shot learning. In: IEEE Conference on Computer Vision and Pattern (2019) 25. Deng, J., Guo, J., Xue, N., Zafeiriou, S.: ArcFace: additive angular margin loss for deep face recognition. In: Computer Vision and Pattern Recognition, pp. 4690– 4699 (2019) 26. He, G., Li, F., Wang, Q., Bai, Z., Xu, Y.: A hierarchical sampling based triplet network for fine-grained image classification. Pattern Recogn. 115, 107889 (2021) 27. Ji, Z., Sun, Y., Yu, Y., Pang, Y., Han, J.: Attribute-guided network for cross-modal zero-shot hashing. IEEE Trans. Neural Netw. 31(1), 321–330 (2020)
Generative Generalized Zero-Shot Learning Based on Auxiliary-Features
537
28. Akata, Z., et al.: Evaluation of output embeddings for fine-grained image classification. In: IEEE Conference on Computer Vision and Pattern Recognition (2015) 29. Akata, Z., et al.: Label-embedding for image classification. In: IEEE Transactions on Pattern Analysis and Machine Intelligence (2016) 30. Gao, R., Hou, X., Qin, J., Liu, L., Zhu, F., Zhang, Z.: A joint generative model for zero-shot learning. In: Leal-Taix´e, L., Roth, S. (eds.) ECCV 2018. LNCS, vol. 11132, pp. 631–646. Springer, Cham (2019). https://doi.org/10.1007/978-3-03011018-5 50 31. Kumar Verma, V., Arora, G., Mishra, A., Rai, P.: Generalized zero-shot learning via synthesized examples. In: IEEE Conference on Computer Vision and Pattern (2018) 32. Zhang, H., Long, Y., Liu, L., Shao, L.: Adversarial unseen visual feature synthesis for zero-shot learning. Neurocomputing 329, 12–20 (2019)
Learning Stable Representations with Progressive Autoencoder (PAE) Zhouzheng Li1 , Dongyan Miao1 , Junfeng Gao2 , and Kun Feng1(B) 1
Key Lab of Engine Health Monitoring-Control and Networking of Ministry of Eduation, Beijing University of Chemical Technology, Beijing 100029, China [email protected] 2 PetroChina Refining and Chemicals Branch, Beijing 100120, China [email protected]
Abstract. Autoencoder, which compresses the information into latent variables, is widely used in various domains. However, how to make these latent variables understandable and controllable is a major challenge. While the β-VAE family is aiming to find disentangled representations and acquire human-interpretable generative factors like what independent component analysis (ICA) does in the linear domain, we propose Progressive Autoencoder (PAE), a novel autoencoder based model, as a correspondence to principal component analysis (PCA) in the non-linear domain. The main idea is to train an autoencoder with one latent variable first, then add latent variables progressively with decreasing weights to refine the reconstruction results. This brings PAE two remarkable characteristics. Firstly, the latent variables of PAE are ordered by the importance of a downtrend. Secondly, the latent variables acquired are stable and robust regardless of the network initial states. Since our main work is to analyze the gas turbine, we create a toy dataset with a custom-made non-linear system as a simulation of gas turbine system to test the model and to demonstrate the two key features of PAE. In light of PAE as well as β-VAE is derivative of Autoencoder, the structure of β-VAE could be easily added to our model with the capability of disentanglement. And the specialty of PAE could also be demonstrated by comparing it with the original β-VAE. Furthermore, the experiment on the MNIST dataset demonstrates how PAE could be applied to more sophisticated tasks. Keywords: Stable representation · Latent variable model · Autoencoders · machine learning · non-linear system analysis
1
Introduction
The Autoencoders (AE) have been used to automatically extract features from data without supervision for many years. Since then, a lot of work has been done to enhance this elegant structure. A particular enhancement direction focuses on extracting latent variables with properties that are useful for non-linear system analysis, like disentanglement or interpretability. Rifai et al. (2011) restrict the c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 538–549, 2023. https://doi.org/10.1007/978-981-99-1639-9_45
Learning Stable Representations with Progressive Autoencoder (PAE)
539
learned representations within contractive space and got a better robustness using contractive autoencoder [1]. Variational autoencoder [2] by Diederik P Kingma and Max Welling (2014) root in the methods of Variational Bayesian and graphical model, mapping the input into a distribution instead of individual variables. Diederik P. Kingma, et al., in the same year also introduced conditional VAE [3] to learn with labeled data and obtain meaningful latent variables with semi-supervised learning. Then β-VAE [4,5] is proposed by Irina Higgins, et al., (2017), trying to learn disentangled representations by strengthening the punishment of KL term with a hyperparameter beta and narrowing the information bottleneck. β-TC VAE [6] by TianQi Chen, et al., (2018) further refined their work until Babak Esmaeili, et al., (2019) unified all methods that modify the objective function with HFVAE [7]. From the GAN [8] family, some methods try to learn disentangled representations, a few milestones are conditional GAN (2014) [9] that involves label to input data, BiGAN (2017) [10] with a bidirectional structure to project data back to latent space, InfoGAN (2017) [11] that learns disentangled representations with information-theoretic extended GAN, and InfoGAN-CR (2020) [12] that references techniques from the β-VAE branch. Behind the observable variables of an unknown system, there are usually a few independent source variables whose changes represent the entire system. These variables are usually hidden from sight and physically difficult to acquire, but once dug out, and made clear their relationship to the observable variables, they can help us fully understand the system. Excavating them from the data, however, is extremely hard and sometimes impossible. For easier understanding, Fig. 1 displays the relationship between these variables and concepts. Against the background of gas turbine condition monitoring, which is the main work of our lab, the observable variables are performance parameters collected by sensors such as temperature, pressure, and etc. These parameters are entangled with each other and incomplete as well. To get a clear and intact presentation of the state of gas turbines, we proposed a new model to achieve these aims. In this paper, Progressive Autoencoder (PAE) is proposed with a new Progressive Patching Decoder (PPD) structure to learn representations with another unique property: stability. The representations learned by PAE are ranked by their importance for reconstruction, and for the same dataset PAE always provides a stable result because latent variables always learn the same features. The reconstruction error is progressively reduced by adding new latent variables until another new freedom degree could not contribute to the reduction of reconstruction error. Good robustness is achieved by introducing a denoising structure [13] to PAE using a dropout layer, this allows the PAE to perform a self-supervised learning task [14] to understand the non-linear system better. The highly flexible structure also enables PAE to work both on a supervised and unsupervised manner. Briefly, PAE has two outstanding features. First is that the latent variables are learned by decreasing order of importance and the second is the high stability of latent variables learned by PAE. Additionally, by combining of beta-VAE, the disentanglement ability could be added to PAE and form beta-PAE.
540
Z. Li et al.
Fig. 1. The relationships between observable variables, source variables, labels, latent variables, and representations of source variables.
2
Framework of Progressive Autoencoder
The structure of PAE is illustrated in Fig. 2. Similar to a classical autoencoder, PAE is made of 2 parts, an encoder and a decoder. The encoder of PAE can be further divided into 2 neural nets, Encoder1 is used for generating unsupervised latent variables zi , (i = 1, 2... n) and Encoder0 is prepared for supervised learning. Encoder0 could either provides z0 as a result of unsupervised learning or gives yˆ with the supervision of labels. The decoder of PAE is a multi-level decoder called Progressive Patching Decoder (PPD), it takes in not all the latent variable of z but one at a time to progressively patch the Median code mi with Process variable pi generated from z with neural nets named NNi, then followed by a normal Decoder to reconstruct original data x. The main objectives for Progressive Autoencoder are: 1. minimizing the distance (usually the mean square error (MSE)) between input data vector x and reconstructed data vector x ˆi ; 2. minimizing the label loss if any label is provided; 3. Minimizing the KL-divergence between prior and posterior distribution from VAE. What’s special about PAE is that instead of having one reconstructed result, we have n of them. Each of the reconstructed result x ˆi uses the first i latent variables, from z0 to zi , as if we are training n individual autoencoders with 1 ∼ n latent variables simultaneously with shared weights. 2.1
Encoder Network
Encoder1 is like the VAE Encoder. it learns to output a distribution as latent representations from the input data x. The final result of Encoder1, z, is sampled from the Gaussian distribution ∼ N (0, 1) with parameters μ and σ using
Learning Stable Representations with Progressive Autoencoder (PAE)
541
Fig. 2. Structure of Progressive Autoencoder.
the reparameterization trick proposed in VAE. μ and σ are the outputs of the Encoder1: z =μ+σ·
(1)
Encoder0 is prepared for semi-supervised learning. If looked separately, it’s exactly the same as a normal supervised Neural Network. The purpose of Encoder0 is for extracting labels (human-selected representations) from the data. When there is no label available, it can also work without supervision to extract the “Principal latent variable” of the data. The input of Encoder0 is the data vector x, and the output is either the predicted label vector yˆ which has the same dimension as the label or the principal latent variable z0 . When used with labels, the objective of Encoder0 is to minimize the difference between actual label y and output yˆ. With an individual loss function: 2
LEncoder0 = Ex,y∼P d (y − yˆ)
(2)
With continuous labels, Mean Squared Error is used to calculate the error between the predicted and actual label, for categorical labels, Cross Entropy can be used instead. To improve the robustness of the feature extractor Encoder0 and Encoder1, noise is added to the data x before it’s fed into the neural nets. To help handle data with missing values, a dropout layer is used in this case to add noise. The special layer can set a certain amount of input to 0, forcing the encoder nets to extract representations with fewer observable variables. 2.2
Progressive Patching Decoder
The latent vector z is generated one column at a time by the encoder network of PAE, each of them contributing a new degree of freedom to the reconstruction result. PPD takes in the n-dimensional z and produces n results using 1, 2... n latent variables. PPD is designed to be a multi-level structure where the newly
542
Z. Li et al.
added latent variable is used to “patch” the previous result and refine it by giving more information. From a regular feature fusion point of view, the “patching” process should either be done by concatenating the vectors, or adding them together. However, we cannot add the latent variables together directly, as it would limit the width of AE bottlenecks and provide fewer degrees of freedom. Concatenating also doesn’t work in this case, as the input shape for a neural network layer should normally be fixed, the structure with adaptive shapes and elements from the continual learning domain might work, but is unnecessarily difficult for this task. Achieving both utility and simplicity, PPD uses a hierarchical structure: the first latent variable is converted by a neural network NN0 into the base median vector m0 , the newly added latent variables are converted by NNi into the patcher vector pi and “patch” m0 one by one, then both the patched and under-patched median vectors are used for reconstruction. The median vectors should have dimensions bigger than n, so it can include all information from latent variables. The patching process can be done by simply adding the patcher vector to the median vector: mi = pi + mi−1
(3)
As we discovered later, a 2-step linear transformation achieves the same result and converges faster: (4) mi = pi1 + pi2 · mi−1 where pi 1 is called adder patcher vector and pi 2 is called multiplier patcher vector. The two patchers are both generated by NNi and have the same shape with mi . pi 1 is initialized to be a gaussian distribution that is 0-mean and pi 2 is to have 1-mean. NNi generates the progressive patcher for median code to calculate the refined median codes. The total amount of NNi is decided according to the task. The input for NNi is zi generated by Encoder, and the output is pi 1 and pi 2 the adder and multiplier patcher. 2.3
Training Method
When dealing with an unsupervised task, for each column of z, a reconstruction result x ˆi is generated by the PPD, and the goal is to minimize the differences between each x ˆi and input data x. A hyperparameter ξ is involved in the loss function to balance the optimization process, and to put equal effort on each reconstruction result - that the reconstruction error weighted by ξ should be roughly the same. According to experiment results, we decided that using a geometric series is appropriate, α is the common ratio of the geometric series, normally α = 23 . Therefore, the reconstruction loss term of the loss function is: n 1 − αi 2 +1 (5) LRL = Ex P d (x − x ˆi ) · ξ · 1−α i=0 Adding the KL divergence term that encourages the posterior q (z|x) to be as close to the Gaussian prior p (z), the complete Encoder loss function is:
Learning Stable Representations with Progressive Autoencoder (PAE)
Lcomplete = LRL +
n 1 · DKL [q (zi |x) p (zi )] n i=0
543
(6)
When it comes to a supervised task, the only difference is the training of Encoder0, the loss function is mentioned in Sect. 2.1. LEncoder0 is used to update the parameters of Encoder0 and Lcomplete responsible for other parameters of the model.
3
Experiments
To assess the quality of the proposed network, we conducted experiments on a toy dataset sampled from a custom-made non-linear system. Results on PAE with different settings are collected and compared to the more classical VAE and beta-VAE framework. Mutual Information [15] is used to measure the relativity of the representations extracted and generative factors. 3.1
Non-linear Toy Dataset
We want to present the full potential of Progressive Autoencoder in a clear yet not too easy or challenging way. And as our ultimate aim is condition monitoring of gas turbine, we decided to design a non-linear toy dataset and test our model on this dataset. This dataset could imitate the performance parameters of gas turbine which is under the influence of internal state and environmental influence. Furthermore, the toy datasets are easy to control and allow us to discover some interesting properties of PAE.
Fig. 3. This figure shows a few systems outputs changing accordingly to one input varying while other inputs are fixed to 0. Figure left shows the result without noise and figure right is what the dataset truly looks like when adding Gaussian noise with a std of 0.125.
For the dataset, we first created a mathematical model for the custom-made non-linear memoryless time-invariant system, and use the model to generate
544
Z. Li et al.
the data. Figure 3 shows the degree of non-linearity of the mappings used for the costume-made system. The system has 5 inputs (generative factors) and 48 outputs (observable variables), each output is determined by the 5 inputs under a random non-linear transformation with Gaussian noise. Each input from the toy dataset is not equivalent to each other, as the “importance” of each input variable is different, determined by its overall contribution to the outputs. A less “important” input variable would have a smaller impact to output variables with the same degree of changes, and the non-linearity of the impact is also reduced. In our toy datasets, the 1st and 2nd input variables (generative factors) S1, S2 have equal importance and S3 S5 have gradually decreased importance. The dataset is then generated from the model with 5 inputs independently sampled from truncated Gaussian distributions with a standard deviation of 1. 3.2
Experiment Settings
The architecture of Progressive Autoencoder used in the experiment session is shown in Table 1. Experiments for other frameworks are also conducted on this architecture, with different NNi unit amounts. In order to avoid confusion, while describing the experiment results we use L1∼Ln to represent the nth latent variable rather than zi from the architecture. S1∼S5 are used to represent the generative factors or the inputs of the dataset respectively. Table 1. Progressive Autoencoder architecture for Non-linear Toy datasets Encoder0 and Encoder
NNi
Decoder 1
Input 48 1d array Input zi ∈ R Input mi ∈ R5 0 Dropout (Drop Ratio: 0.2) FC. 64 SeLU FC. 32 SeLU FC. 64 SeLU FC. 32 SeLU FC. 64 SeLU FC. 64 SeLU FC. 2 · 50(pi1 , pi2 ) FC. 48 FC. 2 · n∗ (μi , σi ) n∗ is determined according to the experiment requirements
In each experiment, 10,000 data pieces are sampled from the toy dataset and a total of 20000 training iterations are conducted (whether it’s converged or not) using Adam optimizer with batch size = 500 and learning rate = 0.001. Note that each experiment presents in this section is done individually, they have different network random initial states and do not share any trainable parameters with each other. Each experiment is done more than once with different initial states to prove reproduction stability. In the first two experiment, dataset is the toy dataset mentioned in Sect. 3.1, and in the last experiment, we test the model on MNIST.
Learning Stable Representations with Progressive Autoencoder (PAE)
545
Fig. 4. (a) shows the distribution of adder patcher (blue) and multiplier patcher (orange) which is mentioned in Sect. 3. 2. (b) shows the reconstruction loss lowered by each latent variable. (c) and (d) show the non-linear mapping between L1 L6 and S1 S5 in two separate experiments. (Color figure online)
3.3
Results of Toy Dataset
Figure 4 shows the result of PAE with 6 latent variables. As shown in Fig. 4a, the distributions which count every vector component of patchers are gradually closing into the center, meaning that they are contributing less and less to the final result and the significance of new latent variable is on the decline. This is also reflected in Fig. 4b in which the reconstruction error is decreasing with more latent variables added at a level that also gradually reduces. Until L6 is added, the variation of loss seems to cease because there are only five generative factors and the first five latent variables have almost learned all features of them. When given the 6th latent variable, it becomes redundant could hardly refine the result any further. Figures 4c and 4d, which come from two individual experiments with different random initialization, present the relationship between S and L. In the title of each subgraph, Si Lj means that the horizontal axis is Si and the vertical axis is Lj , while MI (short for mutual information) is calculated to measure the relativity of Si and Lj . By comparing the two graphs, we could easily find that the corresponding subgraphs are extremely similar which proves that latent variables in same place always learn same features and this is exactly the stability we talked above. However, the feature learned by PAE is still Entangled, in other to solve this problem, the structure of β-VAE is imported in next experiment. Figure 5 shows the result of supervised PAE with 6 latent variables. The first two latent variables have learned good features by the supervision of S1 and S2. The other latent variable could also Provide proof of the stability of PAE.
546
Z. Li et al.
Fig. 5. Demonstration of the results of supervised PAE.
Table. 2 shows the reconstruction errors of all the PAE experiments beta-PAE and supervised PAE, with different latent variables available to use. Reconstruction errors of model with same latent variables are similar and Reconstruction errors in same row increase gradually. The reconstruction error of L2 of Supervised PAE is greater than others’ because what they learn is restricted by supervision while the reconstruction error of L6 of Supervised PAE is the smallest one because it can learn better with the “patch” of other latent variables. The experiment that will be described next is about Beta-PAE, which sacrifices some ability of reconstruction and acquires the ability of disentanglement. Table 2. Reconstruction Errors for Different Architectures with Different Amount of Latent Variables
3.4
-
R.E L1 R.E L2 R.E L3 R.E L4 R.E L5 R.E L6
PAE - 1 L.V PAE - 2 L.V PAE - 3 L.V PAE - 4 L.V PAE - 5 L.V PAE - 6 L.V β-PAE - 6 L.V Supervised-PAE
0.422 0.431 0.455 0.448 0.472 0.484 0.565 -
0.305 0.308 0.315 0.335 0.348 0.426 0.450
0.235 0.226 0.231 0.235 0.331 0.281
0.188 0.183 0.179 0.253 0.194
0.161 0.149 0.186 0.143
0.147 0.183 0.141
Comparison Between Beta-PAE and Beta-VAE
By combining beta-VAE and Progressive Autoencoder, we can learn disentangled representations and rank them by importance at the same time. This could be useful in industrial applications. Results of Progressive Autoencoder and betaVAE are shown in Fig. 6.
Learning Stable Representations with Progressive Autoencoder (PAE)
547
Fig. 6. The result of beta-VAE and Progressive Autoencoder with 6 latent variables.
Through comparing Fig. 6a and Fig. 5c/Fig. 5d, we find that beta-PAE gains the ability of disentanglement. From the comparison of Figs. 6a and 6b, though beta-VAE separate generation factors a little better, beta-PAE is stable. BetaPAE always learns the generation factors in a fixed sequence, order of decreasing importance, but the result of beta-VAE is different every time as we all know. 3.5
Experiments on MNIST
Fig. 7. The results of β-PAE with one latent variable varies each time and others fixed
From the experiments with the toy dataset, the characteristics of β-PAE are shown. In this part, we will display the property of β-PAE in a more intuitive
548
Z. Li et al.
way. For MNIST dataset, β-PAE is used with supervision to generate the digits first, then applied unsupervised approach to learn the other properties of the hand-written digits. We turn the latent variables of trained model from −1 to 1 and fix the other latent variables to 0. The result is shown in Fig. 7. In Fig. 7, the pictures are ordered by the sequence of latent variables. For example, in Fig. 7a, the first latent variable, which is the most important one, is changed from −1 to 1 and others are always 0. The first 3 properties learned are human-interpretable: Angle, Width, and Thickness while the other properties learned are not so easily described. The 3 factors learned, according to personal experiences, might just as well be the most important factors for digit reconstruction.
Fig. 8. (a) is the result of β-PAE with one latent variables and (b)–(d) are results of β-PAE by adding only one latent variables each time. (e) is the input that is placed here for comparison
β-PAE can also generate the digits using only the first few latent variables, the “intermediate products” as they are refined. Figure 8g shows how each latent variable refines the reconstruction results. At first, PAE could only get what number it is so that every picture of same number looks exactly the same. As new latent variables are added, the mode could acquire Angle, Width, and Thickness of the number and become increasingly clear.
4
Discussion and Conclusion
In this paper, we proposed a new Autoencoder-based model called Progressive Autoencoder. In this model, the latent variables are ordered by their consequence and have strong stability along with the structure called Progressive Patching Decoder. Progressive Autoencoder could also gain the disentanglement ability by combining the idea of β-VAE. We created a toy dataset to prove the property of PAE. The results indicate that features learned by PAE are stable and order by the way we want. We also showed how to use PAE on a supervised task and designed experiments to prove the effect. Comparison of β-PAE and β-VAE shows the combination of stability and ability of disentanglement. How does PAE work is intuitively showed by the experiments on MNIST, in which we could find variation of features which are visible.
Learning Stable Representations with Progressive Autoencoder (PAE)
549
Overall, Progressive Autoencoder is a flexible architecture that can be used in many real-life scenarios such as condition monitoring of gas turbine and can help us analyze non-linear systems better, although we haven’t mentioned any of these apps for space reasons but only explained the theory and characteristics. In this paper, we just use Multi-layer Perceptron as a substructure of PAE. By the combination of other structures such as Convolutional Neural Networks or Attention Mechanism, PAE may become stronger and applicable to a variety of fields.
References 1. Rifai, S., Vincent, P., Muller, X., Glorot, X., Bengio, Y.: Contractive auto-encoders: explicit invariance during feature extraction. In: ICML (2011) 2. Kingma, D.P., Welling, M.: Auto-Encoding Variational Bayes. CoRR, abs/1312.6114 (2014) 3. Kingma, D.P., Mohamed, S., Rezende, D.J., Welling, M.: Semi-supervised learning with deep generative models. In: NIPS (2014) 4. Higgins, I., et al.: beta-VAE: learning basic visual concepts with a constrained variational framework. In: ICLR (2017) 5. Burgess, C., et al.: Understanding disentangling in beta-VAE. arXiv:MachineLearning (2018) 6. Chen, T.Q., Li, X., Grosse, R.B., Duvenaud, D.: Isolating Sources of Disentanglement in Variational Autoencoders. In: NeurIPS (2018) 7. Esmaeili, B., et al.: Structured disentangled representations. In: AISTATS (2019) 8. Goodfellow, I.J., et al.: Generative adversarial nets. In: NIPS (2014) 9. Mirza, M., Osindero, S.: Conditional Generative Adversarial Nets. arXiv:abs/1411.1784 (2014) 10. Donahue, J., Kr¨ ahenb¨ uhl, P., Darrell, T.: Adversarial Feature Learning. arXiv:abs/1605.09782 (2017) 11. Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I., Abbeel, P.: InfoGAN: interpretable representation learning by information maximizing generative adversarial nets. In: NIPS (2016) 12. Lin, Z., Thekumparampil, K.K., Fanti, G., Oh, S.: InfoGAN-CR and ModelCentrality: self-supervised model training and selection for disentangling GANs. In: ICML (2020) 13. Bengio, Y., Yao, L., Alain, G., Vincent, P.: Generalized denoising auto-encoders as generative models. In: NIPS (2013) 14. Yann, L., Ishan, M.: Self-supervised learning: the dark matter of intelligence. https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-ofintelligence/. Accessed 18 Mar 2022 15. Kraskov, A., St¨ ugbauer, H., Grassberger, P.: Estimating mutual information. Phys. Rev. 69(6), 066138 (2004)
Effect of Image Down-sampling on Detection of Adversarial Examples Anjie Peng1,2 , Chenggang Li1 , Ping Zhu3(B) , Zhiyuan Wu1 , Kun Wang1 , Hui Zeng1(B) , and Wenxin Yu1 1
Southwest University of Science and Technology, Sichuan, MY 621010, China [email protected] 2 Science and Technology on Communication Security Laboratory, Sichuan, CD 610041, China 3 Chengdu University of Information Technology, Sichuan, CD 610103, China
Abstract. Detecting adversarial examples and rejecting to input them into a CNN classifier is a crucial defense method to prevent the CNN being fooled by the adversarial examples. Considering that attackers usually utilize down-sampling to match the input size of CNN and the detection methods are commonly evaluated on down-sampled images, we study how the detectability of adversarial examples is affected by the interpolation algorithm if the legitimate image is down-sampled prior to be attacked. Since the down-sampling changes the relationships among neighboring pixels, the steganalysis-based detectors relying on the neighborhood dependencies are probably affected sharply. Experimental results on ImageNet verify that the detection accuracy varies among different interpolation kernels dramatically (the largest difference of accuracy is up to about 9%), and such novel phenomena appear valid universally across the tested CNN models and attack algorithms for the steganalysis-based detection method. Our work is of interest to both attackers and defenders for the purpose of benchmarking the attack algorithm and detection method respectively. Keywords: Convolution Neural Network Detection · Down-sampling
1
· Adversarial Image ·
Introduction
Adversarial examples have attracted attentions to the security of convolution neural network (CNN) classifiers. Adversarial attacks, such as FGSM [1], BIM [2], DeepFool [3], BP [4], C&W [5], craft imperceptive perturbations on a legitimate image carefully to generate the adversarial image, and effectively force the CNN to misclassify the original ground truth label. This form of attack throws out some security threats in the CNN-based applications, especially in the security sensitive field, for instance, self-driving cars [6], robots [7]. How to harden CNNs against the adversarial attacks [8–13] is a hot topic. This work was partially supported by NSFC (No. 41865006), Sichuan Science and Technology Program (No. 2022YFG0321, 2022NSFSC0916). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 550–561, 2023. https://doi.org/10.1007/978-981-99-1639-9_46
Effect of Image Down-sampling on Detection of Adversarial Examples
551
Fig. 1. The flow chart of a detector defense against the adversarial image. The detector defense filters out the adversarial image, and only feeds the clean image into the CNN for classification. In this work, we focus on how the pre-processing down-sampling in the process of generating an adversarial image (dashed box) affects the detectability of the detector defense.
Besides adversarial training [8,9,28], detecting the adversarial images and filtering out them before inputting them into the CNN is another important defense approach as illustrated in Fig. 1. Input transformation and steganalysisbased method are two typical detection algorithms. Since adversarial perturbations are not robust against image transformations, input transformation methods first feed a questioned image and its elaborately manipulated version into the CNN, and then detect the questioned image as adversarial if the CNN outputs are inconsistent before and after the transformation, such as the denoising filter composed by scalar quantization and smoothing spatial filter [14], feature squeezing (FS for short) [15], image quilting [16], image resampling [31]. As indicated by Goodfellow et al. [1] that the adversarial attack can be treated as a sort of accidental steganography, some steganalysis-based methods are proposed [17–21]. In this work, we study how the pre-processing down-sampling affects the detectability of adversarial images. Unlike the post-processing re-sampling used in the input transformation detection [31], the role of the down-sampling in our work is a pre-processing operation. We consider it as an important topic for several reasons. (1) Down-sampling is a commonly used operation when generating adversarial images as shown in Fig. 1. To save computation sources for the deep architecture, the size of input image for CNN is usually small. For example, ResNet [22] models accept RGB inputs of size 224 × 224 × 3. To match the small input size of CNN, the image needs to be down-sampled before attacking. Some adversarial platforms employ different down-sampling algorithms for the attack. For example, Cleverhans1 (bilinear), EvadeML2 (nearest), RealSafe3 (bilinear), Foolbox4 (bicubic), Advertorch5 (bilinear). (2) For the purpose of benchmarking the detection method. Figure 1 shows that the down-sampling 1 2 3 4 5
https://github.com/cleverhans-lab/cleverhans. https://evadeML.org/zoo. https://github.com/thu-ml/ares. https://github.com/bethgelab/foolbox. https://github.com/BorealisAI/advertorch.
552
A. Peng et al.
possibly be a factor affecting the detectability of detector defense. Many detection methods [1,14–21] are evaluated on the down-sampled adversarial images but without considering the effects of down-sampling. To our best knowledge, the role of the pre-processing down-sampling and its influence on detection method has not been studied so far. Many works [23–26] have analyzed the impacts of pre-processing and post-processing on steganalysis and forensics. Inspired by these works, we select three typical interpolation algorithms, two typical attacks, BIM [2], and C&W [5], two CNN models, ResNet-50 [22] and Inception-V3 [29], two typical detection methods, ESRM [19] and FS [15], for considerations. Experimental results reveal that the detection accuracies vary quite dramatically among different interpolation kernels and attack parameters for the state-of-the-art steganalysis-based method ESRM [19]. These results may provide some implications to attackers and defenders, and assist them develop their own optimal strategies to evade detection or improve defense ability.
2
Motivation Experiment
Jan Kodovsk´ y et.al. [23,24] find that down-sampling remarkably affects the steganalysis results. As steganography versus steganalysis is analogous to adversarial attack versus detection defense [17], we also study how the down-sampling affects the detectability of adversarial images. To motivate our study, we select 1000 images from the validation dataset of ImageNet-1000(ILSVRC-2012) as source image database. Next, we prepare three kinds of down-sampled database generated on three commonly used interpolation kernels: nearest, bilinear and lanczos using resizing algorithm Resize (·) in PyTorch. All source images are down-sampled so that the smaller side of the image is 224 pixels, finally central-cropped to 224 × 224 pixels. Fig. 2 illustrates the results of ESRM [19] detecting untargeted BIM adversarial images which are generated on ResNet-50. For each attack strength budget , an ESRM detector is constructed by training the ensemble classifier [27] with using ESRM feature. Half of the images are used for training and the other half are for testing, while the performance is evaluated by the detection accuracy (Acc) under equal number of legitimate images and adversarial images. The results in Fig. 2 show that striking discrepancies of Acc is reflected in detecting different versions of down-sampled database. For example, at the attack strength budget = 1, the Acc of nearest kernel is 83.1%, being about 10.0% lower than the Acc of the bilinear kernel. These results indicate that the choice of the interpolation kernel significantly affects the detectability, and thus a deeper understanding of this phenomenon is of a great importance for fairly benchmarking the detection method.
Effect of Image Down-sampling on Detection of Adversarial Examples
553
Fig. 2. The accuracies (Accs, % ) of ESRM detecting BIM down-sampled adversarial images created with three different interpolation kernels against ResNet-50. The x axis is the attack strength budget .
3
Further Investigation
Inspired by the results shown in Fig. 2, we select two adversarial attacks (BIM [2], C&W [5]) attacking two CNN classifiers (ResNet-50 [22], Inception-V3 [29]), and further investigate how the interpolation kernel affects two detection methods (ESRM [19], FS [15]) on a larger image database. 3.1
Down-Sampling Algorithm
The image down-sampling is executed before attacking CNN classifiers as shown in Fig. 1. The down-sampling process is executed as follows: (1) Determine the position of new pixels based on the scaling factor; (2) Input the distances between the new pixels and neighbor old pixels to the interposition kernel to compute the weights; (3) Sum weights of intensities of neighbor old pixels as the values of new pixels. Obviously, the interpolation kernel and scaling factor are two primary factors. In this work, we focus on the interpolation kernel and employ the PyTorch function Resize (·) with three commonly used interpolation kernels nearest neighbor ϕn , bilinear ϕb , and lanczos ϕl in the experiments. As indicated by the formula (1)-(3), the bilinear and lanczos kernels consider more neighboring pixels than the nearest neighbor kernel. The bilinear and lanczos kernels are expected to cause stronger dependencies of neighboring pixels on the downsampled images than the nearest neighbor kernel does.
554
A. Peng et al.
The resolutions of the images in the validation dataset of ImageNet1000 (ILSVRC-2012) are larger than the input size of our used ResNet-50 [22] and Inception-V3 [29], such as 500 × 375 × 3, 500 × 333 × 3, 375 × 500 × 3, 500 × 500 × 3. To match the input size of our CNNs, we first down-sample the short side to 224 and 299, and then crop the center part to form the resized image of size 224 × 224 × 3 for ResNet-50 and 299 × 299 × 3 for Inception-V3 respectively. 1, − 12 ≤ x < 12 ϕn (x) = (1) 0, otherwise 1 − |x|, |x| ≤ 1 ϕb (x) = (2) 0, otherwise ⎧ ⎪ x=0 ⎨1, 2sin(πx)sin( πx ) 2 ϕl (x) = (3) , 0 < |x| < 2 π 2 x2 ⎪ ⎩ 0, otherwise 3.2
Adversarial Attacks
After creating the down-sampled images to match the input size of CNN, the adversarial image is generated on Advertorch platform6 . Two typical attack algorithms BIM [2] and C&W [5] are considered for attacking against commonly used pre-trained CNN models ResNet-50 [22]7 and Inception-V3 [29]8 respectively. BIM [2] is an iterative gradient-based attack. For an image x of = x, it can be formulated as (4), where f (·) label y true , with initializing xadv 0 is a CNN classifier, ∇(J(·)) is the gradient of the loss function J(·), clipx, (·) limits the perturbation is less than . We set the iteration number to be 10, α = 1 and = 1, 3, 5 to ensure attacking successfully and adding imperceptive perturbations on the legitimate image. C&W [5] is an optimizationbased attack. It optimizes the problem (5) to generate adversarial images, where c is a hyperparameter tuning the L2 distance and the prediction function F (xadv ) = max(Z(xadv )l∗ − max{Z(xadv )i , i = l∗ }, −κ) for untargeted attack which makes the true label l∗ least-likely. We set the confidence κ = 0, 10, 20 and the other parameters used default value in the Advertorch platform. adv true xadv + clipx, (αsign(∇xadv J(f (xadv ))) i+1 = xi i ), y i
arg min ||xadv − x||2 + cF (xadv )
xadv s.t.xadv ij
6 7 8
∈ [0, 1]
https://github.com/BorealisAI/advertorch. https://download.pytorch.org/models/resnet50-0676ba61.pth. https://download.pytorch.org/models/inception v3 google-0cc3c7bd.pth.
(4)
(5)
Effect of Image Down-sampling on Detection of Adversarial Examples
3.3
555
Detection Results of ESRM for Different Interpolation Kernels
ESRM [19] is a steganalysis-based detection method. It enhances the steganalysis feature SRM [30] via considering the modification probability of each pixel and allocating large weights to the probably modified pixels when calculating cooccurrences. As ESRM feature is a composition of co-occurrences of multiple high frequency residuals, it yields tremendous dimensions up to 34671 and it is susceptible to neighboring pixels dependencies. The FLD ensemble classifier with default settings is employed [19,27] to construct the binary detector for detecting adversarial images from legitimate images. To evaluate the detection accuracy of ESRM, we randomly select 5000 images that consists of 5 images per class from the famous validation dataset of ImageNet-1000(ILSVRC-2012) as the source dataset. The ratio of training samples and testing samples is 1:1. Table 1 denotes the results of ESRM detecting BIM and C&W adversarial images attacking ResNet-50. It is shown that the Acc varies sharply when detecting different down-sampled adversarial images. For each attack strength, the Acc of detecting bilinear down-scaled images is the highest, while that of detecting nearest neighbor down-scaled images is the lowest. The minimum and maximum difference between them are 4.15% (at BIM, = 5) and 8.3% (at BIM, = 1) respectively. The mainly reason for these discrepancies of Acc is that different interpolation kernels result in different dependencies of neighboring pixels which finally cause different detection accuracies of ESRM method. For the downscaling with the nearest interpolation kernel, it skips some original pixels and assigns the new resized pixel value based on the formula (1) via replacing it with the nearest original pixel value. However, for the bilinear and lanczos kernel, the pixel values in the down-scaled image are interpolated as a certain linear combination of the original pixel values. Obviously, for the down-sampling fixing other parameters, the bilinear and lanczos kernel results in stronger neighboring pixel dependencies than the nearest kernel. As indicated in steganalysis [23,24,30], the stronger dependencies will be disturbed more when adding adversarial perturbations onto the legitimate image to generate the adversarial image. This means that attacking on the bilinear and lanczos down-scaled image will alter more on neighboring pixel dependencies than attacking on the nearest neighbor down-scaled image. Since ESRM feature is based on the neighboring pixel dependencies, it is expected that ESRM detector possesses superior detectability on the bilinear and lanczos resized adversarial images than on the nearest resized adversarial images as empirically verified in Table 1. To further verify our investigation, we repeat experiments of ESRM detecting adversarial images attacking Inception-V3. The results in Table 2 also illustrate that the Acc varies when detecting different down-sampled adversarial images. Similarly, the Acc of detecting nearest, lanczos, and bilinear resized images are ascending. Because the down-scaling factor used for Inception-V3 is lower than that for ResNet-50 and the changes on neighboring pixel dependencies are weakened. Compared with the results in Table 1, the discrepancies of Acc for
556
A. Peng et al.
Table 1. The Acc (%) of ESRM detecting different down-sampled adversarial images attacking ResNet-50. BIM
Interpolation kernel = 1 = 3 nearest 83.78 92.48 bilinear 92.08 92.08 lanczos 86.48 86.48
C&W Interpolation kernel κ = 0 κ = 10 nearest 54.31 67.51 bilinear 59.37 59.37 lanczos 82.33 70.49
=5 94.82 98.97 97.83 κ = 20 74.86 82.33 77.87
detecting three kinds of down-scaled adversarial images become smaller as shown in Table 2. The results in Tables 1 and 2 indicates that ESRM feature based on the neighboring pixel dependencies is affected by the interpolation kernel, and thus the detectability of ESRM detector is different for detecting adversarial images generated from different down-sampled legitimate images. Besides, we figure out that detecting BIM adversarial images are easier than detecting C&W adversarial images. Because BIM adds more adversarial perturbations onto legitimate image, which disturbs the neighboring pixel dependencies more. Under this reason, detecting the adversarial images generated by stronger attack strength(larger , κ) also become easier as shown in Tables 1 and 2. Table 2. The Acc (%) of ESRM detecting different down-sampled adversarial images attacking Inception-V3. BIM
Interpolation kernel = 1 = 3 nearest 89.28 94.71 bilinear 91.86 97.83 lanczos 89.84 96.39
C&W Interpolation kernel κ = 0 κ = 10 nearest 58.44 69.71 bilinear 63.54 75.89 lanczos 60.61 72.72
3.4
=5 95.95 98.38 97.50 κ = 20 79.37 85.31 82.28
Detection Results of FS for Different Interpolation Kernels
FS [15] is a typical transformation-based method, which is built on the assumption that the legitimate image is more robust against image transformations than the adversarial image. For a questioned image, the detection process of FS is as follows. (1) Employ bit depth reduction, median filtering, and non-local mean
Effect of Image Down-sampling on Detection of Adversarial Examples
557
methods successively to squeeze input space. (2) Input the original image and its squeezed version to the CNN to get the corresponding SoftMax outputs. (3) Calculate the L1 distance of two SoftMax outputs. (4) Compare the distance with a threshold distance T and predict the one whose distance is larger than T as adversarial, otherwise legitimate. The default best joint detection method of FS is employed in the experiment, where the threshold T is determined via fixing FPR=5% (i.e., at most 5% legitimate images are misclassified as adversarial). We randomly select 11000 images that consists of 11 images per class from ILSVRC-2012 as the source dataset. One half of the legitimate image is used for determining the threshold T. We report the detection accuracy of the other half legitimate images and their corresponding adversarial images. The experimental results in Tables 3 and 4 show that the Accs of detecting nearest, bilinear and lanczos resized image are different at the same attack strength level. The greatest difference among them occurs at detecting C&W on ResNet-50 with κ = 0, where the Acc of detecting nearest, bilinear and lanczos resized image are 61.55%, 72.09% and 73.47% respectively. On the contrary, for a stronger attack or a more robust CNN classifier, the difference in the Acc of detecting different down-sampled images has become smaller. Notice that FS is different from ESRM, it is a transformation-based detector, whose detectability is determined not only by the difference between the legitimate image and adversarial image, but also by the robustness of the CNN classifier. Hence, we find out the reason for the above decreased discrepancy probably be that the robust CNN model and the strong attack smooth out the differences caused by different interpolation kernels. For example, the classification accuracies of pre-trained Inception-V3 [29] on nearest, bilinear and lanczos resized images are 84.53%, 84.63% and 85.07% respectively. These similar results indicate that these CNN models have similar robustness against image squeezing operations, so it is expected that the Accs are nearly same for different downscaled images as shown in Table 4. The attack with strong attack strength also yields Accs be nearly same as shown in Table 3. Table 3. The Acc (%) of FS detecting different down-sampled adversarial images attacking ResNet-50. BIM
Interpolation kernel = 1 = 3 nearest 78.16 58.55 bilinear 77.43 56.85 lanczos 78.58 57.45
C&W Interpolation kernel κ = 0 κ = 10 nearest 61.55 82.38 bilinear 72.09 83.69 lanczos 73.47 84.38
=5 52.80 51.32 51.61 κ = 20 77.99 79.99 79.77
558
A. Peng et al.
Table 4. The Acc (%) of FS detecting different down-sampled adversarial images attacking Inception-V3. BIM
Interpolation kernel = 1 = 3 nearest 81.90 63.28 bilinear 81.22 62.86 lanczos 84.51 64.00
C&W Interpolation kernel κ = 0 κ = 10 nearest 81.95 90.43 bilinear 82.55 90.03 lanczos 83.59 90.74
=5 56.47 55.69 56.79 κ = 20 83.87 82.92 83.30
Notice that FS is different from ESRM, it is a transformation-based detector, whose detectability is determined not only by the difference between the legitimate image and adversarial image, but also by the robustness of the CNN classifier. Hence, we find out the reason for the above decreased discrepancy probably be that the robust CNN model and the strong attack smooth out the differences caused by different interpolation kernels. For example, the classification accuracies of pre-trained Inception-V3 [29] on nearest, bilinear and lanczos resized images are 84.53%, 84.63% and 85.07% respectively. These similar results indicate that these CNN models have similar robustness against image squeezing operations, so it is expected that the Accs are nearly same for different downscaled images as shown in Table 4. The attack with strong attack strength also yields Accs be nearly same as shown in Table 3. 3.5
Discussion
The results in Tables 1, 2, 3 and 4 verify that the pre-processing down-sampling affects the detectability of the adversarial examples for the detector defense. As diverse interpolation kernels change the dependencies of neighboring pixels differently, the down-sampling brings different impacts on the detectability of the steganalysis-based detector ESRM which is relied on the neighborhood dependencies. The bilinear and lanczos kernel results in stronger neighboring pixel dependencies than the nearest kernel. Generally, the stronger neighborhood dependencies will be disturbed more by the adversarial perturbation, thus causing the bilinear and lanczos interpolated down-sampled images are easier to be detected the nearest down-sampled images. Simultaneously, the downsampling also affects the detectability of the image transformation detector FS, but with smaller differences of influences among different interpolation kernels. The experimental results may give some implications for attackers and defenders to develop their own optimal strategies under the attack and defense confrontation situation. Since ESRM and FS bear relative low detection accuracies on the nearest down-scaled adversarial images, in order to evade the detection defense as much as possible, attackers tend to apply the nearest neighbor
Effect of Image Down-sampling on Detection of Adversarial Examples
559
down-sampled images for adversarial attack. For defenders, as indicated by the detection results of ResNet-50 and Inception-V3, develop more accurate and robust CNN model is a choice of improving the defense ability.
4
Conclusion
The down-sampling is usually applied before generating the adversarial images. In this paper, we study how the pre-processing down-sampling affects the detectability of adversarial images. To our best knowledge, this paper is the first work related to the influence of down-sampling on the detectability, and it reveals a surprising sensitivity of steganalysis-based detection to the choice of the interpolation kernel. To get complete empirically investigations, experiments are executed on three interpolation kernels, two qualitatively different attack algorithms, BIM and C&W, and two state-of-the-art detecting methods named FS and ESRM. Since down-sampling alters the strength of dependencies among neighboring image pixels, experimental results verify that the detectability of steganalysis-based feature ESRM is affected heavily by the interpolation kernel used in the down-sampling. Besides, the detectability of FS is also affected by the interpolation kernel, but it is less affected than ESRM does. The main contribution of this paper is explaining how the detectability of adversarial images varies with the interpolation algorithms and its settings. Our work is probably advantageous for attackers and defenders to benchmark their performance as full as possible under the attack and defense confrontation situation. In the future, we will study how some other pre-processing factors (such as smoothing, noising) affect the detectability of adversarial images.
References 1. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. In: International Conference on Machine Learning, pp. 1–10 (2015) 2. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. In: International Conference on Learning Representations. arXiv:1706.06083 (2018) 3. Moosavi-Dezfooli, S.M., Fawzi, A. and Frossard, P.: Deepfool: a simple and accurate method to fool deep neural networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2574–2582 (2016) 4. Zhang, H., Avrithis, Y., Furon, T., Amsaleg, L.: Walking on the edge: fast, lowdistortion adversarial examples. IEEE Trans. Inf. Forensics Secur. 16, 701–713 (2020) 5. Carlini, N., Wagner, D.: Towards evaluating the robustness of neural networks. In: 2017 IEEE Symposium on Security and Privacy, pp. 39–57 (2017) 6. Bourzac, K.: Bringing big neural networks to self-driving cars, smartphones, and drones. IEEE Spectrum, 13–29 (2016) 7. Mnih, V., et al.: Human-level control through deep reinforcement learning. Nature 518(7540), 529–533 (2015) 8. Wong, E., Rice, L., Kolter, J., Z.: Fast is better than free: Revisiting adversarial training. arXiv:2001.03994 (2020)
560
A. Peng et al.
9. Zhang, H., Yu, Y., Jiao, J., Xing, E., El Ghaoui, L., Jordan, M.: Theoretically principled trade-off between robustness and accuracy. In: International Conference on Machine Learning, pp. 7472–7482 (2019) 10. Machado, G.R., Silva, E., Goldschmidt, R.R.: Adversarial machine learning in image classification: a survey toward the defender’s perspective. ACM Comput. Surv. (CSUR) 55(1), 1–38 (2021) 11. Grosse, K., Manoharan, P., Papernot, N., Backes, M., McDaniel, P.: On the (statistical) detection of adversarial examples. arXiv:1702.06280 (2017) 12. Lu, J., Issaranon, T., Forsyth, D.: Safetynet: detecting and rejecting adversarial examples robustly. In: Proceedings of the IEEE International Conference On Computer Vision, pp. 446–454 (2017) 13. Li, X. and Li, F.: Adversarial examples detection in deep networks with convolutional filter statistics. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 5764–5772 (2017) 14. Liang, B., Li, H., Su, M., Li, X., Shi, W., Wang, X.: Detecting adversarial image examples in deep neural networks with adaptive noise reduction. IEEE Trans. Dependable Secur. Comput. 18(1), 72–85 (2018) 15. Xu, W., Evans, D. and Qi, Y.: Feature squeezing: Detecting adversarial examples in deep neural networks. In: Network and Distributed System Security Symposium. arXiv:1704.01155 (2017) 16. Guo, C., Rana, M., Cisse, M., Van Der Maaten, L.: Countering adversarial images using input transformations. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. arXiv:1711.00117 (2017) 17. Sch¨ ottle, P., Schl¨ ogl, A., Pasquini, C., B¨ ohme, R.: Detecting adversarial examples-a lesson from multimedia security. In: 2018 26th European Signal Processing Conference (EUSIPCO), pp. 947–951 (2018) 18. Fan, W., Sun, G., Su, Y., Liu, Z., Lu, X.: Integration of statistical detector and Gaussian noise injection detector for adversarial example detection in deep neural networks. Multimed. Tools Appl. 78(14), 20409–20429 (2019). https://doi.org/10. 1007/s11042-019-7353-6 19. Liu, J., et al.: Detection based defense against adversarial examples from the steganalysis point of view. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4825–4834 (2019) 20. Bonnet, B., Furon, T., Bas, P.: Forensics through stega glasses: the case of adversarial images. In: International Conference on Pattern Recognition, pp. 453–469 (2021) 21. Peng, A., Deng, K., Zhang, J., Luo, S., Zeng, H., Yu, W.: Gradient-based adversarial image forensics. In: International Conference on Neural Information Processing, pp. 417–428 (2020) 22. He, K., Zhang, X., Ren, S. and Sun, J.: Deep residual learning for image recognition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 23. Kodovsk´ y, J., Fridrich, J.: Steganalysis in resized images. In: 2013 IEEE International Conference on Acoustics, Speech and Signal Processing, pp. 2857–2861 (2013) 24. Kodovsk´ y, J., Fridrich, J.: Effect of image downsampling on steganographic security. IEEE Trans. Inf. Forensics Secur. 9(5), 752–762 (2014) 25. Stamm, M.C., Wu, M. and Liu, K.R.: Information forensics: An overview of the first decade. IEEE access. 1, 167–200 (2013). (Kang, X., Stamm, M.C., Peng, A. and Liu, K.R.: Robust median filtering forensics using an autoregressive model. IEEE Trans. Inf. Forensics Secur. 8(9), pp. 1456–1468 (2013))
Effect of Image Down-sampling on Detection of Adversarial Examples
561
26. Kang, X., Stamm, M.C., Peng, A., Liu, K.R.: Robust median filtering forensics using an autoregressive model. IEEE Trans. Inf. Forensics Secur. 8(9), 1456–1468 (2013) 27. Kodovsky, J., Fridrich, J., Holub, V.: Ensemble classifiers for steganalysis of digital media. IEEE Trans. Inf. Forensics Secur. 7(2), 432–444 (2011). (Dong, Y., et al.: Benchmarking adversarial robustness on image classification. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 321–331 (2020)) 28. Dong, Y., et al.: Benchmarking adversarial robustness on image classification. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 321–331 (2020) 29. Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., Wojna, Z.: Rethinking the inception architecture for computer vision. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2818–2826 (2016) 30. Fridrich, J., Kodovsky, J.: Rich models for steganalysis of digital images. IEEE Trans. Inf. Forensics Secur. 7(3), 868–882 (2012) 31. Mustafa, A., Khan, S.H., Hayat, M., Shen, J., Shao, L.: Image super-resolution as a defense against adversarial attacks. IEEE Trans. Image Process. 29, 1711–1724 (2020)
Boosting the Robustness of Neural Networks with M-PGD Chenghai He1,2,3 , Li Zhou3 , Kai Zhang3 , Hailing Li3 , Shoufeng Cao3 , Gang Xiong1,2(B) , and Xiaohang Zhang3(B) 1
Institute of Information Engineering, Chinese Academy of Sciences, Beijing, China {hechenghai,xionggang}@iie.ac.cn 2 School of Cyber Security, University of Chinese Academy of Sciences, Beijing, China 3 CNCERT/CC, Beijing, China {hechenghai,zhouli,zhangkai,lihailing,csf,zhangxiaohang}@cert.org.cn
Abstract. Neural networks have achieved state-of-the-art results in many fields. With further research, researchers have found that neural network models are vulnerable to adversarial examples which are carefully designed to fool the neural network models. Adversarial training is a process of creating adversarial samples during the training process and directly training a model on the adversarial samples, which can improve the robustness of the model to adversarial samples. In the adversarial training process, the stronger the attack ability of the adversarial samples, the more robust the adversarial training model. In this paper, we incorporate the momentum ideas into the projected gradient descent (PGD) attack algorithm and propose a novel momentum-PGD attack algorithm (M-PGD) that greatly improves the attack ability of the PGD attack algorithm. After that, we train a neural network model on the adversarial samples generated by the M-PGD attack algorithm, which could greatly improve the robustness of the adversarial training model. We compare our adversarial training model with the other five adversarial training models on the CIFAR-10 and CIFAR-100 datasets. Experiments show that our adversarial training model can be extremely more robust to adversarial samples than the other adversarial training models. We hope our adversarial training model will be used as a benchmark in the future to test the attack ability of attack models.
Keywords: Neural networks classification
1
· Adversarial training · Image
Introduction
Neural networks have achieved remarkable results in many fields, such as image classification [1,2], speech recognition [3], and natural language processing [4]. However, recent studies have found that the neural network models are vulnerable to attacks from adversarial samples which are created by slightly modifying c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 562–573, 2023. https://doi.org/10.1007/978-981-99-1639-9_47
Boosting the Robustness of Neural Networks with M-PGD
563
the clean samples. And even if these modifications have no impact on the human’s judgment, the neural network models still can’t recognize them correctly. In general, there are two ways to classify attacks. The first way divides attacks into black-box attacks and white-box attacks, if the parameters and structure of the attack model are known, the attack is called a white-box attack; if the parameters and structure of the attack model are both unknown, the attack is called a blackbox attack. The second way divides attacks into targeted attacks and untargeted attacks, if the target class is given, the attack is called a targeted attack; if the target class is not given, the attack is called an untargeted attack. The emergence of adversarial samples has attracted a large number of researchers. And researchers proposed many attack and defense algorithms, such as defensive distillation [6,7], feature compression [8,9], and several other adversarial sample detection methods [10]. These studies provide a reference for how to defend against adversarial samples, but these methods still can’t withstand the attacks of well-designed adversarial samples. Madry et al. [11] have studied the adversarial robustness of neural networks from the perspective of robustness optimization, and they used a natural saddle point (min-max) formula to improve the robustness of neural network models against adversarial samples. At the same time, Madry et al. [11] proposed the PGD attack algorithm to solve the non-convex and non-concave problems in the saddle point formula. Based on the PGD attack algorithm, this paper proposes the momentum-PGD attack algorithm and applies it to adversarial training. Experimental results show that our proposed algorithm can greatly improve the robustness of the model to adversarial samples. In this paper, we make the following contributions: – We incorporate the idea of the momentum method [16] into the PGD attack algorithm and propose a novel untargeted attack algorithm called the M-PGD attack algorithm. Compared to the PGD attack algorithm, the addition of the momentum method makes the M-PGD attack algorithm keep the direction of optimization in the calculation process of the gradient descent method, which could reduce the effect of randomness and boost the attack ability of the M-PGD attack algorithm. – We apply the M-PGD attack algorithm to adversarial training. We train models on the adversarial samples generated by the M-PGD attack algorithm, which can increase the robustness of the adversarial training models against attacks. – We verify our adversarial training model trained on the CIFAR-10 and CIFAR-100 datasets. Experiments show that our adversarial training model achieves better results compared with the other adversarial training models.
2
Backgrounds
In this section, we mainly describe the background knowledge of neural networks and review the work related to attack algorithms and defense algorithms.
564
2.1
C. He et al.
Terminology and Notation
In this paper, we will use the following terminology and notation regarding adversarial samples: – x: the clean sample — unmodified sample from the dataset (either train set or test set) – xadv : the (candidate) adversarial sample — generated by x and designed to be misclassified by the neural network model – F (x): the neural network model – : the max size of the perturbation – J(x, y): the cost function – P rojx, (A): project A — ||x − A||p < . 2.2
Attack Algorithms
Since the adversarial samples appeared, researchers have proposed a large number of attack algorithms. Szegedy [5] discovered for the first time that neural network models are vulnerable to adversarial examples. He generated adversarial examples with boxconstrained L-BFGS. Goodfellow [12] proposed the FGSM which adds perturbation in the direction where the cross-loss value increases. Moosavi-Dezfooli [14] proposed the DeepFool attack method which is mainly based on the idea of support vector machines [15]. Kurakin [13] proposed BIM which executes the FGSM attack algorithm T times with small step size and intercepts the adversarial sample to the valid range each time. Yinpeng Dong et al. [21] introduced the momentum method into the I-FGSM attack algorithm and proposed the MI-FGSM attack algorithm. Xiaosen Wang [26] proposed Admix which extends mixup algorithm but admixes two images in a master and slave manner. Ali Shahin Shamsabad [27] introduced the ColorFool which perturbs colors only in specific semantic regions and within a chosen range so the samples are still perceived as natural. 2.3
Defensive Algorithms
PGD Adversarial Training. [11] is a strong first-order L∞ attack. If a model can resist the PGD attacks, then it can resist a series of first-order L∞ attacks. Based on this, Madry proposed the PGD adversarial training method. The goal of adversarial training learning is as follows: min ρ(θ), ρ(θ) = E(x,y)∼D [max L(θ, x+δ, y)] θ
δ∈S
(1)
FreeAT Adversarial Training. [17] produces robust models with little additional cost relative to natural training. The key idea is to use one simultaneous back-pass to update model parameters and image perturbations, rather than using separate gradient computations for each update step.
Boosting the Robustness of Neural Networks with M-PGD
565
YOPO Adversarial Training. [18] is motivated by Pontryagin’ Maximum Principle (PMP) and can decouple adversary update and its associated backpropagation. Dual Head Adversarial Training (DH-AT). [24] modifies the architecture and training strategy of the network to seek higher robustness. Specifically, DHAT first attaches a second network head to one of the intermediate layers of the network and then aggregates the outputs of the two heads using a lightweight CNN.
3
M-PGD Adversarial Training Algorithm
In this section, we mainly introduce the M-PGD adversarial training algorithm. It includes two parts: the M-PGD attack algorithm and the adversarial training algorithm. The M-PGD attack algorithm can generate adversarial samples with the largest loss value under the same perturbation, the adversarial training algorithm can boost the robustness of the model with the adversarial samples generated by the M-PGD attack algorithm. The framework of the M-PGD adversarial training algorithm is shown in Fig. 1.
Adversarial Training Update Model
M-PGD Aack Model Clean Samples
Neural Network Model Momentum Stochastic Gradient Descent
Adversarial Samples Parameters Τ
γ
ε
Fig. 1. The framework of the M-PGD adversarial training.
3.1
M-PGD Attack
In this section, we proposed the momentum projected gradient descent (M-PGD) attack algorithm to generate adversarial samples. In the process of generating adversarial samples, the PGD attack algorithm only updates greedily along the negative gradient direction in each iteration, which will cause the PGD attack algorithm sometimes to descend very slowly. Sometimes the PGD attack doesn’t even converge because of the randomness of the direction of the optimization.
566
C. He et al.
The momentum method [16] is a technology that stabilizes the direction of the gradient descent algorithm by accumulating the velocity vector in the gradient direction of the loss function in the iterative process. In order to improve the attack ability of the PGD attack algorithm, we apply the idea of momentum to the PGD algorithm and propose the momentum projected gradient descent (M-PGD) attack algorithm to generate adversarial samples. Adversarial samples can be generated using the following formula: vt+1 = γ · vt + x J(θ, xadv t , y)
adv xadv + α · sign(vt+1 )) t+1 = P roj(xt
(2)
In the M-PGD attack algorithm, if the current gradient is similar to the historical gradient direction, the descent trend following the current gradient will be strengthened; if it is different, the descent trend following the current gradient will be weakened. The M-PGD attack is summarized in Algorithm 1. Algorithm 1. M-PGD Attack Input: A model M with loss function J, a clean sample x and the corresponding label y, the max size of perturbation , the number of steps T , and the momentum factor γ Output: An (candidate) adversarial examples xadv with x − xadv ∞ < 1: α = 2.5 T 2: v0 = 0; xadv = x + random 0 3: for t = 0 to T − 1 do , y) with the model M 4: calculate the gradient J(θ, xadv t 5: Update vt+1 with the momentum method as , y) vt+1 = γ · vt + x J(θ, xadv t 6:
Update xadv t+1 as
adv + α · sign(vt+1 ) xadv t+1 = xt
adv xadv t+1 = P roj(xt+1 , max(0, x − ), min(x + , 255))
7: end for 8: return xadv T
3.2
Adversarial Training
In this section, we will use the adversarial samples generated by the M-PGD attack algorithm to boost the robustness of the model to the adversarial samples. The basic idea is to inject the adversarial samples into the training set and new adversarial samples are continuously generated based on the current model during the training process [12]. Unlike Goodfellow [12] and Kurakin [22], we did not train adversarial models on the clean samples but directly train adversarial models on the adversarial
Boosting the Robustness of Neural Networks with M-PGD
567
samples when the sample data is relatively simple. The main reason is that adversarial training does not lead to a decrease in the accuracy of the clean samples for relatively simple datasets such as MNIST and CIFAR-10. Our approach also directly complies with the min-max formulation (Equation (1)), which has been proved to be effective by many papers [11]. According to Equation (1), we need to find the adversarial samples with the largest loss value under the same perturbation. Generally speaking, the stronger the attack ability of an attack model, the greater the loss generated by the adversarial samples. In the latter Section, we can prove that the average loss of the adversarial samples generated by the M-PGD attack algorithm is greater than that of the adversarial samples generated by the PGD attack algorithm. The processing of adversarial training is summarized in Algorithm 2.
Algorithm 2. Adversarial Training Input: Clean samples x and the corresponding label y, the size of minibatch m, and the number of iterations T Output: A robust model M 1: Randomly initialize the neural network model M 2: for t = 1 to T do 3: Read minibatch B = {x1 , x2 , · · · , xm } from the training set adv 4: Generate m adversarial examples {xadv 1 , · · · , xm } based on the corresponding clean examples B using the model M adv 5: make a new minibatch B = {xadv 1 , · · · , xm } 6: Do one training step of the model M using the minibatch B 7: end for
4
Experiments
In this section, we apply our proposed algorithms (Algorithm 1 and Algorithm 2) to train more robust models. We do many experiments on the CIFAR-10 and CIFAR-100 datasets [19] to validate the effectiveness of our proposed algorithms. Compared with other adversarial training models, we find our adversarial training model can achieve state-of-the-art robustness on the CIFARs. To train and use DNNs, we use TensorFlow [23] and machines equipped with NVIDIA Tesla V100 GPUs. We run experiments for around 80k iterations and then use the obtained accuracy as the final result of the experiments. 4.1
Parameters Selection
In this section, we use the CIFAR-10 dataset to study the influence of hyperparameters on adversarial training.
568
C. He et al.
(a)
(b)
(c)
(d)
Fig. 2. (a)(b)(c) denote the accuracy rates of the clean samples and the adversarial samples based on the “adv trained” model. The curve marked by the triangle represents the accuracy of the model on clean samples, the curve marked by the square represents the accuracy of the model on adversarial samples. (d) denotes the cross-entropy values of the adversarial samples generated based on the PGD attack algorithm and the MPGD attack algorithm.
The Momentum Factor γ is used to adjust the influence of the previous gradient on the current gradient. When γ = 0, the momentum method no longer influences the adversarial training, and the M-PGD adversarial training becomes the PGD adversarial training. We study the classification accuracy rate of our adversarial training model on clean samples and the adversarial samples when γ is from 0.0 to 3.0 with a granularity of 0.2. The experimental results are shown in Fig. 2 (a). when γ = 1.8, we get the best result. The Number of Steps T determines the number of gradient calculations when generating adversarial samples. We study the classification accuracy rate of the model on the clean samples and the adversarial samples when T is from 5 to 40 with a granularity of 5. The experimental results are shown in Fig. 2 (b). when T = 10, we get the best result.
Boosting the Robustness of Neural Networks with M-PGD
569
The Max Size of Perturbation controls the size of the maximum perturbation that can be added to each pixel of the samples during training, which controls the degree of the modification of the samples. We study the classification accuracy rate of the model on the clean samples and the adversarial samples when is from 4 to 32 with a granularity of 4. The experimental results are shown in Fig. 2 (c). With the increase of , the classification accuracy of the model on the clean samples and the adversarial samples continuously decreases. 4.2
Models Comparison
To prove the validity of our M-PGD adversarial training model, we compare our M-PGD adversarial training model with the 7-PGD model [11], the FreeAT model [17], and the YOPO model [18], the DH-AT model [24], and the advGDRO model [25]. For better comparison, we train our M-PGD adversarial models with the Wide-Resnet 32-10 model [20] like the 7-PGD model and the FreeAT model. YOPO uses wider networks (Wide-Resnet 34-10) with a larger batch size (256) in their experiments. To do a direct comparison, we train a YOPO model with a Wide-Resnet 32-10 model, and the batch size is set as 128. We also train the DH-AT model [24] with TRADES and the advGDRO model [25] with an initial learning rate = 0.1, momentum = 0.9, λ1 = 6, and weight decay = 2e−4 . We attack all adversarial training models using the PGD attacks with K iterations on both the cross-entropy loss (PGD-K) and the Carlini-Wagner loss (CWK). All adversarial samples are generated based on the “adv trained” model in Madry’s CIFAR-10 challenge repository. Therefore, the comparison test is blackbox testing. We test the robustness of all adversarial training models using the clean samples, the samples generated by the PGD-20 attack model, the samples generated by the PGD-100 attack model, the samples generated by the CW-100 attack model, and the samples generated by employing the PGD-20 attack model 10 times with random restarting. The parameters of our model are the max size of perturbation = 8, the momentum factor γ = 1.8, the batch size m = 128, and the number of steps T = 10. CIFAR-10. We train all models on the CIFAR-10 dataset. All experimental results are shown in Table 1. According to Table 1, we can know that the classification accuracy of our model is at least 1.09% higher than the other models on the clean samples, the classification accuracy of our model is at least 16.11% higher than the other models on the samples generated by the PGD-20 attack model, the classification accuracy of our model is at least 16.03% higher than the other models on the samples generated by the PGD-100 attack model, the classification accuracy of our model is at least 16.75% higher than the other models on the samples
570
C. He et al.
Table 1. Results of adversarial samples generated based on the “adv trained” model attacking each model Training
Evaluated Against Clean Samples PGD-20 PGD-100 CW-100 10 restart PGD-20
Natural model 95.01%
0.00%
0.00%
0.00%
0.00%
FreeAT m = 8 85.96%
46.82%
46.19%
46.60%
46.33%
7-PGD trained 87.25%
45.84%
45.29%
46.52%
45.53%
YOPO-5-3
43.83%
43.32%
43.25%
43.62%
85.04%
DH-AT
86.78%
59.21%
59.64%
59.85%
59.62%
advGDRO
88.04%
49.83%
49.32%
49.25%
49.62%
Our model
89.13%
65.94%
65.35%
66.35%
65.82%
generated by the CW-100 attack model, and the classification accuracy of our model is at least 16.20% higher than the other models on the samples generated by employing the PGD-20 attack model 10 times with random restarting. Therefore, our adversarial training model successfully reaches a robustness level and gets the best results on both clean and adversarial samples. CIFAR-100. We train all models on the CIFAR-100 dataset which is a more difficult dataset with more classes. The complexity of the CIFAR-100 dataset makes the accuracy of the model on the CIFAR-100 dataset much lower than the accuracy of the model on the CIFAR10 dataset. All experimental results are shown in Table 2. Table 2. Results of adversarial samples attacking each model Training
Evaluated Against Clean Samples PGD-20 PGD-100
Natural model 78.84%
0.00%
0.00%
FreeAT m = 8 62.13%
25.88%
25.58%
7-PGD trained 59.87%
22.76%
22.52%
YOPO-5-3
41.36%
15.64%
15.57%
DH-AT
59.57%
28.91%
27.28%
advGDRO
60.36%
30.64%
30.57%
Our model
62.25%
37.53%
37.64%
According to Table 2, we can know that the classification accuracy of our model is at least 0.12% higher than the other models on the clean samples, the classification accuracy of our model is at least 6.89% higher than the other
Boosting the Robustness of Neural Networks with M-PGD
571
models on the samples generated by the PGD-20 attack model, the classification accuracy of our model is at least 7.07% higher than the other models on the samples generated by the PGD-100 model. Therefore, our adversarial training model also successfully reaches a robustness level and gets the best results on both the clean and the adversarial samples. 4.3
Adversarial Loss Value
To prove that the M-PGD attack algorithm can produce adversarial samples with larger loss values under the same perturbation. We compare the cross-entropy values of the adversarial samples generated by the M-PGD attack algorithm and the PGD attack algorithm when is from 4 to 40 with a granularity of 2. All adversarial samples are generated based on the “naturally trained” model from Madry’s CIFAR-10 challenge repository. The experimental results are shown in Fig. 2 (d), we can conclude that the M-PGD attack algorithm can generate samples with larger cross-entropy values under the same . Training on these adversarial samples generated by the MPGD attack algorithm can obtain models with stronger defensive capabilities.
5
Conclusions
Because the PGD attack algorithm has randomness in the optimization process, it will sometimes cause the adversarial training model unable to converge. To overcome the shortcomings of the PGD attack algorithm, we propose the MPGD attack algorithm. And then, we train an adversarial training model based on the adversarial samples generated by the M-PGD attack algorithm. Experiments show that our adversarial training model is more robust to the adversarial samples. In the future, our adversarial training model could be used to test the attack ability of the attack models. Acknowledgements. This work is supported by the National Key R&D Program of China (No. 2020YFB1006105).
References 1. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. Commun. ACM 60(6), 84–90 (2017). https://doi.org/ 10.1145/3065386 2. Lecun, Y., Bottou, L., Bengio, Y., Haffner, P.: Gradient-based learning applied to document recognition. Proc. IEEE 86(11), 2278–2324 (1998). https://doi.org/10. 1109/5.726791 3. Hinton, G., et al.: Deep neural networks for acoustic modeling in speech recognition: the shared views of four research groups. IEEE Signal Process. Mag. 29(6), 82–97 (2012). https://doi.org/10.1109/MSP.2012.2205597
572
C. He et al.
4. Andor, D., et al.: Globally normalized transition-based neural networks. In: Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), pp. 2442–2452. Association for Computational Linguistics, Berlin, Germany, August 2016. https://doi.org/10.18653/v1/P16-1231 5. Szegedy, C., et al.: Intriguing properties of neural networks. In: Bengio, Y., LeCun, Y. (eds.) 2nd International Conference on Learning Representations, ICLR 2014, Banff, AB, Canada, 14–16 April 2014, Conference Track Proceedings (2014) 6. Carlini, N., Wagner, D.A.: Towards evaluating the robustness of neural networks. In: 2017 IEEE Symposium on Security and Privacy, SP 2017, San Jose, CA, USA, 22–26 May 2017, pp. 39–57. IEEE Computer Society (2017). https://doi.org/10. 1109/SP.2017.49 7. Papernot, N., McDaniel, P., Wu, X., Jha, S., Swami, A.: Distillation as a defense to adversarial perturbations against deep neural networks. In: 2016 IEEE Symposium on Security and Privacy (SP), pp. 582–597 (2016). https://doi.org/10.1109/SP. 2016.41 8. He, W., Wei, J., Chen, X., Carlini, N., Song, D.: Adversarial example defenses: ensembles of weak defenses are not strong. In: Proceedings of the 11th USENIX Conference on Offensive Technologies, p. 15. WOOT 2017, USENIX Association, USA (2017) 9. Xu, W., Evans, D., Qi, Y.: Feature squeezing: Detecting adversarial examples in deep neural networks. In: 25th Annual Network and Distributed System Security Symposium, NDSS 2018, San Diego, California, USA, 18–21 February 2018, The Internet Society (2018) 10. Carlini, N., Wagner, D.: Adversarial examples are not easily detected: bypassing ten detection methods. In: Proceedings of the 10th ACM Workshop on Artificial Intelligence and Security, pp. 3–14. AISec 2017, Association for Computing Machinery, New York, NY, USA (2017). https://doi.org/10.1145/3128572.3140444 11. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. In: 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30– May 3, 2018, Conference Track Proceedings. OpenReview.net (2018) 12. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. In: Bengio, Y., LeCun, Y. (eds.) 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, 7–9 May 2015, Conference Track Proceedings (2015) 13. Kurakin, A., Goodfellow, I.J., Bengio, S.: Adversarial examples in the physical world. In: 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, 24–26 April 2017, Workshop Track Proceedings, OpenReview.net (2017) 14. Moosavi-Dezfooli, S., Fawzi, A., Frossard, P.: Deepfool: a simple and accurate method to fool deep neural networks. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2016, Las Vegas, NV, USA, 27–30 June 2016, pp. 2574–2582. IEEE Computer Society (2016). https://doi.org/10.1109/ CVPR.2016.282 15. Hearst, M., Dumais, S., Osuna, E., Platt, J., Scholkopf, B.: Support vector machines. IEEE Intell. Syst. Appl. 13(4), 18–28 (1998). https://doi.org/10.1109/ 5254.708428 16. Polyak, B.: Some methods of speeding up the convergence of iteration methods. USSR Comput. Math. Math. Phys. 4(5), 1–17 (1964) 17. Shafahi, A., et al.: Adversarial training for free! In: Advances in Neural Information Processing Systems, vol. 32. Curran Associates, Inc. (2019)
Boosting the Robustness of Neural Networks with M-PGD
573
18. Zhang, D., Zhang, T., Lu, Y., Zhu, Z., Dong, B.: You only propagate once: accelerating adversarial training via maximal principle. In: Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019, 8–14 December 2019, Vancouver, BC, Canada, pp. 227–238 (2019) 19. Krizhevsky, A., Hinton, G.: Learning multiple layers of features from tiny images. Master’s thesis, Department of Computer Science, University of Toronto, pp. 32–33 (2009) 20. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770–778. IEEE Computer Society, Los Alamitos, CA, USA, June 2016. https:// doi.org/10.1109/CVPR.2016.90 21. Dong, Y., et al.: Boosting adversarial attacks with momentum. In: 2018 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 9185–9193. IEEE Computer Society, Los Alamitos, CA, USA, June 2018. https://doi.org/10. 1109/CVPR.2018.00957 22. Kurakin, A., Goodfellow, I.J., Bengio, S.: Adversarial machine learning at scale. In: Conference Track Proceedings of 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, 24–26 April 2017, OpenReview.net (2017) 23. Abadi, M., et al.: TensorFlow: Large-scale machine learning on heterogeneous systems (2015), software available from tensorflow.org 24. Jiang, Y., Ma, X., Erfani, S.M., Bailey, J.: Dual head adversarial training. CoRR abs/2104.10377 (2021). https://arxiv.org/abs/2104.10377 25. Chiu, M.C., Ma, X.: Learning representations robust to group shifts and adversarial examples (2022). https://doi.org/10.48550/ARXIV.2202.09446, https://arxiv.org/ abs/2202.09446 26. Wang, X., He, X., Wang, J., He, K.: Admix: enhancing the transferability of adversarial attacks. CoRR abs/2102.00436 (2021), https://arxiv.org/abs/2102.00436 27. Shamsabadi, A.S., Sanchez-Matilla, R., Cavallaro, A.: Colorfool: semantic adversarial colorization. CoRR abs/1911.10891 (2019). http://arxiv.org/abs/1911.10891
StatMix : Data Augmentation Method that Relies on Image Statistics in Federated Learning Dominik Lewy1 , Jacek Ma´ ndziuk1(B) , Maria Ganzha1 , and Marcin Paprzycki2,3 1 Faculty of Mathematics and Information Science, Warsaw University of Technology, Koszykowa 75, 00-662 Warszawa, Poland [email protected], {jacek.mandziuk,maria.ganzha}@pw.edu.pl 2 Systems Research Institute Polish Academy of Sciences, Warszawa, Poland [email protected] 3 Warsaw Management University, Warsaw, Poland
Abstract. Availability of large amount of annotated data is one of the pillars of deep learning success. Although numerous big datasets have been made available for research, this is often not the case in real life applications (e.g. companies are not able to share data due to GDPR or concerns related to intellectual property rights protection). Federated learning (FL) is a potential solution to this problem, as it enables training a global model on data scattered across multiple nodes, without sharing local data itself. However, even FL methods pose a threat to data privacy, if not handled properly. Therefore, we propose StatMix, an augmentation approach that uses image statistics, to improve results of FL scenario(s). StatMix is empirically tested on CIFAR-10 and CIFAR-100, using two neural network architectures. In all FL experiments, application of StatMix improves the average accuracy, compared to the baseline training (with no use of StatMix ). Some improvement can also be observed in non-FL setups. Keywords: Federated Learning Augmentation
1
· Data Augmentation · Mixing
Introductions
One of key factors, behind the success of deep learning in Computer Vision, is the availability of large annotated datasets like ImageNet [2] or COCO [17]. However, even if large datasets theoretically exist, there can be restrictions related to bringing them to one place, to enable model training. Federated learning (FL) addresses this challenge by enabling data to be kept where it is, and share only limited information, based on which the original content cannot be recreated. At the same time FL allows training a model that achieves better results than ones trained in isolation on separated nodes. This, for instance, is a typical scenario for c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 574–585, 2023. https://doi.org/10.1007/978-981-99-1639-9_48
StatMix : Data Augmentation Method for Federated Learning
575
hospitals that gather (possibly annotated) medical images. However, they cannot share it with other hospitals, due to various reasons (e.g. GDPR regulations or intellectual property rights protection). According to the FL classification, proposed in [15], the method presented in this paper addresses a horizontal data partitioning scenario (each of individual nodes collects similar data). The specific focus is on Convolutional Neural Network (CNN) architectures, since the problem considered is an image classification. However, the approach is in no way limited to the CNN-oriented use case. The proposed method is based on sharing limited amount of data between nodes, thus avoiding violation of privacy. In the paper we consider a centralized FL setup. Nevertheless, the proposed algorithm (StatMix ) is communication architecture agnostic and can be easily applied in decentralized settings with each node sharing information with all the other (or selected group of) nodes, instead of a server. Again, with assumed minimization of amount of shared information, the efficiency of communication is not the focus of this study. 1.1
Motivation
Historically, in majority of FL research, during model training, either gradients of the training process (e.g. FedSGD [19]) or weights of the model (e.g. FedAvg [19]), have been shared. Only recently a paper on sharing averaged images (FedMix [23]) was published. However, all these approaches pose a potential threat to data privacy if data sharing is not properly managed (e.g. by using differential privacy, or by ensuring the number of images in the averaged images is large enough). The method, proposed in what follows, limits the information shared to bare minimum (just 6 values, 2 per each color channel), and is still able to provide boost in accuracy. 1.2
Contribution
The main contribution of this work is threefold: – A simplistic data augmentation (DA) mechanism (StatMix ), dedicated to FL learning setup that limits the amount of communication between participating nodes, is proposed. – StatMix is evaluated on two different CNNs, with numbers of FL nodes ranging from 5 to 50, and shows promising results, improving baseline by between 0.3% and 7.5% depending on the architecture and the number of nodes. – It is shown that the standard set of simple DAs, typically used for CIFAR datasets, is not well suited for FL scenario, as it deteriorates the performance along with a decrease of the number of samples per each FL node.
2
Related Work
Federated Learning. Since FL system is, usually, a combination of algorithms each research contribution can be regarded and analysed from different angles.
576
D. Lewy et al.
Typical FL aspects include: (1) if the data is partitioned horizontally or vertically [30], (2) which models are used (some require dedicated algorithms, e.g. trees [27], other can be addressed with more general methods, like SGD [21]), (3) whether the global model is updated during the training process [19], or only once when all local models have been already trained [25], (4) what (if any) is the mechanism that guarantees privacy of the data [3], and/or (5) how effective is the process of sharing information between parties of the system [10]. The idea of FL was introduced in [9], where the usage of asynchronous SGD to update a global model in a distributed fashion was proposed. Currently, the most common approach is FedAvg [19], which at each communication round, performs training on a fraction of nodes, using the local data and, at the end of each round, averages the model weights on the server. Subsequent works, in this area, focused on either making the process more effective [6,10], or being able to address particular data-related scenarios (e.g. non-IID setup [1,16,28]). Since the StatMix method shares only highly limited information between nodes, due to space limits, privacy guarantees and communication efficiency will not be discussed in the literature review. Data Augmentation. Another research area relevant to the scope of this paper is DA [14], especially the methods dedicated to the FL setup. An interesting research approach is adjustment of Mixup [26] to the FL regime ([20,22,23]). However, it requires sending mixed data to the server rendering these methods expensive in terms of communication. Moreover, in some cases, this could lead to privacy violation, if small number of samples is selected for mixing. An alternative approach to DA, is the use of GANs for local node DA [7,8]. These approaches require samples from private node data, to be shared with the server for the purpose of GAN training that will be subsequently downloaded to each node to generate additional synthetic samples. Another approach to synthetic data generation is the usage of models trained using, for instance, FedAvg, to generate samples based on the statistics from the batch normalization layer, using a Zero Shot Learning [4]. Yet another stream of research, worth mentioning that according to our best knowledge was not yet applied to FL problems, and is an inspiration for this work, is MixStyle [29], which is dedicated to the problem of Domain Generalization (DG), i.e. construction of classifiers robust to domain shift, able to generalize to unseen domains. To this end MixStyle, similarly to Mixup based methods, performs sample mixing, However, it does not mix pixels but instance-level feature statistics of the two images generated from the neural network.
3
Proposed Approach
In a typical FL scenario, there are two main components: nodes which contain local data that cannot be shared (e.g., due to privacy reasons), and a server that coordinates the process of information exchange. In certain FL implementations the central server is not used, and participating nodes communicate directly.
StatMix : Data Augmentation Method for Federated Learning
577
Algorithm 1. StatMix Local part 1: 1: K ← number of images in the node; N ← number of nodes 2: for i = 1, 2, . . . , N do 3: for k = 1, 2, . . . , K do 4: Calculate all the image statistics according to equations (1)-(2) 5: Sik = {μ(xik )1 , μ(xik )2 , μ(xik )3 , σ(xik )1 , σ(xik )2 , σ(xik )3 } 6: end for 7: end for 8: Share statistics with the sever Sever part: 9: Distribute statistics to all nodes Local part 2: 10: for i = 1, 2, . . . , N do 11: for epoch = 1, 2, . . . , max epoch do 12: for batch = 1, 2, . . . , max batch do 13: if random(0, 1) < PStatM ix then 14: Randomly select set of statistics Sjm , j ∈ {1, . . . , N }, m ∈ {1, . . . , K} 15: Normalize images from a batch using equation (3) 16: Apply augmentation using equation (4) 17: end if 18: end for 19: end for 20: end for
The goal of this work is to increase the accuracy of classifiers, trained in individual nodes, by using limited statistical information (delivered by all nodes and aggregated on the server). This is to be achieved without sending/storing any actual data. Overall, the proposed approach can be characterized as follows (see Fig. 1 and Algorithm 1): (a) Calculation of image statistics (mean and standard deviation per color channel) in individual nodes – Local part 1 in Algorithm 1 (b) Distribution of the calculated statistics to all nodes via central server – Server part in Algorithm 1 (c) Using these statistics in individual nodes to perform style transfer like augmentation of images in this node – Local part 2 in Algorithm 1. Local Part 1. This is the first step of the algorithm. In each node i = 1, . . . , N , for each locally stored image xik , k = 1, . . . , K, where xik ∈ RW ×H×C (W, H and C denote width, height and color channel, resp.), the mean and the standard
578
D. Lewy et al.
Fig. 1. The figure is composed of two components: a central server (storing only image statistics) and nodes (storing subsets of images and image statistics obtained from the server). The flow shows application of statistics calculated in one node to augment images in another node. For instance, node 1 shares statistics of a plane image with node N, based on which an augmented image of a dog is created.
Fig. 2. The first column shows original images and the remaining part of the figure depicts these images augmented with statistics of various images (each column utilizes a different set of image statistics).
StatMix : Data Augmentation Method for Federated Learning
579
deviation of image pixels are calculated separately for each color channel C = 1, 2, 3, using the following equations: μ(xik )c =
H W 1 xik [w, h, c] HW w=1
(1)
h=1
H W 2 1 σ(xik )c = xik [w, h, c] − μ(xik )c HW w=1
(2)
h=1
where xik [w, h, c] is a value of [w, h] pixel of image xik , in color channel c. These 6 statistics form the set Sik , that is used in Local part 2 for image augmentation. Server Part. In the second step, all sets Sik , i = 1, . . . , N , k = 1, . . . , K, are distributed to all N nodes, i.e. in each node, in addition to K local (private) images, N · K statistics are now stored. Local Part 2. Next the augmentation part takes place. In each node i = 1, . . . , N , all images xik located in that node (k = 1, . . . , K) are randomly divided into max batch batches. Then, for each batch, an image xjm is uniformly selected (from all N · K images, i.e. including those located in a given node) and the corresponding set of statistics Sjm is applied to augment all images from the batch using equations (3)-(4). This augmentation procedure is applied independently to each batch with probability PStatM ix . xnorm ik,c =
xik,c − μ(xik )c σ(xik )c
xaugment = xnorm ik,c · σ(xjm )c + μ(xjm )c ik,c
(3) (4)
Note that augmentation procedure (3)-(4) is applied independently to all 3 color channels. Example results of StatMix augmentation are depicted in Fig. 2.
4
Experimental Setup
The experiments were conducted with two popular datasets. CIFAR-10 [11] consists of 50 000 training and 10 000 test color images, of size 32 × 32, grouped into 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship and truck). There are 5 000 and 1 000 samples of each class in the training and test datasets, respectively. CIFAR-100 [12] is a more granular version of CIFAR-10, with 100 classes. Each class has 500 representatives in the training and 100 in the test datasets, respectively. In order to simulate the FL scenario, let us denote by P the set of all training images in a given dataset (CIFAR-10, or CIFAR-100, respectively). P was randomly divided, in a stratified manner, into N disjoint subsets (P1 , . . . , PN ) of equal size, using labels to reflect the same distribution in each Pi , as in the whole set P . Subsequently, each part was transferred (assigned) to a separate
580
D. Lewy et al.
FL node that was connected only to the server (i.e. there were no connections between FL nodes). At this point each of N nodes calculated statistics of images located in this node and transferred them on the server. Next, for each node i = 1, . . . , N , the server shared individual statistics of all images not located in node i, i.e. all images from P \ Pi . Based on this, the images located in node i, i.e. those belonging to Pi , could be augmented (with certain probability) using image statistics from the entire data set P , according to the approach described in Sect. 3. The augmented sets P Ai , i = 1, . . . , N were used to train the model (one of the 2 deep architectures described in the following paragraph). Afterwards, the trained model was tested on the entire test part of the respective CIFAR dataset. Two popular architectures were tested during experiments: PreActResNet18 [5] and DLA [24]. The models belong to different families and offer decent accuracy in non-FL scenarios. SGD optimizer with initial learning rate equal to 0.01 and momentum equal to 0.9 was used. The learning rate was adapted, using cosine annealing [18], from the initial learning rate to 0, over the course of the training process. In all experiments that mention standard DA, random image crop and random horizontal flip were applied [13]. For consistency, all models were trained for 200 epochs, on a batch of 128 images at a time. The experiments were ran 3 times for each N = 1, 5, 10, 50 with the probability of applying statistics-based augmentation set to 0.5.
5
Experimental Results and Analysis
First, CIFAR-10 results are presented in Table 1. In all experiments, in the FL setup (N > 1) the application of StatMix boosts the final accuracy, compared to the baseline case, with no use of StatMix. The impact of the method grows with the number of nodes in the system (at least, to 50 nodes, as tested here). It is worth noting that the augmentation method, proposed for the FL setup, works also in a non-FL scenario (N = 1). The improvement can be observed in all 4 cases (cf. column diff [%] in Table 1). Lastly, it can be observed that standard DAs (random crop and horizontal flip), often utilized with CIFAR data in nonFL scenarios, deteriorate the accuracy of training in the FL scenario (cf. row True vs. row False for a given architecture and given N > 1). For the CIFAR-100, results are summarized in Table 2. On this, more granular, dataset similar observations are also valid. In the majority of cases, application of StatMix improves the results, compared to the baseline (i.e. the case with no StatMix utilization). However, for this more fine-grained dataset this conclusion does not reach 50 nodes, as for this setup, adding StatMix deteriorates the performance. This is, most probably, caused by too high noise-to-image ratio after augmentation, due to 10 times smaller number of representatives in individual classes, as compared to CIFAR-10.
StatMix : Data Augmentation Method for Federated Learning
581
Table 1. Mean and standard deviation results for CIFAR-10 dataset averaged over last 10 epochs and 3 experiment repetitions. Columns denote: number of nodes (N ), model architecture, whether or not standard DA was applied, whether StatMix augmentation was used (0.0 – not used, 0.5 – used with probability 0.5), the relative improvement of applying StatMix compared to not applying it, i.e. [mean(0.5) / mean(0.0) – 1]. Nodes (N ) Architecture
Standard StatMix 0.0 0.5 Mean Std Mean Std Diff [%]
1
DLA
False True PreActResNet18 False True
86.02 93.26 86.15 93.54
0.80 0.28 0.79 0.05
86.58 93.83 86.60 93.79
0.47 0.19 0.14 0.13
0.65 0.61 0.52 0.27
5
DLA
False True PreActResNet18 False True
67.32 63.39 70.83 68.22
1.15 1.03 0.44 0.64
69.47 66.24 72.01 69.12
0.70 0.89 0.55 0.33
3.19 4.50 1.67 1.32
10
DLA
False True PreActResNet18 False True
56.06 50.72 60.72 56.63
1.27 1.45 0.64 0.77
58.97 54.54 62.03 58.69
1.09 1.59 0.76 0.74
5.19 7.53 2.16 3.64
50
DLA
37.47 34.06 38.62 35.01
1.20 1.11 0.96 1.07
38.06 34.65 40.28 36.93
1.42 1.39 1.08 1.21
1.57 1.73 4.30 5.48
False True PreActResNet18 False True
The conclusion that StatMix is generally beneficial in non-FL scenarios (N = 1) is also valid for CIFAR-100 (cf. the rightmost column in the table). In 3 out of 4 cases (including both with standard DA application), adding StatMix augmentation improves obtained results. Observation, that standard augmentation methods, used commonly in the literature, do not help in the FL scenario(s), holds also for CIFAR-100. The reason behind that might be that augmentation introduces some noise and the network cannot distill true patterns based on limited amount of clean data. 5.1
Ablation Study
In order to check how the probability of applying StatMix impacts final classification accuracy, additional experiments were performed using PreActResNet18 architecture (which is less computationally intensive than DLA) and a setup with 5 nodes (N = 5). All remaining training parameters were adopted from the base experiments. Probabilities ranging from 0 to 1, with a step of 0.1, were tested.
582
D. Lewy et al.
Table 2. Mean and standard deviation results for CIFAR-100 dataset, average from 10 epochs and 3 experiment repetitions. Columns, denote: number of nodes (N ), model architecture, whether or not standard DA was applied, whether StatMix augmentation was used (0.0 – not used, 0.5 – used with probability 0.5), relative improvement of applying StatMix compared to not applying it, i.e. [mean(0.5) / mean(0.0) – 1]. Nodes (N ) Architecture
Standard StatMix 0.0 0.5 Mean Std Mean Std Diff [%]
1
DLA
False True PreActResNet18 False True
59.29 73.40 54.99 71.83
2.08 0.26 2.73 0.49
58.11 75.25 55.84 73.63
0.87 0.46 2.21 0.22
–1.99 2.52 1.55 2.51
5
DLA
False True PreActResNet18 False True
26.46 22.84 31.02 27.70
0.49 0.71 0.58 0.60
28.04 24.84 31.39 28.63
0.53 0.60 0.58 0.59
5.97 8.76 1.19 3.36
10
DLA
False True PreActResNet18 False True
19.86 16.48 22.32 19.37
0.59 0.57 0.41 0.50
20.49 17.80 22.86 20.33
0.66 0.92 0.50 0.57
3.17 8.01 2.42 4.96
50
DLA
9.65 7.83 10.74 9.15
0.64 0.69 0.46 0.45
9.56 7.77 10.48 9.20
0.72 0.74 0.56 0.48
–0.93 –0.77 –2.42 0.55
False True PreActResNet18 False True
The results for CIFAR10 and CIFAR100 are presented in Fig. 3. It can be concluded from the chart that both not applying StatMix at all, as well as applying it to the majority of the batches (more than 80% for CIFAR10 and more than 60% for CIFAR100) renders the worst results. Quite interestingly, applying StatMix to all batches (PStatM ix = 1) results in a huge accuracy deterioration (for CIFAR10 the accuracy dropped to 63%, while for CIFAR100 to 19% in comparison to no augmentation). These results have been excluded from the charts, to avoid obfuscating other findings. For CIFAR10, all probabilities between 0.1 and 0.8 bring positive impact, however with no clearly best values. Hence, as long as StatMix is applied to a certain fraction of the batches, it leads to accuracy boost. For CIFAR100 experiments with lower PStatM ix probability (between 0.1 and 0.4) achieve better final accuracy. A possible explanation is that CIFAR100 is a more complex dataset and introducing too much noise through the StatMix augmentation is no longer beneficial. This leads to a conclusion that the results on CIFAR100 could potentially be further optimized by decreasing the probability of StatMix application.
StatMix : Data Augmentation Method for Federated Learning
583
Fig. 3. CIFAR10 and CIFAR100 test accuracy as a function of probability of applying StatMix in FL setup with 5 nodes (N = 5) on PreActResNet18 architecture. The values are averaged over last 10 epochs and 3 independent experiment repetitions. For each dataset the left figure refers to experiments that utilize standard input DA, the right one presents results without its application.
6
Concluding Remarks
In this work, StatMix, a novel DA method designed for FL, has been introduced. StatMix exchanges high level image statistics (two values per color channel). As a result, data privacy remains protected. At the same time, it has been empirically validated that using this method improves model accuracy, over baseline training (with no use of StatMix ), for two standard benchmark datasets, and two popular CNN architectures. Furthermore, StatMix improves performance in classical, non-FL setup where the method helped in majority of cases. While application of StatMix demonstrates very promising results, future work, aimed at verifying if sharing additional statistics (e.g. those related to hidden layers of the trained networks) could be beneficial. However, such an approach would be more expensive, when it comes to computation, since it would require local networks (in each node) to be trained at least twice. The first training would be needed to calculate statistics of the image in the inference phase in selected hidden layers of the network. These hidden-layer statistics could be then distributed to all nodes, and used in the process of final models training, similarly to the current StatMix specification. Verification of this approach is planned as the next step in StatMix development. Another, directions of future research are adding StatMix to an already existing approach like FedAvg and testing it on a bigger data set like ImageNet. Acknowledgements. Research funded in part by the Centre for Priority Research Area Artificial Intelligence and Robotics of Warsaw University of Technology within the Excellence Initiative: Research University (IDUB) programme.
References 1. Danilenka, A., Ganzha, M., Paprzycki, M., Ma´ ndziuk, J.: Using adversarial images to improve outcomes of federated learning for non-iid data. CoRR abs/2206.08124 (2022)
584
D. Lewy et al.
2. Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., Fei-Fei, L.: ImageNet: a large-scale Hierarchical Image Database. In: CVPR09 (2009) 3. Geyer, R.C., Klein, T., Nabi, M.: Differentially private federated learning: a client level perspective. CoRR abs/1712.07557 (2017) 4. Hao, W., et al.: Towards fair federated learning with zero-shot data augmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops, pp. 3310–3319, June 2021 5. He, K., Zhang, X., Ren, S., Sun, J.: Identity mappings in deep residual networks. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9908, pp. 630–645. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46493-0 38 6. Hsieh, K., Phanishayee, A., Mutlu, O., Gibbons, P.B.: The non-iid data quagmire of decentralized machine learning. In: Proceedings of the 37th International Conference on Machine Learning, ICML 2020, 13–18 July 2020, Virtual Event. Proceedings of Machine Learning Research, vol. 119, pp. 4387–4398. PMLR (2020) 7. Jeong, E., Oh, S., Kim, H., Park, J., Bennis, M., Kim, S.: Communication-efficient on-device machine learning: Federated distillation and augmentation under non-iid private data. CoRR abs/1811.11479 (2018) 8. Jeong, E., Oh, S., Park, J., Kim, H., Bennis, M., Kim, S.L.: Hiding in the crowd: Federated data augmentation for on-device learning. IEEE Intell. Syst. 36(5), 80– 87 (2021). https://doi.org/10.1109/MIS.2020.3028613 9. Koneˇcn´ y, J., McMahan, H.B., Ramage, D., Richt´ arik, P.: Federated optimization: distributed machine learning for on-device intelligence. CoRR abs/1610.02527 (2016) 10. Koneˇcn´ y, J., McMahan, H.B., Yu, F.X., Richt´ arik, P., Suresh, A.T., Bacon, D.: Federated learning: strategies for improving communication efficiency. CoRR abs/1610.05492 (2016) 11. Krizhevsky, A., Nair, V., Hinton, G.: CIFAR-10 (Canadian Institute for Advanced Research) (2009). http://www.cs.toronto.edu/∼kriz/cifar.html 12. Krizhevsky, A., Nair, V., Hinton, G.: CIFAR-100 (Canadian Institute for Advanced Research) (2009). http://www.cs.toronto.edu/∼kriz/cifar.html 13. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. In: Bartlett, P.L., Pereira, F.C.N., Burges, C.J.C., Bottou, L., Weinberger, K.Q. (eds.) Advances in Neural Information Processing Systems 25: 26th Annual Conference on Neural Information Processing Systems 2012. Proceedings of a meeting held December 3–6, 2012, Lake Tahoe, Nevada, United States, pp. 1106–1114 (2012). https://proceedings.neurips.cc/paper/2012/ hash/c399862d3b9d6b76c8436e924a68c45b-Abstract.html 14. Lewy, D., Mandziuk, J.: An overview of mixing augmentation methods and augmentation strategies. Artif. Intell. Rev. 56, 2111–2169 (2023). https://doi.org/10. 1007/s10462-022-10227-z 15. Li, Q., et al.: A survey on federated learning systems: Vision, hype and reality for data privacy and protection. CoRR abs/1907.09693 (2019) 16. Li, T., Sahu, A.K., Zaheer, M., Sanjabi, M., Talwalkar, A., Smith, V.: Federated optimization in heterogeneous networks. In: Dhillon, I.S., Papailiopoulos, D.S., Sze, V. (eds.) Proceedings of Machine Learning and Systems 2020, MLSys 2020, Austin, TX, USA, 2–4 March 2020. mlsys.org (2020) 17. Lin, T.-Y., et al.: Microsoft coco: common objects in context. In: Fleet, D., Pajdla, T., Schiele, B., Tuytelaars, T. (eds.) ECCV 2014. LNCS, vol. 8693, pp. 740–755. Springer, Cham (2014). https://doi.org/10.1007/978-3-319-10602-1 48
StatMix : Data Augmentation Method for Federated Learning
585
18. Loshchilov, I., Hutter, F.: SGDR: stochastic gradient descent with warm restarts. In: 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, 24–26 April 2017, Conference Track Proceedings. OpenReview.net (2017). https://openreview.net/forum?id=Skq89Scxx 19. McMahan, B., Moore, E., Ramage, D., Hampson, S., y Arcas, B.A.: Communication-efficient learning of deep networks from decentralized data. In: Singh, A., Zhu, X.J. (eds.) Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017, 20–22 April 2017, Fort Lauderdale, FL, USA. Proceedings of Machine Learning Research, vol. 54, pp. 1273– 1282. PMLR (2017) 20. Oh, S., Park, J., Jeong, E., Kim, H., Bennis, M., Kim, S.: Mix2fld: downlink federated learning after uplink federated distillation with two-way mixup. IEEE Commun. Lett. 24(10), 2211–2215 (2020) 21. Ruder, S.: An overview of gradient descent optimization algorithms. arXiv preprint arXiv:1609.04747 (2016) 22. Shin, M., Hwang, C., Kim, J., Park, J., Bennis, M., Kim, S.: XOR mixup: privacy-preserving data augmentation for one-shot federated learning. CoRR abs/2006.05148 (2020) 23. Yoon, T., Shin, S., Hwang, S.J., Yang, E.: Fedmix: approximation of mixup under mean augmented federated learning. In: 9th International Conference on Learning Representations, ICLR 2021, Virtual Event, Austria, 3–7 May 2021. OpenReview.net (2021) 24. Yu, F., Wang, D., Shelhamer, E., Darrell, T.: Deep layer aggregation. In: 2018 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2018, Salt Lake City, UT, USA, 18–22 June 2018, pp. 2403–2412. Computer Vision Foundation/IEEE Computer Society (2018). https://doi.org/10.1109/CVPR.2018.00255 25. Yurochkin, M., Agarwal, M., Ghosh, S., Greenewald, K.H., Hoang, T.N., Khazaeni, Y.: Bayesian nonparametric federated learning of neural networks. In: Chaudhuri, K., Salakhutdinov, R. (eds.) Proceedings of the 36th International Conference on Machine Learning, ICML 2019, 9–15 June 2019, Long Beach, California, USA. Proceedings of Machine Learning Research, vol. 97, pp. 7252–7261. PMLR (2019) 26. Zhang, H., Ciss´e, M., Dauphin, Y.N., Lopez-Paz, D.: mixup: beyond empirical risk minimization. In: 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30–May 3, 2018, Conference Track Proceedings. OpenReview.net (2018) 27. Zhao, L., Ni, L., Hu, S., Chen, Y., Zhou, P., Xiao, F., Wu, L.: Inprivate digging: enabling tree-based distributed data mining with differential privacy. In: 2018 IEEE Conference on Computer Communications, INFOCOM 2018, Honolulu, HI, USA, 16–19 April 2018, pp. 2087–2095. IEEE (2018). https://doi.org/10.1109/ INFOCOM.2018.8486352 28. Zhao, Y., Li, M., Lai, L., Suda, N., Civin, D., Chandra, V.: Federated learning with non-iid data. CoRR abs/1806.00582 (2018) 29. Zhou, K., Yang, Y., Qiao, Y., Xiang, T.: Domain generalization with mixstyle. CoRR abs/2104.02008 (2021) 30. Zhuang, Y., Li, G., Feng, J.: A survey on entity alignment of knowledge base. J. Comput. Res. Dev. 53(1), 165 (2016). https://doi.org/10.7544/issn1000-1239. 2016.20150661
Classification by Components Including Chow’s Reject Option Mehrdad Mohannazadeh Bakhtiari1,2(B) and Thomas Villmann1,2(B) 1
Saxon Institute for Computer Intelligence and Machine Learning, Mittweida, Germany 2 University of Applied Sciences Mittweida, Mittweida, Germany {mmohanna,thomas.villmann}@hs-mittweida.de.de
Abstract. In this paper, we present an approach how reject options are integrated into Classification-by-Component networks. Classificationby-Component networks are relatively accurate classifiers that offer a fair interpretability. Yet, their performance can be increased by allowing rejection of an uncertain classification. We will modify the original Classification-by-Component model so that the adaptive parameters adapt, taking reject options into account. Keywords: Classification-by-Components Networks · Probabilistic Classifier
1
· Reject Option · Neural
Introduction
Classification by Components (CbC) networks [1] are robust interpretable classifiers with high performance. Nonetheless, reject options [2] can increase the performance of CbC, when a wrong classification costs considerably. Therefore, in this paper, we investigate the details of reject options for CbC networks. There has been much effort gone into modifying reject options since the original work of Chow. However, most of the papers in the field are not directly relatable or applicable to CbC networks. Here, we mention a few of such work. A paper by Musavishavazi et al. [3] has a comparable mathematical formulation to this paper. Nonetheless, it cannot be immediately used for CbC, because of the contrastive CbC objective function. In another paper, Y. Geifman and R. EL-Yaniv [4] designed a selective classifier, that is more suitable for pre-trained deep neural networks. A similar approach is RISAN [5], that has more similar reject-classification formulation to our work in this paper. Reject options can be introduced for classification systems, in which a wrong decision can lead to catastrophic outcomes [2,6]. In case a classifier is not sure about the class of an input, it can refuse labelling it at a lower cost. Originally, Chow [6] defined reject options and the optimal decision strategy for an optimum system [2] were determined. The labels, for each data, is crisp and the conditional M.M.B. is Supported by ESP. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 586–596, 2023. https://doi.org/10.1007/978-981-99-1639-9_49
Classification by Components Including Chow’s Reject Option
587
probabilities are estimated based on a Bayesian estimator. Note that Chow’s reject option is originally applied to a previously trained model. Later, variations of reject options emerged [7,8] as the proposed classifiers intended to train the adaptable parameters, while reject options are taken into account. We would like to adopt this approach, in this paper, such that we use Chow’s optimal threshold as our threshold. However, the model that estimates conditional probabilities for prediction will be slightly different and, consequently, the prediction probabilities will be different. Also, the Error-Reject curve, as described by Chow [2], will be different for CbC with reject option. Villmann et al. [9] introduced a general loss function that reaches Chow’s optimal threshold. The loss function, which is a weighted sum of error and rejection rates, is the inspiration of this paper. The paper is structured as follows: In Sect. 2, CbC networks are introduced. Section 3 provides a background about reject options. In Sect. 4, reject options are investigated for a classifier, with contrastive loss. Section 5, contains the gradients, with respect to adaptable parameters, for the sake of comparison with the original CbC. The experiments and simulations are in Sect. 6. Finally, we conclude the paper in Sect. 7.
2
The Original CbC
A brief description of CbC model [1] is given in this section. We define components as a set K = {K1 , . . . , Kk } existing in data space X ⊂ Rm . Also, a set C = {1, 2, . . . , c} of data classes is defined. Detection probability function di : X → [0, 1] determines the probability to detect the component Ki in data point x. Hence, di (x) = 0 means the full absence of component Ki in x and di (x) = 1 means the full presence. For a data point x, we define the detection vector d(x) = [d1 (x), . . . , dk (x)]T
(1)
+ To study the effect of components, the reasoning quantity rij ∈ (0, 1) is introduced as the probability that the component Ki is important and must be detected to support the class hypothesis j according to Biederman’s cognitive model [10]. Additionally, negative reasoning is considered in CbC by a reasoning − ∈ (0, 1), taken as the probability that the component Ki is imporquantity rij tant and must not be detected to support the class hypothesis j. Finally, an indefinite reasoning is introduced: If the presence of the component Ki reveals no information about the occurrence of class j, then component Ki has a non0 0 is interpreted as vanishing neutral (indefinite) reasoning over class j, i.e. rij the probability that component Ki is not important for class hypothesis j. We + − 0 + rij + rij = 1. assume that the three quantities satisfy the restriction rij Given a classification problem, i.e. a training set T = {(xb , yb )}τb=1 , xb ∈ X ⊂ Rm , yb ∈ C, we consider a set of trainable components Ki ∈ K together + − 0 , rij , rij . According to [1], CbC is to with adjustable reasoning parameters rij minimize the contrastive loss l(x, y) = φ max pj (x)|j = y, j ∈ C − py (x) (2)
588
M. M. Bakhtiari and T. Villmann
where (x, y) ∈ T and φ : [−1, 1] → R is a monotonically increasing function. The class hypothesis possibilities pˆj (x) are defined as + − i rij · di (x) + rij · (1 − di (x)) pˆj (x) = (3) 0 i (1 − rij ) The possibilistic vector is normalised to have class probabilities pˆj (x) pj (x) = ˆi (x) ip
(4)
In original CbC, optimal components are found using a Convolutional Neural Network (CNN) feature extractor [11] and the reasoning parameters are found, using Stochastic Gradient Descent (SGD).
3
Reject Options for a Probabilistic Classifier
In this section, we introduce the original reject options [2], that is applied to a probabilistic classifier, after training phase. Then, we introduce an empirical loss function, such that its optimal threshold is the same as the optimal threshold of the original reject options. Optimizing the empirical loss function allows us to train CbC networks, while reject options are taken into account. The original reject options is as follows. We assume training set T = {(xb , yb )}τb=1 , as defined in previous section. A probabilistic classifier pi (x) is trained to classify a new input x. Note that pi (x) = 1 ∀x (5) i
Chow [2] defined the decision-rejection rule, given a threshold t: x is rejected to be classified if m(x) < 1 − t (6) where m(x) = maxi {pi (x)}. Otherwise, x is accepted to be classified. The predicted class of the data x is C(x) = argmaxi {pi (x)}
(7)
Given cr , ce , and cc are costs of rejecting a data, miss-classifying a data, and correctly classifying a data, respectively, Chow [2] derived the optimal threshold −cc . Note that the order cc < cr < ce is an important assumption. to be t∗ = ccre −c c Here, we assume cc = 0 to have the threshold t∗ =
cr ce
(8)
Now, we define the general empirical loss function for a probabilistic classifier with reject option. Inspired by [9], the general loss function, to be optimised, is defined as (9) L(t) = ce · E(t) + cr · R(t)
Classification by Components Including Chow’s Reject Option
589
where E(t) and R(t) are error rate and reject rate, respectively and they are both functions of the threshold t. It can be proven that, given the decisions are Bayesian optimal, the optimal threshold of (9) is (8). A proof is provided in appendix A. The empirical version of (9), that is more suitable for learning, is defined as below.
L(t) =
1 · ce · H (m (x) + t − 1) · (1 − δ (C(x), y)) |T | (x,y)∈T
1 · + cr · H (−m (x) − t + 1) |T |
(10)
(x,y)∈T
δ(i, j), for i, j ∈ C, is the Kronecker delta function and H(x), for x ∈ R, is the Heaviside function. Note that the argument of Heaviside functions, in (10), are taken from inequality (6). Therefore, H (m (x) + t − 1) is equal to one, when data is accepted to be classified. Also, 1−δ (C(x), y) is a miss-classification indicator. If the prediction and the true label matched for a data point, then 1 − δ (C(x), y) would be 0, and 1 otherwise. Similar approach is taken by Shah and Manwani [12], when combining classification and rejection. Finally, we plug the optimal threshold (8) into (10) to have
L=
1 · ce · H (m (x) + t∗ − 1) · (1 − δ (C(x), y)) |T | (x,y)∈T
1 · + cr · H (−m (x) − t∗ + 1) |T |
(11)
(x,y)∈T
4
Reject Options for CbC Networks
We still need to modify (11), so the loss function of CbC is involved. CbC uses the contrastive loss function l(x, y) = φ (Δp(x, y))
(12)
Δp(x, y) = ps (x, y) − py (x)
(13)
ps (x, y) = max{pj (x) |j = y, j ∈ C}
(14)
where and We would like to define function φ more accurately, so we can use it as miss-classification indication 1 − δ (C(x), y), that appears in (11). We choose the Sigmoid function as below. φ(x) ≡ φλ (x) =
1 1 + exp(− λx )
(15)
590
M. M. Bakhtiari and T. Villmann
Now, the term 1 − δ(C(x, y), in (11), can be substituted by φλ (Δp(x, y)). Note that the function φλ (Δp(x, y)) goes to zero, given a confident correct classification and it becomes 1, when we have miss-classified an input. In short, the Sigmoid function serves as a smooth indicator function, when λ goes to 0. Now, we approximate the Heaviside function, appearing in (11), with another Sigmoid function, so we can derive required gradients later. H(x) ≡ Hγ (x) =
1 1 + exp(− γx )
(16)
In the above function, we achieve better approximation of Heaviside function as γ goes to 0. The new loss function for CbC, with reject options, is achieved when we integrate the modifications (15) and (16) into (11) as follows.
L=
1 · ce · Hγ (m (x) + t∗ − 1) · φλ (Δp(x, y)) |T | (x,y)∈T
1 · + cr · Hγ (−m (x) − t∗ + 1) |T |
(17)
(x,y)∈T
It can be shown that Hγ (−x) = 1 − Hγ (x) Therefore, we can further manipulate the loss function to have cr ce ∗ · Hγ (m (x) + t − 1) · φλ (Δp(x, y)) − L= |T | ce
(18)
(19)
(x,y)∈T
Recalling (8), we substitute ccre with t∗ in the above equation. Also, the mulce tiplier |T | , in the above equation, does not change the optimum of the function and, hence, it is dropped as well. So, we get the final loss for CbC, when reject options are taken into account. L= Hγ (m (x) + t∗ − 1) · (φλ (Δp(x, y)) − t∗ ) (20) (x,y)∈T
The loss function (20) has the term φλ (Δp(x, y))−t∗ , that is a shifted version (by a constant t∗ ) of CbC loss (12). However, reject option plays role as a multiplier this time, with the term Hγ (m (x) + t∗ − 1). The loss function (20) is used to train the CbC network, using SGD. Since the optimization method of interest ¯ = L(x, ¯ y), assois stochastic gradient descent, we further define the local loss L ciated with training pair (x, y), to be a single term of the summation (20). In ¯ y) = Hγ (m (x) + t∗ − 1) · (φλ (Δp(x, y)) − t∗ ). Assuming that other words L(x, ¯ y) are known, we may drop the arguments the arguments of the local loss L(x, ¯ and simply denote the local loss as L.
Classification by Components Including Chow’s Reject Option
5
591
Derivation and Comparison of Gradients
In this section, we first mention the gradients of original CbC, with respect to the reasoning parameters. Then, the gradients of CbC, with reject option, with respect to the same parameters are included. We would like to mark the differences between the adaptation rules of the methods. For CbC, we take φ to be (15). Then, the gradients of CbC loss function (12) , with respect to general adaptive parameter r, is as below. ∂Δp(x, y) ∂l = Ψ (x, y, λ) · ∂r ∂r
(21)
with Ψ (x, y, λ) =
∂l 1 = · φλ (Δp (x, y)) · (1 − φλ (Δp (x, y))) ∂Δp(x, y) λ
(22)
The partial derivative in (21) is ∂ps (x, y) ∂Δpy (x) ∂Δp(x, y) = − ∂r ∂r ∂r
(23)
Based on the above gradients, in each learning step with training pair (x, y), a parameter is updated to raise the probability py (x) and lower the probability ps (x, y). Now, the gradients of a local loss of the main loss function (20) for CbC, with reject option, are found. ¯ ∂L = Hγ (x, t∗ ) · Ψ ∂r
(24)
with Ψ=
1 ∂m(x) ∂φλ (x, y) · (1 − Hγ (x, t∗ )) · (φλ (x, y) − t∗ ) · + γ ∂r ∂r
(25)
Also, we have made the following shorthand notations in (24) and (25) . Hγ (x, t∗ ) = Hγ (m (x) + t∗ − 1)
(26)
φλ (x, y) = φλ (Δp(x, y))
(27)
and (x,y) ∂l Note that the term ∂φλ∂r = ∂r , appearing in (25), is basically the original CbC gradient (21). In (24), if Hγ (x, t∗ ) ≈ 1, meaning data x is accepted to be classified, then the gradient (24) becomes ¯ ∂l ∂φλ (x, y) ∂L ≈ ≈ (28) ∂r ∂r ∂r which is simply the gradient of original CbC that is mentioned in (21).
592
M. M. Bakhtiari and T. Villmann
If Hγ (x, t∗ ) ≈ 0, then the gradient (24) obviously vanishes. This means that we prefer to keep uncertain data points in rejection area to keep the overall risk low. However, if 0 Hγ (x, t∗ ) 1, that happens when x is close to rejection border (as defined by Chow [2]), then the term φλ (x, y) − t∗ , in (25), plays an important role. The term φλ (x, y) − t∗ basically makes a certain balance between in (25) error rate and rejection rate, since the sign of the coefficient of ∂m(x) ∂r depends on the term. We suppose the case that x is correctly classified, which ∂py (x) = ∂r and the gradient (24) becomes means m(x) = py (x). Then, ∂m(x) ∂r
¯ ∂L ∂py (x) ∂φλ (x, y) ∗ = Hγ (x, t ) · β · + (29) ∂r ∂r ∂r where β =
1 γ
· (1 − Hγ (x, t∗ )) · (φλ (x, y) − t∗ ). ∂p (x)
(x,y) y Since the term ∂r exists in the derivative ∂φλ∂r (23), with a negative ∂py (x) coefficient, the magnitude of coefficient of ∂r becomes even larger in (29), if φλ (x, y) − t∗ < 0. This happens when t∗ is relatively large, meaning that cost of misclassification is low or comparable to cost of rejection. In this case the algorithm is less cautious and alter the parameters with larger rates. On the other hand, if t∗ is relatively small (cost of misclassification is higher than cost of rejection), then we potentially have φλ (x, y) − t∗ > 0 and, consequently, the ∂py (x) in (29) becomes small, since we fear misclassifying a correctly coefficient of ∂r classified input. This section is summarized as follows for gradients of CbC, with reject options. In the region that data is accepted to be classified, the method roughly applies the original CbC rules to data points. Also, if a data is firmly rejected, it stays in rejection region with a high probability. Close to the border of rejection and accepting, the method carefully looks for the best performance, based on the term φλ (x, y) − t∗ , that shows the compromise between error and rejection.
6
Experiments and Simulations
For experiments, the ”Two Moons” data set (TMDS) serves to visualize the difference between components of the original CbC and CbC with reject options. Further, the MNIST data set is considered that has been used in the original CbC [1] for comparison. Particularly, a sample of TMDS with added Gaussian noise of standard deviation 0.1, is considered. We suppose 4 components wt , t ∈ {1, . . . , 4} in the 2d-data space initialized as random data points from the training set. The detection functions, for a component t, is defined as the Gaussian dt (x) = exp(−2 · ||x − wt ||2 ). The parameter λ, in (15), was set to 0.01. In all experiments of this paper, the training data is split into training and test data and a 6-fold cross validation, with 3 × 103 learning steps, is used. After applying original CbC, we report the accuracy (95 ± 0.05)% for the test data. The learnt components, as well as the corresponding reasoning parameters, are depicted in the first column of Fig. 1.
Classification by Components Including Chow’s Reject Option
593
Fig. 1. Visualization of Two Moons. First, second, and third columns, are for normal CbC, CbC with reject option (t∗ = 0.49), and CbC with reject option (t∗ = 0.495), respectively. In each row, the top image shows the location of learnt components in red. The black points represent the misclassified data points. Blue points represent rejected samples. The second and third images, from top, depicts reasoning parameters of Purple and Yellow classes, respectively. The last row shows a gray-scale map for reading values. (Color figure online)
Then, we applied CbC, with reject options, to the same problem. The initialization was same as the original CbC. Theoretically, γ is required to be close to 0. In practice, however, a small γ makes a large fraction γ1 , that appears in the coefficient of the gradient (24). This reduces the effect of the gradient ∂φλ (x,y) , in (24), that is responsible for making the predictions accurate. There∂r fore, we need to slowly reduce γ during learning. The simulations showed that, for both TMDS and MNIST problems, γ = 1l (l is the learning step) makes a suitable starting value and the final value will be desirably small. We report a higher accuracy of (99.9 ± 0.01)% and the rejection rate (23 ± 0.02)%, when t∗ = ccre = 0.49. The learnt components, as well as the corresponding reasoning parameters, are depicted in the second column of Fig. 1. A careful balancing between reject and misclassification costs is needed. Therefore, we increased reject cost to have t∗ = ccˆˆre = 0.495. The accuracy is (97 ± 0.02)%, while the rejection rate is (21 ± 0.04)%. See the third column of Fig. 1. Original CbC was applied to MNIST [1], without feature extraction to achieve 89.5% accuracy. We chose λ = 0.01 (15) and γ = 1l , for φγ . 10 prototypes are initialized, at the center of each class, in the feature space. The data points, as well as the prototypes, are kept normalized at all time. The detection function is dt (x) = max(x, wt + b, 0), where ., . indicates the inner product and b is a margin parameter. We have taken b = 0.3 by trial and error. The error and rejection rate for t∗ = ccre = 0.9 are (23 ± 0.1)% and (0.1 ± 0.11)%, respectively. For t∗ = ccˆˆre = 0.89, we have a careful balance between the error and rejection costs. The error and rejection rates are (19±0.09)% and (36±0.1)% respectively. For case t∗ = 0.89, we have depicted the learnt components and the reasoning
594
M. M. Bakhtiari and T. Villmann
parameters in Fig. 2. Then, we reduced the number of prototypes to 9 and initiated the prototypes at the center of the first 9 classes (class 0 to 8). The rest of the setting is the same as the case with 10 prototypes. The error and rejection rate for t∗ = ccre = 0.9 are (31 ± 0.1)% and (0.1 ± 0.1)%, respectively. For t∗ = ccˆˆre = 0.88, the error and rejection rates considerably changed to (0.2±0.1)% and (69 ± 0.09)% respectively. This relatively high rejection rate suggests that using feature extractors [1] for components might be necessary, when there are 9 or less components.
Fig. 2. Visualization of CbC, with reject options (t∗ = ccˆˆre = 0.89) and 10 components, applied to MNIST. The top two rows show the components, denoted by letters {a, b, c, d, e, f, g, h, i, j}. The bottom two rows show the reasoning parameters for each class {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}. Each class is denoted by its digit.
7
Conclusion
In this paper, we integrated reject options in CbC networks. The new method keeps Chow’s optimal threshold, but let the system know about reject options for better adaptation. We have improved the performance of CbC in a scenario that false classification is costlier than rejection. The method was examined using a toy and MNIST data sets.
A
Proof of Optimality of Chow’s Threshold
Here , we prove t =
cr ce
is the optimal solution for the function L(x, y, t) = ce · E(t) + cr · R(t)
Classification by Components Including Chow’s Reject Option
595
The derivative of the above function, with respect to t, is ∂E(t) ∂R(t) ∂ L(x, y, t) = ce · + cr · ∂t ∂t ∂t or
∂E(t) ∂R(t) ∂R(t) ∂ L(x, y, t) = ce · · + cr · ∂t ∂R(t) ∂t ∂t
(30)
Section IV of [2] contains a relation between E(t) and R(T ), given a Bayes optimal decision system, in form of Riemann-Stieltjes integral. The relation is t E(t) = − τ · dR(τ ) t=0
and the differential form is ∂E(t) = −t ∂R(t) We use the above result in (30) to have the following. ∂ ∂R(t) ∂R(t) L(x, y, t) = −ce · t · + cr · ∂t ∂t ∂t The above derivative it set to 0. ∂R(t) · (−ce · t + cr ) = 0 ∂t Then, we solve for t and the result is t∗ = ccre .
References 1. Saralajew, S., Holdijk, L., Rees, M., Asan, E., Villmann, T.: Classification-bycomponents: probabilistic modeling of reasoning over a set of components. Adv. Neural Inf. Process. Syst. 32 (2019) 2. Chow, C.K.: On optimum recognition error and reject tradeoff. IEEE Trans. Inf. Theor. 16(1), 41–46 (1970) 3. Musavishavazi, S., Bakhtiari, M., Villmann, T.: A mathematical model for optimum error-reject trade-off for learning of secure classification models in the presence of label noise during training. In: International Conference on Artificial Intelligence and Soft Computing, pp. 547–554 (2020) 4. Geifman, Y., El-Yaniv, R.: Selective classification for deep neural networks. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 5. Kalra, B., Shah, K., Manwani, N.: RISAN: robust instance specific deep abstention network. In: Uncertainty in Artificial Intelligence, pp. 1525–1534 (2021) 6. Chow, C.K.: An optimum character recognition system using decision functions. IRE Trans. Electron. Comput. 4, 247–254 (1957) 7. Villmann, T., et al.: Self-adjusting reject options in prototype based classification. In: Advances in Self-Organizing Maps and Learning Vector Quantization, pp. 269– 279 (2016)
596
M. M. Bakhtiari and T. Villmann
8. Fischer, L., Nebel, D., Villmann, T., Hammer, B., Wersing, H.: Rejection strategies for learning vector quantization-a comparison of probabilistic and deterministic approaches pp. 109–118 (2014) 9. Villmann, T., Kaden, M., Nebel, D., Biehl, M.,: Learning vector quantization with adaptive cost-based outlier-rejection. In: International Conference on Computer Analysis of Images and Patterns pp. 772–782 (2015) 10. Biederman, I.: Recognition-by-components: a theory of human image understanding. Psychol. Rev. 94(2), 115 (1987) 11. Jogin, M., Madhulika, M., Divya, G., Meghana, R., Apoorva, S., et al.: Feature extraction using convolution neural networks (CNN) and deep learning. In: 2018 3rd IEEE International Conference on Recent Trends in Electronics, Information & Communication Technology (RTEICT), pp. 2319–2323 (2018) 12. Shah, K., Manwani, N.: Sparse reject option classifier using successive linear programming. Proc. AAAI Conf. Artif. Intell. 33(1), 4870–4877 (2019)
Community Discovery Algorithm Based on Improved Deep Sparse Autoencoder Dianying Chen1 , Xuesong Jiang1(B) , Jun Chen1 , and Xiumei Wei2(B) 1
Qilu University of Technology (Shandong Academy of Sciences), Jinan, Shandong, China [email protected] 2 Qilu University of Technology, Jinan, Shandong, China [email protected]
Abstract. Community structures are everywhere, from simple networks to real-world complex networks. Community structure is an important feature in complex networks, and community discovery has important application value for the study of social network structure. When dealing with high-dimensional matrices using classical clustering algorithms, the resulting communities are often inaccurate. In this paper, a community discovery algorithm based on an improved deep sparse autoencoder is proposed, which attempts to apply to the community discovery problem through two different network similarity representations. This can make up for the deficiency that a single network similarity matrix cannot fully describe the similarity relationship between nodes. These similarity representations can fully describe and consider local information between nodes in the network topology. Then, a weight-bound deep sparse autoencoder is constructed based on an unsupervised deep learning method to improve the efficiency of feature extraction. Finally, feature extraction is performed on the similarity matrix to obtain a low-dimensional feature matrix, and the k-means clustering method is used to cluster the low-dimensional feature matrix to obtain reliable clustering results. In various extensive experiments conducted on multiple real networks, the proposed method is more accurate than other community discovery algorithms using a single similarity matrix clustering algorithm, and the efficiency of the community discovery algorithm is much more improved. Keywords: complex network · node similarity · community discovery · deep learning · deep sparse autoencoders
1
Introduction
The wide application of complex networks in real life and their superior performance have made many researchers have a strong interest in the study of complex networks. In recent years, community detection has become a hot topic This work was supported in part by National Key R&D Program of China (No. 2019YFB1707000). c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 597–609, 2023. https://doi.org/10.1007/978-981-99-1639-9_50
598
D. Chen et al.
in complex network research, aiming to discover the underlying network structure and important information [1], so community discovery can provide important research value and scientific significance for the analysis of network structure. Real-world networks have complex network structures, and it is difficult to obtain accurate community structures using traditional clustering methods. Fully mining the complex information in the network and constructing the similarity relationship between network nodes is the key to improve the accuracy of community discovery algorithm. Usually, we use the similarity matrix to describe the similarity relationship between nodes. The similarity matrix is obtained by transforming the adjacency matrix of the network by a single function. In recent years, the community discovery of using deep learning to solve large-scale complex networks has become a hot topic. Although researchers have proposed many deep learning-based models, there are problems of high model complexity and low training efficiency of models with too many parameters. Wang et al. [2] proposed to learn the similarity matrix representation using a multi-layer autoencoder and an extreme learning machine, which improved the accuracy and stability, but the training time was too high. The algorithm proposed by Jia et al. [3] uses an adversarial network to optimize the strength of the node membership community, so that the generator and the discriminator compete with each other. The alternate iteration of the two improves the accuracy, but the many parameters cause poor universality. Li et al. [4] proposed an algorithm that learns the characterization of the edges of the network by using an edge clustering algorithm to transform into the overlapping community division of nodes, which improves the accuracy, but the stability is not high. The algorithm proposed by Shang et al. [5] uses a multi-layer sparse auto-encoder to reduce the dimensionality of similar matrices and perform representation learning, and uses k-means clustering to improve the accuracy, but the parameters of the algorithm are not easy to choose, and the universality is relatively low. Zhang et al. [6] used multi-layer spectral clustering to divide the community of the network. The accuracy of this algorithm is higher than that of a single layer, but the number of layers is an unstable parameter. The above algorithms all use a single function-based similarity matrix as the input matrix. However, this similarity matrix based on a single function cannot reflect the local information of each node. In addition to the directly connected nodes in the network, there are also indirectly connected nodes, and there are also different similarity relationships between these nodes. Second, the training of deep learning models is inefficient due to the large amount of experimental data and excessive parameters. Therefore, a community discovery algorithm CoIDSA (Community Discovery Algorithm Based on Improved Deep Sparse Autoencoder) based on an improved deep sparse autoencoder is proposed to address the above problems. There are three main contributions of this paper. – The new similarity matrix is constructed from two different functions, which can fully exploit the similarity relationship between each node in the network topology.
Community Discovery Algorithm
599
– Develop a learning method of deep sparse autoencoder based on weight binding, extract the feature representation in the similarity matrix, and obtain a low-dimensional feature matrix. This learning method halved the model parameters, speeding up training and reducing the risk of overfitting. – Extensive experiments are conducted on multiple real datasets. The experimental results show that the CoIDSA algorithm proposed in this paper can obtain a more accurate network community structure.
2 2.1
Related Work Community Discovery
Given a network, the subgraph corresponding to a subset of closely connected nodes in the network is called a community, and the process of finding out its community structure is called community discovery. The current mainstream community discovery algorithms are deep learningbased community discovery algorithms. Autoencoders are very common in community discovery because they can efficiently represent nonlinear real-world networks. In [7], Cao et al. proposed a new method to combine a modular model and a normalized cut model via an autoencoder. In [8], the DeCom model is proposed, which exploits the idea of autoencoders. For extracting overlapping communities in large-scale networks. Different from other autoencoder-based schemes, Cavallari et al. [9] proposed the ComE community embedding framework, where the community embedding problem is treated as a multivariate Gaussian distribution to enhance community discovery. Xu et al. [10] proposed the community discovery approach of CDMEC. This framework combines migration learning and stacked autoencoders to generate feature representations of low-dimensional complex networks. 2.2
Deep Sparse Autoencoder
An autoencoder is a neural network that uses a backpropagation algorithm to make the output value equal to the input value. It first compresses the input into a latent space representation, and then utilizes the representation to reconstruct the output. Autoencoders are divided into two parts: encoder and decoder. Autoencoders can learn efficient representations of input data through unsupervised learning. This efficient representation of the input data is called an encoding, and its dimensions are generally much smaller than the input data, making autoencoders useful for dimensionality reduction. Sparse autoencoder [11] is a kind of autoencoder, which is a derivative autoencoder generated by adding sparsity constraints to the hidden layer neurons of autoencoder, and is able to learn the features that best express the sample in a harsh environment and effectively dimensionalize the data sample. The model is finally trained by calculating the error between the output of the autoencoder and the original input, and continuously adjusting the parameters of the autoencoder.
600
3
D. Chen et al.
Method
In this paper, we propose a community discovery algorithm CoIDSA based on improved deep sparse autoencoder, which mainly consists of three steps: Firstly, two similarity matrices are obtained by preprocessing the adjacency matrix according to two different functions to enhance the similarity of nodes; Secondly, a weight-bound deep sparse autoencoder is constructed to extract features from the similarity matrix to obtain a low-dimensional feature matrix with obvious features and improve the efficiency of model training; Finally, the two low-dimensional feature matrices are combined into a new matrix and clustered using the k-means method to obtain the community structure. The algorithm model diagram is shown in Fig. 1. Shared Parameters n n nodes nodes d1 nodes
5
2 1
Shared Parameters d1 nodes
d1 nodes
d2 nodes
Matrix X
n>d1>d2>d3 Shared Parameters d2 nodes
d3 nodes
d2 nodes
3 4
6
7
Matrix B
Similarity Matrix
x1
x3
x2
x2 Deep Sparse Autoencoder
3 4
6
xn
5
2 1
xn-1
low-dimensional features
Network
Clustering Algorithm (k-means)
7
Communities
Fig. 1. Algorithm model diagram
3.1
Enhance Node Similarity
Use the s-hop function to preprocess the obtained matrix as the first similarity matrix. It can solve the problem of losing the similarity relationship information between many nodes and not reflecting the complete local information of each node because the similarity relationship between the nodes that are not directly connected cannot be represented. Suppose network graph G = (V, E), for nodes v, u ∈ V , if the shortest path from node v to u is called s, then node v can reach node u through s hops. Node Similarity: Given a network graph G = (V, E), for nodes v, u ∈ V , the similarity sim(v, u) between nodes v and u is defined as: sim(v, u) = eσ(1−s)
(1)
Among them, s ≥ 1, with the increase of the number of hops s, the similarity between nodes decreases continuously, σ is called the attenuation factor,
Community Discovery Algorithm
601
σ ∈ (0, 1), which controls the attenuation degree of the similarity of nodes, the greater the σ, the similarity between nodes decay faster. Network similarity matrix: Given a network graph G = (V, E), X = [xij ]n×n is a matrix corresponding to the network graph G. Use formula (1) to calculate the similarity xij = sim (vi , vj ) between the corresponding two nodes vi and vj in X, vi , vj ∈ V , then X is called the similarity matrix of G. When the number of hops is greater than a certain threshold, two nodes that are not in the same community will also get a certain similarity value, which will make the boundary of the community structure more blurred. Therefore, the hop threshold s is set, and only the similarity between nodes that can reach each other within s hops is calculated to ensure that the topology information of the graph is enhanced without affecting the division of community boundaries. Inspired by the modular function Q, which is defined as the difference between the number of edges within communities and the expected number of such edges among all pairs of nodes, the adjacency matrix is preprocessed by the modular function Q, and the modular matrix B is used as the second similarity matrix. Q=
1 T ki kj 1 h Bh A(i, j) − (hi hj ) = 4 m i,j 2m 4m
(2)
k k
i j is the number of edges between nodes i and j, h is the community where 2m membership vector, ki is the degree of node i, A is the adjacency matrix, and m is the total number of edges in the network.
3.2
Feature Extraction
This section describes the detailed process of feature extraction by CoIDSA algorithm. The process of constructing a deep sparse autoencoder based on weight binding is described, and the preprocessed similarity matrices X and B are subjected to feature extraction, and then the two low-dimensional feature matrices after feature extraction are merged, and finally the community structure is obtained by clustering. We give the similarity matrix X = [xij ]n×n of the network graph G as an example, and take the similarity matrix X as the input matrix of the autoencoder, and the encoder maps the input data X to the hidden layer features. The hidden layer feature ξi ∈ Rd is obtained by formula (3), which is expressed as: Encoding: ξi = sf (W xi + b)
(3)
Among them, Sf is the nonlinear activation function, such as the sigmoid 1 ; W ∈ Rd×n is the weight matrix, and b ∈ function sigmoid(x) = (1+exp(−x)) Rd×1 is the bias vector of the coding layer. The decoder reconstructs the hidden layer feature ξi , and the decoding result xi ∈ Rn×1 can be obtained by formula (4) as the output information:
602
D. Chen et al.
ˆ ξi + ˆb Decoding : xi = sg W
(4)
ˆ = W T ∈ Rn×d is the weight where Sg is an activation function of the decoder. W matrix and ˆb ∈ Rn×1 is the bias vector of the decoding layer. During the training process, the autoencoder adjusts the four parameters ˆ , b, ˆb} of the weight matrix and the bias vector, but in order to improve δ = {W, W the training efficiency of the autoencoder, the method of binding weights is adopted here. The encoder and the decoder share the weights, that is, the weights ˆ to each other, and will of the decoder and the encoder are transposed W = W be updated during backpropagation public weight matrix. Then minimize the reconstruction error of xi and xi : min mize
,b,ˆ W,W b
n 2 sf (W xi + b) + ˆb − xi sg W 2
i=1
We use KL divergence, adding a sparsity constraint to the autoencoder:
d n 1 KL ρ ξi n i=1 j=1
(5)
(6)
Then the reconstruction error of building a sparse autoencoder is: n 2 ˆ sf (W xi + b) + ˆb − xi L(δ) = sg W 2
i=1 d
1 +α KL ρ sf (W xi + b) n i=1 j=1 n
(7)
where α is the weight coefficient that controls the sparse penalty term, and ρ is a sparse parameter, usually a small value close to 0 (ρ = 0.01).
4
Experimental Results and Analysis
In order to test the performance of the proposed algorithm in this paper, real datasets are used for performance evaluation. The existing algorithms are compared with the CoIDSA algorithm proposed in this paper. This section is divided into three parts: experimental preparation, evaluation criteria and experimental results. 4.1
Experiment Preparation
DataSet. In this paper, we use four real datasets Football [12], Polblogs [13], Polbooks [14], and Dolphins [15]to validate our algorithm. The specific dataset description is shown in Table 1.
Community Discovery Algorithm
603
Table 1. Dataset Information Dataset
Number of nodes Number of sides communities
Football
115
613
12
Polblogs
1490
16718
2
Polbooks
105
441
3
Dolphins
62
159
2
For better experiments, we set up neural networks with different depths in different layers of the real-world network. For small datasets, choosing different depths of neural network layers can achieve better results in the feature extraction process. For large datasets, we use powers of 2 to reduce the depth of the neural network in the number of nodes in each layer. In the experiments, we set the learning rate of the deep sparse autoencoder to 0.1. Considering the sparsity limitation, we set the sparsity parameter to 0.01. For the number of layers of the deep sparse autoencoder, we choose a suitable 3-layer training layer in each network to provide more accurate feature extraction results. Such parameter settings can effectively utilize the characteristics of deep sparse autoencoders to ensure the accuracy of the feature extraction process in datasets of different sizes through layer-by-layer greedy training. The details are shown in Table 2. Table 2. Structure of deep neural network Dataset
Number of nodes
Dolphins 62-32-16 Football
64-32-16
Polbooks 64-32-16 Polblogs
4.2
256-128-64
Evaluation Criteria and Comparison Algorithms
Evaluation Criteria. The accuracy of the resulting community is judged by the community RC = {C1 , C2 , . . . , Ck } obtained by the real community GT = {C1 , C2 , . . . , Cn } through our algorithm. where n is the number of real communities and k is the number of resulting communities. This paper uses two general community evaluation criteria, N M I and modularity Q, to analyze the accuracy of community discovery. Contrast Algorithm. In order to fully demonstrate the effectiveness of the proposed algorithm in this paper, the proposed algorithm is compared with existing algorithms, namely CoDDA [5], SSCF [16], DNR [17] and DSACD [18]. The
604
D. Chen et al.
CoDDA algorithm is a community discovery algorithm based on sparse autoencoder, which performs feature extraction on the similarity matrix of a single function, and then obtains the community structure by clustering; The SSCF algorithm is a sparse subspace community detection method based on sparse linear coding; The DNR algorithm is to use the nonlinear model in the deep neural network for community detection; The DSACD algorithm is to construct similarity matrices to reveal the relationships between nodes and to design deep sparse autoencoders based on unsupervised learning to reduce the dimensionality and extract the feature structure of complex networks. 4.3
Experimental Results
In this section, the experiments consist of three parts, which are ablation experiment, contrast experiment and parameter experiment. We compare the algorithm of this paper with other algorithms based on two performance evaluation criteria, N M I and modularity Q. To explain the results of parameter selection, we give the parameter experiments. Ablation Experiment. The experimental results are shown in Fig. 2. It can be seen from the figure that the similarity matrix obtained based on the number of hops is 5.5% higher than the clustering result obtained based on the similarity matrix of the modular function, which is due to the preprocessing method based on the number of hops, which recalculates the similarity matrix of the network nodes and improves the local information of the nodes more. The clustering results of the similarity matrix obtained based on two functions are on average 5.2% higher than the similarity matrix obtained after the hop-count based preprocessing, so it can be clearly seen that the experimental results of two similarity matrices outperform the experimental results of one similarity matrix. Thus we can conclude that using two different similarity matrices is effective for the accuracy improvement of the community discovery algorithm.
Fig. 2. Experimental results of different functions as input.
Community Discovery Algorithm
605
Comparative Experiment. In order to verify the accuracy of this algorithm in community discovery, the CoIDSA algorithm is used to compare with the existing algorithms, and the results of the community evaluation indicators N M I and Q are shown in Fig. 3.
Fig. 3. Comparison of N M I and Q performance of CoIDSA algorithm with 4 existing community detection algorithms on 4 real datasets. Table 3. Analysis of community discovery results Algorithm Dolphins NMI Q
Football NMI Q
Polbooks NMI Q
Polblogs NMI Q
SSCF
0.82 0.52 0.83 0.52 0.56 0.35 0.12 0.16
DNR
0.81 0.56 0.85 0.55 0.58 0.54 0.51 0.32
CoDDA
0.79 0.51 0.90 0.58 0.82 0.51 0.71 0.51
DSACD
0.82 0.53 0.91 0.56 0.85 0.50 0.73 0.52
CoIDSA
0.91 0.52 0.95 0.60 0.90 0.52 0.86 0.63
The metrics of community detection on different algorithms for the test dataset are presented in Table 3 and Fig. 3. The N M I metrics of the CoIDSA algorithm as a whole are higher compared to all other algorithms. This is because the algorithm adopts the preprocessing method based on the number of hops and the modularization function before clustering, and recalculates the similarity matrix of the network nodes, so that the complexity of the network structure is fully considered, and different methods are used to construct multiple similarity matrix. Then, a deep sparse autoencoder bound by weights is used to perform feature extraction on the similarity matrix to obtain a low-dimensional feature matrix with more obvious features, and then cluster to obtain a more accurate community. Community results using the CoIDSA algorithm are on
606
D. Chen et al.
average 5% higher than other algorithms. This is because the CoIDSA algorithm uses multiple similarity matrix inputs before feature extraction, which improves the local information of nodes, and the obtained low-dimensional feature matrix can better express the structure of the network, which proves the effectiveness of the CoIDSA algorithm. The DSACD algorithm is higher than the CoDDA algorithm in all indicators, because the improvement of the DSACD algorithm in the CoDDA algorithm is consistent with the description of the paper. The clustering results on the Polblogs dataset are lower than those of the other three datasets, because the Polblogs dataset has many times more data than the other three datasets, but the indicators of the CoIDSA algorithm also reach a higher level, which shows that the CoIDSA algorithm is also effective on larger datasets. However, due to the characteristics of deep learning, more time is needed during the training process, and the running time of each algorithm on the four datasets is shown in Table 4. Table 4. Running time on different datasets Dataset Dolphins(s) Football(s) Polbooks(s) Polblogs(min) SSCF
1.91
2.89
3.65
5.89
DNR
2.23
3.78
4.67
6.35
CoDDA 1.85
2.56
2.43
4.36
DSACD 1.83
2.43
2.25
4.12
CoIDSA 1.36
1.19
1.07
3.54
From the experimental results in Table 4, it can be seen that the CoIDSA algorithm proposed in this paper is better than the current mainstream community discovery algorithm in time. The reason is that we choose weight binding as the optimization method, and bind the weight of the decoder layer to the encoder layer. This halved the model parameters, speeding up training and reducing the risk of overfitting. Where the weights are shared among the layers and the common weight matrix will be updated during backpropagation for the purpose of improving training efficiency. A weight-bound deep sparse autoencoder is used for training, and the experimental results show that this algorithm is effective in improving the efficiency of community discovery. Parametric Experiment. Deep sparse autoencoders for community discovery contain two important parameters: jump threshold (s) and decay factor (σ) in the first similarity matrix. These two parameters have a direct impact on the clustering results. This section sets up experiments to find the optimal parameters. – The number of jump thresholds s
Community Discovery Algorithm
607
For the Football dataset, the number of nodes in each layer of a deep sparse autoencoder is [64–32], and the decay factor σ = 0.5 is used to analyze the impact of different jump thresholds on N M I. As shown in Fig. 4, under different values of the jump threshold s, compared to directly using the k-means algorithm to cluster the similarity matrix, the result community obtained by the CoIDSA algorithm is more accurate. The CoIDSA algorithm can greatly improve the accuracy of community results.
Fig. 4. N M I values of CoIDSA algorithm and k-menas algorithm under different jump threshold s in Football dataset.
It can be seen from Fig. 4 that with the increase of s, the N M I value shows a trend of increasing first and then decreasing, which is also in line with the actual situation. Because in the real network, there is a certain degree of similarity between nodes that are not directly connected but can be reached after a certain number of hops. If the number of hops is too large, there is a certain degree of similarity between nodes far away, but the ambiguity of the community identification boundary is increased. For the smaller datasets Football, Polbooks, and Dolphins, the jump threshold was chosen to be 3 hops, and for the larger dataset Polblogs, 5 hops was chosen to achieve optimal results. – Decay factor σ For the Football dataset, the number of nodes in each layer of a deep sparse autoencoder is [64–32], and the jump threshold s=3, and the influence of different attenuation factors on N M I is analyzed. As shown in Fig. 5, under different values of the attenuation factor, compared to directly using the k-means algorithm to cluster the similarity matrix, the result community obtained by the CoIDSA algorithm is more accurate. The CoIDSA algorithm can greatly improve the accuracy of community results.
608
D. Chen et al.
Fig. 5. N M I values of CoIDSA algorithm and k-menas algorithm under different attenuation factors in Football dataset.
It can be seen from Fig. 5 that with the increase of the attenuation factor, the N M I value shows a trend of first increasing and then decreasing. When the attenuation factor is set to 0.5, the local characteristics of the node can be enhanced to achieve the optimal result.
5
Conclusions and Future Work
In this paper, in order to more effectively detect complex network structures with high-dimensional feature representations, we propose a community discovery algorithm based on an improved deep sparse autoencoder. Through experiments on real data sets, the community discovery algorithm based on the improved deep sparse autoencoder proposed in this paper has higher accuracy, stronger stability and faster training efficiency for community discovery. In addition, the main research object of this paper is static networks. Therefore, community discovery on dynamic networks will be the direction of future research. Acknowledgment. This work was supported in part by National Key R&D Program of China (No. 2019YFB1707000).
References 1. Ding, Z., Chen, X., Dong, Y., Herrera, F.: Consensus reaching in social network degroot model: the roles of the self-confidence and node degree. Inf. Sci. 486, 62–72 (2019) 2. Wang, F., Zhang, B., Chai, S., Xia, Y.: Community detection in complex networks using deep auto-encoded extreme learning machine. Mod. Phys. Lett. B 32(16), 1850180 (2018) 3. Jia, Y., Zhang, Q., Zhang, W., Wang, X.: Communitygan: Community detection with generative adversarial nets. In: The World Wide Web Conference, pp. 784–794 (2019)
Community Discovery Algorithm
609
4. Li, S., Zhang, H., Wu, D., Zhang, C., Yuan, D.: Edge representation learning for community detection in large scale information networks. In: Doulkeridis, C., Vouros, G.A., Qu, Q., Wang, S. (eds.) MATES 2017. LNCS, vol. 10731, pp. 54–72. Springer, Cham (2018). https://doi.org/10.1007/978-3-319-73521-4 4 5. Shang, J., Wang, C., Xin, X., Ying, X.: Community detection algorithm based on deep sparse autoencoder. J. Softw. 28(3), 648–662 (2017) 6. Zhang, X., Newman, M.E.: Multiway spectral community detection in networks. Phys. Rev. E 92(5), 052808 (2015) 7. Cao, J., Jin, D., Yang, L., Dang, J.: Incorporating network structure with node contents for community detection on large networks using deep learning. Neurocomputing 297, 71–81 (2018) 8. Bhatia, V., Rani, R.: A distributed overlapping community detection model for large graphs using autoencoder. Future Gener. Comput. Syst. 94, 16–26 (2019) 9. Cavallari, S., Zheng, V.W., Cai, H., Chang, K.C.-C., Cambria, E.: Learning community embedding with community detection and node embedding on graphs. In: Proceedings of the 2017 ACM on Conference on Information and Knowledge Management, pp. 377–386 (2017) 10. Xu, R., Che, Y., Wang, X., Hu, J., Xie, Y.: Stacked autoencoder-based community detection method via an ensemble clustering framework. Inf. Sci. 526, 151–165 (2020) 11. Qiao, S., et al.: A fast parallel community discovery model on complex networks through approximate optimization. IEEE Trans. Knowl. Data Eng. 30(9), 1638– 1651 (2018) 12. Khorasgani, R.R., Chen, J., Zaiane, O.R.: Top leaders community detection approach in information networks. In: 4th SNA-KDD Workshop on Social Network Mining and Analysis, Citeseer (2010) 13. Adamic, L.A., Glance, N.: The political blogosphere and the 2004 us election: divided they blog. In: Proceedings of the 3rd International Workshop on Link Discovery, pp. 36–43 (2005) 14. Newman, M.E.: Modularity and community structure in networks. Proc. Natl. Acad. Sci. 103(23), 8577–8582 (2006) 15. Lusseau, D., Newman, M.E.: Identifying the role that animals play in their social networks. In: Proceedings of the Royal Society of London. Series B: Biological Sciences, vol. 271, no. suppl 6, pp. S477–S481 (2004) 16. Mahmood, A., Small, M.: Subspace based network community detection using sparse linear coding. IEEE Trans. Knowl. Data Eng. 28(3), 801–812 (2015) 17. Yang, L., Cao, X., He, D., Wang, C., Wang, X., Zhang, W.: Modularity based community detection with deep learning. In: IJCAI, vol. 16, pp. 2252–2258 (2016) 18. Fei, R., Sha, J., Xu, Q., Hu, B., Wang, K., Li, S.: A new deep sparse autoencoder for community detection in complex networks. EURASIP J. Wireless Commun. Networking 2020(1), 1–25 (2020). https://doi.org/10.1186/s13638-020-01706-4
Fairly Constricted Multi-objective Particle Swarm Optimization Anwesh Bhattacharya1 , Snehanshu Saha2(B) , and Nithin Nagaraj3 1
Department of Physics and CSIS, BITS-Pilani, Pilani, Rajasthan 333031, India [email protected] 2 APPCAIR and Department of CSIS, BITS-Pilani, Goa, Goa 403726, India [email protected] 3 Consciousness Studies Programme, National Institute of Advanced Studies, Bangalore 560012, India [email protected]
Abstract. It has been well documented that the use of exponentiallyaveraged momentum (EM) in particle swarm optimization (PSO) is advantageous over the vanilla PSO algorithm. In the single-objective setting, it leads to faster convergence and avoidance of local minima. Naturally, one would expect that the same advantages of EM carry over to the multi-objective setting. Hence, we extend the state of the art Multi-objective optimization (MOO) solver, SMPSO, by incorporating EM in it. As a consequence, we develop the mathematical formalism of constriction fairness which is at the core of extended SMPSO algorithm. The proposed solver matches the performance of SMPSO across the ZDT, DTLZ and WFG problem suites and even outperforms it in certain instances. Keywords: Particle Swarm Optimization Optimization · Meta-Heuristic
1
· Multi-Objective
Introduction
Particle Swarm Optimization (PSO) was first proposed by Kennedy and Eberhart [1,2] in 1995 as an evolutionary single-objective optimization algorithm. N particles are initialised at random positions/velocities in the search space, and the ith particle updates its trajectory according to (t+1)
vi
(t+1) xi
(t)
(t)
(t)
(t)
= wvi + c1 r1 (pbesti − xi ) + c2 r2 (gbest(t) − xi ) =
(t) xi
+
(t+1) vi
(1) (2)
r1 and r2 are random numbers drawn from the uniform distribution U (0, 1). (t) pbesti is the best position (in terms of minimizing the objective) that particle i has visited upto time t. gbest(t) is the best position among all particles that has Supported by DST. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 610–621, 2023. https://doi.org/10.1007/978-981-99-1639-9_51
Fairly Constricted Multi-objective Particle Swarm Optimization
611
been achieved. After sufficient iterations, all particles assume positions xi near gbest with the particle velocities vi ≈ 0. In this state, we say that the swarm has converged. [3] proposes the EMPSO algorithm to speed up the convergence and avoid local minima in single-objective problems. It is a vanilla PSO algorithm aided by exponentially-averaged momentum (EM). Their PSO update equations are as follows (t+1)
Mi
(t+1) vi
(t)
= βMi =
(t)
+ (1 − β)vi
(t+1) Mi
+
(3)
(t) c1 r1 (pbesti
−
(t) xi )
+ c2 r2 (gbest(t) −
(t) xi )
(4)
Eq. (3) computes the exponentially-averaged velocity of the ith particle upto timestep t. The position update equation for EMPSO remains the same as Eq. (2). The momentum factor must obey 0 < β < 1.1 By recursively expanding Eq. (3), a particle’s momentum is an exponentially weighted sum of all its previous velocities (t+1)
Mi
(t)
(t−1)
= (1 − β)vi + β(1 − β)vi
+ β t−2 (1 − β)vi
(2)
+ β t−1 (1 − β)vi
(1)
(5)
In certain single-objective problems, [3] report a 50% reduction in the iterations taken to convergence for EMPSO relative to vanilla PSO. Due to its superior performance over the vanilla algorithm in the single-objective setting, we hypothesize that similar benefits of EM would be seen in multi-objective problems. The central setting of multi-objective optimization (MOO) is the following problem f (x) = [f1 (x), f2 (x), . . . , fk (x)] minimize n x∈R
i.e., given an input space Rn , we want to optimize k functions f1 , f2 , . . . , fk in the objective space. In practice, MOO solvers find an Pareto front which represents a non-dominated set of decision variables xi ∈ Rn . Simplistically, it is a set of solutions where each member of the set is as good a solution as any other. A comprehensive introduction to MOO can be found in [4]. SMPSO [5] is the state-of-the-art MOO solver that is based on vanilla PSO. It uses a constricted vanilla PSO whose update equation is (t+1)
vi
(t)
(t)
(t)
(t)
= χ[wvi + c1 r1 (pbesti − xi ) + c2 r2 (gbest(t) − xi )]
where χ is the constriction factor [6] defined as follows √2 φ>4 χ = 2−φ− φ2 −4φ 1 φ≤4
(6)
(7)
with φ = c1 + c2 . Hence χ is a function of c1 , c2 . Since the constriction factor is with respect to vanilla PSO, we denote it as χ ≡ χ(v) (φ)2 . The position update equation for constricted vanilla PSO remains the same as Eq. (2). We describe SMPSO in Algorithm 1. 1 2
Note that β = 0 degenerates to vanilla PSO. Note that constriction factor is negative for φ > 4.
612
A. Bhattacharya et al.
Algorithm 1. SMPSO Pseudocode .
1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13:
initializeSwarm() leaders = initializeArchive() gen ← 0 while gen < maxGen do computeSpeed () updatePosition() mutation() evaluation() updateArchive(leaders) updateParticlesMemory() gen ← gen + 1 end while Return leaders
Line 1 initializes the particle’s positions in the input space along with random velocities. As per [4], the external archive for storing leaders is initialized in line 2. Line 5 updates the swarm obeying constricted Vanilla PSO Eqs. (6, 7). Line 6 follows the regular position update equation as Eq. (2). Line 7 performs a turbulence mutation which introduces a diversity of solutions in the swarm, so that they don’t converge to a single point. Finally, the particles are evaluated and the external archive is updated in lines 8–10. In particular, we focus on line 5 of Algorithm 1 and expand it in Algorithm 2.
Algorithm 2. SMPSO computeSpeed () .
1: for i ← 1 to swarmSize do 2: r1 ← Uniform(0, 1) 3: r2 ← Uniform(0, 1) 4: c1 ← Uniform(1.5, 2.5) 5: c2 ← Uniform(1.5, 2.5) 6: φ ← c1 + c2 7: χ ← ConstrictionFactor (φ) 8: v[i] ← wv[i] + c1 r1 (pbest[i] − x[i]) + c2 r2 (gbest − x[i]) 9: v[i] ← χv[i] 10: v[i] ← VelocityConstriction(v[i]) 11: end for
Lines 2–3 draw r1 , r2 from a uniform distribution U (0, 1) and lines 4–5 draw c1 , c2 ∼ U (1.5, 2.5). Line 6 computes φ and line 7 computes the constriction factor χ(v) (φ). Lines 8–9 update the particles velocity according to Eq. (6) where x[i] and v[i] are the position and velocity vectors respectively of the ith particle. Finally, line 10 performs a velocity constriction based on the boundary of the
Fairly Constricted Multi-objective Particle Swarm Optimization
613
search space3 . SMPSO claims that its superiority over other MOO solvers, such as OMOPSO [7] and NSGA-II [8], is rooted in the randomized selection of c1 , c2 along with the constriction factor χ(v) (φ) which maintains a diversity of solutions in the swarm.
2
Motivations
Apart from the external archive, leader selection and mutation, the performance of SMPSO is governed by the dynamics of the swarm which is solely dictated by the computeSpeed () subroutine (Algorithm 2). Thus, the incorporation of EM in SMPSO must occur within the computeSpeed () function (line 5 of Algorithm 1). As a first attempt, we formulate the desired computeSpeed () in Algorithm 3. We name our EM-aided SMPSO algorithm as EM-SMPSO. Akin to Algorithm 2, we Algorithm 3. EM-SMPSO computeSpeed () .
1: for i ← 1 to swarmSize do 2: r1 ← Uniform(0, 1) 3: r2 ← Uniform(0, 1) 4: c1 ← Uniform(1.5, 2.5) 5: c2 ← Uniform(1.5, 2.5) 6: β ← Uniform(0, 1) 7: φ ← c1 + c2 8: χ ← ConstrictionFactor (φ, β) 9: m[i] ← βm[i] + (1 − β)v[i] 10: v[i] ← m[i] + c1 r1 (pbest[i] − x[i]) + c2 r2 (gbest − x[i]) 11: v[i] ← χv[i] 12: v[i] ← VelocityConstriction(v[i]) 13: end for
draw β ∼ U (0, 1) in line 6. Line 8 computes the appropriate constriction factor for EM-SMPSO. Note that the function ConstrictionFactor () now takes two arguments (φ, β) instead of one. This is because EM directly affects the swarm dynamics and hence we need a different constriction factor χ ≡ χ(m) (φ, β). Lines 9–11 are the update equations of constricted EMPSO. It can be shown that the constriction co-efficient is4 √ 2 φ > 4(1 + β)−1 (m) χ (φ, β) = 2−φ− φ2 −4(1−β)φ (8) 1 otherwise From a theoretical standpoint, adopting a positive/negative constriction coefficient are equivalent because only the modulus |λ| is significant [9]. Moreover, 3 4
Details of this step can be found in [5]. A complete derivation can be found in Sect. 1 of the supplementary file. GitHub https://github.com/anuwu/FCMPSO.
614
A. Bhattacharya et al.
note that β = 0 implies that the effect of momentum is absent and it can be easily confirmed that χ(m) (φ, 0) = χ(v) (φ). Thus, our derivation is consistent with that of vanilla PSO.
(a) SMPSO
(b) EMSMPSO
(c) SMPSO
(d) EMSMPSO
Fig. 1. Pareto Fronts on ZDT1 and ZDT2
In Fig. 1, we present the Pareto fronts of EM-SMPSO (Algorithm 3) on the ZDT [10] bi-objective problems. The nature of the fronts are poor compared to that obtained by SMPSO i.e., significantly fewer points in the external archive and a fragmented Pareto front. The SMPSO Pareto fronts, on the other hand, are smooth and dense. The Pareto fronts were obtained using the jmetalpy [11] framework. In the single-objective realm, a blanket introduction of EM into the swarm dynamics significantly improved performance compared to vanilla PSO across various objective functions. Whereas, in the multi-objective case, that is not the case as is demonstrated by the Pareto fronts. It is instuctive to analyse that component of SMPSO which is pivotal to its superior performance — the constriction factor. Drawing c1 , c2 ∼ U (1.5, 2.5) entails that φ ∼ U (3, 5) according to Algorithm (3). The midpoint of this distribution is φ = 4, which is also the value at which the two separate branches of χ(v) (φ) are defined in Eq. (7). We say that the constriction factor is active if the first branch is taken. Hence in the entire evolution 5 of the swarm, the constriction factor is activated with probability 12 . It is in this sense that SMPSO is a fairly constricted algorithm — the constriction factor is activated/unactivated with equal chance. EM-SMPSO with φ ∼ U (3, 5) and β ∼ U (0, 1) is not a fairly constricted algorithm because of the way χ(m) (φ, β) is defined. We prove this fact in Sect. 3
3
Finding a Fairly Constricted Algorithm
We first develop simple mathematical formulae to assess the fairness of any variant of EM-SMPSO algorithm where φ ∼ U (φ1 , φ2 ) and β ∼ U (β1 , β2 ). The respective probability densities for the variables φ, β are pφ (φ), pβ (β) respectively. Let E be the event that φ > 4(1 + β)−1 corresponding to the definition in Eq. (8). We wish to find a formula for P (E) 5
One step in the evolution of the swarm is one iteration of the while loop in Algorithm 1. The complete evolution is iterating through the loop until the stopping criteria is met.
Fairly Constricted Multi-objective Particle Swarm Optimization
615
P (E) = φ>4(1+β)−1
pφ (φ)pβ (β)dφ dβ =
β>4φ−1 −1
pβ (β)pφ (φ)dβ dφ
Using simple calculus, Eq. (9) can be simplified to P (E) =
φg
β2
4φ−1 −1
φl
pβ (β)pφ (φ)dβ dφ +
φ2
φg
pφ (φ)dφ
(9)
where φl = max(φ1 , 4(1 + β2 )−1 ) and φg = min(4(1 + β1 )−1 , φ2 ). Additionally, we define the unfairness metric μ = P (E)− 12 to ease mathematical analysis. Note that it satisfies −0.5 ≤ μ ≤ 0.5. It is a measure of how far away an algorithm is from being fairly constricted. μ = 0 is a fairly constricted algorithm whereas μ > 0 is over-constricted, and μ < 0 is under-constricted. It can be shown that Algorithm (3) corresponds to an unfairness value μ = 1 − 2 ln(4/3) ≈ 0.42 It is an over-constricted algorithm compared to SMPSO by a large margin. Thus, we have been able to reason about the suboptimal nature of the Pareto Fronts using fairness analysis of the constriction factor. We wish to utilize the full range of the momentum parameter and hence set β1 = 0, β2 = 1. In computing the probability integral, we posit φl = φ1 and φg = φ2 which amounts to exercising the choices of φ1 ≥ 2 and φ2 ≤ 4 respectively. Hence P (E) =
φ2
1
dβ φ1
4/φ−1
dφ ln(φ2 /φ1 ) =2−4 φ2 − φ1 φ2 − φ1
(10)
With an intelligent choice of φ1 and solving a transcendental equation, it can be shown that using c1 , c2 ∼ U (1, 1.7336) and β ∼ U (0, 1) would result in a fairly constricted algorithm. We call it Fairly Constricted Particle Swarm Optimization (FCPSO)6 . Note that there may exist other parameter sets that are also fairly constricting. In this work, we have derived only one fairly constricting set and subsequently used it for benchmarking.
4
Results
We first present the Pareto fronts of the ZDT1 and ZDT3 problems. From a first qualitative look, the Pareto fronts of FCPSO match that of SMPSO. The solution points are densely packed, and well-connected unlike the fragmented Pareto fronts of the naive EM-SMPSO algorithm (Fig. 1).
6
Complete details on computing the probability integral P (E) can be found in Sect. 2 of the supplementary file.
616
A. Bhattacharya et al.
(a) ZDT1
(b) ZDT3
Fig. 2. Pareto Fronts of FCPSO
4.1
Assessment with Quality Indicators
We choose the following 5 quality indicators to assess the performance of FCPSO. Namely, they are Inverted Generational Distance (IGD), Spacing (SP), Hypervolume (HV), and the -indicator (EPS). Computation of these indicators is done with the jmetalpy framework. A thorough description of these indicators can be found in [12]. The measurement of all indicators was done after letting the swarm evolve for 25, 000 function evaluations. In the case of measuring function evaluation itself, we allow the swarm to evolve until 95% of the hypervolume (HV hereafter) of the theoretically computed Pareto front is reached. The theoretical fronts were obtained from [13,14] and [15]. All quality indicator values of FCPSO are accompanied by corresponding values from SMPSO for the sake of comparison. Each measurement was repeated 20 times for statistical testing. The resultant p-values have been written as subscripts in the tables. We have not shown values of all quality indicators for all problems due to space constraints, however (Fig. 2). Bi-objective ZDT and Tri-objective DTLZ: The ZDT [10] suite of multiobjective problems is a list of biobjective problems for the evaluation of MOO algorithms. Along with the triobjective DTLZ [16] problems, these constitute the basic problems that any MOO algorithm must be able to solve accurately. We show that FCPSO matches the performance of the state-of-the-art SMPSO in these problems. Please refer to Table 1 for the FE and HV measurements. FCPSO matches SMPSO in most problems, occasionally outperforming (and underperforming) SMPSO. A peculiarity to be noted for the DTLZ-2, 4 and 5 problems is that the number of FEs is not characteristic of the other problems and hence we have not included their statistical p-values. This is probably due to the particular nature of the problem, or the swarm initialisation implemented in jmetalpy. 5-Objective and 10-Objective DTLZ: We test the algorithm on a harder variant of the DTLZ problems with 5 objectives. The quality indicator values
Fairly Constricted Multi-objective Particle Swarm Optimization
617
Table 1. ZDT & DTLZ : FE and HV FE
SMPSO
FCPSO
HV
SMPSO
FCPSO
zdt1
72.12
56.45
5.65e−10
zdt1
3.66
3.66
0.00
zdt2
68.95
87.90
1.10e−01
zdt2
3.33
3.33
0.00
zdt3
110.90
102.01
zdt3
4.40
zdt4
32.92
45.35
5.73e−07
zdt4
3.65
zdt6
27.40
29.31
9.20e−01
zdt6
3.17
3.17
1.55e−10
dtlz1
39.49
58.40
1.37e−03
dtlz1
3.28
3.27
2.70e−03
dtlz2
2.16
dtlz2
7.34
7.33
6.02e−13
dtlz3
521.41
dtlz3
6.39
5.97
2.70e−03
dtlz4
9.31
10.52
dtlz4
7.33
7.29
0.00
dtlz5
6.67
8.68
dtlz5
4.26
4.26
0.00
dtlz6
107.27
175.70
1.06e−09
dtlz6
4.26
3.79
dtlz7
102.15
74.30
1.16e−07
dtlz7
11.07
11.01
5.11e−03
2.19 402.91
1.59e−06
4.38
4.11e−11
3.65
0.00
3.18e−04 6.33e−05
are shown in Tables 2 and 3. For 10-objective DTLZ, we do not have HV values as jmetalpy took too long to evaluate HV values. Thus, we only have IGD, EPS and SP values pertaining to these problems. The values are available in Table 4 and 5. Table 2. 5-DTLZ : HV and IGD HV
SMPSO
FCPSO
IGD SMPSO
dtlz1
1.73
6.30
dtlz1
4.36
0.07
0.00
dtlz2
22.03
27.56
dtlz2
0.64
0.42
0.00
dtlz3
0.24
5.34
dtlz3
50.77
dtlz4
30.45
30.85
dtlz4
0.28
dtlz5
6.11
6.15
dtlz5
0.07
dtlz6
2.83
0.11
0.00
dtlz7
1.12
0.50
0.00
0.00 0.00
3.80e−08 1.99e−07 1.99e−07
dtlz6
0.11
6.02
dtlz7
16.89
54.24
0.00 0.00
FCPSO
8.16
6.22e−15
0.18 0.06
0.00
1.62e−01
FCPSO outperforms SMPSO in all problems except DTLZ2, DTLZ4, DTLZ5 with respect to the spacing (SP) quality indicator in both 5-objective and 10objective realm. There is a notable exception, however, where SMPSO dominates with respect to SP in 10-objectives. Nevertheless, the gap between it and FCPSO is not significantly high. 5-Objective and 10-Objective WFG: The WFG test suite was proposed in [17] to overcome the limitations of ZDT/DTLZ test suites. For one, ZDT is limited to 2 objectives only. Secondly, the DTLZ problems are not deceptive (a notion developed in [17]) and none of them feature a large flat landscape. Moreover, they state that the nature of the Pareto-front for DTLZ-5,6 is
618
A. Bhattacharya et al. Table 3. 5-DTLZ : EPS and SP EPS SMPSO
FCPSO
SP
SMPSO
FCPSO
dtlz1
3.37
0.12
0.00
dtlz1
3.10
dtlz2
0.71
0.55
0.00
dtlz2
0.29
dtlz3
36.81
6.19
2.95e−14
dtlz3
49.05
39.85
dtlz4
0.43
0.32
6.66e−08
dtlz4
0.18
0.21
dtlz5
0.09
0.08
1.62e−01
dtlz5
0.20
0.25
0.00
dtlz6
2.96
0.00
dtlz6
1.11
0.56
0.00
dtlz7
2.77
4.11e−11
dtlz7
0.25
0.11 1.05
2.67
1.37e−03
0.34
0.00
4.11e−11
1.59e−06
0.23
2.70e−03
Table 4. 10-DTLZ : IGD and EPS IGD SMPSO dtlz1
8.09
dtlz2
0.75
dtlz3
43.87
FCPSO 1.93
1.25e−12
0.57 31.48
0.00
1.37e−03
EPS SMPSO dtlz1
5.61
dtlz2
0.65
dtlz3
33.16
FCPSO 1.57
2.98e−10
0.56 21.56
0.00
1.45e−04
dtlz4
0.58
0.00
dtlz4
0.65
dtlz5
0.06
0.08
1.97e−09
dtlz5
0.08
0.10
6.33e−05
dtlz6
0.51
0.15
1.97e−09
dtlz6
0.62
0.15
2.56e−12
dtlz7
1.45
1.32
4.22e−06
dtlz7
1.54
0.95
1.64e−02
0.43
0.52
0.00
unclear beyond 3 objectives. Lastly, the complexity of each of the previous mentioned problem is fixed for a particular problem. Hence, the WFG problems are expressed as a generalised scheme of transformations that lead an input vector to a point in the objective space. The WFG test suite is harder and a rigorous attempt at creating an infallible, robust benchmark for MOO solvers. Table 5. Spacing : 10-DTLZ and 10-WFG SMPSO SMPSO dtlz1
13.92
10.29
1.21
wfg2
0.12
wfg3
0.09
wfg4
1.93
wfg5
0.75
0.00
wfg6
0.23
0.27
5.73e−07
1.55e−10
0.62
FCPSO
wfg1
FCPSO
1.39 0.14
0.00
3.18e−04
0.12
dtlz2
0.44
dtlz3
90.24
93.83
dtlz4
0.43
0.44
dtlz5
0.22
dtlz6
0.80
0.95
1.33e−15
wfg7
2.51
2.48
1.62e−01
dtlz7
0.92
0.74
6.02e−13
wfg8
0.23
0.27
6.63e−09
wfg9
2.06
0.00
7.19e−02 1.10e−01
0.30
2.27
0.00
2.22e−16
0.91
1.43
0.00
0.00
Fairly Constricted Multi-objective Particle Swarm Optimization
619
Tables 6 and 7 contain the results for 5-objective WFG problems. The results for 10-objective WFG problems are in Tables 5 and 8. FCPSO matches SMPSO with a small margin in most problems, if not outperforming it.7 Table 6. 5-WFG : HV and IGD HV
SMPSO
wfg1
4450.24
FCPSO
wfg2
1693.23
1987.07
wfg3
14.51
14.59
wfg4
1395.45
1604.41
wfg5
2700.61
2813.04
wfg6
1540.74
wfg7 wfg8
4455.31
IGD SMPSO
FCPSO 1.88
0.00
wfg1
1.86
2.56e−12
wfg2
1.95
wfg3
2.63
2.61
2.14e−08
0.00
wfg4
1.74
1.69
2.56e−12
0.00
wfg5
1.42
1557.22
2.95e−14
wfg6
3.13
3.11
2.22e−16
1909.03
1986.85
1.33e−15
wfg7
2.04
2.00
2.56e−12
1910.25
1886.20
4.22e−06
wfg8
2.93
wfg9 2348.75
2237.12
wfg9
1.52
3.18e−04
9.32e−03
1.36e−13
1.75
1.38
2.92 1.50
0.00
0.00
0.00
2.78e−02
Table 7. 5-WFG : EPS and SP EPS SMPSO
FCPSO
wfg1
1.55
1.65
SP
SMPSO
0.00
wfg1
0.43
0.44
wfg2
7.96
wfg3
2.30e−01
7.38
1.08e−05
wfg2
0.04
0.06
6.66e−08
6.83
6.83
5.49e−01
wfg3
0.04
wfg4
2.19
2.04
1.59e−06
wfg4
0.58
0.69
2.22e−16
wfg5
5.00
5.01
3.17e−01
wfg5
0.28
0.31
5.73e−07
wfg6
7.44
7.43
6.89e−01
wfg6
0.09
0.11
1.97e−09
wfg7
2.45
2.40
9.32e−03
wfg7
0.63
0.62
8.41e−01
wfg8
8.32
8.33
1.36e−13
wfg8
0.07
0.08
5.73e−07
wfg9
2.36
wfg9
0.64
0.60
4.22e−06
3.58
0.00
FCPSO
0.05
0.00
Table 8. 10-WFG : IGD and EPS IGD SMPSO
7
FCPSO
wfg1
3.29
3.26
wfg2
6.14
5.86
wfg3
5.73
wfg4
7.09
EPS SMPSO
FCPSO
5.49e−01
wfg1
1.48
1.52
5.73e−07
wfg2
15.12
14.22
5.70
1.97e−09
wfg3
13.13
13.12
5.30
6.22e−15
wfg4
3.78
2.98
wfg5
0.17
0.16
wfg6
13.10
13.07
0.00 0.00
2.78e−02 2.67e−05 2.78e−02 5.65e−10
wfg5
0.06
wfg6
14.77
14.76 4.27
0.05
0.00
8.91e−02
wfg7
5.17
4.89
1.05e−11
wfg7
4.24
wfg8
8.59
8.60
5.49e−01
wfg8
16.63
16.64
0.00
wfg9
2.48
wfg9
1.16
0.69
0.00
2.12
0.00
2.30e−01
Complete experimental results can be found in the “data” folder of the GitHub link.
620
5
A. Bhattacharya et al.
Discussion, Conclusion and Future Works
At the time of appearance, SMPSO was the state-of-the-art MOO solver compared to other algorithms such as OMOPSO, NSGA-II. Its success is tied to the use of velocity constriction, which we have theoretically analysed and extended to the case of exponentially-averaged momentum. Moreover, there is a dearth of literature on the stochastic analysis of evolutionary algorithms. In the realm of single-objective PSO, [18] has analysed the stability of PSO considering the stochastic nature of r1 , r2 of the PSO update Eq. [1]. We have successfully performed an analysis in a similar vein. The idea proposed in this work is simple, but it could be applied for the stochastic analysis of evolutionary algorithms. In this paper, we have discussed the motivations for introducing exponentially-averaged momentum in the SMPSO framework. Having defined specific notions for constriction fairness, we have successfully incorporated exponentially-averaged momentum to SMPSO and demonstrated its performance in MOO problems. It would be beneficial to develop a large number of parameter schemes that are also fairly constricting and compare their performance. Finding a parameterization (φ1 , φ2 , β1 , β2 ) that ranges smoothly over the entire range of unfairness would help in comprehensively profiling quality indicators. Moreover, the unfairness value of an EM-SMPSO algorithm is not absolute in itself i.e., multiple parameter schemes could result in the same value of unfairness. A thorough assessment could enable the creation of selection mechanisms, density estimators, alternate notions of elitism tailored to the usage of EM in swarm-based MOO algorithms.
References 1. Kennedy, J., Eberhart, R.: Particle swarm optimization. In: Proceedings of ICNN 1995 - International Conference on Neural Networks, vol. 4, pp. 1942–1948 (1995) 2. Kennedy, J., Eberhart, R.C.: Swarm Intelligence. Morgan Kaufmann Publishers Inc., San Francisco (2001) 3. Mohapatra, R., Saha, S., Coello, C.A.C., Bhattacharya, A., Dhavala, S.S., Saha, S.: Adaswarm: augmenting gradient-based optimizers in deep learning with swarm intelligence. IEEE Trans. Emerg. Top. Comput. Intell. 6(2), 329–340 (2021) 4. Coello, C.A.C.: An introduction to multi-objective particle swarm optimizers. In: Gaspar-Cunha, A., Takahashi, R., Schaefer, G., Costa, L. (eds.) Soft Computing in Industrial Applications. Advances in Intelligent and Soft Computing, vol. 96, pp. 3–12. Springer, Berlin (2011). https://doi.org/10.1007/978-3-642-20505-7 1 5. Nebro, A.J., Durillo, J.J., Garcia-Nieto, J., Coello, C.C., Luna, F., Alba, E.: Smpso: a new pso-based metaheuristic for multi-objective optimization. In: 2009 IEEE Symposium on Computational Intelligence in Multi-Criteria DecisionMaking(MCDM), pp. 66–73 (2009) 6. Clerc, M., Kennedy, J.: The particle swarm - explosion, stability, and convergence in a multidimensional complex space. IEEE Trans. Evol. Comput. 6(1), 58–73 (2002)
Fairly Constricted Multi-objective Particle Swarm Optimization
621
7. Sierra, M.R., Coello Coello, C.A.: Improving PSO-based multi-objective optimization using crowding, mutation and ∈-dominance. In: Coello Coello, C.A., Hern´ andez Aguirre, A., Zitzler, E. (eds.) EMO 2005. LNCS, vol. 3410, pp. 505–519. Springer, Heidelberg (2005). https://doi.org/10.1007/978-3-540-31880-4 35 8. Deb, K., Pratap, A., Agarwal, S., Meyarivan, T.: A fast and elitist multiobjective genetic algorithm: Nsga-ii. IEEE Trans. Evol. Comput. 6(2), 182–197 (2002) 9. Strogatz, S.H.: Nonlinear Dynamics and Chaos: With Applications to Physics, Biology. Westview Press, Chemistry and Engineering (2000) 10. Zitzler, E., Deb, K., Thiele, L.: Comparison of multiobjective evolutionary algorithms: empirical results. Evol. Comput. 8(2), 173–195 (2000). https://doi.org/10. 1162/106365600568202 11. Ben´ıtez-Hidalgo, A., Nebro, A.J., Garc´ıa-Nieto, J., Oregi, I., Del Ser, J.: jmetalpy: a python framework for multi-objective optimization with metaheuristics. Swarm Evol. Comput. 51, 100598 (2019). https://www.sciencedirect.com/science/article/ pii/S2210650219301397 12. Audet, C., Digabel, S., Cartier, D., Bigeon, J., Salomon, L.: Performance indicators in multiobjective optimization (2018) 13. Coello, C.A.C., Lamont, G.B., Veldhuizen, D.A.V.: Evolutionary algorithms for solving multi-objective problems (Genetic and Evolutionary Computation). Springer, Berlin (2006) 14. optproblems https://pypi.org/project/optproblems. Accessed 13 Feb 2020 15. Moea framework. http://moeaframework.org/index.html. Accessed 13 Feb 2020 16. Deb, K., Thiele, L., Laumanns, M., Zitzler, E.: Scalable test problems for evolutionary multiobjective optimization. Springer, London, pp. 105–145 (2005). https:// doi.org/10.1007/1-84628-137-7 6 17. Huband, S., Hingston, P., Barone, L., While, L.: A review of multiobjective test problems and a scalable test problem toolkit. IEEE Trans. Evol. Comput. 10(5), 477–506 (2006) 18. Jiang, M., Luo, Y.P., Yang, S.Y.: Particle swarm optimization - stochastic trajectory analysis and parameter selection. In: Chan, F.T., Tiwari, M.K., (eds.) Swarm Intelligence, Rijeka: IntechOpen (2007). ch. 11. https://doi.org/10.5772/5104
Argument Classification with BERT Plus Contextual, Structural and Syntactic Features as Text Umer Mushtaq1(B) and J´er´emie Cabessa2,3 1
Laboratory of Mathematical Economics and Applied Microeconomics (LEMMA), University Paris 2 - Panth´eon-Assas, 75005 Paris, France [email protected] 2 Laboratory DAVID, UVSQ, Universit´e Paris-Saclay, 78000 Versailles, France 3 Institute of Computer Science of the Czech Academy of Sciences, 18207 Prague 8, Czech Republic
Abstract. In Argument Mining (AM), the integral sub-task of argument component classification refers to the classification of argument components as claims or premises. In this context, the content of the component alone does not actually suffice to accurately predict its corresponding class. In fact, additional lexical, contextual, and structural features are needed. Here, we propose a unified model for argument component classification based on BERT and inspired by the new prompting NLP paradigm. Our model incorporates the component itself together with contextual, structural and syntactic features – given as text – instead of the usual numerical form. This new technique enables BERT to build a customized and enriched representation of the component. We evaluate our model on three datasets that reflect a diversity of written and spoken discourses. We achieve state-of-art results on two datasets and 95% of the best results on the third. Our approach shows that BERT is capable of exploiting non-textual information given in a textual form. Keywords: NLP · Argument Mining Features as Text · Prompting
1
· Text Classification · BERT ·
Introduction
Argument Mining (AM) is the automated identification and analysis of the underlying argumentational structure in natural texts [3]. Essential sub-tasks in AM include: 1) separating argument components from non-argumentative text, 2) classifying argument components to determine their role in the argumentative process, 3) given two argument components, deciding whether they are linked or not and, 4) given two linked components, deciding whether the link This research was supported by Labex MME-DII as well as by the Czech Science Foundation, grant AppNeCo No. GA22-02067S, institutional support RVO: 67985807. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 622–633, 2023. https://doi.org/10.1007/978-981-99-1639-9_52
Argument Classification with BERT and Features as Text
623
is supporting or attacking [1,13]. AM is utilized for several popular downstream applications like Stance Recognition and Sentiment Analysis. Argumentative discourse happens in many interesting settings. Written discourse such as essays and articles consists of structured presentation of claims and premises on a certain topic [12,16]. Organized political speeches consist of argumentative dialog between two or more candidates on several issues [8,11]. Social media platforms provide an avenue for users to debate and discuss contentious issues [15]. All three settings are inherently argumentative and are ideal for AM systems. Text classification automatically classifies general text into pre-defined classes. In AM, it is the task of classifying argument components as either claims or premises. Claims are assertions made or positions taken for or against a particular topic and premises are evidence, justifications or warrants presented in support of claims. For argument component classification, however, the use of different embeddings (GloVe, ELMo, FastText, etc.) alone as sentence representation do not suffice. The role of an argument component depends, among others, on its context and position in the text and thus cannot be captured by its content alone. Therefore, additional features like lexical, indicator, discourse, syntactic, contextual and structural features have been used to enrich the sentence representation of the components [4,7,17]. Transformer models have been game changers in NLP [19]. Bidirectional Encoder Representations from Transformers (BERT) are sequential models which are pre-trained on huge amounts of data in a self-supervised manner [2]. Using a transfer learning process called fine-tuning, this pre-trained BERT model is then utilized for an NLP task on a specific dataset. BERT models have been successfully used for several NLP tasks such as text classification. In fact, the BERT embedding as sentence representation outperforms earlier embeddings (GloVe, ELMo, FastText, etc.) on text classification tasks. The ‘Pre-train, Prompt, Predict’ paradigm has also been a game-changer in NLP [9]. In this paradigm, task-specific supervised fine-tuning is replaced by additional self-supervised training involving textual prompts designed for specific downstream tasks. For instance, the sentiment of the sentence ‘I liked the movie!’ is obtained by the output of the language model on the input ‘I liked the movie! The movie was [MASK].’ which includes the sentence and a task specific prompt. For argument component classification, however, the straightforward prompting approach would not capture the necessary contextual, structural and syntactic information. Based on these considerations, we propose a novel approach, inspired by prompt engineering, which incorporates – in textual form – the contextual, structural and syntactic features necessary for argument component classification. Specifically, we introduce a novel model for argument component classification which is based on the popular BERT model. Our model incorporates contextual, structural and syntactic features as text to build a customized and enriched BERT-based representation of the argument component. We experiment with
624
U. Mushtaq and J. Cabessa
our model on three datasets: one written essays-based, one speech-based and one written social media-based. We show that: 1) our features as text sentence representation model improves upon the BERT-based component only representation, 2) our structural features as text representation outperforms the classical approach of numerically concatenating these features with BERT embedding, and 3) our model achieves state-of-art results on two datasets and 95% of the best results on the third. Overall, we situate our work within the ‘better models vs better data’ question by developing task-specific and customized data as opposed to designing more complex models. We make the code available on GitHub at: https://github.com/mohammadoumar/features as text. This paper is structured as follows. Section 2 describes the related literature that informs our work. Section 3 presents the datasets. In Sect. 4, we introduce our novel features as text model in detail. Section 5 presents the experimental setting, results and analysis of our work. Section 5 provides concluding remarks and future directions.
2
Related Works
Stab and Gurevych [17] present a features-based approach for argument component classification in the Persuasive Essays (PE) dataset (see Sect. 3). They use hand-crafted features (lexical, structural, syntactic, etc.) with Support Vector Machines (SVMs) and Conditional Random Fields (CRFs). They show that structural features, which capture the position of the component in the full text, are most useful for component classification. Hadaddan et al. [4] use both features-based and neural network-based approaches for argument component classification in the Yes We Can (YWC) political debates dataset (See Sect. 3). In the features-based approach, they use an SVM with both Bag of Words (BoW) and a custom features set (POS, syntactic, NER, etc). In the neural network-based setting, they use both a feed-forward neural network with the custom features set and an LSTM with FastText word embedding. Potash et al. [14] present a Joint Neural Model for simultaneous learning of argument component classification and link extraction between argument components in the PE and Micro-Text Corpus (MTC) datasets. This model consists of a Bi-LSTM encoder, a fully connected layer for component classification and an LSTM decoder for link identification. They use three methods for textual representation: Bag of Words (BoW), GloVe embedding and structural features. Kuribayashi et al. [7] introduce an extension to the LSTM-minus-based span representation [20] where they create separate representations of the argumentative markers (‘I think’, ‘because’, etc.) and argumentative component present in the argument unit. For textual/span representation, they use GloVe and ELMo embeddings concatenated with Bag of Words (BoW) and structural features. They experiment with the PE and MTC datasets. Mayer et al. [10] use neural network-based architectures for argument mining in a dataset of abstracts of bio-chemical healthcare trials. They combine the
Argument Classification with BERT and Features as Text
625
boundary detection and component classification tasks into one sequence tagging task. They use several static and dynamic embeddings such as BERT, GloVe, ELMo, fastText, FlairPM, etc. with various combinations of LSTMs, GRUs and CRFs as well as BERT fine-tune. We situate ourselves within the ’better models vs better data’ question. We posit that the BERT model is powerful enough to achieve improved performance if provided with task-specific enriched input data. To that end, our work is the first to investigate and implement a features as text, BERT-based model for argument component classification.
3
Datasets
In our work, we use three datasets for argument classification: Persuasive Essays (PE) [17], Yes We Can (YWC) [4] and Change My View (CMV) [6]. In this section, we present and explain the datasets. Persuasive Essays (PE): The PE dataset was introduced by Stab and Gurevych [17]. It consists of 402 essays on diverse topics selected from the online portal essayforum.com. Each essay, which is divided into several paragraphs, consists of arguments (major claims, claims and premises) for or against a position on a controversial topic. A MajorClaim is a direct assertion of the author’s position on the topic of the essay. A Claim is an assertion the author makes in support of his/her position on the topic. A Premise is a piece of evidence or warrant that the author presents to support his/her claim(s). For example, a snippet of the essay on the topic ‘Should students be taught to compete or to cooperate?’ is given below with claim(s) in bold and premise(s) in italics: First of all, [through cooperation, children can learn about interpersonal skills which are significant in the future life of all students.]claim1 [What we acquired from team work is not only how to achieve the same goal with others but more importantly, how to get along with others.]premise1 [During the process of cooperation, children can learn about how to listen to opinions of others, how to communicate with others, how to think comprehensively, and even how to compromise with other team members when conflicts occurred.]premise2 [All of these skills help them to get on well with other people and will benefit them for the whole life.]premise3 . Yes We Can (YWC): The YWC dataset was introduced by Haddadan et al. [4]. It consists of presidential and vice presidential debates in the quadrennial US presidential elections from 1960 to 2016: a total of 39 debates. The dataset consists of transcripts of these debates with claims and premises made by the candidates.
626
U. Mushtaq and J. Cabessa
Change My Views (CMV): The CMV dataset is presented by Tan et al. [6,18]. It is based on the “r/changemyview ” subreddit from the social media platform Reddit.com. It consists of 113 threads containing argumentative conversations, made up of claims and premises, between internet users on 37 controversial topics. The statistics for all three datasets are given in Table 1. Table 1. Corpus and component statistics for PE, YWC and CMV datasets. In the CMV dataset, Major Claims are called Main Claims. Dataset
Corpus Statistics
Component Statistics
Persuasive Essays (PE)
Tokens Sentence Paragraphs Essays
147,271 7,116 1,833 402
Major Claims Claims Premises Total
751 1,506 3,832 6,089
Yes We Can! (YWC)
Speech Turns Sentences Words Debates
6,601 34,103 676,227 39
Claims Premises Other Total
11,964 10,316 7,252 29,621
75,078 3,869 37 113
Main Claims Claims Premises Total
116 1,589 2,059 3,764
Change My View (CMV) Words Paragraphs Topics Files
4
Model
In this section, we introduce our novel BERT-based model for argument component classification. Our model incorporates contextual, structural and syntactic features – represented as text – instead of the usual numerical form. This approach enables BERT to build an enriched representation of the argument component. 4.1
BERT
BERT architecture consists of twelve encoder blocks of the Transformer model stacked together and 12 self-attention heads [2]. The self-attention heads enable BERT to incorporate bidirectional context and focus on any part of the input sequence. BERT builds a 768 dimensional representation – or embedding – of the input text sequence. In this work, as opposed to current approaches, we enrich the BERT model with textual representation of contextual, structural and syntactic features. These features are described below.
Argument Classification with BERT and Features as Text
4.2
627
Features
Contextual Features: Contextual features capture the full meaning of an argument component in its semantic and linguistic space. In our work, we use full sentence and topic statement as contextual features. The full sentence feature helps capture the presence of argumentative and/or discourse markers (’I think’, ’In my opinion’, etc.). These markers indicate that the component preceding or succeeding them in the sentence is more likely a claim than a premise. The topic statement feature helps discriminate between claims and premises because a claim is more likely to directly address the topic statement and, thus, be more semantically similar to it. For both the Persuasive Essays (PE) and Change My View (CMV) datasets, the contextual features are the topic of the essay/discussion and the full sentence of the argument component. For the Yes We Can (YWC) dataset, in addition to the full sentence, we use candidate name and election year as topical information. We define the textual representation of contextual features as: contextual features as text = ‘Topic: t. Sentence: s.’ where t is the topic of the essay/discussion thread or the speaker and election year of the debate speech and s is the full sentence which contains the argument component (see Example 1). Structural Features: Structural features incorporate the idea that argumentation follows a certain (perhaps fluid) pattern which can be used to discriminate between claims and premises. These features capture the location of the argument component in the whole essay and in the paragraph in which it appears. For example, claims are more likely to appear in the introductory and concluding paragraphs as well as in the beginning and towards the end of the paragraph. Premises, on the other hand, are more likely to follow a claim in the paragraph [17]. We define the textual representation of structural features as: structural features as text = ‘Paragraph Number: n. Is in introduction: i. Is in conclusion: c. Is first in paragraph: f . Is last in paragraph: l.’ where n is the paragraph number in which the argument component is present, i is Yes if the argument component is in the introduction paragraph and No otherwise, c is Yes if the argument component is in the conclusion paragraph and No otherwise, f is Yes if the argument component is the first component in its paragraph and No otherwise, and l is Yes if the argument component in the last component in its paragraph and No otherwise (see Example 1). Syntactic Features: Part-Of-Speech (POS) involves classification of English words into categories depending on their linguistic role in a sentence. These categories include noun, verb, adjective, adverb, pronoun, preposition, conjunction, interjection, numeral, article, or determiner [5]. We define the textual representation of syntactic features as:
628
U. Mushtaq and J. Cabessa
syntactic features as text = ‘Part Of Speech tags: t1 , t2 ...tn ’ where ti represents the POS tag of the i-th word in the argument component. 4.3
Combined Features as Text
We combine the textual representations of the contextual, structural and syntactic features to build an enriched BERT-based representation of the argument component. The combined representation is defined as follows: combined features as text = contextual features as text + structural features as text + syntactic features as text where ‘+’ denotes the string concatenation operation. Note that the argument component itself is included in the full sentence. Example 1: We consider an example from the Persuasive Essays (PE) dataset: argument component 398 from essay 28: argument component = ‘advertising cigarettes and alcohol will definitely affect our children in negative way’ The contextual, structural and syntactic features of this argument component are given in Table 2. The combined features as text representation of this argument component is: ‘[Topic: Society should ban all forms of advertising. Sentence: Ads will keep us well informed about new products and services, but we should also bear in mind that advertising cigarettes and alcohol will definitely affect our children in negative way.]contextual [Paragraph Number: Five. Is in introduction: No. Is in conclusion: Yes. Is first in paragraph: No. Is last in paragraph: Yes.]structural [Part of Speech tags: VERB, NOUN, CCONJ, NOUN, VERB, ADV, VERB, DET, NOUN, ADP, ADJ, NOUN]syntactic ’ where the argument component is in bold and the contextual, structural and syntactic features are contained in brackets. This combination of contextual features, structural features and argument component jointly form the enriched sentence representation that is input to the BERT model.
5
Results and Analysis
In this section, we present and analyse our results. We use our model for two tasks: 1) BERT fine-tune: We fine-tune BERT on the three datasets using our novel combined features as text sentence representation. 2) Textual vs numerical features comparison: We fine-tune BERT and compare results in two cases: first, with our structural features as text and second, with structural features numerically concatenated with BERT sentence embedding.
Argument Classification with BERT and Features as Text
629
Table 2. Features for argument component 398 of the PE dataset. The component itself is in bold. Feature
Value
essaytopic
‘Society should ban all forms of advertising.’
f ullsentence
‘Ads will keep us well informed about new products and services, but we should also bear in mind that advertising cigarettes and alcohol will definitely affect our children in negative way.’
para nr
5
is in intro
0
is in conclusion 1 is f irst in para 0 is last in para
1
pos tags
VERB, NOUN, CCONJ, NOUN, VERB, ADV, VERB, DET, NOUN, ADP, ADJ, NOUN
5.1
Experimental Setting
For the PE dataset, we use the original split: 322 essays in the train set (4,709 components) and 80 essays in the test set (1,258 components). For the YWC dataset, we also use the original split with 10,447 components in the train set, 6,567 components in the test set and 5,226 components in the validation set. For the CMV dataset, we randomly set aside 90 threads for the train set (2,720 components) and 23 threads for the test set (763 components). The implementation details of the model and experiment are presented in Table 3. 5.2
Task Results
The results of Task 1 and Task 2 are presented in Tables 4 and 5, respectively. State-of-the-art results are also shown in Table 4: F1 score of 0.86 for PE [7] and 0.67 for YWC [4] datasets. The results can be summarized as follows: • Our novel features as text sentence representation, which incorporates contextual, structural and syntactic features as text, improves upon the BERT-based component only representation. • Our features as text representation outperforms the classical approach of numerically concatenating these features with BERT embedding. • Our model achieves state-of-art results on two datasets and 95% of the best results on the third.
630
5.3
U. Mushtaq and J. Cabessa
Analysis
The addition of contextual, structural and syntactic features as text enables BERT to relate the argument component to the linguistic and argumentative flow of the whole paragraph and essay. In Task 1, for the PE dataset, the contextual, structural and syntactic parts of our combined representation improve the results compared to the BERTbased component only representation. The contextual representation improves the F1 score from 0.57 to 0.68. The combined contextual, structural and syntactic representation improves the F1 score from 0.68 to 0.82 which is 95% of the state-of-the-art result (0.86) [7]. However, the state-of-the-art approach works on paragraphs which are chunked into segmented discourse units and require argumentative marker (AM) versus argumentative component (AC) distinction in sentences. In contrast, our model simply works on the sentence level and requires no AM/AC distinction to be made. Overall, the improvement achieved by structural features emphasizes the importance of the position of the argument component in written argumentative texts, like persuasive essays. For the CMV dataset, our combined contextual and structural representation improves the F1 score from 0.76 to 0.79. Here, the contextual only part does not improve the results because the argument component and full sentence boundaries almost always coincide. By contrast, the structural features do improve the results, but to a lesser extent than in the PE dataset. This difference is explained by the fact that written text on social media platforms is less structured than written text in academic essays. In contrast with the other datasets, for YWC, the combined contextual, structural and syntactic representation does not show improvement. Nevertheless, our model outperforms the state-of-the-art results in the literature (0.69 vs 0.67) [4]. These results show that the somewhat concrete linguistic and structural flow present in the written PE dataset and (to a lesser extent) in the CMV dataset is lacking in the spoken YWC dataset because of its extemporaneous and fluid nature. Table 3. Model implementation details. We experimented with several parameter values. For each experiment, the best parameter values are available on the GitHub repository. Name
Values
Model
‘bert-base-uncased’
Embedding dimension
768
Batch size
[16, 24, 32, 48]
Epochs
[3, 6, 8, 12]
Learning rate
[1e–5, 2e–5, 1e–3, 5e–3, 5e–5]
Warmup ratio, Weight decay, Dropout 0.1, 0.01, 0.1 Loss function
Cross Entropy Loss
Argument Classification with BERT and Features as Text
631
Table 4. Task 1 results. Performance of our features as text BERT-based model on the three datasets. We report results of different combinations of features as text. MC, C and P represents the F1 scores for MajorClaim, Claim and Premise, respectively. F1 represents the macro F1 score. The abbreviations ‘strct’ and ‘synt’ stand for structural features and syntactic features respectively. The last two rows represent the state-ofthe-art results for the PE and YWC datasets.
Sentence representation
PE MC C
P
F1
YWC C P
F1
CMV C P
F1
component only
0.49 0.41 0.81 0.57 0.71 0.68 0.69 0.74 0.79 0.76
sentence
0.69 0.48 0.82 0.66 0.71 0.68 0.69 0.70 0.77 0.74
topic + sent
0.70 0.70 0.84 0.68 0.69 0.65 0.67 0.70 0.75 0.73
sent + strct
0.85 0.68 0.91 0.81 0.70 0.68 0.69 0.75 0.84 0.79
topic + sent + strct
0.86 0.68 0.91 0.81 0.69 0.65 0.67 0.76 0.80 0.78
topic + sent + strct + synt 0.86 0.71 0.91 0.82 0.71 0.62 0.67 0.76 0.78 0.77 LSTM + dist [7]
0.92 0.73 0.92 0.86 –
LSTM + word emb [4]
–
–
–
–
–
–
–
0.70 0.68 0.67 –
–
–
–
–
Table 5. Task 2 results. Comparison between structural features numerically concatenated to BERT embedding and our features as text sentence representation. Dataset
Features concatenated Features as text MC C P F1 MC C P F1
Persuasive Essays (PE)
0.82 0.57 0.90 0.76
0.86 0.68 0.91 0.81
Yes We Can! (YWC)
–
0.70 0.65 0.67
–
0.69 0.65 0.69
Change My View (CMV) –
0.70 0.76 0.73
–
0.76 0.80 0.78
Overall, we see that our features as text sentence representation, which incorporates contextual, structural and syntactic features as text, improves upon the BERT-based component only representation. In fact, the latter representation is unable to capture two significant classification clues: the context and the structure. The context includes argumentative markers (‘In my opinion’, ‘I think’, etc.) while the structure captures the position of the argument component in argumentative text. The results from Task 2 show that our features as text representation outperforms the classical representation where structural features are numerically concatenated with BERT embedding. For the PE and CMV datasets, the improvement in F1 scores is significant: from 0.76 to 0.82 and from 0.73 to 0.78, respectively. For the YWC dataset, on the other hand, the improvement is less significant: from 0.67 to 0.69. These results support our contention that the datasets for which the contextual and structural features provide the most significant information (Task 1) correspond precisely to those where the features as text representation performs the best (Task 2). In other words, the more significant
632
U. Mushtaq and J. Cabessa
the contextual and structural features, the better the features as text representation. Overall, our approach shows that BERT performs better when non-textual information is given to it as text.
6
Conclusion
In this work, we introduce a novel model for argument component classification which is based on the popular BERT model and inspired by the game-changing prompting paradigm. Our model incorporates contextual, structural and syntactic features as text to build an enriched BERT-based representation of the argument component. We experiment with our model on three datasets: two written and one spoken. We obtain three main results: 1) our features as text sentence representation model improves upon the BERT-based component only representation, 2) our structural features as text representation outperforms the classical approach of numerically concatenating these features with BERT embedding and 3) our model achieves state-of-art results on two datasets and 95% of the best results on the third. To the best of our knowledge, our work is the first to investigate and implement a model based on features as text sentence representation. Based on our results and analysis, we think that a systematic study to compare Argument Mining dynamics in written and spoken datasets would be of great benefit to the AM community. In terms of prospective research directions, we plan to merge our features as text technique with the LSTM-minus-based span representation model of Kuribayashi et al. [7]. We also intend to extend our features as text technique to other features such as syntactic and lexical [17]. We see our work as a first step towards a hybrid BERT-PROMPT end-toend AM pipeline, thereby combining two dominant NLP paradigms. We think that our features as text approach opens up exciting new possibilities both for Argument Mining as well as any NLP tasks which require feature engineering. More generally, our approach can be used in other ML settings where the features can be described as text.
References 1. Cabrio, E., Villata, S.: Five years of argument mining: a data-driven analysis. In: Proceedings of the 27th International Joint Conference on Artificial Intelligence, IJCAI 2018, pp. 5427–5433. AAAI Press (2018) 2. Devlin, J., Chang, M., Lee, K., Toutanova, K.: BERT: pre-training of deep bidirectional transformers for language understanding. In: Burstein, J., Doran, C., Solorio, T. (eds.) Proceedings of NAACL-HLT 2019, pp. 4171–4186. ACL (2019) 3. Habernal, I., Wachsmuth, H., Gurevych, I., Stein, B.: The argument reasoning comprehension task: Identification and reconstruction of implicit warrants. In: Proceedings of NAACL-HLT 2018. pp. 1930–1940. ACL (2018) 4. Haddadan, S., Cabrio, E., Villata, S.: Yes, we can! mining arguments in 50 years of US presidential campaign debates. In: Proceedings of ACL 2019, pp. 4684–4690. ACL (2019)
Argument Classification with BERT and Features as Text
633
5. Haspelmath, M.: Word classes and parts of speech, pp. 16538–16545, December 2001. https://doi.org/10.1016/B0-08-043076-7/02959-4 6. Hidey, C., Musi, E., Hwang, A., Muresan, S., McKeown, K.: Analyzing the semantic types of claims and premises in an online persuasive forum. In: Proceedings of ArgMining@EMNLP 2017, pp. 11–21. ACL (2017) 7. Kuribayashi, T., et al.: An empirical study of span representation in argumentation structure parsing. In: Proceedings of ACL, pp. 4691–4698. ACL (2019) 8. Lippi, M., Torroni, P.: Argument mining from speech: detecting claims in political debates. In: Schuurmans, D., Wellman, M.P. (eds.) Proceedings of AAAI 2016, pp. 2979–2985. AAAI Press (2016) 9. Liu, P., Yuan, W., Fu, J., Jiang, Z., Hayashi, H., Neubig, G.: Pre-train, prompt, and predict: a systematic survey of prompting methods in natural language processing. CoRR abs/2107.13586 (2021) 10. Mayer, T., Cabrio, E., Villata, S.: Transformer-based argument mining for healthcare applications. In: Proceedings of ECAI 2020, pp. 2108–2115. IOS Press (2020) 11. Menini, S., Cabrio, E., Tonelli, S., Villata, S.: Never retreat, never retract: argumentation analysis for political speeches. In: Proceedings of AAAI 2018, vol. 32, no. 1 (2018) 12. Moens, M.F., Boiy, E., Palau, R.M., Reed, C.: Automatic detection of arguments in legal texts. In: Proceedings of ICAIL 2007, pp. 225–230. ACM (2007) 13. Peldszus, A., Stede, M.: From argument diagrams to argumentation mining in texts: a survey. Int. J. Cogn. Inf. Nat. Intell. 7, 1–31 (2013) 14. Potash, P., Romanov, A., Rumshisky, A.: Here’s my point: joint pointer architecture for argument mining. In: Proceedings of EMNLP 2017, pp. 1364–1373. ACL (2017) 15. Somasundaran, S., Wiebe, J.: Recognizing stances in online debates. In: Proceedings of ACL/IJCNLP 2009, pp. 226–234. ACL (2009) 16. Song, Y., Heilman, M., Beigman Klebanov, B., Deane, P.: Applying argumentation schemes for essay scoring. In: Proceedings of ArgMining@ACL 2014, pp. 69–78. ACL (2014) 17. Stab, C., Gurevych, I.: Parsing argumentation structures in persuasive essays. Comput. Linguist. 43(3), 619–659 (2017) 18. Tan, C., Niculae, V., Danescu-Niculescu-Mizil, C., Lee, L.: Winning arguments: Interaction dynamics and persuasion strategies in good-faith online discussions. In: Proceedings of WWW 2016, pp. 613–624. ACM (2016) 19. Vaswani, A., et al.: Attention is all you need. In: Proceedings of NIPS 2017, pp. 6000–6010. Curran Associates Inc. (2017) 20. Wang, W., Chang, B.: Graph-based dependency parsing with bidirectional LSTM. In: Proceedings of ACL 2016, pp. 2306–2315. ACL (2016)
Variance Reduction for Deep Q-Learning Using Stochastic Recursive Gradient Haonan Jia1 , Xiao Zhang2,3 , Jun Xu2,3(B) , Wei Zeng4 , Hao Jiang5 , and Xiaohui Yan5 1
3
School of Information, Renmin University of China, Beijing, China [email protected] 2 Gaoling School of Artificial Intelligence, Renmin University of China, Beijing, China {zhangx89,junxu}@ruc.edu.cn Beijing Key Laboratory of Big Data Management and Analysis Methods, Beijing, China 4 Baidu Inc., Beijing, China [email protected] 5 Huawei Technologies, Shenzhen, China {jianghao66,yanxiaohui2}@huawei.com Abstract. Deep Q-learning often suffers from poor gradient estimations with an excessive variance, resulting in unstable training and poor sampling efficiency. Stochastic variance-reduced gradient methods such as SVRG have been applied to reduce the estimation variance. However, due to the online instance generation nature of reinforcement learning, directly applying SVRG to deep Q-learning is facing the problem of the inaccurate estimation of the anchor points, which dramatically limits the potentials of SVRG. To address this issue and inspired by the recursive gradient variance reduction algorithm SARAH, this paper proposes to introduce the recursive framework for updating the stochastic gradient estimates in deep Q-learning, achieving a novel algorithm called SRG-DQN. Unlike the SVRG-based algorithms, SRG-DQN designs a recursive update of the stochastic gradient estimate. The parameter update is along an accumulated direction using the past stochastic gradient information, and therefore can get rid of the estimation of the full gradients as the anchors. Additionally, SRG-DQN involves the Adam process for further accelerating the training process. Theoretical analysis and the experimental results on well-known reinforcement learning tasks demonstrate the efficiency and effectiveness of the proposed SRG-DQN algorithm.
Keywords: Deep Q-learning
· Variance reduction
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-981-99-1639-9 53. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 634–646, 2023. https://doi.org/10.1007/978-981-99-1639-9_53
Variance Reduction for Deep Q-Learning
1
635
Introduction
Recent years have witnessed the dramatic progress of deep reinforcement learning (RL) in a variety of challenging tasks including computer games, robotics, natural language process, and information retrieval. Amongst the RL algorithms, deep Q-learning is a simple yet quite powerful algorithm for solving sequential decision problems [8,9]. Roughly speaking, deep Q-learning makes use of a neural network (Q-network) to approximate the Q-value function in traditional Q-learning models. The system state is given as the input and the Q-values of all possible actions are generated as the output. The learning of the parameters in Q-network amounts to sequences of optimization problems on the basis of the stored agent’s experiences. Stochastic gradient descent (SGD) is often employed to solve these optimization problems. That is, at each iteration of the optimization, to calculate the parameter gradients, the agent samples an action according to the current Q-network, issues the action to the environment, gathers the reward, and moves to the next state. The reward is used as the supervision information for calculating the gradient for updating the Q-network parameters. The gradient points in the direction of maximum increase the possibility of getting high accumulative future rewards. In real-world applications, SGD method in deep Q-learning often suffers from the inaccurate estimation of the gradients. The high variance gradient inevitably hurts the efficiency and effectiveness of the deep Q-learning algorithms. How to reduce the variance has become one of the key problems in deep RL. Research efforts have been undertaken to solve this problem. For example, AveragedDQN [2] extends the traditional DQN algorithm by averaging the previously learned Q-values estimates, achieving a variance reduced gradient estimation with an approximation error guarantee. More recently, SVR-DQN [18] proposed an optimization strategy by combining the stochastic variance reduced gradient (SVRG) [5] technique and the deep Q-learning, called SVR-DQN. It has been shown that reducing the variance leads to a more stable and accurate training procedure. The Adam [6] optimization algorithm is an extension to stochastic gradient descent. Though it was mainly designed for optimizing neural networks, Adam can also be directly applied to improve the training process of deep Q-learning. More methods on variance reduction for deep Q-learning please refer to [12,14]. SVRG has also been applied to policy gradient methods in RL as an effective variancereduced technique for stochastic optimization, such as off-line control [17], policy evaluation [4], and on-policy control [11]. The convergence rate of the variancereduced policy gradient has been proved showing its advantages over vanilla methods [15]. [16] applied the recursive variance reduction techniques to policy gradient algorithms and proved the state-of-the-art convergence rate of policy gradient methods.
636
H. Jia et al.
Though preliminary successes have been achieved, current methods are far from optimal because they ignored the essential differences between RL and traditional machine learning. The SVRG-based methods need to pre-calculate full gradients as the anchors. The anchors are crucial for finding more accurate gradient direction estimations in the down-stream parameter update. When being executed on a fixed training set, SVRG-based methods can easily estimate the anchors by scanning all of the training instances. In deep Q-learning, however, the anchors cannot be accurately estimated anymore because the learning is conducted in an online manner: (1) In deep Q-learning, the training instances (i.e., the sampled transitions) are gradually generated with the training goes on, via issuing actions to the environment at each iteration. The algorithm cannot access the instances that will be generated in the future iterations; (2) In deep Q-learning, the selection of the actions is guided by the DQN with current parameters. Therefore, the generated instances at different iterations cannot be identically distributed, as the DQN parameters have been updated. The phenomenon makes the problem of inaccurate estimation of the anchors more severe. Empirical analyses also have shown that the inaccurate estimation of the anchors greatly impacted the performances of the SVRG-based methods. In this paper, to address the issue and inspired by the variance reduction algorithm SARAH [10], we propose to adopt the recursive gradient estimation mechanism in SARAH into the training iterations of deep Q-learning, achieving a novel deep Q-learning algorithm called SRG-DQN. Specifically, SRG-DQN contains an outer loop which samples N training instances (i.e., N transitions including the state, action, reward, and next-state) based on the current Qnetwork and from the experience replay, and an inner loop which first estimates the stochastic gradients recursively and then updates the Q-network parameters. Besides, the Adam process is executed at the end of the outer loop for further improving the efficiency of the training. Theoretical and experimental analyses demonstrate that the recursive gradient estimation mechanism successfully addresses the problem of inaccurately anchors in SVRG-based methods. It also heritages the advantages from SARAH including the fast convergence rate, and the stable and reliable training. We conduct experiments on RL benchmark tasks to evaluate the proposed SRG-DQN algorithm. Experimental results show that SRG-DQN outperforms the state-of-the-art baselines including SGD-based and SVRG-based deep Q-learning algorithms, in terms of reward scores, convergence rates, and training time. Empirical analyses also show that SRG-DQN dramatically reduces the variance of the estimated gradients, discovering how and why SRG-DQN can improve the performance of the baseline algorithms.
2 2.1
Related Work SGD for Deep Q-Learning
In Q-learning, it is assumed that the agent will perform the sequence of actions that will eventually generate the maximum total reward (return). The return is also called the Q-value and the strategy is formalized as:
Variance Reduction for Deep Q-Learning
Q(s, a) = r(s, a) + γ max Q(s , a ), a ∈A(s)
637
(1)
where γ ∈ [0, 1] is a discount factor which controls the contribution of rewards in the future, s is the current state, A(s) contains all of the candidate actions under state s, a is the selected action, r(s, a) is the received reward after issuing a at s, and s is the next state that the system moves to after issuing a. The equation states that the Q-value yielded from being at state s and performing action a is the immediate reward r(s, a) plus the highest Q-value possible from the next state s . It is easy to see that Q(s, a) helps the agent to figure out exactly which action to perform. Traditionally, Q-value is defined as a table with which the agent figures out exactly which action to perform at which state. However, the time and space complexities become huge when facing large state and action spaces. Deep neural networks have been used to approximate the Q-value function, called deep Qlearning. The learning of the parameters in the Q-network amounts to a serious of optimization problems. Specifically, assuming that at time step t, the system is at state St and the agent issues an action At . After that at the time step t + 1, it receives a reward Rt+1 and transits to state St+1 . Therefore, we collect a transition tuple (St , At , Rt+1 , St+1 ). The loss function, therefore, is defined as the mean squared error of the Q-value predicted by the Q-network and the target Q-value, where the target Q-value is derived from the Bellman equation y(St , At ) = Rt+1 + γ max Q(St+1 , a; θ), a∈A(St )
(2)
where Q(St+1 , a; θ) is the Q-network with parameters θ that predicts the Qvalue for the next state with the same Q-network. Stochastic gradient descent has been employed for conducting the optimization. Given a sampled transition (S, A, R, S ), the stochastic gradient can be estimated as: g = ∇(y(S, A) − Q(S, A; θ))2 = 2(y(S, A) − Q(S, A; θ))∇Q(S, A; θ),
(3)
where ∇Q(S, A; θ) calculates the gradient of Q w.r.t. the parameter θ. 2.2
Variance Reduced Deep Q-Learning
The original stochastic gradient descent based on a single transition often hurts from the problem of high gradient estimation variances. There are many ways to accelerate SGD convergence from the perspective of reducing variance, such as SAG [13], SAGA [3], and SVRG [5]. Researchers have combined the variance reduction techniques proposed in traditional machine learning with deep Q-learning. For example, Zhao et al. [18] proposed an algorithm called Stochastic Variance Reduction for Deep Q-learning (SVR-DQN) which combines SVRG with utilizes the optimization strategy of SVRG during the learning. Specifically, at each outer iteration s = 1, 2, · · · , S, the algorithm samples a batch of N training transitions D = {(Si , Ai , Ri+1 , Si+1 )}N i=1 , and calculates a full gradient according to Eq. (3) on D as the anchor:
638
H. Jia et al.
˜= g
N 1 2(yi − Q(Si , Ai ; θ0s ))∇Q(Si , Ai ; θ0s ), N i=1
(4)
where yi = Ri+1 + γ maxa ∈A(Si+1 ) Q(Si+1 , a ; θ0s ), and θ0s is the network parameter at the beginning of the s-th outer iteration. In its inner iteration indexed by m, for each sampled transition (S, A, R, S ) ∈ D, the stochastic gradients w.r.t. up to date parameters are calculated: s s s gm = 2(ym − Q(S, A; θm ))∇Q(S, A; θm ),
(5)
s where ym = R + γ maxa ∈A(S ) Q(S , a ; θm ). Similarly, the stochastic gradients w.r.t. ‘old parameters’ are also calculated:
g0s = 2(y0 − Q(S, A; θ0s ))∇Q(S, A; θ0s ),
(6)
wherey0 = R + γ maxa ∈A(S ) Q(S , a ; θ0s ). Finally based on Eq. (4), (5) and (6), the variance reduced gradient is calculated as s ˜. Δ = gm − g0s + g
(7)
Besides, at each outer iteration, SVR-DQN obtains a more accurate estimation of the gradient using Adam process, which can accelerate the training of deep Q-learning and improve the performances [18].
3
Our Approach: SRG-DQN
In this section, we analyze the limitations of the variance reduction mechanism in SVR-DQN and propose a novel deep Q-learning algorithm called SRG-DQN. 3.1
Problem Analysis
˜ are crucial they provide In SVR-DQN, the accurate estimations of the anchors g stable baselines to adjust the stochastic gradient. In traditional machine learning, the SVRG anchors can be accurately estimated based on the whole train data (i.e., full gradients). In deep Q-learning, however, the training instances are not fixed in advance but we need to collect them at each parameter change. Therefore, the estimated anchors are only based on N instances sampled at the previous and current iterations. The phenomenon inevitably makes the estimated anchors inaccurate due to the following two reasons.
Variance Reduction for Deep Q-Learning
639
First, deep Q-learning algorithms are usually run in an online manner by nature. At each iteration, the algorithm samples an (or some) instance(s) as the new training data. Therefore, it is impossible for the algorithm to estimate the full gradients as the anchors, because only a part of the whole training instances (the instances sampled at the past iterations) are accessible. Second, the instances sampled at Fig. 1. Training curves of “SVR-DQN”, different iterations are based on the “SVR-DQN with exact anchors” and DQN Q-network with different parameters. baseline. The shaded area presents one stanThat is, at each iteration, the agent dard deviation. We focus on the influence of first sample an action guided by the anchors and omit the Adam process in these Q-network Q(s, a; θ) (for example, - algorithms. greedy), and then use the sampled instance to update parameters. Therefore, the Q-network is continuously updated during the training, and the sampled instances at different iterations would belong to different distributions. This will make the estimated anchors cannot reflect the directions that the exact gradients should point to. We conduct an experiment to show the impact of the inaccurate anchors. Specifically, based on the Mountain Car task, we compared the performances of deep Q-learning with SVRG (SVR-DQN) and its variation in which the anchor points could be exactly estimated. To do this, we first ran the existing SVR-DQN and recorded all of the sampled transitions. Then, we re-ran SVR-DQN using iden˜ tical settings and training transitions at each iteration, except estimating the g in line 5 with all of the recorded transitions (denoted as “SVR-DQN with exact anchors”). ˜ are estimated on the whole training data as that In this way, the anchors g of in traditional machine learning Fig. 1 shows the training curves of the two models”. We can see that compared with “SVR-DQN”, “SVR-DQN with exact anchors” converged faster, better, and with lower variances. The results clearly indicate that the inaccurate estimation of the anchors can hurt the power of SVRG. We conclude that directly applying SVRG to deep Q-learning violates the basic assumptions of SVRG, and therefore limits its full potentials. 3.2
Recursive Gradient Deep Q-Learning
To address the problem and inspired by the algorithm of SARAH [10], in this paper we propose a novel algorithm called Stochastic Recursive Gradient Deep Q-Network (SRG-DQN). Different from the SVRG-based methods, SRG-DQN resorts to the recursive gradients rather than the full gradients, as the anchors. In this way, SRG-DQN gets rid of the inaccurate estimation of the anchors.
640
H. Jia et al.
As shown in Algorithm 1, SRG-DQN contains an outer loop indexed by s. At each outer loop, N training instances D are sampled and the initial anchor point Δs0 is calculated based on all of the N sampled instances: Δs0 = ∇
N 1 (yi − Q(Si , Ai ; θ0s ))2 N i=1
=
N 1 2(yi − Q(Si , Ai ; θ0s ))∇Q(Si , Ai ; θ0s ) N i=1
(8)
where each of the targets yi = Ri+1 + γ maxa ∈A(Si+1 ) Q(Si+1 , a ; θ0s ) is the Q-value derived from the Bellman equation, for i = 1, · · · , N . Δs0 is first used to update the parameters. At each iteration of its inner loop m, the algorithm first randomly samples one training instance (S, A, R, S ) ∈ D. The stochastic gradient w.r.t. current s , is calculated as follows: up-to-date parameters, denoted as θm s s s s gm = ∇(ym − Q (S, A; θm )) = 2 (ym − Q(S, A; θm )) ∇Q(S, A; θm ), 2
(9)
s where the target ym = R + γ maxa ∈A(S ) Q(S , a ; θm ) . Similarly, the stochastic s gradient w.r.t. the previous inner loop parameter θm−1 is also calculated: s s s gm−1 = 2 ym−1 − Q(S, A; θm−1 ) ∇Q(S, A; θm−1 ), s where ym−1 = R + γ maxa ∈A(S ) Q(S , a ; θm−1 ). Following the recursive gradient defined in [10], the final gradient at current loop, Δsm , can be defined recursively. That is, using the previous loop gradient Δsm−1 as the anchor point to estimate the current loop gradient: s s Δsm = gm − gm−1 + Δsm−1 .
(10)
Note that the anchor for the first loop (i.e., Δs0 , m = 0) is the full gradient calculated in the outer loop. To further improve the performances and in light of Adam optimizer [6], we also propose to introduce the Adam process in SRGDQN. Specifically, after the ending of each inner loop, an Adam process is executed for further updating the parameters, including calculating a biased first moment, a biased second raw moment, a bias-corrected first moment, and a biascorrected second raw moment, and finally conducting the parameter updating. Detailed description of Adam process in SRG-DQN can be found in the suppleS mentary material. Following the practices in [7,10], SRG-DQN takes θM +1 as its final output. The algorithm is also suitable for a mini-batch version. 3.3
Theoretical Analysis
We analyze the convergence of SRG-DQN as follows. Proof of Theorem 1 can be found in the supplementary material. In the inner loop, the optimization in DQN can be formulated as the empirical risk minimization problem: min F (θ) := θ∈Θ
M 1 fi (θ), M i=1
(11)
Variance Reduction for Deep Q-Learning
641
2
where fi (θ) = [yi − Q(Si , Ai ; θ)] . The optimization problem in DQN is nonconvex due to the composite structure of neural networks. Given the error parameter ε > 0, the goal is to search for an ε-optimal point θ ∈ Θ such that E[∇F (θ)2 ] ≤ ε.
(12)
First, similar to that of in [1], we give the definition of the Incremental Firstorder Oracle (IFO): Definition 1. IFO takes a point θ ∈ Θ and an index i ∈ {1, 2, . . . , M } as inputs and returns the pair ∇fi (θ). Then the convergence rate of the algorithms can be measured by the oracle complexity. The oracle complexity is defined as the smallest number of queries to IFO leading to an ε-optimal point. Further assuming that each function fi (θ) is βi smooth and bounded for i ∈ [M ] (that is, ∇fi (θ) is βi -Lipschitz continuous and |fi (θ)| ≤ Bi , i ∈ [M ], θ ∈ Θ), we have the following Theorem 1 for SRG-DQN: M 2 Theorem 1. Let μ = i=1 βi /M and Bmax = supi∈[M ] {Bi }. For SRGDQN within a single outer loop (in outer iteration s ∈ [S]), √ setting η ≤ √ √ M /ε queries 2/[ μ( 4M + 1 + 1)] to attain an ε-optimal point requires Ω to IFO. Algorithm 1. Stochastic Recursive Gradient for Deep Q-Learning (SRG-DQN) Require: Deep Q-function Q, # epochs S, epoch size M , discount factor γ, step size η Ensure: Model parameters θ 0 1: Initialize θM +1 ← θ0 2: for s = 1 to S do s−1 3: θ0s ← θM +1 s 4: sample N transitions D = {(Si , Ai , Ri+1 , Si+1 )}N i=1 according to Q(s, a; θ0 ) 1 N s s s 5: Δ0 ← N i=1 2(yi − Q(Si , Ai ; θ0 ))∇Q(Si , Ai ; θ0 ) where yi = Ri+1 + γ max Q(Sk+1 , a; θ0s ) a∈A(Sk+1 )
6: 7: 8: 9: 10: 11: 12: 13: 14: 15: 16: 17:
θ1s ← θ0s − ηΔs0 {update with full gradient} for m = 1 to M do randomly select a transition (S, A, R, S ) ∈ D s ) ym ← R + γ max Q(S , a ; θm a ∈A(S )
s ← 2(y − Q(S, A; θ s ))∇Q(S, A; θ s ) {gradient w.r.t. up-to-date parameters} gm m m m s ym−1 ← R + γ max Q(S , a ; θm−1 ) a ∈A(S )
s s s gm−1 ← 2(ym−1 − Q(S, A; θm−1 ))∇Q(S, A; θm−1 ) {gradient w.r.t. previousiteration parameters} s − gs s Δsm ← gm m−1 + Δm−1 {recursive gradient which using the previous one as the anchor} s s − ηΔs {update parameters} θm+1 ← θm m end for {Adam process} end for S return θM +1
642
H. Jia et al.
M + M 2/3 /ε for the Remark 1. While the oracle complexity of SVRG is O optimization problem in DQN [7], SRG-DQN achieves a lower oracle complexity w.r.t. the number of the epoch size M , which indicates SRG-DQN has a faster convergence rate than that of SVRG for DQN.
4 4.1
Experiments Experimental Settings
We follow [16] to conduct the experiments on benchmark RL environments, including the Cartpole, Mountain Car and Pendulum problems. Following the setups in [9], -greedy strategy is used for the exploitation and exploration, where decreases linearly from initial value 0.1 to 0.001. The transfer instances generated during the interactions between the agent and the environment are stored in the experience replay memory, which adopted a first-in-first-out mechanism to store the transition data. When performing the gradient descent, the algorithm sampled 64 transition instances from the experience replay uniformly as the training batch data. The learning frequency is set to 16, which means that the batch data is sampled once every 16 rounds. In all the experiments, the DQN algorithm in [9] is adopted as the main architecture. Our algorithms are called “SRG-DQN” and “SRG-DQN without Adam Process”. The DQN optimized by SGD (called “DQN with SGD”) and DQN optimized by SVRG (called “SVR-DQN”) are chosen as the baselines.
Fig. 2. Performance curves of DQN with SGD (blue), SVR-DQN (green), SRG-DQN without Adam process (purple) and SRG-DQN (orange) for three tasks, where the shaded area represents standard deviation, the ‘Steps’ represents the outer iteration s in Algorithm 1 and the ‘Avg Reward’ represents the average rewards, the ‘Episode’ represents a complete trajectory and the ‘Avg Reward’ represents the average return per trajectory. (Color figure online)
4.2
Experimental Results
We conduct average reward experiments in three tasks, in which the average reward is used to measure the performance of the agent. For fair comparisons, DQN structures in all algorithms are set with the same parameters.
Variance Reduction for Deep Q-Learning
643
The left part of Fig. 2 compares the performance of the four algorithms on the Mountain Car task. To encourage the car to explore, we replace the reward function from the original discrete value to a continuous function that is correlated with the car’s position. Without limiting the episode length, the four algorithms all run 100,000 steps which means the faster the car reaches the goal, the higher average reward of each action will get. We omit the standard deviation of DQN with SGD in this figure, which is obviously poor. The middle part of Fig. 2 compares the performance of the four algorithms on Pendulum task. To facilitate DQN for choosing the action, we decompose it into 12 parts with equal distances. All the four algorithms do not limit the number of episodes, run 20,000 steps, and repeat 50 rounds. From the results, our algorithm achieves a fast convergence rate and has the optimal average reward with reduced variances. The right part of Fig. 2 compares the performance of the four algorithms under Cartpole task in which we need to keep the pole standing, and once it falls, the task is terminated. So we replace the average rewards for each step with the average reward for each episode. To accelerate convergence, we replace the reward function from the discrete value of 0/1 to a continuous function related to observations. The score is higher when the pole is straighter. All four algorithms run 800 episodes and repeat 10 rounds. From the results, our algorithm has an excellent average reward while significantly reducing the variance.
Fig. 3. Results of the experimental analysis on Mountain Car. Left: average steps w.r.t. episodes; Middle: 2 distances between the exact anchors and recursive anchors in SRG-DQN and SVR-DQN. Right: SRG-DQN with different optimizer processes.
4.3
Experimental Analysis
We experimentally analyze the reasons why SRG-DQN is effective. We first conduct the episode-average size experiment, in which four algorithms are run with 150 episodes and 100 rounds under the same model parameter settings. The convergence rate is measured by the average size of the episode, and the stability is evaluated by the standard deviation. The experimental results are shown in the left part of Fig. 3, where the bold line represents the average episode size of multiple experiments, and the shading represents the standard deviation. We can observe that SRG-DQN has significantly improved the convergence rate and the stability compared to the traditional DQN with SGD. Compared with SVRDQN, SRG-DQN further shortens the average episode length, reduces the standard deviation. The orange line represents SRG-DQN. From the results, Adam
644
H. Jia et al.
process and variance reduction algorithm can play a complementary role, and their combination can further accelerate the algorithm convergence and improve the stability of the agent. In addition, we computed the 2 distances between the exact anchors and the recursive anchors in SVR-DQN and SRG-DQN. From the results about distances in the middle part of Fig. 3, we can observe that the recursive anchors in SRG-DQN can significantly reduce the distances from the exact anchors, which is another reason why SRG-DQN can achieve better performances. We also compared the effects of the combined use of variance reduction methods and different optimizer processes. As shown in the right part of Fig. 3, adding an optimization process will further improve the performance of the algorithm, and due to the combined use of first-order gradient and secondorder gradient information, the Adam process has proven to be a better choice than the Adagrad. In order to explore whether our method really reduces the variance of gradient estimation, we compared the performances of SRG-DQN with SVR-DQN on the variance reduction for gradients. We calculated the standard deviations of the gradients with respect to the parameters on each dimension separately and then summed them up. From the results in Table 1, we can find that, compared with SVR-DQN, SRG-DQN significantly reduces the standard deviation, and it is almost completely superior to SVR-DQN at most steps. Thus, we can conclude that SRG-DQN converges to the function controlling the variance of the gradients, and achieves an improvement on SVR-DQN for variance reduction, which demonstrates the effectiveness of our stochastic recursive gradient for the variance reduction in DQN. Table 1. The comparisons between SVR-DQN and SRG-DQN in terms of the standard deviation of the gradients on Mountain Car task, where the standard deviation is computed by summing the standard deviations of each element in the first-layer network gradient vector. We recorded the standard deviation once every 1,000 steps for the first 10,000 steps (first three rows) and last 10,000 steps (last three rows). # Steps (first) 1,000 SVR-DQN SRG-DQN
5
3,000
4,000
5,000
6,000
7,000
8,000
9,000
10,000
0.229 0.790 0.430 0.474 0.626 1.176 0.500 0.388 0.638 0.739 0.122 0.395 0.389 0.520 0.625 0.632 0.966 0.966 0.394 0.562
# Steps (last) 1,000 SVR-DQN SRG-DQN
2,000
2,000
3,000
4,000
5,000
6,000
7,000
8,000
9,000
10,000
0.306 0.286 0.293 0.212 0.188 0.318 0.443 0.271 0.193 0.205 0.225 0.334 0.254 0.353 0.154 0.247 0.182 0.229 0.188 0.203
Conclusion
This paper proposes a novel deep Q-learning algorithm using stochastic recursive gradients, which reduces the variance of the gradient estimation. The proposed algorithm introduces the recursive framework for updating the stochastic
Variance Reduction for Deep Q-Learning
645
gradient and computing the anchor points. Adam process is involved for achieving a more accurate gradient direction. Theoretical analysis and empirical comparisons showed that the proposed algorithm outperformed the state-of-the-art baselines in terms of reward scores, convergence rate, and stability. The proposed stochastic recursive gradient provides an effective scheme for variance reduction in reinforcement learning. Acknowledgements. This work was funded by the National Key R&D Program of China (2019YFE0198200), National Natural Science Foundation of China (61872338, 62102420, 61832017), Beijing Outstanding Young Scientist Program NO. BJJWZYJH012019100020098, Intelligent Social Governance Interdisciplinary Platform, Major Innovation & Planning Interdisciplinary Platform for the “Double-First Class” Initiative, Renmin University of China, and Public Policy and Decision-making Research Lab of Renmin University of China.
References 1. Agarwal, A., Bottou, L.: A lower bound for the optimization of finite sums. In: Proceedings of the 32nd International Conference on Machine Learning, pp. 78–86 (2015) 2. Anschel, O., Baram, N., Shimkin, N.: Averaged-DQN: variance reduction and stabilization for deep reinforcement learning. In: Proceedings of the 34th International Conference on Machine Learning, pp. 176–185 (2017) 3. Defazio, A., Bach, F., Lacoste-Julien, S.: SAGA: a fast incremental gradient method with support for non-strongly convex composite objectives. In: Advances in Neural Information Processing Systems, vol. 27, pp. 1646–1654 (2014) 4. Du, S.S., Chen, J., Li, L., Xiao, L., Zhou, D.: Stochastic variance reduction methods for policy evaluation. In: Proceedings of the 34th International Conference on Machine Learning, pp. 1049–1058 (2017) 5. Johnson, R., Zhang, T.: Accelerating stochastic gradient descent using predictive variance reduction. In: Advances in Neural Information Processing Systems, vol. 26, pp. 315–323 (2013) 6. Kingma, D., Ba, J.: Adam: a method for stochastic optimization. arXiv preprint arXiv:1412.6980 (2014) 7. Li, B., Ma, M., Giannakis, G.B.: On the convergence of SARAH and beyond. arXiv:1906.02351 (2019) 8. Mnih, V., et al.: Playing atari with deep reinforcement learning. arXiv:1312.5602 (2013) 9. Mnih, V., et al.: Human-level control through deep reinforcement learning. Nature 518(7540), 529–533 (2015) 10. Nguyen, L.M., Liu, J., Scheinberg, K., Tak´ aˇc: SARAH: a novel method for machine learning problems using stochastic recursive gradient. In: Proceedings of the 34th International Conference on Machine Learning, pp. 2613–2621 (2017) 11. Papini, M., Binaghi, D., Canonaco, G., Pirotta, M., Restelli, M.: Stochastic variance-reduced policy gradient. In: Proceedings of the 35th International Conference on Machine Learning, pp. 4023–4032 (2018) 12. Romoff, J., Henderson, P., Pich´e, A., Francois-Lavet, V., Pineau, J.: Reward estimation for variance reduction in deep reinforcement learning. arXiv preprint arXiv:1805.03359 (2018)
646
H. Jia et al.
13. Roux, N.L., Schmidt, M., Bach, F.R.: A stochastic gradient method with an exponential convergence rate for finite training sets. In: Advances in Neural Information Processing Systems, vol. 25, pp. 2663–2671 (2012) 14. Sabry, M., Khalifa, A.M.A.: On the reduction of variance and overestimation of deep Q-learning. arXiv preprint arXiv:1910.05983 (2019) 15. Xu, P., Gao, F., Gu, Q.: An improved convergence analysis of stochastic variance reduced policy gradient. In: Proceedings of the 35th Conference on Uncertainty in Artificial Intelligence, p. 191 (2019) 16. Xu, P., Gao, F., Gu, Q.: Sample efficient policy gradient methods with recursive variance reduction. In: Proceedings of the 8th International Conference on Learning Representations (2020) 17. Xu, T., Liu, Q., Peng, J.: Stochastic variance reduction for policy gradient estimation. arXiv:1710.06034 (2017) 18. Zhao, W.Y., Peng, J.: Stochastic variance reduction for deep Q-learning. In: Proceedings of the 18th International Conference on Autonomous Agents and MultiAgent Systems, pp. 2318–2320 (2019)
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer Xinlei Huang1 , Jialiang Tang1 , Haifeng Qing1 , Honglin Zhu1 , Ning Jiang1(B) , Wenqing Wu2 , and Peng Zhang2 1
School of Computer Science and Technology, Southwest University of Science and Technology, Mianyang 621010, China [email protected] 2 School of Mathematics and Physics, Southwest University of Science and Technology, Mianyang 621000, Sichuan, China
Abstract. Knowledge distillation (KD) is a widely used model compression technology to train a superior small network named student network. KD promotes a student network to mimic the knowledge from the middle or deep layers of a large network named teacher network. In general, existing knowledge distillation methods neglect to explore the shallow features of neural networks that contain informative texture knowledge. In this paper, we propose Shallow Texture Knowledge Distillation (SeKD) for distilling these informative shallow features. Moreover, we investigate the traditional machine learning method and adopt Gradient Local Binary Pattern (GLBP) for shallow features extraction. However, we have found that using GLBP to process shallow features will introduce an additional computational burden. To reduce computation, we design a texture attention module to optimize shallow feature extraction for distilling. We have conducted extensive experiments to evaluate the effectiveness of our proposed method. When training on the CIFAR-10 and CIFAR-100 datasets, the student network WideResNet162 trained by SeKD achieves 94.35% and 75.90% accuracies, respectively. Keywords: Knowledge distillation · Shallow features mechanism · Shallow texture knowledge distillation
1
· Attention
Introduction
With the continuous development of deep learning, deep neural networks (DNNs) have achieved great success in computer vision tasks, such as image classification [11,19], object detection [17], and semantic segmentation [2]. In general, DNNs with excellent performance are inclined to be cumbersome, which require massive parameters and computation. However, most small terminal equipment in the real world only are resource-limited and unable to deploy these cumbersome DNNs. It dramatically hinders the practical application of DNNs. To solve this problem, various model compression techniques [4,13,16] have been c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 647–658, 2023. https://doi.org/10.1007/978-981-99-1639-9_54
648
X. Huang et al.
proposed. Among these techniques, knowledge distillation (KD) is a widely used technology due to its effectiveness. KD utilizes a large teacher network to supervise the training of a small student network, which is also concerned in this paper. Knowledge transfer is the core of knowledge distillation. In the first knowledge distillation framework proposed by Hinton et al. [4], they transfer the soft target from teacher to student. Subsequently, many researchers have effectively explored the knowledge for transferring in different locations of neural networks. Zagoruyko et al. [28] extracted the expressive attention features from the intermediate layer of the network for knowledge transmission to enhance student’s performance. Zhang et al. [29] transferred knowledge in the deeper layers of the network to increase the ability of student networks to generate high-frequency information. However, the above methods ignore the importance of shallow features, thus leading to the performance of the student networks being still limited. Shallow features contain rich texture information, which is applied in various image processing fields. Especially in traditional machine learning for image classification, many methods can well handle shallow features. Ojala et al. [14] proposed the LBP algorithm to extract rich texture information in images. Further, Jiang et al. [6,7] and Tang et al. [20] proposed GLBP and used it as the first layer of the neural network to optimize the processing of shallow features. Unlike [6,20] utilize GLBP for image classification, we only use GLBP to describe the texture features extracted in the shallow layers of the teacher network for knowledge transferring. In this paper, we propose Shallow Texture Knowledge Distillation (SeKD) to perform knowledge distillation in the shallow layers of the network. We use the GLBP to extract shallow texture features at three locations in the shallow layers of the network, as shown in Fig. 2, and utilize them as knowledge for transferring from teacher to student. Moreover, in machine learning, the GLBP often introduces a huge computation burden due to it produces massive feature maps. In our SeKD, we design a texture attention module to decrease the redundant computation caused by the GLBP. The contributions of this paper are summarized as follows: 1) We propose SeKD to exploit the informative but widely ignored shallow features of neural networks to transfer knowledge from teacher network to student network. 2) We investigate abound traditional mechanism learning methods and finally introduce the GLBP algorithm into a knowledge distillation framework for efficient shallow feature extraction. 3) We design the texture attention module to reduce unacceptable computation caused by GLBP.
2
Related Work
In this section, we first outline the research status of knowledge distillation and then briefly describe the development and computational characteristics of
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer
649
the GLBP algorithm, as well as the attention mechanism used to optimize the algorithm. 2.1
Knowledge Distillation
To increase the deployability of neural networks with superior performance on small terminals with limited computing power and storage resources, many model compression methods have been proposed, such as knowledge distillation [4,18,28], quantization [16], and pruning [13]. Among these techniques, knowledge distillation is an effective method and concern in this paper. Thus we will describe knowledge distillation methods as follows. Existing knowledge distillation researches can be roughly divided into three types according to the location of features extraction: features extraction at intermediate layers, deep layers, and other locations (e.g., dataset distillation [24]). The original knowledge distillation framework [4,18] is to promote the student network to simulate the soft labels outputted by the teacher network to approximate the inferential ability of the teacher. Subsequent researches [3,12,26,28] have found that transferring knowledge in the middle layer of the network can enhance student processing of features to improve performance. In recent years, abound methods search to extract features from various locations of models to optimize knowledge transfer further. Nonetheless, knowledge in the shallow layers of the network is always ignored. Unlike previous studies, our method processes shallow features as knowledge for transferring to further improve the student network’s ability. 2.2
Gradient Local Binary Pattern
In image processing, the processing of image texture information has always attracted much attention. Among many image processing algorithms, the Local Binary Pattern (LBP) proposed by Ojala et al. [14] is a valid local texture feature extraction method. Based on the LBP algorithm, Jiang et al. proposed an optimal texture feature extraction algorithm named Gradient Local Binary Pattern (GLBP). After further improvement [8,9], recent works [6,20] introduced GLBP into neural networks as the first layer of network for optimizing model. They utilize 56 “uniform” patterns of LBP to construct a GLBP layer with gradient information and use the GLBP layer as the first layer of the neural network for optimizing the processing of shallow features. Figure 1 shows the construction process of a kind of GLBP kernel. The constructed GLBP kernel is used to calculate with the feature map to obtain texture features. 2.3
Attention Mechanism
In recent years, more and more studies [2,22,23,25] show that the attention mechanism can bring performance improvement to DNNs. Woo et al. [25] introduce a lightweight and general module CBAM , which infers attention maps in
650
X. Huang et al.
Fig. 1. The diagram of how to construct GLBP kernel.
both spatial and channel dimensions. By multiplying the attention map and the feature map with adaptively refined features, the module exhibits considerable performance and low computational overhead. Fu et al. [2] propose DANet to aggregate local features with their global dependencies. They model the semantic dependencies of features through two attention modules to further improve the semantic segmentation performance of the network. Inspired by previous works, we introduce a texture attention module to refine the texture information of the feature map in the channel dimension. Since the number of channels in the feature map is compressed by the texture attention module, the redundant computation introduced by the GLBP operator is greatly reduced. Ablation experiment results verify the effectiveness of the module.
3
Method
In this section, we present how to realize our proposed SeKD in detail. Subsection 3.1 briefly reviews previous research and provides the necessary notational definitions for subsequent illustration. Subsection 3.2 proposes shallow texture knowledge distillation. Subsection 3.3 introduces the texture attention module we proposed in the channel dimension. Finally, Subsect. 3.4 presents the details of the training procedure for the student model. 3.1
Preliminary Research and Notation
Gradient Local Binary Pattern. GLBP [6,20] algorithm is proposed to describe local texture features. As shown in Fig. 1, for a 3 × 3 local area of an input feature, taking the pixel value of the center point as the reference, the position where the pixel value of the surrounding position is greater than that of the center position is expressed as 1, and vice versa is 0. The binary sequence obtained from this is called a GLBP pattern. Based on this binary pattern, the positions of all “1” areas are assigned the reciprocal of the length of the “1” area, and the positions of “0” areas are assigned the opposite of the reciprocal of the length of the “0” area, to construct GLBP kernel. In GLBP, considering the “uniform” patterns of LBP [14], only 56 binary patterns are used for calculation.
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer
651
We denote the input feature map as F ∈ RC×H×W , where C, H, and W represent the channels, height, and width, respectively. Correspondingly, the extracted GLBP feature is expressed as G ∈ RCg ×H×W . GLBP uses the GLBP kernel constructed by 56 patterns to convolute each channel graph of the input feature map. Therefore, Cg is 56 times that of C. The formula for extracting texture information with GLBP is defined as: G = GLBP (F ).
(1)
Knowledge Distillation for Image Classification. Knowledge distillation utilizes a large pre-trained teacher network to assist the training of a small student network, enabling the student network to achieve performance close to the teacher network. Given a set of training examples X = {x1 , x2 , ..., xn } and their labels Y = {y1 , y2 , ..., yn }, n denotes the total number of examples, the teacher and student networks are denoted by Ψt and Ψs , respectively. For image classification tasks, the loss function of the student network can be formulated as: Lstudent = αLce + (1 − α)LKD ,
(2)
where Lce is the cross-entropy loss to measure the distance between the predictions output by the model and the data labels. LKD is the distillation loss used to make the soft labels output by the student network similar to the teacher network. α ∈ (0, 1) is the hyper-parameter used to balance two the loss components. The knowledge distillation loss LKD is formulated as: LKD = τ 2
n Ψt (xi ) 1 Ψs (xi ) KL sof tmax , sof tmax , n i=1 τ τ
(3)
where KL indicates the Kullback-Leibler divergence measure the distance between two distributions. τ is the temperature hyper-parameter that controls the softening effect on the output labels. 3.2
Shallow Texture Knowledge Distillation
In knowledge distillation, the influence of the difference between shallow texture features for knowledge transmission is generally ignored. In our research, we utilize the GLBP to extract shallow texture features for knowledge transfer. In SeKD, we select three locations to extract shallow features, which are located behind the network’s first, second, and third convolution layers, as shown in Fig. 2. We use the GLBP operator to describe the texture features extracted by the shallow layers of the neural network. The difference between the shallow
652
X. Huang et al.
features of teacher and student is represented by the Euclidean distance. The loss function of shallow texture knowledge distillation is defined as: n
LSeKD =
1 ||GLBP (ψt (xi )) − GLBP (ψt (xi ))||2 , n i=1
(4)
where ψ(·) represents shallow layers of the neural network. Subscript t and s denote the variables produced by the teacher and student networks, respectively.
Fig. 2. The exploration of three locations for texture knowledge transmission.
3.3
Texture Attention Module
As mentioned in Subsect. 3.1, the number of channels of the feature map processed by the GLBP algorithm will be expanded by 56 times, which will bring a huge computational burden. To solve this problem, we introduce the texture attention module to compress the number of channels while retaining the texture information in the channel map. Since less texture information is generally represented as lower texture smoothness, calculating the GLBP feature of some channel maps with a smooth texture is inefficient. It is necessary to weight the channel map according to the degree of texture smoothness, which makes the GLBP algorithm pay more attention to practical texture information and reduce redundant calculations. In our research, we use standard deviation to evaluate texture smoothness. The standard deviation represents the degree of dispersion of a set of data. Generally, images with smooth texture have a minor standard deviation. So we represent the texture smoothness of the channel map by standard deviation. We denote a feature map with C channels extracted by the shallow layers of the neural network by ψ(xi ) = {fc }C c=1 . As shown in Fig. 3, texture attention module first calculate the standard deviation σc of each channel map in the
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer
653
spatial dimension, where c ∈ (1..C) indicates the channel index. The standard deviation of the cth channel map σc is calculated as: H W ¯ 2 i=1 j=1 (fij,c − fc ) . (5) σc = HW − 1 Then, the σc are processed by softmax to get the attention weight wc : C wc = exp(σc )/( exp(σi )).
(6)
i=1
After that, weighted addition of all channel map to get a single compressed channel feature map F : F =
C
(fi × wi ),
(7)
i=1
where fi is the ith channel diagram, F ∈ RH×W . Finally, F will be calculated with GLBP to get texture features. Since the number of channels of the input feature map ψt (xi ) is compressed, the computational cost of algorithm GLBP is significantly reduced. Then, we denote texture attention module as Θ(·) and the formula 4 can be rewritten as: n
LSeKD =
1 ||GLBP (Θ(ψt (xi ))) − GLBP (Θ(ψt (xi )))||2 . n i=1
(8)
Subsequent ablation experiments in Subsect. 4.2 prove that the texture attention module used in our research can reduce computation drastically at the expense of less performance.
Fig. 3. The calculation process of texture attention module. fi represents the ith channel map. σi represents the standard deviation of the ith channel map. wi represents the attention weight of the ith channel map.
654
3.4
X. Huang et al.
Implementation Details
With the channel attention mechanism, the loss function of shallow features can improve student network performance without redundant computing. The total loss function consists of three parts: LT otal = αLce + βLKD + γLSeKD ,
(9)
where α, β, and γ are the no-negative trade-off parameters. In the experiment, we follow the setting in [28] to set α, β, γ to 1, 1, 1000 to balance the loss items.
4
Experiments
In this section, we split the section into two parts to prove the effectiveness of the SeKD we proposed. Firstly, we verify that SeKD can improve the performance of student networks under different network architectures without bells and whistles. We choose ResNet8×4, WideResNet16-2 and WideResNet40-1 as the neural network model to train on CIFAR-100 datasets, WideResNet16-2 and WideResNet16-1 on CIFAR-10 [10] datasets. In the second part, we exhibit the influence of three different feature extraction positions and the texture attention module on the student network through ablation experiments. 4.1
Experiments on the CIFAR Dataset
CIFAR-10 is a dataset for image classification with ten categories containing 50,000 training RGB images and 10,000 test RGB images. Similar to CIFAR-10, CIFAR-100 has 100 categories and contains 1.4 million 32 × 32 color images. To achieve better results, our experimental parameters are configured with reference to [21]. On CIFAR-100, except FSP, we trained the student network for 240 epochs using Stochastic Gradient Descent (SGD) [1] as the optimizer, the initial value of learning rate is set as 0.05, the momentum is 0.9, and the weight-decay is 5e-4. The learning rate is divided by 10 at 150, 180, and 210 iterations. We summarize the experimental results in Table 1 and Table 2. Table 1. Experimental results of the CIFAR-10 dataset. TeacherNet
WideResNet40-2 WideResNet16-2
StudentNet
WideResNet16-2 WideResNet16-1
Student (base) KD [4] AT [28] SeKD (ours)
93.76% 93.90% 94.15% 94.35%
93.63% 94.20% 93.77% 94.38%
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer
655
Table 2. Experimental results of the CIFAR-100 dataset. We denote by * data where we citation from author-provided, for all other datas we used our reimplementation based on the paper. TeacherNet
ResNet32×4 WideResNet40-2 WideResNet40-2
StudentNet
ResNet8×4
Teacher (base) [21] 79.42%* Student (base) [21] 72.50%* 73.52% KD [4] 73.82% AT [28] 73.42% FitNet [18] 72.61% FSP [27] 73.80% NST [5] 73.80% PKT [15] 75.51%* CRD [21] 74.88% SeKD (ours)
WideResNet16-2 WideResNet40-1 75.61%* 73.26%* 75.35% 74.30% 73.77% 72.44% 74.26% 75.11% 75.48%* 75.90%
75.61%* 71.98%* 73.70% 72.65% 72.34% N/A 72.32% 73.28% 74.14%* 74.41%
As shown in Table 1, we first evaluated small-scale experiments on the CIFAR-10 dataset. The results show that our method improves the baseline model by 0.59% and 0.75%, respectively, and consistently provides better classification performance than alternative methods. Further, we conducted more extensive experiments on the CIFAR-100 dataset. The classification results of the student network are reported in Table 2, where SeKD outperforms the baseline model with 2.38%-2.64% higher accuracy. Notably, the student network WRN16-2 trained by SeKD obtained a 75.90% accuracy, which is 0.29% higher than the teacher network and 0.42% higher than the second effective method CRD [21]. We can conclude that, except for the student network ResNet8×4 trained with CRD [21], SeKD consistently outperforms all reference methods. 4.2
Ablation Experiments
In this subsection, we show the effectiveness of our design choice in two aspects. We use the same training strategy as CIFAR Experiments to verify the optimal location of GLBP feature extraction and the effectiveness of texture attention module. The Optimal Location. The transmission of shallow texture knowledge in different network locations will have different results. As shown in Fig. 2, we attempted three different locations of texture knowledge transmission. Table 3 shows the results of this experiment. As Table 3 indicates, It can be observed that: 1) When distilling knowledge in the first, second, and third layers, the classification accuracy of ResNet8×4
656
X. Huang et al.
is improved by 1.36%, 0.81%, and 0.85%, respectively. 2) The first layer is the optimal location for distilling shallow knowledge, which provides 0.55%–1.36% performance improvement for different models. Therefore, in the CIFAR experiments Sect. 4.1, we transfer shallow knowledge at the first layer of the network. Table 3. Experimental results of different location of knowledge transfer on CIFAR100. TeacherNet
ResNet32×4 WideResNet40-2 WideResNet40-2
StudentNet
ResNet8×4
WideResNet16-2 WideResNet40-1
KD(base) First layer Second layer Third layer
73.52% 74.88% 74.33% 74.37%
75.35% 75.90% 75.29% 75.36%
73.70% 74.41% 73.82% 74.28%
Texture Attention Module (TAM). In this experiment, we explored the impact of the TAM on both the performance and reduction of the computational burden of the SeKD. The computational burden is represented by the training time on an RTX3060. ResNet8×4, taught by ResNet32×4, is used as the base architecture to train on CIFAR-100. We take the training time of SeKD without TAM as 100% to show the effect of TAM on training acceleration. Other parameters of the experiment are set according to the optimal selection of existing experiments. Table 4. Experimental results of texture attention module on CIFAR-100. Method
Top1-acc Training Time
SeKD without TAM 74.94% 74.88% SeKD
11h 6’ 22” (100%) 7h 19’ 35” (65.97%)
Table 4 shows that the TAM achieves a 34.03% training acceleration with 0.06% performance degradation. It proves that the TAM can control the redundant computation introduced by GLBP with little impact on model performance.
5
Conclusion
In this paper, we focus on the ignored shallow features in knowledge distillation and propose a novel knowledge distillation framework called SeKD. SeKD utilizes GLBP to describe shallow features and introduces a texture attention module to alleviate redundant computation for optimizing the knowledge transfer. There are mainly two benefits of SeKD: 1) The shallow knowledge transfer
Optimizing Knowledge Distillation via Shallow Texture Knowledge Transfer
657
can enable student to learn the informative shallow features from the teacher. 2) The proposed texture attention module can effectively cut the redundant calculation introduced by the GLBP algorithm. Extensive experiments show that the proposed method exhibits higher performance than the baseline approaches. Acknowledgement. This research is supported by Sichuan Science and Technology Program (No. 2022YFG0324), SWUST Doctoral Research Foundation under Grant 19zx7102.
References 1. Bottou, L.: Stochastic gradient descent tricks. In: Montavon, G., Orr, G.B., M¨ uller, K.-R. (eds.) Neural Networks: Tricks of the Trade. LNCS, vol. 7700, pp. 421–436. Springer, Heidelberg (2012). https://doi.org/10.1007/978-3-642-35289-8 25 2. Fu, J., et al.: Dual attention network for scene segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 3146– 3154 (2019) 3. Heo, B., Lee, M., Yun, S., Choi, J.Y.: Knowledge transfer via distillation of activation boundaries formed by hidden neurons. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 3779–3787 (2019) 4. Hinton, G., Vinyals, O., Dean, J., et al.: Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, vol. 2, no. 7 (2015) 5. Huang, Z., Wang, N.: Like what you like: knowledge distill via neuron selectivity transfer. arXiv preprint arXiv:1707.01219 (2017) 6. Jiang, N., Tang, J., Yu, W., Zhou, J., Mai, L.: Gradient local binary pattern layer to initialize the convolutional neural networks. In: 2021 IEEE International Symposium on Circuits and Systems (ISCAS), pp. 1–5. IEEE (2021) 7. Jiang, N., Xu, J., Goto, S.: Pedestrian detection using gradient local binary patterns. IEICE Trans. Fundam. Electron. Commun. Comput. Sci. 95(8), 1280–1287 (2012) 8. Jiang, N., Xu, J., Yu, W.: An intra-combined feature for pedestrian detection. IIEEJ Trans. Image Electron. Visual Comput. 1(1), 88–96 (2013) 9. Jiang, N., Xu, J., Yu, W., Goto, S.: Gradient local binary patterns for human detection. In: 2013 IEEE International Symposium on Circuits and Systems (ISCAS), pp. 978–981. IEEE (2013) 10. Krizhevsky, A., Hinton, G., et al.: Learning multiple layers of features from tiny images (2009) 11. Krizhevsky, A., Sutskever, I., Hinton, G.E.: Imagenet classification with deep convolutional neural networks. Commun. ACM 60(6), 84–90 (2017) 12. Lee, S.H., Kim, D.H., Song, B.C.: Self-supervised knowledge distillation using singular value decomposition. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 335–350 (2018) 13. Li, H., Kadav, A., Durdanovic, I., Samet, H., Graf, H.P.: Pruning filters for efficient convnets. arXiv preprint arXiv:1608.08710 (2016) 14. Ojala, T., Pietikainen, M., Maenpaa, T.: Multiresolution gray-scale and rotation invariant texture classification with local binary patterns. IEEE Trans. Pattern Anal. Mach. Intell. 24(7), 971–987 (2002) 15. Passalis, N., Tefas, A.: Probabilistic knowledge transfer for deep representation learning. CoRR, abs/1803.10837, vol. 1, no. 2, p. 5 (2018)
658
X. Huang et al.
16. Polino, A., Pascanu, R., Alistarh, D.: Model compression via distillation and quantization. arXiv preprint arXiv:1802.05668 (2018) 17. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. In: Advances in Neural Information Processing Systems, vol. 28 (2015) 18. Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: hints for thin deep nets. arXiv preprint arXiv:1412.6550 (2014) 19. Sun, Y., Wang, X., Tang, X.: Deep learning face representation from predicting 10,000 classes. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1891–1898 (2014) 20. Tang, J., Jiang, N., Yu, W.: Gradient local binary pattern for convolutional neural networks. In: 2021 IEEE International Conference on Image Processing (ICIP), pp. 744–748. IEEE (2021) 21. Tian, Y., Krishnan, D., Isola, P.: Contrastive representation distillation. arXiv preprint arXiv:1910.10699 (2019) 22. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 23. Wang, F., et al.: Residual attention network for image classification. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3156–3164 (2017) 24. Wang, T., Zhu, J.Y., Torralba, A., Efros, A.A.: Dataset distillation. arXiv preprint arXiv:1811.10959 (2018) 25. Woo, S., Park, J., Lee, J.-Y., Kweon, I.S.: CBAM: convolutional block attention module. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 3–19. Springer, Cham (2018). https://doi.org/10.1007/9783-030-01234-2 1 26. Yang, J., Martinez, B., Bulat, A., Tzimiropoulos, G.: Knowledge distillation via adaptive instance normalization. arXiv preprint arXiv:2003.04289 (2020) 27. Yim, J., Joo, D., Bae, J., Kim, J.: A gift from knowledge distillation: fast optimization, network minimization and transfer learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4133–4141 (2017) 28. Zagoruyko, S., Komodakis, N.: Paying more attention to attention: Improving the performance of convolutional neural networks via attention transfer. arXiv preprint arXiv:1612.03928 (2016) 29. Zhang, L., Chen, X., Tu, X., Wan, P., Xu, N., Ma, K.: Wavelet knowledge distillation: towards efficient image-to-image translation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12464– 12474 (2022)
Unsupervised Domain Adaptation Supplemented with Generated Images S. Suryavardan, Viswanath Pulabaigari, and Rakesh Kumar Sanodiya(B) Indian Institute of Information Technology, Sri City, Chittoor, India {suryavardan.s19,viswanath.p,rakesh.s}@iiits.in
Abstract. With Domain Adaptation we aim to leverage a given source dataset to model a classifier on the target domain. In an unsupervised setting, the goal is to derive class-based features and adapt it to a different domain. Minimizing discrepancy between these domains is one of the most promising directions to training a common classifier. There have been several algorithmic and deep modelling approaches to this problem, all trying to combat the inherent problem of finding an ideal objective function. Recent algorithmic approaches focus on extracting domaininvariant features by minimizing the distributional and geometrical divergence between domains simultaneously. In this work, we present the results of adding generated synthetic images to some renowned shallow algorithmic approaches. Using Generative Adversarial Networks (GANs), we generate images using categorical information from source domain, thus, adding variance and variety to the source data. We present the impact of synthetic data on notable unsupervised domain adaptation algorithms and show an improvement in about 62% of the 80 task results. Keywords: Unsupervised Domain Adaptation · Transfer Learning Feature Learning · Generative Adversarial Networks
1
·
Introduction
Machine Learning (ML) is a growing field that aims to imitate humans’ ability to learn and apply it to different tasks. ML algorithms extract features from given inputs and recognise a relation that enable it to classify or categorize the input into a given set of classes. Despite the development of ML as a field and many researchers working towards its growth, one of the of the biggest challenges faced by it is the presence of a large number of tasks. The number of applications and the complexity of models is ever-growing, but it is accompanied by the presence of a large amount of unstructured data. Cleaning and processing this data for every task and from different domains is very expensive. This reliance of ML on labelled data is a major drawback as it is difficult to obtain such structured data in real-world applications. Domain adaptation helps combat this issue by devising approaches that help transfer knowledge and apply it to classifiers between different source and target domains. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 659–670, 2023. https://doi.org/10.1007/978-981-99-1639-9_55
660
S. Suryavardan et al.
Unsupervised Domain Adaptation deals with a target domain which has data without label information. Our task is to learn class-based information from a source domain and recognise similar differentiating features from the target domain. For example, in the CMU-PIE dataset (Pose, Illumination and Expression) [1], we have a labelled set of images of different human faces facing any direction A. This is the source dataset and it has label information for 68 different people. The target dataset is from a different domain with the people facing a different direction B. Our task is to train a common classifier on both domains, such that it can classify the unlabelled target data. This task has received a lot of attention in the computer vision research community and many different approaches have been presented. Existing approaches rely on the target domain sharing the same label space despite being from a different domain [2]. Based on this assumption, most approaches are focused on extracting common features or representation in a common feature space using label knowledge just from the source domain. This includes methods such as variance based algorithms and deep learning methods. Variance algorithms may use ideals like instance re-weighting, feature matching etc. Feature matching can also further be divided into data-centric and subspace-centric algorithms. More about these approaches in the next section. The contributions of this paper include the presentation of experiments that observe the impact of supplementing traditional domain adaptation algorithms with synthetic or generated data. As these shallow algorithms are reliant on domain common features and variance, we survey results of such algorithms on data generated using Generative Adversarial Networks. Our work also introduces a conditional as well as a bi-directional variant of a generative adversarial net, namely, a Bi-directional Conditional GAN.
2
Related Work
The presence of a large number of unstructured datasets and the applications using such datasets have only grown in the past years. This has encouraged focused research on unsupervised domain adaptation and its applications. One aspect of this is the introduction of several datasets ranging from general transfer learning [4] and image translation [5,6] to medical applications [7] and person re-identification [8]. The other aspect is the work related to proposing algorithms and models that try to match and classify the given source and target domain. Although a lot of work has been done in this field [3], we will discuss approaches relevant to our current work. The work in this field can be divided into a) Traditional approaches and b) Deep Learning approaches (Fig. 1). Deep learning approaches use neural networks to align the two domains. We briefly talk about some notable modelling deep approaches to unsupervised domain adaptation. The common modelling approaches are discrepancy based and adversarial based. Discrepancy based models use statistically defined loss functions such as Maximum Mean Discrepancy (MMD) [9], CORAL [10] etc.
Unsupervised Domain Adaptation Supplemented with Generated Images
661
Fig. 1. General categorization of Unsupervised Domain Adaptation approaches, taken from this survey [2]
This helps the neural model learn common weights that help align the two domains such that a common classifier can be trained. Adversarial methods include two competing models, where one generates domain invariant features and the other discriminates between features from different domains. This also is based on different loss functions such as MMD [11] etc. Traditional approaches work towards extracting features from raw images, identifying common attributes and using them to adapt the two domains. Feature matching or feature selection algorithms use feature extraction as their primary step. They aim to learn a common latent feature space such that the source and target domain can be transformed into a different space, where a common classifier can be trained. This feature matching can be centered around the data or the subspace. Subspace centric methods try to represent data points in a common subspace, for example the Grassmann manifold [12] or incrementally align features across domains in a geodesic flow kernel [13]. The work on Subspace alignment [15] lead to an approach that has been used as the base for other subspace centric algorithms. This has also been extended to algorithms that learn a common manifold by connecting the common features between source and target data. Data centric methods rely on aligning the two domains while preserving categorical features of both domains [14]. This is done by reducing the distribution gap between source and target data. Most of these algorithms include one or more of the following variance based methods: Marginal distribution adaptation, Conditional distribution adaptation and Joint distribution adaptation. They work by increasing variance between samples and preserving categorical information while minimizing distance between the source and target domain. Algorithms such as [18] exploit both these modelling approaches and preserve domain and class specific information while identifying shared
662
S. Suryavardan et al.
features and maximising variance across categories. Finally, some other traditional approaches also use pseudo-labels with the unlabeled target data and iteratively improve those labels [16]. This helps learn class invariant data from both source and target domains. The traditional approaches rely on presence of data with ample categorical information and this is further fueled by methods that rely on variance-based algorithms. In the upcoming sections, we present our results and observations on supplementing generated images to some of these notable traditional algorithms. Table 1. Existing Traditional Unsupervised Domain Adaptation algorithms and their objective functions as described in Sect. 4.2
3
Methods
Functions SDIP TDIP TVP ROS IR SDM MCDA
JDA [16] TJM [17] JGSA [18] LDAPL [19]
✗ ✗ ✓ ✓
✗ ✗ ✗ ✓
✗ ✗ ✓ ✓
✗ ✗ ✗ ✓
✗ ✓ ✗ ✗
✗ ✗ ✓ ✓
✓ ✓ ✓ ✓
Problem Statement
The initial problem description is that of general unsupervised domain adaptation. Let Ds = {(x1s , ys1 ), (x2s , ys2 ), ..., (xns s , ysns )} be the source domain data given that x and y are image and label pairs, such that Xs ∈ Rdxns and Ys ∈ R1xns . Similarly, let Dt = {(x1t ), (x2t ), ..., (xnt t )} be the target domain data given that x is the image, such that Xt ∈ Rdxnt . This source and target domain can be described as a distribution i.e. P (Xs ) Ds and P (Xt ) Dt . As Ds = Dt , we can not simply train a common classifier for predictions on the unlabelled target dataset. Thus, we need to learn a new feature representation with minimal distribution gap. We propose the additional of synthetic data, n n Dg = {(x1g , yg1 ), (x2g , yg2 ), ..., (xg g , yg g )}, such that Xg ∈ Rdxng and Ys ∈ R1xns . The generated images use class-based knowledge from the source domain labels with the aim of generating images in a common domain space. Dg must encode source domain as well as target domain data in order to supplement domain adaptation. This is achieved by using an adversarial network for the generation of images.
4
Method
Shallow traditional algorithms use methods that build upon eigen decomposition and PCA to find a matrix that optimally transforms the source and target data. They work by projecting data in a space that is common for both domains, while also maintaining the inter-relation between samples in the dataset. This refers
Unsupervised Domain Adaptation Supplemented with Generated Images
663
to three data components: the variance of the input data, class specific features and domain specific features. An issue in these algorithms is that they are constrained by the input data. When the dataset has many classes but only a few samples per class, it acts as a bottleneck. To address this limitation, we present our observations and results on supplementing these algorithms with generated data. The generated data ideally should add to all the three data components mentioned above. The following subsections describe our the generation model and the objective functions involved. 4.1
Generative Module
Generative Adversarial Networks (GANs) [20] are a subset of modern neural networks that use two competing models when training. A generator model tries to give output images as similar as possible to the given input dataset, while the discriminator tries to classify the generated images as real or fake. The two models are trained jointly using the classification loss in Eq. 1, such that the generator tries to fool the discriminator and by doing so learns to generate better images. Conditional Generative Adversarial Network (CGAN) [21] passes label information to the generator and discriminator, this simple change helps the model learn class-based knowledge from the input dataset. This is largely thanks to the loss function, that incorporates the conditional distribution gap, as in Eq. 2. min max V (D, G) = Expdata (x) [log D(x)] + Ezpz (z) [log(1 − D(G(z))] (1) G
D
min max V (D, G) = Expdata (x) [log D(x|y)]+Ezpz (z) [log(1−D(G(z|y))] (2) G
D
Using CGANs, we can generate synthetic images conditioned on source domain labels by discriminating between the source images and the generated images. An important modification to the model for this task is the additional encoder model that makes it bi-directional [22] and thus, the model is a Bidirectional Conditional GAN (BiCGAN). We added a pixel-wise loss function between the input noise to the generator and the encoded output from the fake images. This is different to the BiCoGAN implementation in [24], where they try to encode the categorical information rather than the input noise. The encoder model ensures that the diversity added by the noise is captured effectively, which is important especially in the case of a limited dataset such as the PIE dataset, as generating the exact same images as the input data would be redundant. Moreover, the encoder can be used to capture additional information to pass to the generator - for eg. encoding target domain data when testing. This helps us achieve the goal of supplementing the source data with images that add to its variance and helps improve the domain distribution matching. Model architecture is described in Fig. 2. The three models are jointly trained such that Dg has source label information and data uniqueness from the categorical loss and pixel
664
S. Suryavardan et al.
Fig. 2. Architecture of the proposed bi-directional conditional generative adversarial model. The generator, encoder and discriminator are jointly trained using classification loss for source and fake data, as well as mean squared loss between input noise and fake data. The input to the generator is noise and labels from source domain, to the encoder it is the fake data and the input to the discriminator is the fake data and source data
loss functions. Figure 4 shows examples of some generated images. The synthetic data preserves class-based knowledge, thus by increasing data size and adding variance, we are aiding the traditional algorithms in data transformation. 4.2
UDA Algorithm Module
Each traditional algorithm has some objective functions it tries to optimize when performing domain adaptation. Table 1 shows a list of some traditional algorithms with their objective functions. Objective Functions. A brief description of the objective functions given in the table, – SDIP is source discriminative information preservation. The two functions below help increase variance between classes and preserve distance within classes. A is the source transformation matrix and, Sb and Sw are between and within class variance of source domain. max Tr(AT Sb A) A
min Tr(AT Sw A) A
(3)
– TDIP is target discriminative information preservation. This is the same as SDIP but algorithms that use TDIP, implement pseduo-labels to incorporate class-based functions and here B is the target transformation matrix. max Tr(B T Tb B) B
min Tr(B T Tw B) B
(4)
Unsupervised Domain Adaptation Supplemented with Generated Images
665
– TVP is target variance preservation. It preserves target domain variance by maximizing the function while projecting data in a different space. Tt is the total variance of the target domain. max Tr(B T Tt B)
(5)
B
– ROS is retaining original similarity and SDM is Subspace divergence minimization. They aim to learn optimal transformation matrices A and B for source and target domains respectively. They help learn feature representation while aligning the two distributions using the Frobenius norm. 2
min A − BF
(6)
A,B
– IR or instance re-weighting is used to ignore irrelevant samples (instances) or outliers, when reducing the distribution gap. It uses a sparsity regularizer on the source tranformation matrix to assign weights to samples. 2
min A2,1 − BF
(7)
A,B
– MCDA is the most important function. It is marginal and conditional distribution alignment. It reduces the marginal and conditional distribution gap between the two domains by learning a new feature space while maintaining the original similarity within the datasets. 2 1 1 T T A mi − B mj (8) min A,B ns nt mi ∈Ms mj ∈Mt F
C 1 min nc A,B c=1 s
2 1 T T A mi − c B mj nt mi ∈Mcs mj ∈Mct
(9)
F
Role of Generated Data. As described above, the traditional algorithms use several objective functions for optimization and these functions can be aided with generated data. SDIP, TDIP and TVP rely largely on variance, which the generated data adds to, thanks to the input noise and the encoder model. We condition our model on source domain’s class or categorical data while also encoding target domain’s domain specific data. The absence of target data in Fig. 2 is because it is used only when testing, to generate noise from target images using the encoder to pass as input to the generator. Using the conditional discriminator and the encoder, the generator learns to include domain specific knowledge. As a result, while the generated data only moderately impacts the individual domains’ variance, it largely assists the distribution adaptation. MCDA being the most important function is enhanced by both the domain and class specific information, held within our generated data.
666
5 5.1
S. Suryavardan et al.
Experiments Implementation
Fig. 3. T-SNE visualization of the original dataset (on the left) and the generated data with 100 images for each class (on the right). The generated data in all domains show a diverse set of images for each class while also preserving the categorical similarity hence highlighting the aim of generating data.
The BiCGAN model was implemented using the Tensorflow library. As mentioned earlier, the modification in the CGAN as shown in the architecture in Fig. 2 was done so that the model can not only be source domain class conditioned but it can also learn to generate a variety of images instead of just duplicates from the source domain. Using this bi-directional network, we could generate two sets of data: original data generated by passing random noise to the generator and target conditioned data generated by passing target images to the encoder, followed by passing these encodings to the generator. We conducted experiments by passing these generated synthetic images to the traditional unsupervised domain adaptation algorithms. We chose state-of-the-art algorithms that introduced major changes or functions in this field. The algorithms we tested for are: Joint Distribution Adaptation (JDA) [16], Transfer Joint Matching (TJM) [17], Joint Geometrical and Statistical Alignment (JGSA) [18] and
Unsupervised Domain Adaptation Supplemented with Generated Images
667
Linear Discriminant Analysis via Pseudo Labels (LDAPL) [19]. These algorithms introduced the aforementioned objective functions to the unsupervised domain adaptation task. Our experiments were conducted on the PIE (Pose, Illumination and Expression) [1] dataset provided by CMU. The dataset contains images of human faces from five different camera angles, with different lighting, illumination and expression. These form the five different domains in the dataset: PIE05, PIE07, PIE09, PIE27 and PIE29. This gives us 20 cross domain tasks of the form: PIE05 → PIE07, PIE05 → PIE09, etc. Each domain has 68 classes i.e. 68 different people facing in 5 different directions. Figure 4 shows some examples from the dataset along with generated images from that domain. The results are described in Table 2 and they are given for the same set of parameters for all domain pairings.
Fig. 4. Two different examples of generated fake images from the PIE dataset. The larger image is a real image from the source domain and with that are 6 generated images
5.2
Results
The results presented in this work are for some notable traditional algorithms supplemented with synthetic images. We generated 100 images for each person or class in the source domain during prediction using the encoder and the generator. Then, the 6800 fake images for each task were passed to JDA, TJM, JGSA and LDAPL along with the original data. K-Nearest Neighbors classifier is used to get the final classification score of the respective target domain (norm in prior work). The accuracy scores for all tasks or domain pairs in the PIE dataset are given in Table 2. It is important to note that the results provided are for the same set of parameters and kernels. Instead of computing the best result across all possible parameters, we decided to compare the impact of additional synthetic data for a set of common baseline parameters to obtain a better understanding of its influence. Using these baseline parameters, about 62% of the tasks show improvements while about 50% of them have an increase of over 2% in the score. JDA performs better with synthetic data on 15 of the 20 PIE domain tasks. Similarly, TJM also shows improvements on 15 of the 20 tasks. JGSA and LDAPL being more complex algorithms, focus on optimizing individual domain
668
S. Suryavardan et al.
variance. Thus, adding images helps improve the scores of 11 tasks in JGSA and 8 in LDAPL. Figure 3 provides a t-distributed Stochastic Neighbor Embedding (T-SNE) plot of the source and target data. T-SNE uses the similarity between samples in a dataset to present a visualization in a lower dimensional plot. The figure shows that as the generated data is distributed differently to the source data, it enhances the variance and marginal distribution. On the other hand, it also shows that the conditional distribution is preserved as samples of the same class (or color) are close to one another. Table 2. Classification accuracy (%) on the PIE dataset compared for selected traditional algorithms and the algorithms supplemented with generated images using BiCGAN. Domain pair Algorithms JDA JDA + BiCGAN TJM
TJM + BiCGAN JGSA JGSA + BiCGAN LDAPL LDAPL + BiCGAN
P05→P07 P05→P09 P05→P27 P05→P29 P07→P05 P07→P09 P07→P27 P07→P29 P09→P05 P09→P07 P09→P27 P09→P29 P27→P05 P27→P07 P27→P09 P27→P29 P29→P05 P29→P07 P29→P09 P29→P27
55.92 57.41 85.79 45.77 64.44 64.03 84.05 45.53 62.12 60.90 80.74 48.16 85.20 87.42 89.71 58.95 53.66 49.23 53.00 64.52
6
65.81 53.80 90.51 45.96 64.89 67.77 86.90 52.27 56.09 60.22 83.60 48.04 86.31 88.77 90.81 59.99 51.29 48.99 51.96 65.97
62.86 55.94 86.57 45.28 68.37 71.38 87.17 55.76 61.16 65.99 86.00 53.68 82.02 88.34 91.42 62.75 53.06 53.84 56.00 66.78
58.32 55.02 89.25 49.26 64.02 66.97 83.66 44.36 58.43 60.04 80.62 44.49 84.63 86.74 88.97 57.11 49.31 50.15 47.00 63.05
78.82 73.22 92.52 56.31 75.72 80.51 85.85 67.52 75.57 76.61 87.41 63.85 91.63 93.62 91.30 73.71 61.85 70.96 69.49 76.90
75.26 67.89 91.17 47.18 81.30 81.13 89.07 58.52 75.93 78.39 90.78 66.42 92.38 93.06 91.05 72.18 62.45 69.61 72.61 77.65
80.45 75.74 95.40 59.31 83.85 83.88 90.09 66.54 76.56 74.77 90.06 68.32 94.33 95.76 92.46 79.04 68.04 74.22 72.79 82.70
79.13 71.63 95.70 56.74 83.22 80.45 90.72 63.42 72.78 76.98 91.77 66.85 94.48 95.03 93.38 77.94 64.62 76.86 76.96 80.47
Conclusion and Future Work
In this work, we present results obtained on supplementing notable traditional unsupervised domain adaptation algorithms, namely, JDA, TJM, JGSA and LDAPL with generated fake images. The images were generated using a bidirectional GAN conditioned on labels from the source domain data. The encoder module i.e. the bi-directional aspect of the GAN helps achieve the goal of adding class-conditioned data while also obtaining a large variety of images. The reason the approach works can be observed in the T-SNE visualization [23] in Fig. 3, which shows that the generated data maintains class based distribution while also adding variance to the original data. Moreover, encoding the target data helps incorporate target domain information in our synthetic data, further helping
Unsupervised Domain Adaptation Supplemented with Generated Images
669
the algorithms. The results show that by adding images and aiding the variancebased objective functions of the algorithms mentioned above, we can improve the performance of these algorithms on the PIE dataset. Although, a drawback is that as the complication of the base algorithm increases, the impact of synthetic data decreases. A higher-order objective function can possibly help combat this limitation. The generated data is successful in simplifying the distribution gap reduction, but it does not contribute as expected to the variance based functions. The future work involves 1) Testing similar approach on other domain adaptation datasets, 2) While our approach is computationally cheaper, better results can be obtained by directly adapting the domain after fake image generation using deep modelling instead of using traditional algorithms and 3) Use of higher-order objective function for domain matching.
References 1. Sim, T., Baker, S., Bsat, M.: The CMU Pose, Illumination, and Expression Database. IEEE Trans. Pattern Anal. Mach. Intell. 25(12), 1615–1618 (2003) 2. Zhang, Y.: A Survey of Unsupervised Domain Adaptation for Visual Recognition. arXiv preprint arXiv:2112.06745 (2021) 3. Wilson, G., Cook, D.J.: A survey of unsupervised deep domain adaptation. ACM Trans. Intell. Syst. Technol. (TIST) 11(5), 1–46 (2020) 4. Venkateswara, H., Eusebio, J., Chakraborty, S., Panchanathan, S.: Deep hashing network for unsupervised domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5018–5027 (2017) 5. Cordts, M., et al.: The cityscapes dataset for semantic urban scene understanding. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016) 6. Karras, T., Aila, T., Laine, S., Lehtinen, J.: Progressive growing of GANs for improved quality, stability, and variation. In: International Conference on Learning Representations (2018) 7. Jack, C., et al.: The Alzheimer’s Disease neuroimaging initiative (ADNI): MRI methods. J. Magn. Reson. Imaging JMRI 27, 685–691 (2008). https://doi.org/10. 1002/jmri.21049 8. Zheng, L., Shen, L., Tian, L., Wang, S., Wang, J., Tian, Q.: Scalable person re-identification: a benchmark. In: IEEE International Conference on Computer Vision (2015) 9. Gretton, A., Borgwardt, K., Rasch, M., Scholkopf, B., Smola, A.: A kernel method for the two-sample-problem. In: Scholkopf, B., Platt, J., Hoffman, T. (eds.) Advances in Neural Information Processing Systems, vol. 19. MIT Press (2006). https://proceedings.neurips.cc/paper/2006/file/ e9fb2eda3d9c55a0d89c98d6c54b5b3e-Paper.pdf 10. Sun, B., Feng, J., Saenko, K.: Return of frustratingly easy domain adaptation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 30 (2016) 11. Hoffman, J., et al.: Cycada: cycle-consistent adversarial domain adaptation. In: International Conference on Machine Learning, pp. 1989–1998. PMLR (2018) 12. Gopalan, R., Li, R., Chellappa, R.: Domain adaptation for object recognition: an unsupervised approach. In: 2011 International Conference on Computer Vision, pp. 999–1006. IEEE (2011)
670
S. Suryavardan et al.
13. Gong, B., Shi, Y., Sha, F., Grauman, K.: Geodesic flow kernel for unsupervised domain adaptation. In: 2012 IEEE Conference on Computer Vision and Pattern Recognition, pp. 2066–2073. IEEE (2012) 14. Pan, S.J., Tsang, I.W., Kwok, J.T., Yang, Q.: Domain adaptation via transfer component analysis. IEEE Trans. Neural Networks 22(2), 199–210 (2010) 15. Fernando, B., Habrard, A., Sebban, M., Tuytelaars, T.: Unsupervised visual domain adaptation using subspace alignment. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2960–2967 (2013) 16. Long, M., Wang, J., Ding, G., Sun, J., Yu, P.S.: Transfer feature learning with Joint Distribution Adaptation. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2200–2207 (2013) 17. Long, M., Wang, J., Ding, G., Sun, J., Yu, P.S.: Transfer joint matching for unsupervised domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1410–1417 (2014) 18. Zhang, J., Li, W., Ogunbona, P.: Joint geometrical and statistical alignment for visual domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1859–1867 (2017) 19. Sanodiya, R.K., Yao, L.: Linear discriminant analysis via pseudo labels: a unified framework for visual domain adaptation. IEEE Access 8, 200073–200090 (2020) 20. Goodfellow, I., et al.: Generative adversarial nets. In: Advances in Neural Information Processing Systems, vol. 27 (2014) 21. Mirza, M., Osindero, S.: Conditional generative adversarial nets. arXiv preprint arXiv:1411.1784 (2014) 22. Donahue, J., Krahenbuhl, P., Darrell, T.: Adversarial feature learning. arXiv preprint arXiv:1605.09782 (2016) 23. Van der Maaten, L., Hinton, G.: Visualizing data using t-SNE. J. Mach. Learn. Res. 9(11) (2008) 24. Jaiswal, A., AbdAlmageed, W., Wu, Y., Natarajan, P.: Bidirectional conditional generative adversarial networks. In: Jawahar, C.V., Li, H., Mori, G., Schindler, K. (eds.) ACCV 2018. LNCS, vol. 11363, pp. 216–232. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-20893-6 14
MAR2MIX: A Novel Model for Dynamic Problem in Multi-agent Reinforcement Learning Gaoyun Fang1 , Yang Liu1 , Jing Liu1,2 , and Liang Song1,2(B) 1
2
Academy for Engineering and Technology, Fudan University, Shanghai 200433, China [email protected] Shanghai East-Bund Research Institute on NSAI, Shanghai 200439, China
Abstract. As a challenging problem in the Multi-Agent Reinforcement Learning (MARL) community, the cooperative task has received extensive attention in recent years. Most current MARL algorithms use the centralized training distributed execution approach, which cannot effectively handle the relationship between local and global information during training. Meanwhile, many algorithms mainly focus on the collaborative tasks with a fixed number of agents without considering how to cooperate with the existing agents when the new agents enter in the environment. To address the above problems, we propose a Multiagent Recurrent Residual Mix model (MAR2MIX). Firstly, we utilize the dynamic masking techniques to ensure that different multi-agent algorithms can operate in dynamic environments. Secondly, through the cyclic residual mixture network, we can efficiently extract features in the dynamic environment and achieve task collaboration while ensuring effective information transfer between global and local agents. We evaluate the MAR2MIX model in both non-dynamic and dynamic environments. The results show that our model can learn faster than other benchmark models. The training model is more stable and generalized, which can deal with the problem of agents joining in dynamic environments well.
Keywords: Multi-Agent reinforcement learning Dynamic problem
1
· Coordination task ·
Introduction
With the development of artificial intelligence technology [12,17], many fields apply Multi-Agent Reinforcement Learning (MARL) to tackle their problems, such as edge computing [26,33], traffic flow control [4] and mobile communication [3,7,16]. In these fields, cooperation is the main focus. Moreover, dynamic entity structure and effective performance are also considered essential. Recently, This work is supported in part by the Shanghai Key Research Lab. of NSAI, China, and the Joint Lab. on Networked AI Edge Computing Fudan University-Changan. c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 671–682, 2023. https://doi.org/10.1007/978-981-99-1639-9_56
672
G. Fang et al.
Centralized Training and Decentralized Execution method (CTDE) is widely regarded as an excellent paradigm and deals with many multi-agent cooperation challenges [2,30,32]. In this paradigm, value decomposition is the main method. VDN [27] and QMIX [22] develop from this method. Other algorithms, such as Qtran [25] and QTran++ [24], develop from QMIX. Additionally, sharing observations is another way to enhance model performance, such as MADDPG [18] and COMA [8]. However, these researches apply the algorithms in the scenario, which has a fixed number of agents. Variation is inevitable in the real world. For instance, edge computing faces ad-hoc connection problems [5]. When the new devices enter a constructed network, edge servers need to reallocate computing and energy resources. Another example is that traffic flow control system encounters vehicles changing on the roads [13]. The controller needs to adjust its scheduling strategies to avoid vehicle congestion on the road. Both examples face computation complexity to transform the strategy while new entities enter. Generally, some environment features and information are related to time variety. Therefore, extracting features and information from time is dispensable to enhance strategy’s performance [21]. Recent researches utilize additional model, such as transfer learning [6,23] to learn the features from different scale agents. These methods improve the computational overhead and training latency but only suit for a fixed number of agents entering in a fixed time. Another research proposes a coach-player scheme, COPA [15], to address dynamic agent entering. Nevertheless, COPA executes in a scenario allowing agents to can communicate with each other. Allowing agents to communicate with each other increases computational complexity while being incompatible with a real physical environment without communication. To address the dynamic agent random entering, we propose a novel algorithm named Multi-Agent Recurrent Residuals Mix model (MAR2MIX). Firstly, inspired by the transformer [28] and COPA [15], we use the technique of masking the non-activated agents using a dynamic mask until this part of the agents is activated in the environment. This mechanism ensures the inactivated agent does not impact the existing agents. It also guarantees that the inactivated agents can interact with the environment and existing agents, execute their actions, and train their parameters after they are activated. Secondly, the proposed MAR2MIX model can extract the dynamic features from temporal information caused by the environment when new agents are added. Simultaneously, it can extract individual agents’ global and local information and then transfer them from the network. Finally, we are inspired by the Soft Actor Critic policy (SAC) [10] and introduce the entropy mechanism into our model. When a new agent enters the environment, states and action selection become complex. Thus, the agent must explore the action space to construct an efficient policy effectively. The entropy mechanisms utilize the stochastic methods to encourage agents adequately explore the action space. We implement our algorithm with Multi Particle Environment (MPE) on the Open-AI’s simulation platform [18] and compare the proposed MAR2MIX with MADDPG [10], MATD3 [1,29] and MASAC [9]. We modified the MPE and divided the scenarios into dynamic and non-dynamic categories. In all scenarios
MAR2MIX: A Novel Model for Dynamic Problem in MARL
673
of the experiments, MAR2MIX outperforms existing models. In dynamic scenarios, our model significantly outperforms other algorithms regarding training speed, stability, generalization, and the final reward score.
2 2.1
Background Problem Formulation
In this work, dynamic cooperation task is considered as a decentralized Partially Observable Markov Decision Processing (POMDP) [15], which is described as a tuple . The global states are defined as S. While the agents execute in the environment, they only observe a part of the environment around them. Thus, o ∈ O denotes the information that each agent observes. a ∈ A denotes the activated agents, which can interact with the environment. On the other hand, a ˆ ∈ Aˆ denotes the inactivated which can not interact with the environment and other agents. r ∈ R denotes the rewards of agents. T represents the transition which records the information of agents and environment on each time step. E denotes the entropy. Each agent has an independent policy which is denoted by π. After the activated agents observe, they generate the action and interact with the environment. Environment generates the reward, denoted by r ∼ R(o, a). To measure the value of the environment’s states and agents’ action, Qπ (o, a) is introduced. Qπ (o, a) is the expectation of discounted accumulation of R and is defined as: Qπ (o, a) = E[rt+1 (o, a) + γQπ (ot , at )|ot = o, at = a].
(1)
The algorithm aims to train the agent so that the agent’s policy can maximize the Q function. 2.2
Deep Deterministic Policy Gradient
Currently, the multi-agent reinforcement learning algorithm for the POMDP [20] problem mainly contains two main ways: value function decomposition [22, 27] and policy gradient [18]. Value function decomposition is based on Deep Q Learning Algorithm (DQN) [19], which calculates the loss function by using the temporal difference error function. Therefore, the loss function is defined as: JQ (θ) = E[(Qθ (ot , at ) − (r(ot , at ) + γQ∗θ (ot+1 , at+1 )))],
(2)
where Q∗ denotes target Q network. θ and θ denote the parameters of the current network and target network, respectively. However, the independent deployment of DQN by each agent is difficult to complete the task collaboratively. To address this problem, series of algorithms based on value function decomposition are proposed, such as QMIX [22], Qtran [25], VDN [27]. Although these methods apply the CTDE framework, the highly centralized value function decomposition method is unsuitable for distributed deployment. Another way, using policy gradient, is based on Deep Deterministic Policy Gradient Algorithm (DDPG) [14].
674
G. Fang et al.
∞ To maximize score function J (θπ ) = t=0 (γ t rt ), DPPG utilizes deep network to adjust parameter θ to achieve the goal. The policy function is defined as: π ∗ = argmaxπ E[r(ot , at )]. (3) t
According to the Q function, policy loss function is defined as: Jπ (φ) = E[−Qθ (ot , at )].
3
(4)
Method
In this section, we will introduce our method called MAR2MIX. When the agents interact with the environment, the transition is generated, which is highly related to the time feature. Thus, Recurrent Neural Network (RNN) [31] can efficiently improve the extraction performance of time-varying features. To deal with overestimation and improve stability, we introduce the entropy mechanism [10] and Residual [11] method. 3.1
RNN Augmented SAC
SAC [10] method is widely used in reinforcement learning, which is an extension of DDPG. Compared with other mainstream reinforcement learning algorithms, SAC imports the entropy mechanism into the value function. The entropy mechanism promotes the agent to fully explore the action space, which can effectively avoid the model falling into local optimal. Assuming that x obeys the distribution P , the entropy for the input x is defined as: E(P ) = E[−logP (x)].
(5)
The value function Qπ (o, a) is modified as: ∞ ∞ Qπ (o, a) = E[ γ t rt (ot , at ) + α γ t H(π(ot ))|ot = o, at = a]. t=0
(6)
t=0
Therefore, Q value loss function JQ (θ) is defined as: 1 JQ (θ) = E[ (Qθ (ot , at ) − (r(ot , at ) + γ(Qθ (ot+1 , at+1 ) − αlog(π(ot+1 )))))2 ]. (7) 2 Besides, the policy function π ∗ and policy loss function Jπ (φ) is defined as: E[r(ot , at ) + αH(π(ot ))], (8) π ∗ = argmaxπ t
Jπ (φ) = E[αlogπφ (fφ (t , ot ) − Qθ (ot , fφ (t , ot ))],
(9)
where t is an input noise sampled from the Gaussian distribution, and at = fφ (t ; ot ) is an action sampler in line with current policy. As MADDPG is the
MAR2MIX: A Novel Model for Dynamic Problem in MARL
675
Fig. 1. The overall structure of MAR2MIX. The left part is the structure diagram of policy network, and the right part is the flow of training and execution.
multi-agent version of DDPG, SAC can be extended to be a Multi-Agent Soft Actor Critic method (MASAC). However, while agents enter the environment dynamically, the environment’s state and other agents’ estimation will be influenced. The perturbation caused by the addition of agents will affect the network’s judgment of the loss function and the value function. Multi-Layer Perceptron (MLP) extracts a few features from observation using traditional methods to deal with the dynamic problem. However, a dynamic problem directly affects the transition, and the transformation of the transition is tightly related to time. In contrast, RNN provides a practical choice to extract time features. Therefore, we introduce RNN into MASAC structure to improve the method’s ability to learn more hidden information about dynamic processes. 3.2
Dynamic Mask Technology
Before importing the observation and action into the critic network, we utilize a dynamic mask technology to handle the entering of new agents. We define existing agents as activated and define unexisting agents as inactivated. Before the inactivated agents enter the environment, all their values and parameters are set to zero. The dynamic mask is not removed until the agents are activated. Inspired by the padding structure of Transformer [28] and COPA [15], dynamic mask technology not only ensures that the inactive agent will not affect the existing agent but also guarantees that the newly activated agent can interact with other agents after entering the environment without changing the structure of the critic network and the policy network. Other multi-agent algorithms can be deployed in dynamic environments by using this technique.
676
G. Fang et al.
Algorithm 1: Multi-Agent Recurrent Residual Mix Model 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
3.3
Initialize parameter θ and φ of the network; for episode=1 to max-episode-number do Initialize parameter of exploration and transition (o, a, r, o ); for step=1 to max-time-step do Set non-activated observation oi , action ai and reward ri to zero; if i > , randomly select an action ai from the action space A; else Select action ai = πφi (oi ); Execute activated actions a = (ai aN ) get reward ri and new observation oi ; Store (o, a, r, o ) in replay buffer D, oi ← oi ; for agent i = 1 to N do Sample a bath d from D; if agent i is not activated then Set Qi ,θi ,φi to Zero; else Calculate Q value Qi by Critic network; Calculate target Q value Qi by target Critic network; Calculate Q loss JQ (θ) and Update θ by minimizing TD-Loss; Calculate entropy H for agent i and Calculate Policy loss Jπ (φ) and Update φ by minimizing TD-Loss; θi ← θi − λQ ∇θi J (θ), φi ← φi − λπ ∇φi J (φ) ; θˆi ← τ θi + (1 − τ )θˆi
Recurrent Residual MIX Network
Value function decomposition models like QMIX and VDN utilize a central Q function to gain global information and train every agent’s Q network. Different model has a different method for the decomposing value function. VDN achieves decomposition by using the sum relationship of the agent’s value function. QMIX proposes that each agent’s value function is a joint value that a MIX network can analyze. Thus, QMIX uses hypernet to decompose. Although both convey the global information from a centralized structure, they effectively decrease each agent’s ability to calculate a local Q value from their observation. On the other hand, the decentralized Q function is another method deployed on the CTDE model, such as MADDPG and MATD3. Agents compute their local Q value by using an independent critic network. The Independent network enhances the critic network to adapt to local observation. Besides, this model transfer other agents’ observation into the being trained agent to gain the global features. Without other operations, a shallow critic network cannot distinguish between global and local information. The impact of local information gradually diminishes as this data is passed between neurons. As mentioned, neither of the two types of models address the problem of local Q value transferring.
MAR2MIX: A Novel Model for Dynamic Problem in MARL
677
To solve the problems, we propose a novel CTDE model which combines an independent critic network and a modified MIX network. After the dynamic mask technology is operated, observation and action are imported into the MIX network. As Fig. 1 shows, hypernet W extracts all agent’s observation and action as weight, and hypernet B extract the local agent’s observation and action as bias. Q-net outputs multiply with the weights W and add it to the bias B. Since the dynamic environment mentioned in Sect. 3.1 is closely related to temporal features, we input the output of the mix into an RNN network with the residual module. The entire structure of the MAR2MIX model is described in Algorithm 1. Through this method, the information of the activated local agent can be effectively transmitted through the network while retaining the global agent information. In addition, the modified mix network can ignore the features of the newly entering agent by the global features.
4
Experiments
In this section, we compare different algorithms through various experiments, and the experimental results prove that our algorithm improves performance effectively. First, by conducting experiments in dynamic and non-dynamic environments and analyzing the experimental results, we present that our model has significant advantages in dealing with the problems of the two scenarios. Secondly, through the ablation study, we confirmed that the model’s two modules can improve the model’s performance. 4.1
Environmental Setup
OpenAI’s multi-agent reinforcement learning environment MPE is used as the primary environment for this experiment. Due to this paper’s focus on cooperation tasks, we mainly evaluate the algorithm in the Spread scenario of MPE. However, since the MPE are all non-dynamic environments inspired by COPA, we modify the Spread scenario into a dynamic environment. In a non-dynamic environment, the agent’s task is to take the shortest path to the target point based on reducing collisions with each other. The reward obtained by the agent is the distance between the agent and the target point. When an agent collides with another agent, the reward value is set to −4. In a dynamic environment, the activated agent enters when the agent runs for a random period. The newly entering agent also cooperates with the agent existing in the environment, and the reward setting is consistent with the non-dynamic environment. In addition, to improve the environment’s complexity and verify the algorithm’s generalization in different environments, we modified the dynamic environment and added dynamic obstacles to interfere with the agent’s actions. When the agent collides with a moving obstacle, the reward value is set to −6.
678
G. Fang et al.
Fig. 2. Comparisons the performance of MAR2MIX and other algorithms in four different scenarios.
4.2
Parameter Setting
For a fair comparison, the proposed model and the compared benchmark models use the same environment parameters. In the training process, all models are set to 10K steps, with every 150 steps set to an episode, and every 300 steps are set to a test. We evaluate the model’s performance by testing ten episodes and calculating the mean value of the rewards. The same network hyperparameters were used for both MAR2MIX and the benchmark models. The number of multi-layer neurons is 256. The initial value of greedy and the learning rate are uniformly set to 0.1 and 0.3, respectively. The discount factor is set to 0.9. The temperature coefficient alpha in MASAC and MAR2MIX is set to 0.2. 4.3
Experimental Results
We compare the proposed MAR2MIX model with MADDPG, MATD3, and MASAC. The curves for the training process are illustrated in Fig. 2. The line segments in the figure represent the mean number of rewards tested for different models, and the shading in the line segments shows the standard deviation of the rewards. There are two environment settings, dynamic and non-dynamic,
MAR2MIX: A Novel Model for Dynamic Problem in MARL
679
Table 1. Mean evaluate reward in non-dynamic and dynamic scenario. Model
(a) The number of existing agents in non-dynamic scenario. 3 4 5 6 7
MADDPG MATD3 MASAC MAR2MIX
−46.7 ± 5.9 −59.0 ± 9.2 −46.6 ± 4.4 −19.5 ± 3.1
Model
(b) The number of enter agents with 2 existed in dynamic scenario. 1 2 3 4 5
MADDPG MATD3 MASAC MAR2MIX
−50.0 ± 5.3 −52.9 ± 5.1 −44.2 ± 8.6 −21.3 ± 2.6
−60.1 ± 4.6 −61.3 ± 5.9 −59.0 ± 6.7 −33.7 ± 4.9
−57.0 ± 4.2 −61.2 ± 5.8 −46.0 ± 4.4 −32.8 ± 3.9
−76.3 ± 6.7 −71.8 ± 5.8 −67.7 ± 4.9 −48.3 ± 3.7
−68.2 ± 4.1 −67.2 ± 6.3 −66.3 ± 6.7 −46.8 ± 4.4
−80.0 ± 5.8 −80.1 ± 6.0 −76.9 ± 8.1 −57.7 ± 5.3
−66.5 ± 5.6 −68.0 ± 4.8 −69.8 ± 4.4 −50.1 ± 7.4
−80.7 ± 6.1 −82.7 ± 3.9 −84.5 ± 6.3 −68.9 ± 5.6
−75.3 ± 3.4 −71.3 ± 3.8 −68.8 ± 5.3 −55.1 ± 6.5
each containing two types of tasks. One type of task only requires the agents to avoid collisions with each other while reaching the target point with the shortest path. The other kind of task is based on the previous one and requires avoiding moving obstacles. The experimental results in Fig. 2(a) and 2(b) perform that in the nondynamic scenario, the learning speed of MADDPG and MATD3 is slower than the proposed MAR2MIX. Besides, in the later stage of training, their reward value is also lower, and the stability is worse. The learning speed of MAR2MIX in the training process is significantly faster than other benchmark models. The shaded area of the experimental results in the pre-training and post-training works is not large, which can indicate the stability of the model in the whole training process. In particular, the average reward values obtained by the agents in the later stages of training are significantly higher than those of the other models. MAR2MIX still performs well in dynamic environments, which is shown in Fig. 2(c) and 2(d). The performance of MASAC in a dynamic environment is significantly degraded compared to a non-dynamic environment. the final convergence score of MASAC is not as high as that of a non-dynamic environment. In contrast, MAR2MIX remains stable in the dynamic environment. Therefore, MAR2MIX can help the agents learn the time-varying features in the dynamic environment and ensure task completion. 4.4
Ablation Study
To verify the generalization ability and robustness of the proposed model to environmental changes, we increase the number of agents in our experiments and determine the model’s performance by detecting the change of the reward value of the agents during the run several times. Table 1(a) and Fig. 3(a) show the evaluate values of models with different number of agents in non-dynamic scenario. It should be noted that the increase
680
G. Fang et al.
Fig. 3. (a) Results for different numbers of agents in a non-dynamic environment. (b) Results for different number of agents entering when there are already two agents in a dynamic environment.
in the number of agents causes fluctuations in the overall reward values, but the graphs still reflect the model’s performance in the face of changes in agents. The performance of MAR2MIX is more stable even in scenarios with a high number of agents because of the slight decrease in the rewards compared with other models. We can also conclude that the robustness of the model is better than other benchmark models by the slight fluctuation of the standard deviation of MAR2MIX. Table 1(b) and Fig. 3(b) show the model’s overall performance with the number of entering agents increased. As shown in Fig. 3, MAR2MIX can maintain a high reward value with only a slight decrease in performance when the number of entering agents increases. The experimental results demonstrate that MAR2MIX can handle the agent joining problem in different scenarios well while ensuring the stability and training efficiency of the model.
5
Conclusion
This paper proposes MAR2MIX, a multi-agent reinforcement learning model for solving dynamic join-agent cooperation tasks. This model uses MASAC as the backbone for dynamic environment processing using dynamic masking techniques. Throughout the model, we propose a novel distributed policy network, R2MIX, which extracts the global agents’ state features while ensuring that the features of local agents are not dropped due to network transmission through the operation of the MIX network. Meanwhile, to extract dynamic features with high time correlation, we add recurrent neural networks and residual neural blocks to the network. The experimental results show that MAR2MIX can efficiently implement cooperative tasks in dynamic and non-dynamic environments, and the model learns quickly and has good robustness and generalization. In the future, we will explore more ways to improve the model’s performance by efficiently improving the transition extraction.
MAR2MIX: A Novel Model for Dynamic Problem in MARL
681
References 1. Ackermann, J., Gabler, V., Osa, T., Sugiyama, M.: Reducing overestimation bias in multi-agent domains using double centralized critics. arXiv preprint arXiv:1910.01465 (2019) 2. Canese, L., et al.: Multi-agent reinforcement learning: a review of challenges and applications. Appl. Sci. 11(11), 4948 (2021) 3. Chen, M., et al.: Distributed learning in wireless networks: recent progress and future challenges. IEEE J. Sel. Areas Commun. (2021) 4. Chu, T., Wang, J., Codec` a, L., Li, Z.: Multi-agent deep reinforcement learning for large-scale traffic signal control. IEEE Trans. Intell. Transp. Syst. 21(3), 1086–1095 (2019) 5. Cui, J., Wei, L., Zhang, J., Xu, Y., Zhong, H.: An efficient message-authentication scheme based on edge computing for vehicular ad hoc networks. IEEE Trans. Intell. Transp. Syst. 20(5), 1621–1632 (2018) 6. Da Silva, F.L., Warnell, G., Costa, A.H.R., Stone, P.: Agents teaching agents: a survey on inter-agent transfer learning. Auton. Agent. Multi-Agent Syst. 34(1), 1–17 (2020) 7. Feriani, A., Hossain, E.: Single and multi-agent deep reinforcement learning for AIenabled wireless networks: a tutorial. IEEE Commun. Surv. Tutor. 23(2), 1226– 1252 (2021) 8. Foerster, J., Farquhar, G., Afouras, T., Nardelli, N., Whiteson, S.: Counterfactual multi-agent policy gradients. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 9. Gupta, S., Dukkipati, A.: Probabilistic view of multi-agent reinforcement learning: a unified approach (2019) 10. Haarnoja, T., Zhou, A., Abbeel, P., Levine, S.: Soft actor-critic: off-policy maximum entropy deep reinforcement learning with a stochastic actor. In: International Conference on Machine Learning, pp. 1861–1870. PMLR (2018) 11. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 12. LeCun, Y., Bengio, Y., Hinton, G.: Deep learning. Nature 521(7553), 436–444 (2015) 13. Liang, J., Chen, J., Zhu, Y., Yu, R.: A novel intrusion detection system for vehicular ad hoc networks (VANETs) based on differences of traffic flow and position. Appl. Soft Comput. 75, 712–727 (2019) 14. Lillicrap, T.P., et al.: Continuous control with deep reinforcement learning. arXiv preprint arXiv:1509.02971 (2015) 15. Liu, B., Liu, Q., Stone, P., Garg, A., Zhu, Y., Anandkumar, A.: Coach-player multi-agent reinforcement learning for dynamic team composition. In: International Conference on Machine Learning, pp. 6860–6870. PMLR (2021) 16. Liu, C., Tang, F., Hu, Y., Li, K., Tang, Z., Li, K.: Distributed task migration optimization in MEC by extending multi-agent deep reinforcement learning approach. IEEE Trans. Parallel Distrib. Syst. 32(7), 1603–1614 (2020) 17. Liu, Y., Liu, J., Zhu, X., Wei, D., Huang, X., Song, L.: Learning task-specific representation for video anomaly detection with spatial-temporal attention. In: IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 2190–2194. IEEE (2022)
682
G. Fang et al.
18. Lowe, R., Wu, Y.I., Tamar, A., Harb, J., Pieter Abbeel, O., Mordatch, I.: Multiagent actor-critic for mixed cooperative-competitive environments. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 19. Mnih, V., et al.: Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602 (2013) 20. Murphy, K.P.: A survey of POMDP solution techniques. Environment 2(10) (2000) 21. Rakelly, K., Zhou, A., Finn, C., Levine, S., Quillen, D.: Efficient off-policy metareinforcement learning via probabilistic context variables. In: International Conference on Machine Learning, pp. 5331–5340. PMLR (2019) 22. Rashid, T., Samvelyan, M., Schroeder, C., Farquhar, G., Foerster, J., Whiteson, S.: QMIX: monotonic value function factorisation for deep multi-agent reinforcement learning. In: International Conference on Machine Learning, pp. 4295–4304. PMLR (2018) 23. Shao, K., Zhu, Y., Zhao, D.: Starcraft micromanagement with reinforcement learning and curriculum transfer learning. IEEE Trans. Emerg. Top. Comput. Intell. 3(1), 73–84 (2018) 24. Son, K., Ahn, S., Reyes, R.D., Shin, J., Yi, Y.: Qtran++: improved value transformation for cooperative multi-agent reinforcement learning. arXiv preprint arXiv:2006.12010 (2020) 25. Son, K., Kim, D., Kang, W.J., Hostallero, D.E., Yi, Y.: Qtran: learning to factorize with transformation for cooperative multi-agent reinforcement learning. In: International Conference on Machine Learning, pp. 5887–5896. PMLR (2019) 26. Song, L., Hu, X., Zhang, G., Spachos, P., Plataniotis, K., Wu, H.: Networking systems of AI: on the convergence of computing and communications. IEEE Internet Things J. (2022) 27. Sunehag, P., et al.: Value-decomposition networks for cooperative multi-agent learning. arXiv preprint arXiv:1706.05296 (2017) 28. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 29. Xie, Y., Lin, R., Zou, H.: Multi-agent reinforcement learning via directed exploration method. In: 2022 2nd International Conference on Consumer Electronics and Computer Engineering (ICCECE), pp. 512–517. IEEE (2022) 30. Yang, Y., Wang, J.: An overview of multi-agent reinforcement learning from game theoretical perspective. arXiv preprint arXiv:2011.00583 (2020) 31. Yu, Y., Si, X., Hu, C., Zhang, J.: A review of recurrent neural networks: LSTM cells and network architectures. Neural Comput. 31(7), 1235–1270 (2019) 32. Zhang, K., Yang, Z., Ba¸sar, T.: Multi-agent reinforcement learning: a selective overview of theories and algorithms. In: Handbook of Reinforcement Learning and Control, pp. 321–384 (2021) 33. Zhou, Z., Chen, X., Li, E., Zeng, L., Luo, K., Zhang, J.: Edge intelligence: Paving the last mile of artificial intelligence with edge computing. Proc. IEEE 107(8), 1738–1762 (2019)
Adversarial Training with Knowledge Distillation Considering Intermediate Representations in CNNs Hikaru Higuchi1 , Satoshi Suzuki2 , and Hayaru Shouno1(B) 1
2
The University of Electro-Communications, Chofu, Tokyo, Japan {h.higuchi,shouno}@uec.ac.jp NTT Computer and Data Science Laboratories, Yokosuka, Kanagawa, Japan [email protected]
Abstract. A main challenge for training convolutional neural networks (CNNs) is improving the robustness against adversarial examples, which are images with added the artificial perturbations to induce misclassification in a CNNs. This challenge can be solved only by adversarial training, which uses adversarial examples rather than natural images for CNN training. Since its introduction, adversarial training has been continuously refined from various points of view. Some methods focus on constraining CNN outputs between adversarial examples and natural images, resembling knowledge distillation training. Knowledge distillation was originally intended to constrain the outputs of teacher–student CNNs to promote generalization of the student CNN. However, recent methods for knowledge distillation constrain intermediate representations rather than outputs to improve performance for natural images because it directly works well to preserve intraclass cohesiveness. To further investigate adversarial training using recent knowledge distillation methodology (i.e., constraining intermediate representations), we attempted to evaluate this method and compared it with conventional ones. We first visualized intermediate representations and experimentally found that cohesiveness is essential to properly classify not only natural images but also adversarial examples. Then, we devised knowledge distillation using intermediate representations for adversarial training and demonstrated its improved accuracy compared with output constraining for classifying both natural images and adversarial examples. Keywords: Convolutional neural network · Adversarial training Knowledge distillation · Intermediate representation · Manifold hypothesis
1
·
Introduction
Convolutional neural networks (CNNs) play a central role in computer vision for tasks such as an image classification [4,6,11]. However, recent studies have c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 683–691, 2023. https://doi.org/10.1007/978-981-99-1639-9_57
684
H. Higuchi et al.
demonstrated that adversarial perturbations, which are artificially made to induce misclassification in a CNN, can cause a drastic decrease in the classification accuracy [16]. In general, humans can naturally and correctly classify adversarial examples, which are images with adversarial perturbations. Since CNNs are originally inspired by human visual systems [4], they should be able to treat adversarial examples in the same way as natural images, like humans. Thus, a main challenge in training CNNs is improving their robustness against adversarial examples, as humans naturally do. et al. [13] introduced adverTo correctly classify adversarial examples, Madry sarial training, which uses adversarial examples instead of natural images for CNN training (Fig. 1(a)). Athalye et al. [1] found that only adversarial training improves classification robustness for adversarial examples, although diverse methods have been explored. Therefore, subsequent studies have been focused on improving adversarial training [9,10,15,17]. For instance, various methods improve the robustness by constraining the CNN output. For example, ALP [10] and TRADES [17] force the CNN output to be similar for adversarial examples and natural images during adversarial training (Fig. 1(b)). Hence, the corresponding CNNs provide similar outputs regardless of adversarial perturbation. More recent methods such as Smooth Logits [2] or LBGAT [3] employ knowledge distillation, whose constraints bring the outputs of a student (adversarialtrained) CNN closer to those of a teacher (pretrained) CNN (Fig. 1(c)). Knowledge distillation is effective for adversarial training because it enables the student CNN to imitate the decision boundary of the teacher CNN, which is sufficiently generalized after pretraining. Remarkably, knowledge distillation using intermediate representations rather than outputs in CNNs can further improve the classification performance for usual natural image classification [7,14]. This is because intermediate representations easily determine the decision boundary between classes and preserve intraclass cohesiveness. As the adversarial training methods in [2,3,10,17] focus only on outputs, the CNNs may not properly reflect intraclass cohesiveness. In contrast, if a CNN with adversarial training can use intermediate representations to similarly classify natural images and adversarial examples, its performance may be improved (Fig. 1(d)). Thus, intermediate representations of CNNs during adversarial training should be further explored. Accordingly, we analyzed CNNs trained with adversarial training by 1) visualizing the intermediate representations and 2) resembling knowledge distillation in intermediate representations to improve the performance of adversarial training. The contributions of this study can be summarized as follows: – We confirm phenomena observed in intermediate representations of CNNs trained with adversarial training. – We visualized intermediate representations and experimentally verify that cohesiveness is essential to correctly classify not only natural images but also adversarial examples.
Adversarial Training with Knowledge Distillation
685
– We introduce knowledge distillation using intermediate representations and demonstrate that this method is more effective than knowledge distillation at the outputs for improving the classification accuracy of both natural images and adversarial examples.
2
Experimental Analysis in CNNs with Adversarial Training
We begin by formulating adversarial training and investigate phenomena observed in the corresponding CNNs. First, we explain adversarial training and then explain phenomena observed in CNNs based on the ResNet-18 architecture [6] (Fig. 2).
Fig. 1. (a) – (c): Conventional methods for comparison. (d): Our proposed method based on adversarial training.
2.1
Adversarial Training
CNN training using natural images can be formulated as follows: min Ep(x,y) [Lclass (x, y; θ)] , θ
(1)
where (x, y) is the pair of input image and its true label and Lclass (x, y; θ) is the classification loss given CNN parameters θ and data (x, y). Cross-entropy is commonly used as the classification loss. Equation (1) can be interpreted as an optimization problem to search parameters θ that minimize classification loss Lclass .
686
H. Higuchi et al.
Meanwhile, adversarial training is formulated as follows: min Ep(x,y) max (Lclass (x + δ, y; θ)) , θ
δ :||δ ||q ≤
(2)
where δ is an adversarial perturbation bounded by an q -norm ball. Equation (2) solves the minimization in Eq. (1) after solving the maximization problem for loss function Lclass for x given θ, x, and y. Min–max optimization is repeated to obtain parameter θ that is robust to adversarial examples. As the maximization problem in Eq. (2) cannot be solved explicitly, it is often approximated by applying a strong attack method called projected gradient descent (PGD) [13]. 2.2
Visualization of Intermediate Representations in CNNs
We also evaluate intermediate representations between vanilla-CNN trained only with natural images and adv-CNN with conventional adversarial training [13]. Specifically, we visualize and compare intermediate representations of the CNNs by using t-SNE [12] for dimensionality reduction of intermediate representations. We use ResNet-18 [6] (Fig. 2) as the CNN and PGD for adversarial attack [13]. PGD performs strong adversarial attacks by repeatedly generating adversarial perturbations using the fast-gradient sign method [5]. In this study, we used 10 and 20 iterations for the adversarial attack during training and testing, respectively, and the CIFAR-10 as the image classification dataset.
Fig. 2. ResNet-18 architecture (L is described in Eq. (3)).
Fig. 3. Visualization of intermediate representations in vanilla-CNN (former two row) and adv-CNN (latter two row) for natural images and adversarial examples.
Adversarial Training with Knowledge Distillation
687
The former two row of Fig. 3 shows intermediate representations in vanillaCNN after dimensionality reduction. The upper-left graph of Fig. 3 shows the intermediate representations for natural images, while the upper-right graph shows the representations for adversarial examples. Vanilla-CNN can suitably gather intermediate representations for each class in natural images. The clusters of intermediate representations can contribute to higher classification accuracy for natural images. When the natural images are affected by adversarial perturbations, the clusters are dispersed in the feature space. Hence, adversarial examples degrade intraclass cohesiveness and cause a drastic decrease in the classification accuracy. The latter two row of Fig. 3 shows intermediate representations in adv-CNN. As shown in the figure, using adv-CNN, similar intermediate representations are obtained for adversarial examples and natural images. However, adv-CNN provides inferior intermediate representations for natural images compared with vanilla-CNN (lower-left graph of Fig. 3) In fact, adv-CNN provides lower accuracy than vanilla-CNN for natural images because it cannot establish clear decision boundaries to classify such images.
Fig. 4. Diagram of proposed method.
3
Proposed Method: Adversarial Training with Knowledge Distillation
As we mentioned above, vanilla-CNN should have acquired effective representations for classifying natural images. Therefore, in this section, we propose a novel method that adversarially trains the CNN while constraining its representation to preserving the one of vanilla-CNN for natural images. 3.1
Knowledge Distillation
Knowledge distillation [8] shares the representations and constrains the output of a student model from that of a teacher model. Hense, it improve the performance
688
H. Higuchi et al.
of the student model (training target). Among a lot of knowledge distillation method, we employed a method using intermediate constraint loss, which aims to bring intermediate representation of the student model closer to those in the teacher model [7,14]. 3.2
Adversarial Training with Knowledge Distillation
We propose an adversarial training method with knowledge distillation that employs a CNN trained with natural images as the teacher model. Figure 4 shows a diagram of the proposed method. The student model is the target of adversarial training, and the teacher vanilla-CNN accurately classifies natural images. We aim to make the intermediate representations of the training target similar to those of the teacher vanilla-CNN. Equation (3) shows the method formulation as an optimization problem. min Ep(x ,y) θ
max
δ :||δ ||q ≤
L L Lclass (x + δ, y; θ) + α · Linter (fstudent (x + δ), fteacher (x))
(3) The loss in Eq. (3) consists of two functions, classification loss Lclass and intermediate constraint loss Linter . In addition, f L is the intermediate representations of layer L and α is a hyperparameter that determines the contribution of Linter to training. Moreover, Lclass improves the classification accuracy for adversarial examples x + δ, and it is the same loss as in conventional adversarial training [13], while loss Linter makes the intermediate representations of the student L L (x + δ)) similar to those of vanilla-CNN (fteacher (x)). model (ftarget Table 1. Classification accuracy of evaluated CNNs. The value in boldface indicates the best result, and the underlined value indicates the second best result on each column. Model
Alpha Accuracy Accuracy (natural) (adv)
vanilla-CNN adv-CNN
-
0.949 0.847
0.0 0.483
outKD-CNN outKD-CNN outKD-CNN
0.01 0.1 0.5
0.849 0.855 0.857
0.486 0.500 0.502
interKD-CNN 1 interKD-CNN 50 interKD-CNN 100
0.850 0.870 0.866
0.493 0.522 0.521
Adversarial Training with Knowledge Distillation
689
Fig. 5. Visualization of intermediate representations in outKD-CNN (former two row) and interKD-CNN (latter two row) for natural images and adversarial examples.
4
Experimental Evaluation
We compared the proposed method with output constraining [2,3] and evaluated the constraint effectiveness. 4.1
Experimental Setup
We conducted experiments under the same conditions as in the experimental analysis reported in Sect. 2. We used the mean squared error as intermediate constraint loss Linter . Let us denote the CNN trained using the proposed method (Eq. (3), Fig. 4) for constrained layer L = 20 (i.e., using outputs as constraints) as outKD-CNN and the CNN with constrained layer L < 20 (i.e., using intermediate representations as constraints) as interKD-CNN. 4.2
Classification Accuracy
Table 1 lists the classification accuracy of vanilla-CNN trained only with natural images, adv-CNN trained with conventional adversarial training [13], outKDCNN [2,3], and interKD-CNN. We evaluated weight α ∈ {1, 50, 100} for outKDCNN and α ∈ {0.01, 0.1, 0.5} for interKD-CNN. The CNNs with adversarial training and knowledge distillation (outKD-CNN and interKD-CNN) tend to achieve higher accuracy than adv-CNN for natural images and adversarial examples. InterKD-CNN (α = 50, L = 17) exhibits the highest accuracy for adversarial examples and the second highest accuracy for natural images among the evaluated CNNs, even outperforming outKD-CNN. Thus, constraining intermediate representations seems more effective for improving the classification accuracy than constraining outputs. 4.3
Visualization of Intermediate Representations
To evaluate the representations obtained from training with the proposed method, we evaluated the CNN trained using proposed method in terms of intermediate representations, as in Sect. 2. Figure 5 (former two row) and Fig. 5
690
H. Higuchi et al.
(latter two row) show intermediate representations obtained from interKDCNN (α = 50, L = 17) and outKD-CNN (α = 0.5, L = 20), respectively. As shown in Fig. 5, interKD-CNN obviously has cohesive intermediate representations compared with outKD-CNN, as we expected. Hence, knowledge distillation in interKD-CNN effectively worked as an anchor to preserve the representations of each class for natural images provided by vanilla-CNN and promotes the classification accuracy.
5
Conclusions
After evaluating intermediate representations in CNNs, we found that training using only natural images provides effective intermediate representations in terms of classifying natural images, while conventional adversarial training does not. This indicates that intraclass cohesiveness is important to correctly classify natural images. Accordingly, we propose a method involving knowledge distillation using intermediate representations from a teacher CNN trained only using natural images to a student CNN with adversarial training. This method aims to preserve representations for natural images of the teacher, achieving a higher accuracy than CNNs with conventional adversarial training. As future works, we will further explore an effective training method in preserving representation for adversarial examples and achieving higher classification performance. Also, in this study, we used the mean squared error as the intermediate constraint loss to achieve similar intermediate representations for natural images and adversarial examples, but this loss may be inappropriate. In future work, we will explore more appropriate loss functions for constraining by considering the characteristics of intermediate representations (e.g., manifolds). Acknowledgement. This study was partly supported by MEXT KAKENHI, Grantin-Aid for Scientific Research on Innovative Areas 19H04982 and Grant-in-Aid for Scientific Research (A) 18H04106.
References 1. Athalye, A., Carlini, N., Wagner, D.: Obfuscated gradients give a false sense of security: circumventing defenses to adversarial examples. In: International Conference on Machine Learning (ICML) (2018) 2. Chen, T., Zhang, Z., Liu, S., Chang, S., Wang, Z.: Robust overfitting may be mitigated by properly learned smoothening. In: International Conference on Learning Representations (ICLR) (2020) 3. Cui, J., Liu, S., Wang, L., Jia, J.: Learnable boundary guided adversarial training. In: IEEE International Conference on Computer Vision (ICCV) (2021) 4. Fukushima, K.: Neocognitron: a self-organizing neural network model for a mechanism of pattern recognition unaffected by shift in position. Biol. Cybern. 36(4), 193–202 (1980) 5. Goodfellow, I.J., Shlens, J., Szegedy, C.: Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572 (2014)
Adversarial Training with Knowledge Distillation
691
6. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016) 7. Heo, B., Kim, J., Yun, S., Park, H., Kwak, N., Choi, J.Y.: A comprehensive overhaul of feature distillation. In: IEEE International Conference on Computer Vision (ICCV) (2019) 8. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. Stat (2015) 9. Ilyas, A., Santurkar, S., Tsipras, D., Engstrom, L., Tran, B., Madry, A.: Adversarial examples are not bugs, they are features. In: Advances in Neural Information Processing Systems (NeurIPS) (2019) 10. Kannan, H., Kurakin, A., Goodfellow, I.: Adversarial logit pairing. arXiv preprint arXiv:1803.06373 (2018) 11. Krizhevsky, A., Sutskever, I., Hinton, G.E.: ImageNet classification with deep convolutional neural networks. In: Advances in Neural Information Processing Systems (NeurIPS) (2012) 12. Van der Maaten, L., Hinton, G.: Visualizing data using t-SNE. J. Mach. Learn. Res. (JMLR) (2008) 13. Madry, A., Makelov, A., Schmidt, L., Tsipras, D., Vladu, A.: Towards deep learning models resistant to adversarial attacks. In: International Conference on Learning Representations (ICLR) (2018) 14. Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C., Bengio, Y.: Fitnets: hints for thin deep nets. arXiv preprint arXiv:1412.6550 (2014) 15. Stutz, D., Hein, M., Schiele, B.: Disentangling adversarial robustness and generalization. In: IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2019) 16. Szegedy, C., et al.: Intriguing properties of neural networks. In: International Conference on Learning Representations (ICLR) (2014) 17. Zhang, H., Yu, Y., Jiao, J., Xing, E., El Ghaoui, L., Jordan, M.: Theoretically principled trade-off between robustness and accuracy. In: International Conference on Machine Learning (ICML) (2019)
Deep Contrastive Multi-view Subspace Clustering Lei Cheng1 , Yongyong Chen1,2 , and Zhongyun Hua1,2(B) 1
2
School of Computer Science and Technology, Harbin Institute of Technology, Shenzhen, Shenzhen 518055, China {cyy2020,huazhongyun}@hit.edu.cn Guangdong Provincial Key Laboratory of Novel Security Intelligence Technologies, Harbin Institute of Technology, Shenzhen, Shenzhen 518055, China
Abstract. Multi-view subspace clustering has become a hot unsupervised learning task, since it could fuse complementary multi-view information from multiple data effectively. However, most existing methods either fail to incorporate the clustering process into the feature learning process, or cannot integrate multi-view relationships well into the data reconstruction process, which thus damages the final clustering performance. To overcome the above shortcomings, we propose the deep contrastive multi-view subspace clustering method (DCMSC), which is the first attempt to integrate the contrastive learning into deep multi-view subspace clustering. Specifically, DCMSC includes multiple autoencoders for self-expression learning to learn self-representation matrices for multiple views which would be fused into one unified self-representation matrix to effectively utilize the consistency and complementarity of multiple views. Meanwhile, to further exploit multi-view relations, DCMSC also introduces contrastive learning into multi-autoencoder network and Hilbert Schmidt Independence Criterion (HSIC) to better exploit complementarity. Extensive experiments on several real-world multi-view datasets demonstrate the effectiveness of our proposed method by comparing with state-of-the-art multi-view clustering methods. Keywords: Multi-view Subspace clustering Hilbert Schmidt Independence Criterion
1
· Contrastive learning ·
Introduction
Multi-view clustering, finding a consensus segmentation of data across multiple views, has become a hot unsupervised learning topic. Unlike single-view clustering, it faces multiple different descriptions or sources of the same data. How to fully exploit the consistency and complementarity of different views is the most significant challenge for multi-view clustering. Currently, multi-view clustering has made great progress and has played an important role in many practical applications. Most traditional methods first learned one common representation and then performed some single-view clustering methods. However, they would c The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 692–704, 2023. https://doi.org/10.1007/978-981-99-1639-9_58
Deep Contrastive Multi-view Subspace Clustering
693
ignore the high-dimensionality of data, and their performance is greatly reduced when the dimensions of each view are extremely unbalanced. Subspace clustering refers to find the underlying subspace structures of the data under the popular assumption that high-dimensional data could be well described in several low-dimensional subspaces. Recently, self-representationbased subspace clustering models have achieved great success. It assumes that each data point can be represented as a linear combination of other data points. Given a single-view data matrix consisting of n column vectors where each column represents a sample, self-representation properties can be formalized as follows: min L(X, C) + R(C) s.t. X = XC, C
(1)
where X = [x1 , x2 , · · · , xn ]. L(X, C) represents the self-representation loss and R(C) is the regularization term. Recently, multi-view subspace clustering methods have made great success by extending the single-view subspace clustering methods. In general, there are two main ways to exploit multi-view information. The first way is to learn a common representation first, and then self-representation is conducted on the learned common representation. For instance, the direct way is to concatenate all multi-view features to form a combined feature. The second approach is to first conduct self-representation on each view separately, and then fuse them. In recent years, to solve the problem of insufficient representation ability and possible non-linearity of original data, deep learning-based methods have been proposed. For example, a unified network architecture composed of multiple autoencoders are designed to integrate the process of feature learning and multi-view relationship exploration into data clustering in [1,3]. Among them, [3] propose a multi-view deep subspace clustering networks (MvDSCN) in which the multiview self-representation relation is learned by the end-to-end manner. Although they have achieved good results, there are still the following limitations: (1) Data representation learning and clustering processes are independently handled, and multi-view relationship cannot play a role in the feature extraction process. (2) Multi-view data reconstruction process is performed independently within each view, which ignores comprehensive information from multiple data. (3) They are unable to effectively handle the imbalance of multi-view dimensions. To overcome the above shortcomings, we propose the Deep Contrastive Multiview Subspace Clustering (DCMSC) method which mainly includes a base network to conduct self-representation learning and an additional module including Schmidt Independence Criterion (HSIC) regularizer and contrastive penalty. Specifically, the base network includes V independent autoencoders, each of which is used to extract the latent features of each original view, and the fully connected layer between the encoder and the decoder is used to obtain the self-representation matrix; HSIC part and contrastive learning part could effectively take advantage of multi-view relationships and mine the consistency and complementarity of multi-view data. Among them, the HSIC restriction module is used to punish the dependencies between the representations of each view and promote the diversity of subspace representations; the consistency of each sample point in each
694
L. Cheng et al.
view is achieved through the contrastive learning module. Finally, the combined result of the self-representation matrix obtained by all view-specific autoencoders is used to build the similarity matrix. The final clustering result is obtained by the spectral clustering algorithm. In summary, our contributions include: 1. For the first time, we integrate the general idea of contrastive learning into the multi-view subspace clustering problem and propose the deep contrastive multi-view subspace clustering method. 2. In DCMSC, the base network is mainly used to learn the view-specific selfrepresentation matrix constrained by the additional network which could make good use of the multi-view relationship, so that the fusion using the learned V self-representation matrices has a more powerful representation ability, thereby achieving better clustering performance. 3. The contrastive learning regards different views in multiple views as dataenhanced versions and aims to explore the common semantics among multiple views while the Hilbert Schmidt Independence Criterion is used to discover the diversity of multi-view features. Extensive experiments on a wide range of datasets demonstrate that DCMSC achieves state-of-the-art clustering effectiveness.
2 2.1
Related Works Subspace Clustering
Subspace clustering aims to reveal the inherent clustering structures of the data composed of multiple subspaces. Given a set of n samples [x1 , x2 , · · · , xN ] ∈ Rd×N in which d denotes the dimension of data, the basic model of selfrepresentation subspace clustering methods can be described as follows: min Cp + X − XC2F , C
(2)
where · p is an arbitrary matrix norm. After optimizing the above formula, the obtained self-representation matrix C could describe the subspace clustering relationship between data points, and then is input into the spectral clustering algorithm to obtain the final clustering result. For example, Sparse Subspace Clustering (SSC) [21] aims to enhance sparsity of self-representation by imposing l1 -norm regularization on the self-representation matrix. To discover multi-subspace structures, Low-rank representation (LRR) [22] explored the multi-block diagonal properties of self-representation matrix. Essentially, selfrepresentation based methods depend on assumption that each data point can be reconstructed by a linear combination of other points. However, the actual data may not meet this assumption. Many scholars have proposed using the kernel trick [4] to solve this problem, but the kernel technique is heuristic. With the powerful representation ability of neural networks, a large number of deep subspace clustering networks have been proposed to embed self-representation into deep autoencoder through fully connected layers, which has achieved the stateof-the-art performance. Deep adversarial subspace clustering leveraged the idea
Deep Contrastive Multi-view Subspace Clustering
695
of generative adversarial and added a GAN-like model into self-representation loss to evaluate clustering performance [23]. 2.2
Multi-view Subspace Clustering
The multi-view clustering problem is faced with multiple representations of the same data. Compared with single-view data, multi-view data contains consensus information and complementary information from multiple views [5,6]. How to effectively fuse the information of each view is the key to the multi-view clustering task. For existing multi-view subspace clustering methods, there are currently three main categories. The first category is to perform self-representation learning on each view individually, and then fuse the results of individual selfrepresentations. Divergent Multi-view Subspace Clustering (DiMSC) [5] proposed to exploit the complementarity from multi-view data by reducing redundancy. The second class firstly learns a common latent representation, and then performs self-representation learning on this latent representation. Latent Multiview Subspace Clustering (LMSC) [6] explored complementary information from different views while building latent representation. The third category combines the above two ideas. Reciprocal Multi-layer Subspace Learning (RMSL) [2] simultaneously constructed the view-specific subspace representations and common representation to mutually restore the subspace structure of the data through the Backward Encoding Networks (BEN) and the Hierarchical SelfRepresentative Layers (HSRL). Multi-view Deep Subspace Clustering Network (MvDSCN) [3] proposed a network that can simultaneously learn view-specific self-representation and common self-representation, and leverages HSIC to capture nonlinear and higher-order inter-view relationships. However, due to their complex network structures and objective design, it is difficult to well optimize each objective function at the same time in the optimization process, and they ignore the role of each view describing the data in exploring the data representation and clustering structure. 2.3
Contrastive Learning
[7,8] is a popular unsupervised learning paradigm in recent years, whose main idea is to make the similarity between positive pair as close as possible while negative pair as far as possible. This learning paradigm has achieved great success on computer vision [9]. For example, [11] proposed a one-stage online clustering method, which conducted contrastive learning both at instance-level and cluster-level. [12,13] introduced contrastive learning into multi-view clustering. For example, a contrastive multi-view encoding framework [13] has been designed to capture the latent scene semantics. Multi-level Feature Learning for Contrastive Multi-view Clustering (MFLVC) [10] proposes a flexible multi-view contrastive learning framework, which can simultaneously achieve the coherence goal of high-level features and semantic labels. However, to the best of our knowledge, there is no related work that exploits the idea of regarding each view as a data-augmented version in contrastive learning and applies the idea of contrastive learning into the multi-view subspace clustering task.
696
3
L. Cheng et al.
Proposed Method V
Given dataset with V views {X v ∈ Rdv ×N }v=1 , where X v = [xv1 , xv2 , · · · , xvN ] and dv and N is the number of data points and features in the v th view, respectively, our goal is to find a common self-representation matrix C that can express the relationship between data points among the multi-view data. In this section, we describe the Deep Contrastive Multi-view Subspace Clustering (DCMSC) in details.
Fig. 1. Illustration of the proposed Deep Contrastive Multi-view Subspace Clustering (DCMSC) method. DCMSC builds V parallel autoencoders for latent feature extraction of view-specific data in which self-representation learning is conducted by a fully connected layer between encoder and decoder. Specifically, v th original view X v is ˆ v through decoder. Self-representation matrix encoded as Z v and reconstructed to X v C is obtained by building a fully connected layer between the encoder and the decoder without activation function. Contrastive learning module is introduced into our network to exploit more common semantics information and HSIC constraint can effectively exploit the complementary information from multiple-view data, in which v th high level semantic representation H v is obtained for constructing contrastive loss. The final combination of all view-specific self-representation matrices further integrate the complementary and consistent information from multiple views.
3.1
The Proposed DCMSC
The network architecture of the proposed DCMSC method is shown in the Fig. 1, which consists of two modules, i.e., the base net which learns view-specific representation {C v }Vv=1 and the additional module including contrastive learning part and HSIC part which further exploits the multi-view relationship. In details, the
Deep Contrastive Multi-view Subspace Clustering
697
base net consists of V autoencoders, each of which conducts self-representation learning for each view-specific data. The encoder can be regarded as a function that simultaneously plays the role of dimensionality reduction and nonlinear conversion and the decoder is used to reconstruct the input features. Then selfrepresentation is conducted by a fully connected layer without linear activation function and bias, which is built between the encoder and the decoder. Combining reconstruction loss in autoencoder into basic self-representation model i.e., Eq. (2), the loss of v th autoencoder is summarized as follows: ˆ v 2F , min C v 2F + Z v − Z v C v 2F + X v − X v
(3)
C
ˆ v is the reconstruction where Z v is the output of encoder for v th view X v and X v of X . Complementary and consistent information in multiple views can be exploited by summing the self-representation matrices of each view at the end. The study in [3,5] has proposed to introduce HSIC which measures the nonlinear and high-order correlations into multi-view subspace clustering to exploit more complementary information. Here, we adopt the empirical definition of HSIC proposed in [3]: HSIC(C m , C n ), (4) Lhsic = ij
where C m and C n denote mth and nth self-representation matrix respectively, HSIC((C m , C n ) = trace((C m )T C m H(C n )T C n H) and H is a N × N square matrix with element 1 − N1 . It is worth noting that data reconstruction only in specific view cannot well exploit multi-view relational information. To alleviate this problem, we propose to introduce contrastive learning into our framework. Specifically, our contrastive learning module consists of a fully connected layer shared by all views. As shown in Fig. 1, let H m denotes the output of the contrastive learning module for the latent representation of the mth view Z m as mth high level semantic representation. Each high-level feature hm i has (V N − 1) feature pairs, i.e., m n n=1,··· ,V n {hi , hj }j=1,··· ,N , which consist of (V − 1) positive pairs {hm i , hi }n=m,··· ,N and V (N − 1) negative pairs left. Contrastive learning aims to maximize the similarities of positive pairs while minimize that of negative pairs. Specifically, the contrastive loss between H m and H n is defined as [10]: (mn)
f c
=−
N 1 log N N i=1 j=1
m
ed(hi v=m,n
,hn i )/τF m ,hv )/τ F j
ed(hi
− e1/τF
,
(5)
where d(x, y) measures the similarity between sample x and sample y and τF denotes the temperature parameter. Inspired by NT-Xent [7], we apply cosine distance: x, y d(x, y) = (6) . xy
698
L. Cheng et al.
The final contrastive loss is designed as accumulated losses among all views: Lcon =
v
(mn)
m=1 n=m
f c
.
(7)
Mathematically, the loss function of DCMSC is formulated by combining the above tentative loss function in Eqs. (3), (4), (7) as follows: Lfinal = Lae + α1 Lself + α2 Lreg + α3 Lhsic + α4 Lcon =
V
ˆ v 2F + α1 X v − X
v=1
+ α3
V v=1
V
HSIC(Z i , Z j ) + α4
V v=1
V
C v 2F
v=1 V
V
m=1 n=m
ij
where Lae =
Z v − Z v C v 2F + α2 (mn)
f c
(8)
,
V V ˆ v 2 , Lself = Z v −Z v C v 2 and Lreg = C v 2 . X v − X F F F v=1
v=1
Parameters α1 , α2 , α3 , and α4 are non-negative ones to balance different contributions of different terms. 3.2
Optimization
The whole process of the proposed LDLRSC is summarized in Algorithm 1. We first pre-train the network without self-representation layer for more effective training in fine-tune stage and prevention of possible all-zero solution [3]. After thefine-tune stage, the final self-representation matrix C is calculated V v as C = v=1 C . Generally, we can construct the affinity matrix simply by T (|C| + |C| )/2 for spectral clustering. Here, we adopt the heuristic employed by SSC [21], which has been proved beneficial for clustering.
4 4.1
Experiment Experimental Settings
Datasets. We conduct experiments on 6 benchmark multi-view datasets to evaluate our proposed DCMSC, including 4 classical datasets: Yale, ORL, Still DB and BBCSport and 2 bigger datasets: Caltech and BDGP. More details are listed in Table 1. Evaluation Metrics. We adopt 4 widely used metrics to evaluate the clustering performance: accuracy (ACC), normalized mutual information (NMI), purity (PUR) and The F-measure. Note that higher values indicate better performance for the above 4 metrics. Parameters will be optimized to achieve the best clustering performance for all experiments. The average metric of 10 trials over each dataset is reported.
Deep Contrastive Multi-view Subspace Clustering
699
Algorithm 1. DCMSC Input: Multi-view data [X 1 , X 2 , · · · , X V ]; Maximum iteration Tmax ; Trade-off parameters α1 , α2 , α3 , α4 ; The number of cluster K; Output: Clustering result L; 1: Pre-train V autoencoders without self-representation layer; 2: Initialize the self-expression layer and contrastive learning net; 3: while t ≤ Tmax do 4: Calculate the loss (8) and its gradient; 5: Do forward propagation; 6: end while 7: Calculate the final self-representation matrix C = Vv=1 C v ; 8: Run algorithm employed by [21] to obtain affinity matrix A; 9: Run spectral clustering to get the clustering results L. Table 1. The details of the datasets. Datasets
#Samples #Views #Classes Dimension of features
Yale
165
3
11
4096 /3304/6750
ORL
400
3
10
Still DB
476
3
6
200/200/200
4096/3304/6750
BBCSport
544
2
5
3183/3203
BDGP
2,500
2
5
1750/79
Caltech-3V 1,400
3
7
40/254/1984
Caltech-5V 1,400
5
7
40/254/1984/512/928
Comparison Methods. The comparison methods include some traditional state-of-the-art methods for both multi-view subspace clustering and deep multiview clustering: BestSV [24], LRR [22], RMSC [31], DSCN [25], DCSC [26], DC [27], DMF [28], LMSC [6], MSCN [29], MvDSCN [3], RMSL [2], MVC-LFA [15], COMIC [16], IMVTSC [18], CDIMC-net [30], EAMC [17], SiMVC [20], CoMVC [20], MFLVC [10]. Implementation. We implement our DCMSC method on TensorFlow-2 in Python and evaluate its performance on several baseline methods. Adam optimizer is adopted for the gradient descent and the learning rate of the network is set to 1e−3 . We choose ReLU as the activation function in the network except the self-expression layer. 4.2
Experimental Results
We compared DCMSC mainly with 8 subspace-based multi-view clustering algorithms on 4 datasets. To evaluate the superiority and robustness of our method, we also conduct experiments on 2 big datasets and compared its performance with 6 state-of-the-art multi-view clustering algorithms. The results are given in
700
L. Cheng et al.
Table 2 and Table 3. From Table 2, we can see that the proposed DCMSC significantly outperforms all methods on the first two datasets and performs comparable performance on the last two datasets. Obviously, DCMSC boosts the clustering performance by a large margin over other methods on Yale. The improvement of the proposed DCMSC over the second-best method FMR are 10.1%, 10.2%, and 18.5% with respect to NMI, ACC, and F-measure, respectively. From Table 3, there are following results: (1) our method obtain can also obtain competitive clustering performance on big data; (2) DCMSC greatly improves the clustering performance on Caltech-5V. In addition, we observe that although RMSL behaves on some benchmark datasets of multi-view subspace clustering, it does not obtain very competitive performances on BDGP and Caltech. In contrast, our method still maintains decent performance on other datasets. 4.3
Visualization
To intuitively show the superiority of DCMSSC, we visualized the affinity matrix A on BBCSport, ORL and Yale in Fig. 2, where Aij denotes the similarity between sample xi and sample xj . Affinity A could be obtained from the final self-representation matrix C by algorithm employed by [21]. Noting the data points are sorted by classes on the above 3 datasets, the affinity matrix A should have a block-diagonal structure ideally. From Fig. 2, we can see that the affinity A learned by our proposed DCMSC well exhibits the block-diagonal property compared with MvDSCN.
Fig. 2. Visualization of learned affinity matrix on BCCSport, ORL and Yale.
4.4
Ablation Studies
We conducted ablation studies on Lcon on Yale and Caltech-5V to illustrate the effectiveness of our contrastive learning module. Table 4 shows that our method
Deep Contrastive Multi-view Subspace Clustering
701
Table 2. Results of all methods on four small datasets. Bold indicates the best and underline indicates the second-best. Datasets
Yale
Metrics
NMI
ACC
F-measure NMI
ORL ACC
F-measure NMI
Still DB ACC
F-measure NMI
ACC
F-measure
BestSV LRR RMSC DSCN DCSC DC LMSC DMF MSCN MvDSCN RMSL DCMSC
0.654 0.709 0.684 0.738 0.744 0.756 0.702 0.782 0.769 0.797 0.831 0.944
0.616 0.697 0.642 0.727 0.733 0.766 0.670 0.745 0.772 0.824 0.879 0.955
0.475 0.547 0.517 0.542 0.556 0.579 0.506 0.601 0.582 0.626 0.828 0.907
0.777 0.773 0.723 0.801 0.811 0.788 0.819 0.823 0.833 0.870 0.881 0.931
0.711 0.731 0.654 0.711 0.718 0.701 0.758 0.773 0.787 0.834 0.842 0.911
0.297 0.306 0.285 0.323 0.325 0.315 0.328 0.336 0.312 0.377 0.336 0.388
0.221 0.240 0.232 0.293 0.301 0.280 0.269 0.265 0.261 0.320 0.293 0.284
0.836 0.832 0.737 0.821 0.843 0.724 0.900 0.890 0.888 0.931 0.976 0.953
0.768 0.774 0.655 0.683 0.712 0.492 0.887 0.889 0.854 0.860 0.954 0.907
0.903 0.895 0.872 0.883 0.893 0.865 0.931 0.933 0.928 0.943 0.950 0.970
0.104 0.109 0.106 0.216 0.222 0.199 0.137 0.154 0.168 0.245 0.135 0.156
BBCSport 0.715 0.690 0.608 0.652 0.683 0.556 0.826 0.821 0.813 0.835 0.917 0.864
Table 3. Results of all methods on BDGP, Caltech-3V and Caltech-5V. Bold indicates the best and underline indicates the second-best. Datasets
BDGP
Evaluation metrics
ACC
NMI
PUR
Caltech-3V ACC
NMI
PUR
Caltech-5V ACC
NMI
PUR
RMSL [2] (2019) MVC-LFA [15] (2019) COMIC [16] (2019) CDIMC-net [19] (2020) EAMC [17] (2020) IMVTSC-MVI [18] (2021) SiMVC [20] (2021) CoMVC [20] (2021) MFLVC [10] (2022) DCMSC
0.849 0.564 0.578 0.884 0.681 0.981 0.704 0.802 0.989 0.985
0.630 0.395 0.642 0.799 0.480 0.950 0.545 0.670 0.966 0.957
0.849 0.612 0.639 0.885 0.697 0.982 0.723 0.803 0.989 0.985
0.596 0.551 0.447 0.528 0.389 0.558 0.569 0.541 0.631 0.890
0.551 0.423 0.491 0.483 0.214 0.445 0.495 0.504 0.566 0.785
0.608 0.578 0.575 0.565 0.398 0.576 0.591 0.584 0.639 0.890
0.354 0.741 0.532 0.727 0.318 0.760 0.719 0.700 0.804 0.914
0.340 0.601 0.549 0.692 0.173 0.691 0.677 0.687 0.703 0.825
0.391 0.747 0.604 0.742 0.342 0.785 0.729 0.746 0.804 0.914
achieves good results even without Lcon , and the better effect could be obtained with the Lcon , which shows contrastive learning works to improve the performance for multi-view subspace task due to its ability to exploit more comprehensive relationship in multi-view data. Table 4. Ablation studies for contrastive learning structures on Yale and Caltech-5V. Datasets
Yale
Caltech-5V
Evaluation metrics ACC
NMI
PUR
ACC
NMI
PUR
w/o Lcon
0.912
0.826
0.912
0.874
0.782
0.874
w/ Lcon
0.955 0.944 0.955 0.914 0.825 0.914
702
5
L. Cheng et al.
Conclusion
In this paper, we proposed a novel method named Deep Contrastive Multiview Subspace Clustering (DCMSC) to exploit the multi-view relationship by combining multiple self-representation matrix and introducing contrastive learning into the networks for exploring more consistent information. DCMSC consists of the base network composed of V autoencoders by which V view-specific self-representation matrices are learned. In addition, HSIC regularizer and contrastive learning module are included in our base network to exploit more comprehensive information. Experiments on both benchmark datasets and two bigger datasets verify the superiority and robustness of our method compared with the state-of-the-arts methods. Acknowledgements. This work was supported in part by the National Natural Science Foundation of China under Grants 62071142 and 62106063, by the Guangdong Basic and Applied Basic Research Foundation under Grant 2021A1515011406, by the Shenzhen College Stability Support Plan under Grants GXWD2020123015542700320200824210638001 and GXWD20201230155427003-20200824113231001, by the Guangdong Natural Science Foundation under Grant 2022A1515010819, and by Guangdong Provincial Key Laboratory of Novel Security Intelligence Technologies under Grant 2022B1212010005.
References 1. Rui, M., Zhiping, Z.: Deep multi-view subspace clustering network with exclusive constraint. In: 2021 40th Chinese Control Conference (CCC), pp. 7062–7067 (2021) 2. Li, R., Zhang, C., Fu, H., Peng, X., Zhou, T., Hu, Q.: Reciprocal multi-layer subspace learning for multi-view clustering. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8172–8180 (2019) 3. Zhu, P., Hui, B., Zhang, C., Du, D., Wen, L., Hu, Q.: Multi-view deep subspace clustering networks. arXiv Preprint arXiv:1908.01978 (2019) 4. Patel, V., Van Nguyen, H., Vidal, R.: Latent space sparse subspace clustering. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 225–232 (2013) 5. Cao, X., Zhang, C., Fu, H., Liu, S., Zhang, H.: Diversity-induced multi-view subspace clustering. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 586–594 (2015) 6. Zhang, C., Hu, Q., Fu, H., Zhu, P., Cao, X.: Latent multi-view subspace clustering. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4279–4287 (2017) 7. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607 (2020) 8. Wang, T., Isola, P.: Understanding contrastive representation learning through alignment and uniformity on the hypersphere. In: International Conference on Machine Learning, pp. 9929–9939 (2020)
Deep Contrastive Multi-view Subspace Clustering
703
9. Van Gansbeke, W., Vandenhende, S., Georgoulis, S., Proesmans, M., Van Gool, L.: SCAN: learning to classify images without labels. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12355, pp. 268–285. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58607-2 16 10. Xu, J., Tang, H., Ren, Y., Peng, L., Zhu, X., He, L.: Multi-level feature learning for contrastive multi-view clustering. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 16051–16060 (2022) 11. Li, Y., Hu, P., Liu, Z., Peng, D., Zhou, J., Peng, X.: Contrastive clustering. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 8547–8555 (2021) 12. Hassani, K., Khasahmadi, A.: Contrastive multi-view representation learning on graphs. In: International Conference on Machine Learning, pp. 4116–4126 (2020) 13. Tian, Y., Krishnan, D., Isola, P.: Contrastive multiview coding. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12356, pp. 776–794. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58621-8 45 14. Ji, P., Zhang, T., Li, H., Salzmann, M., Reid, I.: Deep subspace clustering networks. In: Advances in Neural Information Processing Systems, pp. 24–33 (2017) 15. Wang, S., et al.: Multi-view clustering via late fusion alignment maximization. In: IJCAI, pp. 3778–3784 (2019) 16. Peng, X., Huang, Z., Lv, J., Zhu, H., Zhou, J.: COMIC: multi-view clustering without parameter selection. In: International Conference on Machine Learning, pp. 5092–5101 (2019) 17. Zhou, R., Shen, Y.: End-to-end adversarial-attention network for multi-modal clustering. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 14619–14628 (2020) 18. Wen, J., et al.: Unified tensor framework for incomplete multi-view clustering and missing-view inferring. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp. 10273–10281 (2021) 19. Wen, J., Zhang, Z., Xu, Y., Zhang, B., Fei, L., Xie, G.: CDIMC-net: cognitive deep incomplete multi-view clustering network. In: IJCAI, pp. 3230–3236 (2020) 20. Trosten, D., Løkse, S., Jenssen, R., Kampffmeyer, M.: Reconsidering representation alignment for multi-view clustering. In: CVPR, pp. 1255–1265 (2021) 21. Elhamifar, E., Vidal, R.: Sparse subspace clustering: algorithm, theory, and applications. IEEE Trans. Pattern Anal. Mach. Intell. 2765–2781 (2013) 22. Liu, G., Lin, Z., Yan, S., Sun, J., Yu, Y., Ma, Y.: Robust recovery of subspace structures by low-rank representation. IEEE Trans. Pattern Anal. Mach. Intell. 171–184 (2012) 23. Zhou, P., Hou, Y., Feng, J.: Deep adversarial subspace clustering. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1596– 1604 (2018) 24. Ng, A., Jordan, M., Weiss, Y.: On spectral clustering: analysis and an algorithm. In: Advances in Neural Information Processing Systems, vol. 14 (2001) 25. Peng, X., Xiao, S., Feng, J., Yau, W., Yi, Z.: Deep subspace clustering with sparsity prior. In: IJCAI, pp. 1925–1931 (2016) 26. Jiang, Y., Yang, Z., Xu, Q., Cao, X., Huang, Q.: When to learn what: deep cognitive subspace clustering. In: Proceedings of the 26th ACM International Conference on Multimedia, pp. 718–726 (2018) 27. Caron, M., Bojanowski, P., Joulin, A., Douze, M. Deep clustering for unsupervised learning of visual features. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 132–149 (2018) 28. Zhao, H., Ding, Z., Fu, Y.: Multi-view clustering via deep matrix factorization. In: Thirty-First AAAI Conference on Artificial Intelligence (2017)
704
L. Cheng et al.
29. Abavisani, M., Patel, V.: Deep multimodal subspace clustering networks. IEEE J. Sel. Top. Signal Process. 12, 1601–1614 (2018) 30. Wen, J., Zhang, Z., Xu, Y., Zhang, B., Fei, L., Xie, G.: CDIMC-net: cognitive deep incomplete multi-view clustering network. In: Proceedings of the TwentyNinth International Joint Conference on Artificial Intelligence, IJCAI 2020, pp. 3230–3236 (2020) 31. Xia, R., Pan, Y., Du, L., Yin, J.: Robust multi-view spectral clustering via low-rank and sparse decomposition. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 28 (2014)
Author Index
A Abulaish, Muhammad
477
B Bakhtiari, Mehrdad Mohannazadeh Bhattacharya, Anwesh 610 Boiarov, Andrei 373 Bortkiewicz, Michał 457 C Cabessa, Jérémie 622 Cao, Shoufeng 562 Cao, Ze 53 Chai, Hongfeng 446 Chakrabarty, Dalia 275 Chen, Dianying 597 Chen, Jun 597 Chen, Min 3 Chen, Si 165 Chen, Yi 153 Chen, Yongyong 692 Cheng, Lei 692 Cheng, Zi-hao 28, 53 Cui, Laiping 348 D Dai, Qi 15 Devi, V. Susheela 313 Ding, Chris 79 Ding, Yuxin 513 Dong, Chuanhao 177 Duan, Huajuan 177 Dubi´nski, Jan 457 F Fang, Gaoyun 671 Feng, Kun 538 Feng, Ziming 239
586
G Gałkowski, Tomasz 251 Ganzha, Maria 574 Gao, Junfeng 538 Ge, Yuhang 336 Grzechoci´nski, Jakub 263 Gu, Ming 409 Guo, Qimeng 177 Guo, Xiangzhe 385 H He, Chenghai 562 He, Haoming 79 He, Qiyi 119 He, Yangliu 65 Higuchi, Hikaru 683 Hu, Xuegang 336 Hua, Zhongyun 692 Huang, Jiayang 215 Huang, Xiaofang 489 Huang, Xinlei 433, 647 J Jia, Haonan 634 Jia, Yan 300 Jiang, Changhua 215 Jiang, Hao 634 Jiang, Ning 433, 647 Jiang, Xuesong 597 Jin, Tianlei 106 K Kang, Yongxin 189 Katz, Guy 92 Khabarlak, Kostiantyn 373 Kimura, Masanari 468 Krzy˙zak, Adam 251
© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2023 M. Tanveer et al. (Eds.): ICONIP 2022, CCIS 1791, pp. 705–707, 2023. https://doi.org/10.1007/978-981-99-1639-9
706
L Leung, Chi-Sing 202 Levy, Natan 92 Lewy, Dominik 574 Li, Chenggang 489, 550 Li, Hailing 562 Li, Jun 361 Li, Junlong 336 Li, Kai 189 Li, Peipei 336 Li, Qinglin 420 Li, Zhouzheng 538 Liang, Kun 130 Liao, Jianxin 65 Liu, Donghong 501 Liu, Guojing 348 Liu, Hanwen 153 Liu, Jian- wei 15, 28 Liu, Jian-wei 40, 53 Liu, Jing 671 Liu, Peiyu 177 Liu, Yang 671 Liu, Yanli 420 Lu, Wenhao 202 Luo, Cui 300 Luo, Xiaoling 153 Luo, Zhigang 227 M Ma, Kaiyang 348 Ma, Yubin 513 Małkowski, Adam 263 Ma´ndziuk, Jacek 574 Meng, Junfeng 65 Meng, Qiwei 106 Miao, Dongyan 538 Mo, Chong 300 Mu, Zonghao 106 Mushtaq, Umer 622 N Nagaraj, Nithin 610 P Pan, Yi 3 Pan, Yuchen 361 Paprzycki, Marcin 574 Pawlak, Stanisław 457 Peng, Anjie 489, 550
Author Index
Pu, Zhiqiang 3 Pulabaigari, Viswanath
659
Q Qian, Wen 513 Qin, Yeguang 142 Qing, Haifeng 433, 647 Qiu, Xuetao 446 Qu, Hong 153 R Ren, Yimeng 130 Roy, Gargi 275 Ru, Dongyu 239 S Sah, Amit Kumar 477 Saha, Snehanshu 610 Sanodiya, Rakesh Kumar Shang, Yuhu 130 She, Xiaohan 446 Shen, Qiwei 65 Shouno, Hayaru 683 Song, Liang 671 Song, Wei 106 Su, Kejia 215 Su, Ruidan 385 Su, Yiqi 513 Sum, John 202 Sun, Weimin 526 Suryavardan, S. 659 Suzuki, Satoshi 683 Szatkowski, Filip 457 T Tan, Zehan 397 Tang, Fengxiao 142 Tang, Hao 501 Tang, Jialiang 433, 647 Trzci´nski, Tomasz 457 Tu, Shikui 385 V Villmann, Thomas Vivek, A. 313
586
W Wan, Bo 215 Wang, Da-Han 165
659
Author Index
Wang, Fei 215 Wang, Guangbin 513 Wang, Haobo 336 Wang, Kun 550 Wang, Libo 165 Wang, Mengzhu 227 Wang, Ming-hui 40 Wang, Mingwei 119 Wang, Shanshan 227 Wang, Wen 106 Wang, Ye 300 Wang, Yu 348 Wang, Yun 397 Wang, Zhen 165 Wawrzy´nski, Paweł 263 Wei, Xiumei 597 Wu, Wenqing 433, 647 Wu, Xianze 239 Wu, Zhiyuan 550 Wu, Zongze 79 X Xi, Xiangming 106 Xing, Guanyu 420 Xing, Junliang 189 Xing, Xiang 28 Xiong, Bang 215 Xiong, Gang 562 Xiong, Haoxuan 324 Xiong, Yuanchun 288 Xu, Jianhua 361 Xu, Jun 634 Xu, Lei 385 Xu, Liancheng 177 Xu, Xinhai 501 Xu, Yuanyuan 324 Y Yan, Tianwei 227 Yan, Xiaohui 634
707
Yang, Gang 526 Yang, Hua 409 Yang, Jia- peng 15 Yang, Weidong 397 Yang, Yanming 446 Yang, Zhenyu 348 Yastrebov, Igor 373 Ye, Zhiwei 119 Yi, Jianqiang 3 Yu, Wenxin 489, 550 Yu, Yong 239 Z Zang, Yifan 189 Zeng, Deyu 79 Zeng, Hui 489, 550 Zeng, Wei 634 Zeng, Zehua 446 Zhang, Baowen 288 Zhang, Feng 501 Zhang, Haichao 119 Zhang, Kai 562 Zhang, Malu 153 Zhang, Peng 433, 647 Zhang, Weinan 239 Zhang, Xiankun 130 Zhang, Xiao 634 Zhang, Xiaohang 562 Zhang, Yiying 130 Zhao, Enmin 189 Zhao, Junbo 336 Zhao, Ming 142 Zhao, Mingyu 397 Zhou, Li 562 Zhou, Zihan 513 Zhu, Honglin 433, 647 Zhu, Ping 489, 550 Zhu, Shiqiang 106 Zhu, Shunzhi 165 Zhu, Yusen 142