678 19 11MB
English Pages XII, 300 [307] Year 2021
Advances in Intelligent Systems and Computing 1232
M. Arif Wani Taghi M. Khoshgoftaar Vasile Palade Editors
Deep Learning Applications, Volume 2
Advances in Intelligent Systems and Computing Volume 1232
Series Editor Janusz Kacprzyk, Systems Research Institute, Polish Academy of Sciences, Warsaw, Poland Advisory Editors Nikhil R. Pal, Indian Statistical Institute, Kolkata, India Rafael Bello Perez, Faculty of Mathematics, Physics and Computing, Universidad Central de Las Villas, Santa Clara, Cuba Emilio S. Corchado, University of Salamanca, Salamanca, Spain Hani Hagras, School of Computer Science and Electronic Engineering, University of Essex, Colchester, UK László T. Kóczy, Department of Automation, Széchenyi István University, Gyor, Hungary Vladik Kreinovich, Department of Computer Science, University of Texas at El Paso, El Paso, TX, USA Chin-Teng Lin, Department of Electrical Engineering, National Chiao Tung University, Hsinchu, Taiwan Jie Lu, Faculty of Engineering and Information Technology, University of Technology Sydney, Sydney, NSW, Australia Patricia Melin, Graduate Program of Computer Science, Tijuana Institute of Technology, Tijuana, Mexico Nadia Nedjah, Department of Electronics Engineering, University of Rio de Janeiro, Rio de Janeiro, Brazil Ngoc Thanh Nguyen , Faculty of Computer Science and Management, Wrocław University of Technology, Wrocław, Poland Jun Wang, Department of Mechanical and Automation Engineering, The Chinese University of Hong Kong, Shatin, Hong Kong
The series “Advances in Intelligent Systems and Computing” contains publications on theory, applications, and design methods of Intelligent Systems and Intelligent Computing. Virtually all disciplines such as engineering, natural sciences, computer and information science, ICT, economics, business, e-commerce, environment, healthcare, life science are covered. The list of topics spans all the areas of modern intelligent systems and computing such as: computational intelligence, soft computing including neural networks, fuzzy systems, evolutionary computing and the fusion of these paradigms, social intelligence, ambient intelligence, computational neuroscience, artificial life, virtual worlds and society, cognitive science and systems, Perception and Vision, DNA and immune based systems, self-organizing and adaptive systems, e-Learning and teaching, human-centered and human-centric computing, recommender systems, intelligent control, robotics and mechatronics including human-machine teaming, knowledge-based paradigms, learning paradigms, machine ethics, intelligent data analysis, knowledge management, intelligent agents, intelligent decision making and support, intelligent network security, trust management, interactive entertainment, Web intelligence and multimedia. The publications within “Advances in Intelligent Systems and Computing” are primarily proceedings of important conferences, symposia and congresses. They cover significant recent developments in the field, both of a foundational and applicable character. An important characteristic feature of the series is the short publication time and world-wide distribution. This permits a rapid and broad dissemination of research results. ** Indexing: The books of this series are submitted to ISI Proceedings, EI-Compendex, DBLP, SCOPUS, Google Scholar and Springerlink **
More information about this series at http://www.springer.com/series/11156
M. Arif Wani Taghi M. Khoshgoftaar Vasile Palade •
•
Editors
Deep Learning Applications, Volume 2
123
Editors M. Arif Wani Department of Computer Science University of Kashmir Srinagar, India
Taghi M. Khoshgoftaar Computer and Electrical Engineering Florida Atlantic University Boca Raton, FL, USA
Vasile Palade Faculty of Engineering and Computing Coventry University Coventry, UK
ISSN 2194-5357 ISSN 2194-5365 (electronic) Advances in Intelligent Systems and Computing ISBN 978-981-15-6758-2 ISBN 978-981-15-6759-9 (eBook) https://doi.org/10.1007/978-981-15-6759-9 © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Singapore Pte Ltd. The registered company address is: 152 Beach Road, #21-01/04 Gateway East, Singapore 189721, Singapore
Preface
Machine learning algorithms have influenced many aspects of our day-to-day living and transformed major industries around the world. Fueled by an exponential growth of data, improvements in computer hardware, scalable cloud resources, and accessible open-source frameworks, machine learning technology is being used by companies in big and small alike for innumerable applications. At home, machine learning models are suggesting TV shows, movies, and music for entertainment, providing personalized ecommerce suggestions, shaping our digital social networks, and improving the efficiency of our appliances. At work, these data-driven methods are filtering our emails, forecasting trends in productivity and sales, targeting customers with advertisements, improving the quality of video conferences, and guiding critical decisions. At the frontier of machine learning innovation are deep learning systems, a class of multi-layered networks is capable of automatically learning meaningful hierarchical representations from a variety of structured and unstructured data. Breakthroughs in deep learning allow us to generate new representations, extract knowledge, and draw inferences from raw images, video streams, text and speech, time series, and other complex data types. These powerful deep learning methods are being applied to new and exciting real-world problems in medical diagnostics, factory automation, public safety, environmental sciences, autonomous transportation, military applications, and much more. The family of deep learning architectures continues to grow as new methods and techniques are developed to address a wide variety of problems. A deep learning network is composed of multiple layers that form universal approximators capable of learning any function. For example, the convolutional layers in Convolutional Neural Networks use shared weights and spatial invariance to efficiently learn hierarchical representations from images, natural language, and temporal data. Recurrent Neural Networks use backpropagation through time to learn from variable length sequential data. Long Short-Term Memory networks are a type of recurrent network capable of learning order dependence in sequence prediction problems. Deep Belief Networks, Autoencoders, and other unsupervised models generate meaningful latent features for downstream tasks and model the underlying concepts of distributions by reconstructing their inputs. Generative Adversarial v
vi
Preface
Networks simultaneously learn generative models capable of producing new data from distribution and discriminative models that can distinguish between real and artificial images. Transformer Networks combine encoders and decoders with attention layers for improved sequence-to-sequence learning. Network architecture search automates the designs of these deep models by optimizing performance over the hyperparameter space. As a result of these advances, and many others, deep learning is revolutionizing complex problem domains with state-of-the-art results and, in some cases, is a way superior to the human performances. This book explores some of the latest applications in deep learning and includes a variety of architectures and novel deep learning techniques. Deep models are trained to recommend products, diagnose medical conditions or faults in industrial machines, detect when a human falls, and recognize solar panels in aerial images. Sequence models are used to capture driving behaviors and identify radio transmitters from temporal data. Residual networks are used to detect human targets in indoor environments, algorithm incorporating thresholding strategy is used to identify fraud within highly imbalanced data, and hybrid methods are used to locate vehicles during satellite outages. Multi-adversarial variational autoencoder network is used for image synthesis and classification and finally parameter continuation method is used for non-convex optimization of deep neural networks. We believe that these recent deep learning methods and applications illustrated in this book capture some of the most exciting advances in deep learning. Srinagar, India Boca Raton, USA Coventry, UK
M. Arif Wani Taghi M. Khoshgoftaar Vasile Palade
Contents
Deep Learning-Based Recommender Systems . . . . . . . . . . . . . . . . . . . . Meshal Alfarhood and Jianlin Cheng
1
A Comprehensive Set of Novel Residual Blocks for Deep Learning Architectures for Diagnosis of Retinal Diseases from Optical Coherence Tomography Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Sharif Amit Kamran, Sourajit Saha, Ali Shihab Sabbir, and Alireza Tavakkoli
25
Three-Stream Convolutional Neural Network for Human Fall Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Guilherme Vieira Leite, Gabriel Pellegrino da Silva, and Helio Pedrini
49
Diagnosis of Bearing Faults in Electrical Machines Using Long Short-Term Memory (LSTM) . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Russell Sabir, Daniele Rosato, Sven Hartmann, and Clemens Gühmann
81
Automatic Solar Panel Detection from High-Resolution Orthoimagery Using Deep Learning Segmentation Networks . . . . . . . . . . . . . . . . . . . . 101 Tahir Mujtaba and M. Arif Wani Training Deep Learning Sequence Models to Understand Driver Behavior . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 123 Shokoufeh Monjezi Kouchak and Ashraf Gaffar Exploiting Spatio-Temporal Correlation in RF Data Using Deep Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 143 Debashri Roy, Tathagata Mukherjee, and Eduardo Pasiliao Human Target Detection and Localization with Radars Using Deep Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 173 Michael Stephan, Avik Santra, and Georg Fischer
vii
viii
Contents
Thresholding Strategies for Deep Learning with Highly Imbalanced Big Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 199 Justin M. Johnson and Taghi M. Khoshgoftaar Vehicular Localisation at High and Low Estimation Rates During GNSS Outages: A Deep Learning Approach . . . . . . . . . . . . . . . . . . . . . 229 Uche Onyekpe, Stratis Kanarachos, Vasile Palade, and Stavros-Richard G. Christopoulos Multi-Adversarial Variational Autoencoder Nets for Simultaneous Image Generation and Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . 249 Abdullah-Al-Zubaer Imran and Demetri Terzopoulos Non-convex Optimization Using Parameter Continuation Methods for Deep Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 273 Harsh Nilesh Pathak and Randy Clinton Paffenroth Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 299
Editors and Contributors
About the Editors Dr. M. Arif Wani is a Professor at the University of Kashmir, having previously served as a Professor at California State University, Bakersfield. He completed his M.Tech. in Computer Technology at the Indian Institute of Technology, Delhi, and his Ph.D. in Computer Vision at Cardiff University, UK. His research interests are in the area of machine learning, with a focus on neural networks, deep learning, inductive learning, and support vector machines, and with application to areas that include computer vision, pattern recognition, classification, prediction, and analysis of gene expression datasets. He has published many papers in reputed journals and conferences in these areas. Dr. Wani has co-authored the book ‘Advances in Deep Learning,’ co-edited the book ‘Deep Learning Applications,’ and co-edited 17 conference proceeding books in machine learning and applications area. He is a member of many academic and professional bodies, e.g., the Indian Society for Technical Education, Computer Society of India, and IEEE USA. Dr. Taghi M. Khoshgoftaar is the Motorola Endowed Chair professor of the Department of computer and electrical engineering and Computer Science, Florida Atlantic University, and the Director of NSF Big Data Training and Research Laboratory. His research interests are in big data analytics, data mining and machine learning, health informatics and bioinformatics, social network mining, and software engineering. He has published more than 750 refereed journal and conference papers in these areas. He was the Conference Chair of the IEEE International Conference on Machine Learning and Applications (ICMLA 2019). He is the Co-Editor-in-Chief of the Journal of Big Data. He has served on organizing and technical program committees of various international conferences, symposia, and workshops. He has been a Keynote Speaker at multiple international
ix
x
Editors and Contributors
conferences and has given many invited talks at various venues. Also, he has served as North American Editor of the Software Quality Journal, was on the editorial boards of the journals Multimedia Tools and Applications, Knowledge and Information Systems, and Empirical Software Engineering, and is on the editorial boards of the journals Software Quality, Software Engineering and Knowledge Engineering, and Social Network Analysis and Mining. Dr. Vasile Palade is currently a Professor of Artificial Intelligence and Data Science at Coventry University, UK. He previously held several academic and research positions at the University of Oxford—UK, University of Hull—UK, and the University of Galati—Romania. His research interests are in the area of machine learning, with a focus on neural networks and deep learning, and with main application to image processing, social network data analysis and web mining, smart cities, health, among others. Dr. Palade is author and co-author of more than 170 papers in journals and conference proceedings as well as several books on machine learning and applications. He is an Associate Editor for several reputed journals, such as Knowledge and Information Systems and Neurocomputing. He has delivered keynote talks to international conferences on machine learning and applications. Dr. Vasile Palade is an IEEE Senior Member.
Contributors Meshal Alfarhood Department of Electrical Engineering and Computer Science, University of Missouri-Columbia, Columbia, USA M. Arif Wani Department of Computer Science, University of Kashmir, Srinagar, India Jianlin Cheng Department of Electrical Engineering and Computer Science, University of Missouri-Columbia, Columbia, USA Stavros-Richard G. Christopoulos Institute for Future Transport and Cities, Coventry University, Coventry, UK; Faculty of Engineering, Coventry University, Coventry, UK Randy Clinton Paffenroth Worcester Polytechnic Institute, Mathematical Sciences Computer Science & Data Science, Worcester, MA, USA Gabriel Pellegrino da Silva Institute of Computing, University of Campinas, Campinas, SP, Brazil Georg Fischer Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany Ashraf Gaffar Arizona State University, Tempe, USA
Editors and Contributors
xi
Clemens Gühmann Chair of Electronic Measurement and Diagnostic Technology & Technische Universität Berlin, Berlin, Germany Sven Hartmann SEG Automotive Germany GmbH, Stuttgart, Germany Abdullah-Al-Zubaer Imran University of California, Los Angeles, CA, USA Justin M. Johnson Florida Atlantic University, Boca Raton, FL, USA Sharif Amit Kamran University of Nevada, Reno, NV, USA Stratis Kanarachos Faculty of Engineering, Coventry University, Coventry, UK Taghi M. Khoshgoftaar Florida Atlantic University, Boca Raton, FL, USA Shokoufeh Monjezi Kouchak Arizona State University, Tempe, USA Guilherme Vieira Leite Institute of Computing, University of Campinas, Campinas, SP, Brazil Tahir Mujtaba Department of Computer Science, University of Kashmir, Srinagar, India Tathagata Mukherjee Computer Science, University of Alabama, Huntsville, AL, USA Harsh Nilesh Pathak Expedia Group, Seattle, WA, USA Uche Onyekpe Institute for Future Transport and Cities, Coventry University, Coventry, UK; Research Center for Data Science, Coventry University, Coventry, UK Vasile Palade Research Center for Data Science, Coventry University, Coventry, UK Eduardo Pasiliao Munitions Directorate, Air Force Research Laboratory, Eglin AFB, Valparaiso, FL, USA Helio Pedrini Institute of Computing, University of Campinas, Campinas, SP, Brazil Daniele Rosato SEG Automotive Germany GmbH, Stuttgart, Germany Debashri Roy Computer Science, University of Central Florida, Orlando, FL, USA Russell Sabir SEG Automotive Germany GmbH, Stuttgart, Germany; Chair of Electronic Measurement and Diagnostic Technology & Technische Universität Berlin, Berlin, Germany Ali Shihab Sabbir Center for Cognitive Skill Enhancement, Independent University Bangladesh, Dhaka, Bangladesh
xii
Editors and Contributors
Sourajit Saha Center for Cognitive Skill Enhancement, Independent University Bangladesh, Dhaka, Bangladesh Avik Santra Infineon Technologies AG, Neubiberg, Germany Michael Stephan Infineon Technologies AG, Neubiberg, Germany; Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany Alireza Tavakkoli University of Nevada, Reno, NV, USA Demetri Terzopoulos University of California, Los Angeles, CA, USA
Deep Learning-Based Recommender Systems Meshal Alfarhood and Jianlin Cheng
Abstract The term “information overload” has gained popularity over the last few years. It defines the difficulties people face in finding what they want from a huge volume of available information. Recommender systems have been recognized to be an effective solution to such issues, such that suggestions are made based on users’ preferences. This chapter introduces an application of deep learning techniques in the domain of recommender systems. Generally, collaborative filtering approaches, and Matrix Factorization (MF) techniques in particular, are widely known for their convincing performance in recommender systems. We introduce a Collaborative Attentive Autoencoder (CATA) that improves the matrix factorization performance by leveraging an item’s contextual data. Specifically, CATA learns the proper features from scientific articles through the attention mechanism that can capture the most pertinent parts of information in order to make better recommendations. The learned features are then incorporated into the learning process of MF. Comprehensive experiments on three real-world datasets have shown our method performs better than other state-of-the-art methods according to various evaluation metrics. The source code of our model is available at: https://github.com/jianlin-cheng/CATA.
This chapter is an extended version of our published paper at the IEEE ICMLA conference 2019 [1]. This chapter incorporates new experimental contributions compared to the original conference paper. M. Alfarhood (B) · J. Cheng Department of Electrical Engineering and Computer Science, University of Missouri-Columbia, Columbia, USA e-mail: [email protected] J. Cheng e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_1
1
2
M. Alfarhood and J. Cheng
1 Introduction The era of e-commerce has vastly changed people’s lifestyles during the first part of the twenty-first century. People today tend to do many of their daily routines online, such as shopping, reading the news, and watching movies. Nevertheless, consumers often face difficulties while exploring related items such as new fashion trends because they are not aware of their existence due to the overwhelming amount of information available online. This phenomenon is widely known as “information overload”. Therefore, Recommender Systems (RSs) are a critical solution for helping users make decisions when there are lots of choices. RSs have been integrated into and have become an essential part of every website due to their impact on increasing customer interactions, attracting new customers, and growing businesses’ revenue. Scientific article recommendation is a very common application for RSs. It keeps researchers updated on recent related work in their field. One traditional way to find relevant articles is to go through the references section in other articles. Yet, this approach is biased toward heavily cited articles, such that new relevant articles with higher impact have less chance to be found. Another method is to search for articles using keywords. Although this technique is popular among researchers, they must filter out a tremendous number of articles from the search results to retrieve the most suitable articles. Moreover, all users get the same search results with the same keywords, and these results are not personalized based on the users’ personal interests. Thus, recommendation systems can address this issue and help scientists and researchers find valuable articles while being aware of recent related work. Over the last few decades, a lot of effort has been made by both academia and industry on proposing new ideas and solutions for RSs, which ultimately help service providers in adopting such models in their system architecture. The research in RSs has evolved remarkably following the Netflix prize competition1 in 2006, where the company offered one million dollars for any team that could improve their recommendation accuracy by 10%. Since that time, collaborative filtering models and matrix factorization techniques in particular have become the most common models due to their effective performance. Generally, recommendation models are classified into three categories: Collaborative Filtering Models (CF), Content-Based Filtering models (CBF), and hybrid models. CF models [2–4] focus on users’ histories, such that users with similar past behaviors tend to have similar future tastes. On the other hand, CBF models work by learning the item’s features from its informational description, such that two items are possibly similar to each other if they share more characteristics. For example, two songs are similar to each other if they both share the same artist, genre, tempo, energy, etc. However, similarities between items in CF models are different such that two items are likely similar to each other once they are rated by multiple users in the same manner, even though those items have different characteristics.
1 www.netflixprize.com.
Deep Learning-Based Recommender Systems
3
Generally, CF models function better than CBF models. However, CF performance drops substantially when users or items have an insufficient amount of feedback data. This problem is defined as the data sparsity problem. To tackle data sparseness, hybrid models have been widely proposed in recent works [5–8], in which content information, used in CBF models, is incorporated with CF models to improve the system performance. Hybrid models are divided into two sub-categories according to how models are trained: loosely coupled models and tightly coupled models [7]. Loosely coupled models train CF and CBF models separately, like ensembles, and then determine the final score based on the scores of the two separated models. On the other hand, the tightly coupled models train both CF and CBF models jointly. In joint training, both models cooperate with one another to calculate the final score under the same loss function. Even though traditional recommendation approaches have achieved great success over the last years, they still have shortcomings in accurately modeling complex (e.g., non-linear) relationship between users and items. Alternatively, deep neural networks are universal function approximators that are capable of modeling any continuous function. Recently, Deep Learning (DL) has become an effective approach for most data mining problems. DL meets recommendation systems in the last few years. One of the first works that applied DL concept for CF recommendations was Restricted Boltzmann Machines (RBM) [4]. However, this approach was not deep enough (two layers only) to learn users’ tastes from their histories, and it also did not take contextual information into consideration. Recently, Collaborative Deep Learning (CDL) [7] has become a very popular deep learning technique in RSs due to its promising performance. CDL can be viewed as an updated version of Collaborative Topic Regression (CTR) [5] by substituting the Latent Dirichlet Allocation (LDA) topic modeling with a Stacked Denoising Autoencoder (SDAE) to learn from item contents, and then integrating the learned latent features into a Probabilistic Matrix Factorization (PMF). Lately, Collaborative Variational Autoencoder (CVAE) [8] has been proposed to learn deep item latent features via a variational autoencoder. The authors show that their model learns better item features than CDL because their model infers the latent variable distribution in latent space instead of observation space. However, both CDL and CVAE models assume that all parts of their model’s contribution are the same for their final predictions. Hence, in this work, we propose a deep learning-based model named Collaborative Attentive Autoencoder (CATA) for recommending scientific articles. In particular, we integrate the attention mechanism into our unsupervised deep learning process to identify an item’s features. We learn the item’s features from the article’s textual information (e.g., the article’s title and abstract) to enhance the recommendation quality. The compressed low-dimensional representation learned by the unsupervised model is incorporated then into the matrix factorization approach for our ultimate recommendation. To demonstrate the capability of our proposed model to generate more relevant recommendations, we conduct inclusive experiments on three realworld datasets, which are taken from the CiteULike2 website, to evaluate CATA 2 www.citeulike.org.
4
M. Alfarhood and J. Cheng
against multiple recent works. The experimental results prove that our model can extract more constructive information from an article’s contextual data than other models. More importantly, CATA performs very well where the data sparsity is extremely high. The remainder of this chapter is organized in the following manner. First, we demonstrate the matrix factorization method in Sect. 2. We introduce our model, CATA, in Sect. 3. The experimental results of our model against the state-of-the-art models are discussed thoroughly in Sect. 4. We then conclude our work in Sect. 5.
2 Background Our work is designed and evaluated on recommendations with implicit feedback. Thus, in this section, we describe the well-known collaborative filtering approach, Matrix Factorization, for implicit feedback problems.
2.1 Matrix Factorization Matrix Factorization (MF) [2] is the most popular CF method, mainly due to its simplicity and efficiency. The idea behind MF is to decompose the user-item matrix, R ∈ Rn×m , into two lower dimensional matrices, U ∈ Rn×d and V ∈ Rm×d , such that the inner product of U and V will approximate the original matrix R, where d is the dimension of the latent factors, such that d min(n, m). n and m correspond to the number of users and items in the system. Figure 1 illustrates the MF process. R ≈ U · VT
(1)
MF optimizes the values of U and V by minimizing the sum of the squared difference between the actual values and the model predictions with adding two regularization terms, as shown here: L=
Ii j λu λv v j 2 u i 2 + (ri j − u i v Tj )2 + 2 2 2 i, j∈R
(2)
where Ii j is an indicator function that equals 1 if useri has rated item j , and 0 if otherwise. Also, ||U || and ||V || are the Euclidean norms, and λu , λv are two regularization terms preventing the values of U and V from being too large. This avoid model overfitting. Explicit data, such as ratings (ri j ) are not regularly available. Therefore, Weighted Regularized Matrix Factorization (WRMF) [9] introduces two modifications to the previous objective function to make it work for implicit feedback. The optimization
Deep Learning-Based Recommender Systems
5
Fig. 1 Matrix factorization illustration
process in this case runs through all user-item pairs with different confidence levels assigned to each pair, as in the following: L=
ci j λu λv v j 2 u i 2 + ( pi j − u i v Tj )2 + 2 2 2 i, j∈R
(3)
where pi j is the user preference score with a value of 1 when useri and item j have an interaction, and 0 otherwise. ci j is a confidence variable where its value shows how confident the user like the item. In general, ci j = a when pi j = 1, and ci j = b when pi j = 0, such that a > b > 0. Stochastic Gradient Decent (SGD) [10] and Alternating Least Squares (ALS) [11] are two optimization methods that can be used to minimize the objective function of MF in Eq. 2. The first method, SGD, loops over each single training sample and then computes the prediction error as ei j = ri j − u i v Tj . The gradient of the objective function with respect to u i and v j can be computed as follows: ∂L =− Ii j (ri j − u i v Tj )v j + λu u i ∂u i j ∂L =− Ii j (ri j − u i v Tj )u i + λv v j ∂v j i
(4)
After calculating the gradient, SGD updates the user and item latent factors in the opposite direction of the gradient using the following equations:
6
M. Alfarhood and J. Cheng
⎞ ⎛ ui ← ui + α ⎝ Ii j ei j v j − λu u i ⎠ j
vj ← vj + α
(5)
Ii j ei j u i − λ j v j
i
where α is the learning rate. Even though SGD is easy to implement and generally faster than ALS in some cases, it is not suitable to use with implicit feedback, since looping over each single training sample is not practical. ALS works better in this case. ALS iteratively optimizes U while V is fixed, and vice versa. This optimization process is repeated until the model converges. To determine what user and item vector values minimize the objective function for implicit data (Eq. 3), we first take the derivative of L with respect to u i . ∂L =− ci j ( pi j − u i v Tj )v j + λu u i ∂u i j 0
= −Ci (Pi − u i V T )V + λu u i
0
= −Ci V Pi + Ci V u i V T + λu u i
(6)
V Ci Pi = u i V Ci V + λu u i T
V Ci Pi = u i (V Ci V T + λu I ) ui
= V Ci Pi (V Ci V T + λu I )−1
ui
= (V Ci V T + λu I )−1 V Ci Pi
where I is the identity matrix. Similarly, taking the derivative of L with respect to v j leads to v j = (U C j U T + λv I )−1 U C j P j
(7)
3 Proposed Model In this section, we illustrate our proposed model in depth. The intuition behind our model is to learn the latent factors of items in PMF with the use of available side textual contents. We use an attentive unsupervised model to catch more plentiful information from the available data. The architecture of our model is displayed in Fig. 2. We first define the problem with implicit feedback before we go through the details of our model.
Deep Learning-Based Recommender Systems
λu
7
λv X̂ j
Decoder Attention
Ui
Vj
Rij
i = 1:n
Zj
X
Softmax
Encoder
Xj
j = 1:m
Fig. 2 Collaborative attentive autoencoder architecture
3.1 Problem Definition User-item interaction data is the primary source for training recommendation engines. This data can be either collected in an explicit or implicit manner. In explicit data, users directly express their opinion about an item using the rating system to show how much they like that item. The user’s ratings usually vary from one-star to five-stars with five being very interested and one being not interested. This type of data is very useful and reliable due to the fact that it represents the actual feeling of users about items. However, users’ ratings occasionally are not available due to the difficulty of obtaining users’ explicit opinions. In this case, implicit feedback can be obtained indirectly from the user’s behavior such as user clicks, bookmarks, or the time spent viewing an item. For instance, if a user listens to a song 10 times in the last two days, he or she most likely likes this song. Thus, implicit data is more prevalent and easier to collect, but it is generally less reliable than explicit data. Also, all the observed interactions in implicit data constitute positive feedback, but negative feedback is missing. This problem is also defined as the one-class problem. There are multiple previous works aiming to deal with the one-class problem. A simple solution is to treat all missing data as negative feedback. However, this is not true because the missing (unobserved) interaction could be positive if the user is aware of the item existing. Therefore, using this strategy to build a model might result in a misleading model due to faulty assumptions at the outset. On the contrary, if
8
M. Alfarhood and J. Cheng
we treat all missing data as unobserved data without considering including negative feedback in the model training, the corresponding trained model is probably useless since it is only trained on positive data. As a result, sampling negative feedback from positive feedback is one practical solution for this problem, which has been proposed by [12]. In addition, Weighted Regularized Matrix Factorization (WRMF) [9] is another proposed solution that introduces a confidence variable that works as a weight to measure how likely a user is to like an item. In general, the recommendation problem with implicit data is usually formulated as follows:
1, if there is user-item interaction (8) Rnm = 0, otherwise where the ones in implicit feedback represent all the positive feedback. However, it is important to note that a value of 0 does not imply always negative feedback. It may be that users are not aware of the existence of those items. In addition, the user-item interaction matrix (R) is usually highly imbalanced, such that the number of the observed interactions is much less than the number of the unobserved interactions. In other words, matrix R is very sparse, meaning that users only interact explicitly or implicitly with a very small number of items compared to the total number of items in this matrix. Sparsity is one frequent problem in RSs, which brings a real challenge for any proposed model to have the capability to provide effective personalized recommendations under this situation. The following sections explain our methodology, where we aim to eliminate the influence of the aforementioned problems.
3.2 The Attentive Autoencoder Autoencoder [13] is an unsupervised learning neural network that is useful for compressing high-dimensional input data into a lower dimensional representation while preserving the abstract nature of the data. The autoencoder network is generally composed of two main components, i.e., the encoder and the decoder. The encoder takes the input and encodes it through multiple hidden layers and then generates a compressed representative vector, Z j . The encoding function can be formulated as Z j = f (X j ). Subsequently, the decoder can be used then to reconstruct and estimate the original input, Xˆ j , using the representative vector, Z j . The decoder function can be formulated as Xˆ j = f (Z j ). Each the encoder and the decoder usually consist of the same number of hidden layers and neurons. The output of each hidden layer is computed as follows: h () = σ (h (−1) W () + b() )
(9)
Deep Learning-Based Recommender Systems
9
where () is the layer number, W is the weights matrix, b is the bias vector, and σ is a non-linear activation function. We use the Rectified Linear Unit (ReLU) as the activation function. Our model takes input from the article’s textual data, X j = {x 1 , x 2 , . . . , x s }, where x i is a value between [0, 1] and s represents the vocabulary size of the articles’ titles and abstracts. In other words, the input of our autoencoder network is a normalized bag-of-words histograms of filtered vocabularies of the articles’ titles and abstracts. Batch Normalization (BN) [14] has been proven to be a proper solution for the internal covariant shift problem, where the layer’s input distribution in deep neural networks changes across the time of training, and causes difficulty in training the model. In addition, BN can work as a regularization procedure like Dropout [15] in deep neural networks. Accordingly, we apply a batch normalization layer after each hidden layer in our autoencoder to obtain a stable distribution from each layer’s output. Furthermore, we use the idea of the attention mechanism to work between the encoder and the decoder, such that only the relevant parts of the encoder output are selected for the input reconstruction. Attention in deep learning can be described simply as a vector of weights to show the importance of the input elements. Thus, the intuition behind attention is that not all parts of the input are equally significant, i.e., only few parts are significant for the model. We first calculate the scores as the probability distribution of the encoder’s output using the so f tmax(.) function. ezc f (z c ) = z d de
(10)
The probability distribution and the encoder output are then multiplied using element-wise multiplication function to get Z j . We use the attentive autoencoder to pretrain the items’ contextual information and then integrate the compressed representation, Z j , in computing the items’ latent factors, V j , from the matrix factorization method. The dimension space of Z j and V j are set to be equal to each other. Finally, we adopt the binary cross-entropy (Eq. 11) as the loss function we want to minimize in our attentive autoencoder model.
yk log( pk ) − (1 − yk ) log(1 − pk ) (11) L=− k
where yk corresponds to the correct labels and pk corresponds to the predicted values.
10
M. Alfarhood and J. Cheng
3.3 Probabilistic Matrix Factorization Probabilistic Matrix Factorization (PMF) [3] is a probabilistic linear model where the prior distributions of the latent factors and users’ preferences are drawn from Gaussian normal distribution. u i ∼ N (0, λ−1 u I) v j ∼ N (0, λ−1 v I) pi j ∼
(12)
N (u i v Tj , σ 2 )
We integrate the items’ contents, trained through the attentive autoencoder, into PMF. Therefore, the objective function in Eq. 3 has been changed slightly to become L=
ci j λu λv v j − θ (X j )2 u i 2 + ( pi j − u i v Tj )2 + 2 2 2 i, j∈R
(13)
where θ (X j ) = Encoder (X j ) = Z j . Thus, taking the partial derivative of our previous objective function with respect to both u i and v j results in the following equations that minimize our objective function the most u i = (V Ci V T + λu I )−1 V Ci Pi v j = (U C j U T + λv I )−1 U C j P j + λv θ (X j )
(14)
We optimize the values of u i and v j using the Alternating Least Squares (ALS) optimization method.
3.4 Prediction After our model has been trained and the latent factors of users and articles, U and V , are identified, we calculate our model’s prediction scores of useri and each article as the dot product of vector u i with all vectors in V as scor esi = u i V T . Then, we sort all articles based on our model predication scores in descending order, and then recommend the top-K articles for that useri . We go through all users in U in our evaluation and report the average performance among all users. The overall process of our approach is illustrated in Algorithm 1.
Deep Learning-Based Recommender Systems
11
Algorithm 1: CATA algorithm 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
pretrain autoencoder with input X ; Z ← θ(X ); U, V ← Initialize with random values; while do for do u i ← update using Equation 14; end for for do vi ← update using Equation 14; end for end while for do scor esi ← u i V T ; sort(scor esi ) in descending order; end for Evaluate the top-K recommendations;
4 Experiments In this section, we conduct extensive experiments aiming to answer the following research questions: • RQ1: How does our proposed model, CATA, perform against state-of-the-art methods? • RQ2: Does adding the attention mechanism actually improve our model performance? • RQ3: How could different values of the regularization parameters (λu and λv ) affect CATA performance? • RQ4: What is the impact of different dimension values of users and items’ latent space on CATA performance? • RQ5: How many training epochs are sufficient for pretraining our autoencoder? Before answering these research questions, we first describe the datasets used in our evaluations, the evaluation metrics, and the baseline approaches we use to evaluate our model against.
4.1 Datasets Three scientific article datasets are used to evaluate our model against the state-ofthe-art methods. All datasets are collected from CiteULike website. The first dataset is called Citeulike-a, which is collected by [5]. It has 5,551 users, 16,980 articles, and 204,986 user-article pairs. The sparseness of this dataset is extremely high, where only around 0.22% of the user-article matrix has interactions. Each user has at least
12
M. Alfarhood and J. Cheng
ten articles in his or her library. On average, each user has 37 articles in his or her library and each article has been added to 12 users’ libraries. The second dataset is called Citeulike-t, which is collected by [6]. It has 7,947 users, 25,975 articles, and 134,860 user-article pairs. This dataset is actually sparser than the first one with only 0.07% available user-article interactions. Each user has at least three articles in his or her library. On average, each user has 17 articles in his or her library and each article has been added to five users’ libraries. Lastly, Citeulike-2004–2007 is the third dataset, and it is collected by [16]. It is three times bigger than the previous ones with regard to the user-article matrix. It has 3,039 users, 210,137 articles, and 284,960 user-article pairs. This dataset is the sparsest in this experiment, with a sparsity equal to 99.95%. Each user has at least ten articles in his or her library. On average, each user has 94 articles in his or her library and each article has been added only to one user library. Brief statistics of the datasets are shown in Table 1. Title and abstract of each article are given in each dataset. The average number of words per article in both title and abstract after our text preprocessing is 67 words in Citeulike-a, 19 words in Citeulike-t, and 55 words in Citeulike-2004–2007. We follow the same preprocessing techniques as the state-of-the-art models in [5, 7, 8]. A five-stage procedure to preprocess the textual content is displayed in Fig. 3. Each article title and abstract are combined together and then are preprocessed such that stop words are removed. After that, top-N distinct words based on the TF-IDF measurement are picked out. 8,000 distinct words are selected for the Citeulike-a dataset, 20,000 distinct words are selected for the Citeulike-t dataset, and 19,871 distinct words are selected for the Citeulike-2004–2007 dataset to form the bag-ofwords histogram, which are then normalized into values between 0 and 1 based on the vocabularies’ occurrences.
Table 1 Descriptions of citeulike datasets Dataset #Users Citeulike-a Citeulike-t Citeulike-2004–2007
5,551 7,947 3,039
#Articles
#Pairs
Sparsity (%)
16,980 25,975 210,137
204,986 134,860 284,960
99.78 99.93 99.95
Fig. 3 A five-stage procedure for preprocessing articles’ titles and abstracts
Deep Learning-Based Recommender Systems
13
Fig. 4 Ratio of articles that have been added to ≤N users’ libraries
Figure 4 shows the ratio of articles that have been added to five or fewer users’ libraries. For example, 15, 77, and 99% of the articles in Citeulike-a, Citeulike-t, and Citeulike-2004–2007, respectively, are added to five or fewer users’ libraries. Also, only 1% of the articles in Citeulike-a have been added only to one user library, while the rest of the articles have been added to more than this number. On the contrary, 13, and 77% of the articles in Citeulike-t and Citeulike-2004–2007 have been added only to one user library. This proves the sparseness of the data with regard to articles as we go from one dataset to another.
4.2 Evaluation Methodology We follow the state-of-the-art techniques [6–8] to generate our training and testing sets. For each dataset, we create two versions of the dataset for sparse and dense settings. In total, six dataset cases are used in our evaluation. To form the sparse (P = 1) and the dense (P = 10) datasets, P items are randomly selected from each user library to generate the training set while the remaining items from each user library are used to generate the testing set. As a result, when P = 1, only 2.7, 5.9, and 1.1% of the data entries are used to generate the training set in Citeulike-a, Citeulike-t, and Citeulike-2004–2007, respectively. Similarly, 27.1, 39.6, and 10.7% of the data entries are used to generate the training set when P = 10 as Fig. 5 shows.
14
M. Alfarhood and J. Cheng
(a) Citeulike-a
(b) Citeulike-t
(c) Citeulike-2004-2007
Fig. 5 The percentage of the data entries that forms the training and testing sets in all citeulike datasets
We use recall and Discounted Cumulative Gain (DCG) as our evaluation metrics to test how our model performs. Recall is usually used to evaluate recommender systems with implicit feedback. However, precision is not favorable to use with implicit feedback because the zero value in the user-article interaction matrix has two meanings: either the user is not interested in the article, or the user is not aware of the existence of this article. Therefore, using the precision metric only assumes that for each zero value the user is not interested in the article, which is not the case. Recall per user can be measured using the following formula: recall@K =
Relevant Articles ∩ K Recommended Articles Relevant Articles
(15)
where K is set manually in the experiment and represents the top K articles of each user. We set K = 10, 50, 100, 150, 200, 250, and 300 in our evaluations. The overall recall can be calculated as the average recall among all users. If K equals the number of articles in the dataset, recall will have a value of 1. Recall, however, does not take into account the ranking of articles within the top-K recommendations, as long as they are in the top-K list. However, DCG does. DCG shows the capability of the recommendation engine to recommend articles at the top of the ranking list. Articles in higher ranked K positions have more value than others. The DCG among all users can be measured using the following equation: |U |
DCG@K =
1 rel(i) |U | u=1 i=1 log2 (i + 1) K
(16)
where |U | is the total number of users, i is the rank of the top-K articles recommended by the model, and rel(i) is an indicator function that outputs 1 if the article at rank i is a relevant article, and 0 otherwise.
Deep Learning-Based Recommender Systems
15
4.3 Baselines We evaluate our approach against the following baselines described below: • POP: Popular predictor is a non-personalized recommender system. It recommends the most popular articles in the training set, such that all users get identical recommendations. It is widely used as the baseline for personalized recommendation models. • CDL: Collaborative Deep Learning (CDL) [7] is a deep Bayesian model that jointly models both user-item interaction data and items’ content via a Stacked Denoising Autoencoder (SDAE) with a Probabilistic Matrix Factorization (PMF). • CML+F: Collaborative Metric Learning (CML) [17] is a metric learning model that pulls items liked by a user closer to that user. Recommendations are then generated based on the K-Nearest Neighbor of each user. CML+F additionally uses a neural network with two fully connected layers to train items’ features (articles’ tags in this chapter) to update items’ location. CML+F has been shown to have a better performance than CML. • CVAE: Collaborative Variational Autoencoder (CVAE) [8] is a probabilistic model that jointly models both user-item interaction data and items’ content through a Variational Autoencoder (VAE) with a Probabilistic Matrix Factorization (PMF). It can be considered as the baseline of our proposed approach since CVAE and CATA share the same strategy. For hyper-parameter settings, we set the confidence variables (i.e., a and b) to a = 1, and b = 0.01. These are the same values used in CDL and CVAE as well. Also, a four-layer network is used to construct our attentive autoencoder. The fourlayer network has the following shape “#Vocabularies-400-200-100-50-100-200400-#Vocabularies”.
4.4 Experimental Results For each dataset, we repeat the data splitting four times with different random splits of training and testing set, which has been previously described in the evaluation methodology section. We use one split as a validation experiment to find the optimal parameters of λu and λv for our model and the state-of-the-art models as well. We search a grid of the following values {0.01, 0.1, 1, 10, 100} and the best values on the validation experiment have been reported in Table 2. The other three splits are used to report the average performance of our model against the baselines. In this section, we address the research questions that have been previously defined in the beginning of this section.
16
M. Alfarhood and J. Cheng
Table 2 Parameter settings for λu and λv based on the validation experiment Approach Citeulike-a Citeulike-t Citeulike-2004–2007 Sparse Dense Sparse Dense Sparse Dense λu λv λu λv λu λv λu λv λu λv λu λv CDL CVAE CATA
4.4.1
0.01 0.1 10
10 10 0.1
0.01 1 10
10 10 0.1
0.01 0.1 10
10 10 0.1
0.01 0.1 10
10 10 0.1
0.01 0.1 10
10 10 0.1
0.01 0.1 10
10 10 0.1
RQ1
To evaluate how our model performs, we conduct quantitative and qualitative comparisons to answer this question. Figures 6, 7, 8, and 9 show the performance of the top-K recommendations under the sparse and dense settings in terms of recall and DCG. First, the sparse cases are very challenging for any proposed model since there is less data for training. In the sparse setting where there is only one article in each user’s library in the training set, our model, CATA, outperforms the baselines in all datasets in terms of recall and DCG, as Figs. 6 and 7 show. More importantly, CATA outperforms the baselines by a wide margin in the Citeulike-2004–2007 dataset, where it is actually sparser and contains a huge number of articles. This validates the robustness of our model against data sparsity.
(a) Citeulike-a
(b) Citeulike-t
(c) Citeulike-2004-2007
Fig. 6 Recall performance under the sparse setting, P = 1
(a) Citeulike-a
(b) Citeulike-t
Fig. 7 DCG performance under the sparse setting, P = 1
(c) Citeulike-2004-2007
Deep Learning-Based Recommender Systems
(a) Citeulike-a
(b) Citeulike-t
17
(c) Citeulike-2004-2007
Fig. 8 Recall performance under the dense setting, P = 10
(a) Citeulike-a
(b) Citeulike-t
(c) Citeulike-2004-2007
Fig. 9 DCG performance under the dense setting, P = 10
Second, with the dense setting where there are more articles in each user’s library in the training set, our model performs comparably to other baselines in Citeulike-a and Citeulike-t datasets as Figs. 8 and 9 show. As a matter of fact, many of the existing models actually work well under this setting, but poorly under the sparse setting. For example, CML+F achieves a competitive performance on the dense data; however, it fails on the sparse data since their metric space needs more interactions for users to capture their preferences. On the other hand, CATA outperforms the other baselines under this setting in the Citeulike-2004–2007 dataset. As a result, this experiment demonstrates the capability of our model for making more relevant recommendations under both sparse and dense data conditions. In addition to the previous quantitative comparisons, some qualitative results are reported in Table 3 as well. The table shows the top ten recommendations generated by our model (CATA) and the state-of-the-art model (CVAE) for one randomly selected user under the sparse setting using the Citeulike-2004–2007 dataset. In this example, user 20 has only one article in his training library, entitled “Assessment of Attention Deficit/ Hyperactivity Disorder in Adult Alcoholics”. From this example, this user seems to be interested in the treatment of Attention-Deficit/Hyperactivity Disorder (ADHD) among alcohol- and drug-using populations. Comparing the recommendation results between the two models, our model recommends more relevant articles based on the user’s interests. For instance, most of the recommended articles using the CATA model are related to the same topic, i.e., alcohol- and drug-users with ADHD. However, there are some irrelevant articles recommended by CVAE,
18
M. Alfarhood and J. Cheng
Table 3 The top-10 recommendations for one selected random user under the sparse setting, P = 1, using the citeulike-2004–2007 dataset User ID: 20 Articles in the training set: assessment of attention deficit/ hyperactivity disorder in adult alcoholics CATA In user library? 1. A Double-blind, placebo-controlled withdrawal trial of dexmethylphenidate No hydrochloride in children with ADHD 2. Double-blind placebo-controlled trial of methylphenidate in the treatment of Yes adult ADHD patients with comorbid cocaine... 3. Methylphenidate treatment for cocaine abusers with adult Yes attention-deficit/hyperactivity disorder: a pilot study 4. A controlled trial of methylphenidate in adults with attention Yes deficit/hyperactivity disorder and substance use disorders 5. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes double-blind comparison of methylphenidate and... 6. A large, double-blind, randomized clinical trial of methylphenidate in the Yes treatment of adults with ADHS 7. Patterns of inattentive and hyperactive symptomatology in cocaine-addicted Yes and non-cocaine-addicted smokers diagnosed... 8. Frontocortical activity in children with comorbidity of tic disorder and No attention-deficit hyperactivity disorder 9. Gender effects on attention-deficit/hyperactivity disorder in adults, revisited Yes 10. Association between dopamine transporter (DAT1) genotype, left-sided Yes inattention, and an enhanced response to... CVAE In user library? 1. Psycho-social correlates of unwed mothers No 2. A randomized, controlled trial of integrated home-school behavioral treatment No for ADHD, predominantly inattentive type 3. Age and gender differences in children’s and adolescents’ adaptation to sexual No abuse 4. Distress in individuals facing predictive DNA testing for autosomal dominant No late-onset disorders: Comparing questionnaire results... 5. Combined treatment with sertraline and liothyronine in major depression: a No randomized, double-blind, placebo-controlled trial Yes 6. An open-label pilot study of methylphenidate in the treatment of cocaine dependent patients with adult ADHS 7. Treatment of cocaine dependent treatment seekers with adult ADHD: Yes double-blind comparison of methylphenidate and... 8. ADouble-Blind, Placebo-Controlled Withdrawal Trial of Dexmethylphenidate No Hydrochloride in Children with ADHS 9. Methylphenidate treatment for cocaine abusers with adult Yes attention-deficit/hyperactivity disorder: a pilot study 10. A large, double-blind, randomized clinical trial of methylphenidate in the Yes treatment of adults with ADHS
Deep Learning-Based Recommender Systems
19
Table 4 Performance comparisons on sparse data with using attention layer (CATA) and without (CATA). Best values are marked in bold Approach Citeulike-a Citeulike-t Recall@300 DCG@300 Recall@300 DCG@300 CATA– CATA
0.3003 0.3060
1.6644 1.7206
0.2260 0.2425
0.4661 0.5160
such as recommended articles numbers 1, 3, and 4. From this example and other users’ examples we have examined, we can state that our model detects the major elements of articles’ contents and users’ preferences more accurately.
4.4.2
RQ2
To examine the importance of adding the attention layer into our autoencoder, we create another variant of our model that has the same architecture, but lacks the attention layer, which we call CATA–. We evaluate this model on the sparse cases using Citeulike-a and Citeulike-t datasets. The performance comparisons are reported in Table 4. As the table shows, adding the attention mechanism boosts the performance. Consequently, using the attention mechanism gives more focus to some parts of the encoded vocabularies in each article to better represent the contextual data, eventually leading to increased recommendation quality.
4.4.3
RQ3
There are two regularization parameters (λu and λv ) that are used in the objective function of the matrix factorization method to prevent the latent vectors’ magnitude from being too large, which eventually prevent the model from overfitting the training data. Our previously reported results are obtained by setting λu and λv to the numbers in Table 2 based on the validation experiment. However, we perform multiple experiments to show the impact of different values of λu and λv and how they affect our model’s performance. We use different values to set the parameters from the following range {0.01, 0.1, 1, 10, 100}. Figure 10 visualizes how our model performs under each combination of λu and λv . We find that our model has a lower performance when the value of λv is considerably large under the dense setting, as Figs. 10b, d show. On the other hand where the data is sparser in Figs. 10a, c, e, f, a very small value of λu (e.g., 0.01) tends to have the lowest performance among all other numbers. Even though Fig. 10f shows the performance under the dense setting for the Citeulike-2004–2007 dataset, it still exemplifies the sparsity with regard to articles as we indicate before in Fig. 4, where 80% of the articles have only been added to one user’s library. Generally, we observe that optimal performance happens in all datasets when λu = 10 and λv = 0.1. We can conclude that when there is suffi-
20
M. Alfarhood and J. Cheng
(a) Citeulike-a, P=1
(d) Citeulike-t,P=10
(b) Citeulike-a, P=10
(e) Citeulike-2004-2007,P=1
(c) Citeulike-t, P=1
(f) Citeulike–2004-2007,P=10
Fig. 10 The impact of λu and λv on CATA performance for a, b Citeulike-a, c, d citeulike-t, and e, f citeulike-2004–2007 datasets
cient users’ feedback, items’ contextual information is no longer essential to obtain users’ preferences, and vice versa.
4.4.4
RQ4
The vectors of the latent features (U and V ) represent the characteristic of users and items that a model tries to learn from data. We examine the impact of the size of these vectors on the performance of our model. In other words, we examine how many dimensions in the latent space can represent the user and item features more accurately. It is worth mentioning that our reported results in the RQ1 section use 50 dimensions, which is similar to the size used by the state-of-the-art model (CVAE) in order to have fair comparisons. However, we run our model again using five dimension sizes from the following values {25, 50, 100, 200, 400}. Figure 11 shows how our model performs in terms of recall@100 under each dimension size. We observe that increasing the dimension size in dense data leads always to a gradual increase in our model performance, as shown in Fig. 11b. Also, larger dimension sizes are recommended for sparse data as well. However, they do not necessary improve the model’s performance all the time (e.g., the Citeulike-t dataset in Fig. 11a). Generally, dimension sizes between 100 and 200 are suggested for the latent space dimension.
Deep Learning-Based Recommender Systems
(a) P=1
21
(b) P=10
Fig. 11 The performance of CATA model with respect to different dimension values of the latent space under, a sparse data and b dense data
4.4.5
RQ5
We pretrain our autoencoder first until the loss value of the data converges sufficiently. The loss value shows the error computed by the autoencoder’s loss function where it shows how well the model reconstructs outputs from inputs. Figure 12 visualizes the number of needed training epochs to render the loss value sufficiently stable. We find that 200 epochs are sufficient for pretraining our autoencoder.
Fig. 12 The reduction in the loss values versus the number of training epochs
22
M. Alfarhood and J. Cheng
5 Conclusion In this chapter, we present a Collaborative Attentive Autoencoder (CATA) for recommending scientific articles. We utilize an article’s textual data to learn a better compressed representation of the data through the attention mechanism, which can guide the training process to focus on the relevant part of the encoder output in order to improve model predictions. CATA shows superiority over other state-of-the-art methods on three scientific article datasets. The performance improvement of CATA increases consistently as data sparsity increases. The qualitative results also reflect the good quality of our model recommendations. For potential future work, user data can be gathered and then used to update the user latent factors in the same way as we update the item latent factors. Even though user data is often not available due to privacy concerns (e.g., CiteULike datasets do not have user data), we believe that item data, together with user-item interaction data, can be used to infer user information. In addition, other variants of deep autoencoders discussed in [18], could be investigated to replace the attentive autoencoder. Another possible direction for future work is to explore new metric learning algorithms to substitute the Matrix Factorization (MF) technique, because the dot product in MF does not guarantee that items are placed correctly in the latent space with respect to the triangle inequality between items.
References 1. M. Alfarhood, J. Cheng, Collaborative attentive autoencoder for scientific article recommendation. in 2019 18th IEEE International Conference on Machine Learning and Applications (ICMLA) (IEEE, 2019) 2. Y. Koren, R. Bell, C. Volinsky, Matrix factorization techniques for recommender systems. Computer 8, 30–37 (2009) 3. A. Mnih, R. Salakhutdinov, Probabilistic matrix factorization. in Advances in Neural Information Processing Systems (2008), pp. 1257–1264 4. R. Salakhutdinov, A. Mnih, G. Hinton, Restricted Boltzmann machines for collaborative filtering. in Proceedings of the 24th International Conference on Machine Learning (ACM, 2007), pp. 791–798 5. C. Wang, D. Blei, Collaborative topic modeling for recommending scientific articles. in Proceedings of the 17th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (ACM, 2011), pp. 448–456 6. H. Wang, B. Chen, W. Li, Collaborative topic regression with social regularization for tag recommendation. IJCAI (2013) 7. H. Wang, N. Wang, D. Yeung, collaborative deep learning for recommender systems. in Proceedings of the 21th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (ACM, 2015), pp. 1235–1244 8. X. Li, J. She, Collaborative variational autoencoder for recommender systems. in Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (ACM, 2017) 9. Y. Hu, Y. Koren, C. Volinsky, Collaborative filtering for implicit feedback datasets. in Proceedings of the 8th IEEE International Conference on Data Mining (ICDM) (IEEE, 2008)
Deep Learning-Based Recommender Systems
23
10. S. Funk, Netflix update: try this at home (2006). https://sifter.org/simon/journal/20061211. html. Accessed 13 Nov 2019 11. Y. Zhou, D. Wilkinson, R. Schreiber, R. Pan, Large-scale parallel collaborative filtering for the netflix prize. in Proceedings of the International Conference on Algorithmic Applications in Management (Springer, 2008) 12. R. Pan, Y. Zhou, B. Cao, N. Liu, R. Lukose, M. Scholz, Q. Yang, One-class collaborative filtering. in Eighth IEEE International Conference on Data Mining, 2008. ICDM’08 (IEEE, 2008) 13. G. Hinton, R. Salakhutdinov, Reducing the dimensionality of data with neural networks. Science 313(5786), 504–507 (2006) 14. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing internal covariate shift (2015). arXiv preprint arXiv:1502.03167 15. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. (JMLR) 15(1), 1929–1958 (2014) 16. A. Alzogbi, Time-aware collaborative topic regression: towards higher relevance in textual item recommendation. BIRNDL@ SIGIR (2018) 17. C. Hsieh, L. Yang, Y. Cui, T. Lin, S. Belongie, D. Estrin, collaborative metric learning. in Proceedings of the 26th International Conference on World Wide Web. International World Wide Web Conferences Steering Committee (2017) 18. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020)
A Comprehensive Set of Novel Residual Blocks for Deep Learning Architectures for Diagnosis of Retinal Diseases from Optical Coherence Tomography Images Sharif Amit Kamran, Sourajit Saha, Ali Shihab Sabbir, and Alireza Tavakkoli
Abstract Spectral Domain Optical Coherence Tomography (SD-OCT) is a demanding imaging technique by which diagnosticians detect retinal diseases. Automating the procedure for early detection and diagnosis of retinal diseases has been proposed in many intricate ways through the use of image processing, machine learning, and deep learning algorithms. Unfortunately, the traditional methods are erroneous in nature and quite expensive as they require additional participation from the human diagnosticians. In this chapter, we propose a comprehensive sets novel blocks for building a deep learning architecture to effectively differentiate between different pathologies causing retinal degeneration. We further show how integrating these novel blocks within a novel network architecture gives a better classification accuracy of these disease and addresses the preexisting problems with gradient explosion in the deep residual architectures. The technique proposed in this chapter achieves better accuracy compared to the state of the art for two separately hosted Retinal OCT image data-sets. Furthermore, we illustrate a real-time prediction system that by exploiting this deep residual architecture, consisting one of these novel blocks, outperforms expert ophthalmologists. S. A. Kamran (B) · A. Tavakkoli University of Nevada, Reno, NV, USA e-mail: [email protected] A. Tavakkoli e-mail: [email protected] S. Saha · A. S. Sabbir Center for Cognitive Skill Enhancement, Independent University Bangladesh, Dhaka, Bangladesh e-mail: [email protected] A. S. Sabbir e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_2
25
26
S. A. Kamran et al.
1 Introduction Diabetes being one of the crucial health concerns affects up to 7.2% of the world population and this number can potentially mount up to 600 million by the year 2040 [6, 38]. With the pervasiveness of diabetes, statistically—one out of every three diabetic patients develop Diabetic Retinopathy (DR) [33]. Diabetic Retinopathy causes severe damage to human vision which in turn engenders vision loss that affects nearly 2.8% of world population [3]. In developed countries there exist efficacious vision tests for DR screening and early treatment yet, a by product of such systems is fallacious results. Moreover, identifying the false positives and false negatives still remains a challenging impediment for diagnosticians. On the other hand, DR is often mistreated in many developing and poorer economies, where access to trained ophthalmologist and eye-care machinery may be insufficient. Therefore, this is impending to shift the DR diagnostic technology to a system that is autonomous in the pursuit of more accurate and faster test results on DR and other related retinal diseases and inexpensive so that more people have access to it. This research proposes a novel architecture based on convolutional neural network which can identify Diabetic Retionpathy, while being able to categorize multiple retinal diseases with near perfect accuracy in real time. There are different retinal diseases other than Diabetic Retinopathy, such as Macular Degeneration. Macula is a retinal sensor that is found in the central region of retina in human eyes. The retinal lens perceive light emitted from outside sources and transform them into neural signals, a process otherwise known as vision. The Macula plays an integral role in human vision from processing light via photo-receptor nerve cells to aggregating them encoded into neural signals sent directly to the brain through optic nerves. Retinal diseases such as Macular Degeneration, Diabetic Retinopathy, and Choroidal Neovascularization are the leading causes of eye diseases and vision loss worldwide. In ophthalmology a technique called Spectral Domain Optical Coherence Tomography (SD-OCT) is used for viewing the morphology of the retinal layers [32]. Furthermore, another way to treat these diseases is to use depth-resolved tissue formation data encoded in the magnitude and delay of the back-scattered light by spectral analysis [1]. While retrieving the retinal image is performed by the computational process of SD-OCT, differential diagnosis is conducted by human ophthalmologists. Consequently, this leaves room for human error while performing differential diagnosis. Hence, an autonomous expert system is beneficial for ophthalmologists to distinguish among different retinal diseases more precisely with fewer mistakes in a more timely manner. One of the predominant factors for misclassification of retinal maladies is due to the stark similarity between Diabetic Retinopathy and other retinal diseases. They can be grouped by three major categories, (i) Diabetic Macular Edema (DME) and Agerelated degeneration of retinal layers (AMD), (ii) Drusen, a condition where lipid or protein build-up occurs in the retinal layer, and (iii) Choroidal Neovascularization (CNV), a growth of new blood vessels in sub-retinal space. The most common retinal
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
27
diseases are Diabetic Retinopathy and Age-related Macular Degeneration, worldwide [9]. On the other hand, Drusen acts as an underlying cause that can trigger DR or AMD in a prolonged time-frame. Choroidal Neovascularization (CNV), however is an advanced stage of age-related Macular degeneration that affects about 200,000 people worldwide every year [8, 37]. Despite a decade of improvements to existing algorithms, identification of retinal diseases still produces erroneous results and requires expert intervention. To address this problem, we propose a novel deep neural network architecture which not only identifies retinal diseases in real time but also performs better than human experts for specific tasks. While our emphasis is to train a deep neural network model to minimize the classification error as much as possible, we grapple with the challenges of over-fitting on data, gradient explosion and gradient vanishing, as evident in many deep neural network models. In this work, we propose a newly designed residual Convolutional Neural Network (CNN) block that helps us reduce the memory footprint while we are able to train a much deeper CNN for better accuracy. Furthermore, we propose a novel building block in our CNN architecture that contains a signal attenuation mechanism, a newly written function to conjoint previous input layers before passing onto the next one in the network. Then we further show, how these proposed signal propagation techniques can lead to building a deeper and high-precision network without succumbing to weight degradation, gradient explosion, and overfitting. In the following sections, we elaborate our principal contributions and also provide a comparative analysis of different approaches.
2 Literature Review In this section, we discuss a series of computational approaches that have been historically adopted to diagnose SD-OCT imagery. These approaches dates back to the earlier developments in image processing as well as the segmentation algorithms in pre-deep learning era of computer vision. While those developments are important and made tremendous strides in their specific domains, we also discuss the learningbased approaches and deep learning models that were created and trained both with and without transfer learning by other researchers in the pursuit of classifying retinal diseases from SD-OCT images. In this section, we further explain the pros and cons of the other deep learning classification models on SD-OCT images and the developments and contributions we achieved with our proposed deep learning architecture.
2.1 Traditional Image Analysis The earliest approach traced in the pursuit of classifying retinal diseases from images contains multiple image processing techniques followed by feature extraction and classification [24]. Evidently, one such research was conducted where retinal dis-
28
S. A. Kamran et al.
eases are classified from images by finding abnormalities such as microaneurysms, haemorrhages, exudate and cotton wool-spot from Retinal Fundus images [7]. This approach exploits a noise reduction algorithm and blurring to branch out the fourclass problem to two cases of a two-class problems. Chronologically, they perform background subtraction followed by shape estimation as feature extractor. Sequentially they compute these extracted features to classify each of the four abnormalities. In parallel, similar techniques with engineered features were adopted to detect Diabeitc Macular Edema (DME) and Choroidal Neovascularization (CNV). The images were manipulated on five discrete parameters: Retinal Thickness, augmentation of Retinal Thickening, Macular volume, retinal morphology, and vitreoretinal relationship [25]. There exists another efficacious method that compounded statistical classification with edge detection algorithms to detect sharp edges [28]. Sanchez et al.’s [28] algorithm achieved a sensitivity score of 79.6% while classifying Diabeitc Retionpathy. Ege et al.’s [7] approach incorporating Mahalanobis classifier detected microaneurysms, haemorrhages, exudates, and cottonwool spots with a sensitivity of 69, 83, 99, and 80%, respectively. It is evident that each of these techniques shows promising improvements over the others, however they are not on par with human diagnosticians in terms of precision. More effectual detection accuracy, therefore is still required for these systems to be of assistance to human diagnosticians and ophthalmologists.
2.2 Segmentation-Based Approaches One of the most pronounced ways to identify a patient with Diabetic Macular Edema is by enlarging macular density in retinal layer [1, 5]. Naturally, several systems have been proposed and implemented which comprises retinal layer segmentation. Due to evidence of liquids building up in the sub-retinal space as determined by the segmentation algorithms, further identification of factors that engenders specific diseases are made possible [17, 22, 23]. In [20, 26], the authors proposed the idea of segmenting the intra-retinal layers in ten parts and then extracted the texture and depth information from each layer. Subsequently, any aberrant retinal features are detected by classifying the dissimilarity between healthy retinas and the diseased ones. Moreover, Niemeijer et al. [22] introduced a technique for 3D segmentation of regions containing fluid in OCT images using a graph-based implementation. A graph-cut algorithm is applied to get the final predictions from the information initially retrieved from layer-based segmentation of fluid regions. Even though implementation based on a previous segmentation of retinal layers have registered high scoring prediction results, the initial step is reportedly laborious, prolonged and erroneous [10, 13]. As reported in [19], retinal thickness measurements obtained by different systems have stark dissimilarity. Therefore, it is neither efficient nor optimal to compare between different retinal depth information retrieved by separate machines, despite of the improved prediction accuracy over the feature engineering methods with traditional
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
29
image analysis discussed earlier. These observations further enforces the fact that segmentation-based approaches aren’t effective as a universal retinal disease recognition system.
2.3 Machine Learning and Deep Learning Techniques Recently, a combination of optimization algorithms from machine learning and a series of deep neural network architectures have become a popular choice among researchers and engineers, in regard to achieving state-of-the-art accuracy for recognizing various retinal diseases [18, 21, 34]. Various deep learning architectures with necessary modifications described in [36] can be employed to classify retinal diseases from SD-OCT images as well. Awais et al. proposed a combination of VGG16 [31] with KNN and Random forest classifier with 100 trees to create a deep classification model in order to differentiate between Normal Retina and Diabetic Macular Edema. Lee et al. trained a standalone VGG16 architecture with a binary output to classify Age-related Macular Edema (AMD) and cases with no retinal degeneration [18]. While these systems exploit learning-based feature extraction from the vicinity of large-scale image data-sets, the neural network models are computationally ineffective and suffers from large memory footprint. On the contrary, transfer learning methods depend on weeks of training on millions of images and are not ideal for finding stark differences between Retinal diseases. To help alleviate from all of these challenges an architecture is necessary which is specially catered for identifying retinal deceases with high precision, speed, and low memory usage.
2.4 Our Contributions In this work, we propose a novel convolutional neural network which specializes in identifying retinal diseases with near perfect precision. Moreover, through this architecture we are proposing (a) a new residual unit subsuming Atrous Separable Convolution, (b) a novel building block, and (c) a mechanism to prevent gradient degradation. The proposed network outperforms other architectures with respect to the number of parameters, accuracy, and memory size. Our proposed architecture is trained from scratch and bench-marked on two publicly available data-sets: OCT2017 [16], Srinivasan2014 [32] data-sets. Henceforth, it doesn’t require any pre-trained weights, reducing the training and deployment time of the model by many folds. We believe with the deployment of this model, rapid identification and treatment can be carried out with near perfect certainty. Additionally, it will aid the ophthalmologist to get a second expert opinion in the course of differential diagnosis. This work is an extension of our previous work [14] where we experiment with different novel residual convolutional block architectures and achieve stat-of-the-art performance on both OCT2017 [16] data-set and Srinivasan2014 [32] data-set. In
30
S. A. Kamran et al.
this chapter, we illustrate an exploratory analysis of laterally distinguishable novel residual blocks and discuss the methodologies and observations that help us select the optimal block in order to create our CNN architecture. In this chapter, we further show our results on fivefold training and demonstrate the efficacy of our model and discuss how our architectural design prevents the model from over-fitting the data at hand. Along with that, we illustrate the differences of training on different variants of our proposed model and further show our deployment pipeline in this chapter.
3 Proposed Methodology In this section, we discuss our proposed methodologies and observations adopted toward designing the proposed CNN architecture. We first elaborate how we train both data-sets on different residual units, each with unique lateral propagation architecture. We then discuss how we select our proposed residual unit based on the observations from training on the other variants. Sequentially, we illustrate how we join our proposed residual unit with a signal attenuation mechanism and a newly written signal propagation function to prevent gradient degradation. Subsequently, we then demonstrate our proposed CNN architecture and the efficacy and novelty of the model. In Fig. 1 we exemplify different variants of residual unit, their attributes, and how we arrive at our proposed variant of the residual block. Figure 2 illustrates the Deep Convolutional Neural Network (CNN) architecture we propose for the classification of retinal diseases from Optical Coherence Tomography (OCT) images. In Fig. 2a we delineate how the proposed Residual Learning Unit improves feature learning capabilities while discussing the techniques we adopt to reduce computational complexity
Fig. 1 Different variants of residual unit and our proposed residual unit with a novel lateral propagation of neurons
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
31
Fig. 2 Illustration of the Building Blocks of our proposed CNN [ OpticNet-71 ]. Only the very first convoluiton (7 × 7) layer in the CNN and the very last convolution (1 × 1) layer in Stage [2, 3, 4]: Residual Convolutional Unit uses stride 2, while all other convolution operations use stride 1
for such performance enhancement. While, Fig. 2b depicts the proposed mechanism to handle gradient degradation, Fig. 2c narrates the entire CNN architecture. We discuss constituent segments of our CNN architecture—Optic Net and philosophy behind our lateral design schemes over the following subsections.
3.1 Lateral and Operation-Wise Variants of Residual Unit To design a model with higher efficacy, lower memory footprint and portable computational complexity we experimented with different types of convolution operations and various lateral choices in the residual unit of our proposed neural network.
3.1.1
Vanila Residual Unit
Figure 1a represents the propagation of neurons in a vanilla residual unit with preactivation [11, 12]. While achieving a high prediction accuracy, the first row of Table 1 however reports higher memory and computational footprint resulting from such a design. The natural progression was to include a kernel with wide receptive field, hence inclusion of dilation in the kernel operation as shown by Yu et al. [39]. The following Sect. 3.1.2 addresses the mechanism of this block in details.
32
S. A. Kamran et al.
Table 1 Comparison between different convolution operations used in the middle portion of residual unit Type of convolution used in Approximate numbers Depletion factor for Accuracy the middle of residual unit of parametersa parameter, p (%) (%) f 2 × D [i] × D [i−1] = 36,864 Atrous convolution ( f − 1)2 × D [i] × D [i−1] = 16,384 Separable convolution ( f 2 + D [i] ) × D [i−1] = 4,672 Atrous separable convolution (( f − 1)2 + D [i] ) × D [i−1] = 4,352 Regular convolution
Atrous convolution and atrous separable convolution branched
1 1 [i] 2 2 (( f − 1) (1 + 2 D ) 1 [i] [i−1] +2 D ) × D =
5,248
100 (1 − 1 f2
+
99.30 1 2 f) 1 D [i]
= 44.9
97.20
= 12.5
98.10
1 + (1 − 1f )2 × D1[i] = f2 11.6 1 + (1 − 1f )2 × (2 f )2
( 41 +
1 )= 2D [i]
96.70 99.80
14.4
kernel size, (f, f) = (3, 3). Depth (# kernels) in Residual unit’s middle operation, D [i] = 64 and first operation, D [i−1] = 64. b The Test Accuracy reported in the table is obtained by training on OCT2017 [16] data-set, while the backbone network is Optic-Net 71 a Here,
3.1.2
Atrous Residual Unit
To reduce computational parameters, we then replace the middle 3 × 3 convolution block in residual unit with a 2 × 2 atrous convolution with dilation rate two, as detailed in Fig. 1b. Figure 3 further illustrates how atrous convolution with skipped feature extraction capture minor details in the signal while reducing the number of parameters by a reasonable margin—detailed in the second row of Table 1. On the other hand, atrous residual unit registers a rather poor performance on our data-set. What’s more, by not incorporating any depth information while doing convolution operation, the spatial information overflows throughout the architecture. Resulting in more error-prone results for borderline diagnosis. To address this problem depth-wise convolution or Separable convolution is used in the next Sect. 3.1.3.
3.1.3
Separable Residual Unit
Concurrently we redesigned the residual unit with an atrous separable convolution block in the middle feature extraction module, outlined in Fig. 1c. With a depthwise convolution followed by a point-wise operation we achieved a much lower computational stress with a relatively better inference accuracy than we do on atrous residual unit, as we report in the third row of Table 1. The reason being small depth information dominates throughout the architecture. We tried to address this problem by incorporating larger receptive fields using dilation in the depth-wise convolution layer inside the Separable Residual unit which is discussed in details in Sect. 3.1.4.
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
3.1.4
33
Atrous Separable Residual Unit
To further contract computational strain, we then take the separable residual unit and replace the depth-wise 3 × 3 convolution block with an atrous convolution block of 2 × 2 with dilation rate two, as shown in Fig. 1d. Figure 3 further demonstrates the mechanism of atrous separable convolution operation on signals. With this design choice, we cut the parameter count by 87.4% which is the fastest computational complexity in our experiment (Fourth row of Table 1). However, the repetitive use of such residual unit on our proposed neural network accumulates to the lowest prediction accuracy on our data-set. Furthermore, using only separable convolution results to depth-wise feature extraction in most of the identity blocks while not encompassing any spatial information in any of those learnable layers. That effectively answers the underlying reason for having such low accuracy out of all the novel residual units. Observations made on such trade-offs between computational complexity and performance we arrive at our proposed residual (Fig. 1e) unit which we discuss next in Sect. 3.2.
Fig. 3 Atrous separable convolution. Exploiting ( f − 1) × ( f − 1) convolutions in stead of f × f that yields more fine grained and coarse features with better depth resolution compared to regular atrous convolution
34
S. A. Kamran et al.
3.2 Proposed Residual Unit and Learning Mechanism Historically, Residual Units [11, 12] used in Deep Residual Convolutional Neural Networks (CNN), process the incoming input through three convolution operations while adding the incoming input with the processed output. These three convolutional operations are (1 × 1), (3×3) and (1 × 1) convolutions. Therefore, replacing the (3 × 3) convolution in the middle with other types of convolutional operations can potentially change the learning behavior, computational complexity, and eventually prediction performance, as demonstrated in Fig. 1. We experimented with different convolution operations as replacement for the (3 × 3) middle convolution and observed which choice contributes the most to reduce the number of parameters, ergo computational complexity, as depicted in Table 1. Furthermore, in Table 1, we use a depletion factor for parameters, p which is a ratio of number of parameters in the replaced convolution and regular convolution expressed in percent. In our proposed residual unit, we replace the middle (3 × 3) convolution operation with two different operations running in parallel as detailed in Fig. 2a. Whereas, a conventional residual unit uses D [i] number of channels for the middle convolution, we use 21 D [i] number of channels for each of the newly replaced operations to prevent any surge in parameter. In the proposed branching operation we use a (2 × 2) Atrous convolution (C2 ) with dilation rate, r = 2 to get a (3 × 3) receptive field in the left branch while in the right branch we use a (2 × 2) Atrous separable convolution (C3 ) with dilation rate, r = 2 to get a (3 × 3) receptive field. Sequentially, the results are then added together. Furthermore, separable convolution [30] disentangles the spatial and depth-wise feature maps separately while Atrous convolutions inspect both spatial and depth channels together. We hypothesize that adding two such feature maps that are learned very differently shall help trigger more robust and subtle features. X l+1 = X l + (X l C1 C2 ) + (X l C1 C3 ) C4 ˆ l , Wl ) = X l + F(X
(1)
Figure 3 shows how adding Atrous and Atrous separable feature maps help disentangle the input image space with more depth information instead of activating only the predominant edges. Moreover, the last row of Table 1 confirms that adopting this strategy still reduces the computational complexity by a reasonable margin, while improving inference accuracy. Equation (1) further clarifies how input signals X l travel through the proposed residual unit shown in Fig. 2a, where refers to convolution operation.
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
35
3.3 Proposed Building Block and Signal Propagation In this section we discuss the proposed building block as a constituent part of OpticNet. As shown in Fig. 2b we split the input signal (X l ) into two branches (1) Stack of Residual Units, (2) Signal Exhaustion. Later in this section, we explain how we connect these two branches to propagate signals further in the network. α(X l ) = X l+N = X l +
N
ˆ i , Wi ) F(X
(2)
i=l
3.3.1
Stack of Residual Units
In order to initiate a novel learning chain for propagating signals through stacking several of the proposed residual units linearly, we suggest to combine global residual effects enhanced by pre-activation residual units [12] and our proposed set of convoˆ l , Wl ) denotes all the proposed set lution operations (Fig. 2a). As shown in (1), F(X of convolution operations inside a residual unit for input X l . We sequentially stack these residual units N times over X l which is input to our proposed building block, as narrated in Fig. 2b. Equation (2) illustrates the state of output signal denoted by X l+N which is processed through a stack of residual units of length N . For the sake of further demonstration we denote X l+N as α(X l ).
3.3.2
Signal Exhaustion
In the proposed building block, we propagate the input signal X l through an Max-pooling layer to achieve spatial down-sampling which we then up-sample through Bi-linear interpolation. Since the down-sampling module only forwards the strongest activations, the interpolated reconstruction makes a dense spatial volume from the down-sampled representation—intrinsically exhausting the incoming signal X l . As detailed in Fig. 2b, we sequentially pass the exhausted signal space through sigmoid activation, σ(X l ) = 1/(1 + e−X l ). Recent research [35] has shown how auto-encoding with residual skip connections improve attention[Pencoder (input|code) → Pdecoder (code|input) + input] oriented classification performance. However unlike auto-encoders, max-pooling, and Bi-linear interpolation functions are not enabled with learning mechanism. In Optic-Net, we capacitate the CNN to activate spikes from an exhausted signal space because we use it as a mechanism to avert gradient degradation. For the sake of further demonstration we denote the exhausted signal activation module, σ(X l ) as β(X l ). τ (X l ) = α(X l ) + β(X l ) + α(X l ) × β(X l ))
(3)
36
S. A. Kamran et al.
N ∂ ˆ 1 + σ(X l ) × 1 + F(X i , Wi ) ∂ X l i=l N ˆ i , Wi ) × σ(X l ) × 1 − σ(X l ) + 1 + Xl + F(X
∂τ (X l ) = ∂ Xl
i=l
N ∂ ˆ σ (X l ) × 1+ F(X i , Wi ) 1 − σ(X l ) ∂ X l i=l + 1 + X l+N × σ (X l ) = 1+
3.3.3
(4)
Signal Propagation
As shown if Fig. 2b, we process the residual signal, α(X l ) and exhausted signal, β(X l ) following (3) and we denote the output signal propagated from the proposed building block as τ (X l ). Our hypothesis behind such design is that, whenever one of the branch falls prey to gradient degradation from a mini-batch the other branch manages to propagate signals unaffected by the mini-batch with amplified absolute gradient. To validate our hypothesis (3) shows that, τ (X l ) ≈ α(X l ), ∀β(X l ) ≈ 0 and τ (X l ) ≈ β(X l ), ∀α(X l ) ≈ 0 illustrating how the unaffected branch survives the degradation in the affected branch. However, when none of the branch gets affected by gradient amplification the multiplication (α(X l ) × β(X l )) balances out the increase in signal propagation due to both branch’s addition. Equation (4) delineates the gradient of building block output τ (X l ) with respect to building block input X l calculated during back-propagation for optimization.
3.4 CNN Architecture and The Optimization Chain Figure 2c portrays the entire CNN architecture with all the building blocks and constituent components joined together. First, the input batch (224 × 224 × 3) is propagated through a 7 × 7 Conv with stride 2 that follows batch-normalization and ReLU activation. Then we propagate the signals via a Residual Convolution Unit (same as the unit used in [12]) which is then followed by our proposed building block. We propagate the signals through this [Residual Convolution Unit → Building Block] procedure for S = 4 times, as we call them stage 1, 2, 3, and 4, respectively. Then global average pooling is applied to the signals which passes through two more Fully Connected(FC) layers for the loss function which is denoted by ξ. In Table 2, we show the number of feature maps (Layer Depth) we use for each layer in the network. The output shape of the input tensor after four consecutive stages are (112×112×256), (56×56×512), (28×28×1024), and (14×14×2048),
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
37
Table 2 Architectural specifications for opticnet-71 and layer-wise analysis for number of feature maps in comparison with Resnet50-v1 [11] Layer name ResNet50 V1 [11] OpticNet71 [Ours] Conv 7 × 7 Stage1: Res Conv Stage1: Res Unit Stage2: Res Conv Stage2: Res Unit Stage3: Res Conv Stage3: Res Unit Stage4: Res Conv Stage4: Res Unit Global Avg Pool Dense layer 1 Dense layer 2 Parameters Required FLOPs CNN Memory
[64] × 1 [64, 64, 256] × 1 [64, 64, 256] × 2 [128, 128, 512] × 1 [128, 128, 512] × 3 [256, 256, 1024] × 1 [256, 256, 1024] × 5 [512, 512, 2048] × 1 [512, 512, 2048] × 2 2048 K (Classes) – 25.64 Million 3.8 ×109 98.20 MB
[64] × 1 [64, 64, 256] × 1 [32, 32, 32, 256] × 4 [128, 128, 512] × 1 [64, 64, 64, 512] × 4 [256, 256, 1024] × 1 [128, 128, 128, 1024] × 3 [512, 512, 2048] × 1 [256, 256, 256, 2048] × 3 2048 256 K (Classes) 12.50 Million 2.5 ×107 48.80 MB
respectively. Moreover, FC1 ∈ R2048×256 and FC2 ∈ R256×K , where K = number of classes. Stage ∂τ (X j ) ∂ξ ∂ξ l (5) = × Stage j ∂ Xl ∂ X ∂τ (X l ) l j=1 Equation (5) represents the gradient calculated for the entire network chain distributed over stages of optimization. As (4) suggests, the term (1 + σ(X l ))—in comparison with [12])—works as an extra layer of protection to prevent possible gradient explosion caused by the stacked residual units by multiplying nonzero activations with the residual unit’s gradients. Moreover, the term (1 + X l+N ) indicates that the optimization chain still has access to signals from much earlier in the network and to prevent unwanted spikes in activations the term σ (X l ) can still mitigate gradient expansion which can potentially jeopardize learning otherwise.
4 Experiments The following section contains information for training, validating, and testing the architectures with different settings and hyper-parameters. Moreover, it gives a detailed analysis of how the architectures were compared with previous techniques and expert human diagnosticians. Additionally, the juxtaposition was drawn in terms of speed, accuracy, sensitivity, specificity, memory usage, and penalty weighted met-
38
S. A. Kamran et al.
rics. All the files related to the experimentation and training can be found in the following Code Repository: https://github.com/SharifAmit/OCT_Classification.
4.1 Specifications of Data-Sets and Preprocessing Techniques We benchmark our model against two distinct data-sets (different scale, sample space, etc.). The first data-set aims at correctly recognizing and differentiating between four distinct retinal states provided by the OCT2017 [16] data-set. Where, the stages are normal healthy retina, Drusen, Choroidal Neovascularization (CNV), and Diabetic Macular Edema (DME). OCT2017 [16] data-set contains 84,484 images (provided as high quality TIFF format with three non-RGB color channels). We split them into 83,484 train-set and 1000 test-set. The second data-set—Srinivasan2014 [32]— consists of three classes and aims at classifying normal healthy specimen of retina, Age-Related Macular Degeneration (AMD) and Diabetic Macular Edema (DME). Srinivasan2014 [32] data-set consists of 3,231 image samples that we split into 2,916 train-set, 315 test-set. We resize images from both data-sets to 224 × 224 × 3 for both training and testing. For both the data-set we do fivefold cross-validation on the training set and find the best models.
4.2 Performance Metrics We calculated four standard metrics to evaluate our CNN model on both data-sets: Accuracy (6), Sensitivity (7), Specificity (8) and a Special Weighted Error (9) from [16]. Where N is the number of image samples and K is the number of classes. Here TP, FP, FN, and TN denotes True Positive, False Positive, False Negative, and True Negative, respectively. We report True Positive Rate (TPR) or Sensitivity (6) and True Negative Rate (TNR) or Specificity (7) for the both the data-sets [16, 32]. For this, we calculate the TPR and TNR for individual classes then sum all the values and then divide that by the number of classes (K). 1 TP N
(6)
Sensitivity =
TP 1 K T P + FN
(7)
Specificity =
1 TN K T N + FP
(8)
Accuracy =
Weighted Error =
1 Wi j · X i j N i, j∈K
(9)
A Comprehensive Set of Novel Residual Blocks for Deep Learning … Table 3 Penalty weights proposed for Oct2017 [16] Normal Drusen Normal Drusen CNV1 DME2 1 CNV: 2 DME:
0 1 4 4
1 0 2 2
39
CNV1
DME2
1 1 0 1
1 1 1 0
Chorodial Neovascularization Diabetic Macular Edema
Fig. 4 Confusion matrix generated by OpticNet-71 for OCT2017[16] data-set
As reported in [16], the penalty points for incorrect categorization of a retinal disease can be arbitrary. Table 3 shows the penalty weight values for misidentifying a category set by [16] which is only specific to OCT2017 [16] data-set. To calculate Weighted Error (9), we apply element-wise multiplication on the confusion matrix generated by specific model (Fig. 4 represents the confusion matrix generated by OpticNet-71 on OCT2017 [16] data-set) and the weight matrix in Table 3 and then take an average over the number of samples. Here, the penalty weight values from Table 3 is denoted by W and the model’s prediction (confusion matrix) is denoted by X where i, j denotes the rows and columns of the confusion matrix.
4.3 Training OpticNet-71 and Obtained Results 4.3.1
OCT2017 Data-Set
In Table 4, we report a comprehensive study for OCT2017 [16] data-set evaluated through testing standards such as Test Accuracy, Sensitivity, Specificity, and Weighted Error. OpticNet-71 scores the highest Test Accuracy (99.80%) among
40
S. A. Kamran et al.
Table 4 Results on Oct2017 [16] data-set Architectures Test accuracy InceptionV3 (limited) Human expert 2 [16] InceptionV3 [16] ResNet50-v1 [11] MobileNet-v2 [29] Human expert 5 [16] Xception [4] OpticNet-71 [Ours]
93.40 92.10 96.60 99.30 99.40 99.70 99.70 99.80
Sensitivity
Specificity
Weighted error
96.60 99.39 97.80 99.30 99.40 99.70 99.70 99.80
94.00 94.03 97.40 99.76 99.80 99.90 99.90 99.93
12.70 10.50 6.60 1.00 0.60 0.40 0.30 0.20
other existing solutions, with a Sensitivity and Specificity of 99.80 and 99.93%. Furthermore, the Weighted Error is reported to be a mere 0.20% which can be visualized in Fig. 4 as our architecture misidentifies one Drusen and one DME sample as CNV. However, the penalty weight is only 1 for each of the misclassification as we report in Table 3. Sequentially, with our proposed OpticNet-71 we obtain state-ofthe-art results on OCT2017 [16] data-set across all four performance metrics, while significantly surpassing human benchmarks as mentioned in Table 4.
4.3.2
Srinivasan2014 Data-Set
We benchmark OpticNet-71 against other methods in Table 5 while evaluating Srinivasan2014 [32] data-set through three metrics: Accuracy, Sensitivity, and Specificity. Among the mentioned solutions in Table 5 Lee et al. [18] use modified VGG-16, Awais et al. [2] use VGG architecture with KNN in final layer and Karri et al. [15] uses GoogleNet while they all use weights from transfer learning on ImageNet [27]. As shown in Table 5, OpticNet-71 achieves state-of-the-art result by scoring 100% Accuracy, Sensitivity, and Specificity. Furthermore, we train ResNet50-v1 [11], ResNet50-v2 [12], MobileNet-v2 [29], and Xception [4] using pre-trained weights from 3.2 million ImageNet Data-set consisting of 1000 categories [27] to compare with our achieved results (Tables 4 and 5), while we train Optic-Net from scratch with randomly initialized weights.
4.4 Hyper-Parameter Tuning and Performance Evaluation The hyper-parameters while training OpticNet-47, OpticNet-63, OpticNet-71, MobileNet-v2 [29], XceptionNet [4], ResNet50-v2 [12], ResNet50-v1 [11] are as follows: batch size, b = 8; epochs = 30; learning rate, αlr = 1e−4 ; step decay, γ = 1e−1 . lr × γ, if validation We use adaptive learning rate and decrease it using αlrnew = αcurrent
A Comprehensive Set of Novel Residual Blocks for Deep Learning … Table 5 Results on Srinivasan2014 [32] data-set Architectures Test accuracy Sensitivity Lee et al. [18] Awais et al. [2] ResNet50-v1 [11] Karri et al. [15] MobileNet-v2 [29] Xception [4] OpticNet-71 [Ours]
87.63 93.00 94.92 96.00 97.46 99.36 100.00
84.63 87.00 94.92 – 97.46 99.36 100.00
41
Specificity 91.54 100.00 97.46 – 98.73 99.68 100.00
loss doesn’t lower for six consecutive epochs. Moreover, we set the lowest learning lr = 1e−8 . Furthermore, We use Adam optimizer with default parameters rate to αmin adam = 0.90 and β2adam = 0.99 for all training schemes. We train OCT2017 [16] of β1 data-set for 44 hours and Srinivasan2014 [32] data-set for 2 hours on a 8 GB NVIDIA GTX 1070 GPU. Inception-v3 models under-perform compared to both pre-trained models and OpticNet-71 as seen in Table 4. OpticNet-71 takes 0.03 seconds to make prediction on an OCT image—which is real time and while accomplishing state-of-the-art results on OCT2017 [16], Srinivasan2014[32] data-set our model also surpass human level prediction on OCT images as depicted in Table 4. Human experts are real diagnosticians as reported in [16]. In [16], there are six diagnosticians and the highest performing one is Human Expert 5 while the lowest performing one is Human Expert 2. To validate our CNN architecture’s optimization strength we also train two smaller versions of OptcNet-71 on both dataests, which are OpticNet-47 ( [N1 N2 N3 N4 ] = [2 2 2 2] ) and OpticNet-63 ( [N1 N2 N3 N4 ] = [3 3 3 3] ). In Fig. 5 we unfold how all of our variants of OpticNet outperforms the pre-trained CNNs on Srinivasan2014 [32] data-set while OpticNet-71 outperforms all the pre-trained CNNs on OCT2017 [16] data-set in terms of accuracy as well as performance-memory trade-off.
4.5 Analysis of Proposed Residual Interpolated Block To understand how the Residual Interpolated Block works, we visualize features by passing a test image through our CNN model. Figure 6a illustrates some of the sharp signals propagated by Residual blocks while the interpolation reconstruction routine propagates a weak signal activation, yet the resulting signal space is both more sharp and fine grained compared to their Residual counterparts. Since the conv layers in the following stage activates the incoming signals first, we do not output an activated signal space from a stage. Instead we only activate the interpolation counterpart and then multiply with the last residual block’s non-activated output space while adding the raw signals with the multiplied signal as well—which we
42
S. A. Kamran et al.
Fig. 5 Test accuracy (%), CNN memory (Mega-Bytes) and model parameters (Millions) on OCT2017 [16] data-set and Srinivasan2014 [32] data-set
consider as output from each stage as narrated in Fig. 6b. Furthermore, Fig. 6b portrays how element-wise addition with the element-wise multiplication between signals helps the learning propagation of OpticNet-71. Figure 6b precisely depicts why this optimization chain is particularly significant, as a zero activation can cancel out a live signal channel from the residual counterpart (τ (X l ) = α(X l ) + β(X l ) × (1 + α(X l ))) while a dead signal channel can also cancel out a nonzero activation from the interpolation counterpart (τ (X l ) = β(X l ) + α(X l ) × (1 + β(X l )))—thus preventing all signals of a stage from dying and resulting in catastrophic optimization failure due to dead weights or gradient explosion.
4.6 Fivefold Training on OCT2017 Data-set We ran a fivefold cross-validation on OCT2017 [16] data-set to mitigate over-fitting and assert our model’s generalization potency. We randomly split the 84,484 images into 5 subsets, each containing approximately 16896 images. We validate our model on one of the subsets after training the model on the remaining 4 subsets. We continue this process for 5 times to validate on all different subsets. In Fig. 7 we report the best training and validation accuracy among all five folds. Figure 7 also depicts the arithmetic average of training and validation accuracy among all five folds. Furthermore, each accuracy is reported for a particular batch (size of 8) with all 30 epochs registered in the X-axis of Fig. 7. Following the similar manner we report best training loss, best validation loss, average training loss and average validation loss for
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
43
Fig. 6 a Visualizing input images from each class through different layers of Optic-Net 71. As shown, the feature maps at the end of each building block learns more fine-grained features by focusing sometimes on the same shapes—rather in different regions of the image—learning to decide what features lead the image to the ground truth. b The learning progression, however, shows how exhausting the signal propagated with residual activation learns to detect more thin edges—delving further into the Macular region to learn anomalies. While using the signal exhaustion mechanism sometimes, important features can be lost during training. Our experiments show, by using more of these building blocks we can reduce that risk of feature loss and improve overall optimization for Optic-Net 71
OCT2017[16] data-set in Fig. 8. As Fig. 7 illustrates, the best and average accuracy from both training and validation set reaches a stable maxima after training and this phenomena attest to the efficacy of our model. Furthermore, in Fig. 8 we report a discernible discrepancy between the average validation loss and best validation loss among the fivefold which assures that our model’s performance is not a resultant factor of over-fitting.
4.7 Training different variants of OpticNet-71 on OCT2017 We experimented with different variants of residual units to find the best version of OpticNet-71 for both the data-set. First, we ran the experiment with a vanilla convolution on OCT2017 [16] data-set. But with that we were under-performing against Human-expert 5. After that we incorporated dilation with convolution layers with which we reached sensitivity of 99.73% and weighted error of 0.8% as shown in Table 6. So, it was worse than using residual units with vanilla convolution. Next, we tried with separable convolution and its dilated counterpart with which we achieved a precision score of 99.30% (vanilla) and 99.40% (dilated) consecutively. Lastly, we tried out the proposed novel residual unit consisting of dilated convolution on one branch and separable convolution with dilation on the other. With this, we reached
44
S. A. Kamran et al.
Fig. 7 Average accuracy and best accuracy of the fivefold cross-validation on the OCT2017 [16] validation and training
Fig. 8 Average accuracy and minimum loss of the fivefold cross-validation on the OCT2017 [16] validation and training
our desired precision score of 99.93% and weighted error of 0.2% beating Human Expert 5 and reached state-of-the-art accuracy. It is worthwhile to mention that the hyper-parameters for all the architectures were same. Moreover, the training was done with the same optimizer (Adam) for 30 epochs. So, it is quite evident that OpticNet-71, comprising of dilated convolution and dilated separable convolution, was the optimum choice for deployment and prediction in the wild.
A Comprehensive Set of Novel Residual Blocks for Deep Learning … Table 6 Results on OCT2017 [16] using different variants of OpticNet-71 Architectures Test accuracy Sensitivity Specificity Optic-Net 71 (vanilla convolution) Optic-Net 71 (dilated convolution) Optic-Net 71 (separable convolution) Optic-Net 71 (separable convolution with dilation) Optic-Net 71 (dilated convolution + separable convolution with dilation)
45
Weighted error
99.40
99.40
99.80
0.60
99.20
99.20
99.73
0.80
99.30
99.30
99.76
0.70
99.40
99.40
99.80
0.60
99.80
99.80
99.93
0.20
4.8 Deployment Pipeline of Our System In this section, we expound the application pipeline we have used to deploy our model for users to interact with. The application pipeline consists of two fragments as depicted in Fig. 9: (a) The Front End, (b) The Back End. The user interacts with
Fig. 9 Application pipeline of optic net
46
S. A. Kamran et al.
an app in the front end where one can upload an OCT image. The input image which is an image uploaded by user on the app is then passed onto the back end. The input image first goes through a pre-defined set of preprocessing steps and then gets forwarded to our CNN model (Optic Net-71). All of these processes take place on the server. The class prediction score outputted by our model is then sent back to the app corresponding to the input image specific request by user. For our CNN is lightweight and capable of outputting a high precision prediction in real time, it facilitates a smooth user experience. Meanwhile, the uploaded image along with it’s prediction tensor is simultaneously stored on cloud storage for further fine-tuning and training of our model which expedites the goal of heightening system precision and widens new horizons to foster further research and development.
5 Conclusion In this chapter, we introduced a novel sets of residual blocks that can be infused to build a convolutional neural network for abridging the relation between diagnosing retinal diseases with expert level precision. Additionally, by exploiting this architecture we devise a practical solution that can address the problem of vanishing and exploding gradients. This work is an extension of our previous work [14] which illustrates the exploratory analysis of different novel blocks and how effective it is in the diagnosis of retinal degeneration. In future, we would like to expand on this research to address other sub-types of retinal degeneration and isolate the boundaries of the macular subspace in retina, which in turn will assist the expert ophthalmologist to carry out their differential diagnosis, Acknowledgments We would like to thank https://www.cse.unr.edu/CVL/ “UNR Computer Vision Laboratory” and http://ccse.iub.edu.bd/ “Center for Cognitive Skill Enhancement” for providing us with the technical support.
References 1. K. Alsaih, G. Lemaitre, M. Rastgoo, J. Massich, D. Sidibé, F. Meriaudeau, Machine learning techniques for diabetic macular edema (dme) classification on sd-oct images. Biomed. Eng. Online 16(1), 68 (2017) 2. M. Awais, H. Müller, T.B. Tang, F. Meriaudeau, Classification of sd-oct images using a deep learning approach, in 2017 IEEE International Conference on Signal and Image Processing Applications (ICSIPA) (IEEE, 2017), pp. 489–492 3. R.R. Bourne, G.A. Stevens, R.A. White, J.L. Smith, S.R. Flaxman, H. Price, J.B. Jonas, J. Keeffe, J. Leasher, K. Naidoo et al., Causes of vision loss worldwide, 1990–2010: a systematic analysis. Lancet Glob. Health 1(6), e339–e349 (2013) 4. F. Chollet, Xception: deep learning with depthwise separable convolutions, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2017), pp. 1251–1258
A Comprehensive Set of Novel Residual Blocks for Deep Learning …
47
5. R.A. Costa, M. Skaf, L.A. Melo Jr., D. Calucci, J.A. Cardillo, J.C. Castro, D. Huang, M. Wojtkowski, Retinal assessment using optical coherence tomography. Prog. Retin. Eye Res. 25(3), 325–353 (2006) 6. C. Prevention et al., National diabetes statistics report, 2017 (2017) 7. B.M. Ege, O.K. Hejlesen, O.V. Larsen, K. Møller, B. Jennings, D. Kerr, D.A. Cavan, Screening for diabetic retinopathy using computer based image analysis and statistical classification. Comput. Methods Programs Biomed. 62(3), 165–175 (2000) 8. N. Ferrara, Vascular endothelial growth factor and age-related macular degeneration: from basic science to therapy. Nat. Med. 16(10), 1107 (2010) 9. D.S. Friedman, B.J. O’Colmain, B. Munoz, S.C. Tomany, C. McCarty, P. De Jong, B. Nemesure, P. Mitchell, J. Kempen et al., Prevalence of age-related macular degeneration in the united states. Arch Ophthalmol 122(4), 564–572 (2004) 10. I. Ghorbel, F. Rossant, I. Bloch, S. Tick, M. Paques, Automated segmentation of macular layers in OCT images and quantitative evaluation of performances. Pattern Recognit. 44(8), 1590–1603 (2011) 11. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778 12. K. He, X. Zhang, S. Ren, J. Sun, J.: Identity mappings in deep residual networks, in European Conference on Computer Vision (Springer, 2016), pp. 630–645 13. R. Kafieh, H. Rabbani, S. Kermani, A review of algorithms for segmentation of optical coherence tomography from retina. J. Med. Signals Sens. 3(1), 45 (2013) 14. S.A. Kamran, S. Saha, A.S. Sabbir, A. Tavakkoli, Optic-net: a novel convolutional neural network for diagnosis of retinal diseases from optical tomography images, in 2019 18th IEEE International Conference On Machine Learning And Applications (ICMLA) (2019), pp. 964– 971 15. S.P.K. Karri, D. Chakraborty, J. Chatterjee, Transfer learning based classification of optical coherence tomography images with diabetic macular edema and dry age-related macular degeneration. Biomed. Opt. Express 8(2), 579–592 (2017) 16. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G. Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 172(5), 1122–1131 (2018) 17. A. Lang, A. Carass, M. Hauser, E.S. Sotirchos, P.A. Calabresi, H.S. Ying, J.L. Prince, Retinal layer segmentation of macular oct images using boundary classification. Biomed. Opt. Express 4(7), 1133–1152 (2013) 18. C.S. Lee, D.M. Baughman, A.Y. Lee, Deep learning is effective for classifying normal versus age-related macular degeneration oct images. Ophthalmol. Retin. 1(4), 322–327 (2017) 19. J.Y. Lee, S.J. Chiu, P.P. Srinivasan, J.A. Izatt, C.A. Toth, S. Farsiu, G.J. Jaffe, Fully automatic software for retinal thickness in eyes with diabetic macular edema from images acquired by cirrus and spectralis systems. Investig. ophthalmol. Vis. Sci. 54(12), 7595–7602 (2013) 20. K. Lee, M. Niemeijer, M.K. Garvin, Y.H. Kwon, M. Sonka, M.D. Abramoff, Segmentation of the optic disc in 3-d OCT scans of the optic nerve head. IEEE Trans. Med. Imaging 29(1), 159–168 (2010) 21. G. Lemaître, M. Rastgoo, J. Massich, C.Y. Cheung, T.Y. Wong, E. Lamoureux, D. Milea, F. Mériaudeau, D. Sidibé, Classification of sd-oct volumes using local binary patterns: experimental validation for dme detection. J. Ophthalmol. 2016 (2016) 22. X.C. MeindertNiemeijer, L.Z.K. Lee, M.D. Abràmoff, M. Sonka, 3d segmentation of fluidassociated abnormalities in retinal oct: Probability constrained graph-search-graph-cut. IEEE Trans. Med. Imaging 31(8), 1521–1531 (2012) 23. A. Mishra, A. Wong, K. Bizheva, D.A. Clausi, Intra-retinal layer segmentation in optical coherence tomography images. Opt. Express 17(26), 23719–23728 (2009) 24. H. Nguyen, A. Roychoudhry, A. Shannon, Classification of diabetic retinopathy lesions from stereoscopic fundus images, in Proceedings of the 19th Annual International Conference of the IEEE Engineering in Medicine and Biology Society.’Magnificent Milestones and Emerging Opportunities in Medical Engineering (Cat. No. 97CH36136), vol. 1 (IEEE, 1997), pp. 426– 428
48
S. A. Kamran et al.
25. G. Panozzo, B. Parolini, E. Gusson, A. Mercanti, S. Pinackatt, G. Bertoldo, S. Pignatto, Diabetic macular edema: an oct-based classification. Semin. Ophthalmol. 19, 13–20 (Taylor & Francis) (2004) 26. G. Quellec, K. Lee, M. Dolejsi, M.K. Garvin, M.D. Abramoff, M. Sonka, Three-dimensional analysis of retinal layer texture: identification of fluid-filled regions in sd-oct of the macula. IEEE Trans. Med. imaging 29(6), 1321–1330 (2010) 27. O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein et al., Imagenet large scale visual recognition challenge. Int. J. Comput. Vis. 115(3), 211–252 (2015) 28. C.I. Sánchez, R. Hornero, M.I. Lopez, J. Poza, Retinal image analysis to detect and quantify lesions associated with diabetic retinopathy, in The 26th Annual International Conference of the IEEE Engineering in Medicine and Biology Society, vol. 1 (IEEE, 2004), pp. 1624–1627 29. M. Sandler, A. Howard, M. Zhu, A. Zhmoginov, L.C. Chen, Mobilenetv2: inverted residuals and linear bottlenecks, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4510–4520 (2018) 30. L. Sifre, S. Mallat, Rigid-motion scattering for image classification. Ph.D. thesis, vol. 1, no. 3 (2014) 31. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recognition (2014). arXiv preprint arXiv:1409.1556 32. P.P. Srinivasan, L.A. Kim, P.S. Mettu, S.W. Cousins, G.M. Comer, J.A. Izatt, S. Farsiu, Fully automated detection of diabetic macular edema and dry age-related macular degeneration from optical coherence tomography images. Biomed. Opt. Express 5(10), 3568–3577 (2014) 33. D.S.W. Ting, G.C.M. Cheung, T.Y. Wong, Diabetic retinopathy: global prevalence, major risk factors, screening practices and public health challenges: a review. Clin. Exp. Ophthalmol. 44(4), 260–277 (2016) 34. M. Treder, J.L. Lauermann, N. Eter, Automated detection of exudative age-related macular degeneration in spectral domain optical coherence tomography using deep learning. Graefe’s Arch. Clin. Exp. Ophthalmol. 256(2), 259–265 (2018) 35. F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, X. Tang, Residual attention network for image classification, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3156–3164 (2017) 36. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020) 37. W.L. Wong, X. Su, X. Li, C.M.G. Cheung, R. Klein, C.Y. Cheng, T.Y. Wong, Global prevalence of age-related macular degeneration and disease burden projection for 2020 and 2040: a systematic review and meta-analysis. Lancet Glob. Health 2(2), e106–e116 (2014) 38. J.W. Yau, S.L. Rogers, R. Kawasaki, E.L. Lamoureux, J.W. Kowalski, T. Bek, S.J. Chen, J.M. Dekker, A. Fletcher, J. Grauslund et al., Global prevalence and major risk factors of diabetic retinopathy. Diabetes Care 35(3), 556–564 (2012) 39. F. Yu, V. Koltun, Multi-scale context aggregation by dilated convolutions (2015). arXiv preprint arXiv:1511.07122
Three-Stream Convolutional Neural Network for Human Fall Detection Guilherme Vieira Leite, Gabriel Pellegrino da Silva, and Helio Pedrini
Abstract Lower child mortality rates, advances in medicine, and cultural changes have increased life expectancy to above 60-years old in developed countries. Some countries expect that, by 2030, 20% of their population will be over 65 years old. The quality of life at this advanced age is highly dictated by the individual’s health, which will determine whether the elderly can engage in important activities to their wellbeing, independence, and personal satisfaction. Old age is accompanied by health problems caused by biological limitations and muscle weakness. This weakening facilitates the occurrence of falls, which are responsible for the deaths of approximately 646,000 people worldwide and, even when a minor fall occurs, it can still cause fractures, break bones, or damage soft tissues, which will not heal completely. Injuries and damages of this nature, in turn, will consume the self-confidence of the individual, diminishing their independence. In this work, we propose a method capable of detecting human falls in video sequences using multi-channel convolutional neural networks (CNN). Our method makes use of a 3D CNN fed with features previously extracted from each frame to generate a vector for each channels. Then, the vectors are concatenated, and a support vector machine (SVM) is applied to classify the vectors and indicate whether or not there was a fall. We experiment with four types of features, namely: (i) optical flow, (ii) visual rhythm, (iii) pose estimation, and (iv) saliency map. The benchmarks used (UR Fall Detection Dataset (URFD) [33] and (ii) Fall Detection Dataset (FDD) [12]) are publicly available and our results are compared to those in the literature. The metrics selected for evaluation are balanced G. V. Leite · G. P. da Silva · H. Pedrini (B) Institute of Computing, University of Campinas, Campinas, SP, Brazil e-mail: [email protected] G. V. Leite e-mail: [email protected] G. P. da Silva e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_3
49
50
G. V. Leite et al.
accuracy, accuracy, sensitivity, and specificity. Our results are competitive with those obtained by the state of the art on both URFD and FDD datasets. To the authors’ knowledge, we are the first to perform cross-tests between the datasets in question and to report results for the balanced accuracy metric. The proposed method is able to detect falls in the selected benchmarks. Fall detection, as well as activity classification in videos, is strongly related to the network’s ability to interpret temporal information and, as expected, optical flow is the most relevant feature for detecting falls.
1 Introduction Developed countries have reached a life expectancy of over 60 years of age [76]. Some countries in the European Union and China expect 20% of their population to be over 65 by 2030 [25]. According to the World Health Organization [25, 78], this recent level is a side effect of scientific advances, medical discoveries, reduced child mortality, and cultural changes. However, although human beings are living longer, the quality of this life span is mainly defined by their health, since it dictates an individual’s independence, satisfaction, and the possibility of engaging in activities that are important to their well being. The relation between health and life quality inspired several research groups to try and develop assisting technologies focusing on the elderly population.
1.1 Motivation Naturally, health problems will appear along with the aging process, mainly due to the biological limitations of the human body and muscle weakness. As a side effect, this weakening increases the elderly chances of suffering a fall. Falls are the second leading cause of domestic death worldwide, killing around 646,000 people every year [25]. Reports indicate that between 28 and 35% of the population over 65 years old falls at least once a year and this percentage rises to 32–42% to individuals over 70 years old. In an attempt to summarize which events may lead to a fall, Lusardi et al. [46] reported risk factors, including how falls occur, who are falling and some precautions to avoid them. The effects of a fall accident to this fragile population can unleash a chain reaction of physiological and psychological damages, which consequently could decrease the elderly’s self-confidence and independence. Even if the accident is a small fall, it can break or fracture bones and damage soft tissues that, due to old age, may never fully recover. To avoid serious injuries the elder population needs constant care and monitoring, especially regarding the fact that an accident’s effects are related to the time between its occurrence and the beginning of adequate care. That being said, qualified care-
Three-Stream Convolutional Neural Network …
51
Emergency Situation System raised a suspicion
Fall Detection
Emergency
Check the
signal
situation
Frames False alarm
Operator is contacted
Technical report
Remote operator
Sensors Blood pressure
Camera
Classifier
Heartbeat sensor
IMU
Fall alert
+ Local processing center
Fig. 1 Diagram illustrating the main components of a monitoring system
givers are expensive, especially when added to the already inflated budget of health care in advanced age, which might lead families to relocate the elderly. Relocation is a common practice, in which the surrounding family moves the elderly from their dwelling place to the family’s home, causing discomfort and stress to adapt to a new environment and routine. Based on the scenario described previously, a technological solution can be devised in the form of an emergency system that would reliably trigger qualified assistance automatically. Thus, the response time between the accident and care would be reduced, allowing the elderly to dwell in their own home, keeping their independence and self-esteem. The diagram in Fig. 1 illustrates the components of such system, also known as an assisted daily living (ADL) system. The system is fed by a range of sensors, installed around the home or attached to the subject’s body, which monitors information such as blood pressure, heartbeats, body temperature, blood glucose levels, acceleration, and posture. These devices are connected to a local processing center that, in addition to continuously building technical reports on the health of the elderly, also activates a remote operator in an emergency alert. Upon receiving the alert, the operator performs a situation check and, upon verifying the need, dispatch medical assistance.
52
G. V. Leite et al.
1.2 Research Questions Aware of the dangers and consequences that a fall represents for the elderly population and, as a video-based fall detection system can benefit from deep learning techniques, we propose a fall detection module, as part of an ADL system, which can detect falls in video frame sequences. In implementing this proposal, we intend to answer the following research questions: 1. Are the feature channels able to maintain enough temporal information so that a network can learn its patterns? 2. Which feature channel best contribute to the problem of fall detection? 3. Does the effectiveness of the proposed method remain for other human fall datasets? 4. Are three-dimensional (3D) architectures more discriminative than twodimensional (2D) equivalent to this problem?
1.3 Contributions The main contribution of this work is the multi-channel model to detect human falls [8, 9, 37]. Each channel represents a feature that we judge by descriptive of a fall event and, upon being ensembled, they produce a better estimation of a fall. The model was tested on two publicly available datasets, and is open-sourced [38]. In addition, we present an extensive comparison between different channel combinations and the employed architecture, whose best results are comparable to the state of the art. Finally, we also discuss the implications of simulated datasets as a benchmark for fall detection methods.
1.4 Chapter Layout The remaining of this work is structured as follows. Section 2 reviews the existing approaches in the related literature. Section 3 presents the main concepts that were used in the implementation of this work. Section 4 describes in detail the proposed methodology and its steps. Section 5 exhibits the selected datasets used in the experiments, our evaluation metrics, the experiments carried out, their results, and discussions. Finally, Sect. 6 presents some final considerations and proposals for future work.
Three-Stream Convolutional Neural Network …
53
2 Related Work In the following subsections, we elucidate the related work to fall detection and their methods. The works were split into two groups, based on the sensors used: (i) methods without videos and (ii) methods with videos.
2.1 Methods Without Videos The works in this group utilize various sensors to obtain data, which can be a watch, accelerometer, gyroscope, heart rate sensor or a smartphone, that is, any sensor that is not a camera. Khin et al. [28] developed a triangulation system that uses several presence sensors installed in the monitored room of a home. The presence of someone in the room causes a disturbance in the sensors and a fall would trigger an activation pattern. Despite not reporting the results, the authors stated that they tested this configuration and that the sensors detected different patterns for different actions, indicating that the solution could be used to detect falls. Kukharenko and Romanenko [31] obtained their data from a wrist device. A threshold was used to detect impact or a “weightless” state, and after this detection, the algorithm waits a second and analyzes the information obtained. The method was tested on some volunteers, who reported complaints such as forgetting to wear the device and the discomfort it caused. Kumar et al. [32] placed a sensor attached to a belt on the person’s waist and compared four methods to detect falls: threshold, support vector machines (SVM), K -nearest neighbors (KNN) and dynamic time warping (DTW). The authors also commented on the importance of sensors attached to the body, since they would monitor the individual constantly as it would not present blind spots, as opposed to cameras. Vallejo et al. [72] developed a deep neural network to classify the sensor’s data. The chosen sensor is a gyroscope worn at the waist and the network is composed of three hidden layers, with five neurons each. The authors carried out experiments with adults aged 19–56 years. Zhao et al. [84] collected data from a gyroscope attached to the individual’s waist and used a decision tree to classify the information. The experiments were performed on five random adults. Zigel et al. [86] used accelerometers and microphones as sensors, however, the sensors were installed in the environment, instead of being attached to the subject. The sensors would detect vibrations and feed a quadratic classifier. The tests were executed with a test dummy, which was released from an upright position.
54
G. V. Leite et al.
2.2 Methods With Videos In this section, we grouped the methods whose main data sources are video sequences from cameras. Despite a sensor similarity, these methods presented a variety of solutions such as the following works that used threshold-based activation techniques. To isolate the human silhouette, Lee and Mihailidis [36] performed a background subtraction alongside with a region extraction. After this, the posture was determined through a threshold of the values of the silhouette’s perimeter, speed of the silhouette’s center, and by the Feret’s diameter. The solution was tested on a dataset created by the authors. Nizam et al. [53] used two thresholds, the first verified if the body speed was high and, if so, a second threshold verified whether the joints’ position was close to the ground. The joints’ position was obtained after subtracting the background, with a Kinect camera. The experiments were carried out on a dataset created by the authors. Sase and Bhandari [57] applied a threshold in which, a fall was defined as if the region of interest was smaller than one-third of the individual’s height. The region of interest was obtained by background extraction and the method was tested on the basis URFD [33]. Bhandari et al. [5] applied a threshold on the speed and direction of the region of interest. A combination of Shi-Tomasi and Lucas–Kanade was applied to determine the region of interest. The authors tested the approach in the URFD [33] set and reported 95% accuracy. Another widely used classification technique is the SVM, used by Abobakr et al. [2]. The method used depth information to subtract the background of the video frame, applied a random forest algorithm to estimate the posture and classified it with SVM. Fan et al. [17] also separated the picture between the foreground and background and fitted an ellipse to the silhouette of the human body found. From the ellipse, six features were extracted and served to a slow feature function. An SVM classified the output of the slow feature function and the experiments were executed on the SDUFall [47] dataset. Harrou et al. [23] used an SVM classifier that received features extracted from video frames. During the tests, the authors compared the SVM with a multivariate exponentially weighted moving average (MEWMA) and tested the solution on the datasets URFD [33] and FDD [12]. Mohd et al. [50] fed an SVM classifier with information on the height, speed, acceleration, and position of the joints and performed tests on three datasets: TST Fall Detection [20], URFD [33] and Fall Detection by Zhang [83]. Panahi and Ghods [56] subtracted the background from the depth information, fitted an ellipse to the shape of the individual, classified the ellipse with SVM and performed tests on the URFD [33] dataset. Concerned with privacy, the following works argued that solutions to detect falls should offer anonymity options. Therefore, Edgcomb and Vahid [16] tested the effectiveness of a binary tree classifier over a time series. The authors compared different means to hide identity, such as blurring, extracting the silhouette, replacing the individual with an opaque ellipse or an opaque box. They conducted tests on their dataset,
Three-Stream Convolutional Neural Network …
55
with 23 videos recorded. Lin et al. [42] investigated a solution focused on privacy using a silhouette. They applied a K-NN classifier in addition to a timer that checks whether the individual’s pose has returned to normal. The tests were performed by laboratory volunteers. Some studies used convolutional neural networks, such as the case of Anishchenko [4], which implemented an adaptation of the AlexNet architecture to detect falls in the FDD [12] dataset. Fan et al. [18] used a CNN to monitor and assess the degree of completeness of an event. A stack of video frames was used in a VGG-16 architecture and its result was associated with the first frame in the stack. The method was tested on two datasets: FDD [12] and Hockey Fights [52]. Their results were reported in terms of completeness of the falls. Huang et al. [26] used the OpenPose algorithm to obtain the coordinates of the body joints. Two classifiers (SVM and VGG-16) were compared to classify the coordinates. The experiments were carried out on the datasets URFD [33] and FDD [12]. Li et al. [40] created a modification of CNN’s architecture, AlexNet. The solution was tested on the URFD [33] dataset, also the authors reported that the solution classified between ADLs and falls in real time. Min et al. [49] used an R-CNN (CNN of regions) to analyze a scene, which generates spatial relationships between furniture and the human being on the scene, and then classified the spatial relationship between them. The authors experimented on three datasets: URFD [33], KTH [58], and a dataset created by them. NúñezMarcos et al. [54] performed the classification with a VGG-16. The authors calculated the dense optical flow, which served as a characteristic for the network to classify. They tested the method on the URFD [33] and FDD [12] databases. Coincidentally, all works that used recurrent neural networks used the same architecture, long-short term memory (LSTM). Lie et al. [41] applied a recurrent neural network, with LSTM cells, to classify the individual’s posture. The stance was extracted by a CNN and the experiments were carried out on a dataset created by the authors. Shojaei-Hashemi et al. [60] used a Microsoft Kinect device to obtain the individual’s posture information and an LSTM as a classifier. The experiments were performed on the NTU RGB+D dataset. Furthermore, the authors reported one advantage of using the Kinect, since the posture extraction could be achieved in real time. Lu et al. [43] proposed the application of an LSTM right after a 3D CNN. The authors performed tests on the URFD [33] and FDD [12] datasets. Other machine learning algorithms, such as the K-nearest neighbors, were also used to detect falls. Kwolek and Kepski [34] made use of a combination of an accelerometer and Kinect. Once the accelerometer surpassed a threshold, a fall alert was raised, and only then, the Kinect camera started capturing frames of the scene’s depth, which was used by a second classifier. The authors compared the classification of the frames between KNN and SVM and tested on two datasets, the URFD [33] and on an independent one. Sehairi et al. [59] developed a finite state machine to estimate the position of the human head from the extracted silhouette. The tests were performed on the FDD [12] dataset.
56
G. V. Leite et al.
The application of Markov filters was also used to detect falls, as in the work of Anderson et al. [3], in which the individual’s silhouette was extracted so that his characteristics were classified by the Markov filter. The experiments were carried out on their dataset. Zerrouki and Houacine [81] described the characteristics of the body through curvelet coefficients and the ratio between areas of the body. An SVM classifier performed the posture classification and the Markov filter discriminated between falls or not falls. The authors reported experiments on the URFD [33] and FDD [12] datasets. In addition to the methods mentioned above, the following works made use of several techniques, such as Yu et al. [80], which obtained their characteristics by applying head tracking techniques and analysis of shape variation. The characteristics served as input to a Gaussian classifier. The authors created a dataset for the tests. Zerrouki et al. [82] segmented the frames between the foreground and background and applied another segmentation on the human body, dividing it into five partitions. The body segmentations were fed to an AdaBoost classifier, which obtained 96% accuracy on the URFD [33] dataset. Finally, Xu et al. [79] published a survey that evaluates several fall detection systems.
3 Basic Concepts In this section, we describe in detail the concepts necessary to understand the proposed methodology.
3.1 Deep Neural Networks Deep neural networks (DNNs) are a class of machine learning algorithms, in which several layers of processing are used to extract and transform characteristics from the input data and the backpropagation algorithm enables the network to learn the complex patterns of the data. The input information for each layer is the same as the output of the previous one, except for the first layer, in which data is input, and the last layer, from which the outputs are extracted [22]. This structure is not necessarily fixed, some layers can have two other layers as input or several outputs. Deng and Yu [15] mentioned some reasons for the growing popularity of deep networks in recent years, which include their results in classification problems, improvements in graphic processing units (GPUs), the appearance of tensor processing units (TPUs), and the amount of data available digitally. The layers of a deep network can be organized in different ways, to suit the task at hand. The manner a DNNs is organized is called the “architecture”, and some of
Three-Stream Convolutional Neural Network …
57
them have become well known because of their performance in image classification competitions. Some of them are AlexNet [30], LeNet [35], VGGNet [62] and ResNet [24].
3.2 Convolutional Neural Networks
Fig. 2 Layout of the VGG-16 architecture
Output
FC FC FC
Conv 5-1 Conv 5-2 Conv 5-3 Pool
Conv 4-1 Conv 4-2 Conv 4-3 Pool
Conv 3-1 Conv 3-2 Conv 3-3 Pool
Conv 2-1 Conv 2-2 Pool
Conv 1-1 Conv 1-2 Pool
Input
Convolutional neural networks (CNNs) are a subtype of deep networks, their structure is similar to that of a DNN, such that information flows from one layer to the next. However, on CNN the data is also processed by convolutional layers, which applies various convolution operations and resizes the data, before sending it on to the next layer. These convolution operations allow the network to learn low-level features in the first layers and merge them in the following layers to learn high-level features. Although not mandatory, usually at the very end of a convolutional network there are some fully connected layers. To this work’s scope, two CNN architectures are relevant: (i) VGG-16 [62] and (ii) Inception [66]. The VGG-16 was the winner of the ImageNet Large Scale Visual Recognition Challenge (ILSVRC) 2014 competition, with a 7.3% error in the location category. Its choice of using small filters, convolutions of 3 × 3, stride 1, padding 1 and max-pooling of 2 × 2 with stride 2, allowed the network to be deeper, without being computationally prohibitive. The VGG-16 has 16 layers and 138 million parameters, which is considered small for deep networks. The largest load of computations in this network occurs in the first layers, since, after them, the layers of pooling considerably reduce the load to the deeper layers. Figure 2 illustrates the VGG-16 architecture. The second architecture, Inception V1 [66], was the winner of ILSVRC 2014, in the same year as VGG-16, however, on the classification category, with an error of 6.7%. This network was developed to be deeper and, at the same time, more computationally efficient. The Inception architecture has 22 layers and only 5 million parameters. Its construction consists of stacking several modules, called Inception, illustrated in Fig. 3a. The modules were designed to create something of a network within the network, in which several convolutions and max-pooling operations are performed in parallel and, at the end of these, the features are concatenated to be sent to the next module. However, if the network was composed of Inception modules, as illustrated in Fig. 3a, it would perform 850 million operations total. To reduce this number, bottlenecks
58
G. V. Leite et al.
1x1 Convolutions
3x3 Convolutions
Previous Layer
5x5 Convolutions
Filter Concatenation
3x3 Max Pooling
(a) module without bottleneck
1x1 Convolutions
Previous Layer
1x1 Convolutions
3x3 Convolutions
1x1 Convolutions
5x5 Convolutions
3x3 Max Pooling
1x1 Convolutions
Filter Concatenation
(b) module with bottleneck Fig. 3 Inception modules. Adapted from Szegedy et al. [66]
were created. Bottlenecks reduce the number of operations to 358 million. They are 1 × 1 convolutions that preserve the spatial dimension while decreasing the depth of the features. To do so, they were placed before the convolutions 3 × 3, 5 × 5 and after the max-pooling 3 × 3, as illustrated in Fig. 3b.
3.2.1
Transfer Learning
Transfer learning consists of reusing a network whose weights were trained in another context, usually a related and more extensive context. In addition to usually improving a network’s performance, the technique also collaborates to decrease the convergence time and help on scenarios with not enough training data available [85]. Typically in classification problems, the transfer happens from the ImageNet [14] dataset, which is one of the largest and well-known datasets. The goal is so that the
Three-Stream Convolutional Neural Network …
59
network can learn enough complex patterns that are generic enough to be used in another context.
3.3 Definition of Fall The literature does not provide a universal definition of fall [29, 45, 67], however, some health agencies have created their definition, which can be used to describe a general idea of a fall. The Joint Commission [68] defines a fall as “[...] a fall may be described as an unintentional change in position coming to rest on the ground, floor, or onto the next lower surface (e.g., onto a bed, chair or bedside mat). [...]”. The World Health Organization [77] defines it as “[...] an event which results in a person coming to rest inadvertently on the ground or floor or other lower level. [...]”. The definition of the National Center for Veterans Affairs [70] is as “loss of upright position that results in landing on the floor, ground, or an object or furniture, or a sudden, uncontrolled, unintentional, non-purposeful, downward displacement of the body to the floor/ground or hitting another object such as a chair or stair [...]”. In this work, we describe a fall as an involuntary movement that results in an individual landing on the ground.
3.4 Optical Flow Optical flow is a technique that deduces pixel movement, caused by the displacement of the object or the camera. It is a vector that represents the movement of a region, extracted from a sequence of frames. It assumes that the pixels will not leave the frame region, and is a local method, thus, the difficulty to calculate it on uniform regions. The extraction of the optical flow is performed by comparing two consecutive frames, and its representation is a vector of direction and magnitude. Consider I a video frame and I (x, y, t) a pixel in this frame. A frame analyzed in a future time dt is described in Equation 1 as a function of the pixel I (x, y, t) displacement of (d X, dY ). The Eq. 2 is obtained from a Taylor series divided by dt and has the following gradients f t , f x and f y (Eq. 3). The components of the optical flow are the values of u and v (Eq. 4) and can be obtained by several methods such as Lucas– Kanade [44] or Farnebäck [19]. Figure 4 illustrates some optical flow frames.
60
G. V. Leite et al.
Fig. 4 Examples of extracted optical flow. Each optical flow frame was extracted from the above frame and its next in sequence. The pixels colors indicate the movement direction, and its brightness relates to the magnitude of the movement
I (x, y, t) = I (x + d X, y + dY, t + dT )
(1)
f x u + f y v + ft = 0 ∂f ∂f ∂f fy = ft = fx = ∂x ∂y ∂t ∂x ∂y u= v= ∂t ∂t
(2) (3) (4)
3.5 Visual Rhythm Visual rhythm is an encoding technique, that aims at creating a temporal relation between frames, without losing their spatial information. Its representation consists of a single image summarizing the entire video, in a way that each video frame contributes as a column on the final image [51, 69, 71]. The construction of the visual rhythm happens as each video frame is traversed in a zigzag pattern, from its lower left diagonal to its upper right diagonal, as illustrated in Fig. 5a. Each frame processed in zigzag generates a column of pixels, which is concatenated with the other columns to form the visual rhythm (Fig. 5b). The dimensions of the rhythm image are W × H , in which the width W is the number of frames in the video and the height H is the length of the zigzag path. Figure 6 illustrates some extracted visual rhythms.
3.6 Saliency Map In the context of image processing, a saliency map is a feature of the image that represents regions of interest in an image. The map is generally presented in shades
Three-Stream Convolutional Neural Network …
61
1 2 3
...
N
Resulting Visual Rhythm
Source Frames
Fig. 5 Visual rhythm construction process. On the left, the zigzag manner in which each frame is traversed. On the right, the rhythm construction through the column concatenation
Fig. 6 Visual rhythm examples. Each frame on the first row illustrates a different video and, bellow, the visual rhythm for the corresponding entire video
of gray so that regions with low interest are black and regions with high interest are white. In the context of deep learning, the saliency map is the activation map of a classifier, highlighting the regions that greater contributed to the output classification. In its origin, the salience map was extracted as a way of understanding what the deep networks were learning and it is still used in that way, as in the work of Li et al. [39]. The salience map was used by Zuo et al. [87] as a feature to classify actions on an egocentric point of view, in which the source of information is a camera that corresponds to the subject’s first-person view. Using the saliency map in the egocentric context has its roots on the assumption that important aspects of the action take place in the foreground, instead of the background. The saliency map can be obtained in several ways, as shown by Smilkov et al. [64], Sundararajan et al. [65] and Simonyan et al. [63]. Figure 7 illustrates the extraction of the saliency map.
62
G. V. Leite et al.
Fig. 7 Saliency map examples. Pixels varying from black to white, according to the region importance to the external classifier
3.7 Posture Estimation It is a technique to derive the posture of one or more human beings. Different sensors are used as input to this technique, such as depth sensors from a Microsoft Kinect or images from a camera. The algorithm proposed by Cao et al. [7], OpenPose, is notable for its effectiveness in estimating the pose of individuals in video frames. OpenPose operates with a twostage network in search of 18 body joints. On its first stage, the method creates probability maps of the joint’s position and the second stage predicts affinity fields between the limbs found. The affinity is represented by a 2D vector, which encodes the position and orientation of each body limb. Figure 8 shows the posture estimation of some frames.
Fig. 8 Example of the posture estimation. Each circle represents a joint, whereas the edges represent the limbs. Each limb has a specific color assigned to it
Three-Stream Convolutional Neural Network …
63
4 Proposed Method In this section, we describe the proposed method, an Inception 3D multi-channel network to detect falls on video frames. Figure 9 shows the overview of the method, that is based on the hypothesis raised by Goodale and Milner [21], in which the human visual cortex consists of two parts that focus on processing different aspects of vision. This same hypothesis inspired Simonyan and Zisserman [61] to test neural networks with various channels of information to simulate the visual cortex. In the methodology illustrated in Fig. 9, each stream is a separate neural network, which was trained on a specific feature. For instance, the optical flow stream is a network fed exclusively with the optical flow extracted from the frames, whereas the saliency stream is a different network fed exclusively with the extracted saliency frames, and so on. Since each stream is independent of each other, we explored different architectures and stream combinations in our experiments. Instead of always employing three streams, we also tested whether using only two streams would produce a better classifier.
4.1 Preprocessing In the related literature to deep learning, knowledge about the positive influence of preprocessing steps is ubiquitously present. In this sense, some processes were executed to better tackle the task in question. In this work, the preprocessing step, represented by the green block in Fig. 9, consists of extracting features that can capture the various aspects of a fall and applying data augmentation techniques.
4.1.1
Feature Extraction
The following features were extracted and later inputted to the network in a specific manner. The posture estimation was extracted using a bottom-up approach, with the OpenPose algorithm by Cao et al. [7]. The extracted frames were fed into the network one at a time and the inference results are obtained frame by frame (Fig. 8). Regarding the visual rhythm, an algorithm was implemented by us, such that each video has only one visual rhythm. This rhythm frame was fed to the network repeatedly, so that its inference output could be paired with the other features (Fig. 6). The saliency map was obtained using the SmoothGrad technique, proposed and implemented by Smilkov et al. [64], which acts on top of a previously existing technique by Sundararajan et al. [65]. The frames were fed to the network, once again, one by one (Fig. 7). The optical flow extraction was carried out with the algorithm proposed by Farnebäck [19], which describes the dense optical flow (Fig. 4). As a fall event hap-
64
G. V. Leite et al.
Dataset
Visual Rhythm Stream
ImageNet Weights
Pose Estimation Stream
ImageNet Weights
Training Test
Trained Model
Neural Networks
Trained Model
ImageNet Weights
Neural Networks
Trained Model
Pre-Processing
Optical Flow Stream
Trained Model
Trained Model
Trained Model
SVM
Fall or Not Fall
Fig. 9 Overview of the proposed method that illustrates the training and test phases, their steps and the information flow throughout the method
Three-Stream Convolutional Neural Network …
Optical Flow
First stack
65
Second stack
Third stack
Fig. 10 Illustration of the sliding window and how it moves to create each stack
pens throughout several frames and the flow represents only the relationship between two of them, we employed a sliding window approach, suggested by Wang et al. [74]. The sliding window feeds the network with a stack of ten frames of optical flow. The first stack contains frames from 1 to 10 and the second stack frames from 2 to 11, and so on with stride of 1 (Fig. 10), so each video has N − 10 + 1 stacks, assuming N as the number of frames in a video, also if at the end of a video there are less then ten frames, then they do not contribute to the evaluation. The resulting inference of each stack was associated with the first frame of the stack, it was done this way so that the optical flow could be paired with the other channels on the network.
4.1.2
Data Augmentation
Data augmentation techniques were used, when applicable, in the training phase. The following augmentations were employed: vertical axis mirroring, perspective transform, cropping and adding mirrored borders, adding values between −20 and 20 to pixels, and adding values between −15 and 15 to the hue and saturation. The whole augmentation process was done only over the RGB channel, as the other channels would suffer negatively from it. For instance, the optical flow information depends strictly on the relationship between frames and its magnitude is expressed by the brightness of the pixel, mirroring an optical flow frame would break the continuity between them and adding values to pixels would distort the magnitude of the vector.
4.2 Training Due to the small amount of available data to experiment on, the training phase of our method requires the application of transfer learning techniques. Thus, the model was trained on the ImageNet [14] dataset and, later, on our selected fall dataset, this whole process is illustrated in Fig. 11. Traditionally, other works might freeze some layers between the transfer learning and the training, this was not the case, all the layers were trained with the fall dataset.
66
G. V. Leite et al.
Model without training
Model trained on ImageNet
Model trained on Fall datasets
Fig. 11 Transfer learning process. From left to right, initially the model has no trained weights, then it is trained on the ImageNet [14] dataset and, finally, it is trained on the selected fall dataset
The same transfer learning was done for all feature channels, meaning that, independently of the feature that a channel will be trained on, its starting point was the ImageNet training. After this, each channel is trained with its extracted feature frames, although, we selected four features, we did not combine all of them at the same time and kept it up to three combined features at a time.
4.3 Test The selected features in this work were previously used by some others in the related literature [7, 69], however, they were employed in a single-stream manner. These works, along with the work of Simonyan et al. [61], paved our motivation in proposing a multi-stream methodology that would join the different aspects of each feature, so that a more robust description of the event could be achieved. This ensemble can be accomplished in several ways, varying from a simple average between the channels, through a weighted average, to some automatic methods, like the one used in this work, the application of an SVM classifier. The workflow of the test phase illustrated to the right of Fig. 9, begins with the feature extraction, equal to the training phase. After that, the weights obtained in the training phase were loaded, and all layers of the model were frozen. Then, each channel with its specific trained model performed inferences on their input data. The output vectors were concatenated and sent to the SVM classifier, which in turn would classify each frame between fall and not fall.
Three-Stream Convolutional Neural Network …
67
5 Experimental Results In this section, we describe the datasets used in the experiments, the metrics selected to evaluate our method, the executed experiments, and how our method stands against others proposed.
5.1 Datasets Upon reviewing the related literature, a few datasets were found, however, some were not publicly available, hyperlinks to the data were inactive, or the authors did not answer our contact attempt. This lead us to select the following human fall datasets: (i) URFD [33] and (ii) FDD [12].
5.1.1
URFD
Published by Kwolek and Kepski [33], the URFD dataset (University of Rzeszow Fall Detection Dataset) is made up of 70 video sequences, 30 of which are falls and 40 of everyday activities. Each video has 30 frames per second (FPS), with a resolution of 640 × 240 pixels and varying lengths. The fall sequences were recorded with an accelerometer and two Microsoft Kinect cameras, one camera has a horizontal view of the scene and one with a top–down view, from the ceiling. The activities of daily living were recorded with a single horizontal view camera and an accelerometer. The accelerometer information was excluded from the experiments as it went beyond the scope of the project. As illustrated in Fig. 12, the dataset has five ADL scenarios, but a single fall scenario, in which the camera angle and background are the same, changing only the actors in the scene, this lack of variety is further discussed in the experiments. The dataset is annotated with the following information: • • • • • • • •
Subject’s posture (not lying, lying on the floor and transition). Ratio between height and width of the bounding box. Ratio between maximum and minimum axes. Ratio of the subject’s occupancy in the bounding box. Standard deviation of the pixels to the centroid of the X and Z axes. Ratio between the subject’s height in the frame and the subject’s standing height. Subject’s height. Distance from the subject’s center to the floor.
68
G. V. Leite et al.
Fig. 12 URFD’s [33] environments. a Fall scenarios; b ADLs scenarios
5.1.2
FDD
The FDD dataset (Fall Detection Dataset) was published by Charfi et al. [12] and contains 191 video sequences, with 143 being falls and 48 being day-to-day activities. Each video has 25 FPS, with a resolution of 320 × 240 pixels and varying lengths. All sequences were recorded with a single camera, in four different environments: home, coffee room, office, and classroom, illustrated in Fig. 13. Besides, the dataset presents three experimentation protocols: (i) in which training and testing are created with videos from the home and coffee room environments, (ii) in which the training consists of videos from the coffee room and the test with videos from the office and the classroom and (iii) where the training contains videos from the coffee room, the office, and the classroom and the test contains videos from the office and the classroom. The dataset is annotated with the following information: • Initial frame of the fall. • Final frame of the fall. • Height, width, and coordinates of the center of the bounding box in each frame.
Three-Stream Convolutional Neural Network …
69
Fig. 13 FDD’s [12] environments. a Fall scenarios; b ADLs scenarios
5.2 Evaluation Metrics In this work, we approach the problem of detecting falls as a binary classification, in which a classifier must decide whether a video frame corresponds to a fall or not. To that end, the chosen metrics and their respective equations were as follows: (i) precision (Eq. 5), (ii) sensitivity (Eq. 6), (iii) accuracy (Eq. 7), and (iv) balanced accuracy (Eq. 8). In the following equations, the abbreviations corresponds to: TP true positive, FP false positive, TN true negative, and FN false negative, also in Eq. 9, yi corresponds to the true value of the i sample, and wi corresponds to the sample weight. TP T P + FP TP Sensitivity = T P + FN TP +TN Accuracy = T P + T N + FP + FN 1 Balanced Accuracy = ( yˆi = yi )wˆi wˆi i Precision =
(5) (6) (7) (8)
70
in which
G. V. Leite et al.
wˆi =
wi j 1(y j = yi )w j
(9)
These metrics were chosen because of the need to compare our results with those found in the literature, which, for the most part, reported only: precision, sensitivity, and accuracy. Considering that both datasets are unbalanced, so that the negative class has more than twice as many samples as the positive class (falls), we chose to use the balanced accuracy instead of some other balanced metric, because, as stated in its name, it balances the samples and, in doing so, takes the false negatives into account. False negatives are especially important in fall detection, as ignoring a fall incident can lead to the health problems described in Sect. 1.
5.3 Computational Resources This method was implemented using the Python programming language [73], which was chosen due to its wide availability of libraries for image analysis and deep learning applications. Moreover, some libraries were used, such as: SciKit [27], NumPy [55], OpenCV [6] and Keras [13], and the TensorFlow [1] framework. Deep learning algorithms are known to be computationally intensive. Their training and experiments require more computational power than a conventional notebook can provide and, therefore, were carried out in the cloud on a rental machine from Amazon AWS, g2.2xlarge, with the following specifications: 1x Nvidia GRID K520 GPU (Kepler), 8x vCPUs, and 15GB of RAM.
5.4 Experiments Next, we report the performed experiments, their results, and discussion are presented. The experiments were split between multi-channel, cross-tests, and literature comparisons. To the knowledge of the authors, the cross-tests, in which the model was trained in one dataset and tested on the other, is unprecedented among the selected datasets. The data were split in proportions of 65% for training, 15% for validation and 20% for testing. All our experiments were executed using the following parameters: 500 epochs, along with early stopping and patience of 10, a learning rate of 10−5 , mini-batches of 192, 50% of dropout, Adam optimizer, and we trained to minimize the validation loss function. In order, the results were reported on the URFD dataset, followed by the one from the FDD set. The results obtained on the URFD base are shown in Table 1, and are organized in decreasing order of the balanced accuracy. In this first experiment, the combination
Three-Stream Convolutional Neural Network … Table 1 3D multi-channel results on URFD dataset Channels Precision (%) Sensitivity (%) OF RGB OF VR RGB VR OF RGB SA OF RGB VR OF RGB PE OF SA SA VR RGB SA SA PE RGB PE VR PE OF PE
1.00 0.99 0.99 1.00 0.99 0.99 0.99 0.99 0.99 0.99 0.99 0.99 0.99
1.00 0.96 0.90 0.99 0.99 0.96 0.98 0.94 0.95 1.00 0.99 0.96 0.97
71
Accuracy (%)
Balanced Accuracy (%)
0.97 0.95 0.96 0.98 0.99 0.96 0.94 0.94 0.96 0.91 0.92 0.94 0.92
0.98 0.97 0.97 0.94 0.94 0.91 0.91 0.90 0.89 0.89 0.89 0.88 0.87
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in decreasing order of balanced accuracy, and the best result of each column is highlighted in bold
of the optical flow and RGB channels obtained the best result, nevertheless, the pose estimation channels obtained the worst results. Table 2 shows once again the efficacy of the optical flow channel, however, in this instance, the RGB channels suffered a slight fall in performance, in contrast to the pose estimation ones, who rose from the worsts results. The first cross-test was performed with the training done on the URFD dataset, and the test in the FDD set, as its results are shown in Table 3. Upon executing a cross-test it is expected that the model will not perform as well since it is a completely new dataset. The highest balanced accuracy was of 68%, a sharp drop from the 98% of Table 2, furthermore, a majority of the channels could not surpass the 50% mark, indicating that the model was barely able to generalize its learning. We believe that this drop on the balanced accuracy is an effect of the training dataset quality. As explained previously, the URFD dataset is very limited in its fall scenarios variability, presenting the same scenario repeatedly. Returning to the multichannel test in the URFD dataset (Table 2), it is possible to notice that the channels with access to the background were among the best results, namely, visual rhythm, and RGB, but on the cross-test, the best results were obtained from channels without access to the background information. This could indicate that channels with access to background learned some features present in the scenario and were not able to detect falls when the scenario changed. In contrast to those without access to the background, those were still able to detect falls upon scenario changes. The second cross-test was the opposite experiment, with training done on the FDD dataset, and test on the URFD set (Table 4). Although there is, once again, a sharp drop in the balanced accuracy, it performed much better than the previous
72
G. V. Leite et al.
Table 2 3D multi-channel results on FDD dataset Channels Precision (%) Sensitivity (%) OF SA SA VR SA PE RGB SA OF PE OF RGB VR OF VR RGB VR OF RGB VR PE OF RGB PE OF RGB SA 4 RGB PE
0.99 1.00 1.00 1.00 0.99 0.99 0.99 0.99 0.99 0.99 0.99 0.99 0.99
0.99 0.98 1.00 0.99 0.96 0.95 0.93 0.94 0.91 0.89 0.84 0.85 0.80
Accuracy (%)
Balanced Accuracy (%)
0.98 0.98 0.97 0.99 0.97 0.95 0.94 0.89 0.99 0.91 0.93 0.97 0.91
0.98 0.96 0.96 0.95 0.93 0.91 0.91 0.91 0.90 0.88 0.87 0.86 0.84
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in decreasing order of balanced accuracy, and the best result of each column is highlighted in bold Table 3 3D cross-test results URFD to FDD dataset Channels Precision (%) Sensitivity (%) OF PE OF SA SA PE OF RGB PE OF RGB VR OF RGB SA OF RGB OF VR RGB PE RGB VR RGB SA SA VR VR PE
0.97 0.96 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95
1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00 1.00
Accuracy (%)
Balanced Accuracy (%)
0.97 0.96 0.96 0.96 0.96 0.96 0.96 0.96 0.96 0.96 0.96 0.96 0.96
0.68 0.57 0.55 0.50 0.50 0.50 0.50 0.50 0.50 0.50 0.50 0.50 0.50
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in decreasing order of balanced accuracy, and the best result of each column is highlighted in bold
test, keeping a maximum balanced accuracy of 84% due to the optical flow and pose estimation channels. Moreover, a similar scenario is observable in both cross-tests, the channels without access to background were able to better discriminate between fall and not fall.
Three-Stream Convolutional Neural Network … Table 4 3D cross-test results FDD to URFD dataset Channels Precision (%) Sensitivity (%) OF PE VR PE OF RGB PE SA PE OF SA RGB PE OF RGB SA OF RGB VR OF RGB OF VR RGB VR RGB SA SA VR
0.97 0.96 0.96 0.96 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95 0.95
0.90 0.92 0.99 0.93 0.99 0.99 1.00 1.00 1.00 1.00 1.00 1.00 1.00
73
Accuracy (%)
Balanced accuracy (%)
0.89 0.90 0.97 0.95 0.91 0.91 0.95 0.95 0.95 0.95 0.95 0.95 0.95
0.84 0.75 0.72 0.60 0.54 0.54 0.50 0.50 0.50 0.50 0.50 0.50 0.50
OF: Optical flow, VR: Visual rhythm, SA: Saliency, PE: Posture estimation. The results are in decreasing order of balanced accuracy, and the best result of each column is highlighted in bold
Given that on the second cross-test, even the channels with access to background faced an improvement in their performance, one might argue that the background had nothing to do with it. However, it is important to reiterate that the FDD dataset is more heterogeneous than the URFD, to the point in which a fall scenario and an ADL activity were recorded in the same scenario. This variability added to the intrinsic nature of the 3D model, which creates internal temporal relation between the input data, probably allowed the channels with access to the background to focus on features of the fall movement itself. Thus, the rising number of channels with more than 50% of balanced accuracy.
5.4.1
Our Method Versus the Literature
At last, we compared our best results and the ones from our previous work with those found in the literature. As stated before, we did not compare our balanced accuracy, since no other work reported theirs. The results regarding the URFD dataset are reported in Table 5, while those on the FDD set are shown in Table 6, both are sorted in decreasing order of accuracy. Our Inception 3D architecture surpassed or matched the reviewed works on both datasets, as well as our previous method using the VGG-16 architecture. However, the VGG-16 was not able to surpass the work of Lu et al. [43], which curiously employs a 3D method, an RNN with LSTM architecture, explaining its performance.
74
G. V. Leite et al.
Table 5 Ours versus literature on URFD dataset Approaches Precision (%) Sensitivity (%) Our inception 3D Lu et al. [43] Previous VGG-16 Panahi and Ghods [56] Zerrouki and Houacine [81] Harrou et al. [23] Abobakr et al. [2] Bhandari et al. [5] Kwolek and Kepski [34] Núñez-Marcos et al. [54] Sase and Bhandari [57]
Accuracy (%)
0.99 – 1.00 0.97 –
0.99 – 0.98 0.97 -
0.99 0.99 0.98 0.97 0.96
– 1.00 0.96 1.00
– 0.91 – 0.92
0.96 0.96 0.95 0.95
1.00
0.92
0.95
0.81
–
0.90
Metrics that were not reported by the authors are exhibited as a hyphen (-). The results are sorted by the accuracy on a decreasing order. The best result of each column is highlighted in bold Table 6 Ours versus literature on FDD dataset Approaches Precision (%) Our inception 3D Previous VGG-16 Lu et al. [43] Sehairi et al. [59] Zerrouki and Houacine [81] Harrou et al. [23] Núñez-Marcos et al. [54] Charfi et al. [11]
Sensitivity (%)
Accuracy (%)
1.00 0.99 – – –
0.99 0.99 – – –
0.99 0.99 0.99 0.98 0.97
– 0.99
– 0.97
0.97 0.97
0.98
0.99
–
Metrics that were not reported by the authors are exhibited as a hyphen (-). The results are sorted by the accuracy on a decreasing order. The best result of each column is highlighted in bold
6 Conclusions and Future Work In this work, we presented and compared our deep neural networks method to detect human falls in video sequences, using a 3D architecture, the Inception V1 3D network. The training and evaluation of the method were performed on two public datasets and, regarding its effectiveness, outperformed or matched the related literature.
Three-Stream Convolutional Neural Network …
75
The results pointed out the importance of temporal information in the detection of falls, both because the temporal channels were always among the best results, especially the optical flow, and in the improvement obtained by the 3D method when compared to our previous work on the VGG-16 architecture. The 3D method was also able to generalize its learning to a never seen dataset, and this ability to generalize the learning indicates that the 3D method can be considered a strong candidate to compose the elder monitoring system. This is evidenced throughout all the results shown in Sect. 5, since the temporal channels are always among the most effective, except for two cases shown in Tables 2 and 3, in which the third-best result was the combination of the spatial channels of salience and pose. This can be attributed to the fact that the 3D architecture itself provides a temporal relationship between the data. Our conclusion about the importance of temporal information to the fall classification is corroborated by other works found in the literature, such as those by Meng et al. [48] and Carreira and Zisserman [10], who stated the same for the classification of actions in videos. This indicates that other deep learning architectures, such as those described in [75], could also be used for this application. In addition, as our results surpassed the reviewed works, the method demonstrated itself to be effective in detecting falls. In a specific instance, our method matched the results of the work of Lu et al. [43], in which the author makes use of an LSTM architecture that, like ours, creates temporal relationships between the input data. Innovatively, cross-tests were performed between the datasets. The results of these tests showed a known, however, interesting facet of neural networks, in which the minimization function finds a local minimum that does not correspond with the initial objective of the solution. During the training, some channels of the network, learned aspects of the background of the images to classify the falls, instead of focusing on the aspects of the person’s movement on video. The method developed in our work is corroborated by some factors, such as the evaluation through the balanced accuracy, the tests being performed on two different datasets, the heterogeneity of the FDD dataset, the execution of the cross-tests and the comparisons between the various channel combinations. On the other hand, the work also deals with some difficulties, such as (i) the low variability of the fall videos in the URFD set, (ii) the fact that in the cross tests many combinations of channels obtained only 50% of balanced accuracy, and (iii) the use of simple accuracy as a means of comparison with the literature. However, the proposed method suppresses these counterparts, remaining relevant to the problem at hand. The effectiveness of the method shows that, if trained on a robust enough dataset, it can extract the temporal patterns necessary to classify scenarios between fall and non-fall. Admittedly, there is an expected drop in balanced accuracy in the cross tests. Regarding fall detection, this work is one of the most accurate approaches and would be a great contribution as a module in an integrated system of assistance to the elderly. Concerning future work, some points can be mentioned: (i) exploring other datasets that may contain a greater variety of scenarios and actions, (ii) integrating fall detection into a multi-class system, (iii) experimenting with cheaper features
76
G. V. Leite et al.
to be extracted, (iv) adapting the method to work in real time, either through cheaper channels or a lighter architecture, and (v) dealing with input as a stream of videos, instead of clips, because in a real scenario, the camera would continuously feed the system, which would further unbalance classes. The contributions of this work are presented in the form of a human fall detection method, implemented and publicly available in the repository [38], as well as the experimentation using multi-channels and different datasets, which generated not only a discussion about which metrics are more appropriate for evaluating fall solutions, but also a discussion of the quality of the datasets used in these experiments. Acknowledgements The authors are thankful to FAPESP (grant #2017/12646-3), CNPq (grant #309330/2018-7), and CAPES for their financial support, as well as Semantix Brasil for the infrastructure and support provided during the development of the present work.
References 1. M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, TensorFlow: large-scale machine learning on heterogeneous systems (2015). https://www.tensorflow.org 2. A. Abobakr, M. Hossny, S. Nahavandi, A skeleton-free fall detection system from depth images using random decision forest. IEEE Syst. J. 12(3), 2994–3005 (2017) 3. D.T. Anderson, J.M. Keller, M. Skubic, X. Chen, Z. He, Recognizing falls from silhouettes, in International Conference of the IEEE Engineering in Medicine and Biology Society (2006), pp. 6388–6391 4. L. Anishchenko, Machine learning in video surveillance for fall detection, in Ural Symposium on Biomedical Engineering, Radioelectronics and Information Technology (IEEE, 2018), pp. 99–102 5. S. Bhandari, N. Babar, P. Gupta, N. Shah, S. Pujari, A novel approach for fall detection in home environment, in IEEE 6th Global Conference on Consumer Electronics (IEEE, 2017), pp. 1–5 6. G. Bradski, The openCV library. Dobb’s J. Softw. Tools 120, 122–125 (2000) 7. Z. Cao, T. Simon, S.E. Wei, Y. Sheikh, Realtime multi-person 2D pose estimation using part affinity fields, in IEEE Conference on Computer Vision and Pattern Recognition (2017), pp. 7291–7299 8. S. Carneiro, G. Silva, G. Leite, R. Moreno, S. Guimaraes, H. Pedrini, Deep convolutional multi-stream network detection system applied to fall identification in video sequences, in 15th International Conference on Machine Learning and Data Mining (2019a), pp. 681–695 9. S. Carneiro, G. Silva, G. Leite, R. Moreno, S. Guimaraes, H. Pedrini, Multi-stream deep convolutional network using high-level features applied to fall detection in video sequences, in 26th International Conference on Systems, Signals and Image Processing (2019b), pp. 293–298 10. J. Carreira, A. Zisserman, Quo vadis, action recognition? a new model and the kinetics dataset, in Conference on Computer Vision and Pattern Recognition (IEEE, 2017), pp. 6299–6308 11. I. Charfi, J. Miteran, J. Dubois, M. Atri, R. Tourki, Definition and performance evaluation of a robust svm based fall detection solution, in International Conference on Signal Image Technology and Internet Based Systems, vol. 12 (2012), pp. 218–224 12. I. Charfi, J. Miteran, J. Dubois, M. Atri, R. Tourki, Optimized spatio-temporal descriptors for real-time fall detection: comparison of support vector machine and adaboost-based classification. J. Electron. Imaging 22(4), 041106 (2013) 13. F. Chollet, Keras (2015). https://keras.io 14. J. Deng, W. Dong, R. Socher, L.J. Li, K. Li, L. Fei-Fei, Imagenet: a large–scale hierarchical image database, in IEEE Conference on Computer Vision and Pattern Recognition (2009), pp. 248–255
Three-Stream Convolutional Neural Network …
77
15. L. Deng, D. Yu, Deep learning: methods and applications. Found. Trends Signal Process. 7(3–4), 197–387 (2014) 16. A. Edgcomb, F. Vahid, Automated fall detection on privacy-enhanced video, in Annual International Conference of the IEEE Engineering in Medicine and Biology Society (2012), pp. 252–255 17. K. Fan, P. Wang, S. Zhuang, Human fall detection using slow feature analysis. Multimed. Tools Appl. 78(7), 9101–9128 (2018a) 18. Y. Fan, G. Wen, D. Li, S. Qiu, M.D. Levine, Early event detection based on dynamic images of surveillance videos. J. Vis. Commun. Image Represent. 51, 70–75 (2018b) 19. G. Farnebäck, Two–frame motion estimation based on polynomial expansion, in Scandinavian Conference on Image Analysis (2003), pp. 363–370 20. S. Gasparrini, E. Cippitelli, E. Gambi, S. Spinsante, J. Wåhslén, I. Orhan, T. Lindh, Proposal and experimental evaluation of fall detection solution based on wearable and depth data fusion, in International Conference on ICT Innovations (Springer, 2015), pp. 99–108 21. M.A. Goodale, A.D. Milner, Separate visual pathways for perception and action. Trends Neurosci. 15(1), 20–25 (1992) 22. I. Goodfellow, Y. Bengio, A. Courville, Y. Bengio, Deep Learning (MIT Press, 2016) 23. F. Harrou, N. Zerrouki, Y. Sun, A. Houacine, Vision-based fall detection system for improving safety of elderly people. IEEE Instrum. & Meas. Mag. 20(6), 49–55 (2017) 24. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778 25. D.L. Heymann, T. Prentice, L.T. Reinders, The World Health Report: a Safer Future: global Public Health Security in the 21st Century (World Health Organization, 2007) 26. Z. Huang, Y. Liu, Y. Fang, B.K. Horn, Video-based fall detection for seniors with human pose estimation, in 4th International Conference on Universal Village (IEEE, 2018), pp. 1–4 27. E. Jones, T. Oliphant, P. Peterson, SciPy: open source scientific tools for python (2001). http:// www.scipy.org 28. O.O. Khin, Q.M. Ta, C.C. Cheah, Development of a wireless sensor network for human fall detection, in International Conference on Real-Time Computing and Robotics (IEEE, 2017), pp. 273–278 29. Y. Kong, J. Huang, S. Huang, Z. Wei, S. Wang, Learning spatiotemporal representations for human fall detection in surveillance video. J. Vis. Commun. Image Represent. 59, 215–230 (2019) 30. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional neural networks. Adv. Neural Inf. Process. Syst. 25, 1097–1105 (2012) 31. T. Kukharenko, V. Romanenko, Picking a human fall detection algorithm for wrist–worn electronic device, in IEEE First Ukraine Conference on Electrical and Computer Engineering (2017), pp. 275–277 32. V.S. Kumar, K.G. Acharya, B. Sandeep, T. Jayavignesh, A. Chaturvedi, Wearable sensor–based human fall detection wireless system, in Wireless Communication Networks and Internet of Things (Springer, 2018), pp. 217–234 33. B. Kwolek, M. Kepski, Human fall detection on embedded platform using depth maps and wireless accelerometer. Comput. Methods Programs Biomed. 117(3), 489–501 (2014) 34. B. Kwolek, M. Kepski, Improving fall detection by the use of depth sensor and accelerometer. Neurocomputing 168, 637–645 (2015) 35. Y. LeCun, L. Bottou, Y. Bengio, P. Haffner, Gradient-based learning applied to document recognition. Proc. IEEE 86(11), 2278–2324 (1998) 36. T. Lee, A. Mihailidis, An intelligent emergency response system: preliminary development and testing of automated fall Detection. J. Telemed. Telecare 11(4), 194–198 (2005) 37. G. Leite, G. Silva, H. Pedrini, Fall detection in video sequences based on a three-stream convolutional neural network, in 18th IEEE International Conference on Machine Learning and Applications (ICMLA) (Boca Raton-FL, USA, 2019), pp. 191–195 38. G. Leite, G. Silva, H. Pedrini, Fall detection (2020). https://github.com/Lupins/fall_detection
78
G. V. Leite et al.
39. H. Li, K. Mueller, X. Chen, Beyond saliency: understanding convolutional neural networks from saliency prediction on layer-wise relevance propagation. Comput. Res. Repos. (2017a) 40. X. Li, T. Pang, W. Liu, T. Wang, Fall detection for elderly person care using convolutional neural networks, in 10th International Congress on Image and Signal Processing, BioMedical Engineering and Informatics (2017b), pp. 1–6 41. W.N. Lie, A.T. Le, G.H. Lin, Human fall-down event detection based on 2D skeletons and deep learning approach, in International Workshop on Advanced Image Technology (2018), pp. 1–4 42. B.S. Lin, J.S. Su, H. Chen, C.Y. Jan, A fall detection system based on human body silhouette, in 9th International Conference on Intelligent Information Hiding and Multimedia Signal Processing (IEEE, 2013), pp. 49–52 43. N. Lu, Y. Wu, L. Feng, J. Song, Deep learning for fall detection: 3D-CNN combined with LSTM on video kinematic data. IEEE J. Biomed. Health Inform. 23(1), 314–323 (2018) 44. B.D. Lucas, T. Kanade, An iterative image registration technique with an application to stereo vision, in International Joint Conference on Artificial Inteligence (1981), pp. 121–130 45. F. Luna-Perejon, J. Civit-Masot, I. Amaya-Rodriguez, L. Duran-Lopez, J.P. DominguezMorales, A. Civit-Balcells, A. Linares-Barranco, An automated fall detection system using recurrent neural networks, in Conference on Artificial Intelligence in Medicine in Europe (Springer, 2019), pp. 36–41 46. M.M. Lusardi, S. Fritz, A. Middleton, L. Allison, M. Wingood, E. Phillips, Determining risk of falls in community dwelling older adults: a systematic review and meta-analysis using posttest probability. J. Geriatr. Phys. Ther. 40(1), 1–36 (2017) 47. X. Ma, H. Wang, B. Xue, M. Zhou, B. Ji, Y. Li, Depth-based human fall detection via shape features and improved extreme learning machine. J. Biomed. Health Inform. 18(6), 1915–1922 (2014) 48. L. Meng, B. Zhao, B. Chang, G. Huang, W. Sun, F. Tung, L. Sigal, Interpretable Spatio-Temporal Attention for Video Action Recognition (2018), pp. 1–10. arXiv preprint arXiv:181004511 49. W. Min, H. Cui, H. Rao, Z. Li, L. Yao, Detection of human falls on furniture using scene analysis based on deep learning and activity Characteristics. IEEE Access 6, 9324–9335 (2018) 50. M.N.H. Mohd, Y. Nizam, S. Suhaila, M.M.A. Jamil, An optimized low computational algorithm for human fall detection from depth images based on support vector machine classification, in IEEE International Conference on Signal and Image Processing Applications (2017), pp. 407–412 51. T.P. Moreira, D. Menotti, H. Pedrini, First-person action recognition through visual rhythm texture description, in International Conference on Acoustics (Speech and Signal Processing, IEEE, 2017), pp. 2627–2631 52. E.B. Nievas, O.D. Suarez, G.B. García, R. Sukthankar, Violence detection in video using computer vision techniques, in International Conference on Computer Analysis of Images and Patterns (Springer, 2011), pp. 332–339 53. Y. Nizam, M.N.H. Mohd, M.M.A. Jamil, Human fall detection from depth images using position and velocity of subject. Procedia Comput. Sci. 105, 131–137 (2017) 54. A. Núñez-Marcos, G. Azkune, I. Arganda-Carreras, Vision-based fall detection with convolutional neural networks. Wirel. Commun. Mob. Comput. 2017, 1–16 (2017) 55. T.E. Oliphant, Guide to NumPy, 2nd edn. (CreateSpace Independent Publishing Platform, USA, USA, 2015) 56. L. Panahi, V. Ghods, Human fall detection using machine vision techniques on RGB-D images. Biomed. Signal Process. Control 44, 146–153 (2018) 57. P.S. Sase, S.H. Bhandari, Human fall detection using depth videos, in 5th International Conference on Signal Processing and Integrated Networks (IEEE, 2018), pp. 546–549 58. C. Schuldt, I. Laptev, B. Caputo, Recognizing human actions: a local SVM approach, in 17th International Conference on Pattern Recognition, vol. 3 (IEEE, 2004), pp 32–36 59. K. Sehairi, F. Chouireb, J. Meunier, Elderly fall detection system based on multiple shape features and motion analysis, in International Conference on Intelligent Systems and Computer Vision (IEEE, 2018), pp. 1–8
Three-Stream Convolutional Neural Network …
79
60. A. Shojaei-Hashemi, P. Nasiopoulos, J.J. Little, M.T. Pourazad, Video–based human fall detection in smart homes using deep learning, in IEEE International Symposium on Circuits and Systems (2018), pp. 1–5 61. K. Simonyan, A. Zisserman, Two-stream convolutional networks for action recognition in videos. Adv. Neural Inf. Process. Syst. 27, 568–576 (2014a) 62. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recognition (2014b), pp. 1–14. arXiv, arXiv:14091556 63. K. Simonyan, A. Vedaldi, A. Zisserman, Deep inside convolutional networks: visualising image classification models and saliency maps. Comput. Res. Repos. (2013) 64. D. Smilkov, N. Thorat, B. Kim, F. Viégas, M. Wattenberg, Smoothgrad: removing noise by adding noise (2017), pp. 1–10. arXiv preprint arXiv:170603825 65. M. Sundararajan, A. Taly, Q. Yan, Axiomatic attribution for deep networks, in 34th International Conference on Machine Learning, vol. 70, pp. 3319–3328 (JMLR.org, 2017) 66. C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, going deeper with convolutions, in IEEE Conference on Computer Vision and Pattern Recognition (2015), pp. 1–9 67. S.K. Tasoulis, G.I. Mallis, S.V. Georgakopoulos, A.G. Vrahatis, V.P. Plagianakos, I.G. Maglogiannis, Deep learning and change detection for fall recognition, in Engineering Applications of Neural Networks, ed. by J. Macintyre, L. Iliadis, I. Maglogiannis, C. Jayne (Springer International Publishing, Cham, 2019), pp. 262–273 68. The Joint Commission, Fall reduction program—definition of a fall (2001) 69. B.S. Torres, H. Pedrini, Detection of complex video events through visual rhythm. Vis. Comput. 34(2), 145–165 (2018) 70. US Department of Veterans Affairs, Falls policy overview (2019). http://www.patientsafety. va.gov/docs/fallstoolkit14/05_falls_policy_overview_v5.docx 71. F.B. Valio, H. Pedrini, N.J. Leite, Fast rotation-invariant video caption detection based on visual rhythm. in Iberoamerican Congress on Pattern Recognition (Springer, 2011), pp. 157–164 72. M. Vallejo, C.V. Isaza, J.D. Lopez, Artificial neural networks as an alternative to traditional fall detection methods, in 35th Annual International Conference of the IEEE Engineering in Medicine and Biology Society (2013), pp. 1648–1651 73. G. Van Rossum, F.L. Jr Drake, Python reference manual. Tech. Rep. Report CS-R9525, Centrum voor Wiskunde en Informatica, Amsterdam (1995) 74. L. Wang, Y. Xiong, Z. Wang, Y. Qiao, Towards good practices for very deep two-stream convnets (2015), pp. 1–5. arXiv preprint arXiv:150702159 75. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020) 76. World Health Organization, Global Health and Aging (2011) 77. World Health Organization, Fact sheet falls (2012) 78. World Health Organization, World Report on Ageing and Health (2015) 79. T. Xu, Y. Zhou, J. Zhu, New advances and challenges of fall detection systems: a survey. Appl. Sci. 8(3), 418 (2018) 80. M. Yu, S.M. Naqvi, J. Chambers, A robust fall detection system for the elderly in a smart room, in IEEE International Conference on Acoustics Speech and Signal Processing (2010), pp. 1666–1669 81. N. Zerrouki, A. Houacine, Combined curvelets and hidden Markov models for human fall detection. Multimed. Tools Appl. 77(5), 6405–6424 (2018) 82. N. Zerrouki, F. Harrou, Y. Sun, A. Houacine, Vision-based human action classification using adaptive boosting algorithm. IEEE Sens. J. 18(12), 5115–5121 (2018) 83. Z. Zhang, V. Athitsos, Fall detection by zhong zhang and vassilis athitsos (2020). http://vlm1. uta.edu/~zhangzhong/fall_detection/ 84. S. Zhao, W. Li, W. Niu, R. Gravina, G. Fortino, Recognition of human fall events based on single tri–axial gyroscope, in IEEE 15th International Conference on Networking, Sensing and Control (2018), pp. 1–6 85. F. Zhuang, Z. Qi, K. Duan, D. Xi, Y. Zhu, H. Zhu, H. Xiong, Q. He, A comprehensive survey on transfer learning (2019), pp. 1–27. arXiv preprint arXiv:191102685
80
G. V. Leite et al.
86. Y. Zigel, D. Litvak, I. Gannot, A method for automatic fall detection of elderly people using floor vibrations and sound-proof of concept on human mimicking doll falls. IEEE Trans. Biomed. Eng. 56(12), 2858–2867 (2009) 87. Z. Zuo, B. Wei, F. Chao, Y. Qu, Y. Peng, L. Yang, Enhanced gradient-based local feature descriptors by saliency map for egocentric action recognition. Appl. Syst. Innov. 2(1), 1–14 (2019)
Diagnosis of Bearing Faults in Electrical Machines Using Long Short-Term Memory (LSTM) Russell Sabir, Daniele Rosato, Sven Hartmann, and Clemens Gühmann
Abstract Rolling element bearings are very important components in electrical machines. Almost 50% of the faults that occur in the electrical machines occur in the bearings. This makes bearings as one of the most critical components in electrical machinery. Bearing fault diagnosis has drawn the attention of many researchers. Generally, vibration signals from the machine’s accelerometer are used for the diagnosis of bearing faults. In literature, application of Deep Learning algorithms on these vibration signals has resulted in the fault detection accuracy that is close to 100%. Although, fault detection using vibration signals from the machine is ideal but measurement of vibration signals requires an additional sensor, which is absent in many machines, especially low voltage machines as it significantly adds to its cost. Alternatively, bearing fault diagnosis with the help of the stator current or Motor Current Signal (MCS) is also gaining popularity. This paper uses MCS for the diagnosis of bearing inner raceway and outer raceway fault. Diagnosis using MCS is difficult as the fault signatures are buried beneath the noise in the current signal. Hence, signal-processing techniques are employed for the extraction of the fault features. The paper uses the Paderborn University damaged bearing dataset, which contains stator current data from healthy, real damaged inner raceway, and real damaged outer raceway bearings with different fault severity. Fault features are extracted from MCS by first filtering out the redundant frequencies from the signal and then extracting eight features from the filtered signal, which include three features from time domain R. Sabir (B) · D. Rosato · S. Hartmann SEG Automotive Germany GmbH, Lotterbergstraße 30, 70499 Stuttgart, Germany e-mail: [email protected] D. Rosato e-mail: [email protected] S. Hartmann e-mail: [email protected] R. Sabir · C. Gühmann Chair of Electronic Measurement and Diagnostic Technology & Technische Universität Berlin, Berlin, Germany e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_4
81
82
R. Sabir et al.
and five features from time–frequency domain by using the Wavelet Packet Decomposition (WPD). After the extraction of these eight features, the well-known Deep Learning algorithm Long Short-Term Memory (LSTM) is used for bearing fault classification. The Deep Learning LSTM algorithm is mostly used in speech recognition due to its time coherence, but in this paper, the ability of LSTM is also demonstrated with the fault classification accuracy of 97%. A comparison of the proposed algorithm is done with the traditional Machine Learning techniques, and it is shown that the proposed methodology outperforms all the traditional algorithms which are used for the classification of bearing faults using MCS. The method developed is independent of the speed and the loading conditions. Keywords Ball bearings · AC machines · Fault diagnosis · Fault detection · Discrete wavelet transforms · Wavelet packets · Wavelet coefficients · Learning (artificial intelligence) · Machine learning · Deep learning · LSTM
1 Introduction Rolling element bearings are key component in rotating machinery. They support radial and axial loads by reducing rotational friction and hence carry heavy loads on the machinery. They ensure critical optimal performance of the machinery, but their failures can lead to the downtime of the machinery causing significant economic losses [1, 2]. In an electrical machine, due to the overloading stress and misalignment, bearings are the most susceptible components to be damaged. Bearing faults are caused by insufficient lubrication, fatigue, incorrect mounting, corrosion, electrical damage or from foreign particles (contamination) normally appear as wear, indentation, and spalling [3, 4]. Bearing failures contribute to approximately about 50% of the failures in the electrical machines [4]. The structure of the rolling element bearing is shown in Fig. 1. Nb is the number of balls, Db is the diameter of the balls, Dc is the diameter of the cage, and β is the contact angle between the ball and
Fig. 1 Rolling element bearing [18]
Diagnosis of Bearing Faults in Electrical Machines …
83
the raceway. Rolling element bearings are comprised of two rings, an inner ring and an outer ring as shown in Fig. 1. The inner race is normally mounted on the motor shaft. Between the two rings, some balls or rollers are located. Continued stress results in material fragments loss on the inner and outer race. Due to the susceptibility of bearings being damaged, bearing fault diagnosis has attracted the attention of many researchers. Typical bearing fault diagnosis includes analysis of different measured signals from the electrical machine with damaged bearings by using signal-processing methods. In this paper, only the inner and outer raceway will be analyzed. Each of these faults corresponds to a characteristic frequency f c , which is given by (1) and (2), where fr is the rotating frequency of the rotor. Db Nb fr 1 − cos β 2 Dc Db Nb fr 1 + Inner raceway : f i = cos β 2 Dc
Outer raceway : f 0 =
(1) (2)
These frequencies result in when the balls rotate and hit the defect causing an impulsive effect. These characteristic frequencies are a function of bearing geometry and the rotor frequency. The impulsive effect generates rotor eccentricities at the characteristic frequencies, causing the air-gap between the rotor and the stator to be varied. Because of the variation of the air-gap length, the machine inductances are also varied. These characteristic frequencies and sidebands are derived in [5]. However, these can be only reliably detected under specific conditions as they are highly influenced by external noise and operating conditions. Most of the times, these frequencies only appear in the frequency spectrum for faults with high severity.
2 Literature Review Bearing fault diagnosis using vibration signals, acoustic emissions, temperature measurements, chemical monitoring, electric currents, and shock pulse method has been important research areas during the last few decades. The most widely used methods for bearing fault diagnosis employ the use of vibration signals. Fault diagnosis using the vibration signals is the most effective method until date with accuracy of fault detection reaching close to 100%. With conventional Deep Learning methods on the vibration signals, researchers have been able to achieve accuracies well above 99%, e.g., 99.6% accuracy using CNN–LSTM (Convolutional Neural Network–Long Short-Term Memory) [6], 99.83% accuracy using SDAE (Stacked Denoising Autoencoder) [7], EDAE (Ensemble Deep Autoencoder) 99.15% [8], and many more. However, in most industrial applications, cheap electrical machines used have power rating of 2 kW or less and having additional sensors, e.g., the accelerometer makes it not economically appealing because putting an extra sensor adds to its
84
R. Sabir et al.
cost. Furthermore, [9] argues that the vibration signal can be easily influenced by loose screws and resonance frequencies of different parts of the machine, especially the housing, which may lead to incorrect or misleading fault diagnosis. According to [5], stator current or Motor Current Signal (MCS) can also be effectively used for bearing fault diagnosis since currents can be easily measured from the existing frequency inverters, so no additional sensors are required. As already discussed, damages in the bearings, e.g., a hole, pit, crack, or missing material result in the creation of shock pulses when the ball passes over those defects. The shock pulses result in vibration at the bearing characteristic frequencies. This causes the average air-gap length of the rotor to be varied, resulting in changes in the flux density (or machine inductances) that consequently appear in the stator current as f b f given by (3) and modulate the rotating frequency of the machine. fb f = | fs ± k fc |
(3)
where f s is the electrical frequency of the stator and k = 1, 2, 3, … Schoen et al. [5] further states that by the analysis of the MCS in frequency domain these characteristic frequency components can be detected. However, similar characteristic frequencies start to appear in the frequency spectrum of the MCS as a result of rotor eccentricity, broken rotor bars, or other conditions that may vary the air-gap length of the rotor and stator. Other scenarios that could hinder the detection of the characteristic fault frequencies are when the fault is in the early stages or when the measured signal has a low signal-to-noise ratio. Also, the location of these characteristic frequencies in the frequency spectrum is dependent on the machine speed, bearing geometry, and the location of the fault. Although, bearing fault detection by using vibration signals may have its advantages over diagnosis with MCS, but with the help of Deep Learning algorithms, excellent fault diagnosis using MCS can be achieved. Normally, for Deep Learning techniques, the feature extraction is not needed, but in this case, the modulating components or features are buried in the noise and they have to be detected and extracted using signal-processing techniques. Hence, diagnosis using MCS is not an easy task. In [10], fault signatures are extracted using the Discrete Wavelet Transform (DWT), which denoises the signal, reconstructs it back in time domain, and does spectral analysis to identify the faulty peaks. [11] identifies the fault signatures in MCS using the Welch PSD estimation, but the signatures are more correctly identified using the RMS values of the C8,7 and C7,6 coefficients of the of SWPT(Stationary Wavelet Packet Transform). Nevertheless, in all these approaches only fault signature identification is presented and no algorithm is demonstrated that is able to automatically identify the faults from the data. [12] evaluates the pros and cons of using MCS for diagnosis over the vibration signal, and concludes that the diagnosis from MCS is not always reliable due to low detection amplitude of the fault signatures. Hence, intelligent methods are not enough for correct fault classification. So, for such case, a powerful approach like Deep Learning methods must be adopted, because Deep Learning models have the ability to do classification even from weak features.
Diagnosis of Bearing Faults in Electrical Machines …
85
In literature, several Machine Learning and Deep Learning methods are employed; each method serves its own advantages and disadvantages. For example, in [13], an unsupervised classification technique called the artificial ant clustering is described, but it requires three voltage and three current sensors, which make the method not economically appealing. [14, 15] describe Machine Learning method SVM (Support Vector Machine) for fault diagnosis, but in [14] only outer raceway faults are considered and accuracy drops from 100% to 95% when multi fault classification is done. In [15], however the inner raceway fault, outer raceway fault, and cage fault are considered, the algorithm performs quite well at lower speeds with an average accuracy of 99%, but the accuracy drops to 92% at higher speeds. In [16], Convolutional Neural Network (CNN) is trained with the amplitude of the selected frequency components from the motor’s current spectrum. The results show that the accuracy is about 86%, and only the bearing outer race fault is considered. In [17], Sparse Autoencoders (AE) are used to extract fault patterns in the stator current and the vibration signals, and SVM is used for classification. AE gives a 96% accuracy for high severity bearing faults but 64% for low severity faults, and only the outer race fault is considered. Finally, in [18], two feature extraction methods are used which include a 1D-CNN and WPT (Wavelet Packet transform). In the analysis, two types of bearing faults are classified, inner ring fault and fault due to aluminum powder in the bearing. The method presented covers a wide speed range and fault severity with 98.8% accuracy but fails to include the load variation and also requires very high specification training hardware. In this paper, stator current is used to diagnose two bearing faults, inner and outer raceway, using the well-known Deep Learning algorithm, LSTM network. The bearing fault data considered takes into account different operating conditions such as speed, load variations, and varying fault severity levels. Hence, the method developed is independent of the speed and the loading conditions. In the first section of the paper, the LSTM network is described, then feature extraction and LSTM training are discussed, and finally the results from the trained LSTM are analyzed.
3 LSTM Network RNN (Recurrent Neural Network) is a class of artificial neural networks which is used to identify patterns in sequential data, e.g., speech recognition, handwriting recognition, natural language processing, video analysis, etc. Therefore, RNN can take time and sequence into account, because the RNN possesses two forms of input, the present input and the input from the recent past [19]. The key difference between the RNN to the feedforward network is that they have a feedback loop connected to their past decisions. From Fig. 2, it can be seen that that at a certain time t the recurrent hidden layer neurons have input from not only the input layer xt but also from its own at instance t − 1, i.e., from h t−1 . So, the output is the combination of the present and the past. The process can be represented as described by (4).
86
R. Sabir et al.
Fig. 2 Architecture of a RNN. b RNN over a timestep [20, 21]
h t = f (whx xt + whh h t−1 + bh ) yt = f w yh h t + b y
(4)
where xt ht h t−1 bh by whx whh
is the input to the RNN at time t is the state of the hidden layer at time t is the state of the neural network at t − 1 is the bias of the hidden layer is the bias of the output are the weights between the hidden and the input layer are the weights of the hidden layer at t − 1 and the hidden layer at t.
RNN is trained across the timesteps using Backpropagation Through Time (BPTT) [21]. However, because of the multiplication of gradients at timesteps, the gradient value becomes smaller and smaller or it gets larger and larger, and as a result, the RNN learning process encounters the issue of gradient vanishing or gradient exploding. Due to the issue of vanishing and exploding gradients, RNNs are used in limited applications. This issue is solved by using the LSTM (Long Short-Term Memory) which includes a memory cell that replaces the hidden RNN units [22]. The memory cell is composed of four gates, forget gate (how much to keep from previous cell state), input gate (whether to write in the cell), input modulation gate (how much quantity to write in the cell), and output gate (how much to reveal from the cell) [23]. LSTMs preserve the error that is backpropagated through time and layers. Figure 3 Shows the typical LSTM cell. In Fig. 3, σ (sigmoid) represents the gate activation function, whereas ϕ (tanh) is the input or output node activation. The LSTM model [21] presented in Fig. 3 is described by (5). gt = ϕ wgx xt + wgh h t−1 + bg i t = σ (wi x xt + wi h h t−1 + bi ) f t = σ w f x xt + w f h h t−1 + b f ot = σ (wox xt + woh h t−1 + bo )
Diagnosis of Bearing Faults in Electrical Machines …
87
Fig. 3 LSTM memory cell
st = gt i t + st−1 f t h t = ϕ(st )ot
(5)
where • wgx , wi x , w f x and wox are the weights at time t between the input and hidden layer • wgh , wi h , w f h , and woh are the weights at time t and t − 1 between the hidden layers • bg , bi , b f and bo are the biases of the gates • h t−1 is the value of the hidden layer at time t − 1 • f t ,gt , i t , and ot are the output values of the forget gate, input modulation gate, input gate, output gate, respectively • st and st−1 are the current state at time t and t − 1, respectively.
4 Paderborn University Dataset In [24, 25], a benchmark dataset for the purpose of bearing fault diagnosis is developed. The dataset is composed of synchronously measured currents of two phases along with vibration signals. The stator currents are measured with current transducer sampled at 64 kHz, then filtered with a 25 kHz low-pass filter. The dataset is composed of tests on six undamaged bearings and 26 damaged bearings. In the 26 damaged bearings, 12 have artificially induced damages and 14 have real damages. In most of the research done on bearing fault diagnosis, only artificial bearing faults are induced, and data is collected to make Machine Learning models because they are easy to generate. Nevertheless, these models fail to deliver the expected diagnosis when used in practical industry applications. In [24, 25], Machine Learning models were trained with artificially damaged bearings and tested with real damaged bearing. These models were not able to accurately classify the real damaged bearings. The reason is that it is not possible to accurately replicate the real damages to the
88
R. Sabir et al.
bearings, artificially. Getting real damaged bearing from the machinery is not easily possible. Due to the long lifetimes of bearings, bearings are generally replaced before failure. Hence, it is difficult to find large quantities of real damaged bearings. In [24, 25], by using scientific test rigs, real damaged bearings are produced by way of accelerated lifetime tests. An advantage of such technique is that the bearing damages can be generated by reproducible conditions. However, the disadvantage is that a lot of time and effort is required for such a process. The real damage in the bearing is generated with the application of a very high radial force to the bearing on a test rig. The high radial force applied to the bearing is far greater than what the bearing can endure which results in damages to appear very sooner in the bearing. Also, to further accelerate the process, low viscosity oil is used that results in improper lubrication and speeds up the process. Though the dataset is highly invaluable as it provides the test data for real and artificial bearing damages, nevertheless this paper will only focus on the real bearing damages for the purpose of demonstrating the diagnosis algorithm, but the approach could of course be extended to the bearings with artificial damages. The datasets of the three classes (healthy, outer race, and inner race) are described in Table 1. The damages to the bearings are of varying severity levels. The detailed description of these datasets and geometry of the used bearing can be found in [24, 25]. All the damaged bearing dataset is composed of single point damages, which are a result of fatigue, except the dataset KA30 that is composed of distributed damaged bearing due to plastic deformation. Each dataset (e.g., K001) has 80 measurements (20 measurements corresponding to each operating condition) of 4 s each. Different operating conditions described in the Table 2 are used, so that the Deep Learning algorithm incorporates variations in the speed and load, and the model is not dependent only for certain operating conditions. The operating parameters that are varied include the speed, radial force, Table 1 Dataset of healthy bearing and bearings with real damages
Table 2 Operating parameters of each dataset
Healthy (Class 1)
Inner ring damage (Class 2)
Outer ring damage (Class 3)
K001
KI04
KA04
K002
KI14
KA15
K003
KI16
KA16
K004
KI18
KA22
K005
KI21
KA30
No.
Rotational speed [rpm]
Load torque [Nm]
Radial force [N]
0
1500
0.7
1000
1
900
0.7
1000
2
1500
0.1
1000
3
1500
0.7
400
Diagnosis of Bearing Faults in Electrical Machines …
89
and the load torque. For three settings, the speed of 1500 rpm is used, and in one setting, speed of 900 rpm is used. Similarly, 0.7 Nm load torque for three settings, 0.1 Nm load torque for one setting, 1000 N radial force for the three settings, and 400 N radial force for one setting are used. All the datasets have damages that are less than 2 mm in size, except for datasets KI16, KI18, and KA16 which have damages greater than 2 mm. The temperature is kept constant to about 50 °C throughout all experiments.
5 Data Processing and Feature Extraction Before the Machine or Deep Learning algorithms are applied, the data is preprocessed, and then the important features are extracted to be used as an input to the algorithm for fault classification. It is therefore necessary to imply techniques that extract the best and useful features and get rid of irrelevant or redundant information. Removal of noisy, irrelevant, and misleading features gives a more compact representation and improves the quality of detection and diagnosis [13]. Figure 4a shows the stator current signal of phase 1 of the machines, and Fig. 4b shows its frequency spectrum. Looking closely in Fig. 4a, slight amplitude variations of the sine wave amplitude can be observed. These small variations of amplitude could contain the bearing characteristic frequencies. Hence, they are further analyzed for feature extraction. The spectrum shows the two dominant frequencies ω0 and 5ω0 . These frequencies offer no new information and are present in all the current signals (whether belonging to the healthy or to the damaged bearing). Therefore, removing
Fig. 4 a Stator current, b frequency spectrum of stator current, c stator current after filtration of ω0 and 5ω0 component, d the frequency spectrum of the filtered stator current
90
R. Sabir et al.
the frequencies from the current signal will give more focus to the frequencies that result due to the bearing faults. To remove the unwanted frequencies (ω0 and 5ω0 ), a signal-processing filter is designed in MATLAB that suppresses the frequencies ω0 and 5ω0 from the current signal. Figure 4c shows the filtered current signal, and Fig. 4d shows its spectrum. Now, this filtered signal is used for feature extraction. The spectrum of the filtered signal now contains noise and the characteristic bearing frequencies (if they are present). Huo et al. [26] evaluates the statistical features of the vibration signals from the bearings and concludes that the two features that are most sensitive parameters for detecting bearing faults are kurtosis and impulse factor. This conclusion well applies to the current signals, because these two statistical features adequately capture the impulsive behavior of the faulty bearings. Another time domain feature that is also useful is the clearance factor, which is also sensitive to the faults in the bearings. The remaining features are extracted from the time–frequency domain using third-level WPD (Wavelet Packet Decomposition). WPD is one of the well-known signal-processing technique, which is widely used in fault diagnosis. WPD provides useful information in both time and frequency domains. WPD decomposes the time domain signal into wavelets of various scales with variable sized windows and reveals the local structure in time–frequency domain. The advantages of WPD over other signal-processing techniques are that with WPD transient features can be effectively extracted and features from the full spectrum can be extracted without the requirement of a specific frequency band. WPD uses mother wavelets that are basic wavelet functions to expand, compress, and translate the signal by varying the scale frequency and the time shift of the wavelet. This enables the application at low-scale high frequency for short windows and at highscale low frequency for long window. For example, with long window at high scale and low frequency, higher resolution in time can be achieved for high-frequency components and high resolution in frequency for lower frequency components. The wavelet function must meet the following requirement in (6). Cψ = R
|ψ(ω)|2 dω < ∞ |ω|
(6)
where ψ(ω) is the Fourier transform of the mother wavelet ψ(t) The Continuous Wavelet Transform (CWT) is described by (7). t −b 1 ψa,b (t) = √ ψ a |a|
(7)
where ψa,b (t) is the continuous mother wavelet which is scaled by factor a and translated by factor b, and √1|a| is used for energy preservation. ψa,b (t) acts as a window function whose frequency and time location can be adjusted by a and b. For example, higher resolution in frequency can be achieved by smaller values of a; this helps when extracting higher frequency components of
Diagnosis of Bearing Faults in Electrical Machines …
91
the signal. The scaling and translation parameters make wavelet analysis ideal for non-stationary and non-linear signals in both time and frequency domain. a and b are continuous and can take any value but when a and b are both discretized we get DWT (Discrete Wavelet Transform). Although DWT is an excellent method for time–frequency domain analysis of the signal, it only considers the low-frequency part and neglects the high-frequency part, resulting in very bad resolution of high frequency. Therefore, in our case Wavelet Packet Decomposition (WPD), which is an extension of DWT is used for the timefrequency analysis of the signal. The difference from DWT is that in WPD all the detail signals (high-frequency part) are decomposed into further two signals, detail signal and approximation signal (low-frequency part). Hence, using WPD, a fullscale analysis in the time–frequency domain can be done. The scaling function ∅(t) and wavelet function ψ(t) of WPD are described by (8) and (9), respectively. ∅(t) =
√ 2 h(k)∅(2t − k)
(8)
k
ψ(t) =
√ 2 g(k)∅(2t − k)
(9)
k
where h(k) and g(k) are low- and high-pass filters, respectively. The WPD from level j to j + 1 at node n of a signal s( j, n) is given by (10) and (11). s( j, n) could either be cAj approximation (low frequency) or cDj detail (high frequency) coefficients. s j+1,2n =
h(m − 2k)s j,n
(10)
m
s j+1,2n+1 =
g(m − 2k)s j,n
m
where m is the wavelet coefficient number The process of the jth-level WPD is shown in Fig. 5.
Fig. 5 Schematic of WPD (Wavelet Packet Decomposition) process
(11)
92
R. Sabir et al.
With the help of WPD, the signal is decomposed into different frequency bands [27]. When a damage occurs in a bearing, the effect creates resonance at different frequencies causing energies to be distributed in different frequency bands depending on the bearing fault. The energy E( j, n) of WPD coefficient s( j, n) of jth level and nth node can be computed as shown in (12). E( j, n) =
s( j, n)2
(12)
To perform WPD, a mother wavelet must be selected; the selection of the wavelet depends on the application, as no wavelet is the absolute best. The wavelet that works best has its properties or the similarity close to the signal of application. [28] demonstrates the training of bearing fault data with different wavelets and concludes that the mother wavelet that adequately captures the bearing fault signatures is the Daubechies 6 (db6) wavelet. The Daubechies 6 (db6) wavelet also works best for our case as well and therefore is used for extracting the time–frequency domain features in this paper. In [25], a detailed feature selection method based on maximum separation distance is discussed. Using this method, the relevant features are selected as not all the features are useful for fault classification. Figure 6 Shows the third-level WPD of the signal x. From this decomposition, c A3,0 , c A3,1 , c A3,2 , and cD3,2 coefficients are used, and their energies are calculated. Table 3 shows the detailed list of the eight features that are selected and used in the diagnosis algorithm. These eight features from the filtered current signal are able to describe the bearing fault signatures quite well. From the datasets presented in Table 1, the following steps are considered for preparing the data for training and testing of the LSTM Network.
Fig. 6 Coefficients of the third-level WPD (Wavelet packet decomposition)
Diagnosis of Bearing Faults in Electrical Machines …
93
Table 3 Description of the features Feature no.
Description
1
Kurtosis .Kurt =
2
Impulse factor IF =
3
Clearance factor Clf =
4
RMS (Third-level WPD approximation coefficient 3, 0) R M S=
n
x)4 i=1 (x i −¯ (n−1)σ 4
.
max(|x i |) n i=1 |x i |
1 n
max(|x i |)
2 n √ 1 i=1 |x i | n
n 1
5
Energy of third-level WPD approximation coefficient 3, 0 E A3,0 =
6
Energy of third-level WPD approximation coefficient 3, 1 E A3,1 =
7 8
n
n
i=1
c A3,02 i
i=1 c A3,02
n2
i
i =1 c A3,12 i
Energy of third-level WPD approximation coefficient 3, 2 E A3,2 = n3 i =1 c A3,22i n4 Energy of third-level WPD detail coefficient 3, 2 E D3,2 = i =1 cD3,22 i
1. Step I: The stator current of phase 1 was taken which contains a 4 s measurement (resulting in 1200 signals from all the datasets, 400 signals each for healthy, inner race fault, and outer race fault class). The signals are filtered to remove ω0 and 5ω0 components. Then, the features presented in Table 3 are extracted. 2. Step II: All the features extracted in Step I are scaled between 0 and 1, for better convergence of the algorithm. 3. Step III: The feature data of each class is shuffled within the class itself, so that training and test data incorporate all the different operating points and conditions, as each dataset has different fault severities. 4. Step IV: 20% of the data points (i.e., 80 points) from each class are kept for testing, and 80% of the data points (i.e., 320 data points) from each class are used for training. Therefore, in total, there are 960 data points in the training set and 240 points in the testing set.
6 LSTM Network Training and Results The LSTM network architecture is described in Table 4. The LSTM network is composed of eight input nodes that correspond to the eight selected features. Addition of more features does not affect the accuracy largely but decreasing the features causes accuracy to fall. The input nodes are followed by four hidden layers with 32 hidden nodes in each layer. There is no absolute rule in the choice of the number of layers and hidden nodes, so four layers and 32 nodes have been chosen because decreasing the number of layers and nodes results in decrease in accuracy and while increasing it does not affect the accuracy to a great deal. After each LSTM hidden layer, a 50% dropout from the previous layer is added to prevent the network memorizing or overfitting the training data. Finally, the dense layer with softmax activation
94 Table 4 LSTM achitecture description
R. Sabir et al. Layers
Nodes
Input layer
8 input nodes
LSTM layer 1
32 hidden nodes
50% dropout LSTM layer 2
32 hidden nodes
50% dropout LSTM layer 3
32 hidden nodes
50% dropout LSTM layer 4
32 hidden nodes
50% dropout Dense layer with softmax activation
Table 5 LSTM training parameter details
3 output nodes
Parameters
Value
Training samples
960
Testing samples
240
Batch size
64
Epochs
2500
Dropout ratio
0.5
composing of three output nodes (corresponding to three output classes healthy, inner ring damage, and outer ring damage) is added. The Binary Cross-entropy function is used as the loss function, and ADAM optimizer [29] is used to train the parameters of the LSTM network. In the code implementation, training is done on computer’s GPU, so that the training process is accelerated. After the features have been extracted from the datasets and preprocessed, the feature vector is input to the LSTM Network for training. The parameters of the training are displayed in Table 5. Using the parameters, the network is trained for 2500 epochs with a batch size of 64. After the training, the LSTM network was able to achieve an accuracy of 97% on the testing data and 100% accuracy on training data. Figure 7 Shows the confusion matrix of the testing results of the LSTM model with rows representing the predicted class and columns the true class. False positives are displayed in the left column, and the upper row displays the false negatives. All of the normal or healthy bearing points are correctly classified. However, some points of the inner race fault are misclassified as outer race fault, and some points of the outer race are misclassified as inner race fault. Nevertheless, the LSTM Network with the proposed methodology provides excellent results in diagnosing the bearing faults and classifying healthy, inner race, and outer race faults to a great degree of accuracy.
Diagnosis of Bearing Faults in Electrical Machines …
95
Fig. 7 Confusion matrix of the results of LSTM testing for class 1 (normal), class 2 (inner race damaged), and class 3 (outer race damaged)
7 Comparison with Traditional Methods In the previous section, LSTM network was used in the diagnosis of the bearing inner and outer race faults, and the testing accuracy of 97% is achieved, which is greater than the testing accuracy 93.3% achieved by the ensemble of Machine Learning algorithms in [25]. Therefore, the Machine Learning methods used in their analysis were not exposed to different severity bearing measurements, which led to a lower accuracy in their analysis. In conventional Machine Learning techniques, the features are manually extracted from the data so that the data patterns are clearer to the Machine Learning algorithm. For Deep Learning approaches, the algorithm is capable of automatically extracting high-level features from the data. This approach of feeding the data directly to the Deep Learning algorithm works for majority of the cases. However, in the case of bearing fault diagnosis, feeding direct MCS to the proposed LSTM network failed to give good performance and resulted in a very poor classification accuracy. The reason for the poor performance of the proposed LSTM methodology is that the fault characteristic magnitudes are comparable or lower in magnitude to the signal noise. As a result, LSTM Network is unsuccessful in extracting these characteristic frequency features. Therefore, manual extraction of features from MCS is done, and signal-processing techniques such as wavelet transform are required for this task. As eight features are extracted from the filtered current signal, the traditional Machine Learning methods can also be applied in comparison to the proposed LSTM method. Table 6 shows the comparison of the traditional methods with the proposed method. In this testing, all the datasets were combined and shuffled, and five-fold cross-validation was done. Most of the Machine Learning
96
R. Sabir et al.
Table 6 Comparison of traditional methods with the proposed LSTM method using five-fold cross-validation Algorithm
Classification accuracy (%)
Multilayer perceptron (layers 150, 100, 50 with ADAM solver and 95.4 ReLU activation) SVM (Support Vector Machine)
66.1
k nearest neighbor
91.2
Linear regression
53.0
Linear discriminant analysis
56.8
CART (Classification And Regression Trees)
93.6
Gaussian Naive Bayes
46.0
LSTM (Proposed method)
97.0
methods showed poor performance; only MLP (Multilayer Perceptron), kNN, and CART methods showed promising results. The reason for the better performance of the LSTM method is that due to its memorizing ability, the algorithm learns the patterns in these 8 features, as these 8 features form a sequence. Hence, the proposed LSTM methodology outperforms all the other traditional methods, proving the superiority of Deep Learning techniques compared to the traditional methods. Although, the proposed LSTM methodology outperformed the traditional Machine Learning methods, one disadvantage that Deep Learning algorithms carry is about high computational requirements. The proposed LSTM Network model was trained on GPU to accelerate the training performance. However, training of the Machine Learning algorithms required normal CPU processing and less time, and therefore if computational resources are available, then Deep Learning algorithms should always be opted for. The results from the LSTM model are quite promising, however the dataset is not large enough and the trained model may not measure up to the same performance if the model is exposed to data points from a new environment for testing. This happens due to the randomness of background conditions and different noise conditions of new data points. Hence, more research is needed in this area, so that the Deep Learning models are able to adapt to a new environment without sacrificing their performance. Therefore, it would be helpful to have a broader analysis with a larger dataset.
8 Conclusion and Future Work The paper focused on the diagnosis of the bearing inner and outer race fault. Conventional bearing fault diagnosis involves the use of vibrational signal from the machine’s accelerometer. However, in this paper, MCS is used instead, for the fault diagnosis. By using MCS, its potential in effective bearing fault diagnosis has been demonstrated. With this methodology, the need for additional sensor for monitoring the vibrational
Diagnosis of Bearing Faults in Electrical Machines …
97
signals is eliminated, reducing cost of the fault monitoring system. The Paderborn university real damaged bearing dataset was considered. From the datasets, the stator current of one phase was used and processed by removing the redundant frequencies ω0 and 5ω0 . Then, eight features are extracted from this filtered signal, three features from the time domain, and five features from the time–frequency domain using third-level WPD. These eight features were scaled and then fed to the Deep Learning LSTM network with four hidden layers. The LSTM network was able to show excellent results with a classification accuracy of 97%, even with the dataset containing data at different operating conditions, i.e., different speed and load, which proves that the method developed is independent of the machine operating conditions. In the end, the proposed LSTM methodology is compared with the traditional Machine Learning methods by using five-fold cross-validation, and the proposed methodology outperformed the traditional methods by achieving more than 1.5% accuracy than the best performing algorithm. Hence, it has been shown that fault diagnosis with MCS by using the proposed LSTM algorithm is able to give similar, if not better performance than the diagnosis with vibrational signals. For future work, diagnosis of other bearing faults, e.g., ball fault and cage fault using the stator current will be considered. Secondly, other Deep Learning methods that are listed in [30] such as CNN-LSTM networks, algorithms that incorporate the randomness of different working conditions and algorithms that are able to adapt to different operating environments without compromising on the method’s performance will be explored. Furthermore, more research will done on increasing the classification performance of the network by considering denoising LSTMs, regularization methods, and a larger training dataset.
References 1. R. Sabir, S. Hartmann, C. Gühmann, Open and short circuit fault detection in alternators using the rectified DC output voltage, in 2018 IEEE 4th Southern Power Electronics Conference (SPEC) (Singapore, 2018), pp. 1–7 2. R. Sabir, D. Rosato, S. Hartmann, C. Gühmann, Detection and localization of electrical faults in a three phase synchronous generator with rectifier, in 19th International Conference on Electrical Drives & Power Electronics (EDPE 2019) (Slovakia, 2019) 3. Common causes of bearing failure | applied. Applied (2019). https://www.applied.com/bearin gfailure 4. I.Y. Onel, M.E.H. Benbouzid, Induction motors bearing failures detection and diagnosis: park and concordia transform approaches comparative study, in 2007 IEEE International Electric Machines & Drives Conference (Antalya, 2007), pp. 1073–1078 5. R.R. Schoen, T.G. Habetler, F. Kamran, R.G. Bartfield, Motor bearing damage detection using stator current monitoring. IEEE Trans. Ind. Appl. 31(6), 1274–1279 (1995). https://doi.org/10. 1109/28.475697 6. H. Pan, X. He, S. Tang, F. Meng, An improved bearing fault diagnosis method using onedimensional CNN and LSTM. J. Mech. Eng. 64(7–8), 443–452 (2018) 7. X. Guo, C. Shen, L. Chen, Deep fault recognizer: an integrated model to denoise and extract features for fault diagnosis in rotating machinery. Appl. Sci. 7(41), 1–17 (2017)
98
R. Sabir et al.
8. H. Shao, H. Jiang, Y. Lin, X. Li, A novel method for intelligent fault diagnosis of rolling bearings using ensemble deep autoencoders. Knowl.-Based Syst. 119, 200–220 (2018) 9. D. Filbert, C. Guehmann, Fault diagnosis on bearings of electric motors by estimating the current spectrum. IFAC Proc. 27(5), 689–694 (1994) 10. S. Yeolekar, G.N. Mulay, J.B. Helonde, Outer race bearing fault identification of induction motor based on stator current signature by wavelet transform, in 2017 2nd IEEE International Conference on Recent Trends in Electronics, Information & Communication Technology (RTEICT) (Bangalore, 2017), pp. 2011–2015 11. F. Ben Abid, A. Braham, Advanced signal processing techniques for bearing fault detection in induction motors, in 2018 15th International Multi-Conference on Systems, Signals & Devices (SSD) (Hammamet, 2018), pp. 882–887 12. A. Bellini, F. Immovilli, R. Rubini, C. Tassoni, Diagnosis of bearing faults of induction machines by vibration or current signals: a critical comparison, in 2008 IEEE Industry Applications Society Annual Meeting (Edmonton, AB, 2008), pp. 1–8 13. A. Soualhi, G. Clerc, H. Razik, Detection and diagnosis of faults in induction motor using an improved artificial ant clustering technique. IEEE Trans. Ind. Electron. 60(9), 4053–4062 (2013) 14. S. Gunasekaran, S.E. Pandarakone, K. Asano, Y. Mizuno, H. Nakamura, Condition monitoring and diagnosis of outer raceway bearing fault using support vector machine, in 2018 Condition Monitoring and Diagnosis (CMD) (Perth, WA, 2018), pp. 1–6 15. I. Andrijauskas, R. Adaskevicius, SVM based bearing fault diagnosis in induction motors using frequency spectrum features of stator current, in 2018 23rd International Conference on Methods & Models in Automation & Robotics (MMAR) (Miedzyzdroje, 2018), pp. 826–831 16. S.E. Pandarakone, M. Masuko, Y. Mizuno, H. Nakamura, Deep neural network based bearing fault diagnosis of induction motor using fast fourier transform analysis, in 2018 IEEE Energy Conversion Congress and Exposition (ECCE) (Portland, OR, 2018), pp. 3214–3221 17. J.S. Lal Senanayaka, H. Van Khang, K.G. Robbersmyr, Autoencoders and data fusion based hybrid health indicator for detecting bearing and stator winding faults in electric motors, in 2018 21st International Conference on Electrical Machines and Systems (ICEMS) (Jeju, 2018), pp. 531–536 18. I. Kao, W. Wang, Y. Lai, J. Perng, Analysis of permanent magnet synchronous motor fault diagnosis based on learning. IEEE Trans. Instrum. Meas. 68(2), 310–324 (2019) 19. A beginner’s guide to LSTMs and recurrent neural networks (Skymind, 2019). https://skymind. ai/wiki/lstm 20. S. Zhang, S. Zhang, B. Wang, T.G. Habetler, Machine learning and deep learning algorithms for bearing fault diagnostics-a comprehensive review (2019). arXiv preprint arXiv:1901.08247 21. Z.C. Lipton, J. Berkowitz, C. Elkan, A critical review of recurrent neural networks for sequence learning (2015). arXiv preprint arXiv:1506.00019 22. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9, 1735–1780. (1997) (source: Stanford CS231N) 23. F. Immovilli, A. Bellini, R. Rubini et al., Diagnosis of bearing faults of induction machines by vibration or current signals: a critical comparison. IEEE Trans. Ind. Appl. 46(4), 1350–1359 (2010) 24. Konstruktions-und Antriebstechnik (KAt)—Data Sets and Download (Universität Paderborn), Mb.uni-paderborn.de (2019). https://mb.uni-paderborn.de/kat/forschung/datacenter/bea ring-datacenter/data-sets-and-download/ 25. C. Lessmeier, J.K. Kimotho, D. Zimmer, W. Sextro, Condition monitoring of bearing damage in electromechanical drive systems by using motor current signals of electric motors: a benchmark data set for data-driven classification, in Proceedings of the European Conference of the Prognostics and Health Management Society (2016), pp. 05–08 26. Z. Huo, Y. Zhang, P. Francq, L. Shu, J. Huang, Incipient fault diagnosis of roller bearing using optimized wavelet transform based multi-speed vibration signatures. IEEE Access 5, 19442–19456 (2017)
Diagnosis of Bearing Faults in Electrical Machines …
99
27. X. Wang, Z. Lu, J. Wei, Y. Zhang, Fault diagnosis for rail vehicle axle-box bearings based on energy feature reconstruction and composite multiscale permutation entropy. Entropy 21(9), 865 (2019) 28. S. Djaballah, K. Meftah, K. Khelil, M. Tedjini, L. Sedira, Detection and diagnosis of fault bearing using wavelet packet transform and neural network. Frattura ed Integrità Strutturale 13(49), 291–301 (2019) 29. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization (2014). arXiv preprint arXiv: 1412.6980 30. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020)
Automatic Solar Panel Detection from High-Resolution Orthoimagery Using Deep Learning Segmentation Networks Tahir Mujtaba and M. Arif Wani
Abstract Solar panel detection from aerial or satellite imagery is a very convenient and economical technique for counting the number of solar panels on the rooftops in a region or city and also for estimating the solar potential of the installed solar panels. Detection of accurate shapes and sizes of solar panels is a prerequisite for successful capacity and energy generation estimation from solar panels over a region or a city. Such an approach is helpful for the government to build policies to integrate solar panels installed at home, offices, and buildings with the electric grids. This study explores the use of various deep learning segmentation algorithms for automatic solar panel detection from high-resolution ortho-rectified RGB imagery with resolution of 0.3 m. We compare and evaluate the performance of six deep learning segmentation networks in automatic detection of the distributed solar panel arrays from satellite imagery. The networks are tested on real data and augmented data. Results indicate that deep learning segmentation networks work well for automatic solar panel detection from high-resolution orthoimagery.
1 Introduction The world is exploring to utilize more renewable energy sources as the non-renewable energy sources are getting depleted. One of the most important and abundantly available renewable energy sources is solar energy. For utilizing solar energy, solar panels are installed on ground and roof tops of buildings to convert solar energy into electric energy. Throughout the world, there has been a surge in installing solar panels to get the most out of this form of energy source. One of the challenges is to detect solar panels installed on ground and buildings from aerial imagery. Accurate detection T. Mujtaba (B) · M. A. Wani Department of Computer Science, University of Kashmir, Srinagar, India e-mail: [email protected] M. A. Wani e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_5
101
102
T. Mujtaba and M. A. Wani
of solar panels with accurate shapes and sizes is a prerequisite for determining the capacity of energy generation from these panels. For this problem, semantic segmentation is required to accurately detect the solar panels from the aerial imagery. For the past several years, machine learning approaches have been used for many applications like classification, clustering, and image recognition [1–7], and some research has been done to analyze and asses the objects like roads, buildings, and vehicles present in satellite imagery. In recent past, deep learning has been found to be more effective in image recognition problems than the traditional machine learning techniques [8]. The image recognition problems include tasks like image classification, object detection, and semantic segmentation. Deep learning has outperformed the traditional machine learning techniques in these image recognition tasks. Among these image recognition tasks, the problem of semantic segmentation is one of the key research areas. Semantic segmentation is the process of classifying each pixel of the image to a set of predefined classes, therefore dividing an image into set of regions where each region consists of pixels belonging to one common class. A number of traditional machine learning techniques have been used in segmentation of images in the past. The general procedure of traditional machine learning techniques is to use wellestablished feature descriptors to extract features from an image and use a classifier like support vector machines or random forest classifier on each pixel to determine its likeliness of belonging to one of the predefined classes. Such techniques heavily depend on the procedure that is used to extract features and usually need skilled feature engineering for designing such procedures. It is necessary to determine which features are important for an image. When the number of classes increases, feature extraction becomes tedious task. Moreover, such techniques do not involve using one end-to-end approach, feature extraction and classification is done separately. In Comparison to it, deep learning models are end-to-end models where both feature extraction and classification are done through a single algorithm. Further, model parameters are learnt automatically during the process of learning. Because of the above reasons, this work has used the deep learning approach for the problem of solar panel detection. Deep learning segmentation process has evolved where a deep learning image classification model has been converted into a segmentation model [9]. The segmentation process has been improved with the introduction of UNet [10] model. The model essentially consists of an encoder and a decoder with skip connections from lower layers of encoder to higher layers of decoder to transfer the learned features. The purpose of encoder is to extract features that are up-sampled in the decoder and features in the decoder are concatenated with the features extracted in the encoder. Up-sampling is done by using transposed convolutions. An index-based up-sampling technique has been introduced in [11]. The technique remembers the index of pixel values with maximum intensity during max-pooling operation and then uses the same indexes during up-sampling process. Dilated convolution that uses dilated filters has been tested in [12–14]. Dilated convolution helps in reduction of learnable parameters and it also helps in aggregating the context of the image. The aim of this work is
Automatic Solar Panel Detection from High-Resolution …
103
to analyze and compare the performance of various segmentation architectures for the problem of solar panel detection from aerial imagery. This chapter is organized as follows. Section 2 presents an overview of the related work done on the problem of solar panel detection. Section 3 describes deep learning-based segmentation. Section 4 discusses different deep learning segmentation networks used in this study. Section 5 presents results and discussion. Conclusion is finally presented in Sect. 6.
2 Related Work With the availability of high-resolution satellite images and the emergence of deep learning networks, it has become possible to analyze the high-resolution images for accurate object detection. Detection of the objects like buildings, roads, and vehicles from aerial/satellite images has been studied more widely during the past many years than the detection of solar panels. A few studies [15–17] have been reported in the literature that explore detection of solar panels from satellite or orthoimagery using deep learning techniques. VggNet with pretraining has been used in [16] to detect the solar panels from the aerial imagery dataset given in [18]. The authors have used a very basic VggNet architecture consisting of six convolutional layers and two fully connected layers. The model is applied on every pixel to detect whether it belongs to solar panel or not. Further processing is performed to connect contagious pixels and declare these as regions. Though this model has been found effective in detecting solar panels but it does not give the shape and size of panel arrays explicitly. Authors in [17] proposed a fully convolutional neural network for solar panel detection. It uses seven convolutional layers with different number of filters and filter sizes. It first extracts features from the image by using several convolution and maxpooling layers, it then up-samples the extracted features and uses skip connection from the shallow layers which contain fine grained features and concatenates them with coarse layers. Although the work has been able to detect solar panels, the standard metric used for segmentation is missing. A deep learning segmentation architecture called SegNet has been used in [15] for automatic detection of solar panels from ortho-rectified images given is [18]. Again it has used very small image patches of size 41 × 41 for training purposes. These small patches may not contain much of the diverse background like buildings, roads, trees, and vehicles, which may restrict the algorithm’s capability to perform in diverse backgrounds. Further, each image in the dataset, which is of size 5000 × 5000, would generate approximately 14,872 image patches of size 41 × 41 on which training and testing are to be done. Very little work has been done on solar panel detection using deep learning-based segmentation.
104
T. Mujtaba and M. A. Wani
3 Deep Learning-Based Segmentation Segmentation of an image involves classifying every pixel of an image into a given set of classes. Unlike some applications where the aim to classify the whole image into one of the given classes, semantic segmentation makes dense predictions by inferring labels to every pixel. The contemporary classification architectures like VggNet [19], ResNet [20], and DenseNets [21] can be converted into architectures that are suitable for segmentation. The segmentation process is shown in Fig. 1 where general encoder–decoder architecture for deep learning segmentation process is shown. A segmentation model generally consists of an encoder and a decoder. The encoder is usually a pretrained classification network with its fully connected layers removed. Its task is to extract features and produce a heatmap with low resolution. The task of the decoder is to up-sample the heatmap successively to the original resolution of the input. The deep learning segmentation algorithms differ in the way encoders extract features and the way decoders perform up-sampling with different skip connection strategies. The different encoding, decoding, and up-sampling techniques used in deep learning-based segmentation are discussed below.
3.1 Encoding Techniques The main purpose of the encoder is to extract features from the images through successive convolution and pooling layers. The encoding part usually comprises networks like VggNet, ResNet, and DenseNet with their fully connected layers removed.
Fig. 1 General encoder–decoder architecture for deep learning segmentation process
Automatic Solar Panel Detection from High-Resolution …
105
a. VggNet VggNet [19] has various variants and the prominent one is VggNet-16, which consists 13 convolution layers, 5 max-pooling layers, and 3 fully connected layers. VggNet19 is also prominent and consists 16 convolutional layers, 5 max-pooling layers, and 3 fully connected layers. VggNet uses smaller sized 3 × 3 convolution filters as compared with the AlexNet which uses 7 × 7 convolution filters. Within a given receptive field, using multiple number of 3 × 3 filters is better than using one larger sized 7 × 7 filter as it involves less parameters and reduces computational effort. b. ResNet Residual Network (ResNet) [20] uses residual connections or identity shortcut connections skipping one or more layers to increase the performance without increasing the depth of the network. The network consists of stacked residual blocks with 3 × 3 convolutions. Various variants of this network are: ResNet-18, ResNet34, ResNet-52, ResNet-101, and ResNet-156, consisting of 18, 34, 52, 101, and 156 layers, respectively. Such a network can be easily adopted in the segmentation process with fully connected layers removed. c. DenseNet The main characteristic feature of DenseNet [21] is that its every layer is connected to every other layer in a feedforward manner resulting in (L(L + 1))/2 direct connections where L is the number of layers in the network. For each and every layer, the features maps of all preceding layers are used as input and its own produced feature maps are used in all subsequent layers. DenseNet alleviates vanishing gradient problem, strengthens feature propagation, and encourages feature reuse and reduces the number of parameters. The network consists of various dense blocks where a layer is connected to every subsequent layer. Each layer in a dense block consists of batch normalization, ReLU activation and 3 × 3 convolution. The layers between two dense blocks are transitions layers consisting of batch normalization, 1 × 1 convolution and average pooling. Such type of network can be easily adopted for semantic segmentation.
3.2 Decoding Techniques A decoder increases the image resolution by using up-sampling techniques. Some of the decoding techniques commonly used in deep learning segmentation process are summarized below.
106
T. Mujtaba and M. A. Wani
a. Skip Connection-based Decoders i. FCN Based decoding FCN-based decoder [9] uses transpose convolution (deconvolution) for up-sampling. The decoder has different variants depending on from which pooling layer a skip connection is added: (a) FCN-32 has no skip connection, (b) FCN-16 uses one skip connections from fourth pooling layer, and (c) FCN-8 uses two skip connections from the third and fourth pooling layers. After last convolution, FCN decoder uses 1 × 1 convolution layer and softmax function to classify image pixels. The number of 1 × 1 convolutions equals the number of classes into which the image pixels are to be categorized. ii. UNet-Based decoding The UNet-based decoder [10] uses transpose convolution for up-sampling followed by convolutional layers and ReLU activation. UNet extends the concept of skip connection used in the FCN decoder. Here every encoder is connected to its corresponding decoder unit through a skip connection. The features learnt in the encoder block are carried over and concatenated with the features of the decoder block. b. Max-pooling index-based Decoder Max-pooling index-based decoder [11] uses the indices stored during the maxpooling process of the encoder to up-sample the feature maps in the corresponding decoder during the decoding process. Unlike the FCN and UNet decoders, no features are carried to the decoders through skip connections. The up-sampled features are then convolved with trainable filters. c. Un-pooling-based Decoder Un-pooling-based decoder [22] uses un-pooling and deconvolutional layers in its decoder blocks. The un-pooling reverses the pooling operation performed during encoding. In the encoding process, it remembers the maximum activation value during max-pooling operation and during decoding it uses the un-pooling operation to restore the resolution of the activations. The un-pooling operation is performed by using switch variables which remember the location of the maximum activations during max pooling. After un-pooling, a deconvolutional layer is used to densify the sparse feature maps produced by un-pooling.
3.3 Up-Sampling Techniques The various up-sampling techniques that have been used to increase the resolution of the feature maps in the decoder part are summarized below
Automatic Solar Panel Detection from High-Resolution …
107
a. Nearest Neighbor Nearest Neighbor up-sampling technique simply copies the pixel value of the nearest pixel to its neighboring pixel. b. Bed of Nail Bed of Nails puts value of a pixel in a particular/fixed position in the output and the rest of the positions are filled with value 0. c. Bilinear up-sampling It calculates a pixel value by interpolating the values from the nearest pixels which are known but unlike nearest neighbor technique the ratio of contribution from each nearby pixel matters here and is inversely proportional to the ratio of their corresponding distance. d. Max Un-pooling It remembers the index of the maximum activation during the max-pooling operation and uses the same index to position the pixel value in the output during the up-sampling. e. Transpose Convolution Transpose convolution is the most effective and most commonly used technique for image up-sampling in deep learning semantic segmentation because it’s a learnable up-sampling. The input is padded with zeros when convolution is applied. f. Dilated Convolutions Also known as Atrous convolution was first developed for the efficient computation of the undecimated wavelet transform. Dilated convolution is a normal convolution with a wider kernel. The kernel in the convolution is exponentially expanded to capture more context of the image without increasing the number of parameters. A normal convolution is a dilated convolution with a dilation rate equal to 1.
4 Deep Learning Segmentation Architectures Used 4.1 UNet A fully convolutional segmentation architecture for medical image segmentation reported in [10] is explored here for segmentation of solar panel images. A UNet model essentially consists of an encoder and a decoder. The encoder part uses various convolution and pooling operations to extract the features from the image. The output of the encoder is a heatmap which serves as input to the decoder. The purpose of the decoder is to up-sample the heatmap so that spatial dimensions match the input, densify the segmentation, classify the pixels, and produce a segmentation map. The decoder semantically projects the fine features learnt in the beginning layers into the higher layers to produce a dense segmentation. The encoder and decoder architecture forms a U-shaped structure that gives it the name UNet. The contracting path acquires
108
T. Mujtaba and M. A. Wani
Fig. 2 UNet architecture used in this study
context and expanding path facilitates accurate localization. Up-sampling in decoder is done using transpose convolutions. Architecture of UNet used in this study is shown in Fig. 2.
4.2 SegNet The segmentation architecture SegNet [11] for scene understanding applications that is efficient in terms of memory and computational time is explored here for automatic detection of solar panels from satellite images. The SegNet architecture consists of an encoder and decoder like UNet but differs in how up-sampling is done in the decoder part. The deconvolutional layers used for up-sampling in decoder part of UNet are time and memory consuming because up-sampling is performed using a learnable model, which implies filters for up-sampling are learned during the training process. The SegNet architecture replaces the learnable up-sampling by computing and memorizing the max-pool indices and later uses these indices to up-sample the features in the corresponding decoder block to produce sparse feature maps. It then uses normal convolution with trainable filters to densify these sparse feature maps. However, there are no skip connections for feature transfer like in UNet. The use of max-pool indices results in reduced number of model parameters which eventually takes less time to get trained. The architecture of SegNet used in this study is given in Fig. 3. The max-pooling index concept is illustrated in Fig. 4 which also distinguishes between the process of up-sampling used in UNet and SegNet architectures.
Automatic Solar Panel Detection from High-Resolution …
Fig. 3 SegNet architecture used in this study
Fig. 4 Decoders used in FCN and SegNet
109
110
T. Mujtaba and M. A. Wani
Fig. 5 Multicontext aggregation architecture (Dilated Net) used in this study
4.3 Dilated Net The use of dilated convolution for context aggregation in semantic segmentation discussed in [13] is explored for automatic detection of solar panels in satellite images. The feature context aggregation is done without losing resolution or using rescaled versions of images. It introduces a context module consisting of 7 layers with 3 × 3 convolution and dilated convolutions applied at different rates—1, 1, 2, 4, 8, 16, 1. The last convolution is a 1 × 1 × C convolution to produce the final output of the module. The architecture uses two types of context modules—a basic module and a large context module. The basic module contains same number of channels (C) throughout the module while as the large context module contains increasing number of channels (C) as input. The architecture introduces another module known as front module which is constructed over a VggNet by removing the last two pooling layers and striding layers and adding dilation convolution in layers to follow. Finally, it adds a context module to the front module for dense semantic prediction and contextualization. The architecture is shown in Fig. 5. A 2D dilated convolution with different dilation rates is shown in Fig. 6.
4.4 PSPNet PspNet [23] captures global context information by introducing a pyramid pooling module for better classification of small objects is explored for automatic detection of solar panels in satellite images. Small objects are hard to find but have a great importance in overall scene categorization. The pyramid pooling module gathers global context information along with sub-region context for categorization of different
Automatic Solar Panel Detection from High-Resolution …
111
Fig. 6 Dilated convolution in 2D with different dilation rates. a Dilation rate = 1. b Dilation rate = 2, c dilation rate = 3
objects. The pyramid pooling module uses four different pooling operations: 1 × 1, 2 × 2, 3 × 3, and 6 × 6. The pooling operations generate feature maps of different sub-regions and form pooled representation for different locations. The output feature maps from these pooling operations are of varied sizes. It then uses bilinear interpolation for up-sampling these features to a size of input resolution. The number of pyramid levels and size of each level can be modified. The architecture of PSPNet is shown in Fig. 7.
Fig. 7 PSPNet Architecture used in this study
112
T. Mujtaba and M. A. Wani
Fig. 8 Deep Lab v3+ Architecture used in this study
4.5 DeepLab v3+ DeepLab v3+ [12] makes use of an encoder–decoder structure for dense semantic segmentation is explored for automatic detection of solar panels in satellite images. The encoder–decoder structure has two advantages: (i) it is capable of encoding multi-scale contextual information by probing the incoming features with filters or pooling operations at multiple rates and multiple effective fields-of-view, (ii) it can capture sharper object boundaries by gradually recovering the spatial information through the use of skip connections. It has an additional simple and effective decoder module to refine the segmentation results especially along object boundaries. It further applies depth-wise separable convolution to both Atrous Spatial Pyramid Pooling and decoder modules, resulting in a faster and stronger encoder–decoder network. The detailed encoder and decoder structure is given in Fig. 8.
4.6 Dilated Residual Network Dilated Residual Network [14] uses dilated convolutions in a Residual Network for classification, and segmentation tasks are explored for automatic detection of solar panels in satellite images. In convolutional networks, the spatial size of feature maps gets continuously reduced due to multiple use of pooling and striding operations. Such a loss in spatial structure limits the model’s ability to produce good results in
Automatic Solar Panel Detection from High-Resolution …
113
Fig. 9 Dilated residual network architecture used in this study
classification and segmentation tasks. Dilated Residual Network introduces dilated convolution in a residual network to increases the receptive field of the feature maps without increasing the parameters. Dilated Residual Network also develops an effective dilation strategy for dealing the gridding pattern problem that occurs due to increase of dilation rates at successive layers. A residual network is converted into dilated residual network by removing the striding in first layer of block 4 and block 5 and introducing dilation rate of 2 in the rest of the layers of block 4 and 1st layer of block 5 and dilation rate of 4 in the rest of the layers of the block 5. Predictions are produced by 1 × 1 × C layer where C is the number of classes. The feature responses so produced have a resolution of 1/8 of the original image resolution and are bilinearly up-sampled to get the same resolution as input image. A Dilated residual network is shown in Fig. 9.
5 Results and Discussion 5.1 Dataset Used The purpose of this study is to detect the location of solar panels in satellite images of buildings and earth’s surface using deep learning segmentation techniques. The training step of a segmentation process requires a dataset containing both the images as well as their corresponding masks which are used as ground truth. The masks highlight the pixels which correspond to the solar panels. As the dataset containing both the images and corresponding masks is not publicly available, it was decided to
114
T. Mujtaba and M. A. Wani
use the dataset described in [18] and prepare masks for this dataset before training the models. The dataset has 601 TIF orthoimages of four cities of California and contains geospatial coordinates and vertices of about 19,000 solar panels spread across all images. This work has used images of Fresno city. Each image is 5000by-5000 pixels and covers an area of 2.25 km2 . The images are of urban, suburban, and rural landscape type, allowing the model to get trained on diverse images. The vertices of the solar panels associated with each image have been utilized to create polygon areas of white pixels corresponding to solar panels and setting the remaining pixels as black pixels to represent background. Figure 10 shows a sample image of size 5000-by-5000 pixels and its sub-image of size 224-by-224 pixels. To make the models robust, data augmentation has been performed to reduce overfitting and improve generalization. The augmentation was achieved by performing horizontal flip and vertical flip on images. These two augmentations proved useful in training the models and increasing the segmentation accuracy.
Fig. 10 First row shows an image and its mask of size 5000 × 5000. Second row shows an image with its mask of size 224 × 224
Automatic Solar Panel Detection from High-Resolution …
115
5.2 Training All the architectures have been implemented in Python and trained on workstation with Nvidia Tesla K40 (12 GB) GPU and 30 GB RAM. As image size of 5000-by5000 is huge for training purposes and needs a high GPU and RAM configuration, the images and their corresponding masks have been cropped to size 224 × 224. A total of 1118 cropped images were selected for augmentation and training. The testing was done on image crops of the same size. Adam’s learning algorithm with a fixed learning rate of l × 10−5 for ResNet-based models and l × 10−4 for VggNet-based models has been used. Training was done from scratch without any pretraining or transfer learning. The early stopping criterion of 15 epochs to stop the model trainings has been used.
5.3 Performance Metric and Loss Function The performance measure metric used in this study is dice coefficient (f1 score) which is one of the most widely used metric used in segmentation. This metric is used to quantify how similar the ground truth annotated segmentation region matches with the predicted segmentation region of the model. It is defined as the ratio of intersection (overlap) of two regions to the union of the two regions. Given two sets of pixels denoted by X and Y, the dice coefficient index is given by: DC = (2 ∗ |X ∩ Y|)/(|X |U |Y |)
(1)
The value of dice coefficient ranges from 0 to 1. Value close to 1 means more overlap and similarity between the two regions, hence more accurate the predicted segmentation from the model. The loss function used in this study is the dice loss (DL) originated from dice coefficient and was used by [24] and is defined as DL = 1 − DC where DL is the dice loss and DC is the dice coefficient defined above in (1). The dice loss is used to optimize the value of DC during the training process.
5.4 Experimental Results The experimental results have been obtained by using augmented as well as original datasets. The dice coefficients and loss results of training, validation, and testing of augmented and original datasets are reported here.
116
T. Mujtaba and M. A. Wani
Table 1 Results on UNet model Model UNet
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.8826
0.1173
0.8867
0.1132
0.8963
0.1036
Original
0.8750
0.124
0.8819
0.1181
0.8943
0.1056
Table 1 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on UNet model. It can be seen from Table 1 that augmented dataset has helped in improving training, validation, and testing accuracy results. Figure 11 shows the dice coefficient bar graphs of training, validation, and testing augmented and original datasets on UNet model. As can be observed from Fig. 11, the data augmentation has produced better DC values for training, validation, and testing of datasets on UNet model. Table 2 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on SegNet model. It can be seen from Table 2 that augmented dataset has helped in improving training, validation, and testing accuracy results. The training, validation, and testing results of dice coefficients for augmented datasets on SegNet model improves by a margin of about 9–10% when compared with the results of original dataset on SegNet model. Figure 12 has shown the dice coefficient bar graphs of training, validation, and testing augmented and original datasets on SegNet model. As can be observed from Fig. 12, the data augmentation
Fig. 11 Dice Coefficients of augmented and original datasets on UNet model. a Results of training process. b Results of validation process. c Results of testing process
Table 2 Results on SegNet model Model SegNet
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.8167
0.1832
0.7775
0.2224
0.7471
0.2528
Original
0.7102
0.2897
0.6828
0.3171
0.6425
0.3574
Automatic Solar Panel Detection from High-Resolution …
117
Fig. 12 Dice Coefficients of augmented and original datasets on SegNet model. a Results of training process. b Results of validation process. c Results of testing process
has produced better DC values for training, validation, and testing of datasets on SegNet model. Table 3 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on Dilated Net model. It can be seen from Table 3 that augmented dataset has helped in improving training, validation, and testing accuracy results. The training, validation, and testing results of dice coefficients for augmented datasets on Dilated Net model improves by a margin of about 1–2% when compared with the results of original dataset on Dilated Net model. Figure 13 has shown the dice coefficient bar graphs of training, validation, and testing augmented and original datasets on Dilated Net model. As can be observed from Fig. 13, the Table 3 Results on Dilated net Model Dilated net
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.6956
0.3043
0.6591
0.3408
0.6732
0.3267
Original
0.6862
0.3137
0.6465
0.3534
0.6615
0.3384
Fig. 13 Dice coefficients of augmented and original datasets on dilated net model. a Results of training process. b Results of validation process. c Results of testing process
118
T. Mujtaba and M. A. Wani
Table 4 Results on PSPNet Model PSPNet
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.6025
0.397
0.5527
0.4472
0.5091
0.4908
Original
0.6122
0.3877
0.4102
0.5897
0.4181
0.5818
data augmentation has produced better DC values for training, validation, and testing of datasets on Dilated Net model. Table 4 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on PSPNet model. It can be seen from Table 4 that augmented dataset has helped in improving validation and testing accuracy results. The validation and testing results of dice coefficients for augmented datasets on PSPNet model improves by a margin of about 9–15% when compared with the results of original dataset on PSPNet model. However, training results of dice coefficients for augmented datasets on PSPNet model decreases. This implies more epochs are required to train the PSPNet with larger datasets. Figure 14 has shown the dice coefficient bar graphs of training, validation, and testing augmented and original datasets on PSPNet model. As can be observed from Fig. 14, the data augmentation has produced better DC values for validation, and testing of datasets on PSPNet model.
Fig. 14 Dice coefficients of augmented and original datasets on PSPNet model. a Results of training process. b Results of validation process. c Results of testing process
Table 5 Results on DeepLab v3+ model Model DeepLab v3+
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.7572
0.2427
0.6410
0.3589
0.6610
0.3389
Original
0.7877
0.2122
0.5654
0.4345
0.5713
0.4286
Automatic Solar Panel Detection from High-Resolution …
119
Table 5 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on DeepLab v3+ model. It can be seen from Table 5 that augmented dataset has helped in improving validation and testing accuracy results. The validation and testing results of dice coefficients for augmented datasets on DeepLab v3 +model improves by a margin of about 7–9% when compared with the results of original dataset on DeepLab v3+ model. However, training results of dice coefficients for augmented datasets on DeepLab v3+ model decreases. This implies more epochs are required to train the DeepLab v3+ with larger datasets. Figure 15 has shown the dice coefficient bar graphs of training, validation and testing augmented and original datasets on DeepLab v3+ model. As can be observed from Fig. 15, the data augmentation has produced better DC values for validation, and testing of datasets on DeepLab v3+ model. Table 6 shows dice coefficient and loss results of training, validation, and testing augmented and original datasets on Dilated ResNet model. It can be seen from Table 6 that augmented dataset has helped in improving validation and testing accuracy results. The validation and testing results of dice coefficients for augmented datasets on Dilated ResNet model improves by a margin of about 5–9% when compared with the results of original dataset on Dilated ResNet model. However, training results of dice coefficients for augmented datasets on Dilated ResNet model decreases. This implies more epochs are required to train the Dilated ResNet with larger datasets. Figure 16 has shown the dice coefficient bar graphs of training, validation, and testing augmented and original datasets on Dilated ResNet model. As can be observed from
Fig. 15 Dice Coefficients of augmented and original datasets on DeepLab v3+ model. a Results of training process. b Results of validation process. c Results of testing process
Table 6 Results on Dilated ResNet model Model Dilated ResNet
Dataset type
Training results
Validation results
Testing results
DC
Loss
DC
Loss
DC
Loss
Augmented
0.7164
0.2835
0.6766
0.3233
0.6307
0.3692
Original
0.7498
0.2501
0.6203
0.3796
0.5434
0.4565
120
T. Mujtaba and M. A. Wani
Fig. 16 Dice Coefficients of augmented and original datasets on Dilated ResNet model. a Results of training process. b Results of validation process. c Results of testing process
Fig. 16, the data augmentation has produced better DC values for validation, and testing of datasets on Dilated Resnet model. Dice coefficient results of testing augmented and original datasets on all the six models have been summarized in Fig. 17 in the form of bar graphs. The bar graphs indicate that the UNet model produces the best value of DC, implying that the best segmentation accuracy results are produced by UNet model, followed by SegNet and DilatedNet models.
Fig. 17 DC values of testing augmented and original datasets on all the six models
Automatic Solar Panel Detection from High-Resolution …
121
6 Conclusion This work described the automatic detection of solar panels from satellite imagery by using deep learning segmentation models. The study thoroughly discussed various state of art deep learning segmentation architectures, various encoding, decoding, and up-sampling techniques used in deep learning segmentation process. The six architectures for automatic detection of solar panels used were UNet, SegNet, Dilated Net, PSPNet, DeepLab v3+, and Dilated Residual Net. The dataset comprised satellite images of four cities of California. Image size of 224 × 224 was used for training the models. The results concluded that the UNet deep learning architecture that uses skip connections with encoder and decoder modules produced the best segmentation accuracy results. Moreover, dataset augmentation helped to improve the segmentation accuracy results further.
References 1. M.A. Wani, Incremental hybrid approach for microarray classification, in 2008 Seventh International Conference on Machine Learning and Applications (IEEE, 2008), pp. 514–520 2. M.A. Wani, R. Riyaz, A new cluster validity index using maximum cluster spread based compactness measure. Int. J. Intell. Comput. Cybern. (2016) 3. M.A. Wani, R. Riyaz, A novel point density based validity index for clustering gene expression datasets. Int. J. Data Mining Bioinf. 17(1), 66–84 (2017) 4. R. Riyaz, M.A. Wani, Local and global data spread based index for determining number of clusters in a dataset, in 2016 15th IEEE International Conference on Machine Learning and Applications (ICMLA) (IEEE, 2016), pp. 651–656 5. F.A. Bhat, M.A. Wani, Performance comparison of major classical face recognition techniques, in 2014 13th International Conference on Machine Learning and Applications (IEEE, 2014), pp. 521–528 6. M.A. Wani, M. Yesilbudak, Recognition of wind speed patterns using multi-scale subspace grids with decision trees. Int. J. Renew. Res. (IJRER) 3(2), 458–462 (2013) 7. M.R. Wani, M.A. Wani, R. Riyaz, Cluster based approach for mining patterns to predict wind speed, in 2016 IEEE International Conference on Renewable Energy Research and Applications (ICRERA) (IEEE, 2016), pp. 1046–1050 8. M.A. Wani, F.A. Bhat, S. Afzal, A.L. Khan, Advances in Deep Learning (Springer, 2020) 9. J. Long, E. Shelhamer, T. Darrell, Fully convolutional networks for semantic segmentation, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3431– 3440 (2015) 10. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image segmentation, in International Conference on Medical Image Computing and ComputerAssisted Intervention (Springer, Cham, 2015), pp. 234–241 11. V. Badrinarayanan, A. Kendall, R. Cipolla, Segnet: a deep convolutional encoder-decoder architecture for image segmentation. IEEE Trans. Pattern Anal. Mach. Intell. 39(12), 2481– 2495 (2017) 12. L.C. Chen, Y. Zhu, G. Papandreou, F. Schroff, H. Adam, Encoder-decoder with atrous separable convolution for semantic image segmentation, in Proceedings of the European Conference on Computer Vision (ECCV) (2018), pp. 801–818 13. F. Yu, V. Koltun, Multi-scale context aggregation by dilated convolutions, in ICLR (2016)
122
T. Mujtaba and M. A. Wani
14. F. Yu, V. Koltun, T. Funkhouser, Dilated residual networks, in Proceedings of the IEEE Conference on Computer vision and Pattern Recognition (2017), pp. 472–480 15. J. Camilo, R. Wang, L.M. Collins, K. Bradbury, J.M. Malof, Application of a semantic segmentation convolutional neural network for accurate automatic detection and mapping of solar photovoltaic arrays in aerial imagery (2018). arXiv preprint arXiv:1801.04018 16. J.M. Malof, L.M. Collins, K. Bradbury, A deep convolutional neural network, with pre-training, for solar photovoltaic array detection in aerial imagery, in 2017 IEEE International Geoscience and Remote Sensing Symposium (IGARSS) (IEEE, 2017), pp. 874–877 17. J. Yuan, H.H.L. Yang, O.A. Omitaomu, B.L. Bhaduri, Large-scale solar panel mapping from aerial images using deep convolutional networks, in 2016 IEEE International Conference on Big Data (Big Data) (IEEE, 2016), pp. 2703–2708 18. K. Bradbury, R. Saboo, T.L. Johnson, J.M. Malof, A. Devarajan, W. Zhang, R.G. Newell, Distributed solar photovoltaic array location and extent dataset for remote sensing object identification. Sci. Data 3, 160106 (2016) 19. K. Simonyan, A. Zisserman, Very deep convolutional networks for large-scale image recognition (2014). arXiv preprint arXiv:1409.1556 20. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2016), pp. 770–778 21. G. Huang, Z. Liu, L. Van Der Maaten, K.Q. Weinberger, Densely connected convolutional networks, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2017), pp. 4700–4708 22. H. Noh, S. Hong, B. Han, Learning deconvolution network for semantic segmentation, in Proceedings of the IEEE International Conference on Computer Vision (2015), pp. 1520– 1528 23. H. Zhao, J. Shi, X. Qi, X. Wang, J. Jia, Pyramid scene parsing network, in Proceedings of the IEEE International Conference on Computer Vision and Pattern Recognition (Honolulu, HI, USA, 2017), pp. 2881–2890 24. F. Milletari, N. Navab, S.A. Ahmadi, V-net: fully convolutional neural networks for volumetric medical image segmentation, in 2016 Fourth International Conference on 3D Vision (3DV) (IEEE, 2016), pp. 565–571
Training Deep Learning Sequence Models to Understand Driver Behavior Shokoufeh Monjezi Kouchak and Ashraf Gaffar
Abstract Driver distraction is one of the leading causes of fatal car accidents in the U.S. Analyzing driver behavior using machine learning and deep learning models is an emerging solution to detect abnormal behavior and alarm the driver. Models with memory such as LSTM networks outperform memoryless models in car safety applications since driving is a continuous task and considering information in the sequence of driving data can increase the model’s performance. In this work, we used time-sequenced driving data that we collected in eight driving contexts to measure the driver distraction. Our model is also capable of detecting the type of behavior that caused distraction. We used the driver interaction with the car infotainment system as the distracting activity. A multilayer neural network (MLP) was used as the baseline and two types of LSTM networks including the LSTM model with attention network and the encoder–decoder model with attention were built and trained to analyze the effect of memory and attention on the computational expense and performance of the model. We compare the performance of these two complex networks to that of the MLP in estimating driver behavior. We show that our encoder–decoder with attention model outperforms the LSTM attention while using LSTM networks with attention enhanced training process of the MLP network. Keywords Bidirectional · LSTM network · Attention network · Driver distraction · Deep learning · Encoder–decoder
S. M. Kouchak · A. Gaffar (B) Arizona State University, Tempe, USA e-mail: [email protected] S. M. Kouchak e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_6
123
124
S. M. Kouchak and A. Gaffar
1 Introduction Driver distraction is one of the leading causes of fatal car accidents in the U.S., and—while it is becoming an epidemic—it is preventable [1]. Based on the NHTSA report, 3,166 people died in fatal car accidents that are caused by distracted drivers in 2017, which constitutes 8.5% of all fatalities [2, 3]. Distracted driving is any task that diverts the driver’s attention from the primary task of driving such as texting, reading, talking to other passengers, and using the car infotainment system [4, 5]. Texting is the most distracting task since the driver needs to take their eyes off the road for an estimated five seconds to read or send a text. If the driver was driving at 45 mph, they will have covered a distance equal to a football field’s length without seeing the road [4]. Driver distraction is considered a failure in driving task prioritization by putting more attention on secondary tasks, such as reading a text or tuning the radio, while ignoring the primary task of driving. There are four types of distraction including. Visual distraction is any task that makes the driver to take his eyes off the road. Manual distraction happens when the driver takes his hands off the steering wheel [6]. Cognitive distraction is any task that diverts driver attention from the primary task of driving [7]. Audio distraction is any noise that obscures important voice in the car such as alarms or outside such as ambulance vehicle sirens [8]. Designing a car-friendly user interface [9], advanced driver assistant systems (ADAS) [10–12], and autonomous vehicles (AV) [13] are some approaches to solve the problem. Driver assistant systems such as lane keeping and adaptive cruise control reduce human error by automating difficult or repetitive tasks. Although they can enhance driver safety, they have some limitations as they often work only in predefined situations. For instance, collision mitigation systems are based on algorithms that have been defined by systems developers to cover specific situations that could lead to a collision in traffic. These algorithms can’t identify a full range of dangerous situations that could lead to an accident. They typically detect and prevent only a limited number of predefined menace situations [14]. An autonomous vehicle can sense its environment using a variety of sensors such as Radar and LIDAR and move the car with limited or no human contribution. They face some challenges such as dealing with driver cars, unsuitable road infrastructures, and cybersecurity. In the long run, when these challenges are solved, autonomous vehicles could boost car safety and decrease commuting time and cost [15]. Monitoring and analyzing driver behavior using machine learning methods to detect abnormal behavior is an emerging solution to enhance car safety [16]. Human error which can increase due to mental workloads persuaded by distraction is one of the leading causes of car crashes. Workloads are hard to observe and quantify; so, analyzing driver behavior to distinguish normal and aggressive driving behavior can be used as a suitable approach to detect driver distraction and alarm the driver to take the control of the car for a short time [17]. Driving data and driver status
Training Deep Learning Sequence Models …
125
that can be collected from internal sources (like the vehicle’s onboard computer and the data bus) as well as external devices such as camera, can be used to train a variety of machine learning methods such as Markov model, neural network, and decision tree to learn driving patterns and detect abnormal driver behavior [18–22]. Deep learning methods such as convolutional neural network, LSTM networks, and encoder–decoder outperform other machine learning methods in car safety applications such as pedestrian detection and driver behavior classification [23–26]. Some machine learning and deep learning methods make a spontaneous decision based on the current inputs of the model. In some real-world applications like driving, each data sample has an effect on several next samples, so using a temporal dimension and adding memory and attention can extract more behavior-representative information from a sequence of data, which could improve the accuracy of these applications. In this work, we use three models including a multilayer neural network (MLP), an LSTM network with attention layer, and an encoder–decoder model to predict the driver status using only driving data. No intrusive devices such as cameras or driverwearable sensors are used. This makes our approach more user-friendly and increases its potential to be adapted in industry more readily. The MLP model was considered as the baseline and memoryless model. Two other models consider the weighted dependency between input and output [27, 28]. This provides instant feedback on the progress of the network model, which increases its prediction accuracy and reduces its training time. We started with a single-input–single-output neural network and compared the accuracy and training process of this model with an LSTM model with attention, which is multiple-input–single-output and has both memory and attention. The LSTM model with attention achieved less train and test error in a smaller number of training epochs. The average number of training epochs is 400 for MLP and 100 for LSTM attention. Besides, the LSTM attention model has a smaller number of layers. After that, we used the encoder–decoder with attention to estimate a sequence of driver behavior using a multiple-input–multiple-output model. We compared the achieved results with two other models. The train and test error of this model is less than two other models and it can estimate multiple driver behavior vectors. Section 2 discusses some related works. Section 3 describes the experiment. Section 4 explains the features of the collected data in the experiment. We discuss methodology in Sect. 5. Section 6 talks about the details of the three models. Section 7 is the results and Sect. 8 is the conclusion.
2 Related Works Wöllmer et al. [29] discussed a driver distraction detection method. The goal of this method is to model head tracking and the context of driving data using an LSTM network. An experiment was conducted with 30 volunteers that drove an Audi in a straight road in Germany with one lane for each direction and 100 km/h speed limit. Distracted driving data were collected by an interface to measure the vehicle’s CANBus data and a head tracking system installed in the car cockpit. Eight distracting
126
S. M. Kouchak and A. Gaffar
tasks were chosen including radio, CD, phone book, navigation point of interest, phone, navigation, TV, and navigation sound. Distracting functions were available through eight hard keys, which were located on the left and right sides of the interface. In all, 53 non-distracted and 220 distracted runs were done. Collected data from the CAN-Bus and head tracking system were fed to an LSTM network to predict the driver’s state continuously. The accuracy of the model reached 96.6%. Although this approach has high accuracy, it needs some external devices that are not available in all cars, and is often considered an intrusion by drivers. Additionally, it considers one driving context and a straight road. If generalized to complex driving contexts, more work needs to be done to test the accuracy of the model’s performance which might degrade when other types of roads and driving conditions are introduced as this would result in more complex contextual data and hence a larger number of patterns to differentiate between. Xu et al. [30] introduced an image captioning model with attention mechanisms. It is inspired by the language translation model. The input of the model is a raw image and it produces a caption, which describes the input image. In this model, the encoder extracts features of the image using the middle layers instead of using a fully connected layer. The decoder is an LSTM network that produces the caption as a sequence of words. The model shows good results using three benchmark datasets using the METEOR and BLEU metric. Xiao et al. [31] discussed an image classification model. It uses a two-level attention model for fine-grained image classification in deep convolutional neural networks. Fine-grained classification is detecting subordinate level categories under some basic categories. This model is based on an intuition that for fine-grained classification, first the object needs to be detected and in the next step the discriminative parts of the object should be detected. For object detection level, the model uses raw images and removes noise to detect the objects and classify them. In the second level, it filters the object using mid-level CNN filters to detect parts of the image and an SVM classifier is used to classify the image’s parts. The model was validated on the subset of the ILSVRC29112 dataset and the CUB200 2011 dataset and it showed good performance under the weakest supervision condition. Huang et al. [32] proposed a machine translation model with attention to translate image captions from English to German. Additional information from the image is considered by the model to solve the problem of ambiguity in languages. A convolutional neural network is used to extract the image features, and the model adds these features to the text features to enhance the performance of the LSTM network that is used to generate the caption in the target language. Regional features of the image are used instead of general features. In sum, the best performance of the model shows a 2% improvement in the BLEU score and 2.3% enhancement in the METEOR dataset compared to models that only consider text information. Lv et al. [33] introduced a deep learning model for traffic flow prediction that considers both temporal and spatial correlations inherently. The model used a stacked autoencoder to detect and learn features of traffic flow. The traffic data was collected from 15000 detectors in the freeway system across California during weekdays in the first three months of 2013. Collected data in the first two months were used
Training Deep Learning Sequence Models …
127
as the training data and the test dataset was the collected data in the third month. The model outperformed previous models in medium and high traffic, but it didn’t perform well in low traffic. They compared the performance of the model with Random Walk Forecast method, Support Vector Machine, Backpropagation Neural Network, and Radial Basis Function. Stacked autoencoder model, for the 15 min traffic flow prediction in 86% of highways, reached more than 90% accuracy and outperformed the other four shallow models. Saleh et al. [34] discussed a novel method for driver behavior classification using a stacked LSTM network. Nine sensory data are captured using internal sensors of a cellphone during realistic driving sessions. Three driving behavior including normal, drowsy, and aggressive were defined, and driver behavior classification problem was modeled as a time series classification. A sequence of driving feature vectors was used as the input of a stacked LSTM network and the network classified the driver behavior accurately. Besides, the model achieved better results on UAH-DriveSet, which is a naturalistic driver behavior dataset, compared to the baseline approach. This model was compared with other common driver behavior classification methods including decision tree (DT) and multilayer perceptron (MLP). DT and MLP achieved 51 and 75% accuracy, so the proposed method outperformed them by 86% accuracy. These works used different types of LSTM networks to detect driver behavior and driving patterns. In this work, we use two types of LSTM networks including a bidirectional LSTM network and an LSTM network with attention layer to predict driver behavior using only driving data. We use simple LSTM network as the baseline to compare the results of these two networks with simple LSTM network’s result.
3 Experiment We conducted an experiment to observe and model driver’s behavior using a simulated modern car environment and collect a large body of driving patterns under different conditions to be used for training three models including MLP, LSTM network with attention, and encoder–decoder with attention. The experiment was conducted using a Drive Safety Research simulator DS-600, a fully integrated, high performance, high fidelity driving simulation system which includes multi-channel audio/visual system, a minimum 180° wraparound display, full-width automobile cab (Ford Focus) including windshield, driver and passenger seats, center console, dash and instrumentation as well as real-time vehicle motion simulation. The simulated immersive view of the road was provided by three large screens in front and both sides of the car as well as three synchronized mirrors, a rear-view and two side mirrors. The simulator provides different types of roads and driving contexts. We designed an urban road with three left turns and a curve highway part. Figure 1 shows the designed road. Flash A shows the start and end point of this road and flash B shows the highway part of the road.
128
S. M. Kouchak and A. Gaffar
Fig. 1 Designed road
Four contexts of driving were defined for this experiment including Day, Night, Fog, and Fog and Night. We used an android application to simulate the car infotainment system. The application was hosted on an Android v4.4.2 based Samsung Galaxy Tab4 8.0 which was connected to the Hyper Drive simulator. This setup allowed us to control the height and angle of the infotainment system. The removable tablet further allowed the control of screen size and contents. Our minimalist design was used in the interface design of this application. In this design, the main screen of the application shows six groups of car applications and the driver has access to more features under each group. In each step of the navigation, a maximum of six icons were displayed on the screen; which was tested earlier and shown to be a suitable number in infotainment UI design [35]. The minimalist design interaction allowed each application in this interface to be reached within a maximum of four steps of navigation from the main screen; hence conforming with NHTSA guidelines DOT HS 812 108 of acceptable distraction [3]. This helped us to normalize the driver’s interaction tasks and standardize them between different driving contexts.
3.1 Participants We invited 35 volunteers to participate in our simulator experiment. Each volunteer was asked to take a 45-min simulated drive, which was divided into eight driving scenarios including four distracted and four non-distracted ones. Volunteers were graduate and undergraduate students of our university in the range of 20–35 years old and had at least two years of driving experience. Before starting the experiment, each volunteer was trained for 10–15 min until they became familiar with the car interface and with driving in the simulator. We defined four driving contexts including Day, Fog, Night, and Fog and Night. For each context, we had one distracted and one non-distracted scenario. In the non-distracted scenarios, the volunteers were asked to focus on driving and to try to drive as realistic as possible. In distracted scenarios, we
Training Deep Learning Sequence Models …
129
distracted drivers by asking them to perform some tasks on the simulator’s infotainment system. In this experiment, a distracting task is defined as reaching a specific application on the designed car interface. We classified all possible applications in the interface to three groups based on the number of navigation steps that the driver needs to pass from the main screen to reach them and called them two-step, three-step, and four-step tasks. In each distracting scenario, we chose some tasks from each group and asked the driver to do them. We put equal time intervals between tasks. It means when the driver finished a task, we waited a few seconds before asking for the next task. In distracted scenarios, we observed the driver behavior and collected four features for each task including the number of errors that the driver did during the task, response time which shows the time range from the moment that the driver was asked to do a task and the time the task was completed, the mode of driving which is the current driving scenario, and the number of navigation steps that the driver needs to pass from the main screen to a specific application in the interface to complete the task.
4 Data The simulator collected an average of 19000 data vectors per trip, with 53 drivingrelated features each. In sum, during 280 trips 5.3 million data vectors were collected. We used a paired t-test to select the most significant features. Based on the results of pair t-tests, we chose 10 driving features including velocity, speed, steering wheel, brake, lateral accelerating, headway distance, headway time, accelerating, longitude accelerating, and lane position. We set the master data sampling rate at 60 samples per second. While the high sample rate was useful in other experiments, using this high sampling rate made the model computationally expensive, so we compressed the collected data and averaged every 20 samples into one vector. We made a dataset of driving-related data vectors with 10 features, the Vc. In distracted scenarios, volunteers did 2025 tasks in sum. For each task, we defined four features for driver behavior while interacting with our designed interface, the Vh: 1. The name of each scenario that shows if the driver is distracted or not, and based on our previous experiments and the current collected data, the adverse (Fog, Night) and double adverse (Fog and Night) driving contexts adversely affect drivers’ performance. 2. The number of driver errors during the task. We defined error as touching a wrong icon or not following driving rules while interacting with the interface. 3. Response time, which is the length of the task from the moment that we ask the driver to start the task and the moment that the task is done. 4. The number of navigation steps that the driver needs to pass to reach the application and finish the task. All tasks could be completed within four steps or less.
130
S. M. Kouchak and A. Gaffar
For each driver data vector (Vh), several car data vectors (Vc) were collected by the simulator. The number of Vc vectors linked to each task depends on the response time of the task. To map Vc vectors to the corresponding Vh, we divided the Vc vectors of each trip based on the length of tasks’ response time of that trip to Vc blocks with different length and map each block to one Vh vector. In (1) N is the number of tasks that were executed in a trip and per(taski ) is the approximated percentage of Vc vectors in the trip that is related to taski . per(taski ) =
Responset imei
N
k=1
Responset imek
∗ 100
(1)
We used driving data (Vc) as input of a multi-input–single-output bidirectional LSTM network to predict the correlated driver behavior feature vector (Vh). Then we added an attention layer to the model and analyzed its effect on the performance of the model.
5 Methodology Feedforward neural networks have been used in many car-related applications such as lane departure and sign detection [22–25]. In feedforward neural networks, data travels in one direction from the input layer to the output and the training process is based on the assumption that input data are fully independent of each other. Feedforward networks don’t have any impression of order in sequence or time and only the current sample is considered to tune the network’s hyperparameters [36]. We use a feedforward neural network as the baseline model to estimate the driver behavior using driving data. The model is a single-input–single-output network that uses the mean of driving data during each task as the input and estimates the corresponding driver behavior vector. In some real-life applications such as speech recognition, driving and image captioning input samples are not independent and there is valuable information in the sequence so the models with memory such as recurrent neural networks can outperform memoryless models. The LSTM model uses both the current sample and previously observed data in each step of training. It combines these two data to produce the network’s output [37]. The difference between feedforward networks and RNN is that the feedback loop in RNN provides previous steps’ information to the current one, adding memory to the network, which is preserved in some hidden states of the Network. LSTM with attention is a recurrent model that is a combination of the bidirectional LSTM layer and the attention layer. LSTM attention model uses the weighted effect of all driving data during each task to estimate the driver behavior vector [38]. This model has both memory and attention and considers all driving data during the task so our assumption is that this model would be more accurate compared to the feedforward model.
Training Deep Learning Sequence Models …
131
If we want to estimate a sequence of driver behavior with the sequence length N, we have three options including • Running the feedforward model N times. • Running the LSTM attention model N times. • Using sequence-to-sequence model. By using the memoryless feedforward model, we lose the dependency between samples and each output calculates independently. The LSTM attention model would be more accurate for each driver behavior because of memory and attention in the model but we lose the dependency of output sequence. Sequenceto-sequence learning considers dependency and information in input and output sequence. In sequence-to-sequence learning using an encoder–decoder model, the encoder encodes the input sequence to a context vector and the decoder uses the context vector to produce the output sequence. Similar to the multi-input–singleoutput LSTM model, attention can be added to the model to detect and learn the weight of each input in the input sequence on the current output more accurately [39]. We decided to use the encoder–decoder attention model with equal input and output sequence lengths. This model considers the mean of driving data during each task as the input and the correlated driver behavior as the output similar to the feedforward model. If we want to consider all driving data similar to the LSTM attention model, the sequence-to-sequence model would be very computationally expensive, so we tried the same input and output sequence length. Recent work [40] has investigated different types of machine learning, including unsupervised learning and convolutional neural networks. As the need to address other kinds of problems grow, the use of different types of machine learning solutions is attempted. One prohibitive challenge is the increasing need for computational power. Hence new optimization approaches are used to reduce the computational demand in order to make the new solutions technically tractable and scalable.
6 Models 6.1 Multilayer Neural Network Model We built and trained a memoryless multilayer (MLP) neural network using both scaled and unscaled data. This feedforward neural network was considered as the baseline to compare with the LSTM attention network and encoder–decoder attention model that both have attention and memory.
132
S. M. Kouchak and A. Gaffar
Fig. 2 Attention mechanism
6.2 Bidirectional LSTM Network with Attention Layer Attention is one of the fundamental parts of cognition and intelligence in humans [41, 42]. It helps reduce the amount of information processing and complexity [43, 44]. We can loosely define attention as directing some human senses (like vision), and hence the mind, to a specific source or object rather than scanning the entire input space [45]. This is an essential human perception component that helps increase processing power while reducing demand on resources [46]. In the neural network area, attention is primarily used as a memory mechanism that determines which part of the input sequence has more effect on the final output [47, 48]. The attention mechanism considers the weighted effect of each input on the model’s output instead of producing one context vector from all samples of the input sequence (Fig. 2) [49, 50]. We used the attention mechanism to estimate the driver behavior using a sequence of driving data vectors with 10 features and considering the weighted effect of each input driving data on the driver behavior. Our assumption is that using attention mechanism decreases the model’s training and test error and enhances the training process.
6.3 Encoder–Decoder Attention The encoder–decoder model is a multi-input–multi-output neural network architecture. The encoder compresses the input sequence to a representative context vector and the decoder uses this vector to produce the output sequence. In the encoder– decoder with attention network, the encoder provides one context vector that is
Training Deep Learning Sequence Models …
133
filtered specifically for each output in the output sequence. Equation (2) shows the output of the encoder in the encoder–decoder without attention. In (2) h is the output vector of the encoder that contains information of all samples in the input sequence. In attention networks, the encoder produces one vector for each output (3) shows the encoder output in the attention network [49]. h = Encoder(x1 , x2 , x3 , . . . , x T , t)
(2)
[h 1 , h 2 , h 3 , . . . , h T ] = Encoder(x1 , x2 , x3 , . . . , x T )
(3)
The decoder produces one output at a time and the model scores how well the encoded input matches the current output. Equation (4) shows the scoring formula to encode input i in step t. In (4) St−1 shows the output from previous step and h i is the result of encoding input xi . In the next step, scores are normalized using (5) which is a “SoftMax” function. The context vector for each time step is calculated using (6). eti = a(St−1 , h i )
(4)
exp(eti ) ati = T j=0 exp(et j )
(5)
Ct =
T
at j ∗ h j
(6)
j=0
Figure 3 shows the encoder–decoder with attention model. The bidirectional LSTM layer provides access to the previous and next samples of the current sample in each step of training. The “SeqSelfAttention” layer of Keras was used in the model. This layer calculates the weighted effect of each input sample in the input sequence on each output sample of the output sequence. We considered equal length for input and output sequences as explained earlier. Three different models were built and trained including three input and three output steps model, four input and four output steps model and five input and five output steps model as elaborated below.
7 Results 7.1 MLP Results We built an MLP and trained it with both scaled and unscaled data. In this model, 80% of data were used for training and 20% of them for testing. We tried different numbers of layers in the range of 2–6 and a variety of hidden neurons in the range
134
S. M. Kouchak and A. Gaffar
Fig. 3 Encoder–decoder with attention
50–500 for this neural network. Table 1 shows some of the best-achieved results with scaled data and Table 2 Shows the achieved results with unscaled data. These results show the large difference between train and test error that means the model has overfitting problem with both scaled and unscaled data and it doesn’t generalize well in most cases. Table 1 MLP with scaled data
Layer
Neuron
Train MAE
Test MAE
Two
50
0.11
0.27
Two
150
0.082
0.23
Three
50
0.096
0.24
Three
150
0.056
0.32
Four
50
0.089
0.24
Four
150
0.012
0.2
Training Deep Learning Sequence Models … Table 2 MLP with unscaled data
135
Layer
Neuron
Train MAE
Test MAE
Two
150
0.54
1.41
Two
300
0.39
1.51
Three
150
0.33
1.48
Three
300
0.31
1.34
Four
150
0.4
1.46
Four
300
0.17
1.39
7.2 LSTM Attention Results In the next step, we trained an LSTM network with attention. We used Adam optimizer as the model’s optimizer and mean absolute error (MAE) as the accuracy of the model. We tried a wide range of LSTM neurons from 10 to 500. In this model, 80% of the dataset was used as the training set and 20% as the testing set. Table 3 shows some results with unscaled data. For unscaled data, the best result was achieved with 20 neurons and it is 0.85 training and 0.96. This model achieved less test error compared to the MLP model using a smaller number of hidden neurons. The best result of MLP achieved with four hidden layers and 300 neurons in each layer, so the MLP model compared to the LSTM attention model is less accurate and more computationally expensive. Besides, one layer of the LSTM model plus an attention layer had better performance compared to a large MLP model with four fully connected layers. In addition, the training process of the MLP model took around 400–600 epochs while the LSTM converged in around 100–200 epochs. Table 4 shows the best-achieved results with scaled data. The model with 40 neurons reached the minimum test error which is less than all cases in MLP network with scaled data except one case that is 4 hidden layer model with 150 neurons. In general, the LSTM attention model with scaled data generalized better than MLP in all cases and achieved better performance with a smaller number of hidden neurons and smaller network. Table 3 Bidirectional LSTM network with attention layer unscaled data
LSTM Neurons
Train MAE
Test MAE
20
0.85
0.96
30
0.85
0.99
40
0.85
1.01
100
0.85
1
200
0.85
0.99
300
0.96
0.99
136 Table 4 Bidirectional LSTM network with attention layer with scaled data
S. M. Kouchak and A. Gaffar LSTM neurons
Train MAE
Test MAE
30
0.23
0.24
40
0.22
0.22
60
0.22
0.23
100
0.24
0.25
150
0.22
0.23
200
0.23
0.23
7.3 Encoder–Decoder Attention Model As we mentioned earlier, to have a sequence of driver behavior data vectors we can use a sequence-to-sequence model instead of running a multi-input–single-output model multiple times. We chose the encoder–decoder model as a suitable sequence-tosequence model. We built and trained encoder–decoder attention models including three-step, four-step, and five-step models with both scaled and unscaled data. In these models, 80% of data was used as a training dataset and 20% of them were used for testing the model. Different combinations of batch size, activation function, and the number of LSTM neurons were tested. Finally, we chose Adam as the activation function, 100 as the batch size. We tried a range of LSTM neurons from 20 to 500. Besides, we tried different lengths of input and output sequences from 2 to 6. After the four-step model, increasing the length of the sequence didn’t have a positive effect on the model’s performance. Table 5 shows some of the best-achieved results with unscaled data. Mean absolute error (MAE) was used as the lost function of this model. The threestep model reached the minimum error which is 1.5 train and test mean absolute error. The four-step and the five-step models have almost the same performance. The threestep model with unscaled data and 100 LSTM neurons showed the best performance. Figure 4 shows this model’s mean absolute error. Table 6 shows the achieved results for the three models with scaled data. The three-step model with 250 LSTM neurons reached the minimum error (Fig. 5). The test error of the encoder–decoder attention model with unscaled data is close to the MLP model with scale data but the generalization of this model is much better Table 5 Encoder–decoder attention with unscaled data
LSTM neurons
Sequence
MAE train
MAE test
100
3
1.5
1.5
250
3
1.51
1.52
100
4
1.5
1.52
150
4
1.52
1.56
50
5
1.52
1.52
250
5
1.53
1.56
Training Deep Learning Sequence Models …
137
Fig. 4 The mean absolute error of a three-step model with 100 LSTM neuron and unscaled data Table 6 Encoder–decoder attention with scaled data
LSTM neurons
Sequence
MAE train
MAE test
150
3
0.1683
0.16
250
3
0.1641
0.16
100
4
0.1658
0.16
300
4
0.1656
0.16
150
5
0.1655
0.16
400
5
0.1646
0.17
Fig. 5 The mean absolute error of a three-step model with 250 LSTM neurons and scaled data
138
S. M. Kouchak and A. Gaffar
than the MLP model and the network is smaller than the MLP model. Besides, it converges in around 50 epochs on average which is much faster than the MLP model that needs around 400 epochs training on average. Moreover, this model estimates multiple driver behavior vectors in one run. The error of this model with not scaled data is more than LSTM attention model but the generalization of this model is better and the number of input samples is much less than LSTM attention model since this model consider one driving data for each driver behavior vector, so it is computationally less expensive. The encoder–decoder attention with scaled data outperformed both the MLP model and the LSTM attention model. This model has memory and attention similar to the LSTM attention model. Besides, it considers the dependency between samples of output sequence which is not possible if we run a multi-input–single-output model multiple times to have a sequence of driver behavior vectors. Besides, this model is computationally less expensive than the LSTM attention model since it considers one input vector corresponding to each output, similar to the MLP model. The minimum test error of this model with scaled data is 0.06 less than the minimum test error of the LSTM attention model and 0.04 less than the minimum test error of the MLP model. The encoder–decoder attention model converges with less error and takes less time than two other models. The average number of epochs for this model with unscaled and scaled data was between 50 and 100 epochs, which are half of the average epochs of the LSTM attention model and one-fourth of MLP epochs. Besides, the model uses a smaller number of layers compared to the MLP model and a smaller number of input samples compared to the LSTM attention model making it computationally less expensive. It also estimates multiple driver behavior data vectors in each run.
8 Conclusion Using machine learning methods to monitor driver behavior and detect the driver’s inattention that can lead to distraction is an emerging solution to detect and reduce driver distraction. Different deep learning methods such as CNN, LSTM, and RNN networks were used in several car safety applications. Some of these methods have memory, so they can extract and learn information in a sequence of data. There is some information in the sequence of driving data that can’t be impeded from processing them manually. Driving data samples are not independent of each other, so methods that have memory and attention such as recurrent models are a better choice for higher intelligence and hence more reliable car safety applications. These methods utilize different mechanisms to perceive the dependency between samples and extract the latent information in the sequence. We chose the MLP network which is a simple memoryless deep neural network as the baseline for our model. Then we trained an LSTM attention model that has memory and attention. This model outperforms the MLP model with both scaled and unscaled data. The model trained at least two times faster than the MLP model and achieved better
Training Deep Learning Sequence Models …
139
performance with less hidden neurons, smaller network, and a smaller number of training epochs. In order to have a sequence of driving data we have two options: run a multi-input– single-output model multiple times and using a sequence-to-sequence model. We built and trained an encoder–decoder attention model with both scaled and unscaled data to have a sequence of driver behavior data vectors. This model outperforms the MLP model with both scaled and unscaled data. Besides, this model outperformed the LSTM attention model with scaled data. Encoder–decoder attention model trained at least two times faster than the LSTM attention model and four times faster than the MLP model. Besides, in each run, it estimates multiple driving data vectors and it had the best generalization and minimum difference between train and test error. Our work shows that this would be a viable and scalable option for deep neural network models that work in real-life complex driving contexts without the need to use intrusive devices. It also provides an objective measurement of the added advantages of using attention networks to reliably detect driver behavior.
References 1. B. Darrow, Distracted driving is now an epidemic in the U.S., Fortune (2016). http://fortune. com/2016/09/14/distracted-driving-epidemic/ 2. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic Safety Facts Research Note, Report No. DOT HS 812 700) (Washington, DC, National Highway Traffic Safety Administration, 2019) 3. N. Chaudhary, J. Connolly, J. Tison, M. Solomon, K. Elliott, Evaluation of the NHTSA distracted driving high-visibility enforcement demonstration projects in California and Delaware. (Report No. DOT HS 812 108) (Washington, DC, National Highway Traffic Safety Administration, 2015) 4. National Center for Statistics and Analysis, Distracted driving in fatal crashes, 2017, (Traffic safety facts research Note, Report No. DOT HS 812 700), (Washington, DC: National Highway Traffic Safety Administration, 2019) 5. S. Monjezi Kouchak, A. Gaffar, Driver distraction detection using deep neural network, in The Fifth International Conference on Machine Learning, Optimization, and Data Science (Siena, Tuscany, Italy, 2019) 6. J. Lee, Dynamics of driver distraction: the process of engaging and disengaging. Assoc. Adv. Autom. Med. 58, 24–35 (2014) 7. T. Hirayama, K. Mase, K. Takeda, Analysis of temporal relationships between eye gaze and peripheral vehicle behavior for detecting driver distraction. Hindawi Publ. Corp. Int. J. Veh. Technol. 2013, 8 (2013) 8. National Highway Traffic Safety Administration. Blueprint for Ending Distracted Driving. Washington, DC: U.S. Department of Transportation. National Highway Traffic Safety Administration, DOT HS 811 629 (2012) 9. T.B. Sheridan, R. Parasuraman, Human-automation interaction. reviews of human factors and ergonomics, vol. 1, pp. 89–129 (2015). https://doi.org/10.1518/155723405783703082 10. U. Hamid, F. Zakuan, K. Zulkepli, M. ZulfaqarAzmi, H. Zamzuri, M. Rahman, M. Zakaria, Autonomous Emergency Braking System with Potential Field Risk Assessment for Frontal Collision Mitigation (IEEE ICSPC, Malaysia, 2017) 11. L. Li, D. Wen, N. Zheng, L. Shen, Cognitive cars: a new frontier for ADAS research. IEEE Trans. Intell. Transp. Syst. 13 (2012)
140
S. M. Kouchak and A. Gaffar
12. S. Monjezi Kouchak, A. Gaffar, Estimating the driver status using long short term memory, in Machine Learning and Knowledge Extraction, Third IFIP TC 5, TC 12, WG 8.4, WG 8.9, WG 12.9 International Cross-Domain Conference, CD-MAKE 2019 (2019). https://doi.org/10. 1007/978-3-030-29726-8_5 13. P. Koopman, M. Wagner, Autonomous vehicle safety: an interdisciplinary challenge. IEEE Intell. Transp. Syst. Mag. 9, 90–96 (2017) 14. M. Benmimoun, A. Pütz, A. Zlocki, L. Eckstein, euroFOT: field operational test and impact assessment of advanced driver assistance systems: final results, in SAE-China, FISITA (eds) Proceedings of the FISITA 2012 World Automotive Congress. Lecture Notes in Electrical Engineering, vol. 197 (Springer, Berlin, Heidelberg, 2013) 15. S. Monjezi Kouchak, A. Gaffar, Determinism in future cars: why autonomous trucks are easier to design, in IEEE Advanced and Trusted Computing (ATC 2017) (San Francisco Bay Area, USA, 2017) 16. S. Kaplan, M.A. Guvensan, A.G. Yavuz, Y. Karalurt, Driver behavior analysis for safe driving: a survey, IEEE Trans. Intell. Transp. Syst. 16, 3017–3032 (2015) 17. A. Aksjonov, P. Nedoma, V. Vodovozov, E. Petlenkov, M. Herrmann, Detection and evaluation of driver distraction using machine learning and fuzzy logic. IEEE Trans. Intell. Transp. Syst. 1–12 (2018). https://doi.org/10.1109/tits.2018.2857222 18. R. Harb, X. Yan, E. Radwan, X. Su, Exploring precrash maneuvers using classification trees and random forests, Accid. Anal. Prev. 41, 98–107 (2009) 19. A. Alvarez, F. Garcia, J. Naranjo, J. Anaya, F. Jimenez, Modeling the driving behavior of electric vehicles using smartphones and neural networks. IEEE Intell. Transp. Syst. Mag. 6, 44–53 (2014) 20. J. Morton, T. Wheeler, M. Kochenderfer, Analysis of recurrent neural networks for probabilistic modeling of driver behavior. IEEE Trans. Intell. Transp. Syst. 18, 1289–1298 (2017) 21. A. Sathyanarayana, P. Boyraz, J. Hansen, Driver behavior analysis and route recognition by Hidden Markov models, IEEE International Conference on Vehicular Electronics and Safety (2008) 22. J. Li, X. Mei, D. Prokhorov, D. Tao, Deep neural network for structural prediction and lane detection in traffic scene. IEEE Trans. Neural Netw. Learn. Syst. 28, 14 (2017) 23. S. Monjezi Kouchak, A. Gaffar, Non-intrusive distraction pattern detection using behavior triangulation method, in 4th Annual Conference on Computational Science and Computational Intelligence CSCI-ISAI (USA, 2017) 24. S. Su, B. Nugraha, Fahmizal, Towards self-driving car using convolutional neural network and road lane detector, in 2017 2nd International Conference on Automation, Cognitive Science, Optics, Micro Electro-Mechanical System, and Information Technology (ICACOMIT) (Jakarta, Indonesia, 2017), p. 5 25. S. Hung, I. Choi, Y. Kim, Real-time categorization of driver’s gaze zone using the deep learning techniques, in 2016 International Conference on Big Data and Smart Computing (BigComp) (2016), pp. 143–148 26. A. Koesdwiady, S. Bedavi, C. Ou, F. Karray, End-to-end deep learning for driver distraction recognition. Springer International Publishing AG 2017 (2017), p. 8 27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT press, 2016), ISBN: 9780262035613 28. J. Schmidhuber, Deep learning in neural networks: an overview, vol. 61, (Elsevier, 2015), pp. 85–117 29. M. Wöllmer, C. Blaschke, T. Schindl, B. Schuller, B. Färber, S. Mayer, B. Trefflich, Online driver distraction detection using long short-term memory. IEEE Trans. Intell. Transp. Syst. 2(2), 574–582 (2011) 30. K. Xu, J.L. Bay, R. Kirosy, K. Cho, A. Courville, R. Salakhutdinovy, R.S. Zemely, Y. Bengio, Show, attend and tell: neural image caption generation with visual attention, in 32 nd International Conference on Machine Learning (Lille, France, 2015) 31. T. Xiao, Y. Xu, K. Yang, J. Zhang, Y. Peng, Z. Zhang, The application of two-level attention models in deep convolutional neural network for fine-grained image classification, in The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2015), pp. 842–850
Training Deep Learning Sequence Models …
141
32. P. Huang, F. Liu, S. Shiang, J. Oh, C. Dyer, Attention-based multimodal neural machine translation, in Proceedings of the First Conference on Machine Translation, Shared Task Papers, vol. 2, (Berlin, Germany, 2016), pp. 639–645 33. Y. Lv, Y. Duan, W. Kang, Z. Li, F. Wang, Traffic flow prediction with big data: a deep learning approach. IEEE Trans. Intell. Transp. Syst. 16, 865–873 (2015) 34. K. Saleh, M. Hossny, S. Nahavandi, Driving behavior classification based on sensor data fusion using LSTM recurrent neural networks, in IEEE 20th International Conference on Intelligent Transportation Systems (ITSC) (2017) 35. A. Gaffar, S. Monjezi Kouchak, Minimalist design: an optimized solution for intelligent interactive infotainment systems, in IEEE IntelliSys, the International Conference on Intelligent Systems and Artificial Intelligence (London, UK, 2017) 36. C. Bishop, Pattern Recognition and Machine Learning (Springer). ISBN-13: 978-0387310732 37. M. Magic, Action recognition using Python and recurrent neural network, First edn. (2019). ISBN: 978-1798429044 38. D. Mandic, J. Chambers, Recurrent neural networks for prediction: learning algorithms, architectures and stability, First edn. (Wiley, 2001). ISBN: 978-0471495178 39. J. Rogerson, Theory, Concepts and Methods of Recurrent Neural Networks and Soft Computing (2015). ISBN-13: 978-1632404930 40. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020) 41. A. Gaffar, E. M. Darwish, A. Tridane, Structuring heterogeneous big data for scalability and accuracy. Int. J. Digit. Inf. Wirel. Commun. 4, 10–23 (2014) 42. A. Gaffar, H. Javahery, A. Seffah, D. Sinnig, A pattern framework for eliciting and delivering UCD knowledge and practices, in Proceedings of the Tenth International Conference on Human-Computer Interaction (2003), pp. 108–112 43. A. Gaffar, Enumerating mobile enterprise complexity 21 complexity factors to enhance the design process, in Proceedings of the 2009 Conference of the Center for Advanced Studies on Collaborative Research (2009), pp. 270–282 44. A. Gaffar, The 7C’s: an iterative process for generating pattern components, in 11th International Conference on Human-Computer Interaction (2005) 45. J. Bermudez, Cognitive Science: An Introduction to the Science of the Mind, 2nd edn. (2014). 978-1107653351 46. B. Garrett, G. Hough, Brain & Behavior: an Introduction to Behavioral Neuroscience, 5th edn. (SAGE). ISBN: 978-1506349206 47. Y. Wang, M. Huang, L. Zhao, X. Zhu, Attention-based LSTM for aspect-level sentiment classification, in Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing (2016), pp. 606–615 48. I. Sutskever, O. Vinyals, Q. Le, Sequence to sequence learning with neural networks, in Advances in Neural Information Processing Systems 27 (NIPS 2014) (2014) 49. S. Frintrop, E. Rome, H. Christenson, Computational visual attention systems and their cognitive foundations: a survey, ACM Trans. Appl. Percept. (TAP) 7 (2010). https://doi.org/10.1145/ 1658349.1658355 50. Z. Yang, D. Yang, C. Dyer, X. He, A Smola, E. Hovy, Hierarchical attention networks for document classification, in NAACL-HLT 2016 (San Diego, California, 2016), pp. 1480–1489
Exploiting Spatio-Temporal Correlation in RF Data Using Deep Learning Debashri Roy, Tathagata Mukherjee, and Eduardo Pasiliao
Abstract The pervasive presence of wireless services and applications have become an integral part of our lives. We depend on wireless technologies not only for our smartphones but also for other applications like surveillance, navigation, jamming, anti-jamming, radar to name a few areas of applications. These recent advances of wireless technologies in radio frequency (RF) environments have warranted more autonomous deployments of wireless systems. With such large scale dependence on use of the RF spectrum, it becomes imperative to understand the ambient signal characteristics for optimal deployment of wireless infrastructure and efficient resource provisioning. In order to make the best use of such radio resources in both the spatial and time domains, past and current knowledge of the RF signals are important. Although sensing mechanisms can be leveraged to assess the current environment, learning techniques are the typically used for analyzing past observations and to predict the future occurrences of events in a given RF environment. Machine learning (ML) techniques, having already proven useful in various domains, are also being sought for characterizing and understanding the RF environment. Some of the goals of the learning techniques in the RF domain are transmitter or emitter fingerprinting, emitter localization, modulation recognition, feature learning, attention and saliency, autonomous RF sensor configuration and waveform synthesis. Moreover, in largescale autonomous deployments of wireless communication networks, the signals received from one component play a crucial role in the decision-making process of other components. In order to efficiently implement such systems, each component D. Roy (B) Computer Science, University of Central Florida, Orlando, FL 32826, USA e-mail: [email protected] T. Mukherjee Computer Science, University of Alabama, Huntsville, AL 35899, USA e-mail: [email protected] E. Pasiliao Munitions Directorate, Air Force Research Laboratory, Eglin AFB, Valparaiso, FL 32542, USA e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_7
143
144
D. Roy et al.
of the network should be uniquely identifiable. ML techniques, that include recurrent structures, have shown promise in creating such autonomous deployments using the idea of radio frequency machine learning (RFML). Deep learning (DL) techniques with the ability to automatically learn features, can be used for characterization and recognition of different RF properties by automatically exploiting the inherent features in the signal data. In this chapter, we present an application of such deep learning techniques to the task of RF transmitter fingerprinting. The first section concentrates on the application areas in the field of RF where deep learning can be leveraged for futuristic autonomous deployments. Section 2 presents discussion of different deep learning approaches for the task of transmitter fingerprinting as well as the significance of leveraging recurrent structures through the use of recurrent neural network (RNN) models. Once we have established the basic knowledge and motivation, we dive deep into the application of deep learning for transmitter fingerprinting. Hence, a transmitter fingerprinting technique for radio device identification using recurrent structures, by exploiting the spatio-temporal properties of the received radio signal, is discussed in Sects. 3 and 4. We present three types of recurrent neural networks (RNNs) using different types of cell models: (i) long shortterm memory (LSTM), (ii) gated recurrent unit (GRU), and (iii) convolutional long short-term memory (ConvLSTM) for that task. The proposed models are also validated with real data and evaluated using a framework implemented using Python. Section 5 describes the testbed setup and experimental design. The experimental results, computational complexity analysis, and comparison with state of the art are discussed in Sect. 6. The last section summarizes the chapter. Keywords RF fingerprinting · Recurrent neural network · Supervised learning · Software-defined radios
1 Applications of Deep Learning in the RF Domain We are living in a world where the distances are shrinking every day, thanks to an explosion in the use of connected devices. The ubiquitous usage of wirelessly connected Internet-of-Things (IoT) [56] along with the deployment of wireless autonomous systems has ushered in a new era of industrial-scale deployment of RF devices. This prevalence of large-scale peer-to-peer communication and nature of the underlying ubiquitous network brings forth the challenge of accurately identifying a RF transmitter. Every device that is part of a large network needs to be able to identify its peers with high confidence in order to set up secure communication channels. One of the ways in which this is done is through the interchange of “keys” [42] for host identification. However, such schemes are prone to breaches by malicious agents [50] because often the actual implementations of such systems are not cryptographically sound. In order to get around the problem of faulty implementations, one can use the transmitter’s intrinsic characteristics to create a “fingerprint” that can be used by a transmitter identification system. Every transmitter, no matter how similar,
Exploiting Spatio-Temporal Correlation in RF Data …
145
has intrinsic characteristics because of the imperfections in its underlying components such as amplifiers, filters, frequency mixers as well as the physical properties of the transmitting antenna; these characteristics are unique to a specific transmitter. The inaccuracies present in the manufacturing process and the idiosyncrasies of the hardware circuitry also contribute to the spatial and temporal characteristics of the signal transmitted through a particular device.
1.1 Transmitter Identification This inherent heterogeneity can be exploited to create unique identifiers for the transmitters. One such property is the imbalance in the Inphase (I ) and Quadrature (Q) phase components of the signal (I/Q data). However, because of the sheer number of the transmitters involved, manually “fingerprinting” each and every transmitter is not a feasible task [6]. Thus, in order to build such a system, there needs to be an “automatic” method of extracting the transmitter characteristics and using the resulting “fingerprint” for the differentiation process. One way of achieving this, is by learning the representation of the transmitter in an appropriate “feature space” that has enough discriminating capability so as to be able to differentiate between “apparently identical” transmitters. Hence, the concept of “transmitter fingerprinting” came into light for the task for transmitter classification or identification. For the rest of the chapter we use the phrases “transmitter classification” and “transmitter recognition” interchangeably.
1.2 Modulation Recognition The radio frequency domain consists of a vast range of electro-magnetic frequencies between 3 Hz and 3 THz [19]. Different wireless communication types, such as local area network, broadcast radio, wide area network, voice radios, radar, and several others use different bands of radio frequencies depending on the requirements [52]. They also use variable types of modulation schemes depending on the used frequency band and application area. Modulation is the technique to change the data to be transmitted on an RF carrier. Modulation recognition is the task of identifying the modulation type of the received signal. To that end, the modulation recognition can be leveraged to identify the type of communication the transmitter is using, even without the knowledge of used frequency band. Deep convolution networks have shown great prospect in successfully classifying among different types of modulation schemes, details of which are presented in next section.
146
D. Roy et al.
2 Approaches for Implementing Deep Learning in RF Domain In this section, we present the evolution of researches for transmitter identification and modulation recognition from traditional approaches to the deep learning-based ones. First, we discuss few of the traditional methods for the transmitter identification task. These traditional methods use manual feature engineering and leverage different radio attributes like transients, or spurious modulations to create discriminating feature sets. A transient signal is transmitted when a transmitter is powered up or powered down. During this short period (typically a few microseconds), capacitive loads charge or discharge. Different classification approaches using transient-based recognition was proposed in [18, 45, 51]. In [51], where the authors proposed a genetic algorithm-based solution for transmitter classification based on transients. A multifractal segmentation technique was proposed in [45] using the same concept of transients. Another transient-based transmitter classification was proposed in [18] using a k-nearest neighbor discriminatory classifier. However, these traditional approaches have extra overhead due to the feature extraction step and furthermore the quality of the solution is constrained by the type of feature selected, and therefore by the knowledge of the expert making that decision. To avoid such overheads, deep learning [55]-based methods can provide an efficient and automatic way of learning and characterizing the feature space within the inherent properties of transmitters. They are able to learn and analyze the inherent properties of large deployments and use it to predict and characterize the associated parameters for the task of automatic feature learning for classification (or regression). Moreover, the task of classification is equivalent to learning the decision boundary and neural networks were a natural candidate for a learning machine algorithm. To that end, neural networks have also previously been used for modulation recognition and transmitter identification [31, 33, 40] and are particularly attractive since they can generate accurate models without knowledge of the apriori data distribution. Next, we demonstrate various existing efforts of such neural network-based methods for different types of applications in the RF domain. We divide our discussion into three parts: (i) deep neural networks, (ii) convolutional neural networks, and (iii) recurrent neural network.
2.1 Deep Neural Networks We refer to deep neural networks (DNNs) as multiple layer feedforward network with backprogapation. DNNs have revolutionized the field of artificial intelligence in the last few years. The problems tackled by DNNs range over Computer Vision, Natural Language Processing, Speech Processing, and so on. They have been demonstrated to perform better than humans, for some of these problems. They have also been shown
Exploiting Spatio-Temporal Correlation in RF Data …
147
to be effective for automatically learning discriminating features from data for various tasks [13]. With proper choice of the neural network architecture and associated parameters, they can compute arbitrarily good function approximations [21]. In general, the use of deep learning in the RF domain has been limited in the past with only a few applications in recent times [30]. However, DNNs have already proven its applicability for the task of transmitter identification. A DNN-based approach was proposed in [36], where the authors classified among eight different transmitters with ∼95% accuracy. For the modulation recognition, a real-life demonstration of modulation classification using a smartphone app was presented in [49]. The authors designed the mobile application DeepRadioTM to distinguish between six different modulation schemes by implementing a four-layer DNN.
2.2 Deep Convolutional Neural Networks Fully connected DNNs are like standard neural networks but with a lot more hidden layers. However, these networks can be augmented with convolutional layers for faster training and for enabling the network to learn more compact and meaningful representations. Deep convolutional neural networks (DCNN) have been shown to be effective for several different tasks in communication. There have been quite a few attempts of using DCNN for learning inherent RF parameters for both modulation recognition and transmitter identification. In [29], the authors have demonstrated the use of convolutional neural networks for modulation detection. Apart from the results, an interesting aspect of the work is the way I/Q values were used as input to the neural network. More precisely, given N I/Q values, the authors used a vector of size 2N as an input to the neural network, effectively using the I and Q components as a tuple representing a point in the complex plane. This representation proves to be useful for using the I/Q data in different learning models. However, the authors considered the only spatial correlations to be exploited by the DCNN implementation for synthetic RF data [27]. More researches were done in [31] with same concept using the real RF data. Further investigations on modulation recognition were presented in [47] using DCNN. As of the transmitter identification task, an application of transmitter fingerprinting was presented in [14] for detecting new types of transmitters using an existing trained neural network model. Similarly, in [44], Sankhe et al. presented a DCNNbased large-scale transmitter identification technique with 80–90% accuracy. However, they could only achieve competitive accuracy by factoring in the controlled impairments at the transmitter. It is to be noted that all prior works have only exploited the spatial correlation of the signal data, though a continuous signal can be represented as a time series, having both temporal and spatial properties [54].
148
D. Roy et al.
2.3 Recurrent Neural Networks Recurrent neural networks (RNN) are capable of predictions with time series data. They have been used extensively for modeling temporal data such as speech [34]. There is limited amount of work that recognizes the potential of using recurrent structures in the RF domain. RNNs have also been shown to be useful for capturing and exploiting the temporal correlations of time series data [13]. There are a few variants of recurrent neural networks: (i) Long Short-Term Memory (LSTM) [16], (ii) Gated Recurrent Unit (GRU) [5], and (iii) Convolutional Long Short-Term Memory (ConvLSTM) [46]. All these variants are designed to learn the long-term temporal dependencies and are capable of avoiding the “vanishing” or “exploding” gradient problems [8]. In [26], O’Shea et. al. presented an RNN model that extracted high-level protocol information from the low-level physical layer representation for the task of radio traffic sequence classification. A radio anomaly detection technique was presented in [28], where the authors used an LSTM-based RNN as a time series predictor using the error component to detect anomaly from real signals. Another application of RNN was proposed in [41], where the authors used a deep recurrent neural network to learn the time-varying probability distribution of received powers on a channel and used the same to predict the suitability of sharing that channel with other users. Bai et al. proposed an RF path fingerprinting [2] method using two RNNs in order to learn the spatial or temporal pattern. Both simulated and real-world data were used to improve the positioning accuracy and robustness of moving RF devices. All these discussed approaches prove that the temporal property of the RF data can be leveraged through RNN models to learn and analyze different RF properties. Moreover, the working principle of RNN is rather simple compared to many complicated learning algorithms such as multi-stage training (MST). In [58], the authors compared several learning paradigms for the task of transmitter identification. More precisely, they looked at feedforward deep neural nets, convolutional neural nets, support vector machines, and deep neural nets with MST. Though they achieved 99% accuracy for classifying 12 transmitters using MST, on the downside MST is a complex procedure and needs an easily parallelizable training environment. On the other hand, RNNs use a firstorder update rule (stochastic gradient) and are comparatively simple procedures. However, there has not been much effort on identifying RF transmitters by exploiting the recurrent structure within RF data. As far as the modulation recognition technique is concerned, a method was proposed in [33] for a distributed wireless spectrum sensing network. The authors proposed a recurrent neural network using long short-term memory (LSTM) cell, yielding 90% accuracy on a synthetic dataset [32]. In the light of recent developments, one could argue that deep learning can be a natural choice for implementing any system for exploiting different RF parameters to build systems for RF application-oriented tasks. It is also clear that the paradigm of applying deep learning in the RF domain gradually shifted from applying DNN to CNN to RNN, while for different application areas, it shifted from modulation
Exploiting Spatio-Temporal Correlation in RF Data …
149
recognition to transmitter identification. It is to be noted that, none of these methods address the problem of providing an end-to-end solution using raw signal data for transmitter identification using automatically extracted “fingerprints”. Hence, we propose a robust end-to-end “radio fingerprinting” solution by proposing different types of RNN-based models.
3 Highlights of the Proposed RNN Models As mentioned earlier, we use three variants of RNNs with time series of I/Q data to present an end-to-end solution for transmitter identification using the concept of “transmitter fingerprinting”. The main highlights for the rest of the chapter areas follows: 1. We exploit the temporal properties of I/Q data by using a supervised learning approach for transmitter identification using recurrent neural networks. We use two approaches: first, we exploit only the temporal property and, then we exploit the spatio-temporal property. We use RNNs with LSTM and GRU cells for the first approach, while we use a convLSTM model for the latter. Although transmitter fingerprinting has been studied before, to the best of our knowledge this is the first work which leverages the spatio-temporal property of the over-the-air signal data for this task. 2. To examine the performance of the proposed networks, we test them on an indoor testbed. We transmit raw signal data from eight universal software radio peripheral (USRP) B210s [10] and collect over-the-air signals using an RTL-SDR [24]. We use the I/Q values from each USRP for fingerprinting the corresponding transmitter. 3. We collect I/Q data from several different types of SDRs, a couple of them made by the same manufacturer and a couple made by different manufacturers. More precisely we use USRP-B210 [10], USRP-B200 [9] and USRP-X310 [11] from Ettus Research as well as ADALM-PLUTO [7] and BLADE-RF [25]. We show that the spatio-temporal property is more pronounced (and thus easier to exploit, both explicitly as well as implicitly) when different types of SDRs (from both same or different manufacturers) are used as transmitters. 4. We also collect three additional datasets of I/Q values from eight USRP B210s with varying signal-to-noise-ratio (SNR). We use distance and multi-path as the defining factors for SNR variation during data collection. 5. We train the proposed RNN models and present a competitive analysis of the performance of our models against the state-of-the-art techniques for transmitter classification. Results reveal that the proposed methods out-perform the existing ones, thus establishing the fact that exploiting the spatio-temporal property of I/Q data can offset the necessity of pre-processing the raw I/Q data used for traditional transmitter fingerprinting approaches.
150
D. Roy et al.
6. The novelty of proposed implementations lies in accurately modeling and implementing different types of RNNs to build a robust transmitter fingerprinting system using over-the-air signal data, by exploiting spatio-temporal correlations.
4 Proposed RNN Models for Classification In order to estimate the noise in an RF channel, the system needs to “listen” to the underlying signal for sometime and “remember” the same. Previously, neural networks lacked this capability when used in the context of temporal data. Another issue with using neural networks with temporal data was the problem of vanishing gradients, when trying to use back propagation. Both these problems were solved by the introduction of Recurrent Neural Networks (RNN) [20]. Moreover, inspired by the success of deep learning systems for the task of characterizing RF environments [30] and the successful use of RNN for the task of analyzing time series data [34], we propose to use deep recurrent structures for learning transmitter “fingerprints” for the task of transmitter classification or identification. These proposed models are extended version of the work presented in [39]. Formulation of temporal property of RF data Given T training samples (for T timestamps) where each training sample is of size of M and consists of a vector of tuples of the form (I, Q) ∈ C representing a number in the complex plane, we represent a single sample as xt = [[(I, Q)i ]t ; i = 1, 2, . . . , M] ∈ C M for each timestamp t = 1, 2, . . . , T , and we use it as an input to the neural network. We use a sample size (M) of 1024 as a default. We want to find the probability of the input vector for next time step (xt+1 ) to belong to class Ck , where k ∈ 1, 2, . . . , K , K being the number of classes. The probability P(Ck |xt+1 ) can be written as P(xt |Ck )P(Ck ) (1) P(Ck |xt+1 ) = P(xt xt+1 ) where (P(xt |Ck )) is the conditional probability of xt given the class Ck and (P(xt x x+1 )) is the probability of xt and xt+1 occurring in order.
4.1 Long Short-Term Memory (LSTM) Cell Model Though LSTM cells can be modeled and designed in various ways depending on the need, we use the cells as shown in Fig. 1. In one LSTM cell, there are (i) three types of gates: input (i), forget ( f ), and output (o); and (ii) a state update of internal cell memory. The most interesting part of the LSTM cell is the “forget” gate, which at time t is denoted by f t . The forget gates decide whether to keep a cell state memory (ct ) or not. The forget gates are designed as per the Eq. (2) on the input value xt at time t and output (h t−1 ) at time (t − 1).
Exploiting Spatio-Temporal Correlation in RF Data …
151
Fig. 1 LSTM cell architecture used in the proposed RNN model
σ
σ
f t = σ(Wx f xt + Wh f h t−1 + b f )
σ
ĉ
(2)
Note that Wx f and b f represent the associated weight and bias, respectively, between input (x) and forget gate ( f ) and σ denotes the sigmoid activation function. Once f t determines which memories to forget, the input gates (i t ) decides which cell states ( ct ) to update as per Eqs. (3) and (4). i t = σ(Wxi xt + Whi h t−1 + bi )
(3)
ct = tanh(Wxc xt + Whc h t−1 + bct−1 )
(4)
In Eq. (5), the old cell state (ct−1 ) is updated to the new cell state (ct ) using forget gates ( f t ) and input gates (i t ). ct = f t · ct−1 + i t · ct
(5)
Here ◦ is the Hadamard product. Finally, we filter the output values through the output gates (ot ) based on the cell states (ct ) as per Eqs. (6) and (7). ot = σ(Wxo xt + Who h t + bo )
(6)
h t = ot · tanh(ct )
(7)
4.2 Gated Recurrent Unit (GRU) Model The main drawback of using LSTM cells is the need for additional memory. GRUs [5] have one less gate for the same purpose, thus having a reduced memory and CPU footprint. The GRU cells control the flow of information just like the LSTM cells,
152
D. Roy et al.
Fig. 2 GRU cell architecture used in the proposed RNN model
σ
σ
σ
but without the need for a memory unit. It simply exposes the full hidden content without any control. It has a “reset gate” (z t ), an “update gate” (rt ), and a cell state memory (ct ) as shown in Fig. 2. The reset gates determine whether to combine the new input with a cell state memory (ct ) or not. The update gate decides how much of ct to retain. The Eqs. (8)–(11), related to the different gates and states are given below. (8) z t = σ(Wx z xt + Whz h t−1 + bz ) rt = σ(Wxr xt + Whr h t−1 + br )
(9)
ct = tanh(Wxc xt + Whc (rt · h t−1 ))
(10)
h t = (1 − z t ) · ct + z t · h t−1
(11)
4.3 Convolutional LSTM Network Model To mitigate this problem, we use a convolution within the recurrent structure of the RNN. We first discuss the spatio-temporal property of RF data and then model a convolutional LSTM network to exploit the same.
4.3.1
Formulation of Spatio-temporal property for RF data
Suppose that a radio signal is represented as a time-varying series over a spatial region using R rows and C columns. Here R represents the time-varying nature of the signal and as such in our case it represents the total number of timestamps at which the signal was sampled (T in our case). C on the other hand represents the
Exploiting Spatio-Temporal Correlation in RF Data …
153
total number of features sampled at each time stamp (in our case its 2048 since there are 1024 features sampled each of dimension 2). Note that each cell corresponding to one value of R and one value of C represents a particular feature (I or Q) at a given point in time. In order to capture the temporal property only, we use a sequence of vectors corresponding to different timestamps 1, 2, . . . , t as x1 , x2 , . . . , xt . However, to capture both spatial and temporal properties, we introduce a new vector χt,t+γ , which is formulated as: χt,t+γ = [xt , xt+1 , . . . , xt+γ−1 ]. So the vector χt,t+γ eventually preserves the spatial properties with an increment of γ in time. So, we get a sequence of new vectors χ1,γ , χγ,2γ , . . . χt,t+γ , . . . , χt+(β−1)γ,t+βγ , where β is R/γ, and the goal is to create a model to classify them into one of the K classes (corresponding to the transmitters). We model the class-conditional densities given by P(χt−γ,t |Ck ), where k ∈ 1, · · · , K . We formulate the probability of the next γ-length sequence to be in class Ck as per Eq. 12. The marginal probability is modeled as P(χt,t+γ ). P(Ck |χt,t+γ ) =
4.3.2
P(χt−γ,t |Ck )P(Ck ) P(χt,t+γ )
(12)
The Model
The cell model, as shown in Fig. 3, is similar to an LSTM cell, but the input transformations and recurrent transformations are both convolutional in nature [46]. We formulate the input values, cell state, and hidden states as a 3-dimensional vector, where the first dimension is the number of measurements which varies with the time interval γ and the last two dimensions contain the spatial information (rows (R) and columns (C)). We represent these as: (i) the inputs: χ1,γ , χγ,2γ , · · · χt,t+γ , · · · , χt+(β−1)γ,t+βγ (previously stated); (ii) cell outputs: C1 , · · · , Ct , and (iii) hidden states: H1 , · · · , Ht . We represent the gates in a similar manner as in the LSTM model. The parameters t, i t , f t , ot , W , b hold the same meaning as in Sect. 4.1. The key operations are defined in Eqs. 13–17. The probability of the next γ-sequence to be in a particular class (from Eq. 12) is used within the implementation and execution of the model. i t = σ(Wxi χt,t+γ + Whi Ht−1 + bi )
(13)
f t = σ(Wx f χt,t+γ + Wh f Ht−1 + b f )
(14)
Ct = f t · Ct−1 + i t . tanh(Wxc χt,t+γ + Whc Ht−1 + bc )
(15)
ot = σ(Wxo χt,t+γ + Who Ht−1 + bo )
(16)
Ht = ot · tanh(Ct )
(17)
154 Fig. 3 ConvLSTM cell architecture used in the proposed RNN model
D. Roy et al.
ConvLSTM Cell
Xt
σ
σ it
Ht-1
tanh
σ
ot ft Ct-1
ĉt Ct
tanh σ
Ht
tanh
tanh activation sigmoid activation sum over all elements Hadamard product
5 Testbed Evaluation In order to validate the proposed models, we collected raw signal data from eight different universal software radio peripheral (USRP) B210s [10]. We collected the data in an indoor lab environment with a signal-to-noise ratio of 30 dB, and used the dataset to discriminate between four and eight transmitters, as mentioned in [40]. We also collected data with varied SNRs (20, 10, and 0 dB) for the same 8 USRP B210 transmitters. Finally, we collected data from five different types of transmitters, each transmitter being implemented using an SDR, from several different manufacturers.
5.1 Signal Generation and Data Collection In order to evaluate our methods for learning the inherent spatio-temporal features of a transmitter, we used different types of SDRs as transmitters. The signal generation and reception are shown in Fig. 4. We used GNURadio [12] to randomly generate signal and modulated the same with quadrature phase shift keying (QPSK). We programmed the transmitter SDRs to transmit the modulated signal over-the-air and sensed the same using a DVB-T dongle (RTL-SDR) [24]. We generated the entire dataset from “over-the-air” data as sensed by the RTL-SDR using the rtlsdr python library. We collected I/Q signal data with a sample size of 1024 at each timestamp. Each data sample had 2048 entries consisting of the I and Q values for the 1024 samples. Note that a larger sample size would mean more training examples for the neural network. Our choice of 1024 samples was sufficient to capture the spatial-temporal properties while at the same time the training was not computationally intensive. We collected 40,000 training examples from each transmitter to avoid the data skewness problem observed in machine learning. The configuration parameters that were used are given in Table 1. We collected two different datasets at 30 dB SNR and three datasets having three different SNRs, as discussed below. The different types of SNR
Exploiting Spatio-Temporal Correlation in RF Data … Random Signal
QPSK Modulation
155 Transmitter SDR
Over The Air Transmission Datasets
Data Collection
RTL-SDR
Fig. 4 Over-the-air signal generation and data collection technique Table 1 Transmission configuration parameters Parameters Values Transmitter gain Transmitter frequency Bandwidth Sample size Samples/transmitter # Transmitters
45 dB 904 MHz (ISM) 200 KHz 1024 40,000 4 and 8
levels were achieved in an indoor lab environment by changing the propagation delay, multi-path, and shadowing effects. We also collected a “heterogeneous" dataset using several different types of SDRs. Note that we intend to make the dataset publicly available upon publication of the chapter.
5.1.1
Homogeneous Dataset
For the “homogeneous” dataset, we used eight radios from the same manufacturer, namely, the USRP-B210 from Ettus Research [10], as transmitters. We collected two sets of data: (i) using 4 USRP B210 transmitters: 6.8 GB size, 160K rows, and 2048 columns and (ii) using 8 USRP B210 transmitters: 13.45 GB size, 320K rows, and 2048 columns. Note that the SNR was 30 dB in each case.
5.1.2
Heterogeneous Dataset
In order to investigate the spatio-temporal correlation in the I/Q data from different types of SDRs from varied manufacturers, we collected a “heterogeneous” dataset as well. We used three different SDRs from same manufacturer and 2 SDRs from two different manufacturers. We used USRP B210 [10], USRP B200 [9] and USRP X310 [11] from Ettus Research. We also used BLADE RF [25] by Nuand and PLUTO SDR [7] by Analog Devices as two different SDRs from two different manufacturers.
156
D. Roy et al.
The signal generation procedure is similar to Fig. 4 with different SDR models as transmitters. The SNR remains 30 dB, same as earlier. The “heterogeneous” datasets were obtained using (i) 5 USRP B210 transmitters: 8.46 GB size, 200K rows, and 2048 columns and (ii) 1 USRP B210, 1 USRP B200, 1 USRP X310, 1 BLADERF, and 1 PLUTO transmitter: 6.42 GB size, 200K rows, and 2048 columns. We include the 5 USRP B210 homogeneous data in this dataset to perform a fair comparison between five heterogeneous radios with five homogeneous ones in Sect. 6.6, as the mentioned homogeneous datasets in previous paragraph contain either four or eight homogeneous radios.
5.1.3
Varying SNR Datasets
We collected three more datasets with 8 USRP B210 transmitters with SNRs of 20 dB, 10 dB, and 0 dB, respectively. Each dataset is of size ∼13 GB with 320K rows and 2048 columns.
5.2 Spatial Correlation in the Homogeneous Dataset Correlation between data samples plays a crucial role in the process of transmitter identification. We represent the I and Q values of each training sample at time (t) as: [I0 Q 0 I1 Q 1 I2 Q 2 I3 Q 3 I4 Q 4 . . . I1023 Q 1023 ]t . We used the QPSK modulation [48] which means that the spatial correlation should be between every fourth value, i.e., between I0 and I4 , and Q 0 and Q 4 . So we calculate the correlation coefficient of I0 I1 I2 I3 and I4 I5 I6 I7 . Similarly, for Q 0 Q 1 Q 2 Q 3 and Q 4 Q 5 Q 6 Q 7 . We take the average of all the correlation coefficients for each sample. We use numpy.corrcoef for this purpose which uses Pearson product-moment correlation coefficients, denoted by r . The Pearson’s method for a sample is given by (M−1) ¯ i=0 (Ii − I )(Q i − Q) (18) r = (M−1) (M−1) 2 2 ¯ (I − I ) (Q − Q) i i i=1 i=0 where M is the sample size, Ii and Q i are the sample values indexed with i. The 1 (M−1) Ii . sample mean is I¯ = M i=0 The spatial correlations of all the samples for the different transmitters are shown in Fig. 5. We observe that for most of the transmitters, the correlation is ∼0.42, with a standard deviation of ∼0.2. However, transmitter 3 exhibits minimal correlation between these samples, which implies that the spatial property of transmitter 3 is different from the other transmitters. As a result transmitter 3 should be easily distinguishable from the others. This claim will be validated later in the experimental
Exploiting Spatio-Temporal Correlation in RF Data …
157
Fig. 5 Spatial correlation in the datasets
result section where we see 0% false positive and false negative for transmitter 3 for all the three proposed models. This observation gives us the motivation to exploit the spatial property as well as the temporal property for the collected time series data.
5.3 Spatial Correlation in the Heterogeneous Dataset The calculated average spatial correlations of samples for five different types of transmitters (as mentioned in Sect. 5.1.2 (ii)) are shown in Fig. 6. We observe that data from USRP-B210 and PLUTO-SDR have better correlations than the other three types. It is to be noted that we calculated the spatial correlation of this data using the same technique described in previous section. It is also evident from the figure that none of the transmitters exhibits impressive correlations, however, each has correlations of different ranges. This phenomenon bolsters our claim that spatio-temporal property in the heterogeneous data will be more distinguishable than homogeneous ones, validated later in experimental result section (Sect. 6.6).
Fig. 6 Spatial correlation in the heterogeneous datasets
158
D. Roy et al.
5.4 Experimental Setup and Performance Metrics We conducted the experiments on a Ryzen 8 Core system with 64 GB RAM, a GTX 1080 Ti GPU unit having 11 GB memory. We use Keras [4] as the frontend and Tensorflow [1] as the backend for our implementations. During the training phase, we use data from each transmitter to train the neural network model. In order to test the resulting trained model, we use test data collected from one of the transmitters and present the same to the trained network. In general, to measure the effectiveness of any learning algorithm, “accuracy” is used as the typical performance metric. However, accuracy can sometimes be misleading and incomplete when the data is skewed. For the task of classification, a confusion matrix overcomes this problem by showing how confused the learned model is on its predictions. It provides more insights on the performance by identifying not only the number of errors, but also more importantly the types of errors.
6 Model Implementations and Results In this section we discuss the implementation of each of the proposed recurrent neural networks. We train each network for transmitter classification with K classes. For the sake of robustness and statistical significance, we present the results for each model after averaging over several runs.
6.1 Implementation with LSTM Cells
Dense
Dense
256
512
256
Output
LSTM Layer2
(None, 1024 2048)
Dense
LSTM Layer1
Data
Input
Fig. 7 RNN implementation with LSTM cells for transmitter classification
Signal Processing and Data Collection
As discussed earlier, the recurrent structure of the neural network can be used to exploit the temporal correlation in the data. To that end, we first implemented a recurrent neural network with LSTM cells and trained it on the collected dataset using the paradigm as shown in Fig. 7. We used two LSTM layers with 1024 and 256 units sequentially. We also used a dropout rate of 0.5 in between these two LSTM layers. Next we used two fully connected (Dense) layers with 512 and 256 nodes, respectively. We apply a dropout rate of 0.2, and add batch normalization [17] on
8
8
Exploiting Spatio-Temporal Correlation in RF Data …
159
the output, finally passing it through a Dense layer having eight nodes. We use ReLU [23] as the activation function for the LSTM layers and tanh [3] for the Dense layers. Lastly, we use stochastic gradient descent [3]-based optimization with categorical cross-entropy training. Note that the neural network architecture was finalized over several iterations of experimentation with the data and we are only reporting the final architecture here. We achieved 97.17% and 92.00% testing accuracy for four and eight transmitters, respectively. The accuracy plots and confusion matrices are shown in Figs. 8 and 9, respectively. Note that the number of nodes in the last layer is equal to the number of classes in the dataset. It is also to be noted that during the process of designing the RNN architecture, we also fine tuned the hyper-parametersbased generalization ability of the current network (as determined by comparing the training and validation errors). We also limited the number of recurrent layers and fully connected layers for each model for faster training [15], since no significant
Fig. 8 Accuracy plots for transmitter classification using LSTM cells
Fig. 9 Confusion matrices for transmitter classification using LSTM cells
160
D. Roy et al.
increase in the validation accuracy was observed after increasing the number of layers. The rows and columns of the confusion matrix correspond to the number of transmitters (classes) and the cell values show the recall or sensitivity and false negative rate for each of the transmitters. Note that recall or sensitivity represents the true positive rates for each of the prediction classes.
6.2 Implementation with GRU Cells
Dense
Dense
256
512
256
Output
GRU Layer2
(None, 1024 2048)
Dense
GRU Layer1
Data
Input
Signal Processing and Data Collection
Next we implemented another variation of the RNN model using GRU cells for leveraging temporal correlation. We used the same architecture as the LSTM implementation, presented in Fig. 10. The proposed GRU implementation needs fewer parameters than the LSTM model. A quantitative comparison is given in Sect. 6.4. The only difference is that we use two GRU layers with 1024 and 256 units instead of using LSTM cells. We achieved 97.76% and 95.30% testing accuracy for four and eight transmitters, respectively. The accuracy plots and confusion matrices are given in Figs. 11 and 12. The GRU implementation provided a slight improvement over the accuracy obtained using LSTM, for each run of the models, for both the datasets.
8
8
Fig. 10 RNN implementation with GRU cells for transmitter classification
Fig. 11 Accuracy plots for transmitter classification using GRU cells
Exploiting Spatio-Temporal Correlation in RF Data …
161
Fig. 12 Confusion matrices for transmitter classification using GRU cells
6.3 Implementation with ConvLSTM2D Cells
512
256
Output
256
Dense
1024
Dense
ConvLSTM2D Layer2
(None, 2048)
Flatten
Dense
ConvLSTM2D Layer1
Data
Input
Signal Processing and Data Collection
Finally, in order to exploit the spatio-temporal property of the signal data, we implemented another variation of the LSTM model with convolutional filters (transformations). The implemented architecture is shown in Fig. 13. ConvLSTM2D uses twodimensional convolutions for both input transformations and recurrent transformations. We first use two layers of convLSTM2D with 1024 and 256 filters, respectively, and a dropout rate of 0.5 in between. We use kernel size of (2, 2) and stride of (2, 2) at each ConvLSTM2D layer. Next we add two fully connected (Dense) layers having 512 and 256 nodes, respectively, after flattening the convolutional output. ReLU [23], and tanh [3] activation functions are used for the convLSTM2D and Dense layers, respectively. ADADELTA [59] with a learning rate of 10−4 and a decay rate of 0.9, is used as the optimizer with categorical cross-entropy training. We achieved 98.9% and 97.2% testing accuracy for four and eight transmitters, respectively. The accuracy plots and confusion matrices are given in Figs. 14 and 15, respectively. Being able to exploit the spatio-temporal correlation, ConvLSTM implementation provides improvement over the accuracies obtained using the LSTM and GRU models, for both the datasets.
8
8
Fig. 13 RNN implementation with ConvLSTM cells for transmitter classification
162
D. Roy et al.
Fig. 14 Accuracy plots for transmitter classification using ConvLSTM cells
Fig. 15 Confusion matrices for transmitter classification using ConvLSTM cells
6.4 Comparisons of LSTM/GRU/ConvLSTM Implementations We used 90%, 5%, and 5% of the data to train, validate, and test, respectively. We ran each model for 50 epochs with early-stopping on the validation set. One epoch consists of a forward pass and a backward pass through the implemented architecture for the entire dataset. The overall accuracies of the different implementations are shown in Table 2. We find that the implementation of convolutional layers with recurrent structure (ConvLSTM2D) exhibit the best accuracy for transmitter classification, which clearly shows the advantage of using the spatio-temporal correlation present in the collected datasets. In Fig. 16, we present a better illustration of the achieved classification accuracies for the different implemented models.
Exploiting Spatio-Temporal Correlation in RF Data …
163
Table 2 Accuracy for different implementations #Trans Models #Parameters (M)
97.76%
98.90%
92.00%
95.30%
97.20%
ConvLSTM
LSTM
GRU
ConvLSTM
97.17 97.76 98.90 92.00 95.30 97.20
GRU
Fig. 16 Comparison of testing accuracies of different types of recurrent neural networks
14.2 10.7 14.2 14.2 10.7 14.2
97.17%
LSTM (6 layers) GRU (6 layers) ConvLSTM (6 layers) LSTM (6 layers) GRU (6 layers) ConvLSTM (6 layers)
LSTM
4 4 4 8 8 8
Acc (%)
6.5 Computational Complexities In this section, we concentrate on the computational time complexities of one epoch for the training phase only, as the trained model gives the output within constant time (O(1)) during the deployment phase. Understanding the time complexity of training an RNN is still an evolving research area. The proposed RNN models are combination of two recurrent layers and four fully connected layers. Hence, we analyze time complexities of both types separately. In [22], the authors proved that δ a fully connected NN of depth δ can be learned in poly(s 2 ) time, where s is the dimension of the input (=T in the proposed models), and poly(.) takes a constant
164
D. Roy et al.
time depending on the configuration of the system. However, the complexity of LSTM and GRU layers depends on the total number of parameters in the network [43]. It can be expressed as O(Plstm × T ) for LSTM layers, where Plstm is the total number of parameters in LSTM network, T is number of timesteps, or total number of training data samples. Similarly for the GRU layers it will be O(Pgr u × T ), where Pgr u is the total number of parameters in GRU network. However, the computational complexity of ConvLSTM layer will depend on the complexity of convolution as well as LSTM layers. In [53], the authors mentioned that the time complexity for ζ training all the convolutional layers is O( τ =1 (ητ −1 ντ2 .ητ ρ2τ ), where ζ is the number of convolutional layers, τ is the index of a convolutional layer, ητ −1 is the number of input channels of the τ th layer, ντ is the spatial size of the filters at the τ th layer, ητ is the number of filters at the τ th layer, and ρτ is the size of the output features of the τ th layer. In the proposed ConvLSTM model, we have two ConvLSTM layers, and four fully connected layers, therefore, we add in additional time complexity for training those convolutional layers. The time complexities for each implemented RNN models for the homogeneous dataset is presented in Table 3, using the aforementioned results on time complexity of neural network training. The total number of parameters used in each network are shown in Table 2. The numbers within the parenthesis in the second column represent the total number of layers for a particular model. Note that we have two different datasets of dimensions 160 and 320 K and as mentioned earlier, we use 95% of data for training and validation purpose. For example, the complexity for ConvLSTM with six layers using 95% of 160e3 data samples for training and validation, is O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3) for 2 ConvLSTM layers 4 and poly(0.95 × 160e32 ) for four fully connected layers. Similarly, the computational complexity of LSTM with six layers using 95% of 160e3 data samples for training and validation, is O(14.2e6 × 0.95 × 160e3) for 2 LSTM layers (where 4 14.2e6 is the number of parameters), and poly(0.95 × 160e32 ) for four fully connected layers.
Table 3 Computational complexities for training of epoch of proposed implementations #Trans Models Complexity 4
LSTM (6)
4
O(14.2e6 × 0.95 × 160e3) + poly(0.95 × 160e32 ) 4 poly(0.95 × 160e32 )
4 4
GRU (6) ConvLSTM (6)
O(10.7e6 × 0.95 × 160e3) + O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3) 4 + poly(0.95 × 160e32 )
8
LSTM (6)
O(14.2e6 × 0.95 × 320e3) poly(0.95 × 320e32 )
8 8
GRU (6) ConvLSTM (6)
O(10.7e6 × 0.95 × 320e3) + poly(0.95 × 320e32 ) O(( 2τ =1 (ητ −1 · ντ2 · ητ · ρ2τ ) + 0.95 × 160e3) 4 + poly(0.95 × 320e32 )
4
4
Exploiting Spatio-Temporal Correlation in RF Data …
165
6.6 Experiments with Heterogeneous Dataset So far, we have implemented the proposed RNN models for “homogeneous” datasets, where transmitter SDRs were from the same manufacturer. However, in reality the transmitters can be of different models from either same manufacturer, or several different manufacturers. Now, we want to explore how the accuracy of transmitter identification would change if “heterogeneous” data (as was discussed in Sect. 5.1.2) obtained from different types of transmitters (manufacturers) were used. From the testing accuracies as shown in Table 4, we observe that all the RNNs perform better when transmitters are of different models either from same or different manufacturers and hence are fundamentally of different types. The performance of LSTM, GRU, and ConvLSTM increase 5%, 3.37%, and 1.51%, respectively, for classifying heterogeneous radios than the homogeneous one. This confirms the intuition that radios manufactured using different processes (from different manufacturers) contain easily exploitable characteristics in their I/Q samples, that can be implicitly learned using an RNN. The comparison of confusion matrices for all three proposed models are presented Figs. 17, 18, and 19. The false positives and true negatives are observed to be considerably low for the 5-HETERO results than the 5-B210s. It is to be noted that
Table 4 Comparison of testing accuracies for different classification models for homogeneous and heterogeneous datasets Models 5-B210s (Acc) (%) 5-HETERO (Acc) (%) Change LSTM GRU ConvLSTM
95.61 96.72 98.5
99.89 99.97 99.99
5%↑ 3.37%↑ 1.51%↑
Fig. 17 Confusion matrices for transmitter classification using LSTM cells for heterogeneous dataset
166
D. Roy et al.
Fig. 18 Confusion matrices for transmitter classification using GRU cells for heterogeneous dataset
Fig. 19 Confusion matrices for transmitter classification using ConvLSTM cells for heterogeneous dataset
we used the same proposed RNN models for heterogeneous data too, thus, implying the robustness of the proposed models.
6.7 Comparisons of Proposed and Existing Approaches Next, we present two comparative studies of our proposed implementations with state-of-the-art techniques. We introduce a differential analysis of different RNNbased implementations in the RF domain in Table 5. Another comparative study for different transmitter classification techniques are shown in Table 6.
Exploiting Spatio-Temporal Correlation in RF Data …
167
Table 5 Comparison of proposed approach with the existing RNN implementations Approaches Model SNR (dB) Acc (%) Inputs Traffic sequence recognition [26]
LSTM
20
31.20
Automatic modulation classification [33] Transmitter classification (Ours) Hetero-transmitter classification (Ours)
LSTM
20
90
ConvLSTM
30
97.20
Hybrid real-synthetic dataset Synthetic dataset[32] Raw signal
ConvLSTM
30
99.99
Raw signal
Table 6 Comparison of the our implementation with the existing transmitter classification approaches Approach #Trans SNR (dB) Acc (%) Inputs Orthogonal component reconstruction (OCR) [57] Genetic Algorithm [51] Multifractal segmentation [45] k-NN [18] Ours Ours-hetero
3
20
62–71
Spurious modulation
5 8
25 Not mentioned
85–98 92.50
Transients Transients
8 8 5
30 30 30
97.20 97.20 99.99
Transients Raw signal Raw signal
The “Inputs” column in both the tables refer to the type of inputs used for the methods under consideration. Table 5 shows a comparison of our ConvLSTM-based RNN for transmitter classification with other RNN-based implementations for separate tasks like modulation recognition and traffic sequence recognition. Table 6 establishes the efficacy of our ConvLSTM-based RNN model for the task of transmitter classification in terms of testing accuracies. It is to be noted that all the other methods use expert crafted features as inputs [18, 45, 51, 57], or work with synthetic datasets [26, 33]. Our method, on the other hand achieves superior accuracy, for both homogeneous (97.20%) and heterogeneous (99.99%) datasets, using features automatically learned from the raw signal data, thereby paving the way for real-time deployment of large-scale transmitter identification systems.
168
D. Roy et al.
6.8 Performance Comparison for Varying SNR In this section, we present the results of transmitter fingerprinting for varying SNR values. We compare the accuracies for the proposed RNN models having 8 USRP B210s with 30 dB SNR, with 3 other datasets collected at 0, 10, and 20 dB SNRs having the same number of transmitters (8B210s) as shown in Table 7. We achieve better accuracies with all the models for higher SNR values, which is intuitive. It is to be mentioned that the proposed ConvLSTM RNN model gives more than 93% accuracy at 0 dB SNR too, whereas GRU model gives lesser than that, and LSTM fails to achieve a considerable range. Moreover, the proposed RNN models can be trained using raw signal data from any type of radio transmitter operating both in indoor as well as outdoor environments. We would also like to point out that though our data was collected in a lab environment, we had no control over the environment, there were other transmissions in progress, people were moving in and out of the lab and there was a lot of multi-path due to the location and design of the lab. Furthermore, the power of the transmitters was low and hence this compounded the problem further. Given this, though we say that the data was collected in a lab environment, in reality it was an uncontrolled daily use environment reflective of our surroundings. Thus we can safely say that these methods will work in any real-world deployment of large-scale radio network. In summary, • Exploiting temporal correlation only, recurrent neural networks yield 95–97% accuracy for transmitter classification using LSTM or GRU cells. RNN implementation with GRU cells needs fewer parameters than LSTM cells as shown in Table 2. • Exploiting spatio-temporal correlation, the implementation of RNN using ConvLSTM2D cells provides better accuracy (97–98%) for transmitter classification, thus providing a potential tool for building automatic real world transmitter identification systems. • The spatio-temporal correlation is more noticeable (with 1.5–5% improvement of classification accuracies) in the proposed RNN models for the heterogeneous transmitters either different models from same manufacturer, or different models from different manufacturers.
Table 7 Accuracies for different recurrent neural network models with varying SNRs SNR(dB) Accuracy (%) LSTM GRU ConvLSTM 0 10 20 30
84.23 90.21 91.89 92.00
90.3 92.64 94.02 95.30
93.3 95.64 97.02 97.20
Exploiting Spatio-Temporal Correlation in RF Data …
169
• The proposed RNN models give better accuracies with increasing SNRs of the data collection environment. However, the ConvLSTM model is able to classify with 93% accuracy at 0 dB SNR too, proving the robustness of spatio-temporal property exploitation. • We present a comparative study of the proposed spatio-temporal property-based fingerprinting with the existing traditional and neural network-based models. This clearly shows that the proposed model achieves the better accuracies compared to any of the existing methods for transmitter identification.
7 Summary With more and more autonomous deployments of wireless networks, accurate knowledge of the RF environment is becoming indispensable. In recent years, there has been a proliferation of autonomous systems that use deep learning algorithms on large-scale historical data. To that end, the inherent recurrent structures within the RF historical data can also be leveraged by deep learning algorithms for reliable future prognosis. In this chapter, we addressed some of such fundamental challenges on how to effectively apply different learning techniques in the RF domain. We presented a robust transmitter identification technique by exploiting both the inherent spatial and temporal properties of RF signal data. The testbed implementation and result analysis prove the effectiveness of the proposed deep learning models. The future step forward can be to apply these methods for identification of actual infrastructure transmitters (for example FM, AM, and GSM) in real-world settings.
8 Further Reading More details on deep learning algorithms can be found in [55]. Advanced applications of deep learning in RF domain involving adversaries can be found in [35, 36, 40]. Deep learning applications in the advanced field of RF, such as dynamic spectrum access is discussed in [37, 38].
References 1. M. Abadi et al., Tensorflow: large-scale machine learning on heterogeneous distributed systems. CoRR (2016) 2. S. Bai, M. Yan, Y. Luo, Q. Wan, RFedRNN: an end-to-end recurrent neural network for radio frequency path fingerprinting, in Recent Trends and Future Technology in Applied Intelligence (2018), pp. 560–571
170
D. Roy et al.
3. C.M. Bishop, Pattern Recognition and Machine Learning (Information Science and Statistics) (Springer, 2006) 4. F. Chollet, et al., Keras: the python deep learning library (2015). https://keras.io 5. J. Chung, C. Gülçehre, K. Cho, Y. Bengio, Empirical evaluation of gated recurrent neural networks on sequence modeling. CoRR (2014). arXiv:abs/1412.3555 6. B. Danev, S. Capkun, Transient-based identification of wireless sensor nodes, in International Conference on Information Processing in Sensor Networks (2009), pp. 25–36 7. A. Devices, ADALM-PLUTO overview (2020). https://wiki.analog.com/university/tools/pluto 8. R. Dey, F.M. Salemt, Gate-variants of gated recurrent unit (GRU) neural networks, in 2017 IEEE 60th International Midwest Symposium on Circuits and Systems (MWSCAS) (2017), pp. 1597–1600 9. Ettus Research: USRP B200 (2020). https://www.ettus.com/all-products/ub200-kit/ 10. Ettus Research: USRP B210 (2020). https://www.ettus.com/product/details/UB210-KIT/ 11. Ettus Research: USRP X310 (2020). https://www.ettus.com/all-products/x310-kit/ 12. GNURadio: GNU Radio (2020). https://www.gnuradio.org 13. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning. MIT press (2016) 14. A. Gritsenko, Z. Wang, T. Jian, J. Dy, K. Chowdhury, S. Ioannidis, Finding a ‘New’ needle in the haystack: unseen radio detection in large populations using deep learning, in IEEE International Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10 15. K. He, J. Sun, Convolutional neural networks at constrained time cost, in IEEE CVPR (2015) 16. S. Hochreiter, J. Schmidhuber, Long short-term memory. Neural Comput. 9(8), 1735–1780 (1997) 17. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing internal covariate shift. CoRR (2015). arXiv:abs/1502.03167 18. I.O. Kennedy, P. Scanlon, F.J. Mullany, M.M. Buddhikot, K.E. Nolan, T.W. Rondeau, Radio transmitter fingerprinting: a steady state frequency domain approach, in IEEE Vehicular Technology Conference (2008), pp. 1–5 19. B.P. Lathi, Modern Digital and Analog Communication Systems, 3rd edn. (Oxford University Press Inc, USA, 1998) 20. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature 521(7553), 436–444 (2015) 21. H.W. Lin, M. Tegmark, D. Rolnick, Why does deep and cheap learning work so well? J. Stat. Phys. 168(6), 1223–1247 (2017) 22. R. Livni, S. Shalev-Shwartz, O. Shamir, On the computational efficiency of training neural networks, in Advances in Neural Information Processing Systems (2014), pp. 855–863 23. V. Nair, G.E. Hinton, Rectified linear units improve restricted boltzmann machines, in Proceedings of International Conference on International Conference on Machine Learning (2010), pp. 807–814 24. NooElec: USRP B210 (2018). http://www.nooelec.com/store/sdr/sdr-receivers/nesdr-minirtl2832-r820t.html 25. Nuad: bladeRF 2.0 micro xA4 (2020). https://www.nuand.com/product/bladeRF-xA4/ 26. T.J. O’Shea, S. Hitefield, J. Corgan, End-to-end Radio traffic sequence recognition with recurrent neural networks, in IEEE Global Conference on Signal and Information Processing (GlobalSIP) (2016), pp. 277–281 27. T. O’Shea, N. West, Radio machine learning dataset generation with GNU radio. Proc. GNU Radio Conf. 1(1) (2016) 28. T.J. O’Shea, T.C. Clancy, R.W. McGwier, Recurrent neural radio anomaly detection. CoRR (2016). arXiv:abs/1611.00301 29. T.J. O’Shea, J. Corgan, T.C. Clancy, Convolutional radio modulation recognition networks, in Engineering Applications of Neural Networks (2016), pp. 213–226 30. T. O’Shea, J. Hoydis, An introduction to deep learning for the physical layer. IEEE Trans. Cogn. Commun. Netw. 3(4), 563–575 (2017) 31. T.L. O’Shea, T. Roy, T.C. Clancy, Over-the-air deep learning based radio signal classification. IEEE J. Sel. Top. Signal Process. 12(1), 168–179 (2018) 32. radioML: RFML 2016 (2016). https://github.com/radioML/dataset
Exploiting Spatio-Temporal Correlation in RF Data …
171
33. S. Rajendran et al., Deep learning models for wireless signal classification with distributed low-cost spectrum sensors. IEEE Trans. Cogn. Commun. Netw. 4(3), 433–445 (2018) 34. J.S. Ren, Y. Hu, Y.W. Tai, C. Wang, L. Xu, W. Sun, Q. Yan, Look, listen and learn-a multimodal LSTM for speaker identification, in AAAI (2016), pp. 3581–3587 35. D. Roy, T. Mukherjee, M. Chatterjee, Machine learning in adversarial RF environments. IEEE Commun. Mag. 57(5), 82–87 (2019) 36. D. Roy, T. Mukherjee, M. Chatterjee, E. Blasch, E. Pasiliao, RFAL: adversarial learning for RF transmitter identification and classification. IEEE Trans. Cogn. Commun. Netw. (2019) 37. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Defense against PUE attacks in DSA networks using GAN based learning, in IEEE Global Communications Conference (2019) 38. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Primary user activity prediction in DSA networks using recurrent structures, in IEEE International Symposium on Dynamic Spectrum Access Networks (2019), pp. 1–10 39. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, RF transmitter fingerprinting exploiting spatio-temporal properties in raw signal data, in IEEE International Conference on Machine Learning and Applications (2019), pp. 89–96 40. D. Roy, T. Mukherjee, M. Chatterjee, E. Pasiliao, Detection of rogue RF transmitters using generative adversarial nets, in IEEE Wireless Communications and Networking Conference (WCNC) (2019) 41. H. Rutagemwa, A. Ghasemi, S. Liu, Dynamic spectrum assignment for land mobile radio with deep recurrent neural networks, in IEEE International Conference on Communications Workshops (ICC Workshops) (2018), pp. 1–6 42. Y.B. Saied, A. Olivereau, D-HIP: a distributed key exchange scheme for HIP-based internet of things, in World of Wireless, Mobile and Multimedia Networks (WoWMoM) (2012), pp. 1–7 43. Sak, H., Senior, A.W., Beaufays, F.: Long short-term memory based recurrent neural network architectures for large vocabulary speech recognition. CoRR (2014). arXiv:abs/1402.1128 44. K. Sankhe, M. Belgiovine, F. Zhou, L. Angioloni, F. Restuccia, S. D’Oro, T. Melodia, S. Ioannidis, K. Chowdhury, No radio left behind: radio fingerprinting through deep learning of physical-layer hardware impairments. IEEE Trans. Cogn. Commun. Netw. 1 (2019) 45. D. Shaw, W. Kinsner, Multifractal modelling of radio transmitter transients for classification, in IEEE WESCANEX (1997), pp. 306–312 46. X. Shi, Z. Chen, H. Wang, D.Y. Yeung, W.K. Wong, W.C. Woo, Convolutional LSTM network: a machine learning approach for precipitation nowcasting, in Proceedings of the 28th International Conference on Neural Information Processing Systems, vol. 1 (2015), pp. 802–810 47. Y. Shi, K. Davaslioglu, Y.E. Sagduyu, W.C. Headley, M. Fowler, G. Green, Deep learning for RF signal classification in unknown and dynamic spectrum environments, in IEEE International Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–10 48. S.W. Smith, The Scientist and Engineer’s Guide to Digital Signal Processing (California Technical Publishing, San Diego, CA, USA, 1997) 49. S. Soltani, Y.E. Sagduyu, R. Hasan, K. Davaslioglu, H. Deng, T. Erpek, Real-time experimentation of deep learning-based RF signal classifier on FPGA, in IEEE International Symposium on Dynamic Spectrum Access Networks (DySPAN) (2019), pp. 1–2 50. M. Stanislav, T. Beardsley, Hacking IoT: a case study on baby monitor exposures and vulnerabilities. Rapid 7 (2015) 51. J. Toonstra, W. Kinsner, A radio transmitter fingerprinting system ODO-1, in Canadian Conference on Electrical and Computer Engineering, vol. 1 (1996), pp. 60–63 52. D. Tse, P. Viswanath, Fundamentals of Wireless Communication (Oxford University Press Inc, USA, 2005) 53. E. Tsironi, P. Barros, C. Weber, S. Wermter, An analysis of convolutional long short-term memory recurrent neural networks for gesture recognition. Neurocomputing 268, 76–86 (2017) 54. N. Wagle, E. Frew, Spatio-temporal characterization of airborne radio frequency environments, in IEEE GLOBECOM Workshops (GC Wkshps) (2011), pp. 1269–1273 55. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning (Springer, 2020) 56. R.H. Weber, R. Weber, Internet of Things, vol. 12 (Springer, 2010)
172
D. Roy et al.
57. S. Xu, L. Xu, Z. Xu, B. Huang, Individual radio transmitter identification based on spurious modulation characteristics of signal envelop, in IEEE MILCOM (2008), pp. 1–5 58. K. Youssef, L. Bouchard, K. Haigh, J. Silovsky, B. Thapa, C.V. Valk, Machine learning approach to RF transmitter identification. IEEE J. Radio Freq. Identif. 2(4), 197–205 (2018) 59. M.D. Zeiler, ADADELTA: an adaptive learning rate method. CoRR (2012). arXiv:abs/1212.5701
Human Target Detection and Localization with Radars Using Deep Learning Michael Stephan, Avik Santra, and Georg Fischer
Abstract In this contribution, we present a novel radar pipeline based on deep learning for detection and localization of human targets in indoor environments. The detection of human targets can assist in energy savings in commercial buildings, public spaces, and smart homes by automatic control of lighting, heating, ventilation, and air conditioning (HVAC). Such smart sensing applications can facilitate monitoring, controlling, and thus saving energy. Conventionally, the detection of radar targets is performed either in the range-Doppler domain or in the range-angle domain. Based on the application and the radar sensor, the angle or Doppler is estimated subsequently to finally localize the human target in 2D space. When the detection is performed on the range-Doppler domain, the processing pipeline includes moving target indicators (MTI) to remove static targets on range-Doppler images (RDI), maximal ratio combining (MRC) to integrate data across antennas, followed by constant false alarm rate (CFAR)-based detectors and clustering algorithms to generate the processed RDI detections. In the other case, the pipeline replaces MRC with Capon or minimum variance distortionless response (MVDR) beamforming to transform the raw RDI from multiple receive channels into raw range-angle images (RAI), which is then followed by CFAR and clustering algorithm to generate the processed RAI detections. However, in the conventional pipeline, particularly in case of indoor human target detection, both domains suffer from ghost targets and multipath reflections from static objects such as walls, furniture, etc. Further, conventional parametric clustering algorithms lead to single target splits, and adjacent target merges in the target range-Doppler and range-angle detections. To overcome such issues, we propose a deep learning-based architecture based on the deep residual U-net model and M. Stephan · A. Santra Infineon Technologies AG, Neubiberg, Germany e-mail: [email protected] M. Stephan (B) · G. Fischer (B) Friedrich-Alexander-University Erlangen-Nuremberg, Erlangen, Germany e-mail: [email protected] G. Fischer e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_8
173
174
M. Stephan et al.
deep complex U-net model to generate human target detections directly from the raw RDI. We demonstrate that the proposed deep residual U-net and complex U-net models are capable of generating accurate target detections in the range-Doppler and the range-angle domain, respectively. To train these networks, we record RDIs from a variety of indoor scenes with different configurations and multiple humans performing several regular activities. We devise a custom loss function and apply augmentation strategies to generalize this model during real-time inference of the model. We demonstrate that the proposed networks can efficiently learn to detect and localize human targets correctly under different indoor environments in scenarios where the conventional signal processing pipeline fails. Keywords mm-wave radar sensor · Deep residual U-Net · Deep complex U-Net · Receiver operating characteristics · Detection strategy · DBSCAN · People sensing · Localization · Human target detection · Occupancy detection
1 Introduction The ever-increasing energy consumption has given rise to a search for energy-efficient smart home technologies that can monitor and save energy, thus enhancing sustainability and reducing the carbon footprint. Depending on baseline and operation, several studies show that energy consumption can be significantly reduced in residential, commercial, or public spaces by 25–75% [1] by monitoring occupancy or counting the number of people and accordingly regulating artificial light and HVAC systems [2]. Frequency modulated continuous wave (FMCW) radars can provide a ubiquitous solution to sense, monitor, and thus control the household appliances’ energy consumption. Radar has evolved from automotive applications such as driver assistance systems, safety and driver alert systems, and autonomous driving systems to low-cost solutions, penetrating industrial and consumer market segments. Radar has been used for perimeter intrusion detection systems [3], gesture recognition [4–6], human– machine interfaces [7], outdoor positioning and localization [8], and indoor people counting [9]. Radar responses from human targets are, in general, spread across Doppler due to the macro-Doppler component of the torso and associated micro-Doppler components due to hand, shoulder, and leg movements. With the use of higher sweep bandwidths, the radar echoes from targets are not received as point targets but are spread across range and are referred to as range-extended targets. Thus, human targets are perceived as doubly spread targets [10] across range and Doppler. Thereby, human target detection using radars with higher sweep bandwidths requires several adaptations in the standard signal processing pipeline before feeding the data into application-specific processing, e.g., target counting or target tracking. Depending on the application and the system used, there are two signal processing pipelines. One where the processing is performed on the range-Doppler domain, and the objective is to perform target detection and localization in the same domain.
Human Target Detection and Localization with Radars …
175
In such a case, the radar RDI processing pipeline involves MTI to remove static targets and MRC to integrate data across antennas. We refer to the RDI after this stage as raw RDI. The raw RDI is then fed into a constant false alarm rate (2D-CFAR) detection algorithm, which then detects whether a cell under test (CUT) is a valid target or not by evaluating the statistics of the reference cells. The constant-false alarm rate detection adaptively calculates the noise floor by calculating the statistics around the CUT and guarantees a fixed false alarm rate, which sets the threshold multiplier with which the estimated noise floor is scaled. Further, the detections are grouped together using a clustering algorithm such that reflections from the same human target are detected as a single cluster and likewise for different humans. Alternately, in the case of multiple virtual channels, the radar processing pipeline involves MTI to remove static targets, followed by receiver angle of arrival estimation, through Capon or minimum variance distortionless response (MVDR) beamforming to transform the raw RDI data to range-angle image (RAI). Generating the RAI provides a means to map the reflective distribution of the target from multi-frequency, multi-aspect data onto the angular beams enabling to localize the target in the spatial domain. Following RAI, the detection of targets is achieved through 2D-CFAR on the rangeangle domain, and eventually, targets are clustered using a clustering algorithm. The former processing pipeline is typically applied in the case of single receive channel sensor systems, where angular information is missing or in case of applications such as activity or fall classification where relevant information lies in the range-Doppler domain. Other applications, which require localizing the target in 2D space require beamforming and detection in the range-angle domain. However, human target detection in indoor environments poses further challenges, such as ghost targets from static objects like walls, chairs, furniture, etc., and also spurious radar responses due to multi-path reflections from multiple targets [11]. Further, strong reflecting or closer human targets often occlude less reflecting or farther human targets at the CFAR detector output. While the earlier phenomenon leads to overestimating the count of humans, the latter phenomenon leads to underestimating the count of humans in the room. This results in increased false alarms and low detection probabilities, leading to poor radar receiver operating characteristics and inaccurate application-specific decisions based on human target counts or target tracking. With the advent of deep learning, a variety of computer vision tasks, namely, face recognition [12, 13], object detection [14, 15], segmentation [16, 17], have seen superior state-of-the-art performance advances. The book in [18] gives a good overview over basic and advanced concepts in deep neural networks, especially in computer vision tasks. Image segmentation is a task wherein given an image, all the pixels belonging to a particular class are extracted and indicated in the output image. In [17], authors have proposed the deep U-Net architecture, which concatenates feature maps from different levels, i.e., from low-level to high-level to achieve image segmentation. In [19], authors have proposed deep residual connections to facilitate training speed and improve classification accuracy. Recently, deep residual U-net like network architectures have been proposed for identifying road lines from SAR images [20]. We have re-purposed the deep residual U-Net to process raw RDIs
176
M. Stephan et al.
and generate target detected RDIs or RAIs. The target detected RDIs and RAIs are demonstrated to suppress reflections from ghost targets, reject spurious targets due to multi-path reflections, avoid target occlusions and achieve accurate target clustering in the detected RDIs and RAIs, thus resulting in reliable and accurate human target detections. In [21], authors have proposed a fully convolutional neural network to achieve object detection and estimation of a vehicle target in 3D space by replacing the conventional signal processing pipeline. In [22], authors have proposed a deep neural network to distinguish the presence or absence of a target and demonstrate its improved performance compared to conventional signal processing. In [23], we have proposed using a deep residual U-net architecture to process raw RDIs into processed RDIs, achieving reliable target detections in the range-Doppler domain. We in this contribution, use a 60-GHz frequency modulated continuous wave (FMCW) radar sensor to demonstrate the performance of our proposed solution in both range-Doppler and range-angle domain. To handle our specific challenges and use-case, we define and train our proposed deep residual U-net model using an appropriate loss function for processing in the output range-Doppler domain. We also propose a complex-U Net model to process the raw RDI from two receive antennas to construct an RAI. For the complex neural network layers, we use the TensorFlow Keras, an open-source neural network library, implementation by Dramsch and Contributors [24] of the complex layers as described in [25]. We demonstrate the performance of our proposed system with real data with up to four persons in indoor environments for both cases and compare the results to the conventional signal processing approach. The paper is outlined as follows. Section 2 presents the system design, with details of the hardware chipset in Sect. 2.1, conventional processing pipeline in Sect. 2.2, challenges and contributions in Sect. 2.3. We present the models of our proposed residual U-Net and complex U-Net architecture in Sects. 3 and 4. We present and describe the dataset in Sect. 5.1, the loss function in Sect. 5.2 and the design considerations for training in Sect. 5.3. We present the results and discussions in Sect. 6 and we conclude in Sect. 7 also outlining possible future directions.
2 System Design 2.1 Radar Chipset The prototype indoor people counting system is based on Infineon’s BGT60TR24B FMCW radar chipset, shown in Fig. 1a. The functional block diagram of the bistatic FMCW radar with one transmit and one receive antenna, representing the scenario with people, furniture, and walls in indoor environments, is depicted in Fig. 1b. BGT60TR24B operates in frequencies ranging from 57 to 64 GHz wherein the chirp duration can be configured. The chip features an external phase-locked loop that controls the linear frequency sweep. The tune voltage output that controls the loop is
Human Target Detection and Localization with Radars …
177
(a) Chipset
(b) FMCW Radar RF signal chain Fig. 1 a Infineon’s BGT60TR24B 60-GHz radar sensor. b Representational figure of the radar scene and functional block diagram of the FMCW radar RF signal chain depicting 1TX, 1RX channel
varied from 1 to 4.5 V to enable the voltage-controlled oscillator to generate highly linear frequency chirps over its bandwidth. The radar response from the target in the field of view is mixed with a replica of the transmitted signal followed by low-pass filtering the resulting mixed-signal, which is then sampled at the analog to digital converter. The digital intermediate frequency (IF) signal contains target information such as range, Doppler, and angle, which can be estimated by digital signal processing algorithms. The chipset BGT60TR24 is configured with the system parameters given in Table 1. Consecutive sawtooth chirps are transmitted within a frame and processed by generating the RDI. Each chirp has NTS = 256 samples representing the DAC samples, B = 4 GHz represents the frequency sweep bandwidth, and Tc = 261 µs represents the chirp duration. PN = 32 and TPRT = 520 µs are the number of chirps transmitted within a frame and the chirp repetition time, respectively. The range resc = 3.75 cm and the maximum theoretical range is olution is determined as δr = 2B Rmax = (N T S/2) × δr = 4.8 m, the divide by 2 arises since BGT60TR24 has only an I (Inphase) channel. The maximum unambiguous velocity is vmax =
c = 4.8 m/s 2 f c TPRT
(1)
and the velocity resolution is δv =
c = 0.3 m/s 2 f c (PN/2)TPRT
(2)
178 Table 1 Operating parameters Parameters, symbol Ramp start frequency, f min Ramp stop frequency, f max Bandwidth, B Range resolution, δr Number of samples per chirp, N T S Maximum range, Rmax Sampling frequency, f s Chirp time, Tc Chirp repetition time, TPRT maximum Doppler, vmax Number of chirps, P N Doppler resolution, δv Number of Tx antennas, NTx Number of Rx antennas, NRx Elevation θelev per radar Azimuth θazim per radar
M. Stephan et al.
Value 58 GHz 62 GHz 4 GHz 3.75 cm 256 4.8 m 2 MHz 261 µs 520 µs 4.8 m/s 32 0.3 m/s 1 3 70◦ 70◦
We enabled 1 Tx and 3 Rx antennas from the L-shape configuration to cover both the elevation and azimuth angle calculations, although the chip contains 2 Tx and 4 Rx antennas. Both the elevation and azimuth 3 dB half-power beamwidth of BGT60TR24 are 70◦ .
2.2 Processing Pipeline Owing to the FMCW waveform and its ambiguity function properties, the accurate range and velocity estimates can be obtained by decoupling them on the generation of the RDIs. The IF signal from a chirp with NTS = 256 number of samples is received, PN consecutive chirps are collected and arranged in the form of a 2D matrix, with dimensions of PN × NTS. The RDI is generated in two steps. The first step involves calculating and subtracting the mean along fast time, followed by applying 1D window function, zero-padding, and then 1D Fast Fourier Transform (FFT) along fast time for all the PN chirps to obtain the range transformations. The fast time refers to the dimension of NTS, which represents the chirp time. Then in the second step, the mean along slow time is calculated and subtracted, a 1D window function and zero-padding are applied, followed by 1D FFT along slow time to obtain the Doppler transformation for all range bins. The slow time refers to the dimension along PN, which represents the intra-chirp time. The mean subtraction across the fast time removes the Tx-Rx leakage, and the subtraction across slow time removes
Human Target Detection and Localization with Radars …
179
the reflections from any static object in the field of view of the radar. Since short chirp times are used, the frequency shift along fast time is mainly due to the two-way propagation delay from the target, which is due to the distance of the target to the sensor. The amplitude and phase shift across slow time is due to the Doppler shift of the target. Based on the application and radar system, the processing pipeline can be either of the following: • raw absolute RDI −→ 2D CFAR −→ DBSCAN −→ Processed RDI • raw complex RDI (multiple channels) −→ MVDR −→ raw RAI −→ 2D CFAR −→ DBSCAN −→ Processed RAI In the first pipeline, the target detection and clustering operation is performed on the range-Doppler domain. The detection and clustering is applied on the absolute raw RDI, and the output is the processed RDI. This is typically applied in case of 1Tx 1Rx radar sensors or in case the application is to extract target Doppler signatures for classification, such as human activity classification, vital sensing, etc. While in the alternate pipeline, the raw complex RDI from multiple channels is processed through the MVDR algorithm to generate the RAI first, followed by 2D CFAR and clustering operation in that domain. The latter pipeline is employed where target detection through localization in 2D space is necessary for the application. In the former pipeline, in case of multiple receive channels the raw RDIs are optionally combined through maximal ratio combining (MRC) to gain diversity and improve signal quality. The objective of MRC is to construct single RDIs by weighted combinations of the RDIs across multiple channels. The gains for the weighted averaging is determined by estimating the signal-to-noise ratio (SNR) for each RDI across antennas. The effective RDI is computed as NRx r x r x r x=1 g |I | (3) I = NRx rx r x=1 g where I r x is the complex RDI of the r xth receive channel, and the gain is adaptively calculated as NTS × PN max{|I r x |2 } gr x = (4) NTS PN r x (m, l)|2 − max{|I r x |2 } |I l=1 m=1 where max{.} represents the maximum value from the 2D function, and gr x represents the estimated SNR at r th receive channel. Thus the (PN × NTS × NRx ) RDI tensor is transformed into a (PN × NTS) RDI matrix. Alternately, in the latter pipeline, instead of MRC the raw complex RDI is converted to an RAI using Capon or MVDR algorithm.
180
M. Stephan et al.
MVDR: Denoting the 3D positional coordinates of the Tx element as d Tx and the Rx elements as dnRx , n = 1, 2 in space, then on assuming far-field conditions, the signal propagation from a Tx element d Tx to a point scatterer p and subsequently the reflection from p to Rx element dnRx can be approximated as 2x + d sin(θ ), where x is the base distance of the scatterer to the center of the virtual linear array, d refers distance between receive elements, and θ is the incident angle to the center of the array with respect to bore-sight. Assuming that far-field conditions are satisfied, the time delay of the radar return from a scatterer at base distance x from the center of the virtual linear array can be expressed as d sin(θ ) 2x + (5) τn = c c The receiving steering vector is anRx (θ ) = exp(− j2π
dnRx sin(θ ) ); n = 1, 2 λ
(6)
where λ is the wavelength of the transmit signal. The two received de-ramped beat signal are used to resolve the relative angle θ of the scatterer. The azimuth imaging profile for each range bin can be generated using the Capon spectrum from the beamformer. The Capon beamformer is computed by minimizing the variance/power of noise while maintaining a distortionless response toward a desired angle. The corresponding quadratic optimization problem is min w H Cw s.t. w H a Rx (θ ) = 1, w
(7)
where C is the Covariance matrix of noise, the above optimization has a closed C −1 a(θ) form expression given as wcapon = a H (θ)C −1 a(θ) , with θ being the desired angle. On substituting wcapon in objective function of (7), the spatial spectrum is given as Pl (θ ) =
1 a Rx (θ ) H Cl−1 a Rx (θ )
(8)
with l = 0, . . . , L where L is the number of range bins. However estimation of noise covariance at each range bin l is difficult in practice, estihence Cˆl is estimated which contains the signal component as well K andIFcan be mated using sample matrix inversion (SMI) technique Cl = N1 k=1 sl (k)slIF (k)H , where K denotes the number of chirps in a frame used for signal plus noise covariance estimation and slIF (k) is the de-ramped intermediate frequency signal at range bin l [26]. 2D CFAR: Following the earlier operation in each pipeline, the raw RDI in the former case and raw RAI in the latter case is fed into a CFAR detection algorithm to
Human Target Detection and Localization with Radars …
181
generate a hit matrix that indicates the detected targets. The target detection problem can be expressed as 1 if |I (cut)|2 > μσ 2 (9) Idet (cut) = 0 if |I (cut)|2 < μσ 2 where Idet (cut) is the detection output for the cell under test (CUT), depending on the input image I . The product of μ and σ 2 , the estimated covariance, represents the threshold for detection. The threshold multiplier μ for the CFAR detection is set by an acceptable probability of false alarm, which for cell-averaging CFAR (CA-CFAR) is provided as −1/N − 1) (10) μ = N (Pfa where Pfa is the probability of false alarm, and N is the window size of the so-called “reference” cell used for noise power estimation. The noise variance σ 2 is estimated from the “reference” cells around the CUT. Thus, estimated noise covariance takes into account the local interference plus noise power around the CUT. While CA-CFAR is the most common detector used in case of point targets, in case of doubly spread targets, such as humans with wide radar bandwidth, it leads to poor detections since the spread targets are also present in the reference cells, leading to high noise power estimations and missed detections. Additionally, CA-CFAR elevates the noise threshold near a strong target, thus occluding nearby weaker targets. Alternately for such doubly spread targets, order-statistics CFAR (OS-CFAR) is used to avoid such issues since the ordered statistic is robust to any outliers, in this case, target’s spread, in the reference cells. Hence, instead of the mean power in the reference cells, the kth ordered data is selected as the estimated noise variance σ 2 . A detailed description of OS-CFAR can be found in [27]. DBSCAN: Contrary to a point target, in case of doubly extended targets, the output of the detection algorithm is not a single detection in the RDI for a target but spread across range and Doppler. Thus, a clustering algorithm is required to group the detections from a single target, based on its size, as a single cluster. For this, the density-based spatial clustering of applications with noise (DBSCAN) algorithm is used, which is the most used unsupervised learning algorithm in the machine learning community. Given a set of target detections from same and multiple targets in the RAI, DBSCAN groups detections that are closely packed together, while at the same time removing as outliers detections that lie alone in low-density regions. To do this, DBSCAN classifies each point as either a core point, edge point, or noise. Two input parameters are needed for the DBSCAN clustering algorithm, the neighborhood radius d, and the minimum number of neighbors M. A point is defined as a core point if it has at least M − 1 neighbors, i.e., points within the distance d. An edge point has less than M − 1 neighbors, but at least one of its neighbors is a core point. All points that have less than M − 1 neighbors and no core point as a neighbor do not belong to any cluster and will be classified as noise [28].
182
M. Stephan et al.
2.3 Challenges and Contributions Indoor environments typically pose a multitude of challenges to human presence detection, localization, and people counting systems using radar sensors. The primary error sources are multi-path reflections, ghost targets, target occlusion, merging targets, and split targets. Multipath reflections and ghost targets occur due to reflections from static objects like walls, chairs, or tables to the human target and back to the radar sensor. These ghost targets and spurious reflections appear as false alarms at the output of the CFAR detection algorithm. Figure 2a presents the indoor environment setup wherein a human walks close to the wall, Fig. 2b depicts the target detected RDI, using a conventional processing pipeline, where the true target is marked in green and ghost targets & multi-path artifacts are marked in red. Target occlusion effects may happen if another object is in the line of sight of the target to the radar sensor, if two targets are close together, or if the reflections from the target are weakened through other reasons, such as translation range migration, rotational range migration, or speckle [29]. Adaptive CFAR algorithms fail to detect the true target in such scenarios. The problem of merging targets and split targets partly originates from the parametric clustering algorithms. With DBSCAN clustering specifically, clusters may merge if the neighborhood radius is set too high. However, when set too low, arms, legs, or the head of a human target may be recognized as separate targets. Based on the indoor environment and activity of the human, the radar response from a target will have varying points on the range-Doppler domain. Thus setting a neighborhood radius that works for all the scenarios is very difficult if not impossible. The traditional processing pipeline with ordered statistics CFAR (OS-CFAR) and DBSCAN is depicted in Fig. 3a, whereas the proposed deep residual U-net architecture to process the target detected RDI is presented in Fig. 3b. Inspired by the deep residual U-Net for image segmentation problems, we in this contribution propose to use the deep residual U-Net architecture to generate detection RDIs, while additionally removing ghost targets, multi-path reflections, preventing target occlusion, and achieving accurate target clustering. After the target detected RDIs are computed, the number of targets and their parameters, i.e., range, velocity, angle, are estimated. The traditional processing pipeline, which takes raw RDIs from multiple receive antennas as input and processes the raw RDIs through MVDR, OS-CFAR, and DBSCAN is depicted in Fig. 4a. On the contrary, in the proposed deep complex U-net model, the raw RDIs from different channels are processed directly by the neural network to reconstruct the target detected RAI, which is presented in Fig. 4b. The objective of the proposed deep complex U-Net model is to reconstruct detection RAIs, whereby the ghost targets are removed, multi-path reflections removed, preventing target occlusion, and achieving accurate target clustering. After the target detected RDIs or RAIs are reconstructed, the number of targets and their parameters, i.e., range, velocity, angle, are estimated. The list of targets with their parameters is further fed to the application-specific processing or people tracking algorithm.
Human Target Detection and Localization with Radars …
183
Fig. 2 a Indoor room environment with a human walking around in the room. b Corresponding processed RDI as sensed by the radar with, OS-CFAR, and DBSCAN. Processed RDI depicts the true target (in green) and ghost targets (in red) due to reflections from walls, etc.
(a) Indoor room environment 0
range in m
1
2
3
4
4
2
0 velocity in m/s
2
4
(b) Traditional detected RDI with a ghost target (red)
3 Network Architecture—Deep Residual U-Net Figure 5 illustrates the network architecture. It has an encoder and a decoder path, each with three resolution steps. In the encoder path, each block contains two 3 × 3 convolutions, each followed by a rectified linear unit (ReLu), and a 3 × 3 max pooling with strides of two in every dimension. In the decoder path, each block consists of an upconvolution of 2 × 2 with strides of two in each dimension, followed by a ReLu and two 3 × 3 convolutions, each followed by a ReLu. The up-convolutions are implemented as an upsampling layer followed by a convolutional layer. Skip connections from layers of equal resolution in the encoder path provide the highresolution features to the decoder path, like the standard U-net. The bridge connection between the encoder and the decoder network consists of two convolutional layers, each followed by a ReLu, with a Dropout Layer in between. The dropout is set to 0.5 during training to reduce overfitting. In the last layer, a 1 × 1 convolution
184
M. Stephan et al. Raw RDI
0
1
OS-CFAR Detection
2 3
range in m
range in m
Traditioanl
0
DBSCAN Clustering
2
0
2
velocity in m/s
3
4
4
0
2
2
4
velocity in m/s
(a) Traditional Processing Pipeline
Detected Target RDI
Raw RDI
0
1
Deep Residual U-Net
2 3
range in m
0
range in m
2
4
4 4
Proposed
1
1 2 3 4
4 4
2
0
2
velocity in m/s
4
(b) Proposed Processing Pipeline
4
2
0
2
4
velocity in m/s
Fig. 3 a Traditional processing Pipeline with OS-CFAR and DBSCAN to generate target detected RDIs. b Processing Pipeline using the proposed deep residual U-Net to suppress ghost targets, multi-path reflections, mitigate target occlusions, and achieve accurate clustering
(a) Traditional Processing Pipeline
(b) Proposed Processing Pipeline Fig. 4 a Traditional processing Pipeline with MVDR, OS-CFAR, and DBSCAN to generate target detected RAIs from input RDIs across channels. b Processing Pipeline using the proposed deep complex U-Net to suppress ghost targets, multi-path reflections, mitigate target occlusions, and achieve accurate clustering
Human Target Detection and Localization with Radars …
185
Conv + ReLu + BN
Concat
MaxPool
Concat + UpConv + ReLu
Dropout
Concat + Conv + Softmax
Input Image
Output Image
Fig. 5 Proposed RDI presence detection architecture for a depth 3 network. Each box corresponds to one or more layers
UpCconv + ReLu
Input Image
Cconv + ReLu + BN
Concat strided Cconv + ReLu
Output Image
Conv + ReLu
Fig. 6 Proposed RAI presence detection and localization architecture for a depth 4 network. Each box corresponds to one or more layers
reduces the number of output channels to the number of classes, which is 2, target present/absent, in our case. The architecture has 40752 trainable parameters in total. As suggested in the literature, bottlenecks are prevented by doubling the number of channels before max pooling [30]. We also adopt the same scheme in the decoder path to avoid bottlenecks. The input to the network is a 128 × 32 × 1 raw RDI. Our output is a 128 × 32 × 1 image with pixel values between 0 and 1, representing the probability of target presence/absence in each pixel.
4 Network Architecture—Deep Complex U-Net Figure 6 shows the network architecture used for the localization task. While it looks quite similar to Fig. 5, it differs significantly in the details. It is still a U-Net like architecture, but it uses fully complex operations and 3D-convolutions. A brief description of complex convolutional layer and complex activation layer is provided below:
186
M. Stephan et al.
Fig. 7 Illustration of complex convolution on 2D CNN from one layer to next
1. Complex Convolutional Layer: Figure 7 illustrates the 2D CNN operation using complex kernels and input image map at any layer i. The complex convolutional layer generates feature maps from the range-Doppler dimensions of both receive channels. The kth feature map in ith layer can be expressed as Aˆ i,k + Bˆ i,k = (Ai−1, j ∗ Ci,k − Bi−1,k ∗ Di,k ) + j (Ai−1,k ∗ Di,k + Bi−1,k ∗ Ci,k )
(11)
where Ai−1,k , Bi−1,k presents the real and imaginary parts of the feature map at ith layer and kth map after the convolution operation. The kth kernel’s real and imaginary components Ci,k , Di,k are real and imaginary parts, respectively. The filter dimensions are Q i × Q i × K i , where Q i is applied on the range-Doppler dimension and K i along the spatial channels. 2. Complex Activation Function: The complex 2D layers progressively extract deeper feature representation. The activation function introduces non-linearity into the representation. The complex ReLU is implemented as Ai,k + j.Bi,k = R E LU ( Aˆ i,k ) + j.R E LU ( Bˆ i,k )
(12)
In effect, complex RELU maps the second quadrant data to the positive half of the imaginary axis, the third quadrant data to the origin, and fourth quadrant data to the positive half of the real axis. In the proposed architecture, the complex max pooling is avoided since it leads to model instabilities, and thus progressive range-Doppler dimension reduction was achieved through strided convolutions. In the encoder path, each block contains two 3 × 3 × 1 complex convolutional layers, each followed by a complex ReLu, and a 3 × 3 × 1 strided convolutional layer with a 2 × 2 × 1 stride. In the decoder path, the up-convolutional layers are implemented as a 2 × 2 × 1 upsampling layer followed by a 2 × 2 × 1 convolutional layer. Between each block of the same depth in the encoder path and the decoder path are skip connections to provide the feature location
Human Target Detection and Localization with Radars …
187
information to the decoder. Batch normalization is done after every convolutional layer in the encoder and the decoder. In the encoder path the size of the channel dimension is doubled with every block, while the sizes of the image dimensions are halved. The opposite is true for the decoder path, where the number of channels is halved, while the image is upsampled by a factor of two in every block. Due to the size of the 3d convolution filters, the network only combines the information provided by the individual antennas in the last complex convolutional layer, where a filter size of 1 × 1 × 2 is used without padding the input. The last convolutional layer, drawn in purple in Fig. 6 is a normal convolutional layer in order to combine the real and imaginary outputs of the previous layers and with a single channel output. The network with four blocks each in the encoder, and the decoder has 68181 trainable parameters. The inputs to the neural network are two 128 × 32 complex RDIs from two receiving antennas. The output is one 128 × 32 RAI with pixel values between 0 and 1, representing the probability of target presence/absence in each range-angle pixel.
5 Implementation Details 5.1 Dataset To create the labeled dataset, the raw RDIs are processed with the respective traditional processing pipelines, with MVDR beamforming to create the labeled RAIs and MRC instead to create the labeled RDIs. During recording, a camera was used to generate the dataset, with whose feedback we removed ghost targets and multipath reflections and added detections whenever targets were occluded due to other humans or static humans close to the wall. This was done by changing the parameters for the detection and clustering algorithms in the conventional pipeline, so that the probability of detection approaches 100%, also resulting in a high probability of false alarm. This means decreasing the CFAR scaling factor in case of target occlusion, and reducing/increasing the maximum neighbor distance for cluster detection with DBSCAN in case of merged/separated targets. All falsely detected clusters are then manually removed using the camera data as a reference. The described process is relatively simple for one-target measurements, as the correct cluster is generally the one closest in range to the radar sensor. The dataset comprises from one up to four humans in the room. The RDIs are augmented to increase the dataset and achieve a broader generalization of the model. Due to the sharply increasing difficulty in creating labeled data with multiple humans present, we synthetically computed RDIs with multiple targets by superimposing several raw one-target RDIs after translation and rotation. With this technique, a large number of RDIs, limited only by the number of possible combinations of the one-target measurements, can be synthesized. Some caution is required in having a large enough basis of one-target measurements, as the network
188
M. Stephan et al.
may otherwise overfit on the number of possible target positions. To increase the pool of one-target measurements for the RAI presence detection and localization task, the one-target measurements were augmented by multiplying the RDIs from both antennas by the same complex values with an amplitude close to one and a random phase. This operation changes the input values but should not change the estimated ranges and angles aside from a negligible range error in the range of half the wavelength.
5.2 Loss Function Given a set of training images, raw RDIs and the corresponding ground truth processed RDIs or RAIs Ii , Pi , the training procedure estimates the parameters of the network, such that the model generalizes to reconstruct accurate RDIs or RAIs. This is achieved through minimizing the loss between the true RDIs/RAIs Pi and that generated by h(Ii ; W ). Here, we use a weighted combination of focal loss [31], and hinge loss, as given in Eq. (13), for the loss function to train the proposed models H L( p) = 1 − y(2 p − 1) F L( pt ) = (1 − pt )ϒ log( pt ) L( pt ) = α (F L( pt ) + ηH L( p))
(13)
The variables y ∈ {±1}, p ∈ [0, 1], and pt ∈ [0, 1], specify the class label, the estimated probability for the class with label y = 1, and the probability that a pixel was correctly classified, as defined in Eq. (14) as p if y = 1 (14) pt = 1 − p otherwise The parameters γ , η, and α influence the shape of the focal loss, the weight of the hinge loss, and the class weighting, respectively. For training the deep residual U-Net model for reconstructing the processed range-Doppler image, these parameters were chosen to γ = 2, α = 0.25, and η was step-wise increased to η = 0.15. In case of training the deep complex U-Net model for reconstructing the processed range-angle image, the same parameters except for η = 0 was used. The chosen parameters led to the best learned model in terms of F1 score accuracy for both models. The reason for using the focal loss is the class imbalance in the training data due to the nature of the target setup in picture Fig. 2a. In most cases, the frequency of pixels labeled as “target absent” is much higher than the frequency of those labeled with “target present”. Additionally, the class frequencies may vary widely between single training samples, especially as the training set contains RDIs with different numbers of targets. The focal loss places a higher weight on the cross-entropy loss for misclassified pixels and a much lower weight on well-classified ones. Thus, pixels
Human Target Detection and Localization with Radars …
189
belonging to the rarer class are generally weighted higher. The hinge loss is added to the focal loss with a factor η to force the network to make clearer classification decisions. The value for η is chosen in such a way that the focal loss dominates the training for the first few epochs before the hinge loss becomes relevant.
5.3 Design Consideration The weight initialization of the network is performed with a Xavier uniform initializer, which draws samples from a uniform distribution within [−limit, limit], where the limit is calculated by taking the square root of six divided by the total number of input and output units in the weight tensor. The respective biases were initialized as zeros. For the backpropagation, the Adam [32] optimizer was used, with the default learning rate (alpha) of 0.001; the exponential decay rate for the first moment (beta1) was set to 0.9 and to 0.999 for the second moment (beta2). The epsilon that counters divide by zero problems is set to 10−7 .
6 Results and Discussions The most common metric for evaluating the detection-clustering performance is to evaluate the radar receiver operating characteristics (ROC) and the corresponding area under curve (AUC). In the case of target detection problems such as the one presented, classification accuracy can be a misleading metric since it doesn’t fully capture the model performance in this case. Precision helps when the costs of false positives, i.e., ghost targets or multi-path reflections, are high. Recall helps when the cost of false negatives, i.e., target occlusions, are substantial. F1 score is an overall measure of a model’s performance that combines precision and recall. A good F1 score, close to 1, means low false positives, i.e., ghost target detections and low false negatives, i.e., target occlusions, thus indicating correct target detections without being disturbed by false alarms. A total of 2000 test images are used for evaluating the performance of our proposed approach. The test set consists of one to four human target raw RDIs from two receive antennas, where the data was collected from different room configurations with humans performing several regular activities. In our experiments, we observed that using the RDIs after MTI as inputs to the neural network (NN) allows the network to generalize better in terms of different target scenes. Without removing the static targets from the input RDIs, the network appears prone to overfitting on the target scene used to create the training data. However, networks trained using static target removed RDIs, as presented in this contribution, do not suffer from such generalization issues. In order to evaluate the probability of detection and the probability of false alarm for the NN-based and the traditional signal processing approach, the respective processed RDI outputs are compared to the labeled data. Due to the difficulties in cre-
190
M. Stephan et al.
ating the labeled data, it is likely that only parts of each human are classified as “target present” in the RDI/RAI. Therefore, when defining missed detections and false alarms, small positional errors, and variations in the cluster size should be discounted. We did this by computing the center of mass for each cluster in the labeled data and the processed RDIs. Targets are identified as correctly detected only if the distance between the cluster center of masses of the processed and the labeled RDIs is smaller than 20 cm in range, and 1.5 m/s in velocity. Additionally, we enforce that each cluster in the network output can only be assigned to exactly one cluster in the labeled data and vice-versa. Table 2 presents the performance of the proposed approach in terms of F1 score in comparison to the traditional processing chain. Results for the proposed approach are shown for a depth three network (NN_d3), as shown in Fig. 5 and also for a deeper network with five residual blocks (NN_d5) in both encoder and decoder. For the depth 5 network, 2 more residual blocks were added in both the encoder, and the decoder, compared to the structure displayed in Fig. 5. The input to the bridge block between encoder and decoder path then has the dimension 4 × 1 × 125. The proposed approach gives a much better detection performance, an F1 score of 0.89, than the traditional processing pipeline with an F1 score of 0.71. The deeper network (NN_d5), with around 690 thousand trainable parameters, shows some further improvements in terms of detection performance with an F1 score of 0.91. Figure 8 presents the ROC curve of the proposed U-net architecture with depth three and five in comparison with traditional processing. The ROC curves for the depth three and the depth five networks are done for a varying hard threshold parameter, which describes the minimum pixel value for a point to be classified as a target in the NN output. In the ROC curve representing the traditional processing chain, the scaling factor for the CFAR detection method was varied. The curves are extrapolated for detection probabilities close to one. As depicted in the ROC curve, the proposed residual U-net architecture provides a much better AUC performance compared to the traditional processing pipeline. Similarily, Fig. 9 shows a better AUC performance for the complex U-net compared to the traditional processing. Figure 10a presents the raw RDI, Fig. 10b presents the detected RDI using traditional approaches, Fig. 10c presents the detected RDI using the proposed deep residual U-Net approach for a synthetic four target measurement. Originally, the
Table 2 Comparison of the detection performance of the traditional pipeline with the proposed U-net architecture with a depth of 3 and depth of 5 for localization in the range-Doppler image Approach Description F1-score Model size Traditional Proposed U-net depth 3 Proposed U-net depth 5
OS-CFAR with DBSCAN Proposed loss
0.71
–
0.89
616 kB
Proposed loss
0.91
2.8 MB
Human Target Detection and Localization with Radars …
191
ROC curve
1 0.9 0.8 0.7
pD
0.6 0.5 0.4 0.3 NN_d5 NN_d3 traditional
0.2 0.1 0
0
0.2
0.4
0.6
0.8
1
pFA
Fig. 8 Radar receiver operating characteristics (ROC) comparison between the proposed residual deep U-net and the traditional signal processing approach
probability of detection
ROC curve
probability of false alarm
Fig. 9 Radar receiver operating characteristics (ROC) comparison between the proposed complex deep U-net and the traditional signal processing approach. Dashed parts indicate extrapolation
NN outputs classification probabilities between zero and one for each pixel. To get the shown output, we used a hard threshold of 0.5 on the NN output, so that any pixels with corresponding values of 0.5 and higher are classified as belonging to a target. With the traditional approach, one target, at around 2 m distance, and −2 m/s in velocity, was mistakenly split into two separate clusters. Additionally, the two targets between 2 and 3 m were completely missed.
M. Stephan et al. 0
0
1
1
1
2 3 4
range in m
0
range in m
range in m
192
2 3 4
4
2
0
2
4
2 3 4
4
2
0
2
4
4
2
velocity in m/s
velocity in m/s
(a) Raw RDI
0
2
4
velocity in m/s
(b) Processed RDI traditional approach
(c) Processed RDI proposed approach
0
0
1
1
1
2 3 4
range in m
0
range in m
range in m
Fig. 10 a Raw RDI image with four human targets, b processed RDI using the traditional approach wherein one target is split and two targets are occluded, c processed RDI using proposed approach wherein all targets are detected accurately
2 3
2
0
2
4
3 4
4 4
2
4
2
0
2
4
4
velocity in m/s
velocity in m/s
(a) Raw RDI
2
0
2
4
velocity in m/s
(b) Processe RDI Traditional Approach
(c) Processed RDI Proposed Approach
0
0
1
1
1
2 3
range in m
0 range in m
range in m
Fig. 11 a Raw RDI image with four human targets, b processed RDI using the traditional approach wherein one target is occluded, c processed RDI using proposed approach wherein all targets are detected accurately
2 3
4
2
0 2 velocity in m/s
(a) Raw RDI
4
3 4
4
4
2
4
2
0 2 velocity in m/s
(b) Processe RDI Traditional Approach
4
4
2
0
2
4
velocity in m/s
(c) Processed RDI Proposed Approach
Fig. 12 a Raw RDI image with four human targets, b processed RDI using the traditional approach wherein two targets are merged into one by the DBSCAN clustering algorithm in addition to two ghost targets, c processed RDI using proposed approach wherein all targets are detected accurately and with proper clustering
Figure 11a–c presents the target occlusion problem on synthetic data for the conventional processing chain. While with the traditional approach one target at around 3 m distance was missed, the proposed approach in Fig. 11c is able to reliably detect all the targets. The figures in 12a–c, and 13a–c show two different scenarios the traditional approach struggles with.
193
0
0
1
1
1
2 3
range in m
0 range in m
range in m
Human Target Detection and Localization with Radars …
2 3 4
4 4
2
0
2
velocity in m/s
(a) Raw RDI
4
2 3 4
4
2
0
2
4
velocity in m/s
(b) Processed RDI traditional approach
4
2
0
2
4
velocity in m/s
(c) Processed RDI proposed approach
Fig. 13 a Raw RDI image with four human targets, b processed RDI using the traditional approach wherein a ghost target appears, c processed RDI using proposed approach wherein all targets are detected accurately
In Fig. 12b, the two detected targets around the one-meter mark are too close together for the DBSCAN clustering algorithm as to be detected as two distinct clusters. Therefore, these two clusters merge, and one target is missed. Fig. 13b showcases the ghost target problem. Here, one target at around 4 m in distance to the radar sensor was wrongly detected. In both cases, the U-Net-based approach correctly displays all the distinct targets, as seen in Figs. 12c and 13c. However, it has to be mentioned that while our proposed approach outperforms the traditional processing chain, missed detections, and false alarms may still occur in similar scenarios. From our experiments we have observed, that the proposed approach excels in discarding multi-path reflections and ghost targets caused by the reflections from static objects in the scene, does well in preventing splitting or merging targets, but does not really show improvements for the case of occluded targets if many humans are in front of each other. The cause of this lies in the nature of how most of the training data was created, where one-target measurements were superimposed over each other in order to synthesize the multi-target RDIs. In our experiments, we noticed that the loss function plays a crucial role in achieving excellent detection results. While the current loss function deals well with the class imbalance problem and accelerates training for more important pixels in the RDI, it could be improved by a more severe punishment of merging or split targets, and by allowing small positional errors in the clusters without increasing the loss. The evaluation of the proposed method for target detection and localization in terms of range and angle is again done via ROC-curves and the F1-scores. We use the same set of test measurements as described earlier, but with the complex RDIs from two receiving antennas. For the evaluation of the range-angle output maps of the neural network, a simple clustering algorithm is used on these output images. In our case, we use DBSCAN with small values for the minimum number of neighbors and the maximum neighbor distance so that every nearly continuous cluster is recognized as one target. We then compute the center of masses of each cluster and compare them to the center of masses computed from our labeled data, as described earlier. In this case, a target still counts as detected if the distance between the two centers of mass of the clusters is below some threshold; in this case, 7.8◦ or 37.5 cm. Table 3 shows a comparison of F1-scores for the proposed method and a traditional
194
M. Stephan et al.
Table 3 Comparison of the detection performance of the traditional pipeline with the proposed U-net architecture with a depth of 4 for localization in the range-angle image Approach Description F1-score Model size Traditional Complex U-net depth 3 Complex U-net depth 4
OS-CFAR with DBSCAN Focal loss Focal loss
0.61
–
0.72 0.77
529 kB 1.23 MB
signal processing chain. Compared to the F1-score of 0.62 for the classical method, the proposed method shows a clear improvement with an F1-score of 0.77. The ROC curves in Fig. 9 stop at values smaller one due to how the detections were evaluated. The dashed lines indicate an extrapolation. If the threshold or the CFAR scaling factor is set close to zero, then every pixel will be detected as a target, which would then be identified as one single huge cluster by the DBSCAN algorithm. Therefore, a probability of false alarm of 1 is not achievable. In the evaluation, we saw, that the proposed method performs, as expected, better for fewer targets while its performance worsens mainly with a rising target density. However, even if the neural network is trained with only one and two target measurements, it will still be able to correctly identify the positions of three or four targets in a lot of cases. Comparing the F1-scores from Table 3 to those in Table 2, it seems like the network has a harder time doing the range-angle localization task. This has several reasons. First, only two receiving antennas were used, making an accurate angle estimation more difficult in general. The second explanation, which likely has a bigger impact, is that the evaluation methods were not the same for both experiments. Specifically, the minimum center of mass distance, as the difference in velocity is not comparable to an angle difference. If we increase the angle threshold from 7.8◦ to 15.6◦ , we get an F1-score of about 0.9 and 0.68 for the proposed and the traditional method, respectively. Therefore, the proposed method does well in removing ghost targets and estimating the correct target ranges but is not quite as good in accurately estimating the angles. The traditional method does not gain as much due to this change since most of its errors are due to ghost targets or missed detections. In Figs. 14, 15, 16, and 17 some examples with the input image, the output image from the traditional processing chain, and the output image from the neural network are shown. The input image here is the RDI from one of the antennas, as the complex two antenna input is hard to visualize. It is mainly there to illustrate the difficulty of the range-angle presence detection and localization. In Fig. 14, two targets were missed by the classical signal processing chain, whereas all four targets were detected by the neural network with one additional false alarm. In Fig. 15, all four targets were detected by the network, while one was missed by the traditional approach. In Fig. 16, only one target was correctly identified by the classical approach, three targets missed, and one false alarm included. Here, the neural network only missed one target. In the last example, the network again identified all targets at the correct
0.75
1.5
1.5
2.25 3 3.75
195
0.75
range in m
0.75
range in m
range in m
Human Target Detection and Localization with Radars …
2.25 3
1.5 2.25 3
3.75
4.5
3.75
4.5 -4
-2
0
2
4
4.5 -50
-30
velocity in m/s
-10
10
30
-50
50
-30
angle in °
(a) Raw RDI
-10
10
30
50
angle in °
(b) Processed RAI traditional approach
(c) Processed RAI proposed approach
0.75
1.5
1.5
2.25 3 3.75 4.5
0.75 1.5
range in m
0.75
range in m
range in m
Fig. 14 a Raw RDI image with four human targets, b processed RAI using the traditional approach with two missed detections, c processed RAI using proposed complex U-net approach wherein onetarget split occurs
2.25 3
2.25 3
3.75
3.75
4.5
-4
-2
0
2
4.5
-50
4
-30
-10
10
30
-50
50
-30
(a) Raw RDI
-10
10
30
50
angle in °
angle in °
velocity in m/s
(b) Processed RAI traditional approach
(c) Processed RAI proposed approach
0.75
1.5
1.5
2.25 3 3.75
0.75
range in m
0.75
range in m
range in m
Fig. 15 a Raw RDI image with four human targets, b processed RAI using the traditional approach with one missed detection, c processed RAI using proposed complex U-net approach wherein all targets are detected accurately
2.25 3 3.75
4.5 -2
0
velocity in m/s
(a) Raw RDI
2
4
3 3.75
4.5 -4
1.5 2.25
4.5 -50
-30
-10
10
30
angle in °
(b) Processed RAI traditional approach
50
-50
-30
-10
10
30
50
angle in °
(c) Processed RAI proposed approach
Fig. 16 a Raw RDI image with four human targets, b processed RAI using the traditional approach with one ghost target and three missed detections, c processed RAI using proposed complex U-net approach wherein one target was missed
positions, while the traditional approach has two missed detections and one-target split, resulting in one false alarm.
196
M. Stephan et al.
0.75
1.5 2.25 3 3.75
0.75
1.5
range in m
range in m
range in m
0.75
2.25 3 3.75
4.5 -2
0
velocity in m/s
(a) Raw RDI
2
4
-50
3 3.75
4.5 -4
1.5 2.25
4.5 -30
-10
10
30
50
angle in °
(b) Processed RAI traditional approach
-50
-30
-10
10
30
50
angle in °
(c) Processed RAI proposed approach
Fig. 17 a Raw RDI image with three human targets, b processed RAI using the traditional approach with one-target split and two missed detections, c processed RAI using proposed complex U-net approach wherein all targets are detected accurately
7 Conclusion The traditional radar signal processing pipeline for detecting targets on either RDI or RAI is prone to ghost targets, multi-path reflections from static objects, and target occlusion in cases of human detections, especially in an indoor environment. Further, parametric clustering algorithms suffer from single target splits, and multiple target merges into single clusters. To overcome such artifacts and facilitate accurate human detections, localization and counting, we, in this contribution, proposed to use deep residual U-Net model and deep complex U-Net model to generate accurate human detections in RDI and RAI domain, respectively, in indoor scenarios. We trained the models using custom loss function, proposed architectural designs, and training strategy through data augmentation to achieve accurate processed RDI and RAI. We demonstrated the superior detection and clustering results in terms of F1 score and ROC characterization compared to the conventional signal processing approach. As future work, variational autoencoder generative adversarial network (VAE-GAN) architecture can be deployed to minimize the sensitivity of U-Net models to variations in data due to sensor noise and interference.
References 1. EPRI, Occupancy sensors: positive on/off lighting control, in Rep. EPRIBR-100323 (1994) 2. V. Garg, N. Bansal, Smart occupancy sensors to reduce energy consumption. Energy Build. 32, 81–87 (2000) 3. W. Butler, P. Poitevin, J. Bjomholt, Benefits of wide area intrusion detection systems using FMCW radar (2007), pp. 176–182 4. J. Lien, N. Gillian, M. Emre Karagozler, P. Amihood, C. Schwesig, E. Olson, H. Raja, I. Poupyrev, Soli: ubiquitous gesture sensing with millimeter wave radar. ACM Trans. Graph. 35, 1–19 (2016) 5. S. Hazra, A. Santra, Robust gesture recognition using millimetric-wave radar system. IEEE Sens. Lett. PP, 1 (2018)
Human Target Detection and Localization with Radars …
197
6. S. Hazra, A. Santra, Short-range radar-based gesture recognition system using 3D CNN with triplet loss. IEEE Access 7, 125623–125633 (2019) 7. M. Arsalan, A. Santra, Character recognition in air-writing based on network of radars for human-machine interface. IEEE Sen. J. PP, 1 (2019) 8. C. Will, P. Vaishnav, A. Chakraborty, A. Santra, Human target detection, tracking, and classification using 24 GHZ FMCW radar. IEEE Sens. J. PP, 1 (2019) 9. A. Santra, R. Vagarappan Ulaganathan, T. Finke, Short-range millimetric-wave radar system for occupancy sensing application. IEEE Sens. Lett. PP, 1 (2018) 10. H.L.V. Trees, Detection, Estimation, and Modulation Theory, Part I (Wiley, 2004) 11. A. Santra, I. Nasr, J. Kim, Reinventing radar: the power of 4D sensing. Microw. J. 61, 26–38 (2018) 12. F. Schroff, D. Kalenichenko, J. Philbin, Facenet: a unified embedding for face recognition and clustering (2015), pp. 815–823 13. O. M. Parkhi, A. Vedaldi, A. Zisserman, Deep face recognition, vol. 1 (2015), pp. 41.1–41.12 14. S. Ren, K. He, R. Girshick, J. Sun, Faster r-cnn: towards real-time object detection with region proposal networks. IEEE Trans. Pattern Anal. Mach. Intell. 39, 06 (2015) 15. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object detection (2016), pp. 779–788 16. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A.L. Yuille, Deeplab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected crfs. IEEE Trans. Pattern Anal. Mach. Intell. PP (2016) 17. O. Ronneberger, P. Fischer, T. Brox, U-net: convolutional networks for biomedical image segmentation (2015) 18. M.A. Wani, F.A. Bhat, S.Afzal, A.I. Khan, Advances in Deep Learning (Springer, Singapore, 2020) 19. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition (2016), pp. 770–778 20. Z. Zhang, Q. Liu, Y. Wang, Road extraction by deep residual U-net. IEEE Geosci. Remote Sens. Lett. PP (2017) 21. G. Zhang, H. Li, F. Wenger, Object detection and 3D estimation via an FMCW radar using a fully convolutional network (2019). arXiv preprint arXiv:1902.05394 22. L. Wang, J. Tang, Q. Liao, A study on radar target detection based on deep neural networks. IEEE Sens. Lett. 3(3), 1–4 (2019) 23. M. Stephan, A. Santra, Radar-based human target detection using deep residual U-net for smart home applications, in 18th IEEE International Conference on Machine Learning And Applications (ICMLA) (IEEE, 2019), pp. 175–182 24. J.S. Dramsch, Contributors, Complex-valued neural networks in keras with tensorflow (2019) 25. C. Trabelsi, O. Bilaniuk et al., Deep complex networks (2017). arXiv preprint arXiv:1705.09792 26. L. Xu, J. Li, P. Stoica, Adaptive techniques for MIMO radar, in Fourth IEEE Workshop on Sensor Array and Multichannel Processing, vol. 2006 (IEEE, 2006), pp. 258–262 27. H. Rohling, Radar CFAR thresholding in clutter and multiple target situations. IEEE Trans. Aerosp. Electron. Syst. 19, 608–621 (1983) 28. M. Ester, H.-P. Kriegel, J. Sander, X. Xu, A density-based algorithm for discovering clusters in large spatial databases with noise, in KDD (1996) 29. A. Santra, R. Santhanakumar, K. Jadia, R. Srinivasan, SINR performance of matched illumination signals with dynamic target models (2016) 30. C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, Z. Wojna, Rethinking the inception architecture for computer vision (2016) 31. T.-Y. Lin, P. Goyal, R. Girshick, K. He, P. Dollar, Focal loss for dense object detection (2017), pp. 2999–3007 32. D. Kingma, J. Ba, Adam: a method for stochastic optimization, vol. 12 (2014)
Thresholding Strategies for Deep Learning with Highly Imbalanced Big Data Justin M. Johnson and Taghi M. Khoshgoftaar
Abstract A variety of data-level, algorithm-level, and hybrid methods have been used to address the challenges associated with training predictive models with classimbalanced data. While many of these techniques have been extended to deep neural network (DNN) models, there are relatively fewer studies that emphasize the significance of output thresholding. In this chapter, we relate DNN outputs to Bayesian a posteriori probabilities and suggest that the Default threshold of 0.5 is almost never optimal when training data is imbalanced. We simulate a wide range of class imbalance levels using three real-world data sets, i.e. positive class sizes of 0.03–90%, and we compare Default threshold results to two alternative thresholding strategies. The Optimal threshold strategy uses validation data or training data to search for the classification threshold that maximizes the geometric mean. The Prior threshold strategy requires no optimization, and instead sets the classification threshold to be the prior probability of the positive class. Multiple deep architectures are explored and all experiments are repeated 30 times to account for random error. Linear models and visualizations show that the Optimal threshold is strongly correlated with the positive class prior. Confidence intervals show that the Default threshold only performs well when training data is balanced and Optimal thresholds perform significantly better when training data is skewed. Surprisingly, statistical results show that the Prior threshold performs consistently as well as the Optimal threshold across all distributions. The contributions of this chapter are twofold: (1) illustrating the side effects of training deep models with highly imbalanced big data and (2) comparing multiple thresholding strategies for maximizing class-wise performance with imbalanced training data.
J. M. Johnson (B) · T. M. Khoshgoftaar Florida Atlantic University, Boca Raton, FL 33431, USA e-mail: [email protected] T. M. Khoshgoftaar e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_9
199
200
J. M. Johnson and T. M. Khoshgoftaar
1 Introduction Class imbalance exists when the total number of samples from one class, or category, is significantly larger than any other category within the data set. This phenomenon arises in many critical industries, e.g. financial [1], biomedical [2], and environmental [3]. In each of these examples, the positive class of interest is the smaller class, i.e. the minority group, and there is an abundance of less-interesting negative samples. In this study, we focus specifically on binary classification problems that contain a positive and negative class. The concepts presented can be extended to the multiclass problem, however, because multi-class problems can be converted into a set of two-class problems through class decomposition [4]. Imbalanced data sets have been shown to degrade the performance of classification models, often causing models to over-predict the majority group. As a result, instances belonging to the minority group are incorrectly classified as negative samples and positive class performance suffers. To make matters worse, popular evaluation metrics like accuracy are liable to mislead analysts with high scores that incorrectly indicate good prediction performance. For example, given a binary data set with a positive class size of just 1%, a simple learner always outputs the negative class will score 99% accuracy. The extent of this performance degradation depends on problem complexity, data set size, and the level of class imbalance [5]. In this study, we use deep neural network (DNN) models to make predictions with complex data sets that are characterized by both big data and high class imbalance. Generally, DNN model performance degrades as the level of class imbalance increases and the relative size of the positive class decreases [6]. We denote the level of class imbalance in a binary data set using the ratio of negative samples to positive samples, i.e. nneg : npos . For example, the imbalance of a data set with 400 negative samples and 100 positive instances is denoted by 80:20. Equivalently, we sometimes refer to the positive class’s prior probability, e.g. the 80:20 distribution has a positive class prior of 0.2 or 20%. We classify a data set as highly imbalanced when the positive class prior is ≤ 0.01 [7]. When the total number of positive occurrences becomes even more infrequent, we describe the data set as exhibiting rarity. Weiss et al. [8] distinguish between absolute rarity and relative rarity. Absolute rarity occurs when there are very few samples for a given class, regardless of the size of the majority class. Unlike absolute rarity, a relatively rare class can make up a small percentage of a data set and still have many occurrences when the overall data set is very large. For example, given a data set with 10 million records and a relative positive class size of 1%, there are still 100,000 positive cases to use for training machine learning models. Therefore, we can sometimes achieve good performance with relatively rare classes, especially when working with large volumes of data. The experiments in this chapter explore a wide range of class imbalance levels, e.g. 0.03–90%, and include cases of relative rarity. The challenges of working with class-imbalanced data are often compounded by the challenges of big data [9]. Big data commonly refers to data which exceeds the capabilities of standard data storage and processing. These data sets are also defined
Thresholding Strategies for Deep Learning …
201
using the four Vs: volume, variety, velocity, and veracity [10, 11]. The large volumes of data being collected require highly scalable hardware and efficient analysis tools, often demanding distributed implementations. In addition to adding architecture and network overhead, distributed systems have been shown to exacerbate the negative effects of class-imbalanced data [12]. The variety of big data corresponds to the mostly unstructured, diverse, and inconsistent representations that arise as data is consumed from multiple sources over extended periods of time. Advanced techniques for quickly processing incoming data streams and maintaining appropriate turnaround times are required to keep up with the rate at which data is being generated, i.e. data velocity. Finally, the veracity of big data, i.e. its accuracy and trustworthiness, must be regularly validated to ensure results do not become corrupted. MapReduce [13] and Apache Spark [14] are two popular frameworks that address these big data challenges by operating on partitioned data in parallel. Neural networks can also be trained in a distributed fashion using either data parallelism or model parallelism techniques [15, 16]. The three data sets used in this study include many of these big data characteristics. Data-level and algorithm-level techniques for addressing class imbalance have been studied extensively. Data methods use sampling to alter the distribution of the training data, effectively reducing the level of class imbalance for model training. Random over-sampling (ROS) and random under-sampling (RUS) are the two simplest data-level methods for addressing class imbalance. ROS increases the size of the minority class by randomly copying minority samples, and RUS decreases the size of the majority class by randomly discarding samples from the majority class. From these fundamental data-level methods, many more advanced variants have been developed [17–20]. Algorithm-level methods modify the training process to increase the impact of the minority class and reduce bias toward the majority class. Direct algorithm-level methods modify a machine learning algorithm’s underlying learner, usually by incorporating class costs or weights. Meta-learner methods use a wrapper to convert non-cost-sensitive learners into cost-sensitive learners. Costsensitive learning and output thresholding are examples of direct and meta-learner algorithm-level methods, respectively [21]. Output thresholding is the process of changing the decision threshold that is used to assign class labels to a model’s posterior probabilities[22, 23]. Finally, there are a number of hybrid methods that combine two or more data-level methods and algorithm-level methods [7, 24–26]. In this chapter, we explore output thresholding as it relates to deep learning. Deep learning is a subfield of machine learning that uses artificial neural network (ANN) with two or more hidden layers to approximate some function f ∗ , where f ∗ can be used to map input data to new representations or make predictions [27]. The ANN, inspired by the biological neural network, is a set of interconnected neurons, or nodes, where connections are weighted and each neuron transforms its input into a single output by applying a nonlinear activation function to the sum of its weighted inputs. In a feedforward network, input data propagates through the network in a forward pass, each hidden layer receiving its input from the previous layer’s output, producing a final output that is dependent on the input data, the choice of activation function, and the weight parameters [28]. Gradient descent optimization
202
J. M. Johnson and T. M. Khoshgoftaar
adjusts the network’s weight parameters in order to minimize the loss function, i.e. the error between expected output and actual output. Composing multiple nonlinear transformations creates hierarchical representations of the input data, increasing the level of abstraction through each transformation. The deep learning architecture, i.e. deep neural network (DNN), achieves its power through this composition of increasingly complex abstract representations [27]. Deep learning methods have proven very successful in solving complex problems related to natural language and vision [29]. These recent successes can be attributed to an increased availability of data, improvements in hardware and software [30–34], and various algorithmic breakthroughs that speed up training and improve generalization to new data [35]. Ideally, we would like to leverage the power of these deep models to improve the classification of highly imbalanced big data. While output thresholding has been used with traditional learners to treat class imbalance in both big and non-big data problems, we found there is little work that properly evaluates its use in DNN models. Most commonly, the Default threshold of 0.5 is used to assign class labels to a classifier’s posterior probability estimates. In this chapter, however, we argue that the Default threshold is rarely optimal when neural network models are trained with class-imbalanced data. This intuition is drawn from the fact that neural networks have been shown to estimate Bayesian a posteriori probabilities when trained with sufficient data [36]. In other words, a well-trained DNN model is expected to output a posterior probability estimate for input x that corresponds to yc (x) (Eq. 1), where factors p(c) and p(x) are prior probabilities and p(x | c) is the conditional probability of observing instance x given class c. We do not need to compute each factor individually since neural networks do not estimate these factors directly, but we can use Bayes Theorem and the positive class prior from our training data to better understand the posterior estimates produced by our models. The estimated positive class prior p(cpos ) is the probability that a random sample drawn from the training data belongs to the positive class, and it is equal to the number of positive training samples divided by the total number of training samples. In highly imbalanced problems, e.g. p(cpos ) ≤ 0.01, small positive class priors can significantly decrease posterior estimates and, in some cases, yc (x) may never exceed 0.5. If yc (x) ≤ 0.5 for all x, the Default threshold will incorrectly assign all positive samples to the negative class. We can account for this imbalance by identifying Optimal thresholds that balance positive and negative class performance. We build on this concept empirically by exploring two thresholding strategies with three real-world data sets. The Optimal thresholding strategy identifies the threshold that maximizes performance on training or validation data and then uses the optimal value to make predictions on the test set. The Prior thresholding strategy estimates p(cpos ) from the training data and then uses p(cpos ) as the classification threshold for making predictions on the test set. yc (x) = p(c | x) =
p(c) · p(x | c) p(x)
(1)
The first two data sets, Medicare Part B [37] and Part D [38], summarize claims that medical providers have submitted to Medicare and include a class label that
Thresholding Strategies for Deep Learning …
203
indicates whether or not the provider is known to be fraudulent. Medicare is a United States healthcare program that provides affordable health insurance to individuals 65 years and older, and other select individuals with permanent disabilities [39]. The Part B and Part D data sets each have 4.6 and 3.6 million observations, respectively, with positive class sizes < 0.04%. They were made publicly available by the Centers for Medicare & Medicaid Services (CMS) in order to increase transparency and reduce fraud. The third data set was first published by the Evolutionary Computation for Big Data and Big Learning (ECBDL) workshop in 2014 and is now publicly available [40]. Samples within the ECBDL’14 data set contain features that describe pairs of amino acids and labels that indicate whether or not amino acid pairs are spatially close in three-dimensional space. The accurate prediction of these protein contact maps enables further inferences on the three-dimensional shape of proteins. We use a subset of the ECBDL’14 data set consisting of 3.5 million instances and maintain the original positive class size of 0.2%. The ECBDL’14 data set is not as highly imbalanced as the Medicare data set, so we use data sampling techniques to simulate high class imbalance. The primary contribution of this study is in providing a unique and thorough analysis of treating class imbalance with DNN models and output thresholding. These thresholding techniques are applied to deep feedforward networks and can be extended to other deep learning architectures described in “Advances in Deep Learning” [41]. ROS, RUS, and a hybrid ROS-RUS are used to create over 30 training distributions with positive class priors ranging from 0.03 to 90%. Optimal thresholds are computed for each distribution by maximizing performance on training and validation sets. As expected, linear models reveal a strong relationship between the positive class size and the Optimal threshold. In another experiment, we compute Optimal thresholds for each training epoch and visualize its stability over time. Classification experiments are repeated 30 times to account for random error, and multiple deep architectures are used to determine if results generalize to deeper models. Performance results, confidence intervals, and figures are used to show that the Default threshold should not be used when training DNN models with class-imbalanced data. Finally, Tukey’s HSD (honestly significant difference) test [42] results show that the Prior threshold performs as well as the Optimal threshold on average. The remainder of this chapter is outlined as follows. Section 2 discusses several methods for training deep models with class-imbalanced data and other related works that have used the Medicare and ECBDL’14 data sets. In Sects. 3 and 4, we describe the data sets used in this study and present our experiment design, respectively. Results are discussed in Sect. 5, and Sect. 6 closes with suggestions for future works.
2 Related Work This section begins by summarizing other works that address class imbalance with DNN models. We then introduce works related to Medicare fraud prediction and contact map prediction.
204
J. M. Johnson and T. M. Khoshgoftaar
2.1 Deep Learning with Class-Imbalanced Data The effects of class imbalance on backpropagation and optimization were studied by Anand et al. [43] using shallow neural networks. During optimization, the network loss is dominated by the majority class, and the minority group has very little influence on the total loss and network weight updates. This tends to reduce the error of the majority group very quickly during early iterations while consequently increasing the error of the minority group. In recent studies, we explored a variety of data-level, algorithm-level, and hybrid methods for addressing this phenomenon with deep neural networks [44, 45]. Several authors explored data sampling methods [6, 46, 47] and found ROS to outperform RUS and baseline models. Others employed cost-sensitive loss functions or proposed new loss functions that reduce the bias toward the majority class [48–51]. Some of the best results were achieved by more complex hybrid methods that leverage deep feature learning and custom loss functions [52–54]. Output thresholding, which is commonly used with traditional learners to maximize class performance, has received very little attention in related deep learning works. Lin et al. [51] proposed the Focal Loss function for addressing the severe class imbalance found in object detection problems. While their study is not specifically about thresholding, they do disclose using a threshold of 0.05 to speed up inference. Dong et al. [54] also present a loss function for addressing class imbalance, i.e. the Class Rectification Loss. They compare their proposed loss function to a number of alternative methods, including thresholding. Results from Dong et al. show that thresholding outperforms ROS, RUS, cost-sensitive learning, and other baseline models on the imbalanced X-Domain [55] image data set. These studies were not intended to showcase thresholding, yet, their results clearly indicate that thresholding plays an important role in classifying imbalanced data with deep models. To the best of our knowledge, Buda et al. [6] were the only authors to explicitly isolate the thresholding method and study its ability to improve the classification of imbalanced data with deep models. ROS and RUS were used to create training distributions with varying levels of class imbalance from the MNIST [56] and CIFAR10 [57] benchmarks, and the authors evaluated minority class sizes between 0.02 and 50%. Thresholding was achieved by dividing CNN outputs by prior class probabilities, and the accuracy performance metric was used to show how thresholding improves class-wise performance in nearly all cases. In addition to outperforming ROS, RUS, and the baseline CNN, the authors show that combining thresholding with ROS performs exceptionally well and outperforms plain ROS. We expand on the work by Buda et al. by incorporating statistical analysis and complementary performance metrics, e.g. geometric mean (G-Mean), true positive rate (TPR), and true negative rate (TNR). In addition, we provide a unique analysis that compares the Optimal decision threshold to the positive class size of training distributions across a wide range of class imbalance levels.
Thresholding Strategies for Deep Learning …
205
2.2 Medicare Fraud Detection The big data, big value, and high class imbalance inherent in Medicare fraud prediction make it an excellent candidate for evaluating methods designed to address class imbalance. Bauder and Khoshgoftaar [58] use a subset of the 2012–2013 Medicare Part B data, i.e. Florida claims only, to model expected amounts paid to providers for services rendered to patients. In another study, Bauder and Khoshgoftaar [59] proposed an outlier detection method that uses Bayesian inference to identify outliers, and successfully validated their model using claims data of a known Florida provider that was under criminal investigation for excessive billing. This experiment used a subset of 2012–2014 Medicare Part B data that included dermatology and optometry claims from Florida office clinics. Another paper by Bauder et al. [60] uses a Naive Bayes classifier to predict provider specialty types, and then flag providers that are practicing outside their expected specialty type as fraudulent. Results show that specialties with unique billing procedures, e.g. audiologist or chiropractic, are able to be classified with high precision and recall. Herland et al. [61] expanded on the work from [60] by incorporating 2014 Medicare Part B data and real-world fraud labels defined by the List of Excluded Individuals and Entities (LEIE) [62] data set. The authors find that grouping similar specialty types, e.g. Ophthalmology and Optometry, improves overall performance. Bauder and Khoshgoftaar [63] merge the 2012–2015 Medicare Part B data sets, label fraudulent providers using the LEIE data set, and compare multiple traditional machine learning classifiers. Class imbalance is addressed with RUS, and various class distributions are generated to identify the optimal imbalance ratio for training. The C4.5 decision tree and logistic regression (LR) learners significantly outperform the support vector machine (SVM), and the 80:20 class distribution is shown to outperform 50:50, 65:35, and 75:25 distributions. In summary, these studies show that Medicare Part B and Part D claims data contains sufficient variability to detect fraudulent providers and that the LEIE data set can be reliably used for ground truth fraud labels. The Medicare experiments in our study leverage data sets curated by Herland et al. [64]. In one study, Herland et al. [64] used Medicare Part B, Part D, and DMEPOS claims data from the years 2012–2016. Cross-validation and ROC AUC scores are used to compare LR, Random Forest (RF), and Gradient Boosted Tree (GBT) learners. Results show that Part B data sets score significantly better on ROC AUC than the Part D data set, and the LR learner outperforms the GBT and RF learners with a max AUC of 0.816. In a second paper, Herland et al. [65] used these same Medicare data sets to study the effect of class rarity with LR, RF, and GBT learners. In this study, the authors create an absolutely rare positive class by using subsets of the positive class to form new training sets. They reduced the positive class size to 1000, 400, 200, and 100, and then used RUS methods to treat imbalance and compare AUC scores. Results show that smaller positive class counts degrade model performance, and the LR learner with an RUS distribution of 90:10 performs best. Several other research groups have taken interest in detecting Medicare fraud using the CMS Medicare and LEIE data sets. Feldman and Chawla [66] looked for
206
J. M. Johnson and T. M. Khoshgoftaar
anomalies in the relationship between medical school training and the procedures that physicians perform in practice by linking 2012 Medicare Part B data with provider medical school data obtained through the CMS physician compare data set [67]. Significant procedures for each school were used to evaluate school similarities and present a geographical analysis of procedure charges and payment distributions. Ko et al. [68] used the 2012 CMS data to analyze the variability of service utilization and payments. Ko et al. found that the number of patient visits is strongly correlated with Medicare reimbursement, and concluded that there is a possible 9% savings within the field of Urology alone. Chandola et al. [69] used claims data and fraud labels from the Texas Office of Inspector General’s exclusion database to detect anomalies. The authors confirm the importance of including provider specialty types in fraud detection, showing that the inclusion of specialty attributes increases AUC scores from 0.716 to 0.814. Branting et al. [70] propose a graph-based method for estimating healthcare fraud risk within the 2012–2014 CMS and LEIE data sets. The authors leverage the NPPES [71] registry to look up providers that are missing from the LEIE database, increasing their total fraudulent provider count to 12,000. Branting et al. combine these fraudulent providers with a subset of 12,000 nonfraudulent providers and employ a J48 decision tree learner to classify fraud with a mean AUC of 0.96.
2.3 ECBDL’14 Contact Map Prediction The Community-Wide Experiment on the Critical Assessment of Techniques for Protein Structure Prediction (CASP) has been assessing team’s abilities to predict protein structure publicly since 1994 [72]. A total of 13 CASP experiments have been held thus far, and the CASP14 conference is scheduled for December of 2020. CASP notes that substantial progress has been made in the area of residue–residue contact prediction over recent years and descriptions of results are available in their Proceedings [72]. Triguero et al. [73] won the ECBDL’14 competition by employing random oversampling, evolutionary feature selection, and an RF learner (ROSEFW-RF) within the MapReduce framework. In their paper, the authors explore a range of hyperparameters and achieve their best results with an over-sampling rate of 170% and a subset of 90 features. We use the results from Triguero et al., i.e. a balanced TPR and TNR of 0.730, to evaluate the DNN thresholding results obtained in our ECBDL’14 experiments. Since the competition, other groups have used ECBDL’14 data to evaluate methods for treating class imbalance and big data. Fernández et al. [74] compared the performance of ROS, RUS, and SMOTE using subsets of ECBDL’14 containing 0.6 and 12 million instances and 90 features. Apache Spark and Hadoop frameworks were used to distribute RF and decision tree models across partition sizes of 1, 8, 16, 32, and 64. Results from Fernández et al. show that ROS and RUS perform better than SMOTE, and ROS tends to perform better as the number of partitions increases.
Thresholding Strategies for Deep Learning …
207
The best performing learner scored a G-Mean of 0.706 using the subset of 12 million instances and the ROS strategy, suggesting that the subset of 0.6 million instances is not representative enough. Río et al. [75] also used the Hadoop framework to explore RF learner performance with ROS and RUS methods for addressing class imbalance. They too found that ROS outperforms RUS, and suggested that this is due to the already underrepresented minority class being split across many partitions. Río et al. achieved their best performance with 64 partitions and an over-sampling rate of 130%. Unfortunately, results show a relatively low TPR (0.705) compared to the reported TNR (0.725) and the winning competition results (0.730). Similar to these related works, we use a subset of ECBDL’14 data (3.5 million instances) to evaluate thresholding strategies for DNN classification. Unlike related works, however, we did not find it necessary to use a distributed training environment. Instead of training individual models in a distributed fashion, we use multiple compute nodes with sufficient resources to train multiple models independently and in parallel. Also, unlike these related works, we do not rely on data sampling techniques to balance TPR and TNR scores. Rather, we show how a simple thresholding technique can be used to optimize class-wise performance regardless of the class imbalance level.
3 Data Sets This section summarizes the data sets that are used to evaluate thresholding techniques for addressing class imbalance with deep neural networks. We begin with two Medicare fraud data sets that were first curated by Herland et al. [64]. We then incorporate a large protein contact map prediction data set that was published by the Evolutionary Computation for Big Data and Big Learning (ECBDL) workshop in 2014 [40]. All three data sets were obtained from publicly available resources and exhibit big data and class imbalance characteristics.
3.1 CMS Medicare Data Two publicly available Medicare fraud data sets are obtained from CMS: (1) Medicare Provider Utilization and Payment Data: Physician and Other Supplier (Part B) [37], and (2) Medicare Provider Utilization and Payment Data: Part D Prescriber (Part D) [38]. The healthcare claims span years 2012–2016 and 2013–2017 for Part B and Part D data sets, respectively. Physicians are identified within each data set by their National Provider Identifier (NPI), a unique 10-digit number that is used to identify healthcare providers [76]. Using the NPI, Herland et al. map fraud labels to the Medicare data from the LEIE repository. The LEIE is maintained by the Office of Inspector General and it lists providers that are prohibited from practicing. Additional attributes of LEIE data include the reason for exclusion and provider reinstatement
208
J. M. Johnson and T. M. Khoshgoftaar
Table 1 Description of Part B features [64] Feature Description Npi Provider_type Nppes_provider_gender Line_srvc_cnt Bene_unique_cnt Bene_day_srvc_cnt Avg_submitted_chrg_amt Avg_medicare_payment_amt Exclusion
Unique provider identification number Medical provider’s specialty (or practice) Provider’s gender Number of procedures/services the provider performed Number of Medicare beneficiaries receiving the service Number of Medicare beneficiaries/per day services Avg. of the charges that a provider submitted for service Avg. payment made to a provider per claim for service Fraud labels from the LEIE data set
dates, where applicable. Providers that have been excluded for fraudulent activity are labeled as fraudulent within the Medicare Part B and Part D data sets. The Part B claims data set describes the services and procedures that healthcare professionals provide to Medicare’s Fee-For-Service beneficiaries. Records within the data set contain various provider-level attributes, e.g. NPI, first and last name, gender, credentials, and provider type. More importantly, records contain specific claims details that describe a provider’s activity within Medicare. Examples of claims data include the procedure performed, the average charge submitted to Medicare, the average amount paid by Medicare, and the place of service. The procedures rendered are encoded using the Healthcare Common Procedures Coding System (HCPCS) [77]. For example, HCPCS codes 99219 and 88346 are used to bill for hospital observation care and antibody evaluation, respectively. Part B data is aggregated by (1) provider NPI, (2) HCPCS code, and (3) place of service. The list of Part B features used for training are provided in Table 1. Similarly, the Part D data set contains a variety of provider-level attributes, e.g. NPI, name, and provider type. More importantly, the Part D data set contains specific details about medications prescribed by Medicare providers. Examples of prescription attributes include drug names, costs, quantities prescribed, and the number of beneficiaries receiving the medication. CMS aggregates the Part D data over (1) prescriber NPI and (2) drug name. Table 2 summarizes all Part D predictors used for classification.
3.2 ECBDL’14 Data The ECBDL’14 data set was originally generated to train a predictor for the residue– residue contact prediction track of the 9th Community-Wide Experiment on the Critical Assessment of Techniques for Protein Structure Prediction competition (CASP9) [72]. Protein contact map prediction is a subproblem of protein structure
Thresholding Strategies for Deep Learning …
209
Table 2 Description of Part D features [64] Feature Description Npi Provider_type Bene_count Total_claim_count Total_30_day_fill_count Total_day_supply Total_drug_cost Exclusion
Unique provider identification number Medical provider’s specialty (or practice) Number of distinct beneficiaries receiving a drug Number of drugs administered by a provider Number of standardized 30-day fills Number of day’s supply Cost paid for all associated claims Fraud labels from the LEIE data set
prediction. This subproblem entails predicting whether any two residues in a protein sequence are spatially close to each other [78]. The three-dimensional structure of a protein can then be inferred from these residue–residue contact map predictions. This is a fundamental problem in medical domains, e.g. drug design, as the three-dimensional structure of a protein determines its function [79]. Each instance of the ECBDL’14 data set is a pair of amino acids represented by 539 continuous attributes, 92 categorical attributes, and a binary label that distinguishes pairs that are in contact. Triguero et al. [73] provide a thorough description of the data set and the methods used to win the competition. Attributes include detailed information and statistics about the protein sequence and the segments connecting the target pair of amino acids. Additional predictors include the length of the protein sequence and a statistical contact propensity between the target pair of amino acid types [73]. The training partition contains 32 million instances and a positive class size of 2%. We refer readers to the original paper by Triguero et al. for a deeper understanding of the amino acid representation. The characteristics of this data set have made it a popular choice for evaluating methods of treating class imbalance and big data.
3.3 Data Preprocessing Both Medicare Part B and Part D data sets were curated and cleaned by Herland et al. [64] and required minimal preprocessing. First, test sets were created by using stratified random sampling to hold out 20% of each Medicare data set. The heldout test sets remain constant throughout all experiments, and when applicable, data sampling is only applied to the training data to create new class imbalance levels. These methods for creating varying levels of class imbalance are described in detail in Sect. 4. Finally, all features were normalized to continuous values in the range [0, 1]. Train and test sets were normalized by fitting a min-max scaler to the fit data and applying the scaler to the train and test sets separately.
210
J. M. Johnson and T. M. Khoshgoftaar
Table 3 Train and test set sizes Data set Sample count Part B Part D ECBDL’14
Train Test Train Test Train Test
3,753,896 938,474 2,917,961 729,491 2,800,000 700,000
Feature count
Positive count Positive (%)
125
1206 302 1028 257 59,960 15,017
133 200
0.032 0.032 0.035 0.035 2.141 2.145
A subset of 3.5 million records was taken from the ECBDL’14 data with random under-sampling. Categorical features were encoded using one-hot encoding, resulting in a final set of 985 features. Similar to the Medicare data, we normalized all attributes to the range [0, 1] and set aside a 20% test set using stratified random sampling. To improve efficiency, we applied Chi-square feature selection [80] and selected the best 200 features. Preliminary experiments suggested that exceeding 200 features provided limited returns on validation performance. Table 3 lists the sizes of train and test sets, the total number of predictors, and the size of the positive class for each data set. Train sets have between 2.8 and 3.7 million samples and relatively fewer positive instances. The ECBDL’14 data set has a positive class size of 2% and is arguably not considered to be highly imbalanced, i.e. positive class is greater than 1% of the data set. In Sect. 4, we explain how high class imbalance is simulated by under-sampling the positive class. The Medicare data sets, on the other hand, have just between 3 and 4 positive instances for every 10,000 observations and are intrinsically severely imbalanced.
4 Methods We evaluate thresholding strategies for addressing high class imbalance with deep neural networks across a wide range of class imbalance levels. A stratified random 80–20% split is used to create train and test partitions from each of the three data sets. Training sets are sampled to simulate the various levels of class imbalance, and test sets are held constant for evaluation purposes. All validation and hyperparameter tuning are executed on random partitions of the training data. After configuring hyperparameters, each model is fit to the training data and scored on the test set with 30 repetitions. This repetition accounts for any variation in results that may be caused by random sampling and allows for added statistical analysis. All experiments are performed on a high-performance computing environment running Scientific Linux 7.4 (Nitrogen) [81]. Neural networks are implemented using the Keras [32] open-source deep learning library written in Python with the TensorFlow [30] backend. The specific library implementations used in this study are the
Thresholding Strategies for Deep Learning …
211
default configurations of Keras 2.1.6-tf and TensorFlow 1.10.0. The scikit-learn package [82] (version 0.21.1) is used for preprocessing data. The remainder of this section describes (1) the DNN architectures used for each data set, (2) the data sampling methods used to vary class imbalance levels, (3) the Optimal and Prior thresholding strategy procedures, and (4) performance evaluation criteria.
4.1 Baseline Models Baseline architectures and hyperparameters were discovered through a random search procedure. For each data set, we set aside a random 10% of the fit data for validation, trained models on the remaining 90%, and then scored them on the validation set. This process was repeated 10 times for each hyperparameter configuration. The number of hidden layers, the number of neurons per layer, and regularization techniques were the primary focus of hyperparameter tuning. Experiments were restricted to deep fully connected models, i.e. neural networks containing two or more hidden layers. We first sought a model with sufficient capacity to learn the training data and then applied regularization techniques to reduce overfitting and improve generalization to validation sets. We used the area under the Receiver Operating Characteristic curve (ROC AUC) [83] performance metric to assess validation results. We prefer the ROC AUC metric for comparing models because it is threshold agnostic. If one model achieves a higher AUC score, then there exists an operating point (threshold) that will also achieve higher class-wise performance. Model validation results led us to select the following hyperparameter configurations. Mini-batch stochastic gradient descent with mini-batch sizes of 256 is used for all three data sets. This is preferred over batch gradient descent because it is computationally expensive to compute the loss over the entire data set, and increasing the number of samples that contribute to the gradient provides less than linear returns [27]. It has also been suggested that smaller batch sizes offer a regularization effect by introducing noise into the learning process [84]. We employ an advanced form of stochastic gradient descent (SGD) that adapts parameter-specific learning rates through training, i.e. the Adam optimizer, as it has been shown to outperform other popular optimizers [85]. The default learning rate (lr = 0.001) is used along with default moment estimate decay rates of β1 = 0.9 and β2 = 0.999. The Rectified Linear Unit (ReLU) activation function is used in all hidden layer neurons, and the sigmoid activation function is used at the output layer to estimate posterior probabilities [86]. The non-saturating ReLU activation function has been shown to alleviate the vanishing gradient problem and allow for faster training [35]. Network topologies were defined by first iteratively increasing architecture depth and width while monitoring training and validation performance. For both Medicare data sets, we determined that two hidden layers containing 32 neurons per layer provided sufficient capacity to overfit the model to the training data. For the ECBDL’14 data set, a larger network with four hidden layers containing between 128 and 32
212
J. M. Johnson and T. M. Khoshgoftaar
Table 4 Medicare Part B two-layer architecture Layer type # of neurons Input Dense Batch normalization ReLU activation Dropout P = 0.5 Dense Batch normalization ReLU activation Dropout P = 0.5 Dense Sigmoid activation
125 32 32 32 32 32 32 32 32 1 1
# of parameters 0 4032 128 0 0 1056 128 0 0 33 0
neurons each was required to fit the training data. We then explored regularization techniques to eliminate overfitting and improve validation performance. One way to reduce overfitting is to reduce the total number of learnable parameters, i.e. reducing network depth or width. L1 or L2 regularization methods, or weight decay, add parameter penalties to the objective function that constrain the network’s weights to lie within a region that is defined by a coefficient α [27]. Dropout simulates the ensembling of many models by randomly disabling non-output neurons with a probability P ∈ [0, 1] during each iteration, preventing neurons from co-adapting and forcing the model to learn more robust features [87]. Although originally designed to address internal covariate shift and speed up training, batch normalization has also been shown to add regularizing effects to neural networks [88]. Batch normalization is similar to normalizing input data to have a fixed mean and variance, except that it normalizes the inputs to hidden layers across each batch. We found a combination of dropout and batch normalization for best performance for all three data sets. For the Medicare models, we use a dropout rate of P = 0.5 and for the ECBDL’14 data set we use a dropout rate of P = 0.8. Batch normalization is applied before the activation function in each hidden unit. Table 4 describes the two-layer baseline architecture for the Medicare Part B data set. To determine how the number of hidden layers affects performance, we extended this model to four hidden layers following the same pattern, i.e. using 32 neuron layers, batch normalization, ReLU activations, and dropout in each hidden layer. We did not find it necessary to select new hyperparameters for the Medicare Part D data set. Instead, we just changed the size of the input layer to match the total number of features in each respective data set. The architecture for the ECBDL’14 data set follows this same basic pattern but contains four hidden layers with 128, 128, 64, and 32 neurons in each consecutive layer. With the increased feature count and network width, the ECBDL’14 network contains 54 K tunable parameters and is approximately 10× larger than the two-layer architecture used in Medicare experiments.
Thresholding Strategies for Deep Learning …
213
4.2 Data Sampling Strategies We use data sampling to alter training distributions and evaluate thresholding strategies across a wide range of class imbalance levels. The ROS method randomly duplicates samples from the minority class until the desired positive class prior is achieved. RUS randomly removes samples from the majority class without replacement until the desired level of imbalance is reached. The hybrid ROS-RUS method first undersamples from the majority class without replacement and then over-samples the minority class until classes are balanced. A combination of ROS, RUS, and ROSRUS is used to create 17 new training distributions from each Medicare data set and 12 new training distributions from the ECBDL’14 data set. Table 5 describes the 34 training distributions that were created from the Medicare Part B and Part D data sets. Of these new distributions, 24 contain low to severe levels of class imbalance and 10 have balanced positive and negative classes. The ROS-RUS-1, ROS-RUS-2, and ROS-RUS-3 use RUS to remove 50, 75, and 90% of the majority class. They then over-sample the minority class until both classes are balanced 50:50.
Table 5 Description of medicare distributions Distribution Positive prior CMS Part B type nneg Baseline RUS-1 RUS-2 RUS-3 RUS-4 RUS-5 RUS-6 RUS-7 ROS-1 ROS-2 ROS-3 ROS-4 ROS-5 ROS-6 ROS-7 ROS-RUS-1 ROS-RUS-2 ROS-RUS-3
0.0003 0.001 0.005 0.01 0.20 0.40 0.50 0.60 0.001 0.005 0.01 0.20 0.40 0.50 0.60 0.50 0.50 0.50
3,377,421 773,092 194,202 107,402 4,390 1,620 1,085 710 3,377,421 3,377,421 3,377,421 3,377,421 3,377,421 3,377,421 3,377,421 1,688,710 844,355 337,742
CMS Part D npos
nneg
npos
1,085 1,085 1,085 1,085 1,085 1,085 1,085 1,085 3,385 16,969 33,635 844,130 2,251,375 3,377,421 5,064,780 1,688,710 844,355 337,742
2,916,933 1,027,052 204,477 101,801 4,084 1,546 1,028 671 2,916,933 2,916,933 2,916,933 2,916,933 2,916,933 2,916,933 2,916,933 1,458,466 729,233 291,693
1,028 1,028 1,028 1,028 1,028 1,028 1,028 1,028 2,920 14,659 29,401 729,263 1,944,626 2,916,929 4,375,404 1,458,466 729,233 291,693
214
J. M. Johnson and T. M. Khoshgoftaar input : targets y, probability estimates p output: optimal threshold best_thresh ← curr_thresh ← max_gmean ← 0; delta_thresh ← 0.0005; while curr_thresh < 1.0 do yˆ ← ApplyThreshold(p, curr_thresh); ˆ tpr, tnr, gmean ← CalcPerformance(y, y); if tpr < tnr then return best_thresh; end if gmean > max_gmean then max_gmean ← gmean; best_thresh ← curr_thresh; end curr_thresh ← curr_thresh + delta_thresh; end return best_thresh;
Algorithm 1: Optimal threshold procedure
The original ECBDL’14 data set has a positive class size of 2% and is not classified as highly imbalanced. Therefore, we first simulate two highly imbalanced distributions by combining the entire majority class with two subsets of the minority class. By randomly under-sampling the minority class, we achieve two new distributions that have positive class sizes of 1% and 0.5%. We create additional distributions with positive class sizes of 5, 10, 20, 30, 40, 50, 60, 70, 80, and 90% by using RUS to reduce the size of the negative class. As a result, we are able to evaluate thresholding strategies on 12 class-imbalanced distributions of ECBDL’14 data and one balanced distribution of ECBDL’14 data. The size of each positive and negative class can be inferred from these strategies and the training data sizes from Table 3.
4.3 Thresholding Strategies The Optimal threshold strategy is used to approximately balance the TPR and TNR. We accomplish this by using a range of decision thresholds to score models on validation and training data and then select the threshold which maximizes the GMean. We also add a constraint that the Optimal threshold yields a TPR that is greater than the TNR, because we are more interested in detecting positive instances than negative instances. Threshold selection can be modified to optimize any other performance metric, e.g. precision, and the metrics used to optimize thresholds should ultimately be guided by problem requirements. If false positives are very costly, for example, then a threshold that maximizes TNR would be more appropriate. The performance metrics used in this study are explained in Sect. 4.4 and the procedure used to compute these Optimal thresholds is defined in Algorithm 1. Once model training is complete, Algorithm 1 takes ground truth labels and probability estimates from the trained model, iterates over a range of possible threshold values, and returns the threshold that maximizes the G-Mean.
Thresholding Strategies for Deep Learning …
215
For Medicare experiments, Optimal decision thresholds are computed during the validation phase. The validation step for each training distribution entails training 10 models and scoring them on random 10% partitions of the fit data, i.e. validation sets. Optimal thresholds are computed on each validation set, averaged, and then the average from each distribution is used to score models on the test set. While this use of validation data should reduce the risk of overfitting, we do not use validation data to compute Optimal thresholds for ECBDL’14 experiments. Instead, ECBDL’14 Optimal thresholds are computed by maximizing performance on fit data. Models are trained for 50 epochs using all fit data, and then Algorithm 1 searches for the threshold that maximizes performance on the training labels and corresponding model predictions. Unlike the Medicare Optimal thresholds, the ECBDL’14 Optimal thresholds can be computed with just one extra pass over the training data and do not require validation partitions. This is beneficial when training data is limited or when the positive class is absolutely rare. The Prior thresholding strategy estimates the positive class prior from the training data, i.e. p(cpos ) from Eq. 1, and uses its value to assign class labels to posterior scores on the test set. Given a training set with p(cpos ) = 0.1, for example, the Prior thresholding strategy will assign all test samples with probability scores > 0.1 to the positive class and those with scores ≤ 0.1 to the negative class. Since the Prior threshold can be calculated from the training data, and no optimization is required, we believe it is a good candidate for preliminary experiments with imbalanced data. Due to time constraints, this method was only explored using ECBDL’14 data.
4.4 Performance Evaluation The confusion matrix (Table 6) is created by comparing predicted labels to ground truth labels, where predicted labels are dependent on model outputs and the decision threshold. From the confusion matrix, we compute the TPR (Eq. 2), TNR (Eq. 3), and G-Mean (Eq. 4) performance metrics. We compare results using the Default, Optimal, and Prior thresholding strategies using 95% confidence intervals. From these confidence intervals, we are able to determine which thresholding strategies perform significantly better than others. We do not consider the ROC AUC metric because it is threshold agnostic, and we do not consider accuracy or error rate because they are misleading when working with imbalanced data. TPR = Recall =
TP TP + FN
(2)
Table 6 Confusion matrix Predicted positive Predicted negative
Actual positive
Actual negative
True positive (TP) False negative (FN)
False positive (FP) True negative (TN)
216
J. M. Johnson and T. M. Khoshgoftaar
TNR = Selectivity = G-Mean =
√
TN TN + FP
TPR × TNR
(3) (4)
We also use Tukey’s HSD test (α = 0.05) to estimate the significance of ECBDL’14 results. Tukey’s HSD test is a multiple comparison procedure that determines which method means are statistically different from each other by identifying differences that are greater than the expected standard error. Result sets are assigned to alphabetic groups based on the statistical difference of performance means, e.g. group a performs significantly better than group b.
5 Results and Discussion This section presents the DNN thresholding results that were obtained using the Medicare and ECBDL’14 data sets. We begin by illustrating the relationship between Optimal decision thresholds and positive class sizes using confidence intervals (C.I.) and linear models. Next, Default and Optimal thresholds are used to compare G-Mean scores across all Medicare distributions. ECBDL’14 results make similar comparisons and incorporate a third Prior thresholding strategy that proves effective. Finally, a statistical analysis of TPR and TNR scores is used to estimate the significance of each method’s results.
5.1 The Effect of Priors on Optimal Thresholds Classification thresholds are optimized in Medicare Part B and Part D experiments by training models and maximizing performance on a validation set using Algorithm 1. To account for random error and enable statistical analysis, this process is repeated 10 times for each distribution and architecture pair. Results from the two-layer architecture are listed in Table 7. Confidence intervals (α = 0.05) are provided separately for each Medicare data set, and bold-typed intervals indicate those which overlap the Default threshold of 0.5. Medicare results from the two-layer network suggest that the Optimal threshold varies significantly with the positive class prior. For example, optimal thresholds range from 0.0002 to 0.6478 as the positive class prior increases from 0.0003 to 0.6. More specifically, most distributions have Optimal thresholds that are approximately equal to their positive class prior. Following this pattern, we observe that Optimal threshold intervals only overlap the Default threshold when the positive class prior is equal to 0.5. We also observe that the Part B and Part D threshold intervals for respective distributions overlap each other in 17 of 18 cases. This suggests that the
Thresholding Strategies for Deep Learning … Table 7 Medicare optimal thresholds Distribution type Pos. class prior Baseline RUS-1 RUS-2 RUS-3 RUS-4 RUS-5 RUS-6 RUS-7 ROS-1 ROS-2 ROS-3 ROS-4 ROS-5 ROS-6 ROS-7 ROS-RUS-1 ROS-RUS-2 ROS-RUS-3
0.03 0.001 0.005 0.01 0.2 0.4 0.5 0.6 0.1 0.5 0.01 0.2 0.4 0.5 0.6 0.5 0.5 0.5
217
Optimal threshold 95% C.I. Medicare Part B Medicare Part D (0.0002, 0.0003) (0.0007, 0.0011) (0.0059, 0.0069) (0.0095, 0.0125) (0.2502, 0.2858) (0.3959, 0.4441) (0.4704, 0.5236) (0.5400, 0.6060) (0.0005, 0.0009) (0.0051, 0.0073) (0.0087, 0.0132) (0.2135, 0.2685) (0.3691, 0.4469) (0.4150, 0.4910) (0.5169, 0.6091) (0.4554, 0.5146) (0.4940, 0.5497) (0.4771, 0.5409)
(0.0002, 0.0004) (0.0005, 0.0009) (0.0049, 0.0056) (0.0107, 0.0130) (0.1998, 0.2516) (0.4030, 0.4665) (0.4690, 0.5326) (0.5495, 0.6478) (0.0005, 0.0008) (0.0051, 0.0064) (0.0112, 0.0139) (0.1958, 0.2613) (0.3409, 0.4197) (0.3795, 0.4925) (0.5189, 0.5707) (0.3807, 0.4889) (0.4111, 0.5203) (0.4119, 0.4774)
Fig. 1 Positive class size versus optimal decision threshold
relationship between the positive class size and the Optimal threshold is both linear and independent of the data set. In Fig. 1, Medicare Optimal threshold results are grouped by architecture type and plotted against the positive class size of the training distribution. Plots are enhanced with horizontal jitter, and linear models are fit to the data using Ordinary Least
218
J. M. Johnson and T. M. Khoshgoftaar
Fig. 2 ECBDL training epochs versus optimal thresholds
Squares [89] and 95% confidence bands. For both Medicare data sets and network architectures, there is a strong linear relationship between the positive class size and the Optimal decision threshold. The relationship is strongest for the two-layer networks, with r 2 ≥ 0.980 and p ≤ 9.73e−145. The four-layer network results share these linear characteristics, albeit weaker with r 2 ≤ 0.965 and visibly larger confidence bands. We also compute Optimal thresholds after each training epoch using ECBDL’14 data to determine how the Optimal threshold varies during model training. We trained models for 50 epochs and repeated each experiment three times. As illustrated in Fig. 2, the Optimal threshold is relatively stable and consistent throughout training. Similar to Medicare results, the ECBDL’14 Optimal thresholds correspond closely to the positive class prior of the training distribution. This section concludes that the positive class size has a strong linear effect on the Optimal decision threshold. We also found that the Optimal threshold for a given distribution may vary between deep learning architectures. In the next section, we consider the significance of these thresholds by using them to evaluate performance on unseen test data.
5.2 Medicare Classification Results Optimal classification threshold performance is first evaluated against Default threshold performance using the Medicare Part B and Part D data sets. Threshold results are compared over a range of class imbalance levels, i.e. 0.03–60%, using G-Mean, TPR, and TNR performance metrics. G-Mean results from the two- and four-layer networks are consolidated by aggregating on the distribution for each data set, and TPR and TNR results are averaged across both Medicare data sets.
Thresholding Strategies for Deep Learning … Table 8 Medicare Part B G-mean scores Distribution type Pos. class prior Baseline RUS-1 RUS-2 RUS-3 RUS-4 RUS-5 RUS-6 RUS-7 ROS-1 ROS-2 ROS-3 ROS-4 ROS-5 ROS-6 ROS-7 ROS-RUS-1 ROS-RUS-2 ROS-RUS-3
0.0003 0.001 0.005 0.01 0.2 0.4 0.5 0.6 0.001 0.005 0.01 0.2 0.4 0.5 0.6 0.5 0.5 0.5
219
G-mean 95% C.I. Optimal threshold Default threshold (0.7281, 0.7321) (0.7206, 0.7280) (0.7379, 0.7425) (0.7351, 0.7415) (0.7322, 0.7353) (0.7171, 0.7253) (0.7148, 0.7242) (0.7109, 0.7199) (0.7151, 0.7235) (0.7459, 0.7543) (0.7197, 0.7479) (0.7449, 0.7649) (0.7435, 0.7729) (0.7665, 0.7719) (0.7673, 0.7729) (0.7576, 0.7754) (0.7680, 0.7740) (0.7506, 0.7744)
(0.0000, 0.0000) (0.0000, 0.0000) (0.0000, 0.0000) (0.0000, 0.0000) (0.0903, 0.1939) (0.7307, 0.7333) (0.7155, 0.7225) (0.6542, 0.6798) (0.0000, 0.0000) (0.0000, 0.0000) (0.0000, 0.0000) (0.6070, 0.6466) (0.7563, 0.7695) (0.7563, 0.7695) (0.7434, 0.7686) (0.7488, 0.7744) (0.7472, 0.7726) (0.7497, 0.7749)
Tables 8 and 9 list the 95% G-Mean confidence intervals for all Medicare Part B and Part D distributions, respectively. Intervals listed in bold indicate those which are significantly greater than the alternative. Our first observation is that the Default classification threshold of 0.5 never performs significantly better than the Optimal threshold. In fact, the Default threshold only yields acceptable G-Mean scores when classes are virtually balanced, e.g. priors of 0.4–0.6. In all other distributions, the performance of the Default threshold degrades as the level of class imbalance increases. The Optimal threshold, however, yields relatively stable G-Mean scores across all distributions. Even the baseline distribution, with a positive class size of just 0.03%, yields acceptable G-Mean scores > 0.72 when using an Optimal classification threshold. Overall, these results discourage using the Default classification threshold when training DNN models with class-imbalanced data. Results also indicate an increase in G-Mean scores among the ROS and ROSRUS methods. This is not due to the threshold procedure, but rather, the undersampling procedure used to create the RUS distributions. Our previous work shows that using RUS with these highly imbalanced big data classification tasks tends to underrepresent the majority group and degrade performance [90]. Figure 3 presents the combined TPR and TNR scores for both Medicare data sets and DNN architectures. Optimal classification thresholds (left) produce stable
220
J. M. Johnson and T. M. Khoshgoftaar
Table 9 Medicare Part D G-mean scores Distribution type Pos. class prior Baseline RUS-1 RUS-2 RUS-3 RUS-4 RUS-5 RUS-6 RUS-7 ROS-1 ROS-2 ROS-3 ROS-4 ROS-5 ROS-6 ROS-7 ROS-RUS-1 ROS-RUS-2 ROS-RUS-3
0.0003 0.001 0.005 0.01 0.2 0.4 0.5 0.6 0.001 0.005 0.01 0.2 0.4 0.5 0.6 0.5 0.5 0.5
G-mean 95% C.I. Optimal threshold Default threshold (0.6986, 0.7022) (0.7058, 0.7128) (0.7262, 0.7300) (0.7305, 0.7345) (0.7052, 0.7120) (0.6815, 0.6849) (0.6870, 0.6928) (0.6785, 0.6843) (0.8058, 0.8128) (0.7262, 0.7300) (0.7305, 0.7345) (0.7378, 0.7478) (0.7424, 0.7494) (0.7262, 0.7402) (0.7450, 0.7516) (0.7398, 0.7490) (0.7453, 0.7515) (0.7456, 0.7532)
(0.0000, 0.0000) (0.0000, 0.0000) (0.0000, 0.0000) (0.0000, 0.0000) (0.1932, 0.2814) (0.6659, 0.6835) (0.6870, 0.6926) (0.6486, 0.6606) (0.0000, 0.0000) (0.0000, 0.0089) (0.0030, 0.0220) (0.6040, 0.6258) (0.7337, 0.7405) (0.7431, 0.7497) (0.7395, 0.7481) (0.7463, 0.7537) (0.7487, 0.7547) (0.7479, 0.7535)
Fig. 3 Medicare class-wise performance
TPR and TNR scores across all positive class sizes. Furthermore, we observe that the TPR is always greater than the TNR when using the Optimal classification threshold. This suggests that our threshold selection procedure (Algorithm 1) is effective, and that we can expect performance trade-offs optimized during validation to generalize to unseen test data. Default threshold results (right), however, are unstable as the positive class size varies. When class imbalance levels are high, for example, the Default threshold assigns all test samples to the negative class. It is only when classes are mostly balanced that the Default threshold achieves high TPR and TNR. Even
Thresholding Strategies for Deep Learning …
221
when Default threshold performance is reasonably well balanced, e.g. positive class sizes of 40%, we lose the ability to maximize TPR over TNR. In summary, two highly imbalanced Medicare data sets were used to compare DNN prediction performance using Optimal classification thresholds and Default classification thresholds. For each data set, Part B and Part D, 18 distributions were created using ROS and RUS to cover a wide range of class imbalance levels, i.e. 0.03– 60%. For each of these 36 distributions, 30 two-layer networks and 30 four-layer networks were trained and scored on test sets. With evidence from over 2,000 DNN models, statistical results show that the Default threshold is suboptimal whenever models are trained with imbalanced data. Even in the most severely imbalanced distributions, e.g. positive class size of 0.03%, scoring with an Optimal threshold yields consistently favorable G-Mean, TPR, and TNR scores. In the next section, we expand on these results with the ECBDL’14 data set and consider a third thresholding method, the Prior threshold.
5.3 ECBDL’14 Classification Results ECBDL’14 experiments incorporate a third thresholding strategy, the Prior threshold, which uses the prior probability of the positive class from the training distribution as the classification threshold. We first present G-Mean scores from each respective thresholding strategy over a wide range of class imbalance levels, i.e. 0.5–90%. G-Mean, TPR, and TNR results are then averaged across all distributions and summarized using Tukey’s HSD test. Figure 4 illustrates the G-Mean score for each thresholding strategy and distribution. Similar to Medicare results, the Default threshold performance is acceptable when classes are mostly balanced, e.g. positive class sizes of 40–60%, but deteriorates
Fig. 4 ECBDL’14 G-mean results
222
J. M. Johnson and T. M. Khoshgoftaar
Table 10 ECBDL’14 class-wise performance and HSD groups Threshold Geometric mean True positive rate strategy Mean Std. Group Mean Std. Group Default Optimal Prior
0.4421 0.7333 0.7320
0.28 0.01 0.02
b a a
0.4622 0.7276 0.7195
0.37 0.02 0.04
b a a
True negative rate Mean Std. Group 0.7823 0.7399 0.7470
0.26 0.03 0.05
a b b
quickly when the positive class prior is ≥ 0.7 or ≤ 0.3. Unlike the Default threshold, the Optimal threshold yields relatively stable G-Mean scores (≥ 0.7) across all training distributions. Most interestingly, the Prior threshold also yields stable GMean scores across all training distributions. In addition to performing, besides the Optimal threshold, the Prior threshold strategy has the advantage of being derived directly from the training data without requiring optimization. Most importantly, we were able to achieve TPR and TNR scores on par with those of the competition winners [73] without the added costs of over-sampling. Table 10 lists ECBDL’14 G-Mean, TPR, and TNR scores averaged across all training distributions. Tukey’s HSD groups are used to identify results with significantly different means, i.e. group a performs significantly better than group b. For the G-Mean metric, the Optimal and Prior threshold methods are placed into group a with mean scores of 0.7333 and 0.7320, respectively. The Default threshold is placed into group b with a mean score of 0.4421. These results indicate that the Optimal and Prior thresholds balance class-wise performance significantly better than the Default threshold. Equivalently, TPR results place the Optimal and Prior threshold methods into group a and the Default threshold into group b. Put another way, the non-default threshold strategies are significantly better at capturing the positive class of interest. We expect this behavior from the Optimal threshold, as the threshold selection procedure that we employed (Algorithm 1) explicitly optimizes thresholds by maximizing TPR and G-Mean scores on the training data. The Prior method, however, was derived directly from the training data with zero training or optimization, and surprisingly, performed equally as well as the Optimal threshold. We believe these qualities make the Prior thresholding strategy a great candidate for preliminary experiments and baseline models. If requirements call for specific class-wise performance trade-offs, the Prior threshold can still offer an approximate baseline threshold to begin optimization. On further inspection, we see that the Default threshold achieves the highest average TNR score. While the negative class is typically not the class of interest, it is still important to minimize false positive predictions. The Default threshold’s TNR score is misleadingly high, however, and is a result of having more imbalanced distributions with positive class sizes < 0.5 than there are > 0.5. Recall from Fig. 3 that when the positive class prior is small, models tend to assign all test samples to the negative class. The Default threshold scores highly on TNR because most of the models trained with imbalanced data are assigning all test samples to the
Thresholding Strategies for Deep Learning …
223
negative class and achieving a 100% TNR. Therefore, we rely on the G-Mean scores to ensure class-wise performance is balanced and conclude that the Default threshold is suboptimal when high class imbalance exists within the training data. The ECBDL’14 results presented in this section align with those from the Medicare experiments. For all imbalanced training distributions, Optimal thresholds consistently outperform the Default threshold. The Prior threshold, although not optimized, performed statistically as well as the Optimal threshold by all performance criteria.
6 Conclusion This chapter explored the effects of highly imbalanced big data on training and scoring deep neural network classifiers. We trained models on a wide range of class imbalance levels (0.03–90%) and compared the results of two output thresholding strategies to Default threshold results. The Optimal threshold technique used training or validation data to find thresholds that maximize the G-Mean performance metric. The Prior threshold technique uses the positive class prior as the classification threshold for assigning labels to test instances. As suggested by Bayes theorem (Eq. 1), we found that all Optimal thresholds are proportional to the positive class priors of the training data. As a result, the Default threshold of 0.5 only performed well when the training data was relatively balanced, i.e. positive priors between 0.4–0.6. For all other distributions, the Optimal and Prior thresholding strategies performed significantly better based on the G-Mean criterion. Furthermore, Tukey’s HSD test results suggest that there is no difference between Optimal and Prior threshold results. These Optimal threshold results are dependent on the threshold selection criteria (Algorithm 1). This Optimal threshold procedure should be guided by the classification task requirements, and selecting a new performance criteria may yield Optimal thresholds that are significantly different from the Prior threshold. Future works should evaluate these thresholding strategies across a wider range of domains and network architectures, e.g. natural language processing and computer vision. Additionally, the threshold selection procedure should be modified to optimize alternative performance metrics, and statistical tests should be used to identify significant differences.
References 1. W. Wei, J. Li, L. Cao, Y. Ou, J. Chen, Effective detection of sophisticated online banking fraud on extremely imbalanced data. World Wide Web 16, 449–475 (2013) 2. A.N. Richter, T.M. Khoshgoftaar, Sample size determination for biomedical big data with limited labels. Netw. Model. Anal. Health Inf. Bioinf. 9, 1–13 (2020) 3. M. Kubat, R.C. Holte, S. Matwin, Machine learning for the detection of oil spills in satellite radar images. Mach. Learn. 30, 195–215 (1998)
224
J. M. Johnson and T. M. Khoshgoftaar
4. S. Wang, X. Yao, Multiclass imbalance problems: analysis and potential solutions. IEEE Trans. Syst. Man Cyb. Part B (Cybern.) 42, 1119–1130 (2012) 5. N. Japkowicz, The class imbalance problem: significance and strategies, in Proceedings of the International Conference on Artificial Intelligence (2000) 6. M. Buda, A. Maki, M.A. Mazurowski, A systematic study of the class imbalance problem in convolutional neural networks. Neural Netw. 106, 249–259 (2018) 7. H. He, E.A. Garcia, Learning from imbalanced data, IEEE Trans. Knowl. Data Eng. 21, 1263– 1284 (2009) 8. G.M. Weiss, Mining with rarity: a unifying framework. SIGKDD Explor. Newsl. 6, 7–19 (2004) 9. R.A. Bauder, T.M. Khoshgoftaar, T. Hasanin, An empirical study on class rarity in big data, in 2018 17th IEEE International Conference on Machine Learning and Applications (ICMLA) (2018), pp. 785–790 10. E. Dumbill, What is big data? an introduction to the big data landscape (2012). http://radar. oreilly.com/2012/01/what-is-big-data.html 11. S.E. Ahmed, Perspectives on Big Data Analysis: methodologies and Applications (Amer Mathematical Society, USA, 2014) 12. J.L. Leevy, T.M. Khoshgoftaar, R.A. Bauder, N. Seliya, A survey on addressing high-class imbalance in big data. J. Big Data 5, 42 (2018) 13. J. Dean, S. Ghemawat, Mapreduce: simplified data processing on large clusters. Commun. ACM 51, 107–113 (2008) 14. M. Zaharia, M. Chowdhury, M. J. Franklin, S. Shenker, I. Stoica, Spark: cluster computing with working sets, in Proceedings of the 2Nd USENIX Conference on Hot Topics in Cloud Computing, HotCloud’10, (Berkeley, CA, USA), USENIX Association (2010), p. 10 15. K. Chahal, M. Grover, K. Dey, R.R. Shah, A hitchhiker’s guide on distributed training of deep neural networks. J. Parallel Distrib. Comput. 10 (2019) 16. R.K.L. Kennedy, T.M. Khoshgoftaar, F. Villanustre, T. Humphrey, A parallel and distributed stochastic gradient descent implementation using commodity clusters. J. Big Data 6(1), 16 (2019) 17. D.L. Wilson, Asymptotic properties of nearest neighbor rules using edited data. IEEE Trans. Syst. Man Cybern. SMC-2, 408–421 (1972) 18. N.V. Chawla, K.W. Bowyer, L.O. Hall, W.P. Kegelmeyer, Smote: synthetic minority oversampling technique. J. Artif. Int. Res. 16, 321–357 (2002) 19. H. Han, W.-Y. Wang, B.-H. Mao, Borderline-smote: a new over-sampling method in imbalanced data sets learning, in Advances in Intelligent Computing ed. by D.-S. Huang, X.-P. Zhang, G.-B. Huang (Springer, Berlin, Heidelberg, 2005), pp. 878–887 20. T. Jo, N. Japkowicz, Class imbalances versus small disjuncts. SIGKDD Explor. Newsl. 6, 40–49 (2004) 21. C. Ling, V. Sheng, Cost-sensitive learning and the class imbalance problem, in Encyclopedia of Machine Learning (2010) 22. J.J Chen, C.-A. Tsai, H. Moon, H. Ahn, J.J. Young, C.-H. Chen, Decision threshold adjustment in class prediction, in SAR and QSAR in Environmental Research, vol. 17 (2006), pp. 337–352 23. Q. Zou, S. Xie, Z. Lin, M. Wu, Y. Ju, Finding the best classification threshold in imbalanced classification. Big Data Res. 5, 2–8 (2016) 24. X. Liu, J. Wu, Z. Zhou, Exploratory undersampling for class-imbalance learning. IEEE Trans. Syst. Man Cybern. Part B (Cybern.) 39, 539–550 (2009) 25. N.V. Chawla, A. Lazarevic, L.O. Hall, K.W. Bowyer, Smoteboost: improving prediction of the minority class in boosting, in Knowledge Discovery in Databases: PKDD 2003 ed. by N. Lavraˇc, D. Gamberger, L. Todorovski, H. Blockeel, (Springer, Berlin, Heidelberg, 2003), pp. 107–119 26. Y. Sun, Cost-sensitive Boosting for Classification of Imbalanced Data. Ph.D. thesis, Waterloo, Ont., Canada, Canada, 2007. AAINR34548 27. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (The MIT Press, Cambridge, MA, 2016)
Thresholding Strategies for Deep Learning …
225
28. I.H. Witten, E. Frank, M.A. Hall, C.J. Pal, Data Mining, Fourth Edition: practical Machine Learning Tools and Techniques, 4th edn. (San Francisco, CA, USA, Morgan Kaufmann Publishers Inc., 2016) 29. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature. 521, 436 (2015) 30. M. Abadi, A. Agarwal, P. Barham, E. Brevdo, Z. Chen, C. Citro, G.S. Corrado, A. Davis, J. Dean, M. Devin, S. Ghemawat, I. Goodfellow, A. Harp, G. Irving, M. Isard, Y. Jia, R. Jozefowicz, L. Kaiser, M. Kudlur, J. Levenberg, D. Mané, R. Monga, S. Moore, D. Murray, C. Olah, M. Schuster, J. Shlens, B. Steiner, I. Sutskever, K. Talwar, P. Tucker, V. Vanhoucke, V. Vasudevan, F. Viégas, O. Vinyals, P. Warden, M. Wattenberg, M. Wicke, Y. Yu, X. Zheng, TensorFlow: large-scale machine learning on heterogeneous systems (2015) 31. Theano Development Team, Theano: a python framework for fast computation of mathematical expressions (2016). arXiv:abs/1605.02688 32. F. Chollet et al., Keras (2015). https://keras.io 33. A. Paszke, S. Gross, S. Chintala, G. Chanan, E. Yang, Z. DeVito, Z. Lin, A. Desmaison, L. Antiga, A. Lerer, Automatic differentiation in pytorch, in NIPS-W (2017) 34. S. Chetlur, C. Woolley, P. Vandermersch, J. Cohen, J. Tran, B. Catanzaro, E. Shelhamer, cudnn: efficient primitives for deep learning (2014) 35. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional neural networks. Neural Inform. Process. Syst. 25, 01 (2012) 36. M.D. Richard, R.P. Lippmann, Neural network classifiers estimate bayesian a posteriori probabilities. Neural Comput. 3(4), 461–483 (1991) 37. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data: physician and other supplier (2018) 38. Centers For Medicare & Medicaid Services, Medicare provider utilization and payment data: part D prescriber (2018) 39. U.S. Government, U.S. Centers for Medicare & Medicaid Services, The official U.S. government site for medicare 40. Evolutionary Computation for Big Data and Big Learning Workshop, Data mining competition 2014: self-deployment track 41. M. Wani, F. Bhat, S. Afzal, A. Khan, Advances in Deep Learning (Springer, 2020) 42. J.W. Tukey, Comparing individual means in the analysis of variance. Biometrics 5(2), 99–114 (1949) 43. R. Anand, K.G. Mehrotra, C.K. Mohan, S. Ranka, An improved algorithm for neural network classification of imbalanced training sets. IEEE Trans. Neural Netw. 4, 962–969 (1993) 44. J.M. Johnson, T.M. Khoshgoftaar, Survey on deep learning with class imbalance. J. Big Data 6, 27 (2019) 45. J.M. Johnson, T.M. Khoshgoftaar, Medicare fraud detection using neural networks. J. Big Data 6(1), 63 (2019) 46. D. Masko, P. Hensman, The impact of imbalanced training data for convolutional neural networks, in 2015. KTH, School of Computer Science and Communication (CSC) 47. H. Lee, M. Park, J. Kim, Plankton classification on imbalanced large scale database via convolutional neural networks with transfer learning, in 2016 IEEE International Conference on Image Processing (ICIP) (2016), pp. 3713–3717 48. S. Wang, W. Liu, J. Wu, L. Cao, Q. Meng, P. J. Kennedy, Training deep neural networks on imbalanced data sets, in 2016 International Joint Conference on Neural Networks (IJCNN) (2016), pp. 4368–4374 49. H. Wang, Z. Cui, Y. Chen, M. Avidan, A. B. Abdallah, A. Kronzer, Predicting hospital readmission via cost-sensitive deep learning. IEEE/ACM Trans. Comput. Biol. Bioinf. 1 (2018) 50. S.H. Khan, M. Hayat, M. Bennamoun, F.A. Sohel, R. Togneri, Cost-sensitive learning of deep feature representations from imbalanced data. IEEE Trans. Neural Netw. Learn. Syst. 29, 3573– 3587 (2018) 51. T.-Y. Lin, P. Goyal, R. B. Girshick, K. He, P. Dollár, Focal loss for dense object detection, 2017 IEEE International Conference on Computer Vision (ICCV) (2017), pp. 2999–3007
226
J. M. Johnson and T. M. Khoshgoftaar
52. C. Huang, Y. Li, C. C. Loy, X. Tang, Learning deep representation for imbalanced classification, in 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016), pp. 5375–5384 53. S. Ando, C.Y. Huang, Deep over-sampling framework for classifying imbalanced data, in Machine Learning and Knowledge Discovery in Databases, ed. by M. Ceci, J. Hollmén, L. Todorovski, C. Vens, S. Džeroski (Springer International Publishing, Cham, 2017), pp. 770– 785 54. Q. Dong, S. Gong, X. Zhu, Imbalanced deep learning by minority class incremental rectification. IEEE Trans. Pattern Anal. Mach. Intell. 1 (2018) 55. Q. Chen, J. Huang, R. Feris, L.M. Brown, J. Dong, S. Yan, Deep domain adaptation for describing people based on fine-grained clothing attributes, in 2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2015), pp. 5315–5324 56. Y. LeCun, C. Cortes, MNIST handwritten digit database (2010). http://yann.lecun.com/exdb/ mnist/, Accessed 15 Nov 2018 57. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http:// www.cs.toronto.edu/kriz/cifar.html 58. R.A. Bauder, T.M. Khoshgoftaar, A novel method for fraudulent medicare claims detection from expected payment deviations (application paper), in 2016 IEEE 17th International Conference on Information Reuse and Integration (IRI) (2016), pp. 11–19 59. R.A. Bauder, T.M. Khoshgoftaar, A probabilistic programming approach for outlier detection in healthcare claims, in 2016 15th IEEE International Conference on Machine Learning and Applications (ICMLA) (2016), pp. 347–354 60. R.A. Bauder, T.M. Khoshgoftaar, A. Richter, M. Herland, Predicting medical provider specialties to detect anomalous insurance claims, in 2016 IEEE 28th International Conference on Tools with Artificial Intelligence (ICTAI) (2016), pp. 784–790 61. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, Medical provider specialty predictions for the detection of anomalous medicare insurance claims, in 2017 IEEE International Conference on Information Reuse and Integration (IRI) (2017), pp. 579–588 62. Office of Inspector General, LEIE downloadable databases (2019) 63. R.A. Bauder, T.M. Khoshgoftaar, The detection of medicare fraud using machine learning methods with excluded provider labels, in FLAIRS Conference (2018) 64. M. Herland, T.M. Khoshgoftaar, R.A. Bauder, Big data fraud detection using multiple medicare data sources. J. Big Data 5, 29 (2018) 65. M. Herland, R.A. Bauder, T.M. Khoshgoftaar, The effects of class rarity on the evaluation of supervised healthcare fraud detection models. J. Big Data 6(1), 21 (2019) 66. K. Feldman, N.V. Chawla, Does medical school training relate to practice? evidence from big data. Big Data (2015) 67. Centers for Medicare & Medicaid Services, Physician compare datasets (2019) 68. J. Ko, H. Chalfin, B. Trock, Z. Feng, E. Humphreys, S.-W. Park, B. Carter, K.D. Frick, M. Han, Variability in medicare utilization and payment among urologists. Urology 85, 03 (2015) 69. V. Chandola, S.R. Sukumar, J.C. Schryver, Knowledge discovery from massive healthcare claims data, in KDD (2013) 70. L.K. Branting, F. Reeder, J. Gold, T. Champney, Graph analytics for healthcare fraud risk estimation, in 2016 IEEE/ACM International Conference on Advances in Social Networks Analysis and Mining (ASONAM) (2016), pp. 845–851 71. National Plan & Provider Enumeration System, NPPES NPI registry (2019) 72. P.S.P. Center, 9th community wide experiment on the critical assessment of techniques for protein structure prediction 73. I. Triguero, S. Rí, V. López, J. Bacardit, J. Benítez, F. Herrera, ROSEFW-RF: the winner algorithm for the ecbdl’14 bigdata competition: an extremely imbalanced big data bioinformaticsproblem. Knowl.-Based Syst. 87 (2015) 74. A. Fernández, S. del Río, N.V. Chawla, F. Herrera, An insight into imbalanced big data classification: outcomes and challenges. Complex Intell. Syst. 3(2), 105–120 (2017)
Thresholding Strategies for Deep Learning …
227
75. S. de Río, J.M. Benítez, F. Herrera, Analysis of data preprocessing increasing the oversampling ratio for extremely imbalanced big data classification, in 2015 IEEE Trustcom/BigDataSE/ISPA, vol. 2 (2015), pp. 180–185 76. Centers for Medicare & Medicaid Services, National provider identifier standard (NPI) (2019) 77. Centers For Medicare & Medicaid Services, HCPCS general information (2018) 78. P. Di Lena, K. Nagata, P. Baldi, Deep architectures for protein contact map prediction, Bioinformatics (Oxford, England) 28, 2449–57 (2012) 79. J. Berg, J. Tymoczko, L. Stryer, Chapter 3, protein structure and function, in Biochemistry, 5th edn. (W H Freeman, New York, 2002) 80. Z. Zhao, F. Morstatter, S. Sharma, S. Alelyani, A. Anand, H. Liu, Advancing feature selection research, ASU Feature Selection Repository (2010), pp. 1–28 81. S. Linux, About (2014) 82. F. Pedregosa, G. Varoquaux, A. Gramfort, V. Michel, B. Thirion, O. Grisel, M. Blondel, P. Prettenhofer, R. Weiss, V. Dubourg, J. Vanderplas, A. Passos, D. Cournapeau, M. Brucher, M. Perrot, E. Duchesnay, Scikit-learn: machine learning in python. J. Mach. Learn. Res. 12, 2825–2830 (2011) 83. F. Provost, T. Fawcett, Analysis and visualization of classifier performance: comparison under imprecise class and cost distributions, in Proceedings of the Third International Conference on Knowledge Discovery and Data Mining, vol. 43–48 (1999), p. 12 84. D. Wilson, T. Martinez, The general inefficiency of batch training for gradient descent learning. Neural Netw.: Off. J. Int. Neural Netw. Soc. 16, 1429–51 (2004) 85. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2015). arXiv:abs/1412.6980 86. R.P. Lippmann, Neural networks, bayesian a posteriori probabilities, and pattern classification, in From Statistics to Neural Networks, ed. by V. Cherkassky, J.H. Friedman, H. Wechsler (Springer, Berlin, Heidelberg, 1994), pp. 83–104 87. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15, 1929–1958 (2014) 88. S. Ioffe, C. Szegedy, Batch normalization: accelerating deep network training by reducing internal covariate shift, in Proceedings of the 32Nd International Conference on International Conference on Machine Learning, ICML’15, vol. 37 (JMLR.org, 2015), pp. 448–456 89. B. Zdaniuk, Ordinary Least-Squares (OLS) Model (Dordrecht, Springer Netherlands, 2014), pp. 4515–4517 90. J.M. Johnson, T.M. Khoshgoftaar, Deep learning and data sampling with imbalanced big data, 2019 IEEE 20th International Conference on Information Reuse and Integration for Data Science (IRI) (2019), pp. 175–183
Vehicular Localisation at High and Low Estimation Rates During GNSS Outages: A Deep Learning Approach Uche Onyekpe, Stratis Kanarachos, Vasile Palade, and Stavros-Richard G. Christopoulos
Abstract Road localisation of autonomous vehicles is reliant on consistent accurate GNSS (Global Navigation Satellite System) positioning information. Commercial GNSS receivers usually sample at 1 Hz, which is not sufficient to robustly and accurately track a vehicle in certain scenarios, such as driving on the highway, where the vehicle could travel at medium to high speeds, or in safety-critical scenarios. In addition, the GNSS relies on a number of satellites to perform triangulation and may experience signal loss around tall buildings, bridges, tunnels and trees. An approach to overcoming this problem involves integrating the GNSS with a vehicle-mounted Inertial Navigation Sensor (INS) system to provide a continuous and more reliable high rate positioning solution. INSs are however plagued by unbounded exponential error drifts during the double integration of the acceleration to displacement. Several deep learning algorithms have been employed to learn the error drift for a better positioning prediction. We therefore investigate in this chapter the performance of Long Short-Term Memory (LSTM), Input Delay Neural Network (IDNN), Multi-Layer Neural Network (MLNN) and Kalman Filter (KF) for high data rate positioning. We show that Deep Neural Network-based solutions can exhibit better performances for
U. Onyekpe (B) Research Center for Data Science, Institute for Future Transport and Cities, Coventry University, Gulson Road, Coventry, UK e-mail: [email protected] S. Kanarachos Faculty of Engineering, Coventry University, Gulson Road, Coventry, UK e-mail: [email protected] V. Palade Research Center for Data Science, Coventry University, Gulson Road, Coventry, UK e-mail: [email protected] S.-R. G. Christopoulos Institute for Future Transport and Cities, Faculty of Engineering, Coventry University, Gulson Road, Coventry, UK e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_10
229
230
U. Onyekpe et al.
high data rate positioning of vehicles in comparison to commonly used approaches like the Kalman filter. Keywords Inertial navigation · INS · INS/GPS-integrated navigation · GPS outage · Autonomous vehicle navigation · Deep learning · Neural networks · High sampling rate
1 Introduction 1.1 Importance of Autonomous Vehicles It is estimated that the UK’s autonomous vehicle market will be worth an approximate value of £28 billion by 2035 [1]. A major motivation towards the development of these vehicles is the need to improve road safety. According to [2], 75% of trafficrelated road accidents in the UK are due to human driving errors, rising to 94% in the United States. The introduction of autonomous vehicles has the potential to reduce these accidents [3]. Even though these vehicles could introduce new kinds of accidents, there is the drive to ensure they are as safe as possible. Sensory systems are key to the performance of autonomous vehicles as they help the vehicle understand its environment [4]. Examples of sensors found on the outside of the vehicle include LIDARs, cameras and ultrasonic systems. Several data processing and analysis systems are also present inside the vehicle, which use the sensor data to make decisions a human driver would normally make. LIDARs and cameras are imaging systems used to identify objects, structures, potential collision hazards and pedestrians in the vehicle’s trajectory [5]. Cameras are furthermore essential to the identification of road signs and markings on structured roads. As proof of the crucial role imaging systems play in the operation of the vehicle, a good number of sophisticated versions of such systems are already employed [6]. Nevertheless, although imaging systems can be used to assess the vehicle’s environment as well as determine the position of objects or markings relative to it, there is the need to continuously and robustly localise a vehicle with reference to a defined co-ordinate system. Information on how the vehicle navigates through its environment is also needed such that real-time follow-up decisions can be made.
1.2 GNSS (Global Navigation Satellite System) Issues A GNSS receiver performs complex analysis on the signals received from at least four of the many satellites orbiting the earth and is known to be one of the best when it comes to position estimation, as it has no competition in terms of cost or coverage [7]. Despite the wide acceptance of GNSS, it is far from being a perfect positioning system. There can be instances of GNSS failures in outdoor environments, as there has
Vehicular Localisation at High and Low Estimation Rates …
231
to be a direct line of sight between the satellites and the GNSS antennae. GNSS can prove difficult to use in metropolitan cities and similar environments characterised by tall buildings, bridges, tunnels or trees, as its line of sight may be blocked during signal transmission [7]. More so, GNSS signal can be jammed and this leaves the vehicle with no information about its position [8]. As such, a GNSS cannot act as a standalone navigation system. The GNSS is used to localise the autonomous vehicle to a road. To achieve lane accuracy, GNSS is combined with high accuracy LIDARs, cameras, RADAR and High Definition (HD) maps. There are however times when the camera and LIDAR could be uninformative or unavailable for use. The accuracy of low-cost LIDARs and cameras could be compromised when there is heavy fog, snow, sleet or rain [9]. This is a well-recognised issue in the field. The cost of high accuracy LIDAR also makes them a theft attractive item as they are worth several thousands of pounds. Hence, the use of LIDARs on autonomous vehicles would make the vehicles more expensive. Camera-based positioning systems could also face low accuracies depending on the external light intensity and the objects in the camera’s scene. In level 4 self-driving applications, as tested by Waymo LLC and Cruise LLC, the LIDAR scan is matched onto an HD map in real time. Based on this, the system is able to precisely position the vehicle within its environment [10]. However, this method is computationally intensive. Furthermore, changes in the driving environment and infrastructure could make an HD map temporarily outdated and as such not useful for navigation.
1.3 Navigation Using Inertial Measurement Sensors An Inertial Navigation Sensor (INS) however, unlike other sensors found in an autonomous vehicle, does not need to interact with the external environment to perform localisation, making it unique to the other sensors employed on the vehicle. This independence makes it vital for both sensor fusion and safety. An Inertial Measuring Unit (IMU) measures a vehicle’s linear acceleration and rotational rate components in the x, y and z-axis and computes velocity, positioning and orientation by continuous dead reckoning. It functions to provide localisation data, which is needed by the vehicle to position itself within its environment. As production vehicles are already equipped with anywhere from one third to a full INS [11], the IMU can be used to localise the vehicle temporarily in the absence of the GNSS signals. The IMUs can also be used to compare positions and estimate in order to introduce certainty to the final localisation output. In the absence of an IMU, it would be difficult to know when the localization accuracy of the LIDAR may have deteriorated [10]. Through a complex mathematical analysis, the position of the vehicle can be computed using the INS to dead reckon during the GNSS outage. However, the sensors are plagued by exponential error drifts manifested by the double integration of the acceleration to displacement. These errors are unboundedly cascaded over time, leading to provide a very poor positioning estimation. Commonly, the
232
U. Onyekpe et al.
GNSS in what could be described as a mutually symbiotic relationship calibrates the INS periodically during signal coverage to help improve the positioning estimation accuracy. Traditionally, Kalman filters are used to model the error between the GPS position and the INS position solution. Kalman filters have limitations when modelling highly non-linear dependencies, non-Gaussian noise measurements and stochastic relationships. The use of artificial neural network techniques in place of Kalman filters to model the errors has been recently explored by some researchers, as they are capable of learning non-linear relationships within. Compared to Kalman filters, deep learning techniques have proven to perform better in longer GPS signal losses. Rashad et al. proposed a radial basis function neural network to model the position error between the GPS and the INS position [12]. A Multi-Layer Feed-Forward Neural Network (MLNN) was applied to a DGPS, and tactical grade INS-integrated architecture for navigation [13]. [15] utilised an MLNN on a single point positioning GPS integrated and IMU architecture. Malleswaran et al. suggested the use of bidirectional and hetero-associative neural networks on an INS/GPS-integrated system [16]. Noureldin et al. proposed the use of an Input Delay Neural Network (IDNN) on the INS/GPS problem by utilising inputs from previous timesteps [17]. Malleswaran et al. investigated the use of a Sigma-Pi neural network on the navigation problem [18]. The performance of these techniques, as demonstrated in published literature, highlights the potential of intelligent algorithms in autonomous vehicle navigation. More so, a direct comparison of the performances of these techniques is not possible, as information of the vehicle’s dynamics studied are not publicly available. Most researchers have employed a prediction frequency of 1 Hz, as commercial GPS receivers mostly update their location information every second [19]. However, on the motorway and other safety-critical applications, predicting at such frequency is not sufficient to efficiently track the vehicle. With a speed of 70 mph, a vehicle could cover a distance of 32 m in a second with the vehicle’s motion unaccounted for, where a 2.5 m lateral displacement could mean a completely different lane. More so, assessments of vehicle-related accidents by automotive insurers would require high resolution positioning estimation [19]. A high data rate positioning technique would more accurately monitor the vehicle’s motion between GPS signal updates and signal losses. We therefore comparatively investigate in this chapter the performance of LSTM, IDNN, MLNN and the Kalman filter for high data rate positioning.
2 INS/GPS Fusion Techniques Used for Vehicle Positioning 2.1 Kalman Filter (KF) Approach Traditionally, Kalman filters are utilised to perform INS/GPS integration. The Kalman filter is used to estimate a linear system instantaneous state affected by Gaussian white noise. It has become a standard technique for use in INS/GPS applications
Vehicular Localisation at High and Low Estimation Rates …
233
[20]. Despite the wide popularity of the Kalman filter, it does possess some drawbacks. For an INS/GPS-integrated application, the Kalman filter requires stochastic models to represent the INS errors, but these stochastic models are difficult to determine for most gyroscopes and accelerometers [13]. Even more, there is the need for an accurate a priori information of the covariant matrices of the noises associated with the INS. More so, the INS/GPS problem is one of a non-linear nature. As a result, other types of filters have been studied [20]. The Kalman filter functions in two stages: the prediction stage, which involves the computation of the errors between the measurement and the prediction, and the update (innovation) stage, where the Kalman filter uses the inputs, measurements and process model to correct the predictions. In modelling the error between the INS and GNSS position, we model the prediction and update stages in discrete time as X t = At|t−1 X t−1 + Bt−1 Ut−1 + wt−1
(1)
Z t = Ht X t + Vt
(2)
. Xˆ t = At|t−1 X t−1 .
(3)
T T Z t = At|t−1 Pt−1 At|t−1 + Bt−1 Q t−1 Bt−1
(4)
Prediction Stage:
Innovation Stage: This stage involves the computation of the errors between the predicted states and the measurements. −1 K t = Pt− HtT Ht Pt− HtT + Rt
(5)
Xˆ t = Xˆ t− + K t Z t − Ht Xˆ t−
(6)
Pt = (1 − K t Ht ) ∗ Xˆ t−
(7)
where X t is the error state vector, Ut is the input/control vector, wt is the system noise (Gaussian), Vt is the measurement noise (Gaussian), Z t is the measurement vector, At is the state transition matrix, Bt is the system noise coefficient matrix, Ht relates the state to the measurement, Qt is the system noise covariance matrix,Rt is the measurement noise covariance matrix, Xˆ t is the state prediction update a priori, Xˆ t− is the state prediction update a posteriori, K t is the Kalman gain matrix, Pt is the error covariance update a priori and Pt− is the error covariance update a posteriori. More information on the Kalman filter can be found in [13, 21].
234
U. Onyekpe et al.
2.2 Multi-layer Neural Networks (MLNN) and Deep Learning An MLNN consists of an interconnected system of neurons with the ability to map non-linear relationships between the inputs and the outputs. This capability is particularly of interest as vehicles’ dynamics is non-linear in nature. The neurons are connected by weights, with the output defined by a function of the sum of the neuron’s input and transformed non-linearly through an activation function. The neuron’s input is computed from the product of a weight factor matrix and the input matrix, and a bias. The output from a neuron layer becomes the input vector for the neurons in the next layer. Through the continuous backpropagation of errors signals, the weights are adjusted in what is referred to as the training phase of the MLNN. An adjustable learning rate and momentum can be used to prevent the MLNN from getting trapped in a local minimum while backpropagating the errors [22]. The feed-forward layer operation is governed by y=σ
xw + b
(8)
where y is the layer output vector, x is the input vector, w is the weight matrix and b is the bias. Deep learning algorithms employ the use of multiple layers to extract features progressively from the raw input through multiple levels of non-linear abstractions. They try to find good representations of the input–output relationship by exploiting the input distribution structure [23]. The success of deep learning algorithms on several challenging problems has been demonstrated in many published literatures, for example [24–27]. Some popular deep learning architectures used today, especially in computer vision applications, are ResNet [28] and Yolo [29], which show good performances in visual recognition tasks. Aside from the deep learning’s success in image processing tasks, sequential data such as audio and texts are now processed with deep neural networks, which are able to achieve state-of-the-art performances in speech recognition and natural language processing tasks [30]. These successes lend support to the potential of deep learning algorithms to learn complex abstractions present in the noisy sensor signals in the application under study in this chapter.
2.3 Input Delay Neural Network (IDNN) The position errors of INS are accumulative and follow a certain pattern [25]. Therefore, previous positional sequences are required for the model to capture the error trend. It is however difficult to utilise a static neural network to model this pattern. A dynamic model can be employed by using an architecture that presents the previous ‘t’ values of the signal as inputs to the network, thus capturing the error trend present
Vehicular Localisation at High and Low Estimation Rates …
235
Fig. 1 Illustration of an IDNN’s general architecture [17]
in the previous t timesteps [17]. The model can also be trained to learn time-varying or sequential trends through the introduction of memory and associator units to the input layer. The memory unit of the dynamic model can store previous INS samples and forecast using the associator unit. The use of such a dynamic model has a significant influence on the accuracy of the INS position prediction in the absence of GPS [17]. Figure 1 illustrates an IDNN’s general architecture, with p being the tapped delay line memory length, Ui the hidden layer neurons, W the weights, G the activation function, Y the target vector and D the delay operator.
2.4 Long Short-Term Neural Networks (LSTM) Recurrent Neural Networks (RNNs) have been proven to learn more useful features on sequential problems. Long Short-Term Memory (LSTM) networks are a variant of RNN created to tackle its shortfall. They are specifically created to solve the long-term dependency problem, hence enabling them to recall information for long periods. Due to the accumulative and patterned nature of the INS positional errors,
236
U. Onyekpe et al.
Fig. 2 Unrolled RNN architecture
the LSTM can be used to learn error patterns from previous sequences to provide a better position estimation. The operation of the LSTM is regulated by gates; the forget, input and output gates, and it operates as shown below: f t = σ W f h t−1 , xt + b f
(9)
i t = σ Wi h t−1 , xt + bi
(10)
cˆt = tanh Wc h t−1 , xt + bc
(11)
ct = f t ∗ ct−1 + i t ∗ cˆt
(12)
ot = σ Wo h t−1 , xt + bo
(13)
h t = ot ∗ tanh(ct )
(14)
where ∗ is the Hadamard product, f t is the forget gate, i t is the input gate, ot is the output gate, ct is the cell state, h t−1 is the previous state, W is the weight matrix and σ is the sigmoid activation (non-linearity) function. Figures 2 and 3 show the RNN’s architecture and LSTM’s cell structure.
3 Problem Formulation In this section, we discuss the formulation of the learning task using the four modelling techniques introduced in Sect. 2. Furthermore, we present the experiments and comparative results on the inertial tracking problem in Sect. 4.
Vehicular Localisation at High and Low Estimation Rates …
237
Fig. 3 LSTM cell structure
3.1 Inertial Tracking Vehicular tracking using inertial sensors are governed by the Newtonian laws of motion. Given an initial pose, the attitude rate from the gyroscope can be integrated to provide continuous orientation information. The acceleration of the vehicle can also be integrated to determine its velocity and, provided an initial velocity, the vehicle’s displacement can be estimated through the integration of its velocity. Several co-ordinate systems are used in vehicle tracking with the positioning estimation expressed relative to a reference. Usually, measurements in the sensors’ co-ordinate system (body frame) would need to be transformed to the navigation frame [31]. The body frame has its axis coincident to the sensors’ input axis. The navigation frame, also known as the local frame, has its origin as the sensor frames’ origin. Its x-axis points towards the geodetic north, with the z-axis orthogonal to the ellipsoidal plane and the y-axis completing the orthogonal frame. The rotation matrix from the body frame to the navigation frame is expressed in Eq. (15) [20]. ⎤ cos θ cos Ψ − cos θ sin Ψ + sin φ sin θ cos Ψ sin φ sin Ψ + cos φ sin θ cos Ψ ⎥ ⎢ R nb = ⎣ cos θ sin Ψ cos φ cos Ψ + sin φ sin θ sin Ψ − sin φ cos Ψ + cos φ sin θ sin Ψ ⎦ − sin θ sin φ cos θ cos φ cos θ ⎡
(15)
where θ is the pitch, Ψ is the yaw and φ is the roll. As the problem is considered to be the tracking of a vehicle in a horizontal plane, the roll and pitch of the vehicle can be considered negligible and thus assumed to be zero. Hence, the rotation matrix becomes ⎡
R nb
⎤ cos Ψ − sin Ψ 0 = ⎣ sin Ψ cos Ψ 0 ⎦ 0 0 1
(16)
238
U. Onyekpe et al.
An Inertial Measuring Unit usually consists of a three orthogonal axis accelerometer and a three orthogonal axis gyroscope. The accelerometer measures acceleration in the x, y and z-axis. It measures the specific force f on the sensor in the body frame b [20]. This can be expressed as in Eq. (17), where R bn is the rotation matrix from the navigation frame to the body frame, g n represents the gravity vector and a n denotes the linear acceleration of the sensor expressed in the navigation frame [31]. f b = R bn a n − g n
(17)
n n n given : a n = anb + 2ωie + Vnn + ωie × ωie × pn
(18)
n + Vnn is the Coriolis acceleration, where anb is the acceleration of interest, 2ωie n n n ωie × ωie × p is the centrifugal acceleration and Vnn is the velocity of the vehicle in the navigation frame [31]. In this application, the Coriolis acceleration is considered negligible due to its small magnitude compared to the accelerometers’ measurements, and the centrifugal acceleration is considered to be absorbed in the local gravity sector. Thus, a n = anb . The gyroscope measures the attitude change in roll, yaw and pitch. It measures the angular velocity of the body frame (vehicles frame) with respect to the inertial b , the attitude rate can be frame, as expressed in the body frame. Represented by ωib expressed as
n b n b + ωnb = R bn ωie + ωen ωib
(19)
n is the angular velocity of the earth frame with respect to the inertial frame where ωie and estimated to be approximately 7.29 × 10−5 rad/s [31]. The navigation frame is n = 0 with the angular velocity of defined stationary with respect to the earth, thus ωen b interest ωnb representing the rate of rotation of the vehicle in the navigation frame. b may be integrated over time to determine the If initial conditions are known, ωnb vehicles orientation (yaw), as shown in Eq. (20):
t ΨINS = Ψ0 +
b ωnb
(20)
t−1
where Ψ0 is the last known yaw of the vehicle.
3.2 Deep Learning Task Formulation The accelerometer measurement (specific force) f b at each time instant t is typically assumed to be corrupted by a bias δ bI N S and noise εab . Thus, the corrupted sensor’s measurement can be represented as FIbN S ,
Vehicular Localisation at High and Low Estimation Rates …
239
where b b b = f INS + δINS + εab FINS
(21)
Furthermore, theaccelerometers’ noise is typically quite Gaussian1 and can be b modelled as εa ∼ N 0, a . The accelerometers’ bias is slowly time varying and as such can be modelled either as a constant parameter or as part of a time-varying state. The specific force measurement can be expanded from Eq. (17) as shown below: ab = f b + gb
(22)
b b b FINS = aINS + δINS,a + εab
(23)
b b a b = FINS − δINS,a − εab
(24)
b b FINS − δINS,a = a b + εab
(25)
b b However, aINS = FINS − δ bI N S,a
(26)
b aINS = a b + εab
(27)
Through the integration of Eq. (27), the velocity of the vehicle in the body frame can be determined. t b vINS
=
a b + εvb
(28)
t−1
The displacement of the sensor in the body frame at time t from t − 1, x IbN S , can also be estimated by the double integration of the Eq. (27) provided an initial velocity. ¨t b xINS
=
a b + εxb
(29)
t−1 b is the bias in the body frame calculated to be a constant parameter by where δINS,a computing the average reading from a stationary accelerometer which ran for 20 min. b is the corrupted measurement provided directly by the accelerometer sensor at FINS
1 The
vehicles’ dynamics is non-linear, especially when cornering or braking hard; thus, a linear or non-accurate noise model would not sufficiently capture the non-linear relationship.
240
U. Onyekpe et al.
˜t t time t (sampling time), g is the gravity vector, a b , t−1 a b and t−1 a b are the true (uncorrupted) longitudinal acceleration, velocity and displacement, respectively, of the vehicle. ˜t The true displacement of the vehicle is expressed; thus, x Gb P S ≈ t−1 a b Furthermore, εxb can be derived by b b − xINS εxb ≈ xGPS
(30)
b b b , velocity vINS and acceleration aINS of the vehicle The noise εxb , displacement xINS in the body frame within window t − 1 to t can thus be transformed to the navigation frame using the rotation matrix R nb and defined by the North-East-Down (NED) system, as shown in Eqs. (31–34). However, the down axis is not considered in this study. nb b n b b · aINS → aINS → aINS · cos ΨINS , aINS · sin ΨINS from RINS
(31)
nb b n b b from RINS · vINS → vINS → vINS · cos ΨINS , vINS · sin ΨINS
(32)
nb b n b b from RINS · xINS → xINS → xINS · cos ΨINS , xINS · sin ΨINS
(33)
from R nb · εxb → εxn → εxb · cos Ψ, εxb · sin Ψ
(34)
⎡
nb Where : RINS
⎤ cos ΨINS − sin ΨINS 0 = ⎣ sin ΨINS cos ΨINS 0 ⎦ 0 0 1
(35)
3.3 Vehicle’s True Displacement Estimation b b This section presents how to estimate the vehicle’s true displacement xGPS . xGPS is b useful in the determination of the target error εx as detailed in Eq. (30). In estimating the distance travelled between two points on the earth’s surface, it becomes obvious that the shape of the earth is neither a perfect sphere nor ellipse, but rather that of an oblate ellipsoid. Due to the unique shape of the earth, complications exist as there is no geometric shape it can be categorised under for analysis. The Haversine formula applies perfectly to the calculations of distances on spherical shapes, while the Vincenty’s formula applies to elliptic shapes [32].
Vehicular Localisation at High and Low Estimation Rates …
3.3.1
241
Haversine’s Formula
The Haversine’s formula is used to calculate the distance between two points on the earth’s surface specified in longitude and latitude. It assumes a spherical earth [33]. b = 2r sin−1 xGPS
sin2
∅t − ∅t−1 2
+ cos(∅t−1 )cos(∅t )sin2
ϕt − ϕt−1 2
(36)
where xˆtb is the distance travelled within t − 1, t with longitude and latitude (ϕ, ∅) as obtained from the GPS, and r is the radius of the earth.
3.3.2
Vincenty’s Inverse Formula
The Vincenty’s formula is used to calculate the distance between two points on the earth’s surface specified in longitude and latitude. It assumes an ellipsoidal earth [34]. The distance between two points is calculated as shown in Eqs. (37)–(55). Given : f =
sin σ =
1 298.257223563
(37)
b = (1 − f )a
(38)
U 1 = arctan(1 − f )tanφ1
(39)
ϕ = ϕ2 − ϕ1
(40)
U 1 = arctan(1 − f )tanφ2
(41)
(cos U2 sin λ)2 + (cos U1 sinU2 − sinU1 cosU2 cosλ)2
(42)
cos σ = sin U1 sin U2 + cos U1 cos U2 cos λ
(43)
σ = arctan2(sin σ, cos σ )
(44)
cos U1 cos U2 sin λ sin σ
(45)
sin α =
cos(2σm ) = cos σ −
2 sin U1 sin U2 cos2 α
(46)
242
U. Onyekpe et al.
f (47) cos2 α 4 + f 4 − 3 cos2 α 16 λ = ϕ + (1 − C) f sin α[σ + C sin σ cos(2σm ) + C cos σ −1 + 2 cos2 (2σm ) (48) 2 a − b2 (49) u 2 = cos2 α b2 √ 1 + u2 − 1 k1 = √ (50) 1 + u2 + 1 C=
1 + 41 k12 1 − k1 3 B = k1 1 − k12 8 A=
(51) (52)
σ = B sin σ B 1 cos2 (2σm ) + B cos σ −1 + 2 cos2 (2σm ) − cos[2σm ][−3 + 4 sin2 σ ][ − 3 + 4 cos2 (2σm )] 4 6
(53) α1 = arctan2(cos U2 sin λ, cos U1 sin U2 − sin U1 cos U2 cos λ)
(54)
α2 = arctan2(cos U1 sin λ, sin U1 cos U2 − cos U1 sin U2 cos λ)
(55)
where xˆtb is the distance travelled within t − 1 and t with longitude and latitude (ϕ, ∅), a is the radius of the earth at the equator, f is the flattening at the ellipsoid, b is the length of the ellipsoid semi-minor axis, U1 and U2 are the reduced latitude at t and t − 1, respectively, λ is the change in longitude along the auxiliary spheres, s is the ellipsoidal distance between the position at t − 1 and t, σ1 is the angle between the position at t − 1 and t, σ is the angle between the position at t − 1 and t and σm is the angle between the equator and midpoint of the line. The Vincenty’s formula is used in this work, as it provides a more accurate solution compared to Haversine and other great circle formulas [32]. The Python implementation of Vincenty’s Inverse Formula is used here [35].
3.4 Learning Scheme for the Vehicle’s Displacement Error Prediction The neural networks introduced in Sect. 2 are exploited to learn the relationship n n n , velocity vINS and acceleration aINS , between the input features; displacement xINS
Vehicular Localisation at High and Low Estimation Rates …
243
Fig. 4 Learning scheme for the northwards and eastwards displacement error prediction
and the target displacement error εxn (as presented in Sects. 3.2 and 3.3) in the northwards and eastwards direction, as shown in Fig. 4. The predicted displacement error is used to correct the INS-derived displacement to provide a better positioning solution.
4 Experiments, Results and Discussion 4.1 Data set The data used is the V-Vw12 aggressive driving vehicle benchmark data set describing about 107 s of an approximate straight-line trajectory, of which the first 105 s is used for our analysis [36]. The sensors are inbuilt and the data is captured from the vehicle’s ECU using the Vbox video H2 data acquisition sampling at a frequency of 10 Hz. The longitudinal acceleration of the vehicle as well as its rate of rotation about the z-axis (yaw rate), heading (yaw) and the GPS co-ordinates (latitude and longitude) of the vehicle at each time instance is captured. Figure 5 shows the vehicle used for the data collection and the location of its sensors.
Fig. 5 Data collection vehicle [36]
244
U. Onyekpe et al.
Table 1 LSTM, IDNN and MLNN training parameters Parameters
LSTM
IDNN
MLNN
Learning rate
0.09
0.09
0.09
L1 regulariser
0.9
–
–
L2 regulariser
0.99
–
–
Recurrent dropout
5%
–
–
Dropout
–
5%
5%
Sequence length for sample periods 0.1–0.3 s
5
10 × Sample Period
–
Sequence length for sample periods 0.4–1 s
10 × Sample Period
10 × Sample Period
–
Hidden layers
2
2
2
Hidden neurons per layer
32
32
32
Batch size
32
32
32
Epochs
500
500
500
4.2 Training The training is done using the first 75 s of the data, and then the model is tested on the next 10 s as well as 30 s after the training data. The Keras–TensorFlow framework was used in the training exercise with a mean absolute error loss function and an Adamax optimiser with a learning rate of 0.09. 5% of the units were dropped from the hidden layers of the IDNN and MLNN and from the recurrent layers of the LSTM to prevent the neural network from overfitting [37]. Furthermore, all features fed to the models were standardised between 0 and 100 to avoid a biased learning. Forty models were trained for each deep learning model, and the model providing the least position errors was selected. Parameters defining the training of the neural network models are highlighted on Table 1. The general architecture of the models is as shown in Fig. 4. The objective of the training exercise is to teach the neural network to learn the positioning error between the low-cost INS and GPS.
4.3 Testing GPS outages were assumed on the 10 s as well as 30 s of data following the training data to help analyse the performance of the prediction models. With information on the vehicles orientation and velocity at the last known GPS signal before the outage, the northwards and eastwards components of the vehicles displacement x InN S,t, velocity v nI N S,t and acceleration a nIN S,t are fed into the respective models n . to predict the north and east component of the displacement error εx,t
Vehicular Localisation at High and Low Estimation Rates …
245
Table 2 Position error after 10 s of GPS outages LSTM position error (m)
IDNN position error (m)
MLNN position error (m)
Sampling North East Total North East Total North East Total period (s) 1
1.21
0.61 1.36
0.9
0.59
0.7
0.8
0.39
0.7
0.29
0.6
1.04
0.5
KF position error (m) North East Total
1.63
0.64 1.75
25.51 6.55 26.34 1.39
0.54 1.49
0.92
1.22
1.08 1.63
22.02 4.28 22.43 1.05
0.93 1.40
0.47 0.61
0.37
0.93 1.00
19.38 6.69 20.51 0.32
0.81 0.87
0.53 0.60
0.62
0.98 1.16
16.97 6.63 18.22 0.55
0.86 1.02
0.36 1.10
1.37
0.81 1.59
16.36 4.38 16.94 1.22
0.72 1.42
0.62
0.54 0.82
1.67
1.05 1.97
12.09 5.88 13.44 1.50
0.95 1.78
0.4
1.81
0.6
1.91
2.17
1.11 2.44
9.79 5.9
11.43 1.97
1.01 2.22
0.3
2.24
0.89 2.24
2.64
1.15 2.88
6.98 5.93
9.16 2.43
1.06 2.65
0.2
1.45
0.71 1.61
2.83
1.05 3.02
4.73 5.08
6.94 2.80
1.04 2.99
0.1
0.94
0.90 1.30
1.12
1.41 2.80
2.14 4.49
4.97 1.01
1.40 1.73
4.4 Results and Discussion To evaluate the performance of the LSTM, IDNN, MLNN and Kalman filter techniques, two GPS outage scenarios are explored: 10 s and 30 s.
4.4.1
10 s Outage Experiment Result
The performance of the LSTM, IDNN, MLNN and Kalman filter solutions during the 10 s outage is studied. From Table 2, it can be seen that at all sampling periods the LSTM algorithm performed best at estimating the positioning error, followed closely by the Kalman filter and IDNN. The MLNN has the least performance in comparison. Comparing all sampling periods, the LSTM method produces the best error estimation of 0.60 m, at a sampling period of 0.7 s, over about 233 m of travel within the 10 s studied.
4.4.2
30 s Outage Experiment Result
A study of the LSTM, IDNN, MLNN and Kalman Filter performances during the 30 s GPS outage scenario reveals that much unlike the 10 s experiment, the Kalman filter performs poorly in comparison to the LSTM and IDNN approaches. The LSTM performs the best at all sampling frequencies, with the Kalman filter outperforming the MLNN. A comparison of the sampling periods shows that the LSTM approach provides the best error estimation of 4.86 m at 10 Hz, over about 724 m of travel within the 30 s investigated (Table 3).
246
U. Onyekpe et al.
Table 3 Position error after 30 s of GPS outages LSTM position error (m)
IDNN position error MLNN position (m) error (m)
Sampling North East Total North East period (s)
Total North East
KF position error (m)
Total North East
Total
1
4.81
3.07
5.71 4.87
6.57
8.18 30.28 10.26 31.97 22.71
7.70 23.98
0.9
5.33
5.19
7.44 5.22
7.30
8.97 25.42
9.15 27.02 19.57
7.05 20.80
0.8
6.25
7.79
9.99 5.93
8.30 10.20 22.64 12.08 25.66 17.88
9.54 20.27
0.7
6.65
7.69 10.17 6.33
8.71 10.77 20.22 12.60 23.82 16.38 10.21 19.30
0.6
6.93
7.36 10.11 7.03
8.56 11.08 20.35 10.70 22.99 16.89
0.5
6.43
4.45
8.96 11.56 15.38 12.36 19.73 13.07 10.51 16.77
0.4
7.52
7.69 10.76 7.49
0.3
5.46
4.62
8.72 8.16
10.38 13.20
9.98 13.53 16.81
8.88 12.04 14.96
0.2
5.46
4.46
7.05 8.29
9.87 12.89
7.64 12.52 14.67
6.95 11.39 13.35
0.1
3.74
3.11
4.86 8.90
10.22 13.55
5.19 12.2
4.83 11.35 12.33
7.82 7.3
8.88 19.08
9.29 11.93 12.77 12.82 18.09 11.11 11.15 15.74
13.26
5 Conclusions Effective vehicular services and safety of autonomous vehicles depend on an accurate and reliable positioning of the vehicle. Most commercial GPS receivers however operate at a rather low sampling rate (1 Hz) and face reliability problems in urban canyons, tunnels, etc. An INS can fill in for the GPS to provide continuous positioning information in between GPS signals reception. To this end, the LSTM, IDNN, MLNN and Kalman filter techniques were investigated over several sampling scenarios in GPS signal outages of 10 s and 30 s. The results of the study show that during shortterm outages (less than 10 s) and longer GPS outages (about 30 s) the LSTM approach provides the best positioning solution. Furthermore, our findings show that sampling at lower rates during long-term GPS outages provides relatively poorer position estimates. There is however the need to explore the performance of the LSTM model on more complex driving scenarios, as a means to assess its robustness. This will be the subject of our future research.
References 1. I. Dowd, The future of autonomous vehicles, Open Access Government (2019). https://www. openaccessgovernment.org/future-of-autonomous-vehicles/57772/, Accessed 04 June 2019 2. P. Liu, R. Yang, Z. Xu, How safe is safe enough for self-driving vehicles? Risk Anal. 39(2), 315–325 (2019) 3. A. Papadoulis, M. Quddus, M. Imprialou, Evaluating the safety impact of connected and autonomous vehicles on motorways. Accid. Anal. Prev. 124, 12–22 (2019) 4. S.-J. Babak, S.A. Hussain, B. Karakas, S. Cetin, Control of autonomous ground vehicles: a brief technical review—IOPscience (2017). https://iopscience.iop.org/article/10.1088/1757-
Vehicular Localisation at High and Low Estimation Rates …
247
899X/224/1/012029, Accessed 22 Mar 2020 5. K. Onda, T. Oishi, Y. Kuroda, Dynamic environment recognition for autonomous navigation with wide FOV 3D-LiDAR. IFAC-PapersOnLine 51(22), 530–535 (2018) 6. S. Ahmed, M.N. Huda, S. Rajbhandari, C. Saha, M. Elshaw, S. Kanarachos, Pedestrian and cyclist detection and intent estimation for autonomous vehicles: a survey. Appl. Sci. 9(11), 2335 (2019) 7. W. Yao et al., GPS signal loss in the wide area monitoring system: prevalence, impact, and solution, Electr. Power Syst. Res. 147(C), 254–262 (2017) 8. G. O’Dwyer, Finland, Norway press Russia on suspected GPS jamming during NATO drill (2018). https://www.defensenews.com/global/europe/2018/11/16/finland-norway-pressrussia-on-suspected-gps-jamming-during-nato-drill/, Accessed 04 June 2019 9. B. Templeton, Cameras or lasers? (2017). http://www.templetons.com/brad/robocars/cameraslasers.html, Accessed 04 June 2019 10. L. Teschler, Inertial measurement units will keep self-driving cars on track (2018). https:// www.microcontrollertips.com/inertial-measurement-units-will-keep-self-driving-cars-ontrack-faq/, Accessed 05 June 2019 11. OXTS, Why integrate an INS with imaging systems on an autonomous vehicle (2016). https://www.oxts.com/technical-notes/why-use-ins-with-autonomous-vehicle/, Accessed 04 June 2019 12. R. Sharaf, A. Noureldin, A. Osman, N. El-Sheimy, Online INS/GPS integration with a radial basis function neural network. IEEE Aerosp. Electron. Syst. Mag. 20(3), 8–14 (2005) 13. K.-W. Chiang, N. El-Sheimy, INS/GPS integration using neural networks for land vehicle navigation applications (2002), pp. 535–544 14. K.W. Chiang, A. Noureldin, N. El-Sheimy, Multisensor integration using neuron computing for land-vehicle navigation. GPS Solut. 6(4), 209–218 (2003) 15. K.-W. Chiang, The utilization of single point positioning and multi-layers feed-forward network for INS/GPS integration (2003), pp. 258–266 16. M. Malleswaran, V. Vaidehi, M. Jebarsi, Neural networks review for performance enhancement in GPS/INS integration, in 2012 International Conference on Recent Trends in Information Technology ICRTIT 2012, no. 1 (2012), pp. 34–39 17. A. Noureldin, A. El-Shafie, M. Bayoumi, GPS/INS integration utilizing dynamic neural networks for vehicular navigation. Inf. Fusion 12(1), 48–57 (2011) 18. M. Malleswaran, V. Vaidehi, A. Saravanaselvan, M. Mohankumar, Performance analysis of various artificial intelligent neural networks for GPS/INS integration. Appl. Artif. Intell. 27(5), 367–407 (2013) 19. A.S. El-Wakeel, A. Noureldin, N. Zorba, H.S. Hassanein, A framework for adaptive resolution geo-referencing in intelligent vehicular services, in IEEE Vehicular Technology Conference, vol. 2019 (2019) 20. K. Chiang, INS/GPS integration using neural networks for land vehicular navigation UCGE reports number 20209 Department of Geomatics Engineering INS/GPS Integration using Neural Networks for Land Vehicular Navigation Applications by Kai-Wei Chiang (2004) 21. T.P. Van, T.N. Van, D.A. Nguyen, T.C. Duc, T.T. Duc, 15-state extended kalman filter design for INS/GPS navigation system. J. Autom. Control Eng. 3(2), 109–114 (2015) 22. M.W. Gardner, S.R. Dorling, artificial neural networks (the multilayer perceptron)—A review of applications in the atmospheric sciences. Atmos. Environ. 32(14–15), 2627–2636 (1998) 23. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in Deep Learning, vol. 57 (Springer, Singapore, 2020) 24. A. Krizhevsky, I. Sutskever, G.E. Hinton, ImageNet classification with deep convolutional neural networks (2012). https://papers.nips.cc/paper/4824-imagenet-classification-with-deepconvolutional-neural-networks.pdf 25. C. Chen, X. Lu, A. Markham, N. Trigoni, IONet: learning to cure the curse of drift in inertial odometry (2018), pp. 6468–6476 26. P. Kasnesis, C.Z. Patrikakis, I.S. Venieris, PerceptionNet: a deep convolutional neural network for late sensor fusion (2018)
248
U. Onyekpe et al.
27. W. Fang et al., A LSTM algorithm estimating pseudo measurements for aiding INS during GNSS signal outages. Remote Sens. 12(2), 256 (2020) 28. K. He, X. Zhang, S. Ren, J. Sun, Deep residual learning for image recognition, in Proceedings of the IEEE Computer Society Conference on Computer Vision and Pattern Recognition, vol. 2016, pp. 770–778 (2016) 29. J. Redmon, S. Divvala, R. Girshick, A. Farhadi, You only look once: unified, real-time object detection (2016). arXiv preprint arXiv:1506.02640 30. H. Ismail Fawaz, G. Forestier, J. Weber, L. Idoumghar, P.A. Muller, Deep learning for time series classification: a review. Data Min. Knowl. Discov. 33(4), 917–963 (2019) 31. M. Kok, J.D. Hol, T.B. Schön, Using inertial sensors for position and orientation estimation. Found. Trends Signal Process. 11(2), 1–153 (2017) 32. H. Mahmoud, N. Akkari, Shortest path calculation: a comparative study for location-based recommender system, in Proceedings—2016 World Symposium on Computer Applications and Research, WSCAR 2016, pp. 1–5 (2016) 33. C.M. Thomas, W.E. Featherstone, Validation of Vincenty’s formulas for the geodesic using a new fourth-order extension of Kivioja’s formula. J. Surv. Eng. 131(1), 20–26 (2005) 34. T. Vincenty, Direct and inverse solutions of geodesics on the ellipsoid with application of nested equations. Surv. Rev. 23(176), 88–93 (1975) 35. vincenty PyPI. https://pypi.org/project/vincenty/. Accessed 08 May 2020 36. U. Onyekpe, V. Palade, S. Kanarachos, A. Szkolnik, IO-VNBD: inertial and odometry benchmark dataset for ground vehicle positioning (2020). arXiv preprint arXiv:2005.01701 37. Y. Gal, Z. Ghahramani, A theoretically grounded application of dropout in recurrent neural networks (2016). arXiv preprint arXiv:1512.05287
Multi-Adversarial Variational Autoencoder Nets for Simultaneous Image Generation and Classification Abdullah-Al-Zubaer Imran and Demetri Terzopoulos
Abstract Discriminative deep-learning models are often reliant on copious labeled training data. By contrast, from relatively small corpora of training data, deep generative models can learn to generate realistic images approximating real-world distributions. In particular, the proper training of Generative Adversarial Networks (GANs) and Variational AutoEncoders (VAEs) enables them to perform semi-supervised image classification. Combining the power of these two models, we introduce MultiAdversarial Variational autoEncoder Networks (MAVENs), a novel deep generative model that incorporates an ensemble of discriminators in a VAE-GAN network in order to perform simultaneous adversarial learning and variational inference. We apply MAVENs to the generation of synthetic images and propose a new distribution measure to quantify the quality of these images. Our experimental results with only 10% labeled training data from the computer vision and medical imaging domains demonstrate performance competitive to state-of-the-art semi-supervised models in simultaneous image generation and classification tasks.
1 Introduction Training deep neural networks usually requires copious data, yet obtaining large, accurately labeled datasets for image classification and other tasks remains a fundamental challenge [36]. Although there has been explosive progress in the production of vast quantities of high resolution images, large collections of labeled data required for supervised learning remain scarce. Especially in domains such as medical imaging, datasets are often limited in size due to privacy issues, and annotation by medical experts is expensive, time-consuming, and prone to human subjectivity, A.-A.-Z. Imran (B) · D. Terzopoulos University of California, Los Angeles, CA 90095, USA e-mail: [email protected] D. Terzopoulos e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_11
249
250
A.-A.-Z. Imran and D. Terzopoulos
Fig. 1 Image generation based on the CIFAR-10 dataset [19]: a Relatively good images generated by a GAN. b Blurry images generated by a VAE. Based on the SVHN dataset [24]: c mode collapsed images generated by a GAN
inconsistency, and error. Even when large labeled datasets become available, they are often highly imbalanced and non-uniformly distributed. In an imbalanced medical dataset there will be an over-representation of common medical problems and an under-representation of rarer conditions. Such biases make the training of neural networks across multiple classes with consistent effectiveness very challenging. The small-training-data problem is traditionally mitigated through simplistic and cumbersome data augmentation, often by creating new training examples through translation, rotation, flipping, etc. The missing or mismatched label problem may be addressed by evaluating similarity measures over the training examples. This is not always robust and its effectiveness depends largely on the performance of the similarity measuring algorithms. With the advent of deep generative models such as Variational AutoEncoders (VAEs) [18] and Generative Adversarial Networks (GANs) [9], the ability to learn underlying data distributions from training samples has become practical in common scenarios where there is an abundance of unlabeled data. With minimal annotation, efficient semi-supervised learning could be the preferred approach [16]. More specifically, based on small quantities of annotation, realistic new training images may be generated by models that have learned real-world data distributions (Fig. 1a). Both VAEs and GANs may be employed for this purpose. VAEs can learn dimensionality-reduced representations of training data and, with an explicit density estimation, can generate new samples. Although VAEs can perform fast variational inference, VAE-generated samples are usually blurry (Fig. 1b). On the other hand, despite their successes in generating images and semi-supervised classifications, GAN frameworks remain difficult to train and there are challenges in using GAN models, such as non-convergence due to unstable training, diminished gradient issues, overfitting, sensitivity to hyper-parameters, and mode collapsed image generation (Fig. 1c). Despite the recent progress in high-quality image generation with GANs and VAEs, accuracy and image quality are usually not ensured by the same model, especially in multiclass image classification tasks. To tackle this shortcoming, we propose
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
251
a novel method that can simultaneously learn image generation and multiclass image classification. Specifically, our work makes the following contributions: 1. The Multi-Adversarial Variational autoEncoder Network, or MAVEN, a novel multiclass image classification model incorporating an ensemble of discriminators in a combined VAE-GAN network. An ensemble layer combines the feedback from multiple discriminators at the end of each batch. With the inclusion of ensemble learning at the end of a VAE-GAN, both generated image quality and classification accuracy are improved simultaneously. 2. A simplified version of the Descriptive Distribution Distance (DDD) [14] for evaluating generative models, which better represents the distribution of the generated data and measures its closeness to the real data. 3. Extensive experimental results utilizing two computer vision and two medical imaging datasets.1 These confirm that our MAVEN model improves upon the simultaneous image generation and classification performance of a GAN and of a VAE-GAN with the same set of hyper-parameters.
2 Related Work Several techniques have been proposed to stabilize GAN training and avoid mode collapse. Nguyen et al. [26] proposed a model where a single generator is used alongside dual discriminators. Durugkar et al. [7] proposed a model with a single generator and feedback aggregated over several discriminators, considering either the average loss over all discriminators or only the discriminator with the maximum loss in relation to the generator’s output. Neyshabur et al. [25] proposed a framework in which a single generator simultaneously trains against an array of discriminators, each of which operates on a different low-dimensional projection of the data. Moridido et al. [23], arguing that all the previous approaches restrict the discriminator’s architecture thereby compromising extensibility, proposed the Dropout-GAN, where a single generator is trained against a dynamically changing ensemble of discriminators. However, there is a risk of dropping out all the discriminators. Feature matching and minibatch discrimination techniques have been proposed [32] for eliminating mode collapse and preventing overfitting in GAN training. Realistic image generation helps address problems due to the scarcity of labeled data. Various architectures of GANs and their variants have been applied in ongoing efforts to improve the accuracy and effectiveness of image classification. The GAN framework has been utilized as a generic approach to generating realistic training images that synthetically augment datasets in order to combat overfitting; e.g., for synthetic data augmentation in liver lesions [8], retinal fundi [10], histopathology [13], and chest X-rays [16, 31]. Calimeri et al. [3] employed a LAPGAN [6] and Han et al. [11] used a WGAN [1] to generate synthetic brain MR images. Bermudez 1 This
chapter significantly expands upon our ICMLA 2019 publication [15], which excluded our experiments on medical imaging datasets.
252
A.-A.-Z. Imran and D. Terzopoulos
et al. [2] used a DCGAN [29] to generate 2D brain MR images followed by an autoencoder for image denoising. Chuquicusma et al. [4] utilized a DCGAN to generate lung nodules and then conducted a Turing test to evaluate the quality of the generated samples. GAN frameworks have also been shown to improve accuracy of image classification via the generation of new synthetic training images. Frid et al. [8] used a DCGAN and an ACGAN [27] to generate images of three liver lesion classes to synthetically augment the limited dataset and improve the performance of a Convolutional Neural Net (CNN) in liver lesion classification. Similarly, Salehinejad et al. [31] employed a DCGAN to artificially simulate pathology across five classes of chest X-rays in order to augment the original imbalanced dataset and improve the performance of a CNN in chest pathology classification. The GAN framework has also been utilized in semi-supervised learning architectures to leverage unlabeled data alongside limited labeled data. The following efforts demonstrate how incorporating unlabeled data in the GAN framework has led to significant improvements in the accuracy of image-level classification. Madani et al. [20] used an order of magnitude less labeled data with a DCGAN in semi-supervised learning yet showed comparable performance to a traditional supervised CNN classifier and furthermore demonstrated reduced domain overfitting by simply supplying unlabeled test domain images. Springenberg et al. [33] combined a WGAN and CatGAN [35] for unsupervised and semi-supervised learning of feature representation of dermoscopy images. Despite the aforecited successes, GAN frameworks remain challenging to train, as we discussed above. Our MAVEN framework mitigates the difficulties of training GANs by enabling training on a limited quantity of labeled data, preventing overfitting to a specific data domain source, and preventing mode collapse, while supporting multiclass image classification.
3 The MAVEN Architecture Figure 2 illustrates the models that serve as precursors to our MAVEN architecture. The VAE is an explicit generative model that uses two neural nets, an encoder E and decoder D . Network E learns an efficient compression of real data x into a lower dimensional latent representation space z(x); i.e., qλ (z|x). With neural network likelihoods, computing the gradient becomes intractable; however, via differentiable, non-centered re-parameterization, sampling is performed from an approximate function qλ (z|x) = N (z; μλ , σλ2 ), where z = μλ + σλ εˆ with εˆ ∼ N (0, 1). Encoder E yields μ and σ , and with the re-parameterization trick, z is sampled from a Gaussian distribution. Then, with D , new samples are generated or real data samples are reconstructed; i.e., D provides parameters for the real data distribution pλ (x|z). Subsequently, a sample drawn from pφ (x|z) may be used to reconstruct the real data by marginalizing out z. The GAN is an implicit generative model where a generator G and a discriminator D compete in a minimax game over the training data in order to improve their perfor-
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
253
Fig. 2 Our MAVEN architecture compared to those of the VAE, GAN, and VAE-GAN. In the MAVEN, inputs to D can be real data X , or generated data Xˆ or X˜ . An ensemble ensures the combined feedback from the discriminators to the generator
mance. Generator G tries to approximate the underlying distribution of the training data and generates synthetic samples, while discriminator D learns to discriminate synthetic samples from real samples. The GAN model is trained on the following objectives: max V (D) = E x∼ pd ata(x) [log D(x)] + E x∼ pg (z) [log(1 − D(G(z))];
(1)
min V (G) = E x∼ pz (z) [log(1 − D(G(z))].
(2)
D
G
G takes a noise sample z ∼ pg (z) and learns to map it into image space as if it comes from the original data distribution pdata (x), while D takes as input either real image data or generated image data and provides feedback to G as to whether that input is real or generated. On the one hand, D wants to maximize the likelihood for real samples and minimize the likelihood of generated samples; on the other hand, G wants D to maximize the likelihood of generated samples. A Nash equilibrium results when D can no longer distinguish real and generated samples, meaning that the model distribution matches the data distribution. Makhzani et al. [21] proposed the adversarial training of VAEs; i.e., VAE-GANs. Although they kept both D and G, one can merge these networks since both can generate data samples from the noise samples of the representation z. In this case, D receives real data samples x and generated samples x˜ or xˆ via G. Although G and D compete against each other, the feedback from D eventually becomes predictable for G and it keeps generating samples from the same class, at which point the generated samples lack heterogeneity. Figure 1c shows an example where all the generated images are of the same class. Durugkar et al. [7] proposed that using multiple discriminators in a GAN model helps improve performance, especially for resolving this mode collapse. Moreover, a dynamic ensemble of multiple discriminators has recently been proposed to address the issue [23] (Fig. 3). As in a VAE-GAN, our MAVEN has three components, E, G, and D; all are CNNs with convolutional or transposed convolutional layers. First, E takes real samples
254
A.-A.-Z. Imran and D. Terzopoulos
Fig. 3 The three convolutional neural networks, E, G, and D, in the MAVEN
x and generates a dimensionality-reduced representation z(x). Second, G can input samples from noise distribution z ∼ pg (z) or sampled noise z(x) ∼ qλ (x) and it produces generated samples. Third, D takes inputs from distributions of real labeled data, real unlabeled data, and generated data. Fractionally strided convolutions are performed in G to obtain the image dimension from the latent code. The goal of an autoencoder is to maximize the Evidence Lower Bound (ELBO). The intuition here is to show the network more real data. The greater the quantity of real data that it sees, the more evidence is available to it and, as a result, the ELBO can be maximized faster.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
255
In our MAVEN architecture (Fig. 2), the VAE-GAN combination is extended to include multiple discriminators aggregated in an ensemble layer. K discriminators are collected and the combined feedback V (D) =
K 1 wk D k K k=1
(3)
is passed to G. In order to randomize the feedback from the multiple discriminators, a single discriminator is randomly selected.
4 Semi-Supervised Learning Algorithm 1 presents the overall training procedure of our MAVEN model. In the forward pass, different real samples x into E and noise samples z into G provide different inputs for each of the multiple discriminators. In the backward pass, the combined feedback from the discriminators is computed and passed to G and E. In the conventional image generator GAN, D works as a binary classifier—it classifies the input image as real or generated. To facilitate the training for an n-class classifier, D assumes the role of an (n + 1)-classifier. For multiple logit generation, the sigmoid function is replaced by a softmax function. Now, it can receive an image x as input and output an (n + 1)-dimensional vector of logits {l1 , . . . , ln , ln+1 }, which are finally transformed into class probabilities for the n labels in the real data while class (n + 1) denotes the generated data. The probability that x is real and belongs to class 1 ≤ i ≤ n is exp(li ) (4) p(y = i | x) = n+1 j=1 exp(l j ) while the probability that x is generated corresponds to i = n + 1 in (4). As a semisupervised classifier, the model takes labels only for a small portion of the training data. It is trained via supervised learning from the labeled data, while it learns in an unsupervised manner from the unlabeled data. The advantage comes from generating new samples. The model learns the classifier by generating samples from different classes.
4.1 Losses Three networks, E, G, and D, are trained on different objectives. E is trained on maximizing the ELBO, G is trained on generating realistic samples, and D is trained to learn a classifier that classifies generated samples or particular classes for the real data samples.
256
A.-A.-Z. Imran and D. Terzopoulos
Algorithm 1 MAVEN Training procedure. m is the number of samples; B is the minibatch-size; and K is the number of discriminators. steps ← mB for each epoch do for each step in steps do for k = 1 to K do Sample minibatch z (1) , . . . , z (m) from pg (z) Sample minibatch x (1) , . . . , x (m) from pdata (x) Update Dk by ascending along its gradient: ∇ Dk
m 1 log Dk (xi ) + log(1 − Dk (G(z i ))) m i=1
end for (1) (m) Sample minibatch z k , . . . , z k from pg (z) if ensemble is ‘mean’ then Assign weights wk to the Dk Determine the mean discriminator Dμ =
K 1 wk D k K k
end if Update G by descending along its gradient from the ensemble of Dμ : ∇G
m 1 log(1 − Dμ (G(z i ))) m i=1
Sample minibatch x (1) , . . . , x (m) from pdata (x) Update E along its expectation function: p(z) ∇ Eqλ log qλ (z | x) end for end for
4.1.1
D Loss
Since the model is trained on both labeled and unlabeled training data, the loss function of D includes both supervised and unsupervised losses. When the model receives real labeled data, it is the standard supervised learning loss L Dsupervised = −Ex,y∼ pdata log[ p(y = i | x)], i < n + 1.
(5)
When it receives unlabeled data from three different sources, the unsupervised loss contains the original GAN loss for real and generated data from two different sources:
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
257
synG directly from G and synE from E via G. The three losses, L Dreal = −Ex∼ pdata log[1 − p(y = n + 1 | x)],
(6)
log[ p(y = n + 1 | x)], ˆ L DsynG = −Ex∼G ˆ
(7)
log[ p(y = n + 1 | x)], ˜ L DsynE = −Ex∼G ˜
(8)
are combined as the unsupervised loss in D: L Dunsupervised = L Dreal + L DsynG + L DsynE .
4.1.2
(9)
G Loss
For G, the feature loss is used along with the original GAN loss. Activation f (x) from an intermediate layer of D is used to match the feature between real and generated samples. Feature matching has shown much potential in semi-supervised learning [32]. The goal of feature matching is to encourage G to generate data that matches real data statistics. It is natural for D to find the most discriminative features in real data relative to data generated by the model: 2 f (x) ˆ 2 . L G feature = Ex∼ pdata f (x) − Ex∼G ˆ
(10)
The total G loss becomes the combined feature loss (10) plus the cost of maximizing the log-probability of D making a mistake on the generated data (synG / synE); i.e., L G = L G feature + L G synG + L G synE ,
(11)
log[1 − p(y = n + 1 | x)], ˆ L G synG = −Ex∼G ˆ
(12)
log[1 − p(y = n + 1 | x)]. ˜ L G synE = −Ex∼G ˜
(13)
where
and
4.1.3
E Loss
In the encoder E, the maximization of ELBO is equivalent to minimizing the KLdivergence, allowing approximate posterior inferences. Therefore the loss function includes the KL-divergence and also a feature loss to match the features in the synE data with the real data distribution. The loss for the encoder is L E = L EKL + L Efeature ,
(14)
258
where
A.-A.-Z. Imran and D. Terzopoulos
L EKL = − KL [qλ (z | x) p(z)] = Eqλ (z|x) log
p(z) qλ (z | x)
(15)
≈ Eqλ (z|x) and
2 f (x) ˜ 2 . L Efeature = Ex∼ pdata f (x) − Ex∼G ˜
(16)
5 Experiments Applying semi-supervised learning using training data that is only partially labeled, we evaluated our MAVEN model in image generation and classification tasks in a number of experiments. For all our experiments, we used 10% labeled and 90% unlabeled training data.
5.1 Data We employed the following four image datasets: 1. The Street View House Numbers (SVHN) dataset [24] (Fig. 4a). There are 73,257 digit images for training and 26,032 digit images for testing. Out of two versions of the images, we used the version which has MNIST-like 32 × 32 pixel RGB color images centered around a single digit. Each image is labeled as belonging to one of 10 classes: digits 0–9. 2. The CIFAR-10 dataset [19] (Fig. 4b). It consists of 60,000 32 × 32 pixel RGB color images in 10 classes. There are 50,000 training images and 10,000 test images. Each image is labeled as belonging to one of 10 classes: plane, auto, bird, cat, deer, dog, frog, horse, ship, and truck. 3. The anterior-posterior Chest X-Ray (CXR) dataset [17] (Fig. 4c). The dataset contains 5,216 training and 624 test images. Each image is labeled as belonging to one of three classes: normal, bacterial pneumonia (b-pneumonia), and viral pneumonia (v-pneumonia). 4. The skin lesion classification (SLC) dataset (Fig. 4d). We employed 2,000 RGB skin images from the ISIC 2017 dermoscopy image dataset [5]; of which we used 1,600 for training and 400 for testing. Each image is labeled as belonging to one of two classes: non-melanoma and melanoma. For the SVHN and CIFAR-10 datasets, the images were normalized and provided to the models in their original (32 × 32 × 3) pixel sizes. For the CXR dataset, the images were normalized and resized to 128 × 128 × 1 pixels. For the SLC dataset, the images were resized to 128 × 128 × 3 pixels.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
259
Fig. 4 Example images of each class in the four datasets
5.2 Implementation Details To compare the image generation and multiclass classification performance of our MAVEN model, we used two baselines, the Deep Convolutional GAN (DCGAN) [29] and the VAE-GAN. The same generator and discriminator architectures were used for DC-GAN and MAVEN models and the same encoder was used for the VAE-GAN and MAVEN models. For our MAVENs, we experimented with 2, 3, and 5 discriminators. In addition to using the mean feedback of the multiple discriminators, we also experimented with feedback from a randomly selected discriminator. The six MAVEN variants are therefore denoted MAVEN-m2D, MAVEN-m3D, MAVEN-m5D, MAVEN-r2D, MAVEN-r3D, and MAVEN-r5D, where “m” indicates mean feedback while “r” indicates random feedback.
260
A.-A.-Z. Imran and D. Terzopoulos
All the models were implemented in TensorFlow and run on a single Nvidia Titan GTX (12 GB) GPU. For the discriminator, after every convolutional layer, a dropout layer was added with a dropout rate of 0.4. For all the models, we consistently used the Adam optimizer with a learning rate of 2.0−4 for G and D, and 1.0−5 for E, with a momentum of 0.9. All the convolutional layers were followed by batch normalizations. Leaky ReLU activations were used with α = 0.2.
5.3 Evaluation 5.3.1
Image Generation Performance Metrics
There are no perfect performance metrics for measuring the quality of generated samples. However, to assess the quality of the generated images, we employed the widely used Fréchet Inception Distance (FID) [12] and a simplified version of the Descriptive Distribution Distance (DDD) [14]. To measure the Fréchet distance between two multivariate Gaussians, the generated samples and real data samples are compared through their distribution statistics: 2
FID = μdata − μsyn + Tr data + syn − 2 data syn .
(17)
Two distribution samples, Xdata ∼ N(μdata , data ) and Xsyn ∼ N(μsyn , syn ), for real and model data, respectively, are calculated from the 2,048-dimensional activations of the pool3 layer of Inception-v3 [32]. DDD measures the closeness of a generated data distribution to a real data distribution by comparing descriptive parameters from the two distributions. We propose a simplified version based on the first four moments of the distributions, computed as the weighted sum of normalized differences of moments, as follows: DDD = −
4
log wi μdatai − μsyni ,
(18)
i=1
where the μdatai are the moments of the data distribution, the μsyni are the moments of the model distribution, and the wi are the corresponding weights found in an exhaustive search. The higher order moments are weighted more in order to emphasize the stability of a distribution. For both the FID and DDD, lower scores are better.
5.3.2
Image Classification Performance Metrics
To evaluate model performance in classification, we used two measures, image-level classification accuracy and class-wise F1 scoring. The F1 score is
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
F1 = with precision =
2 × precision × recall , precision + recall
TP TP + FP
and
recall =
261
(19)
TP , TP + FN
(20)
where TP, FP, and FN are the number of true positives, false positives, and false negatives, respectively.
5.4 Results We measured the image classification performances of the models with crossvalidation and in the following sections report the average scores from running each model 10 times.
5.4.1
SVHN
For the SVHN dataset, we randomly selected 7,326 labeled images and they along with the remaining 65,931 unlabeled images were provided to the network as training data. All the models were trained for 300 epochs and then evaluated. We generated new images equal in number to the training set size. Figure 5 presents a visual comparison of a random selection of images generated by the DC-GAN, VAE-GAN, and MAVEN models and real training images. Figure 6 compares the image intensity histograms of 10K randomly sampled real images and equally many images sampled from among those generated by each of the different models. Generally speaking, our MAVEN models generate images that are more realistic than those generated by the DC-GAN and VAE-GAN models. This was further corroborated by randomly sampling 10K generated images and 10K real images. The generated image quality measurement was performed for the eight different models. Table 1 reports the resulting FID and DDD scores. For the FID score calculation, the score is reported after running the pre-trained Inception-v3 network for 20 epochs for each model. The MAVEN-r3D model achieved the best FID score and the best DDD score was achieved by the MAVEN-m5D model. Table 2 compares the classification performance of all the models for the SVHN dataset. The MAVEN model consistently outperformed the DC-GAN and VAE-GAN classifiers both in classification accuracy and class-wise F1 scores. Among all the models, our MAVEN-m2D and MAVEN-m3D models were the most accurate.
262
A.-A.-Z. Imran and D. Terzopoulos
Fig. 5 Visual comparison of image samples from the SVHN dataset against those generated by the different models
Fig. 6 Histograms of the real SVHN training data, and of the data generated by the DC-GAN and VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to 5 discriminators
11.316±0.8080.190
12.123±0.1400.207
12.820±0.5840.194
12.620±0.0010.202
18.509±0.0010.215
88.60± 0.08 –
36.9
MAVENm2D
MAVENm3D
MAVENm5D
MAVENr2D
MAVENr3D
MAVENr5D
DO-GAN [23]
TTUR [12]
–
49.500
21.700
18.9
AIQN [28]
SN-GAN [22]
LM [30]
–
–
–
C-GAN [34] 27.300
–
15.511±0.1250.224
12.743±0.2420.223
VAE-GAN
61.293±0.2090.265
DC-GAN
MAVENr5D
MAVENr3D
MAVENr2D
MAVENm5D
MAVENm3D
MAVENm2D
VAE-GAN
DC-GAN
SVHN
Model
FID
Model
DDD
CIFAR-10 DDD
11.052±0.7510.323
10.791±0.0290.357
11.384±0.0010.316
10.909±0.0010.294
11.515±0.0650.300
11.675±0.0010.309
13.252±0.0010.329
16.789±0.3030.343
FID
MAVENr5D
MAVENr3D
MAVENr2D
MAVENm5D
MAVENm3D
MAVENm2D
VAE-GAN
DC-GAN
Model
CXR
152.778±1.254
158.749±0.297
154.501±0.345
147.316±1.169
140.865±0.983
141.339±0.420
141.422±0.580
152.511±0.370
FID
0.180
0.179
0.038
0.100
0.018
0.138
0.107
0.145
DDD
MAVENr5D
MAVENr3D
MAVENr2D
MAVENm5D
MAVENm3D
MAVENm2D
VAE-GAN
DC-GAN
Model
SLC DDD
1.812±0.014 0.795
0.336±0.080 0.783
1.505±0.130 0.789
1.518±0.190 0.793
0.304±0.018 0.249
1.874±0.270 0.802
1.828±0.580 0.795
1.828±0.370 0.795
FID
Table 1 Minimum FID and DDD scores achieved by the DC-GAN, VAE-GAN, and MAVEN models for the CIFAR-10, SVHN, CXR, and SLC datasets
Multi-Adversarial Variational Autoencoder Nets for Simultaneous … 263
264
A.-A.-Z. Imran and D. Terzopoulos
Table 2 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised classification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the SVHN dataset Model
Accuracy
F1 scores 5
6
7
8
9
DC-GAN
0.876
0.860
0.920
0.890
0.840
0.890
0.870
0.830
0.890
0.820
0.840
VAE-GAN
0.901
0.900
0.940
0.930
0.860
0.920
0.900
0.860
0.910
0.840
0.850
MAVEN-m2D
0.909
0.890
0.930
0.940
0.890
0.930
0.900
0.870
0.910
0.870
0.890
MAVEN-m3D
0.909
0.910
0.940
0.940
0.870
0.920
0.890
0.870
0.920
0.870
0.860
MAVEN-m5D
0.905
0.910
0.930
0.930
0.870
0.930
0.900
0.860
0.910
0.860
0.870
MAVEN-r2D
0.905
0.910
0.930
0.940
0.870
0.930
0.890
0.860
0.920
0.850
0.860
MAVEN-r3D
0.907
0.890
0.910
0.920
0.870
0.900
0.870
0.860
0.900
0.870
0.890
MAVEN-r5D
0.903
0.910
0.930
0.940
0.860
0.910
0.890
0.870
0.920
0.850
0.870
5.4.2
0
1
2
3
4
CIFAR-10
For the CIFAR-10 dataset, we used 50 K training images, only 5 K of them labeled. All the models were trained for 300 epochs and then evaluated. We generated new images equal in number to the training set size. Figure 7 visually compares a random selection of images generated by the DC-GAN, VAE-GAN, and MAVEN models and real training images. Figure 8 compares the image intensity histograms of 10K randomly sampled real images and equally many images sampled from among those generated by each of the different models. Table 1 reports the FID and DDD scores. As the tabulated results suggest, our MAVEN models achieved better FID scores than some of the recently published models. Note that those models were implemented in different settings. As for the visual comparison, the FID and DDD scores confirmed more realistic image generation by our MAVEN models compared to the DC-GAN and VAE-GAN models. The MAVEN models have smaller FID scores, except for MAVEN-r5D. MAVEN-m3D has the smallest FID and DDD scores among all the models. Table 3 compares the classification performance of all the models with the CIFAR10 dataset. All the MAVEN models performed better than the DC-GAN and VAEGAN models. In particular, MAVEN-m5D achieved the best classification accuracy and F1 scores.
5.4.3
CXR
With the CXR dataset, we used 522 labeled images and 4,694 unlabeled images. All the models were trained for 150 epochs and then evaluated. We generated an equal number of new images as the training set size. Figure 9 presents a visual comparison of a random selection of generated and real images. The FID and DDD measurements were performed for the distributions of generated and real training samples, indicating that more realistic images were generated by the MAVEN models than by the GAN
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
265
Fig. 7 Visual comparison of image samples from the CIFAR-10 dataset against those generated by the different models
Fig. 8 Histograms of the real CIFAR-10 training data, and of the data generated by the DC-GAN and VAE-GAN models and by our MAVEN models with mean and random feedback from 2, 3, to 5 discriminators
266
A.-A.-Z. Imran and D. Terzopoulos
Table 3 Average cross-validation accuracy and class-wise F1 scores in the semi-supervised classification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the CIFAR-10 dataset Model
Accuracy
F1 scores Plane
Auto
Bird
Cat
Deer
Dog
Frog
Horse
Ship
Truck
DC-GAN
0.713
0.760
0.840
0.560
0.510
0.660
0.590
0.780
0.780
0.810
0.810
VAE-GAN
0.743
0.770
0.850
0.640
0.560
0.690
0.620
0.820
0.770
0.860
0.830
MAVEN-m2D
0.761
0.800
0.860
0.650
0.590
0.750
0.680
0.810
0.780
0.850
0.850
MAVEN-m3D
0.759
0.770
0.860
0.670
0.580
0.700
0.690
0.800
0.810
0.870
0.830
MAVEN-m5D
0.771
0.800
0.860
0.650
0.610
0.710
0.640
0.810
0.790
0.880
0.820
MAVEN-r2D
0.757
0.780
0.860
0.650
0.530
0.720
0.650
0.810
0.800
0.870
0.860
MAVEN-r3D
0.756
0.780
0.860
0.640
0.580
0.720
0.650
0.830
0.800
0.870
0.830
MAVEN-r5D
0.762
0.810
0.850
0.680
0.600
0.720
0.660
0.840
0.800
0.850
0.820
and VAE-GAN models. The FID and DDD scores presented in Table 1 show that the mean MAVEN-m3D model has the smallest FID and DDD scores. The classification performance reported in Table 4 suggests that our MAVEN model-based classifiers are more accurate than the baseline GAN and VAE-GAN classifiers. Among all the models, the MAVEN-m3D classifier was the most accurate.
5.4.4
SLC
For the SLC dataset, we used 160 labeled images and 1,440 unlabeled images. All the models were trained for 150 epochs and then evaluated. We generated new images equal in number to the training set size. Figure 10 presents a visual comparison of randomly selected generated and real image samples. The FID and DDD measurements for the distributions of generated and real training samples indicate that more realistic images were generated by the MAVEN models than by the GAN and VAE-GAN models. The FID and DDD scores presented in Table 1 show that the mean MAVEN-m3D model has the smallest FID and DDD scores. The classification performance reported in Table 5 suggests that our MAVEN model-based classifiers are more accurate than the baseline GAN and VAE-GAN classifiers. Among all the models, MAVEN-r3D is the most accurate in discriminating between non-melanoma and melanoma lesion images.
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
267
Fig. 9 Visual comparison of image samples from the CXR dataset against those generated by the different models Table 4 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised classification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the CXR dataset Model
DC-GAN
Accuracy
0.461
F1 scores Normal
B-Pneumonia
V-Pneumonia
0.300
0.520
0.480
VAE-GAN
0.467
0.220
0.640
0.300
MAVEN-m2D
0.469
0.310
0.620
0.260
MAVEN-m3D
0.525
0.640
0.480
0.480
MAVEN-m5D
0.477
0.380
0.480
0.540
MAVEN-r2D
0.478
0.280
0.630
0.310
MAVEN-r3D
0.506
0.440
0.630
0.220
MAVEN-r5D
0.483
0.170
0.640
0.240
268
A.-A.-Z. Imran and D. Terzopoulos
Fig. 10 Visual comparison of image samples from the SLC dataset against those generated by the different models
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
269
Table 5 Average cross-validation accuracy and class-wise F1 scores for the semi-supervised classification performance comparison of the DC-GAN, VAE-GAN, and MAVEN models using the SLC dataset Model Accuracy F1 scores Non-melanoma Melanoma DC-GAN VAE-GAN MAVEN-m2D MAVEN-m3D MAVEN-m5D MAVEN-r2D MAVEN-r3D MAVEN-r5D
0.802 0.810 0.815 0.814 0.812 0.808 0.821 0.797
0.890 0.890 0.900 0.900 0.900 0.890 0.900 0.890
0.120 0.012 0.016 0.110 0.140 0.260 0.020 0.040
6 Conclusions We have introduced a novel generative modeling approach, called Multi-Adversarial Variational autoEncoder Networks, or MAVENs, which demonstrates the advantage of an ensemble of discriminators in the adversarial learning of variational autoencoders. We have shown that training our MAVEN models on small, labeled datasets and allowing them to leverage large numbers of unlabeled training examples enables them to achieve superior performance relative to prior GAN and VAE-GAN-based classifiers, suggesting that MAVENs can be very effective in simultaneously generating high-quality realistic images and improving multiclass image classification performance. Furthermore, unlike conventional GAN-based semi-supervised classification, improvements in the classification of natural and medical images do not compromise the quality of the generated images. Future work with MAVENs should explore more complex image analysis tasks beyond classification and include more extensive experimentation spanning additional domains.
References 1. M. Arjovsky, S. Chintala, L. Bottou, Wasserstein GAN (2017). arXiv preprint arXiv:1701.07875 2. C. Bermudez, A.J. Plassard, L.T. Davis, A.T. Newton, S.M. Resnick, B.A. Landman, Learning implicit brain MRI manifolds with deep learning, in Medical Imaging 2018: Image Processing, vol. 10574 (2018), p. 105741L 3. F. Calimeri, A. Marzullo, C. Stamile, G. Terracina, Biomedical data augmentation using generative adversarial neural networks, in International Conference on Artificial Neural Networks (2017), pp. 626–634 4. M.J. Chuquicusma, S. Hussein, J. Burt, U. Bagci, How to fool radiologists with generative adversarial networks? A visual turing test for lung cancer diagnosis, in IEEE International Symposium on Biomedical Imaging (ISBI) (2018), pp. 240–244
270
A.-A.-Z. Imran and D. Terzopoulos
5. N.C. Codella, D. Gutman, M.E. Celebi, B. Helba, M.A. Marchetti, S.W. Dusza, A. Kalloo, K. Liopyris, N. Mishra, H. Kittler et al., Skin lesion analysis toward melanoma detection: a challenge at the 2017 ISBI, hosted by ISIC, in IEEE International Symposium on Biomedical Imaging (ISBI 2018) (2018), pp. 168–172 6. E.L. Denton, S. Chintala, A. Szlam, R. Fergus, Deep generative image models using a Laplacian pyramid of adversarial networks, in Advances in Neural Information Processing Systems (NeurIPS) (2015) 7. I. Durugkar, I. Gemp, S. Mahadevan, Generative multi-adversarial networks (2016). arXiv preprint arXiv:1611.01673 8. M. Frid-Adar, I. Diamant, E. Klang, M. Amitai, J. Goldberger, H. Greenspan, GAN-based synthetic medical image augmentation for increased CNN performance in liver lesion classification (2018). arXiv preprint arXiv:1803.01229 9. I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, Y. Bengio, Generative adversarial nets, in Advances in Neural Information Processing Systems (NeurIPS) (2014), pp. 2672–2680 10. J.T. Guibas, T.S. Virdi, P.S. Li, Synthetic medical images from dual generative adversarial networks (2017). arXiv preprint arXiv:1709.01872 11. C. Han, H. Hayashi, L. Rundo, R. Araki, W. Shimoda, S. Muramatsu, Y. Furukawa, G. Mauri, H. Nakayama, GAN-based synthetic brain MR image generation, in IEEE International Symposium on Biomedical Imaging (ISBI) (2018), pp. 734–738 12. M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, S. Hochreiter, GANs trained by a two time-scale update rule converge to a local Nash equilibrium, in Advances in Neural Information Processing Systems (NeurIPS) (2017), pp. 6626–6637 13. L. Hou, A. Agarwal, D. Samaras, T.M. Kurc, R.R. Gupta, J.H. Saltz, Unsupervised histopathology image synthesis (2017). arXiv preprint arXiv:1712.05021 14. A.A.Z. Imran, P.R. Bakic, A.D. Maidment, D.D. Pokrajac, Optimization of the simulation parameters for improving realism in anthropomorphic breast phantoms, in Proceedings of the SPIE, vol. 10132 (2017) 15. A.A.Z. Imran, D. Terzopoulos, Multi-adversarial variational autoencoder networks, in IEEE International Conference on Machine Learning and Applications (ICMLA) (oca Raton, FL, 2019), pp. 777–782 16. A.A.Z. Imran, D. Terzopoulos, Semi-supervised multi-task learning with chest X-ray images (2019). arXiv preprint arXiv:1908.03693 17. D.S. Kermany, M. Goldbaum, W. Cai, C.C. Valentim, H. Liang, S.L. Baxter, A. McKeown, G. Yang, X. Wu, F. Yan et al., Identifying medical diagnoses and treatable diseases by image-based deep learning. Cell 172(5), 1122–1131 (2018) 18. D.P. Kingma, M. Welling, Auto-encoding variational Bayes (2013). arXiv preprint arXiv:1312.6114 19. A. Krizhevsky, Learning multiple layers of features from tiny images. Master’s thesis, University of Toronto, Dept. of Computer Science (2009) 20. A. Madani, M. Moradi, A. Karargyris, T. Syeda-Mahmood, Semi-supervised learning with generative adversarial networks for chest X-ray classification with ability of data domain adaptation, in IEEE International Symposium on Biomedical Imaging (ISBI) (2018), pp. 1038–1042 21. A. Makhzani, J. Shlens, N. Jaitly, I. Goodfellow, B. Frey, Adversarial autoencoders (2015). arXiv preprint arXiv:1511.05644 22. T. Miyato, T. Kataoka, M. Koyama, Y. Yoshida, Spectral normalization for generative adversarial networks (2018). arXiv preprint arXiv:1802.05957 23. G. Mordido, H. Yang, C. Meinel, Dropout-GAN: learning from a dynamic ensemble of discriminators (2018). arXiv preprint arXiv:1807.11346 24. Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, A.Y. Ng, Reading digits in natural images with unsupervised feature learning, in NIPS Workshop on Deep Learning and Unsupervised Feature Learning, vol. 2011 (2011), pp. 1–9 25. B. Neyshabur, S. Bhojanapalli, A. Chakrabarti, Stabilizing GAN training with multiple random projections (2017). arXiv preprint arXiv:1705.07831
Multi-Adversarial Variational Autoencoder Nets for Simultaneous …
271
26. T. Nguyen, T. Le, H. Vu, D. Phung, Dual discriminator generative adversarial nets, in Advances in Neural Information Processing Systems (NeurIPS) (2017), pp. 2670–2680 27. A. dena, C. Olah, J. Shlens, Conditional image synthesis with auxiliary classifier GANs (2017). arXiv preprint arXiv:1610.09585 28. G. Ostrovski, W. Dabney, R. Munos, Autoregressive quantile networks for generative modeling (2018). arXiv preprint arXiv:1806.05575 29. A. Radford, L. Metz, S. Chintala, Unsupervised representation learning with deep convolutional generative adversarial networks (2015). arXiv preprint arXiv:1511.06434 30. S. Ravuri, S. Mohamed, M. Rosca, O. Vinyals, Learning implicit generative models with the method of learned moments (2018). arXiv preprint arXiv:1806.11006 31. H. Salehinejad, S. Valaee, T. Dowdell, E. Colak, J. Barfett, Generalization of deep neural networks for chest pathology classification in X-rays using generative adversarial networks, in IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2018), pp. 990–994 32. T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Radford, X. Chen, Improved techniques for training GANs, in Advances in Neural Information Processing Systems (NeurIPS) (2016), pp. 2234–2242 33. J.T. Springenberg, Unsupervised and semi-supervised learning with categorical generative adversarial networks (2015). arXiv preprint arXiv:1511.06390 34. T. Unterthiner, B. Nessler, C. Seward, G. Klambauer, M. Heusel, H. Ramsauer, S. Hochreiter, Coulomb GANs: provably optimal nash equilibria via potential fields (2017). arXiv preprint arXiv:1708.08819 35. S. Wang, L. Zhang, CatGAN: coupled adversarial transfer for domain generation (2017). arXiv preprint arXiv:1711.08904 36. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan (eds.), Advances in Deep Learning (Springer, 2020)
Non-convex Optimization Using Parameter Continuation Methods for Deep Neural Networks Harsh Nilesh Pathak and Randy Clinton Paffenroth
Abstract Numerical parameter continuation methods are popularly utilized to optimize non-convex problems. These methods have had many applications in Physics and Mathematical analysis such as bifurcation study of dynamical systems. However, as far as we know, such efficient methods have seen relatively limited use in the optimization of neural networks. In this chapter, we propose a novel training method for deep neural networks based on the ideas from parameter continuation methods and compare them with widely practiced methods such as Stochastic Gradient Descent (SGD), AdaGrad, RMSProp and ADAM. Transfer and curriculum learning have recently shown exceptional performance enhancements in deep learning and are intuitively similar to the homotopy or continuation techniques. However, our proposed methods leverage decades of theoretical and computational work and can be viewed as an initial bridge between those techniques and deep neural networks. In particular, we illustrate a method that we call Natural Parameter Adaption Continuation with Secant approximation (NPACS). Herein we transform regularly used activation functions to their homotopic versions. Such a version allows one to decompose the complex optimization problem into a sequence of problems, each of which is provided with a good initial guess based upon the solution of the previous problem. NPACS uses the above-mentioned system uniquely with ADAM to obtain faster convergence. We demonstrate the effectiveness of our method on standard benchmark problems and compute local minima more rapidly and achieve lower generalization error than contemporary techniques in a majority of cases. H. Nilesh Pathak (B) Expedia Group, 1111 Expedia Group Way W, Seattle, WA 98119, USA e-mail: [email protected] R. Clinton Paffenroth Worcester Polytechnic Institute, Mathematical Sciences Computer Science & Data Science, 100 Institute Rd, Worcester, MA 01609, USA e-mail: [email protected] © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9_12
273
274
H. Nilesh Pathak and R. Clinton Paffenroth
1 Introduction In many machine learning and deep learning problems, the key task is an optimization problem, with the objective to learn useful properties of the data given a model. The parameters of the model are used to learn linear or non-linear features of the data and can be used to perform inference and predictions. In this chapter, we will study a challenging non-convex optimization task, i.e. ability of deep learning models to learn non-linear or complex representations from data. This chapter is an extension of the conference paper [52] with additional results, additional context for the theory of the algorithms, a detailed literature survey listing some common limitations of curriculum strategies, and discussions on open research questions. Often, the objective of a deep learning model is to approximate a function f true , which is the true function that maps inputs x to the targets y, such that y = f true (x) [15, 50]. More formally, a deep neural network defines a mapping y = f (x; θ ), where θ is the set of parameters. These parameters θ are estimated by minimizing an objective function, such that the mapping f best approximates f true . However, training the deep neural network to find a good solution is a challenging optimization task [3, 17, 44]. Where by a good solution is meant achieving low generalization error in few training steps. Even with state-of-the-art techniques, machines configured with a large number of processing units and high memory can spend days, or even longer to solve such problems [50]. Neural networks use a composition of functions which, generally speaking, are non-convex and difficult to solve [3, 8]. Even deep linear networks functions are non-convex in parameter space [17]. Evidently, deep learning research have exceedingly advanced in the past decade, and quality optimization algorithms have been proposed to solve non-convex problems such as Stochastic Gradient Descent (SGD) [15], RMSprop [25], Adagrad [12] and ADAM [35] are widely practiced by the deep learning community and have significantly advanced the state-of-the-art in text mining [43], super-resolution [16, 40, 51], image recognition [38], speech recognition [24] and many more. However, their success usually depends on the choice of hyperparameters and the quality of the initialization [60]. Previously, researchers have shown that different initialization methods may lead to dramatically different solution curve geometries on a cost surface [8, 17, 27]. In other words, a small shift in the initialization may lead the model to converge to a different minimum [50]. In this chapter, we attempt to introduce and improve the solution of such non-convex optimization problems by rethinking usual approach. In particular, many of the current state-of-the-art optimization techniques work on a fixed loss surface. On the other hand, we propose to transform the loss surface continuously in order to design an effective training strategy. The contributions of this chapter are the following. We derive a homotopy formulation of common activation functions that implicitly decomposes the deep neural network optimization problem into a sequence of problems. Which then enables one to design and apply parameter continuation methods to rapidly converge to a bet-
Non-convex Optimization Using Parameter Continuation …
275
ter solution. Further, we developed a set of Natural Parameter Continuation (NPC) [52] techniques, apply them to neural networks, and empirically observe improvement in terms of generalization performance compared to standard optimization techniques. After performing a curated list of experiments, we observed that a naive NPC method requires careful tuning depending on the network hyperparameters, data and activation functions [52]. Accordingly, we designed an enhanced strategy to adaptively adjust during training, which we call Natural Parameter Adaptive Continuation (NPAC). As motivated by the continuation literature, Natural Parameter Adaptive Continuation with Secant approximation (NPACS) is our prime contribution. Herein we improve the quality of initialization and speed of convergence by employing the secant method. Further, we perform a set of experiments that demonstrate that continuation methods achieve lower generalization and training loss, that implies faster learning compared to conventional optimization methods. In some of our experiments, we observed the standard optimization methods trapped at some local minimum which our proposed methods swiftly surpassed. Finally, we discuss open research problems and provide insights into how continuation methods can serve as an efficient strategy for optimizing and analyzing neural networks. These ideas are demonstrated using benchmarks and extensive empirical evaluations from multiple sources. The remainder of this chapter is structured as follows: Sect. 2 provides a preliminary discussion of homotopies and similar ideas. Section 2.1 provides classical theory that connects to our work involving the implicit function theorem. Section 3 covers related work, in particular, we discuss advantages and limitation of techniques that use activation as a medium of continuation or curriculum. Section 4 describes our methodology and approach, including a detailed description of our novel optimization and initialization strategy. In Sect. 5 we discuss our experimental results. Additionally, we present a discussion of future open problems in Sect. 6, that could help unravel new paths to deep learning research. Finally, Sect. 7 concludes the chapter. In this extended chapter, Sects. 2, 6 and 4.3 are newly added. Moreover, additional insights and results are added in Sects. 3, 4.6 and 5.
2 Background Non-convex optimization is a challenging task which arises in almost all deep learning models [15, 52]. A historically effective class of algorithms for solving such non-convex problems are numerical continuation methods [2]. Continuation methods can be utilized to organize the training of a neural network and assist in improving the quality of an initial guess that may accelerate the network convergence [52]. The fundamental idea is to start from an easier version of the desired problem and gradually transform it into the original version. Formally, this may be described as follows: given a loss function J (θ ), define a sequence of loss functions {J (0) (θ ), J (1) (θ ), J (2) (θ ), J (3) (θ ), . . . , J (n) (θ )} such that J (i) (θ ) is easier to optimize than J (i+1) (θ ) [15]. Here J (i) (θ ) and J (i+1) (θ ) are sufficiently similar, so
276
H. Nilesh Pathak and R. Clinton Paffenroth
that the solution of J (i) (θ ) can be used as an initial estimate for the parameters of J (i+1) (θ ) [2, 11, 50]. Based upon a literature survey [3, 44, 49, 50], there appear to be two main thrusts for improving optimization by continuation: • How to choose the simplified problem? • How to transform the simplified problem into the original task? In this chapter, we provide an approach that addressed both of these challenges. We also observe that the results in this chapter are based upon results in several disparate domains, and good entryways into the literature include [15, 26] for Autoencoders (AE) [2, 34], for continuation methods and [3] Curriculum Learning.
2.1 Homotopies One possible way to incorporate continuation methods in the context of neural networks is determining homotopy [2] formulations of activation function [52], as defined in the following manner: h(x, λ) = (1 − λ) · h 1 (x) + λ · h 2 (x)
(1)
where λ ∈ [0, 1] is known as a homotopy parameter. Herein, we consider h 1 (x) to be some simple activation function and h 2 (x) to be some more complex activation function. For example, h 1 (x) := x and h 2 (x) := 1+e1 −x , and h(x, λ) provides a continuous transformation between them. While a seemingly simple idea, such methods have a long history and have made substantial impacts in other contexts, for example, the solution of boundary value problems in ODEs [2, 11]. As we will detail in the next section, when such ideas are applied to neural networks, one can transform the minimum of a simple neural network into a minimum of a much more complicated neural network [52].
2.2 The Implicit Function Theorem The Implicit Function Theorem (IFT), is an important tool for the theoretical understanding of continuation methods [34]. While a fulsome treatment of these ideas is beyond the scope of this chapter, in this section we will provide a measure of intuition for the interested reader to help bridge the gap between continuation methods and deep learning. Let’s begin with the objective of a simple AE [34]. J (X, g( f (X )); θ ) = 0 G(X ; θ ) = 0
(2) (3)
Non-convex Optimization Using Parameter Continuation …
277
where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function and θ ∈ R N the network parameters. We want it to be zero in the case of mean squared error. Now we can represent this objective function as a function G(X ; θ ) which can be used in a continuation method framework [34]. Note, that in the above equation X is not a parameter. It is fixed by our choice of training data. Accordingly, when we ‘count parameters’ below, only θ serves as degrees of freedom. Now, say we add another parameter λ, which we call a homotopy parameter, to the system of neural network equations. In our work, this single parameter controls the transformation of the problem from a simple problem to a hard problem, and such homotopy parameters will be the focus of Sect. 4. We rewrite to get G(θ, λ) = 0
(4)
Now (4) can be directly related to the the system of equations in [34], for which implicit function theorem can be applied under certain assumptions1 The IFT then guarantees that, in the neighbourhood of a known solution when the derivative is non-singular, that a smooth and continuous homotopy path must exist. The interested reader can see [2, 34] for more details.
3 Related Work Optimization technique such as SGD or ADAM plays a crucial role in the quality of convergence. Also, finding better initialization in order to find a superior local minimum is popular in the research community. Here we draw a few observations from some of the popular research works which have demonstrated quality performance. Unsupervised pre-training [26] is one such strategy to train neural networks for supervised and unsupervised learning tasks. The idea is to use an Autoencoder (AE) to greedily train the hidden layers one layer at a time. This method has two advantages. First, it provides stable initialization as compared to random initialization as the model has silently learned properties of the data. Second, it increases the regularization effect [15, 50]. Fine-tuning may be seen as an evolution of abovementioned approach where the learned layers are trained further or fine-tuned for the final task [15, 39]. In transfer learning, you are provided with two distinct tasks, such that learning from one distribution can be transferred to another. This technique is broadly used in vision and language tasks such as recommendation systems, image super-resolution, etc. In summary, these ideas suggest that incorporating prior 1 Assumptions:
• G : RN × R − → R N be smooth map • ||G(θ0 , λ0 )|| ≤ c • G 0 (θ) be non-singular at a known root ((θ0 , λ0 )) See. [34] for the IFT theorem and proofs for local continuation.
278
H. Nilesh Pathak and R. Clinton Paffenroth
knowledge to the network or taking few simple learning steps first, can enhance the final training performance [13, 17]. Some of the standard work that has been done close to the research in this chapter is curriculum learning [3, 23]. Smoothing [22, 44] is a common approach for changing the qualitative behaviour of activation functions throughout the course of the training that has been adopted by many researchers in distinct ways [5, 21, 22, 44, 50]. Smoothing can be seen as a convex envelope of the non-convex problems in [45] and it may reveal the global curvature of the loss function [3, 44]. Then this smoothing can be continuously decreased to get a sequence of loss functions with increasing complexity. The smoothed objective function is minimized first, and then progressively smoothing is reduced as training proceeds to obtain the terminal solution [3, 44]. Next, Mollifying Networks added normally distributed noise to the weights of each layer of the neural network [22] and thus, mollified the objective function. Again, this noise is steadily reduced as the training proceeds. Linearization of activation functions is another way of achieving the mollification effect [5, 21, 22]. Similarly, gradually deforming from linear to non-linear representation has been implemented with different noise injection techniques [21, 22, 44]. Initialization is one of the main advantages of continuation methods [2, 34]. In many deep learning applications, we observe random initialization is widely practiced. However, continuation methods suggest the following, namely, to start with a problem whose solution is easily available and then incrementally advance towards the problem you want to solve. Naturally, this method can save computational cost. While other methods would require many iterations to attain a reasonable solution, in the paradigm of parameter continuation methods, the initialization could usually be deterministic, and later we progressively solve simpler problems to obtain the optimal solution. For example, in [44] one performs accurate initialization with almost no computational cost. Furthermore, some more advantages were discussed in diffusion methods [44]. In that research work, authors conceptually showed the popular techniques such as dropouts [58], layerwise pre-training [26, 39], annealed learning rate, careful initialization [60] and noise injection (directly to loss or weights of the network) [3] can naturally arise from continuation methods [44]. We discuss a few limitations and future research directions in Sect. 6. In comparison to the previous approaches that add noise to activation functions [21, 22] to serve as a proxy for the continuation method, we derive homotopies to obtain linearization, which we will discuss in Sect. 4. Our intent is to simplify the original task and start with a good initialization. Despite the above advances, there is a need for systematic empirical evaluation to show how the continuation method may help to avoid early local minima and saddle points [17, 52]. Unlike most of the research, that was focused on classification tasks or prediction tasks, herein we focus on unsupervised learning. However, there is nothing special in the formulation of the homotopy function that depends on the particular structure of AEs, so we expect similar results are possible with arbitrary neural networks.
Non-convex Optimization Using Parameter Continuation …
279
4 Methodology In this section, we discuss our method to construct a sequence of objective functions with increasing complexity. We illustrate how we transform the standard activation functions through homotopy. This leads us to strategically solve deep neural network optimization, following the principles of parameter continuation [2]. We discuss three continuation methods, namely, NPC, NPAC and NPACS [52].
4.1 Continuation Activation (C-Activation) is a homotopy formulation of standard activation functions. Homotopy adds the ability to a network to learn linear, non-linear and any intermediate characteristic of the data. Activation functions can be reformulated according to (5), namely, φC-Activation (v, λ) = (1 − λ) · v + λ · φ(v) (5) λ ∈ [0, 1] where φ can be any activation function such as Sigmoid, ReLU, Tanh, etc., λ is the homotopy parameter and v is the value of the output from the current layer. For a standard AE, the loss function can be represented as the following J (X, gθ ( f θ (X ))) = argmin ||X − gθ ( f θ (X ))||2 . θ
(6)
where X ∈ R m×n is the input data, f is the encoder function, g is the decoder function and θ the network parameters. With the addition of a homotopy parameter λ the optimization can be rewritten in the following manner: J (X, gθ,λ ( f θ,λ (X )))) = argmin ||X − gθ,λ ( f θ,λ (X ))||2 . θ,λ
(7)
In the above optimization problem, at any continuation step λi (λ ∈ [0, 1] and i ∈ N0 ) we solve the problem J (X, gθi ,λi ( f θi ,λi (X )))) and obtain a solution as θi , λi . Every value of λ represents a different optimization problem and a corresponding degree of non-linearity [52].
4.2 Intuition for Extreme Values of λ To provide intuition, first let us consider the extremes (i.e. λ = 0 or λ = 1). When λ = 1, both the objective functions in (6) and (7) are exactly the same, thus λ = 1 indicates the use of conventional activation functions. Second, in a deep neural network as
280
H. Nilesh Pathak and R. Clinton Paffenroth
Fig. 1 This figure is a good example of Manifold Learning [69]. Points in blue are true data points that clearly lies on a non-linear manifold. The green points show an optimized non-linear projection of the data. The red points show the linear manifold which is projected by PCA. Learning linear manifold can be considered as simplified problem to non-linear manifold learning
shown in Fig. 6, consider all the activation functions are C-Activations with λ0 = 0, then the neural network is a linear. Here, the network would attempt to find a linear projection in order to minimize the cost function. We know such a solution can be found in closed form using Principle Component Analysis (PCA) or the Singular Value Decomposition (SVD) [50]. Hence, the PCA solution and the solution of our optimization problem at λ0 = 0 should span the same subspace [15, 50]. Thus, we leverage this observation and initialize our network using PCA, which we discuss further in next Sect. 4.3. Effectively, λ = 0 is analogous to solving a simpler version of the optimization problem (a linear projection) and as λ − → 1 the problem becomes harder (a non-linear and non-convex problem) (Fig. 1). Thus, the homotopy parameter λ defines the scheme for the optimization, which can be referred to as the homotopy path, as shown in Fig. 2. As λ : 0 − → 1, we solve an increasingly difficult optimization problem, where the degree of non-linearity learned from the data increases. However, we need a technique to traverse through this homotopy path and find the solutions along that path and deriving such techniques will be our focus in the following section.
4.3 Stable Initialization of Autoencoder Through PCA Principle Component Analysis (PCA) [47] is a classic dimension reduction technique that provides a lower dimensional linear representation of the data. It is well-known that PCA projections can be computed using the Singular Value Decomposition (SVD) [47] by computing matrices U , , V T from given data as shown in (8). X = U V T
(8)
Non-convex Optimization Using Parameter Continuation …
281
Fig. 2 This figure provides intuition for NPC with a hypothetical homotopy path (blue curve), which connects all solutions in multidimensional space (θ) at every λ ∈ [0, 1]. Here we show how solution (θλ1 ) is used to initialize the parameters of the network at λ2 , θλinit ← − θλ1 using a small 2 λ2,1 . Further taking some ADAM steps (shown by the orange arrow), we find some minimum (θλ2 )
where X is our normalized training data,2 U and V are unitary matrices, and is a diagonal matrix. In Eq. (9), we observe that V can be seen as a mapping of data from high dimensional space to a lower dimension linear subspace (an encoder in language of AEs). In addition, V T is a mapping from a lower dimensional linear subspace back to the original space of the provided data (a decoder in the language of AEs). This behaviour of SVD enables us to initialize the weights and appropriately defined AE. In particular, when λ = 0 in (7) we have that f and g are, in fact, linear functions and (7) can be minimized using the SVD [47]. More precisely, we use the first n columns of V for the encode layer with width n and for the decode layer, we use the transpose of the weights used for the encoder, as shown in (12) and as in [50]. X = U V T X V = U V T V XV = U n = Vn-columns Wencoder n n Wdecoder = (Wencoder )T
(9) (10) (11) (12) (13)
where W n represents the weight matrix of the encoder layer with n as its width. There are multiple advantages to initializing an AE using PCA. First, we start the deep learning training from a solution to a linear problem rather than trying to solve a non-linear, and often non-convex, optimization problem. Having such a straight-forward initiation procedure allows the optimization of the AE to begin
2 For
PCA the data is normalized by having the mean of each column being 0.
282
H. Nilesh Pathak and R. Clinton Paffenroth
from a known global optimum. Second, in the NPCS method, the C-Activation defines the homotopy as a continuous deformation from a linear to a non-linear network. Accordingly, the global optimum to the linear problem can be transformed into a solution of the non-linear and non-convex problem of interest in a principled fashion. Finally, as PCA provides a deterministic initialization which does not require any sampling of the parameter space for initialization [10, 44, 60], our proposed method also has a computational advantage. Of course, the idea of initializing a deep feedforward network with PCA is not novel and has been independently explored in the literature [7, 36, 57]. However, our proposed algorithm is unique in that it leverages powerful techniques for solving homotopy problems using parameter continuation.
4.4 Natural Parameter Continuation (NPC) Natural Parameter Continuation (NPC) is an adaptation of the standard iterative optimization solvers to a homotopy type problem [52]. Continuing our explanation from Sect. 4.2, the solution at λi can be used as the initial guess for the solution at λi+1 = λi + λi . With λi being sufficiently small, the iteration applied to the initial solution should converge to a local minima [2]. In Fig. 2, we show the method used to traverse along the homotopy path. Next, we are asked to find the solution for every problem (J (x, g( f (x); θ, λi ))) along the way. The solution set of such an Ndimensional (θ ) problem is on one-dimensional manifold (homotopy path) embedded in an N+1 dimensional space [48, 52], as we show schematically in the 2D plot 2. Note, using PCA, we know the solution at λ0 = 0, if λ0 is made sufficiently small, then our solution at λ1 , for example, should be very close to the PCA solution. Hence, we initialize the optimization problem at λ1 with a solution of λ0 . We keep repeating this iteratively till the end of homotopy to find the solution to the problem of interest. Generally, this may be referred as Natural Parameter Continuation (NPC) [2, 48] as shown in Fig. 2, where the solution of the optimization problem at λi can be used to initialize the problem at λi+1 . After initializing we may require some standard optimization techniques such as a few steps of the ADAM algorithm to find the optimal solution at λi+1 (Fig. 2). The main challenge in deep learning framework is we have too many hyperparameters. We have different network architectures, activation functions, batch-size, etc and coming up with a thumb rule is difficult. In particular, we have two main disadvantages with this fixed step size; (1) we may end up taking many steps where they are not required and (2) fewer steps where we require cautious updates depending on the nature of a homotopy path [52]. To better understand these disadvantages we illustrate, Fig. 3. To address this, our next method is inspired by the parameter continuation literature, in Sect. 4.5. We use an adaptive solution that can reasonably determine the nature of a homotopy path and update λ accordingly. ← − θλi θλinit i+1
(14)
Non-convex Optimization Using Parameter Continuation …
283
Fig. 3 This figure, we provide intuition for our adaptive methods. We show two possible set of steps, where green is an adaptive λ and grey is a fixed λ. We want to take larger steps when the homotopy path is flat and shorter steps where the homotopy path is more oscillatory
4.5 NPC with Adaption for Continuation Steps (NPAC) In the NPC we had a user-specified configuration for a fixed λ. In practice, we found coming up with a precise λ is difficult. In the field of continuation methods or bifurcation analysis, idea of adaptive step sizes is well-known and used in software packages such as AUTO2000 [48] and AUTO-07P [11]. Accordingly, we borrow this idea in our implementation of neural network training. First, we observe that the behaviour of activation function plays an important role towards the nature of the homotopy path. For example, in Fig. 4, at λ = 0.7 we observe that C-Sigmoid is mostly linear and C-ReLU is gradually transforming into a ReLU, demonstrating two different possible behaviours while transforming through the homotopy path. Accordingly, at λ = 0.7 the algorithm should make longer λ updates in the case of C-Sigmoid, but shorter in the case of C-ReLU. In our experiments, we show similar observations in Table 1. For NPC methods we used fix updates that we computed empirically, such as λ = 0.008 for C-ReLU, however, for C-Sigmoid and C-Tanh we had to be careful in determining λ values. Until λ < 0.8 λ = 0.02 and if λ >= 0.8 then λ = 8e − 4. This is because we
Fig. 4 This figure, from left to right, illustrates how the C-Sigmoid, C-ReLU and C-Tanh behaves at λ = 0.7 on uniformly distributed points between [−10 and 10]
284
H. Nilesh Pathak and R. Clinton Paffenroth
Table 1 This table shows the train and test loss values of different optimization techniques using a specified network. All these experiments were computed for 50,000 backpropogation steps, and we report the averages of the last 100 loss values for both training and testing. Perhaps not surprisingly, SGD does the worst in almost all cases. More interestingly, RMSProp and ADAM both do well in some cases, and quite badly in others. Note that the various parameter continuation methods all have quite stable properties and achieve a low loss in all cases and the lowest loss in most cases FashionMNIST
CIFAR-10
FashionMNIST (Test)
CIFAR-10 (Test)
Network
SGD
RMSProp
ADAM
NPC
NPAC
NPACS
AE-8 Sigmoid
0.1133
0.03915
0.03370
0.03402
0.03370
0.03388
AE-8 ReLU
0.11122
0.03582
0.03318
0.03171
0.03188
0.03191
AE-8 Tanh
0.10741
0.03459
0.03515
0.03573
0.03552
0.03559
AE-16 Sigmoid
0.11397
0.06714
0.06714
0.03418
0.04505
0.03461
AE-16 ReLU
0.11394
0.06714
0.03436
0.03474
0.03445
0.03659
AE-16 Tanh
0.10889
0.03419
0.03540
0.03753
0.03722
0.03622
AE-8 Sigmoid
0.28440
0.03861
0.03352
0.03275
0.03224
0.03238
AE-8 ReLU
0.07689
0.03467
0.03421
0.03459
0.03461
0.03302
AE-8 Tanh
0.27565
0.03355
0.03421
0.03343
0.03392
0.03408
AE-16 Sigmoid
0.28717
0.06223
0.06223
0.03480
0.03310
0.03517
AE-16 ReLU
0.07722
0.03512
0.03419
0.03400
0.03456
0.03463
AE-16 Tanh
0.27884
0.03496
0.03452
0.03637
0.03815
0.03405
Maximum
0.28440
0.06714
0.06714
0.03753
0.04505
0.03659
AE-8 Sigmoid
0.11333
0.08508
0.08525
0.08257
0.08324
0.08170
AE-8 ReLU
0.11123
0.08202
0.08076
0.08225
0.08160
0.08154
AE-8 Tanh
0.10742
0.08571
0.07904
0.07225
0.07306
0.07447
AE-16 Sigmoid
0.11397
0.11396
0.11383
0.07405
0.09050
0.07912
AE-16 ReLU
0.11394
0.11396
0.08035
0.08103
0.07993
0.08441
AE-16 Tanh
0.10891
0.07736
0.07713
0.07899
0.07519
0.08025
AE-8 Sigmoid
0.28440
0.07835
0.06526
0.06344
0.06417
0.06522
AE-8 ReLU
0.28440
0.03861
0.03352
0.03276
0.03225
0.03239
AE-8 Tanh
0.27568
0.09993
0.08604
0.06801
0.07580
0.07544
AE-16 Sigmoid
0.28718
0.28671
0.28648
0.07735
0.07588
0.10232
AE-16 ReLU
0.07724
0.05507
0.05321
0.05188
0.06009
0.05564
AE-16 Tanh
0.27887
0.08614
0.08939
0.10874
0.12128
0.08714
Maximum
0.28718
0.28671
0.28648
0.10874
0.12128
0.10232
For each row, the largest (worst) loss is shown in red, and the lowest (best) loss is shown in green
Non-convex Optimization Using Parameter Continuation …
285
observed C-Sigmoid has almost linear behaviour until λ = 0.8 after which it rapidly adapts to the Sigmoid. Therefore, we needed an adaptive method that can reasonably tune this λ update for different activation functions [52] and also adapt to the nature of the homotopy path. We elaborate on these λ choices in Sect. 5 Accordingly, we developed an adaptive method for determining λ update by utilizing the information of gradients during backpropagation and developed Algorithm 1. NPAC has two benefits, first, it solves the issue of hand-picking the value of λ. Second, this method provides a reasonable approach to determine how close the next suitable problem or λ value should be, for which the current solution (say at λ = 0.25) would be a good initialization. Algorithm 1 Adaptive λ Require: norm_grads- list of Norm of gradients of objective function from previous t steps, λi,i−1 for previous step, scale_up and scale_down factors. 1: avg_norm_p ← mean of first half of norm_grads 2: avg_norm_c ← mean of second half of norm_grads 3: condition1 ← (avg_norm_p−avg_norm_c) < (−tolerance) avg_norm_p
4: condition2 ← (avg_norm_p−avg_norm_c) > (tolerance) avg_norm_p 5: if condition1 then 6: λi+1,i ← λi,i−1 · scale_up 7: else if condition2 then λi,i−1 8: λi+1,i ← scale_down 9: else 10: λi+1,i ← λi,i−1 11: end if 12: return λi+1,i
4.6 Natural Parameter Adaptive Continuation with Secant Approximation (NPACS) NPACS is a more advanced version of NPAC, where we enhance our method with a secant step in θ (multidimensional) space along with the adaptive λ update. A secant line to a curve is one that passes through at least two distinct points [31], and we use this method to find linear approximation of the homotopy path, for a particular neighbourhood [2]. In previous two methods, we simply assigned the previous solution as the initialization for the current problem, whereas in NPACS we take the previous two solutions and apply a secant approximation to initialize the next step. An important step is to properly normalize for the secant step for the parameters (θ ) of a neural network which are commonly thousands or even millions of dimensions, depending on the network size [52]. In Fig. 5, we illustrate the geometric interpretation of Eq. (15). A clear advantage of NPACS is that a secant update follows the homotopy curve more closely to approximate the derivative of
286
H. Nilesh Pathak and R. Clinton Paffenroth
Fig. 5 This figure shows the secant update in our NPCAS method. Unlike other continuation methods, here we utilize the previous two solutions to draw a secant vector in multidimensional space (θ)
the curve [52] and initialize the subsequent problem accordingly. We also present Algorithm 2, that demonstrates all the required steps to implement NPACS. Here we perform model continuation for an AE (depth 8 and 16), using ADAM updates to solve a particular optimization at λi . Depending on the problem, any other optimizer may also be selected. ← − θλi + (θλi − θλi−1 ) · θλinit i+1
λi+1,i λi,i−1
(15)
5 Experimental Results We used two popular datasets for our experiments, namely, CIFAR-10 [37] and Fashion-MNIST [67]. Fashion-MNIST has 55,000 training images and 10,000 test images with ten different classes and each image has 784 dimensions. Similarly, for CIFAR-10 we used 40,000 images for training and 10,000 as a test/validation set. Each image in this dataset is 3072 (32 × 32 × 3) dimensions with ten different classes. CIFAR-10 is a more challenging dataset than Fashion-MNIST [52] and also widely used by the researchers. Next, we have autoencoders (AE), an unsupervised learning method to test our optimization technique. The employed AE is shown in Fig. 6. In case of Fashion MNIST dataset, the input is a 784-dimensional image (or 3072-dimensional image for CIFAR-10 dataset). AE is then used to perform the reconstruction of the image from only two-dimensional representation, which is encoded in the hidden layer. In particular, two neural networks are evaluated, namely, AE-8 and AE-16 of depth 8 and 16, respectively, for all our experiments. We compare our parameter continuation methods NPC at (λ = 8e − 3), NPAC and NPACS against existing methods such as ADAM, SGD and RMSProp. Primarily, task consistency plays a key role in providing conclusive empirical evaluations, and we achieve it by keeping the data, architecture and hyperparameters fixed. As explained in Sect. 4.3 SVD is used for initialization of our network consistently
Non-convex Optimization Using Parameter Continuation …
287
Algorithm 2 NPACS- model Continuation for AE using ADAM Require: Learning rate , Number of Adam steps to perform after every continuation step - u_ f r eq, Initial λ, Initial neural network parameter θ using PCA, intial homotopy parameter λ, t- after t steps adaptive behaviour will start. 1: k ← 1 2: i ← 1 3: nor m_grads ← [ ] 4: loss_histor y ← [ ] 5: while stopping criteria not met do i 6: Sample a minibatch from data x 7: Compute Loss loss ← j J (x, g( f (x)); θ; λi ) 8: loss_histor y ← Append loss 9: Compute gradient estimate gˆ ← AD AM() 10: nor m_grads ←Append ||g|| ˆ 2 11: Apply gradient θ ← θ − · gˆ 12: k ← k + 1 13: if k%u_ f r eq == 0 and u_ f r eq > 0 then 14: if k%t == 0 and k > 0 then 15: λi+1,i ← Compute Adaptive λ() Algorithm-1 16: end if λ 17: θλinit ← − θλi + (θλi − θλi−1 ) · λi+1,i i+1 i,i−1 18: i =i +1 19: Compute loss ← j J (x, g( f (x)); θ; λi ) 20: loss_histor y ← Append loss 21: k ←k+1 22: end if 23: end while Fig. 6 This figure shows Autoencoder (AE-8) we used in our experiments. It is an eight-layer deep network with specified width as shown in the above block diagram. Additionally, we apply one particular activation function to all the hidden layers of the network, except the code and the output layer. We specify these while reporting results
across all our experiments. This provides us an exact vector of parameters (θ ) at λ = 0. Selection of λ was carried out using line search between (8e-5 and 2e-2) and we used the best performing λ =8e-3 for NPC methods. Further, for NPAC and NPACS methods λ was chosen by our adaptive Algorithm 1. Another important hyperparameter is number of ADAM (or any other solver) steps in between two continuation steps. Again, we performed a linear search over all values between 5 and 500, and found 10 ADAM steps to work the best and we use it consistently. Also
288
H. Nilesh Pathak and R. Clinton Paffenroth
note, Sigmoid, Tanh and ReLU are used with SGD, ADAM and RMSProp optimization techniques, but while implementing continuation methods, their continuation counterparts were used, such as C-Sigmoid and C-ReLU from (5). Finally, we are going to opensource our code and all the hyperparameter choices on Github.3 Three important metrics to test a new optimizer are the qualitative analysis of training loss, generalization loss and convergence speed [28]. Table 1, depicts training and validation loss of our network (six variants or tasks) with six different optimization methods and two popular datasets. Table 1 indicates that our continuation methods are consistently performing better at both validation and training loss. There are few more interesting conclusions that can be drawn from Table 1. First, as expected, ADAM, RMSProp and continuation methods performed better than SGD. This also shows the optimization tasks for the experiment are not trivial. Second, our network variant AE-16 Sigmoid turns out to be the most challenging task where ADAM and RMSProp get stuck in a bad local minima which empirically shows networks are difficult to train. However, our methods were able to skip these sub-optimal local minima and achieved 49.18% lower training loss with 34.94% better generalization as compared to ADAM, as shown in Table 1. Similar results were observed with the Fashion-MNIST dataset. Another optimization bottleneck was seen in case of AE-16 ReLU with the Fashion MNIST data, here RMSProp was unable to avoid a sub-optimal local minimum, on the other hand, our methods had 48.88% lower training loss and 30.08% better generalization (testing) error. Finally, we qualitatively demonstrate the maximum loss attained by various optimizers at the final step. Clearly from Table 1, our parameter continuation methods visibly have a much lower loss across all distinct tasks. Further, to report convergence speed we analyze the convergence plots of all the distinct tasks. Out of all plausible ways to define the speed of convergence. In this chapter, an optimizer is faster if it obtains a lower training loss at an earlier step in contrast to competing optimizer [52]. Hence, we studied the convergence plots of various optimizers and found that continuation methods, i.e. NPC, NPAC and NPACS, converged faster in majority of the tasks. In Fig. 7, we can see that the tasks AE-16 and AE-8 Sigmoid networks continuation methods have a much lower training error from the very beginning of the training as compared to both the RMSProp and ADAM optimizer. Thus, our results from Table 1. Additionally, we extend our results to three standard activation functions, C-Sigmoid in Fig. 7, C-Relu in Fig. 8 and CTanh in Fig. 9. Through these tables we not only show that our generalization loss is better in most cases but also that the continuation methods converges faster (i.e. achieve lower train loss in fewer ADAM updates).
3 https://github.com/harsh306/NPACS.
Non-convex Optimization Using Parameter Continuation …
289
Fig. 7 The figures above demonstrate the convergence of various optimization methods when CSigmoid is used as an activation function. In all cases, the X-axis shows the number of iterations, and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide some of the lowest loss values, for training and testing, throughout the optimization procedure
290
H. Nilesh Pathak and R. Clinton Paffenroth
Fig. 8 The figures above demonstrate the convergence of various optimization methods when CReLU is used as an activation function. In all cases, the X-axis shows the number of iterations, and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide some of the lowest loss values, for training and testing, throughout the optimization procedure
Non-convex Optimization Using Parameter Continuation …
291
Fig. 9 The figures above demonstrate the convergence of various optimization methods when CTanh is used as an activation function. In all cases, the X-axis shows the number of iterations, and the Y-axis shows the log-loss or log-validation-loss. Note that the methods we propose provide some of the lowest loss values, for training and testing, throughout the optimization procedure
292
H. Nilesh Pathak and R. Clinton Paffenroth
6 Open Problems and Future Directions Continuation methods have a long history of solving non-convex optimization problems. Recently there has been great progress in enhancing deep learning training using methods that are akin to continuation methods, for example, data curriculum strategies [3, 23, 32, 66]. We believe this trend will continue and have an enormous impact on more fields such as Reinforcement learning and Meta-learning. In the paper, AI-generating algorithms (AI-GAs) [9], the authors proposed three pillars for the future of general AI, namely, (1) meta-learning algorithms, (2) meta-learning architectures and (3) generating effective learning environments. We believe that continuation methods are a good candidate to assist in these directions in a principled manner. Next, we discuss a few open problems for the deep learning community.
6.1 Model Continuation In [3, 50], the authors illustrated that continuation strategies can be introduced in deep learning through model and data continuation. The methods illustrated in this chapter are classical examples of model continuation. In Sect. 3, we discussed different model continuation methods [21, 22, 44], and their advantages. These methods have compared their training strategy with a target model and reported better generalization error. However, we observe there are some limitations in these approaches. Ideally, for the proposed continuation techniques it will be convenient to have them empirically evaluated using different neural architectures and a variety of tasks. We understand that such detailed evaluations may require substantial effort. First, we observed that many different types of neural networks have not yet been systematically tested. For example, diffusion methods were applied on RNNs only [44]. MLPs and LSTMs were tested by [21] and only mollification were tested on CNNs, RNNs and MLPs [22, 54]. In future research, it will be interesting to see how the proposed curriculum affects the convergence of different types of networks. Second, some methods [3, 22] were tested with limited depth and may not necessarily be categorized as deep networks. Third, applications of curriculum learning may be improved with more empirical evaluations from the methods in literature. In particular, it will be interesting to see comparisons with some techniques such as ResNets [38], Dropout [58], and different normalizations for more tasks such as classification, language modelling, regression and reconstruction [22]. As far as this chapter is concerned there are several natural extensions. First, C-Activation can be any activation function, and in the case of multiple activation functions in one network, we may add multiple homotopy parameters for each of them. Also, as stated earlier, there is nothing in these techniques that makes them specific in the AE case. In particular, we would also like to train Convolutional and Recurrent Neural Networks described in [15, 39, 65] using our method. Second, our adaptive method can be improved to conform to the homotopy path more accurately, by using Pseudo-arclength methods [34, 48].
Non-convex Optimization Using Parameter Continuation …
293
6.2 Data Continuation Many researchers have shown learning data from a designed curriculum leads to better learning and robustness to out of distribution error. Some of the popular methods are progressive training [32, 50], Curriculum Learning [3, 23], Curriculum Learning by Transfer Learning [66], C-SMOTE [50], etc. These methods have shown improvement in the generalization performance. However, there is usually some limitation because of the data type. For example, the design of Progressive GANs [32] was applicable for images but may prove to be difficult to apply for a text dataset. For future research it will be interesting to see how data curriculum methods work with different kinds of data, such as images, text, sound and time-series. As a first step in this direction, we collect and share a list of datasets.4 These datasets are indexed from different sources and hope they could prove to be a useful resource for fellow researchers. Recent research work illustrates how a data curriculum can be learned via continuation [50], and automation can be introduced by measuring a model’s learning [18, 30, 55, 59, 66, 68] with limited benchmarks which could likely be extended in several directions.
6.3 Loss Surface Understanding A Deep Neural Network’s loss surface is usually non-convex [8, 17, 27], and depending on the neural architecture, many attempts have been made to theoretically categorize the loss surfaces [17, 33, 41]. Continuation methods or curriculum learning may provide a unique perspective to understand the loss surfaces of the neural network. In particular, we plan to extend our work to perform Bifurcation Analysis of Neural Networks. Bifurcations [2, 11, 49] are the dynamic changes in the behaviour of the system that may occur while we track the solution through the homotopy path. In our case, as we change our activation from linear to non-linear, detected bifurcations [34] of neural networks may help us explain and identify local minima and they may also help us to understand the so-called “black-box” of neural networks better.
6.4 Environment Generation Curricula In Reinforcement learning, agents take some actions in an environment and then provide rewards or feedback for the next action. Similar to data continuation, we can think of the generation of training environments with varied complexities for Reinforcement learning. For example, Reward Shaping [13, 20] is a method where a user defines a curriculum of environments for the agent. However, while a curriculum may be essential for some tasks in principle, in practice it is challenging to know 4 Dataset
collection: https://github.com/harsh306/curriculum-datasets.
294
H. Nilesh Pathak and R. Clinton Paffenroth
the right curriculum for any given task [63]. In the paper [63], the authors empirically show that sometimes intuitively sound curriculum may even harm training (i.e. learning harder tasks can lead to better solutions [56]). As a result, many research works have designed self-generative curricula for training environments to introduce progressively difficult and diverse scenarios for an agent and thus increase robust learning [1, 59, 63]. One can easily imagine continuation methods having a role to play for improving such methods.
6.5 Multi-task Learning In multi-task learning [6] model is designed to perform multiple tasks through the same set of parameters. Designing such a system is more difficult, as learning from one task may elevate or degrade the learning for other task and may also lead to Catastrophic forgetting [4]. Synchronizing neural network to perform multiple tasks is an open research problem and recently, curriculum learning is used to meet this challenge [53, 64]. These can be broadly categorized in two ways, first, design a curricula that provides selected batch of data dynamically, so that it benefits multi-task objective [61, 64]. Usually one need to determine importance of each instance for learning a task for implementing such an automotive system, for which Bayesian optimization is a classic choice [64]. Second, determine the best order of tasks to be learned [53]. The latter approach is restricted to be implemented sequentially, following the determined curriculum [53]. We expect progressive research in this direction may significantly enhance multi-task learning. Also, in the above-mentioned approaches the first approach is similar to the data continuation, and second is similar to model continuation. We believe that we can draw similarities between the two and make impactful progress in the field of multi-task learning.
6.6 Hyperparameter Optimization Meta-learning is an active area of research [14] which was recently surveyed in [9]. We understand there are many aspects of meta-learning [14, 46, 46, 62], but for scope of our discussion we limit our focus to hyperparameter search. The search space of different hyperparameters can be traced efficiently using continuation methods, or even multi-parameter continuation [2, 11, 34]. One can imagine two hyperparameter λm and λd , where λm corresponds to model continuation and λd enables data continuation. Then, we can apply continuation methods to optimize both strategies simultaneously for a given neural network task. Recently, researchers used implicit differentiation for hyperparameter optimization [42], in which millions of network weights and hyperparameters can be jointly tuned. Multi-parameter Continuation Optimization is a promising area of research in fields such as mathematical analysis [2, 11, 34] and others [19, 29]. One may
Non-convex Optimization Using Parameter Continuation …
295
define and optimize most of the components or hyperparameters of deep learning framework via a continuation scheme. We may derive more efficient methods for hyperparameter search that could be capable of evolving from a generalized network via adding/pruning of layers and additional similar ideas.
7 Conclusions In this chapter, we exhibited a novel training strategy for deep neural networks. As observed in most of the experiments all three continuation methods achieved faster convergence, lower loss and better generalization than standard optimization techniques. Also, we empirically show that the proposed method works with popular activation functions, deeper networks and with distinct datasets. Finally, we examine some of the possible improvements and provide various future directions to further in deep learning using continuation methods.
References 1. I. Akkaya, M. Andrychowicz, M. Chociej, M. Litwin, B. McGrew, A. Petron, A. Paino, M. Plappert, G. Powell, R. Ribas et al., Solving rubik’s cube with a robot hand (2019). arXiv preprint arXiv:1910.07113 2. E. Allgower, K. Georg, Introduction to numerical continuation methods. Soc. Ind. Appl. Math. (2003). https://epubs.siam.org/doi/abs/10.1137/1.9780898719154 3. Y. Bengio, J. Louradour, R. Collobert, J. Weston, Curriculum learning (2009) 4. Y. Bengio, M. Mirza, I. Goodfellow, A. Courville, X. Da, An empirical investigation of catastrophic forgeting in gradient-based neural networks (2013) 5. Z. Cao, M. Long, J. Wang, P.S. Yu, Hashnet: deep learning to hash by continuation. CoRR (2017). arXiv:abs/1702.00758 6. R. Caruana, Multitask learning. Mach. Learn. 28(1), 41–75 (1997) 7. T. Chan, K. Jia, S. Gao, J. Lu, Z. Zeng, Y. Ma, Pcanet: a simple deep learning baseline for image classification? IEEE Trans. Image Process. 24(12), 5017–5032 (2015). https://doi.org/ 10.1109/TIP.2015.2475625 8. A. Choromanska, M. Henaff, M. Mathieu, G.B. Arous, Y. LeCun, The loss surface of multilayer networks. CoRR (2014). arXiv:abs/1412.0233 9. J. Clune, Ai-gas: ai-generating algorithms, an alternate paradigm for producing general artificial intelligence. CoRR (2019). arXiv:abs/1905.10985 10. T. Dick, E. Wong, C. Dann, How many random restarts are enough 11. E.J. Doedel, T.F. Fairgrieve, B. Sandstede, A.R. Champneys, Y.A. Kuznetsov, X. Wang, Auto07p: continuation and bifurcation software for ordinary differential equations (2007) 12. J. Duchi, E. Hazan, Y. Singer, Adaptive subgradient methods for online learning and stochastic optimization. J. Mach. Learn. Res. 12, 2121–2159 (2011). http://dl.acm.org/citation.cfm? id=1953048.2021068 13. T. Erez, W.D. Smart, What does shaping mean for computational reinforcement learning? in 2008 7th IEEE International Conference on Development and Learning (2008), pp. 215–219. https://doi.org/10.1109/DEVLRN.2008.4640832 14. C. Finn, P. Abbeel, S. Levine, Model-agnostic meta-learning for fast adaptation of deep networks, in Proceedings of the 34th International Conference on Machine Learning, vol. 70. (JMLR. org, 2017), pp. 1126–1135
296
H. Nilesh Pathak and R. Clinton Paffenroth
15. I. Goodfellow, Y. Bengio, A. Courville, Deep Learning (MIT Press, 2016). http://www. deeplearningbook.org 16. I.J. Goodfellow, NIPS 2016 tutorial: generative adversarial networks. NIPS (2017). arXiv:abs/1701.00160 17. I.J. Goodfellow, O. Vinyals, Qualitatively characterizing neural network optimization problems. CoRR (2014). arXiv:abs/1412.6544 18. A. Graves, M.G. Bellemare, J. Menick, R. Munos, K. Kavukcuoglu, Automated curriculum learning for neural networks. CoRR (2017). arXiv:abs/1704.03003 19. C. Grenat, S. Baguet, C.H. Lamarque, R. Dufour, A multi-parametric recursive continuation method for nonlinear dynamical systems. Mech. Syst. Signal Process. 127, 276–289 (2019) 20. M. Grzes, D. Kudenko, Theoretical and empirical analysis of reward shaping in reinforcement learning, in 2009 International Conference on Machine Learning and Applications (2009), pp. 337–344. 10.1109/ICMLA.2009.33 21. C. Gülçehre, M. Moczulski, M. Denil, Y. Bengio, Noisy activation functions. CoRR (2016). arXiv:abs/1603.00391 22. C. Gülçehre, M. Moczulski, F. Visin, Y. Bengio, Mollifying networks. CoRR (2016). arXiv:abs/1608.04980 23. G. Hacohen, D. Weinshall, On the power of curriculum learning in training deep networks. CoRR (2019). arXiv:abs/1904.03626 24. G. Hinton, L. Deng, D. Yu, G.E. Dahl, A. Mohamed, N. Jaitly, A. Senior, V. Vanhoucke, P. Nguyen, T.N. Sainath, B. Kingsbury, Deep neural networks for acoustic modeling in speech recognition: the shared views of four research groups. IEEE Signal Process. Mag. 29(6), 82–97 (2012). https://doi.org/10.1109/MSP.2012.2205597 25. G. Hinton, N. Srivastava, K. Swersky, Rmsprop: divide the gradient by a running average of its recent magnitude. Neural networks for machine learning, Coursera lecture 6e (2012) 26. G.E. Hinton, R.R. Salakhutdinov, Reducing the dimensionality of data with neural networks. Science 313(5786), 504–507 (2006). https://doi.org/10.1126/science.1127647, http://science. sciencemag.org/content/313/5786/504 27. D.J. Im, M. Tao, K. Branson, An empirical analysis of deep network loss surfaces. CoRR (2016). arXiv:abs/1612.04010 28. D. Jakubovitz, R. Giryes, M.R. Rodrigues, Generalization error in deep learning, in Compressed Sensing and Its Applications (Springer, 2019), pp. 153–193 29. F. Jalali, J. Seader, Homotopy continuation method in multi-phase multi-reaction equilibrium systems. Comput. Chem. Eng. 23(9), 1319–1331 (1999) 30. L. Jiang, Z. Zhou, T. Leung, L.J. Li, L. Fei-Fei, Mentornet: learning data-driven curriculum for very deep neural networks on corrupted labels, in ICML (2018) 31. R. Johnson, F. Kiokemeister, Calculus, with Analytic Geometry (Allyn and Bacon, 1964). https://books.google.com/books?id=X4_UAQAACAAJ 32. T. Karras, T. Aila, S. Laine, J. Lehtinen, Progressive growing of gans for improved quality, stability, and variation. CoRR (2017). arXiv:abs/1710.10196 33. K. Kawaguchi, L.P. Kaelbling, Elimination of all bad local minima in deep learning. CoRR (2019) 34. H.B. Keller, Numerical solution of bifurcation and nonlinear eigenvalue problems, in Applications of Bifurcation Theory, ed. by P.H. Rabinowitz (Academic Press, New York, 1977), pp. 359–384 35. D.P. Kingma, J. Ba, Adam: a method for stochastic optimization. CoRR (2014). arXiv:abs/1412.6980 36. P. Krähenbühl, C. Doersch, J. Donahue, T. Darrell, Data-dependent initializations of convolutional neural networks. CoRR (2015). arXiv:abs/1511.06856 37. A. Krizhevsky, V. Nair, G. Hinton, Cifar-10 (canadian institute for advanced research). http:// www.cs.toronto.edu/kriz/cifar.html 38. A. Krizhevsky, I. Sutskever, G.E. Hinton, Imagenet classification with deep convolutional neural networks, in Advances in Neural Information Processing Systems (2012)
Non-convex Optimization Using Parameter Continuation …
297
39. Y. LeCun, Y. Bengio, G. Hinton, Deep learning. Nature (2015). https://www.nature.com/ articles/nature14539 40. C. Ledig, L. Theis, F. Huszár, J. Caballero, A. Cunningham, A. Acosta, A. Aitken, A. Tejani, J. Totz, Z. Wang et al., Photo-realistic single image super-resolution using a generative adversarial network, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2017), pp. 4681–4690 41. S. Liang, R. Sun, J.D. Lee, R. Srikant, Adding one neuron can eliminate all bad local minima, in S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, R. Garnett (eds.), Advances in Neural Information Processing Systems, vol. 31 (Curran Associates, Inc., 2018), pp. 4350–4360. http://papers.nips.cc/paper/7688-adding-one-neuron-can-eliminate-all-badlocal-minima.pdf 42. J. Lorraine, P. Vicol, D. Duvenaud, Optimizing millions of hyperparameters by implicit differentiation (2019). arXiv preprint arXiv:1910.07113 43. T. Mikolov, I. Sutskever, K. Chen, G.S. Corrado, J. Dean, Distributed representations of words and phrases and their compositionality, in C.J.C. Burges, L. Bottou, M. Welling, Z. Ghahramani, K.Q. Weinberger (eds.), Advances in Neural Information Processing Systems, vol. 26 (Curran Associates, Inc., 2013), pp. 3111–3119. http://papers.nips.cc/paper/5021-distributedrepresentations-of-words-and-phrases-and-their-compositionality.pdf 44. H. Mobahi, Training recurrent neural networks by diffusion. CoRR (2016). arXiv:abs/1601.04114 45. H. Mobahi, III, J.W. Fisher, On the link between gaussian homotopy continuation and convex envelopes, in Lecture Notes in Computer Science (EMMCVPR 2015) (Springer, 2015) 46. A. Nagabandi, I. Clavera, S. Liu, R.S. Fearing, P. Abbeel, S. Levine, C. Finn, Learning to adapt in dynamic, real-world environments through meta-reinforcement learning (2018). arXiv preprint arXiv:1803.11347 47. K. Nordhausen, The elements of statistical learning: data mining, inference, and prediction, second edn. T. Hastie, R. Tibshirani, J. Friedman (eds.), Int. Stat. Rev. 77(3), 482–482 48. R. Paffenroth, E. Doedel, D. Dichmann, Continuation of periodic orbits around lagrange points and auto2000, in AAS/AIAA Astrodynamics Specialist Conference (Quebec City, Canada, 2001) 49. R.C. Paffenroth, Mathematical visualization, parameter continuation, and steered computations. Ph.D. thesis, AAI9926816 (College Park, MD, USA, 1999) 50. H.N. Pathak, Parameter continuation with secant approximation for deep neural networks (2018) 51. H.N. Pathak, X. Li, S. Minaee, B. Cowan, Efficient super resolution for large-scale images using attentional gan, in 2018 IEEE International Conference on Big Data (Big Data) (IEEE, 2018), pp. 1777–1786 52. H.N. Pathak, R. Paffenroth, Parameter continuation methods for the optimization of deep neural networks, in 2019 18th IEEE International Conference on Machine Learning And Applications (ICMLA) (IEEE, 2019), pp. 1637–1643 53. A. Pentina, V. Sharmanska, C.H. Lampert, Curriculum learning of multiple tasks, in Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (2015), pp. 5492–5500 54. J. Rojas-Delgado, R. Trujillo-Rasúa, R. Bello, A continuation approach for training artificial neural networks with meta-heuristics. Pattern Recognit. Lett. 125, 373–380 (2019). https://doi.org/10.1016/j.patrec.2019.05.017, http://www.sciencedirect.com/science/ article/pii/S0167865519301667 55. S. Saxena, O. Tuzel, D. DeCoste, Data parameters: a new family of parameters for learning a differentiable curriculum (2019) 56. B. Settles, Active Learning Literature Survey, , Tech. rep. (University of Wisconsin-Madison Department of Computer Sciences, 2009) 57. M. Seuret, M. Alberti, R. Ingold, M. Liwicki, Pca-initialized deep neural networks applied to document image analysis. CoRR (2017). arXiv:abs/1702.00177 58. N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, R. Salakhutdinov, Dropout: a simple way to prevent neural networks from overfitting. J. Mach. Learn. Res. 15(1), 1929–1958 (2014)
298
H. Nilesh Pathak and R. Clinton Paffenroth
59. F.P. Such, A. Rawal, J. Lehman, K. Stanley, J. Clune, Generative teaching networks: accelerating neural architecture search by learning to generate synthetic training data (2020) 60. I. Sutskever, J. Martens, G. Dahl, G. Hinton, On the importance of initialization and momentum in deep learning, in International Conference on Machine Learning (2013), pp. 1139–1147 61. Y. Tsvetkov, M. Faruqui, W. Ling, B. MacWhinney, C. Dyer, Learning the curriculum with Bayesian optimization for task-specific word representation learning, in Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics, Long Papers, vol. 1 (Association for Computational Linguistics, Berlin, Germany, 2016), pp. 130–139. https://doi. org/10.18653/v1/P16-1013., https://www.aclweb.org/anthology/P16-1013 62. R. Vilalta, Y. Drissi, A perspective view and survey of meta-learning. Artif. Intell. Rev. 18(2), 77–95 (2002) 63. R. Wang, J. Lehman, J. Clune, K.O. Stanley, Paired open-ended trailblazer (POET): endlessly generating increasingly complex and diverse learning environments and their solutions. CoRR (2019). arXiv:abs/1901.01753 64. W. Wang, Y. Tian, J. Ngiam, Y. Yang, I. Caswell, Z. Parekh, Learning a multitask curriculum for neural machine translation (2019). arXiv preprint arXiv:1908.10940 (2019) 65. M.A. Wani, F.A. Bhat, S. Afzal, A.I. Khan, Advances in deep learning, in Advances in Deep Learning (Springer, 2020), pp. 1–11 66. D. Weinshall, G. Cohen, Curriculum learning by transfer learning: theory and experiments with deep networks. CoRR (2018). arXiv:abs/1802.03796 67. H. Xiao, K. Rasul, R. Vollgraf, Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms (2017). https://github.com/zalandoresearch/fashion-mnist 68. H. Xuan, A. Stylianou, R. Pless, Improved embeddings with easy positive triplet mining (2019) 69. C. Zhou, R.C. Paffenroth, Anomaly detection with robust deep autoencoders, in Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining (ACM, 2017), pp. 665–674
Author Index
A Alfarhood, Meshal, 1 Arif Wani, M., 101
C Cheng, Jianlin, 1 Christopoulos, Stavros-Richard G., 229 Clinton Paffenroth, Randy, 273
D da Silva, Gabriel Pellegrino, 49
F Fischer, Georg, 173
G Gaffar, Ashraf, 123 Gühmann, Clemens, 81
H Hartmann, Sven, 81
I Imran, Abdullah-Al-Zubaer, 249
J Johnson, Justin M., 199
K Kamran, Sharif Amit, 25 Kanarachos, Stratis, 229 Khoshgoftaar, Taghi M., 199 Kouchak, Shokoufeh Monjezi, 123
L Leite, Guilherme Vieira, 49
M Mujtaba, Tahir, 101 Mukherjee, Tathagata, 143
N Nilesh Pathak, Harsh, 273
O Onyekpe, Uche, 229
P Palade, Vasile, 229 Pasiliao, Eduardo, 143 Pedrini, Helio, 49
R Rosato, Daniele, 81 Roy, Debashri, 143
© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Singapore Pte Ltd. 2021 M. A. Wani et al. (eds.), Deep Learning Applications, Volume 2, Advances in Intelligent Systems and Computing 1232, https://doi.org/10.1007/978-981-15-6759-9
299
300 S Sabbir, Ali Shihab, 25 Sabir, Russell, 81 Saha, Sourajit, 25 Santra, Avik, 173 Stephan, Michael, 173
Author Index T Tavakkoli, Alireza, 25 Terzopoulos, Demetri, 249