258 20 9MB
English Pages [417] Year 2023
Studies in Computational Intelligence 1060
Arash Shaban-Nejad Martin Michalowski Simone Bianco Editors
Multimodal AI in Healthcare A Paradigm Shift in Health Intelligence
Studies in Computational Intelligence Volume 1060
Series Editor Janusz Kacprzyk, Polish Academy of Sciences, Warsaw, Poland
The series “Studies in Computational Intelligence” (SCI) publishes new developments and advances in the various areas of computational intelligence—quickly and with a high quality. The intent is to cover the theory, applications, and design methods of computational intelligence, as embedded in the fields of engineering, computer science, physics and life sciences, as well as the methodologies behind them. The series contains monographs, lecture notes and edited volumes in computational intelligence spanning the areas of neural networks, connectionist systems, genetic algorithms, evolutionary computation, artificial intelligence, cellular automata, self-organizing systems, soft computing, fuzzy systems, and hybrid intelligent systems. Of particular value to both the contributors and the readership are the short publication timeframe and the world-wide distribution, which enable both wide and rapid dissemination of research output. This series also publishes Open Access books. A recent example is the book Swan, Nivel, Kant, Hedges, Atkinson, Steunebrink: The Road to General Intelligence https://link.springer.com/book/10.1007/978-3-031-08020-3 Indexed by SCOPUS, DBLP, WTI Frankfurt eG, zbMATH, SCImago. All books published in the series are submitted for consideration in Web of Science.
Arash Shaban-Nejad · Martin Michalowski · Simone Bianco Editors
Multimodal AI in Healthcare A Paradigm Shift in Health Intelligence
Editors Arash Shaban-Nejad Oak-Ridge National Laboratory (ORNL) Center for Biomedical Informatics The University of Tennessee Health Science Center (UTHSC) Memphis, TN, USA
Martin Michalowski School of Nursing University of Minnesota Minneapolis, MN, USA
Simone Bianco Research Center IBM Almaden San Jose, CA, USA
ISSN 1860-949X ISSN 1860-9503 (electronic) Studies in Computational Intelligence ISBN 978-3-031-14770-8 ISBN 978-3-031-14771-5 (eBook) https://doi.org/10.1007/978-3-031-14771-5 © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 This work is subject to copyright. All rights are solely and exclusively licensed by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors, and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Switzerland AG The registered company address is: Gewerbestrasse 11, 6330 Cham, Switzerland
Preface
Multimodal Artificial Intelligence is a relatively new concept in high performance computational sciences that aims at integrating multiple data streams in different formats (e.g., text, image, video, audio, and numerical data) to improve the accuracy of information extraction and inference, reduce bias, and generate an overall better representation of the physical, medical, or societal processes described by the data. Incorporating multimodal AI to process multidimensional and multimodal data sets in mission critical domains such as health and medicine can advance health analytics, improve case finding/prediction, diagnosis, risk stratification, referrals, and follow up and decision-making by health professionals and policymakers. This book aims to highlight the latest achievements in the use of AI and multimodal artificial intelligence in biomedicine and healthcare. The edited volume contains selected papers presented at the 2022 Health Intelligence workshop and the associated Data Hackathon/Challenge, co-located with the 36th Association for the Advancement of Artificial Intelligence (AAAI) conference, and presents an overview of the issues, challenges, and potentials in the field, along with new research results. This book provides information for researchers, students, industry professionals, clinicians, and public health agencies interested in the applications of AI in public health and medicine. Memphis, USA Minneapolis, USA San Jose, USA
Arash Shaban-Nejad Martin Michalowski Simone Bianco
v
Contents
Multimodal Artificial Intelligence: Next Wave of Innovation in Healthcare and Medicine . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Arash Shaban-Nejad, Martin Michalowski, and Simone Bianco
1
Unsupervised Numerical Reasoning to Extract Phenotypes from Clinical Text by Leveraging External Knowledge . . . . . . . . . . . . . . . . Ashwani Tanwar, Jingqing Zhang, Julia Ive, Vibhor Gupta, and Yike Guo
11
Domain-specific Language Pre-training for Dialogue Comprehension on Clinical Inquiry-Answering Conversations . . . . . . . . . Zhengyuan Liu, Pavitra Krishnaswamy, and Nancy F. Chen
29
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Gayani Nanayakkara, Nirmalie Wiratunga, David Corsar, Kyle Martin, and Anjana Wijekoon Customized Training of Pretrained Language Models to Detect Post Intents in Online Health Support Groups . . . . . . . . . . . . . . . . . . . . . . . Tootiya Giyahchi, Sameer Singh, Ian Harris, and Cornelia Pechmann EXPECT-NLP: An Integrated Pipeline and User Interface for Exploring Patient Preferences Directly from Patient-Generated Text . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . David Johnson, Nick Dragojlovic, Nicola Kopac, Yifu Chen, Marilyn Lenzen, Sarah Le Huray, Samantha Pollard, Dean Regier, Mark Harrison, Amy George, Giuseppe Carenini, Raymond Ng, and Larry Lynd Medication Error Detection Using Contextual Language Models . . . . . . . Yu Jiang and Christian Poellabauer
41
59
77
91
vii
viii
Contents
Latent Representation Weights Learning of the Indefinite Length of Views for Conception Diagnosis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 101 Bo Li, Mengze Sun, Yuan Yu, Yuanyuan Zhao, Zhongliang Xiang, and Zhiyong An Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 117 Andre Vauvelle, Hamish Tomlinson, Aaron Sim, and Spiros Denaxas Out-of-Distribution Detection for Medical Applications: Guidelines for Practical Evaluation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 137 Karina Zadorozhny, Patrick Thoral, Paul Elbers, and Giovanni Cinà A Robust System to Detect and Explain Public Mask Wearing Behavior . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 155 Akshay Gupta and Biplav Srivastava A Federated Cox Model with Non-proportional Hazards . . . . . . . . . . . . . . 171 D. Kai Zhang, Francesca Toni, and Matthew Williams A Step Towards Automated Functional Assessment of Activities of Daily Living . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 187 Bappaditya Debnath, Mary O’brien, Swagat Kumar, and Ardhendu Behera The Interpretation of Deep Learning Based Analysis of Medical Images—An Examination of Methodological and Practical Challenges Using Chest X-ray Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 203 Steinar Valsson and Ognjen Arandjelovi´c Predicting Drug Functions from Adverse Drug Reactions by Multi-label Deep Neural Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 215 Pranab Das and Dilwar Hussain Mazumder Pattern Discovery in Physiological Data with Byte Pair Encoding . . . . . . 227 Nazgol Tavabi and Kristina Lerman Predicting ICU Admissions for Hospitalized COVID-19 Patients with a Factor Graph-based Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 245 Yurui Cao, Phuong Cao, Haotian Chen, Karl M. Kochendorfer, Andrew B. Trotter, William L. Galanter, Paul M. Arnold, and Ravishankar K. Iyer Semantic Network Analysis of COVID-19 Vaccine Related Text from Reddit . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 257 Chad A. Melton, Jintae Bae, Olufunto A. Olusanya, Jon Hael Brenas, Eun Kyong Shin, and Arash Shaban-Nejad
Contents
ix
Towards Providing Clinical Insights on Long Covid from Twitter Data . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 267 Rohan Bhambhoria, Jad Saab, Sara Uppal, Xin Li, Artur Yakimovich, Junaid Bhatti, Nirma Khatri Valdamudi, Diana Moyano, Michael Bales, Elham Dolatabadi, and Sedef Akinli Kocak Predicting Infections in the Covid-19 Pandemic—Lessons Learned . . . . . 279 Sharare Zehtabian, Siavash Khodadadeh, Damla Turgut, and Ladislau Bölöni Improving Radiology Report Generation with Adaptive Attention . . . . . 293 Lin Wang and Jie Chen Instantaneous Physiological Estimation Using Video Transformers . . . . . 307 Ambareesh Revanur, Ananyananda Dasari, Conrad S. Tucker, and László A. Jeni Automated Vision-Based Wellness Analysis for Elderly Care Centers . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 321 Xijie Huang, Jeffry Wicaksana, Shichao Li, and Kwang-Ting Cheng Efficient Extraction of Pathologies from C-Spine Radiology Reports Using Multi-task Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 335 Arijit Sehanobish, Nathaniel Brown, Ishita Daga, Jayashri Pawar, Danielle Torres, Anasuya Das, Murray Becker, Richard Herzog, Benjamin Odry, and Ron Vianu Benchmarking Uncertainty Quantification on Biosignal Classification Tasks Under Dataset Shift . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 347 Tong Xia, Jing Han, and Cecilia Mascolo Mining Adverse Drug Reactions from Unstructured Mediums at Scale . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 361 Hasham Ul Haq, Veysel Kocaman, and David Talby A Graph-based Imputation Method for Sparse Medical Records . . . . . . . 377 Ramon Viñas, Xu Zheng, and Jer Hayes Using Nursing Notes to Predict Length of Stay in ICU for Critically Ill Patients . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 387 Sudeshna Jana, Tirthankar Dasgupta, and Lipika Dey Automatic Classification of Dementia Using Text and Speech Data . . . . . 399 Hee Jeong Han, Suhas B. N., Ling Qiu, and Saeed Abdullah Unified Tensor Network for Multimodal Dementia Detection . . . . . . . . . . 409 Truong Hoang, Thuy-Trinh Nguyen, and Hoang D. Nguyen
Contributors
Abdullah Saeed College of Information Sciences and Technology, Pennsylvania State University, University Park, PA, USA An Zhiyong School of Computer Science and Technology, Shandong Technology and Business University, Yantai, PR, China; School of Statistics, Shandong Technology and Business University, Yantai, PR, China Arandjelovi´c Ognjen University of St Andrews, St Andrews, Scotland, UK Arnold Paul M. Carle Foundation Hospital, Urbana, IL, USA B. N. Suhas College of Information Sciences and Technology, Pennsylvania State University, University Park, PA, USA Bae Jintae Korea University, Seoul, South Korea Bales Michael Hoffmann-La Roche Ltd., Mississauga, ON, Canada Becker Murray Covera Health, NYC, New York, USA Behera Ardhendu Edge Hill University, Ormskirk, UK Bhambhoria Rohan Queen’s University, Kingston, ON, Canada Bhatti Junaid Manulife, Toronto, ON, Canada Bianco Simone Altos Labs—Bay Area Institute of Science, BAI Computational Innovation Hub, Redwood City, CA, USA Brenas Jon Hael Sanger Institute, Cambridge, UK Brown Nathaniel Covera Health, NYC, New York, USA Bölöni Ladislau Department of Computer Science, University of Central Florida, Orlando, FL, USA Cao Phuong University of Illinois Urbana-Champaign, Champaign, IL, USA
xi
xii
Contributors
Cao Yurui University of Illinois Urbana-Champaign, Champaign, IL, USA Carenini Giuseppe University of British Columbia, Endowment Lands, BC, Canada Chen Haotian University of Illinois Urbana-Champaign, Champaign, IL, USA Chen Jie Peking University, Beijing, China Chen Nancy F. Institute for Infocomm Research (I2R) A*STAR, Singapore, Singapore Chen Yifu University of British Columbia, Endowment Lands, BC, Canada Cheng Kwang-Ting Department of Computer Science and Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong Cinà Giovanni Pacmed BV, Amsterdam, The Netherlands Corsar David Robert Gordon University, Aberdeen, Scotland Daga Ishita Covera Health, NYC, New York, USA Das Anasuya Covera Health, NYC, New York, USA Dasari Ananyananda Department of Mechanical Engineering, Carnegie Mellon University, Pittsburgh, USA Das Pranab Department of Computer Science and Engineering, National Institute of Technology Nagaland, Dimapur, Nagaland, India Dasgupta Tirthankar TCS Research, Kolkata, India Debnath Bappaditya Edge Hill University, Ormskirk, UK Denaxas Spiros Institute of Health Informatics, University College London, London, UK Dey Lipika TCS Research, Kolkata, India Dolatabadi Elham Vector Institute, Toronto, ON, Canada Dragojlovic Nick University of British Columbia, Endowment Lands, BC, Canada Elbers Paul Department of Intensive Care Medicine, Laboratory for Critical Care Computational Intelligence, Amsterdam Medical Data Science, Amsterdam UMC, Vrije Universiteit, Amsterdam, The Netherlands Galanter William L. University of Illinois Hospital & Health Sciences System, Chicago, IL, USA George Amy University of British Columbia, Endowment Lands, BC, Canada Giyahchi Tootiya University of California, Irvine, CA, USA
Contributors
xiii
Guo Yike Pangaea Data Limited, London, UK; Data Science Institute, Imperial College London, London, UK; Hong Kong Baptist University, Hong Kong SAR, Hong Kong, China Gupta Akshay Indian Institute of Technology, Kanpur, India Gupta Vibhor Pangaea Data Limited, London, UK Han Hee Jeong College of Information Sciences and Technology, Pennsylvania State University, University Park, PA, USA Han Jing Department of Computer Science and Technology, University of Cambridge, Cambridge, UK Haq Hasham Ul John Snow Labs Inc., Lewes, DE, USA Harris Ian University of California, Irvine, CA, USA Harrison Mark University of British Columbia, Endowment Lands, BC, Canada Hayes Jer Accenture Labs Dublin, Dublin, Ireland Herzog Richard Covera Health, NYC, New York, USA Hoang Truong University of Science of HCMC, Ho Chi Minh City, Vietnam Huang Xijie Center for Aging Science, Department of Computer Science and Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong Ive Julia Department of Computing, Imperial College London, London, UK; Queen Mary University of London, London, UK Iyer Ravishankar K. University of Illinois Urbana-Champaign, Champaign, IL, USA Jana Sudeshna TCS Research, Kolkata, India Jeni László A. Robotics Institute, Carnegie Mellon University, Pittsburgh, USA Jiang Yu University of Notre Dame, Notre Dame, IN, USA Johnson David University of British Columbia, Endowment Lands, BC, Canada Khodadadeh Siavash Department of Computer Science, University of Central Florida, Orlando, FL, USA Kocak Sedef Akinli Vector Institute, Toronto, ON, Canada Kocaman Veysel John Snow Labs Inc., Lewes, DE, USA Kochendorfer Karl M. University of Illinois Hospital & Health Sciences System, Chicago, IL, USA Kopac Nicola University of British Columbia, Endowment Lands, BC, Canada
xiv
Contributors
Krishnaswamy Pavitra Institute for Infocomm Research (I2R) A*STAR, Singapore, Singapore Kumar Swagat Edge Hill University, Ormskirk, UK Le Huray Sarah Multiple Sclerosis Society of Canada, Toronto, ON, Canada Lenzen Marilyn Multiple Sclerosis Society of Canada, Toronto, ON, Canada Lerman Kristina Information Science Institute, USC, Marina del Rey, CA, USA Li Bo School of Computer Science and Technology, Shandong Technology and Business University, Yantai, PR, China; School of Statistics, Shandong Technology and Business University, Yantai, PR, China Li Shichao Department of Computer Science and Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong Li Xin University of Toronto, Toronto, ON, Canada; Vector Institute, Toronto, ON, Canada Liu Zhengyuan Institute for Infocomm Research (I2R) A*STAR, Singapore, Singapore Lynd Larry University of British Columbia, Endowment Lands, BC, Canada Martin Kyle Robert Gordon University, Aberdeen, Scotland Mascolo Cecilia Department of Computer Science and Technology, University of Cambridge, Cambridge, UK Mazumder Dilwar Hussain Department of Computer Science and Engineering, National Institute of Technology Nagaland, Dimapur, Nagaland, India Melton Chad A. The Bredesen Center for Interdisciplinary Research and Graduate Education, University of Tennessee, Knoxville, USA; Center for Biomedical Informatics, Department of Pediatrics, College of Medicine, University of Tennessee Health Science Center, Memphis, TN, USA Michalowski Martin School of Nursing, University of Minnesota, Minneapolis, MN, USA Moyano Diana Vector Institute, Toronto, ON, Canada Nanayakkara Gayani Robert Gordon University, Aberdeen, Scotland Ng Raymond University of British Columbia, Endowment Lands, BC, Canada Nguyen Hoang D. School of Computer Science and Information Technology, University College Cork, Cork, Ireland Nguyen Thuy-Trinh School of Computing Science, University of Glasgow, Glasgow, UK
Contributors
xv
O’brien Mary Edge Hill University, Ormskirk, UK Odry Benjamin Covera Health, NYC, New York, USA Olusanya Olufunto A. Center for Biomedical Informatics, Department of Pediatrics, College of Medicine, University of Tennessee Health Science Center, Memphis, TN, USA Pawar Jayashri Covera Health, NYC, New York, USA Pechmann Cornelia University of California, Irvine, CA, USA Poellabauer Christian Florida International University, Miami, FL, USA Pollard Samantha BC Cancer Agency, Vancouver, BC, Canada Qiu Ling College of Information Sciences and Technology, Pennsylvania State University, University Park, PA, USA Regier Dean University of British Columbia, Endowment Lands, BC, Canada; BC Cancer Agency, Vancouver, BC, Canada Revanur Ambareesh Robotics Institute, Carnegie Mellon University, Pittsburgh, USA Saab Jad Telus Communications Inc., Vancouver, BC, Canada Sehanobish Arijit Covera Health, NYC, New York, USA Shaban-Nejad Arash Center for Biomedical Informatics, Department of Pediatrics, College of Medicine, The University of Tennessee Health Science Center— Oak-Ridge National Lab (UTHSC-ORNL), Memphis, TN, USA; The Bredesen Center for Interdisciplinary Research and Graduate Education, University of Tennessee, Knoxville, USA Shin Eun Kyong Korea University, Seoul, South Korea Sim Aaron Benevolent AI, London, UK Singh Sameer University of California, Irvine, CA, USA Srivastava Biplav AI Institute, University of South Carolina, Columbia, USA Sun Mengze School of Statistics, Shandong Technology and Business University, Yantai, PR, China Talby David John Snow Labs Inc., Lewes, DE, USA Tanwar Ashwani Pangaea Data Limited, London, UK Tavabi Nazgol BCH, Harvard Medical School, Boston, MA, USA Thoral Patrick Department of Intensive Care Medicine, Laboratory for Critical Care Computational Intelligence, Amsterdam Medical Data Science, Amsterdam UMC, Vrije Universiteit, Amsterdam, The Netherlands
xvi
Contributors
Tomlinson Hamish Benevolent AI, London, UK Toni Francesca Imperial College London, London, UK Torres Danielle Covera Health, NYC, New York, USA Trotter Andrew B. University of Illinois Hospital & Health Sciences System, Chicago, IL, USA Tucker Conrad S. Department of Mechanical Engineering, Carnegie Mellon University, Pittsburgh, USA Turgut Damla Department of Computer Science, University of Central Florida, Orlando, FL, USA Uppal Sara Telus Communications Inc., Vancouver, BC, Canada Valdamudi Nirma Khatri University of British Columbia, Vancouver, BC, Canada Valsson Steinar University of St Andrews, St Andrews, Scotland, UK Vauvelle Andre Institute of Health Informatics, University College London, London, UK Vianu Ron Covera Health, NYC, New York, USA Viñas Ramon University of Cambridge, Cambridge, UK Wang Lin School of Electronic and Computer Engineering, Peking University, Shenzhen, China Wicaksana Jeffry Department of Electronic and Computer Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong Wijekoon Anjana Robert Gordon University, Aberdeen, Scotland Williams Matthew Imperial College London, London, UK Wiratunga Nirmalie Robert Gordon University, Aberdeen, Scotland Xia Tong Department of Computer Science and Technology, University of Cambridge, Cambridge, UK Xiang Zhongliang School of Computer Science and Technology, Shandong Technology and Business University, Yantai, PR, China Yakimovich Artur Roche Products Ltd., Welwyn Garden City, UK Yu Yuan School of Computer Science and Technology, Shandong Technology and Business University, Yantai, PR, China Zadorozhny Karina Pacmed BV, Amsterdam, The Netherlands Zehtabian Sharare Department of Computer Science, University of Central Florida, Orlando, FL, USA
Contributors
xvii
Zhang D. Kai Imperial College London, London, UK Zhang Jingqing Pangaea Data Limited, London, UK; Data Science Institute, Imperial College London, London, UK Zhao Yuanyuan School of Clinical Medicine, Qilu Medical University, Zibo, PR, China Zheng Xu Accenture Labs Dublin, Dublin, Ireland
Abbreviations
ABSA ACOUSTICS AD ACTS ADE ADL ADM ADR AE AGR ALM ASR AUC AUROC AUPRC BART BERT BI-DAF BPE CLM CNN CON COTE COVID-19 CRF CT-BIRT CUI CVVH DAE DCES
Aspect-Based Sentiment Analysis AutomatiC classificatiOn of sUbjectS with demenTIa and healthy Controls using text transcriptions and Speech data Alzheimer’s Diseases Attention Crossing Time Series Adverse Drug Events Activities of Daily Living Alternating Direction Minimizing Adverse Drug Reactions Autoencoder Agreeableness Augmented Lagrange Multiplier Automatic Speech Recognition Area Under Curve Area Under the Receiver-Operator Curve Area Under the Precision-Recall Curve Bidirectional and Auto-Regressive Transformer Bidirectional Encoder Representations from Transformers Bi-Directional Attention Flow network Byte Pair Encoding Contextual Language Model Convolutional Neural Networks Conscientiousness Collection Of Transformation Ensembles Coronavirus Disease 2019 Conditional Random Fields COVID-Twitter-BERT Concept Unique Identifier Continuous Veno-Venous Hemofiltration Denoising Autoencoder Discrete Choice Experiment Surveys xix
xx
DFT DL DOR DTW DUE ED EEG EHR EXPECT-NLP EXT FCNN FG FL GBSG GCD GMM GNN GP GWAS HPO HRNN HRV ICD10 ICU IGBT IPPG ITP IQR LD LDA LIME LOF LRWL LSTM MAET MAP MedDRA METABRIC MGT MHAA MICE MIMIC ML MLM
Abbreviations
Discrete Fourier Transform Deep Learning Diagnostic Odds Ratio Dynamic Time Warping Deterministic Uncertainty Estimation Emergency Department Electroencephalogram Electronic Health Records EXploration of Preferences and Experiences in Collected Texts using Natural Language Processing Extraversion Fully Connected Neural Network Factor Graph Federated learning German Breast Cancer Study Group Gastrointestinal Clinical Dialogue Gaussian Mixture Model Graph Neural Network Gaussian process Genome-Wide Association Studies Human Phenotype Ontology Hierarchical Recurrent Network Heart-Rate Variability International Classification of Diseases, Tenth revision Intensive Care Unit Initial Ground Truth Battery Imaging PPG Individual Task Proficiency Interquartile Range linkage disequilibrium Latent Dirichlet Allocation Local Interpretable Model-Agnostic Explanations Local Outlier Factor Latent Representation Weight Learning Long Short-Term Memory Mask Adherence Estimation Tool Mean Average Precision Medical Dictionary for Regulatory Activities Molecular Taxonomy of Breast Cancer International Consortium Mobile Ground Truth Multi-Head Adaptive Attention Multivariate Imputation by Chained Equations Medical Information Mart for Intensive Care Machine Learning Masked Language Model
Abbreviations
MMML MMSE MSE MSTamps MTCNN MTL N2C2 NCR NER NEU NHS NLL NLP NMS NPI NR NRT NSP OOD OPD OxCGRT PACS PCF PD-BPE PET PGM POS PPCA PPG PTT RR RI RISE RNN ROC ROCKET RPPG SAMD SAX SHAP SMM4H SNR SOTA SVM TF-IDF
xxi
Multimodal Machine Learning Mini-Mental State Examination Mean Squared Error Multi-scale Spatio-Temporal Maps Multi-task Cascaded Convolutional Network Multitask learning National Clinical NLP Challenges Neural Concept Recognizer Name Entity Recognition Neuroticism National Health Service Negative Log Likelihood Natural Language Processing Non-Maximum Suppression Non-Pharmaceutical Intervention Numerical Reasoning Nicotine Replacement Therapy Next Sentence Prediction Out-of-Distribution Openness Oxford COVID-19 Government Response Tracker Post-Acute Sequelae Of SARS-Cov-2 Preweighted Convolution Filtering Pattern Discovery with Byte Pair Encoding Positron Emission Tomography Probabilistic Graphical Model Part-of-Speech Probabilistic Principal Component Analysis Photoplethysmography Pulse-Transit Time Respiration Rate Random Initialization Randomized Input Sampling for Explanation Recurrent Neural Networks Receiver Operating Characteristic RandOm Convolutional KErnel Transform Remote-Photoplethysmography Software as a Medical Device Symbolic Aggregate Approximation Shapley Additive Explanations Social Media Mining for Healthcare Signal-To-Noise Ratio State-Of-The-Art Support Vector Machine Term Frequency-Inverse Document Frequency
xxii
TILES UMAP UMLS V4V VAE VE VP WER WLS
Abbreviations
Tracking Individual Performance with Sensors Uniform Manifold Approximation and Projection Unified Medical Language System Vision-for-Vitals Variational Autoencoders Vaccine Ego View Position Word Error Rate Wisconsin Longitudinal Study
Multimodal Artificial Intelligence: Next Wave of Innovation in Healthcare and Medicine Arash Shaban-Nejad, Martin Michalowski, and Simone Bianco
Abstract Multimodality refers to the utilization of different data types with different representational modes. Medical and health data are becoming more and more multimodal. Emerging multimodal technologies enable users to access, integrate and process multi-modal data and interact with a system in different modalities at the same time. Multimodal artificial intelligence (AI) particularly attempts to process, manage and understand these multimodal data through making multimodal inferences. In biology, medicine, and health, multimodal AI can assist in analyzing complex associations and relationships between various biological processes, health indicators, risk factors, and health outcomes, and developing exploratory and explanatory models. This chapter aims to introduce the concept of multimodal AI and discuss some of its applications in health and biomedicine. Keywords Multimodality · Multimodal artificial intelligence · Digital health · Health intelligence
1 Introduction Multimodal AI aims at integrating two or more data streams to increase the accuracy of information extraction and inference, reduce bias and generate an overall better representation of the physical, medical or societal processes described by the A. Shaban-Nejad (B) Center for Biomedical Informatics, Department of Pediatrics, College of Medicine, The University of Tennessee Health Science Center—Oak-Ridge National Lab (UTHSC-ORNL, Memphis, TN, USA e-mail: [email protected] M. Michalowski School of Nursing, University of Minnesota, Minneapolis, MN, USA e-mail: [email protected] S. Bianco Altos Labs—Bay Area Institute of Science, BAI Computational Innovation Hub, Redwood City, CA, USA e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_1
1
2
A. Shaban-Nejad et al.
data. Types of data can include images and videos, text and language, specialized data types like various -omics (genomics, transcriptomics) and laboratory data, as well as machine-readable metadata associated with medical instruments or specific algorithms. Independently, AI has achieved much progress in the analysis and understanding of all these data streams and more in a healthcare and life science context, to solve problems like classification, recognition, and clustering of various complex objects, semantic classification, differential diagnosis, and prognosis, and even protein classification and de novo design. However, their integration is still a challenge and requires a new generation of algorithms. Data can be integrated in various ways during the training phase of a deep neural architecture: Data or features extracted from the data can be combined prior to learning, which is sometimes referred to as the early fusion strategy; Early fusion includes the combination of features obtained both through manual extraction and through automatic methods (e.g., deep feature extraction), and it is known to improve consistency over single modality. When intermediate layers from different modalities are combined, one talks of joint or intermediate fusion. Joint fusion is also known to improve performance on a combination of text and images, for example, although interdependence across the data types may be required for this technique to work. It is late fusion when the decision layers of the neural networks are combined (e.g., ensemble predictions). This multimodal strategy works optimally when data is not interdependent; Both early and joint when the other two approaches may fail [1]. In this chapter, we outline recent progress in multimodal AI and present applications in healthcare.
2 Clinical and Biomedical Applications of Multimodal AI and Data Science Precision digital health and medicine [2, 3] paved the way to incorporate multidimensional and multimodal data sets into health analytics. AI and data-driven approaches can enable data/knowledge fusion and analytics by integrating various distributed and heterogeneous datasets consisting of different data types such as text, image, and sensor data to improve case finding/ prediction, diagnosis, risk stratification, referrals, and follow up and decision-making by health professionals and policymakers. They can also improve interoperability, interpretability, and explainability [4, 5] to minimize biases and errors in decision-making and reduce the ethical risks [6] of data-driven recommendations and discovery. The recent COVID-19 pandemic has increased the need for multimodal tools for the prediction, prevention, and management of the disease at multiple scales, from the single patient to the population [7–10]. From the analysis of the epidemiological data and environmental co-morbidities to the diagnosis and prognosis of cases based on electronic health records and x-rays/CT scans, multimodal AI has played a major role in the pandemic and has provided researchers and practitioners with a wealth
Multimodal Artificial Intelligence: Next Wave …
3
of important solutions to scientific and clinical research problems. Moreover, the extensive literature [11–15] that has emerged has certainly increased our awareness of the possibilities around the use of these techniques closer to the clinic, creating a playbook that will certainly be useful in the management of future crises, beyond COVID-19. Multimodal approaches have also been used in different domains of biomedicine and health such as chronic disease surveillance [16], screening and assessing child mental health [17, 18], oncology [19, 20], emotion detection [21], ophthalmology [22], personal health libraries [23] and detecting dementia [24]. Baltrusaitis et al. [25] provide a survey of recent advances in multimodal machine learning.
3 Advances in AI Technologies and Data Analytics in Healthcare Some other examples of applications of AI in biomedicine and healthcare that will be discussed through different chapters in this book are the following. Current state-of-the-art phenotyping models can detect general phenotypes but perform poorly when they detect phenotypes requiring numerical reasoning. Tanwar et al. [26] presented an unsupervised methodology leveraging external knowledge and contextualized word embeddings from ClinicalBERT for numerical reasoning in a variety of phenotypic contexts. Clinical dialogue is a conversation between health practitioners and their patients, with the explicit goal of obtaining and sharing medical information. This information contributes to medical decision-making regarding the patient and plays a crucial role in their healthcare journey. Nanayakkara et al. [27] presented a seq2seq learning approach for Automatic Speech Recognition (ASR) transcription error correction of clinical dialogues. Liu et al. [28] also proposed a domain-specific language pre-training, to improve performance on downstream tasks like clinical dialogue comprehension with a focus on the low resource training scenario. To encourage engagement and provide evidence-based responses appropriate to participants’ needs, Giahchi et al. [29] proposed an intent detection model for online support groups for health and mental health support based on natural language processing methods that enable a chatbot to increase interactions and improve discussion quality. Johnson et al. [30] presented a method that allows users (e.g. health services researchers, drug developers, or market researchers) to assess patient preferences for drug therapies semi-automatically from online patient-generated text using a weakly supervised aspect-based sentiment analysis pipeline. Medication errors most commonly occur at the ordering or prescribing stage. Jiang and Poellabauer [31] demonstrated how to use BERT-based contextual language models to detect anomalies in written or spoken text based on a data set extracted from real-world medical data of thousands of patient records. The proposed models can learn patterns of text
4
A. Shaban-Nejad et al.
dependency and predict erroneous output based on contextual information such as patient data. Li et al. [32] proposed Latent Representation Weight Learning (LRWL) to learn the latent representative weight of each image or view for conception diagnosis and then integrate the views with the weights and the diagnostic indexes as part of the input data for deep learning to predict successful conception. Vauvelle et al. [33] presented a phenotyping method, which uses anchor learning and transformers to generate continuous phenotypes that allow for the detection of significant genomic associations with smaller cohorts. Zadorozhny et al. [34] proposed a series of practical considerations and tests to choose the best Detection of Out-of-Distribution (OOD) samples for a specific medical dataset. Gupta and Srivastava [35] proposed a Mask Adherence Estimation Tool (MAET) based on the pre-trained YOLOv5 object detection model and combine it with explanation methods to help the user understand mask adherence at an individual and aggregate level. Zhang et al. [36] presented a federated Cox model that accommodates healthcare data sets and also relaxes the proportional hazards assumption, allowing time-varying covariate effects. Debnath et al. [37] developed Functional-ADL, a novel dataset to improve action recognition. Functional-ADL facilitates multi-label and impairedspecific executions of different Activities of Daily Living (ADL) to contribute toward vision-based automated assessment and rehabilitation of physically impaired persons. Using the ChestX-ray14 dataset Valsson and Arandjelovic [38] presented a thorough analysis of a state-of-the-art learning approach, and discuss the weaknesses of some of the existing algorithms in the literature. Das and Mazumder [39] employed a multi-label prediction-supported deep neural network methodology to classify drug functions. They also addressed the issue of class imbalance using the Multilabel Synthetic Minority Oversampling Technique (MLSMOTE). To improve pattern discovery in physiological data, Tavabi and Lerman [40] proposed an unsupervised method for learning representations of time series based on common patterns identified within them. This method captures both long-term and short-term dependencies present in the data. and applies to both univariate and multivariate time series. Cao et al. [41] presented a factor graph-based model that takes comorbidities and clinical measurements as inputs and predicts intensive care unit (ICU) admissions three days and seven days in advance for hospitalized COVID-19 patients. This model explains relationships between different clinical features and provides interpretations for ICU admissions. Vaccine-related hesitancy, mis/disinformation, and anti-vaccination discourse are hindering the rapid uptake of the COVID-19 vaccine. To facilitate and promote efficient messaging strategies/campaigns to improve the vaccination rate Melton et al. [42] investigated the COVID-19 vaccine hesitancy diffusion networks in an online Reddit community within the initial phase of the COVID-19 pandemic. By leveraging Twitter data, Bhambhoria et al. [43] proposed a method using entity-extraction methods for providing clinical insights into postacute sequelae of SARS-CoV-2 (PACS), or Long COVID before defining subsequent downstream tasks.
Multimodal Artificial Intelligence: Next Wave …
5
Zehtabian et al. [44] started with prediction algorithms proposed for XPrize Pandemic Response Challenge and consider several directions that might allow their improvement. Then, they investigated their performance over medium-term predictions extending over several months. Wang and Chen [45] proposed an architecture with a Contrastive Language-Image Pre-Training (CLIP)-based visual extractor and Multi-Head Adaptive Attention (MHAA) module to improve radiology report generation. Revanur et al. [46] proposed a video Transformer for estimating instantaneous heart rate and respiration rate from face videos. Physiological signals are typically confounded by alignment errors in space and time. To overcome this, they formulated the loss in the frequency domain. Huang et al. [47] developed an automatic, vision-based system for monitoring and analyzing the physical and mental well-being of senior citizens. Utilizing a radiologist’s report dataset on the cervical spine, Sehanobish et al. [48] showed that a multi-task model can beat or achieve the performance of multiple BERT-based models finetuned on various tasks and various task-specific adapter augmented BERT-based models. Xia et al. [49] used three classification tasks based on respiratory sounds and electrocardiography signals to benchmark five representative uncertainty quantification methods. Ul Haq et al. [50] proposed an NLP solution that detects adverse drug reactions (ADR) in unstructured free-text conversations. The high degree of missingness of data in Electronic Medical Records (EHRs) can be attributed to many factors, including device failure, privacy concerns, or other unexpected reasons. Viñas et al. [51] proposed a graph-based imputation method that is both robust to sparsity and unreliable unmeasured events, that can facilitate the diagnosis of novel diseases based on the clinical history of past events. Jana et al. [52] proposed the use of the convolutional neural network (CNN) and Long Short-Term Memory networks (LSTM)-based prediction networks along with transformer-based language models for representing the data from nursing notes to predict the length of stay in the ICU for critically ill patients. Han et al. [53] described the development of ACOUSTICS (AutomatiC classificatiOn of sUbjectS with demenTIa and healthy Controls using text transcriptions and Speech data)—an ensemble model with two deep learning-based architectures for text and speech analysis to detect dementia. Hoang et al. [54] proposed an end-to-end multimodal analysis pipeline for Alzheimer’s dementia detection.
References 1. Huang, S. C., Pareek, A., & Seyyedi, S., et al. (2020). Fusion of medical imaging and electronic health records using deep learning: a systematic review and implementation guidelines. npj Digital Medicine, 3, 136. 2. Shaban-Nejad, A., & Michalowski, M. (2020). Precision health and medicine—A digital revolution in healthcare. Studies in Computational Intelligence, 843, Springer, ISBN 978-3-03024408-8.
6
A. Shaban-Nejad et al.
3. Shaban-Nejad, A., Michalowski, M., Peek, N., Brownstein, J. S., & Buckeridge, D. L. (2020). Seven pillars of precision digital health and medicine. Artificial Intelligence in Medicine, 103, 101793. 4. Shaban-Nejad, A., Michalowski, M., Brownstein, J. S., & Buckeridge, D. L. (2021). Guest editorial explainable AI: Towards fairness, accountability, transparency and trust in healthcare. IEEE Journal of Biomedical Health Informatics, 25(7), 2374–2375. 5. Shaban-Nejad, A., Michalowski, M., Buckeridge, D. L. (2021). Explainability and interpretability: Keys to deep medicine. In A. Shaban-Nejad, M. Michalowski, D. L. Buckeridge (Eds.), Explainable AI in healthcare and medicine (Vol. 914). Studies in computational intelligence. Springer, Cham. https://doi.org/10.1007/978-3-030-53352-6_1 6. Mamiya, H., Shaban-Nejad, A., Buckeridge, D. L. (2017). Online public health intelligence: Ethical considerations at the big data era. In A. Shaban-Nejad, J. Brownstein, D. Buckeridge (Eds.), Public health intelligence and the internet. Lecture notes in social networks. Springer, Cham. https://doi.org/10.1007/978-3-319-68604-2_8 7. Santosh, K. C. (2020). AI-driven tools for coronavirus outbreak: Need of active learning and cross-population train/test models on multitudinal/multimodal data. Journal of Medical Systems, 44(5), 1–5. 8. Chen, J., & See, K. C. (2020). Artificial intelligence for COVID-19: Rapid review. Journal of Medical Internet Research, 22(10), e21476. 9. Brakefield, W. S., Ammar, N., & Shaban-Nejad, A. (2021). UPHO: Leveraging an explainable multimodal big data analytics framework for COVID-19 surveillance and research. In 2021 IEEE International Conference on Big Data (Big Data) (pp. 5854–5858). https://doi.org/10. 1109/BigData52589.2021.9671429 10. Brakefield, W. S., Ammar, N., Olusanya, O. A., & Shaban-Nejad, A. (2021). An urban population health observatory system to support COVID-19 pandemic preparedness, response, and management: Design and development study. JMIR Public Health and Surveillance, 7(6), e28269. https://doi.org/10.2196/28269 11. Mason, A. E., Hecht, F. M., Davis, S. K., et al. (2022). Detection of COVID-19 using multimodal data from a wearable device: Results from the first TemPredict Study. Science and Reports, 12(1), 3463. https://doi.org/10.1038/s41598-022-07314-0.Erratum.In:SciRep. 2022Mar16;12(1):4568 12. Domingo-Fernández, D., Baksi, S., Schultz, B., Gadiya, Y., Karki, R., Raschka, T., Ebeling, C., Hofmann-Apitius, M., & Kodamullil, A. T. (2021). COVID-19 Knowledge Graph: A computable, multi-modal, cause-and-effect knowledge model of COVID-19 pathophysiology. Bioinformatics, 37(9), 1332–1334. https://doi.org/10.1093/bioinformatics/btaa834 13. Naumov, V., Putin, E., Pushkov, S., et al. (2021). COVIDomic: A multi-modal cloud-based platform for identification of risk factors associated with COVID-19 severity. PLoS Computational Biology, 17(7), e1009183. https://doi.org/10.1371/journal.pcbi.1009183 14. Tan, T., Das, B., Soni, R., et al. (2022). Multi-modal trained artificial intelligence solution to triage chest X-ray for COVID-19 using pristine ground-truth, versus radiologists. Neurocomputing, 485, 36–46. 7 May 2022. https://doi.org/10.1016/j.neucom.2022.02.040 15. Chen, Y., Ouyang, L., Bao, F. S., Li, Q., Han, L., Zhang, H., Zhu, B., Ge, Y., Robinson, P., Xu, M., Liu, J., & Chen, S. (2021). A Multimodality machine learning approach to differentiate severe and nonsevere COVID-19: Model development and validation. Journal of Medical Internet Research, 23(4), e23948. https://doi.org/10.2196/23948 16. Brakefield, W. S., Ammar, N., & Shaban-Nejad, A. (2022). An urban population health observatory for disease causal pathway analysis and decision support: Underlying explainable artificial intelligence model. JMIR Formative Research, 6(7), e36055. https://doi.org/10.2196/36055 17. Ammar, N., & Shaban-Nejad, A. (2020). Explainable artificial intelligence recommendation system by leveraging the semantics of adverse childhood experiences: Proof-of-concept prototype development. JMIR Medical Informatics, 4;8(11), e18752. 18. Ammar, N., Zareie, P., Hare, M.E., Rogers, L., Madubuonwu, S., Yaun, J., & Shaban-Nejad, A. (2021). SPACES: Explainable multimodal ai for active surveillance, diagnosis, and management of adverse childhood experiences (ACEs). In 2021 IEEE International Conference on Big Data (Big Data) (pp. 5843–5847)
Multimodal Artificial Intelligence: Next Wave …
7
19. Boehm, K. M., Khosravi, P., Vanguri, R., Gao, J., & Shah, S. P. (2022). Harnessing multimodal data integration to advance precision oncology. Nature Reviews Cancer, 22(2), 114–126. https:// doi.org/10.1038/s41568-021-00408-3.(2021) 20. Skrede, O. J., De Raedt, S., Kleppe, A., et al. (2020). Deep learning for prediction of colorectal cancer outcome: A discovery and validation study. Lancet, 395(10221), 350–360. https://doi. org/10.1016/S0140-6736(19)32998-8 21. Marechal, C., Mikołajewski, D., Tyburek, K., Prokopowicz, P., Bougueroua, L., Ancourt, C., & W˛egrzyn-Wolska, K. (2019). Survey on AI-based multimodal methods for emotion detection. In J. Kołodziej, H. González-Vélez (Eds.), High-Performance modelling and simulation for big data applications. (Vol. 11400, pp. 307–324). Lecture notes in computer science. Cham: Springer. https://doi.org/10.1007/978-3-030-16272-6_11 22. Xiong, J., Li, F., Song, D., Tang, G., He, J., et al. (2022). Multimodal machine learning using visual fields and peripapillary circular OCT scans in detection of glaucomatous optic neuropathy. Ophthalmology, 129(2), 171–180. https://doi.org/10.1016/j.ophtha.2021.07.032 23. Ammar, N., Bailey, J. E., Davis, R. L., & Shaban-Nejad, A. (2021). Using a personal health library-enabled mHealth recommender system for self-management of diabetes among underserved populations: Use case for knowledge graphs and linked data. JMIR Formative Research, 16;5(3), e24738. https://doi.org/10.2196/24738 24. Ilias, L., & Askounis, D. (2022). Multimodal deep learning models for detecting dementia from speech and transcripts. Frontiers in Aging Neuroscience, 17(14), 830943. https://doi.org/10. 3389/fnagi.2022.830943 25. Baltrusaitis, T., Ahuja, C., & Morency, L. P. (2019). Multimodal machine learning: A survey and taxonomy. IEEE Transactions on Pattern Analysis and Machine Intelligence, 41(2), 423–443. 26. Tanwar, A., Zhang, J., Ive, J., Gupta, V., & Guo, Y. (2022). Unsupervised numerical reasoning to extract phenotypes from clinical text by leveraging external knowledge. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 27. Nanayakkara, G., Wiratunga, N., Corsar, D., Martin, K., & Wijekoon, A. (2022). Clinical dialogue transcription error correction using Seq2Seq models. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 28. Liu, Z., Krishnaswamy, P., & Chen, N. F. (2022). Domain-specific language pre-training for dialogue comprehension on clinical inquiry-answering conversations. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 29. Giyahchi, T., Singh, S., Harris, I., & Pechmann, C. (2022). Customized training of pretrained language models to detect post intents in online health support groups. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 30. Johnson, D., Dragojlovic, N., Kopac, N., Chen, Y., Lenzen, M., Le Huray, S., Pollard, S., Regier, D., Harrison, M., George, A., Carenini, G., Ng, R., & Lynd, L. (2022). EXPECTNLP: An integrated pipeline and user interface for exploring patient preferences directly from patient-generated text. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 31. Jiang, Y., & Poellabauer, C. (2022). Medication error detection using contextual language models. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 32. Li, B., Sun, M., Yu, Y., Zhao, Y., Xiang, Z., & An, Z. (2022). Latent representation weights learning of the indefinite length of views for conception diagnosis. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 33. Vauvelle, A., Tomlinson, H., Sim, A., & Denaxas, S. (2022). Phenotyping with positive unlabelled learning for genome-wide association studies. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer.
8
A. Shaban-Nejad et al.
34. Zadorozhny, K., Thoral, P., Elbers, P., & Cinà, G. (2022). Out-of-distribution detection for medical applications: Guidelines for practical evaluation. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer 35. Gupta, A., & Srivastava, B. (2022). A robust system to detect and explain public mask wearing behavior. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 36. Zhang, D. K., Toni, F., & Williams, M. (2022). A federated cox model with non-proportional hazards. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 37. Debnath, B., O’brien, M., Kumar, S., & Behera, A. (2022). A step towards automated functional assessment of activities of daily living. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 38. Valsson, S., & Arandjelovic, O. (2022). The interpretation of deep learning based analysis of medical images—An examination of methodological and practical challenges using chest X-ray data. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 39. Das, P., & Mazumder, D. H. (2022). Predicting drug functions from adverse drug reactions by multi-label deep neural network. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 40. Tavabi, N., & Lerman, K. (2022). Pattern discovery in physiological data with byte pair encoding. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 41. Cao, Y., Cao, P., Chen, H., Kochendorfer, K. M., Trotter, A. B., Galanter, W. L., Arnold, P. M., & Iyer, R. K. (2022). Predicting ICU admissions for hospitalized COVID-19 patients with a factor graph-based model. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 42. Melton, C. A., Bae, J., Olusanya, O. A., Brenas, J. H., Shin, E. K., & Shaban-Nejad, A. (2022). Semantic network analysis of COVID-19 vaccine related text from reddit. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 43. Bhambhoria, R., Saab, J., Uppa, S., Li, X., Yakimovich, A., Bhatti, J., Valdamudi, N. K., Moyano, D., Bales, M., Dolatabadi, E., & Kocak, S. A. (2022). Towards providing clinical insights on long Covid from twitter data. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 44. Zehtabian, S., Khodadadeh, S., Turgut, D., & Bölöni, L. (2022). Predicting infections in the Covid-19 pandemic—Lessons learned. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 45. Wang, L., & Chen, J. (2022). Improving radiology report generation with adaptive attention. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 46. Revanur, A., Dasari, A., Tucker, C. S., & Jeni, L. A. (2022). Instantaneous physiological estimation using video transformers. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 47. Huang, X., Wicaksana, J., Li, S., & Cheng, K. T. (2022). Automated vision-based wellness analysis for elderly care centers. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 48. Sehanobish, A., Brown, N., Daga, I., Pawar, J., Torres, D., Das, A., Becker, M., Herzog, R., Odry, B., & Vianu, R. (2022). Efficient extraction of pathologies from C-Spine radiology reports using multi-task learning. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 49. Xia, T., Han, J., & Mascolo, C. (2022). Benchmarking uncertainty quantification on biosignal classification tasks under dataset shift. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer.
Multimodal Artificial Intelligence: Next Wave …
9
50. Ul Haq, H., Kocaman, V., & Talby, D. (2022). Mining adverse drug reactions from unstructured mediums at scale. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 51. Viñas, R., Zheng, X., & Hayes, J. (2022). A graph-based imputation method for sparse medical records. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 52. Jana, S., Dasgupta, T., & Dey, L. (2022). Using nursing notes to predict length of stay in ICU for critically ill patient. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 53. Han, H. J., BN, S., Qiu, L., & Abdullah, S. (2022). Automatic classification of dementia using text and speech data. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer. 54. Hoang, T., Nguyen, T. T., & Nguyen, H. D. (2022). Unified tensor network for multimodal dementia detection. In Multimodal AI in healthcare: A paradigm shift in health intelligence. Studies in computational intelligence. Springer.
Unsupervised Numerical Reasoning to Extract Phenotypes from Clinical Text by Leveraging External Knowledge Ashwani Tanwar, Jingqing Zhang, Julia Ive, Vibhor Gupta, and Yike Guo
Abstract Extracting phenotypes from clinical text has been shown to be useful for a variety of clinical use cases such as identifying patients with rare diseases. However, reasoning with numerical values remains challenging for phenotyping in clinical text, for example, temperature 102F representing Fever. Current state-of-the-art phenotyping models are able to detect general phenotypes, but perform poorly when they detect phenotypes requiring numerical reasoning. We present a novel unsupervised methodology leveraging external knowledge and contextualized word embeddings from ClinicalBERT for numerical reasoning in a variety of phenotypic contexts. Comparing against unsupervised benchmarks, it shows a substantial performance improvement with absolute gains on generalized Recall and F1 scores up to 79% and 71%, respectively. In the supervised setting, it also surpasses the performance of alternative approaches with absolute gains on generalized Recall and F1 scores up to 70% and 44%, respectively. Equal contribution: Ashwani Tanwar and Jingqing Zhang. A. Tanwar · J. Zhang · V. Gupta · Y. Guo (B) Pangaea Data Limited, London, UK e-mail: [email protected] A. Tanwar e-mail: [email protected] J. Zhang e-mail: [email protected] V. Gupta e-mail: [email protected] J. Zhang · Y. Guo Data Science Institute, Imperial College London, London, UK J. Ive Department of Computing, Imperial College London, London, UK e-mail: [email protected] Queen Mary University of London, London, UK Y. Guo Hong Kong Baptist University, Hong Kong SAR, Hong Kong, China © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_2
11
12
A. Tanwar et al.
Keywords Numerical reasoning · Phenotyping · Contextualized word embeddings · Unsupervised learning · Natural language processing · Deep learning
1 Introduction Extracting phenotypes1 from clinical text has been shown crucial for many clinical use cases [37] such as ICU in-hospital mortality prediction, remaining length of stay prediction, decompensation prediction and identifying patients with rare diseases. There are several challenges in extracting phenotypes such as handling a wide variety of phenotypic contexts, ambiguities, long term dependencies between phenotypes, and so on. Numerical reasoning is one of the key challenges as many of the phenotypes rely on bedside measurements such as temperature, blood pressure, heart rate, breathing rate, serum creatinine, hematocrit, glucose levels. We call these terms numeric entities. As these phenotypes require deep reasoning with the numbers, they are often missed or incorrectly predicted by the existing phenotype extraction methods [2, 3, 6, 15, 17, 32, 35]. Existing phenotype extraction methods such as Neural Concept Recognizer (NCR) [2] which are mostly based on state-of-the-art (SOTA) machine learning (ML) and natural language processing (NLP) technologies exploit non-contextualized word embeddings. These methods cannot detect contextual synonyms of the phenotypes which can be mentioned in various ways by clinicians. For example, previous SOTA phenotyping models like NCR and NCBO [15] can capture the phenotype Fever from the sentence “patient is detected with fever” but fail to capture the same from the sentence “patient is reported to have high temperature”. Similarly, in the sentence, “patient is reported to have high temperature in the room with low temperature”, only the former instance of temperature is a phenotype, while the latter is not. The recent study [38] demonstrates the capability of contextualized embeddings (BERT-based [7]) to differentiate the two instances by context. However, none of these methods above are specifically designed to reason with numbers in clinical text, for example, “temperature 102F” representing Fever. While the contextualized embeddings are useful for reasoning in different contexts, they are not sufficient to address numerical reasoning. In practice, the numerical reasoning for clinical context has specific challenges. First, clinical text may have accumulation of multiple numeric examples in a condensed context such as “Physical examination: temperature 97.5, blood pressure 124/55, pulse 79, respirations 18, O2 saturation 99% on room air.”. In addition, the numeric examples can be mentioned in a variety of different contexts such as “temperature of 102F”, “temperature is 102F”, “temperature is recorded as 102F”, 1
In the medical text, the word “phenotype” refers to deviations from normal morphology, physiology, or behaviour, such as skin rash, hypoxemia, neoplasm, etc. [30]. Please note the difference of the phenotypic information to the diagnosis information expressed in ICD-10 codes [25] as the former contributes to the latter.
Unsupervised Numerical Reasoning to Extract Phenotypes …
13
“temperature is found to be 102F”, which make it more challenging to identify the (numeric entity, number) pair namely (temperature, 102F) in this case. Moreover, numbers in clinical text may not be necessarily connected with phenotypes. For example, the number in “patient required 4 days of hospitalization” is not relevant to any phenotype. To the best of our knowledge, previous studies have not addressed these challenges and we propose the first deep learning based (BERT-based) unsupervised methodology in this paper to accurately extract phenotypes by numerical reasoning from various contexts using external knowledge for clinical natural language processing. In summary, our main contributions are as follows: 1. We propose a new approach to accurately detect phenotypes requiring numerical reasoning using natural language processing and deep learning techniques. 2. The approach is unsupervised and does not require manual data labelling. 3. Our approach can detect phenotypes from a variety of different contexts as it uses contextualized word embeddings.
2 Related Work Phenotyping: Extraction of phenotypes from text has been addressed using several strategies. Shallow matching using linguistic patterns was used extensively by cTAKES [32], MetaMap [3], and Clinphen [6]. Then, the shallow matching was extended to semantic analysis by leveraging non-contextualized word embeddings by the works of [2, 17, 35]. For example, Neural Concept Recognizer (NCR) [2] uses a convolutional neural network (CNN) to build non-contextualized embeddings by leveraging hierarchical medical concepts from biomedical ontologies like Human Phenotype Ontology (HPO) [16]. Finally, [38] showed that using contextualized embeddings from ClinicalBERT [1] helps to detect contextual synonyms of the phenotypes from the text. Similarly, some other works [10, 20, 22, 36] also exploited ClinicalBERT or BioBERT [19] for phenotype detection, but all of these works focus on a limited set or group of phenotypes. None of these methods addresses phenotyping requiring numerical reasoning, so we extend the work and leverage external knowledge with ClinicalBERT to extract phenotypes requiring deep numerical reasoning in different textual contexts. Numerical Reasoning: Recent works publish new datasets for numerical reasoning [39] and utilise deep learning based models to develop the numerical reasoning skills [9, 13, 31, 34] in respective domains other than the clinical domain. For example, [11] shows gains by using artificially created data on various tasks involving numeracy such as math word problems and reading comprehension. Other works [8, 12, 28] designed special modules for numerical reasoning in text which were then integrated with neural networks. Overall, these models have shown advancements in the respective domains for specialized problems but they did not incorporate clinical knowledge with specific extensive reasoning for clinical applications [33].
14
A. Tanwar et al. Input Sentence Step 1: External Knowledge
"her pyrexia increased to 102F and she was begun on levofloxacin"
temperature, heart rate, ...
Lexical Candidates: pyrexia,increased,begun
Numeric Entities
Step 2: Lexical Candidate and Number Extraction Step 3: Contextualized Embeddings Number 102 Step 4: Embeddings Similarity and Deterministic HPO Assignment
Fever
Output Phenotype Annotation
Fig. 1 The workflow of the proposed Numerical Reasoning (NR) model to extract phenotypes from clinical text by leveraging external knowledge. The external knowledge is first created at one time as shown in Tables 1 and 2. Second, the numbers and lexical candidates are extracted from the input sentence. Third, the corresponding contextualized embeddings of lexical candidates and numeric entities are computed. As the final step, the contextualized embeddings are compared with similarity to find the closest numeric entity and then the phenotype (i.e., HPO concept) is assigned and annotated deterministically given the chosen numeric entity and the extracted number. All steps are further elaborated in Sect. 3
3 Methodology This section presents our unsupervised method for numerical reasoning (NR) in clinical textual notes to extract phenotypes. Figure 1 shows the architecture of the proposed method which includes four steps (1) one-time creation of external knowledge connecting numeric entities and phenotypes, (2) extraction of numbers and lexical candidates for numeric entities from text, (3) creation of contextualized embeddings for numeric entities and lexical candidates and (4) linking candidates to numeric entities by embedding similarity and then determining corresponding phenotypes. This section elaborates all of the above steps.
3.1 External Knowledge Phenotypes can often be inferred from ranges of numerical values together with numeric entities which are mentioned in clinical text. For example, the clinicians may mention the numeric entity temperature and the value “102 F” in clinical notes to suggest a patient is suffering from the phenotype Fever (HP:0001945). Therefore, we create an external knowledge base to formalise such connections between phenotypes, numeric entities and numerical values.
Unsupervised Numerical Reasoning to Extract Phenotypes …
15
Table 1 Examples of numeric entities that are used in the study with normal reference range and units. For example, the normal range of body temperature is 36.4–37.3 in Celsius or 97.5–99.1 in Fahrenheit. The ID column corresponds to that in Table 2 ID Numeric Abbreviation Unit Normal reference range entity Lower bound Upper bound 0 0 1 2 3 3 4 5
Temperature Temperature Heart rate
Temp Temp Heart rate
Celsius Fahrenheit Beats per minute (bpm) Breathing rate Breathing rate Breaths per minute Serum Serum mg/dL creatinine creatinine Serum Serum Micromoles/L creatinine creatinine Hematocrit hct % Blood oxygen o2 %
36.4 97.5 60
37.3 99.1 80
12
20
0.6
1.2
53
106.1
41 95
48 100
We first manually collect a list of 33 most frequent numeric entities such as temperature, heart rate, breathing rate, serum anion gap and platelet and their corresponding normal reference ranges (values and units) from the website of National Health Service of UK2 and MIMIC-III database [14]. Table 1 shows examples of numeric entities and their corresponding lower/upper bounds with units. For example, the normal body temperature has a lower bound 97.5 F (36.4 C) and an upper bound 99.1 F (37.3 C). Those numeric entities are then manually mapped with phenotypes which are defined and standardised by Human Phenotype Ontology (HPO) [18]. In most cases, a numeric entity corresponds to three phenotypes depending on whether the actual measurement is lower than, higher than or within the normal reference range. If the actual measurement is lower (higher) than the lower (upper) bound, it means the relevant phenotype is affirmed. For example, the phenotype Hypothermia (HP:0002045) or Fever (HP:0001945) is affirmed if the body temperature is lower or higher than the normal range, respectively. Otherwise, if the body temperature is inside the normal range, the general phenotype Abnormality of Temperature Regulation (HP:0004370), which is the parent phenotype of Hypothermia (HP:0002045) or Fever (HP:0001945), is negated. Table 2 demonstrates examples of the connections between numeric entities and phenotypes. Both Tables 1 and 2 are validated by three expert clinicians with consensus for authenticity and consistency. The external knowledge is created at one time and prior to other steps, which makes the external knowledge reusable. 2
Accessed in November 2021: https://www.nhs.uk.
Heart rate Breathing rate
Serum creatinine Serum creatinine HP:0012101
Hematocrit
Blood oxygen
3
4
5
o2
hct
Heart rate Breathing rate
HP:0012418
HP:0031851
HP:0001662 HP:0046507
HP:0002045
1 2
Temp
Temperature
Reduced hematocrit Hypoxemia
Decreased serum creatinine
Bradycardia Bradypnea
Hypothermia
Number lower than the lower bound (Affirmed) HPO ID HPO name
0
Abb.
Numeric entity
ID
HP:0012419
HP:0001899
HP:0003259
HP:0001649 HP:0002789
HP:0001945
Increased hematocrit Hyperoxemia
Elevated serum creatinine
Tachycardia Tachypnea
Fever
Number higher than the upper bound (Affirmed) HPO ID HPO name
HP:0500165
HP:0031850
HP:0012100
HP:0011675 HP:0002793
HP:0004370
Abnormality of temperature regulation Arrhythmia Abnormal pattern of respiration Abnormal circulating creatinine concentration Abnormal hematocrit Abnormal blood oxygen level
Number inside normal range (Negated) HPO ID HPO name
Table 2 Examples of numeric entities that are used in the study with phenotype labels (including HPO ID and HPO name). Each numeric entity is connected with three phenotype concepts. For example, a patient has Hypothermia (Fever) if their body temperature is lower (higher) than the lower (upper) normal limit. If the body temperature of the patient is inside the normal range, it is negation of the general phenotype Abnormality of temperature regulation. The ID column corresponds to Table 1
16 A. Tanwar et al.
Unsupervised Numerical Reasoning to Extract Phenotypes …
17
Fig. 2 An example of syntactic analysis to extract lexical candidates. In this example, we extract pyrexia, increased, begun as the lexical candidates from the sentence by using Part-of-Speech (POS) tagging and dependency parsing
3.2 Number and Lexical Candidates Extraction We then extract numbers and their corresponding lexical candidates which are likely to be numeric entities from clinical text. For example, in the input sentence “her pyrexia increased to 102F and she was begun on levofloxacin”, the number “102” is connected with “pyrexia” which is contextually similar to the numeric entity “temperature”. The regular expression patterns are created to extract the numbers which typically appear in alpha-numeric format such as “pyrexia increased to 102F” and “heart rate in 90s”. The numbers which are dates or part of specific clinical concepts such as “vitamin B12”, “O2 saturation” are excluded by using a pre-defined dictionary of alpha-numeric words [24] as the numbers are not relevant to phenotypes. After the extraction of numbers, the lexical candidates connected to these numbers are extracted using syntactic analysis. As shown in Fig. 2, we focus on (proper) nouns, adjectives and verbs that are connected with syntactic connections to the extracted numbers (heads or children in the syntactic tree). As a special case, we allow one additional hop from the extracted number via the dependency relation ‘obl’ which stands for oblique nominal. For example, in Fig. 2, the words “pyrexia”, “increased”, and “begun” are extracted as lexical candidates because they are connected to the extracted number “102F” and therefore are likely to represent numeric entities. The list of extracted lexical candidates is passed to the following steps to decide the corresponding numeric entities based on context. As the extraction method of lexical candidates is designed to encourage more extraction to increase recall so that no important word is missed, not all of the lexical candidates will eventually correspond to a numeric entity.
3.3 Contextualized Embeddings for Numeric Entities and Lexical Candidates We use contextualized embeddings (ClinicalBERT [1]) of numeric entities and lexical candidates to measure their similarity and decide which numeric entity should be assigned to the extracted lexical candidates from the input sentence. The objective of the model is to learn a semantic space where all possible expressions (including names and synonyms) of one numeric entity are clustered while the expressions of different numeric entities are differentiated. To achieve this, we use
18
A. Tanwar et al.
ClinicalBERT finetuned with Semantic Textual Similarity (STS) objective defined as follows: 2 |E| |S| 1 1 cos(hei , hs j ) − yei ,s j , |E| |S| i=1 j=1 1, if s j is a synonym of ei where yei ,s j = 0, otherwise
L (ei , s j ) =
(1)
where hei represents contextualized embedding for the ith numeric entity ei in E. Similarly, hs j represents contextualized embedding for the jth synonym s j in S. The ground truth label yei ,s j is 1 if the synonym s j is one of the synonyms of the numeric entity ei and 0 if otherwise. The loss function aims to maximise the cosine similarity between numeric entities and their corresponding synonyms and minimise the similarity between numeric entities and irrelevant synonyms. The collection of {hei |∀ei ∈ E} is used as the reference contextualized embeddings of numeric entities created once. As the training data, we collect all synonyms S = {s1 , s2 , . . . , s|S| } of all the numeric entities E = {e1 , e2 , . . . , e|E| } (listed in Table 2) by connecting the HPO IDs with Unified Medical Language System (UMLS) [5]. During inference, lexical candidates extracted from input sentences are fed into the finetuned ClinicalBERT based model to produce their contextualized embeddings.
3.4 Embedding Similarity and Deterministic HPO Assignment Embeddings pairs are formed by Cartesian product of the contextualized embeddings of lexical candidates and reference contextualized embeddings of numeric entities. Then cosine similarity is computed between all the pairs. The pair with the maximum cosine score above a pre-set threshold gives the selected lexical candidate which in turn gives the corresponding numeric entity. A sentence may have multiple numbers connected with their corresponding numeric entities. We simply consider the lexical candidates (corresponding to each number) as an independent case for the above Cartesian product which helps extracting multiple candidate numeric entities from a single sentence. After measuring similarity of embeddings and determining the numeric entities, we deterministically assign the phenotype depending if the corresponding number is lower than the lower bound, inside the normal range, or higher than the upper bound. For example, in Fig. 1, the lexical candidate “pyrexia” is extracted and the numeric entity “temperature” is assigned based on contextualized embedding. As the number “102F” is higher than the upper bound “99.1”, the phenotype Fever (HP:0001945) is eventually assigned.
Unsupervised Numerical Reasoning to Extract Phenotypes …
19
Table 3 A list of granular phenotypes under primary phenotypes. For example, the reduced ejection fraction can be further divided into three sub-phenotypes by severity based on the actual percentage mentioned in clinical text Primary phenotype HPO ID
Unit
HPO name
Granular range
Granular phenotype
Lower
Upper
HPO ID
HPO Name
HP:0012664 Reduced ejection fraction
%
0
29.9
HP:0012666 Severely reduced ejection fraction
HP:0012664 Reduced ejection fraction
%
30
39.9
HP:0012665 Moderately reduced ejection fraction
HP:0012664 Reduced ejection fraction
%
40
49.9
HP:0012663 Mildly reduced ejection fraction
HP:0001945 Fever
Celsius
37.4
38
HP:0011134 Low-grade fever
HP:0001945 Fever
Fahrenheit
99.2
100.4
HP:0011134 Low-grade fever
We also enhance the HPO assignment process by handling different units of numbers (e.g. Fahrenheit and Celsius) because sometimes the units are not explicitly mentioned in text. Therefore, we decide the unit by comparing the ratio of the number to the extreme ends of the normal reference ranges in different units. For example, normal range for temperature is (36.4, 37.3) in Celsius and (97.5, 99.1) in Fahrenheit. If a given number is 92, then we take the ratios as the following. The unit giving the smaller ratio (Fahrenheit in this case) is then used to determine HPO assignment. •
92 number = = 2.5 upper_bound_celsius 37.3
•
lower_bound_fahren 97.5 = = 1.1 number 92
Moreover, we consider granular phenotypes based on granular sub-ranges as shown in Table 3.
4 Experiment Design 4.1 Datasets We use clinical textual notes from the publicly available MIMIC-III database [14]. In the unsupervised setting, we collected 705 EHR textual notes with 20,926 gold phenotype annotations as shown in Table 4. The gold phenotype annotations were
20
A. Tanwar et al.
Table 4 Statistics (counts) of the test sets in the unsupervised and supervised setting, respectively. The test set in the unsupervised setting includes all manually annotated EHRs. The test set in the supervised setting is a subset of that in the unsupervised setting because some annotated EHRs are used to finetune the baseline models. Please note only Numerical Reasoning (NR) specific phenotypes are used for evaluation as the other phenotypes are not related with numbers in clinical narratives Test set (Unsupervised setting) Test set (Supervised setting) EHRs 705
All phenotypes 20926
NR-specific phenotypes 1121
EHRs 170
All phenotypes 5047
NR-specific phenotypes 322
created by three expert clinicians with consensus and the clinicians were specifically asked to identify contextual synonyms of phenotypes such as “drop in blood pressure” and “BP of 79/48” for Hypotension (HP:0002615). Out of these phenotype annotations, we select a subset with 1,121 phenotype annotations (i.e., NR specific phenotypes) which require numerical reasoning based on two criteria: (1) the annotated phenotypes are among one of the HPO IDs that require numerical reasoning as mentioned in Tables 2 and 3 and (2) the corresponding textual spans of phenotypes contain numbers. The test set in the unsupervised setting is used to compare the proposed NR model with previous unsupervised baseline methods. In the supervised setting, as 535 out of 705 manually annotated EHRs are used to finetune the baseline methods (like ClinicalBERT), the remaining 170 EHRs are used for testing. In other words, the test set in the supervised setting is the subset of that in the unsupervised setting. Though, the proposed NR model is strictly unsupervised, we compare it with supervised baselines to rigorously assess its performance.
4.2 Implementation Details We use the Stanford Stanza [27, 40] library to extract the lexical candidates for numeric entities using syntactic analysis. In syntactic analysis, we only focus on nouns, adjectives and verbs that are “NOUN”, “PROPN”, “ADJ”, and “VERB” as marked by the Part of speech (POS) tagger and we also optimise the process by adding words with the dependency relation ‘compound’ to capture multi-word phrases like “heart rate” and “blood pressure”. Then, we use Semantic Textual Similarity (STS) model from Sentence Transformers [29] library to finetune the ClinicalBERT embeddings with cosine similarity up to 4 epochs using their default hyperparameters3 along with a train and validation batch size of 16 and 1000 evaluation steps. Mean pooling is used to get embeddings of multi-word UMLS synonyms. The threshold for embedding similarity is set as 0.9 empirically. The implementation of the proposed method also uses some other third-party libraries including PyTorch [26] and spaCy. 3
Accessed in November 2021: https://www.sbert.net/docs/training/overview.html.
Unsupervised Numerical Reasoning to Extract Phenotypes …
21
4.3 Baselines and Evaluation Methods We compare the proposed NR model with previous state-of-the-art phenotyping models. In the unsupervised setting, the proposed NR model is compared with unsupervised baselines including NCBO [15], NCR [2] and the unsupervised model by [38]. In the supervised setting, the proposed NR model is compared with the finetuned ClinicalBERT [1] (which is finetuned for phenotyping) and the supervised model by [38]. The NCBO, NCR and finetuned ClinicalBERT are selected as they show better performance than other baseline phenotyping methods (including cTAKES [32], MetaMap [3], Clinphen [6], MedCAT [17], BERT [7], BioBERT [19], SciBERT [4]) in corresponding settings as demonstrated by [38]. Please note the work by [38] publishes one unsupervised and one supervised model hence we compare the proposed NR model with both. We decide not to compare with recent numerical reasoning models (such as [9, 13, 31, 34]) as none of them incorporates clinical knowledge and we find it costly to adapt them to the clinical domain. We first evaluate the proposed NR model against the baselines by using microaveraged Precision, Recall and F1-score at the document level. To ensure comparison with previous studies, we follow the practice by [21] and compute the metrics by the following two strategies. (1) Exact Matches: Only the exact same HPO annotations against the gold standard annotations are counted as correct; (2) Generalized Matches: the gold standard annotations as well as predicted HPO annotations are extended to include all ancestor HPO concepts until the root concept Phenotypic Abnormality (HP:0000118) (exclusive) in the HPO hierarchy. All the extended HPO annotations are then de-duplicated and added to the list of gold standard and predicted HPO annotations respectively for evaluation. By the generalized matches, the prediction of HPO concepts which are children, neighbours or ancestors of the target HPO concepts also receives credits.
5 Results and Discussion 5.1 Quantitative Analysis We report our quantitative results in Table 5 where we evaluate the NR model in the unsupervised setting. We also compare the NR model with the baselines – NCBO, NCR, and unsupervised model by [38] but they perform poorly on the unsupervised test set with straight 0 on all the metrics. This is expected as they are not designed to handle numbers. The NR model performs significantly better than all of them achieving 69% recall and 59% F1 using exact metrics, while 79% recall and 71% F1 using generalized metrics. Precision is relatively lower as we focus on recall to extract more phenotypes, which is motivated by the preference that a model is sensitive to capture more phenotypic features of patients rather than missing ones
22
A. Tanwar et al.
Table 5 In the unsupervised setting, the comparison of baselines NCBO, NCR, and [38] (unsupervised) with proposed Numerical Reasoning (NR) model shows the superior performance of NR model. Interestingly but not surprisingly, the baseline methods produce zero accuracy as they are not designed to reason by numbers Model Exact Generalized Precision Recall F1 Precision Recall F1 NCBO NCR [38] (unsupervised) Numerical Reasoning (NR)
0 0 0 0.5176
0 0 0 0.6879
0 0 0 0.5907
0 0 0 0.6479
0 0 0 0.7907
0 0 0 0.7122
Table 6 The comparison of supervised baselines with the proposed Numerical Reasoning (NR) model in the supervised setting shows that the NR model increases recall significantly by finding more phenotypes even without supervision. Please note supervised setting refers to a subset of unsupervised setting test set which is created to compare unsupervised NR with the supervised baselines Model Exact Generalized Precision Recall F1 Precision Recall F1 Finetuned ClinicalBERT [38] (supervised) Numerical Reasoning (NR) [38] (supervised) + NR
0.8235
0.181
0.2968
1.000
0.2229
0.3646
0.6791 0.5952
0.6293 0.7543
0.6532 0.6654
0.8245 0.7290
0.7762 0.8339
0.7996 0.7780
0.5921
0.8448
0.6963
0.7175
0.9201
0.8062
for better accuracy in downstream clinical use cases [37]. Overall, the NR model shows huge gains which is useful in the absence of costly annotated data. We also compare the unsupervised NR model with the previous state-of-the-art supervised baseline methods. First, we compare the NR model with the supervised model by [38] which is finetuned with annotated data. This comparison is shown in Table 6 on the supervised test set. Though the supervised model by [38] outperforms its unsupervised version, the proposed unsupervised NR model performs better than the supervised baseline with gains of 12.5 and 5.7% on exact and generalized recall, respectively. However, there is a drop in precision which results in the comparable F1 scores. Moreover, using a combination of both the models achieves the best performance improving score by 21.5 and 14.3% on exact and generalized recall, respectively, and 4.3 and 0.7% gains on exact and generalized F1 scores, respectively. Then, the NR model is compared against the finetuned ClinicalBERT [1] which is finetuned to detect phenotypes. The combination of NR model and supervised model by [38] surpasses the performance of the baseline with gains of 66.4 and 69.7% on exact and generalized recall, respectively and 40 and 44.2% gains on exact and
Unsupervised Numerical Reasoning to Extract Phenotypes …
23
generalized F1 scores, respectively, as shown in Table 6. These results highlight the impact of the NR model which shows better performance than the supervised models eliminating the need of costly human annotations of phenotypes.
5.2 Qualitative Analysis We investigate the numerical reasoning capabilities of the proposed NR model and other baseline methods by eye-balling example sentences having different contexts. In the sentence “patient has a temperature of 102F.”, NCR, NCBO, and [38] (unsupervised) do not detect any phenotype. But after adding the word high, i.e., “patient has a high temperature of 102F.”, [38] (unsupervised) correctly detects the phenotype Fever (HP:0001945). However, the predicted textual span is “high temperature” only ignoring the number 102F. It indicates that the [38] (unsupervised) relies on context without considering numbers, while NCR and NCBO still do not detect any phenotype. When the word “temperature” is changed to “fever” and the whole sentence becomes “patient has a high fever of 102F.”, all three unsupervised baseline methods can correctly detect the phenotype Fever (HP:0001945) though the the number is still ignored in the predicted textual span. Overall, we observe all the unsupervised baseline methods solely rely on the textual content by ignoring the numbers, though [38] (unsupervised) can find contextual synonyms of phenotypes. In contrast, the proposed NR model correctly detects the phenotype from all the three variants of the original sentence with the correct textual spans which include numbers. More precisely, the target textual spans are “temperature of 102F”, “temperature of 102F”, and “fever of 102F” with the phenotype Fever (HP:0001945) for the three sentences above, respectively. We observe the similar behavior given the sentence “patient has a breathing rate of 27.” with the phenotype Tachypnea (HP:0002789) as well as “patient has a serum creatinine of 1.7.” with the phenotype Elevated serum creatinine (HP:0003259). The model [38] (unsupervised) detects the phenotype (still ignoring the numbers) when an indicative word like “high” is added, while NCBO and NCR miss the annotations with the exception for the latter sentence where NCR detects the phenotype after “high” is added to the sentence. In short, the results suggest that the proposed NR model reasons with the numbers effectively in different contexts without supervision. The supervised model by [38] overall performs much better with reasonable accuracies than the unsupervised baselines which give straight 0 scores. However, it still lacks the capabilities to reason with numbers. For instance, though the [38] (supervised) correctly predicts the phenotype Fever (HP:0001945) from the sentence “patient has a temperature of 102F.”, if the number in the sentence is changed from 102F to 92F and the target phenotype is therefore changed to Hypothermia (HP:0002045), the [38] (supervised) still predicts fever mistakenly. Similar incorrect predictions are observed when the target phenotype is changed from Tachypnea (HP:0002789) to Bradypnea (HP:0046507) and from Elevated serum creatinine (HP:0003259) to Decreased serum creatinine (HP:0012101). We hypothesize Fever
24
A. Tanwar et al.
is far more common than Hypothermia in the training data, so the model is finetuned with bias towards the highly frequent phenotypes. This may result in the inflation of the scores in Table 6 for [38] (supervised) which overestimates its numerical reasoning capabilities. Based on the observation, we conclude the supervision without additional tailored learning objectives is not sufficient to obtain the numerical reasoning capabilities. However, there are some cases where the NR model fails to produce accurate predictions. For example, in the text—“Pt still with scant bibasilar crackles. Sat @ 97% on 2L NG. Continuing with oral HTN meds and Dig.”, the model predicts Abnormal blood oxygen level (HP:0500165) to negate the phenotype “Sat @ 97%” as 97% is within normal reference range for blood oxygen, i.e., 95–100%. However, the correct phenotype is Hypoxemia (HP:0012418) as the patient achieved this normal range using some external oxygen which implies from the phrase “2L NG”.
5.3 Ablation Studies We conduct two ablation studies to probe the benefit of contextualized embeddings and the learning objective for finetuning in Eq. 1. To evaluate the usage of contextualized embeddings with cosine similarity to connect lexical candidates with numeric entities as described in Sect. 3.4, we ablate the contextualized embeddings and instead we use keyword based shallow matching to connect lexical candidates with numeric entities. Table 7 shows that the ablated method results in significant performance drop, more precisely, in terms of exact Recall from 68.8 to 26.4% and F1 from 59.1% to 38.1% on unsupervised test set.
Table 7 Ablation studies on the unsupervised test set. Comparison of Numerical Reasoning (NR) model variants using keyword based shallow matching of lexical candidates with numeric entities, pretrained contextualized embeddings and finetuned contextualized embeddings. The finetuned contextualized embeddings substantially outperform other two methods and is incorporated into the final NR model NR model with Exact Generalized Precision Recall F1 Precision Recall F1 Keyword based shallow matching Pretrained contextualized embeddings Finetuned contextualized embeddings (used by the final NR model)
0.6854
0.2641
0.3813
0.7745
0.3449
0.4773
0.5065
0.3758
0.4314
0.6006
0.465
0.5241
0.5176
0.6879
0.5907
0.6479
0.7907
0.7122
Unsupervised Numerical Reasoning to Extract Phenotypes …
(a) Pretrained contextualized embeddings
25
(b) Finetuned contextualized embeddings
(c) Color codes for numeric entities
Fig. 3 UMAP visualization of pretrained and finetuned contextualized embeddings of numeric entities and their UMLS synonyms by pretrained and finetuned ClinicalBERT, respectively. Finetuning leads to better differentiation of numeric entities in the semantic space which helps the NR model to identify them with higher accuracy
Therefore, contextualized embeddings is beneficial to capture the semantics of lexical candidates (corresponding to numeric entities) appearing in different contexts. In Table 7, we also compare the difference between pretrained and finetuned contextualized embeddings. The pretrained embeddings are generated by the pretrained ClinicalBERT model without finetuning and the finetuned embeddings are generated after finetuning ClinicalBERT using Semantic Textual Similarity (STS) Eq. 1 as mentioned in Sect. 3.3. As shown in Table 7, the pretrained contextualized embeddings perform poorly with a drop on exact Recall from 68.8 to 37.6% and F1 from 59.1 to 43.1% on unsupervised test set. For better interpretation, we visualize the pretrained and finetuned contextualized embeddings of numeric entities and their corresponding UMLS synonyms in Fig. 3 by using Uniform Manifold Approximation and Projection (UMAP) dimensionality reduction [23]. We find that, by the pretrained contextualized embeddings most of the numeric entities are spread out unevenly in the space. For example, the data points for (general) cholesterol, low-density lipoprotein cholesterol, and high-density lipoprotein cholesterol are intermixed. On the other hand, the finetuned contextualized embeddings form well segregated clusters which means it is easier to predict a corresponding numeric entity of lexical candidates (connected with a number) using cosine similarity without collisions. Overall, it confirms that pretrained contextualized embeddings are not sufficient to connect lexical candidates with numeric entities effectively without the proposed learning objective for finetuning in Eq. 1.
26
A. Tanwar et al.
6 Conclusions and Future Works Numerical reasoning is critical to capture critical phenotypes such as bedside measurement from clinical text. Current state-of-the-art phenotyping models are not designed to reason with numbers, and thus all of them perform poorly in detecting the phenotypes that require numerical reasoning. The proposed unsupervised model shows substantial gains over these models due to its explicit design to reason with numbers by leveraging external knowledge. The proposed model can be potentially generalized to other biomedical NLP tasks that require numerical reasoning from text. The model can be further extended to consider document level context and dynamic external knowledge base. Acknowledgements We would like to thank Dr. Garima Gupta, Dr. Deepa (M.R.S.H) and Dr. Ashok (M.S.) for helping us create gold-standard phenotype annotation data and validate the external knowledge for numerical reasoning.
References 1. Alsentzer, E., et al. (2019) Publicly available clinical BERT embeddings. In: Proceedings of the 2nd Clinical Natural Language Processing Workshop (pp. 72–78). Minneapolis, Association for Computational Linguistics: Minnesota, USA. 2. Arbabi, A., et al. (2019). Identifying clinical terms in medical text using Ontology-Guided machine learning. JMIR Medical Informatics, 7(2), e12596 (2019). 3. Aronson, A. R., et al. (2010). An overview of MetaMap: Historical perspective and recent advances. Journal of the American Medical Informatics Association : JAMIA, 17(3), pp. 229— 236 (2010). ISSN: 1527-974X (Electronic). 4. Beltagy, I., et al. (2019). SciBERT: A pretrained language model for scientific text. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP) (pp. 3615–3620). Association for Computational Linguistics: Hong Kong, China. 5. Bodenreider, O. (2004). The unified medical language system (UMLS): Integrating biomedical terminology. Nucleic Acids Research 32, pp. 267–270. Database-Issue. 6. Deisseroth, C. A., et al. (2019). ClinPhen extracts and prioritizes patient phenotypes directly from medical records to expedite genetic disease diagnosis. Genetics in Medicine, 21(7), 1585– 1593. 7. Devlin, J., et al. (2019). BERT: Pre-training of deep bidirectional transformers for language understanding. In J. Burstein, et al. (Eds.), Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2–7, 2019 (Vol. 1, pp. 4171–4186). Association for Computational Linguistics (Long and Short Papers). 8. Dua, D., et al. (2019). DROP: A reading comprehension benchmark requiring discrete reasoning over paragraphs. In J. Burstein, et al. (Eds.), Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, NAACL-HLT 2019, Minneapolis, MN, USA, June 2–7, 2019 (Vol. 1, pp. 2368–2378). Association for Computational Linguistics (Long and Short Papers).
Unsupervised Numerical Reasoning to Extract Phenotypes …
27
9. Duan, H., et al. (2021). Learning numeracy: A simple yet effective number embedding approach using knowledge graph. In Findings of the Association for Computational Linguistics: EMNLP 2021 (pp. 2597–2602). Association for Computational Linguistics: Punta Cana, Dominican Republic. 10. Franz, L., et al. (2020). A deep learning pipeline for patient diagnosis prediction using electronic health records. arXiv:2006.16926 11. Geva, M., et al. (2020). Injecting numerical reasoning skills into language models. In D. Jurafsky, et al. (Eds.), Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, ACL 2020, Online, July 5–10, 2020, pp. 946–958. Association for Computational Linguistics. 12. Hu, M., et al. (2019). A multi-type multi-span network for reading comprehension that requires discrete reasoning. In K. Inui, et al. (Eds.), Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong 16 Ashwani Tanwar, Jingqing Zhang, Julia Ive, Vibhor Gupta, Yike Guo Kong, China, November 3–7, 2019, pp. 1596–1606. Association for Computational Linguistics. 13. Jin, Z., et al. (2021). NumGPT: Improving numeracy ability of generative pre-trained models. arXiv:abs/2109.03137 14. Johnson, A. E. W., et al. (2016). MIMIC-III, a freely accessible critical care database. Scientific Data 3(1), 1–9. 15. Jonquet, C., et al. (2009). NCBO annotator: Semantic annotation of biomedical data. In International Semantic Web Conference, Poster and Demo Session (Vol. 110). 16. Köhler, S., et al. (2021). The human phenotype ontology in 2021. Nucleic Acids Research, 49, pp. D1207–D1217. Database-Issue. 17. Kraljevic, Z., et al. (2019). MedCAT—Medical concept annotation tool. 18. Köhler, S., et al. (2016). The human phenotype ontology in 2017. Nucleic Acids Research, 45(D1), D865–D876. ISSN: 0305-1048. 19. Lee, J., et al. (2019). BioBERT: A pre-trained biomedical language representation model for biomedical text mining. Bioinformatics. ISSN: 1367-4803. 20. Li, Y., et al. (2020). BEHRT: Transformer for electronic health records. Scientific Reports, 10(1), 1–12. 21. Liu, C., et al. (2019). Ensembles of natural language processing systems for portable phenotyping solutions. Journal of Biomedical Informatics, 100, 103318. ISSN: 1532-0464. 22. Liu, D., et al. (2019). Two-stage federated phenotyping and patient representation learning. In Proceedings of the 18th BioNLP Workshop and Shared Task (pp. 283–291). Association for Computational Linguistics: Florence, Italy. 23. McInnes, L., et al. (2018). UMAP: Uniform manifold approximation and projection for dimension reduction. 24. Moon, S., et al. (2014). A sense inventory for clinical abbreviations and acronyms created using clinical notes and medical dictionary resources. Journal of the American Medical Informatics Association, 21(2), 299–307. 25. World Health Organization. (2004). ICD-10: International statistical classification of diseases and related health problems: Tenth revision. 26. Paszke, A., et al. (2019). PyTorch: An imperative style, high-performance deep learning library. In H. Wallach, et al. (Eds.), Advances in Neural Information Processing Systems (Vol. 32). Curran Associates, Inc. 27. Qi, P., et al. (2020). Stanza: A python natural language processing toolkit for many human languages. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations. 28. Ran, Q., et al. (2019). NumNet: Machine reading comprehension with numerical reasoning. In K. Inui et al. (Eds.), Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3–7, 2019 (pp. 2474–2484). Association for Computational Linguistics.
28
A. Tanwar et al.
29. Reimers, N., et al. (2019). Sentence-BERT: Sentence embeddings using siamese BERTnetworks. In K. Inui, et al. (Eds.), Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3–7, 2019 (pp. 3980–3990). Association for Computational Linguistics. 30. Robinson, P. N. (2012). Deep phenotyping for precision medicine. Human Mutation 33(5), 777–780. 31. Saha, A., et al. (2021). Weakly supervised neuro-symbolic module networks for numerical reasoning. arxiv:abs/2101.11802 32. Savova, G. K., et al. (2010). Mayo clinical text analysis and knowledge extraction system (cTAKES): Architecture, component evaluation and applications. Journal of the American Medical Informatics Association: JAMIA, 17(5), 507–513. ISSN: 1067-5027. 33. Sushil, M., et al. (2021). Are we there yet? Exploring clinical domain knowledge of BERT models. In Proceedings of the 20th Workshop on Biomedical Language Processing (pp. 41–53). Association for Computational Linguistics. 34. Thawani, A., et al. (2021). Numeracy enhances the literacy of language models. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (pp. 6960–6967). Association for Computational Linguistics: Online and Punta Cana, Dominican Republic. 35. Tiwari, P., et al. (2020). TermInformer: unsupervised term mining and analysis in biomedical literature. Neural Computing and Applications. ISSN: 1433-3058. 36. Yang, Z., et al. (2020). Combining deep learning with token selection for patient phenotyping from electronic health records. Scientific Reports, 10(1), 1432. ISSN: 2045-2322. 37. Zhang, J., et al. (2021) Clinical utility of the automatic phenotype annotation in unstructured clinical notes: ICU use cases. arXiv:2107.11665 38. Zhang, J., et al. (2021). Self-supervised detection of contextual synonyms in a multi-class setting: Phenotype annotation use case. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing (pp. 8754–8769). Association for Computational Linguistics: Online and Punta Cana, Dominican Republic. 39. Zhang, Q., et al. (2021). NOAHQA: Numerical reasoning with interpretable graph question answering dataset. In Findings of the Association for Computational Linguistics: EMNLP 2021 (pp. 4147–4161). Association for Computational Linguistics: Punta Cana, Dominican Republic. 40. Zhang, Y., et al. (2021). Biomedical and clinical English model packages for the Stanza Python NLP library. Journal of the American Medical Informatics Association, 28(9), 1892–1899. ISSN: 1527-974X.
Domain-specific Language Pre-training for Dialogue Comprehension on Clinical Inquiry-Answering Conversations Zhengyuan Liu, Pavitra Krishnaswamy, and Nancy F. Chen
Abstract There is growing interest in the automated extraction of relevant information from clinical dialogues. However, it is difficult to collect and construct large annotated resources for clinical dialogue tasks. Recent developments in natural language processing suggest that large-scale pre-trained language backbones could be leveraged for such machine comprehension and information extraction tasks. Yet, due to the gap between pre-training and downstream clinical domains, it remains challenging to exploit the generic backbones for domain-specific applications. Therefore, in this work, we propose a domain-specific language pre-training, to improve performance on downstream tasks like dialogue comprehension. Aside from the common token-level masking pre-training method, according to the nature of human conversations and interactive flow of multi-topic inquiry-answering dialogues, we further propose sample generation strategies with speaker and utterance manipulation. The conversational pre-training guides the language backbone to reconstruct the utterances coherently based on the remaining context, thus bridging the gap between general and specific domains. Experiments are conducted on a clinical conversation dataset for symptom checking, where nurses inquire and discuss symptom information with patients. We empirically show that the neural model with our proposed approach brings improvement in the dialogue comprehension task, and can achieve favorable results in the low resource training scenario. Keywords Machine comprehension · Clinical conversation · Language pre-training
Z. Liu (B) · P. Krishnaswamy · N. F. Chen Institute for Infocomm Research (I2R) A*STAR, Singapore, Singapore e-mail: [email protected] P. Krishnaswamy e-mail: [email protected] N. F. Chen e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_3
29
30
Z. Liu et al.
1 Introduction As one of the fundamental tasks in natural language processing, machine reading comprehension is fueled by avid neural modeling investigations in recent years. Given a certain textual content, the goal is to answer a series of questions based on semantic understanding. Many studies have focused on monological documents like Wikipedia [20] and news articles [7], and some recent work are focusing on dialogue comprehension [16, 17, 24]. Different from passages, human-to-human dialogues are a dynamic and interactive flow of information exchange [21], which are often informal, verbose, and repetitive, and this brings unique challenges for adopting document-oriented approaches on the conversational samples. Recently, there is growing interest in automated extraction of relevant information from clinical dialogues [2, 14, 16], in the form of dialogue comprehension. However, neural models via supervised learning usually require a certain amount of training data, and it is difficult to collect and construct large annotated resources for clinical-related tasks. Fine-tuning large-scale pre-trained language models has become a data-efficient learning paradigm, and achieves substantial improvement on various downstream tasks and applications [1, 15]. However, these general-purpose language backbones are trained on the online crawled data with universal objectives. Although this can provide feature-rich contextualized representations, it also limits their capability in specific domains. On the other hand, while some recent studies have proposed methods for pre-training on dialogue samples [3, 28, 30–32], they are focusing more on the generation of daily conversations (e.g., open-domain social chat). When adopting such generic language backbones on more targeted scenarios such as clinical-related tasks, their performance becomes sub-optimal due to the significant domain difference [9]. To address such challenges, in this work, we propose a domain-specific language pre-training, and adopt it to improve spoken dialogue comprehension performance on the clinical inquiry-answering conversations. While the common methods for constructing pre-training samples (e.g., the token masking used in BERT [1], the infilling and deletion used in BART [11]) proved effective for contextualized modeling, text manipulation designed for specific resources/tasks can bring further improvement [6]. Therefore, considering the nature of human conversations [21] and the interactive flow of inquiry-answering dialogues, we propose a set of sample manipulations for conversational pre-training. At the token level, we introduce a masking strategy especially on speaker tokens, and we propose a permuting operation for better speaker role modeling. At the utterance level, we propose the utterance masking and intratopic permutation scheme. More specifically, given a dialogue, we randomly select one utterance, and mask it or exchange it with another span in the same topic, and the language model is guided to reconstruct the coherent conversation flow based on the context. Moreover, we add an additional token at the beginning of each utterance to explicitly present the utterance boundary information.
Domain-specific Language Pre-training for Dialogue Comprehension …
31
Our experiments are conducted on a multi-topic symptom checking conversation dataset, where nurses inquire and discuss symptom information with patients. We empirically show that the proposed approach brings significant improvement on the dialogue comprehension task, especially in the low resource scenario.
2 Domain-Specific Language Pre-training In this section, we elaborate on the six types of pre-training sample generation. We then describe the reconstruction-based learning process of language modeling to fuse domain-specific conversational features.
2.1 Conversation-based Sample Construction Human-to-human spoken conversations are an interactive process of information exchange. Compared with the monological passages, a language backbone for multiparty dialogues is required to infuse the conversational linguistic features [5], such as speaker roles and utterance boundary information [14], as well as the underlying dialogue discourse structures [21]. To guide language backbones to model the characteristics of conversations, we adopt the following six types of pre-training sample, as shown in Table 1. Token Masking Following the common pre-training scheme proposed in [1], for tokens in the sequence, 5% of them are randomly sampled and replaced with a token. Token Infilling Following the denoising learning scheme proposed in [11], 5% of input tokens are randomly sampled and replaced with random tokens extracted from the vocabulary. Speaker Masking To encourage the model to grasp the speaker information [5] in the interactive flow. We randomly sampled 10% of the utterances, and replaced their speaker tokens with a token. Speaker Permutation Previous work shows that the text infilling scheme is helpful to tackle the training-testing inconsistency issue of masking-based methods [11]. Thus, aside from the speaker masking, we randomly sampled 10% of the utterances, and exchange its speaker with another one in the conversation. Utterance Masking To encourage the model to grasp more contextual information at the utterance level, we further adopt a span masking scheme. More specifically, we randomly sampled 10% of the utterances and mask the whole span, and the model will try to recover it based on the context understanding. Intra-Topic Utterance Permutation While human conversations are often less structured than well-organized documents, they are inherently organized around the dialogue topics in a coarse-grained structure [21]. Therefore, we propose an intra-topic utterance permutation strategy, in which we randomly sampled 5% of the
32
Z. Liu et al.
Table 1 Pre-training data processing for dialogue samples. Here the raw utterances are extracted from the synthetic clinical inquiry-answering conversations. denotes places replaced with the mask token. denotes replacing with a random token from a vocabulary. The permutation is to change the order of two randomly selected tokens/utterances Raw utterances Processed utterances Token masking Nurse: Do you have any headache at night? Patient: No no headache, just a bit cough...
Nurse: Do you have any headache at ? Patient: No no headache, just a cough.. Nurse: Cough? you mean cough at every night? Nurse: Cough? you cough at every night? Token infilling Nurse: Do you have any headache at night? Nurse: Do you have any at night? Patient: No no headache, just a bit cough. Patient: No headache, just a bit cough.. Nurse: Cough? you mean cough at every night? Nurse: Cough? you mean cough at every night? Speaker masking Nurse: Do you have any headache at night? Nurse: Do you have any headache at night? Patient: No no headache, just a bit cough.. : No no headache, just a bit cough.. Nurse: Cough? you mean cough at every night? : Cough? you mean cough at every night? Speaker permutation Nurse: Do you have any headache at night? Do you have any headache at night? Patient: No no headache, just a bit cough.. Nurse: No no headache, just a bit cough.. Nurse: Cough? you mean cough at every night? Nurse: Cough? you mean cough at every night? Utterance masking Nurse: Do you have any headache at night? Nurse: Do you have any headache at night? Patient: No no headache, just a bit cough.. Patient: .. Nurse: Cough? you mean cough at every night? Nurse: Cough? you mean cough at every night? Intra-topic utterance permutation Nurse: Do you have any headache at night? Nurse: Do you have any headache at night? Patient: No no headache, just a bit cough.. Nurse: Cough? you mean cough at every night? Nurse: Cough? you mean cough at every night? Patient: No no headache, just a bit cough..
utterances, and exchange them with another one in the same topic. This operation injects more noise into the conversation flow, and the model can only restore the original order by leveraging the underlying dialogue discourse information. Moreover, we add a special token at the start position of each utterance, which can convey the utterance-level boundary information [30], and is similar to the sentence-level token used in document-based approaches [15].
Domain-specific Language Pre-training for Dialogue Comprehension …
33
Fig. 1 Overview of the language model pre-training. Here we adopt the bi-directional maskrecovering training as in [15]. Superscript values denote utterance ID
2.2 Experiment Setup of Pre-training With the conversation samples built on the aforementioned strategies, we conduct the language backbone pre-training, and a Transformer-based neural architecture is used [25]. To leverage the generic prior language knowledge, we select ‘RoBERTa-base’ model to initialize the Transformer model, and conduct the reconstruction-based learning process [15]. More specifically, as shown in Fig. 1, the input sequence is the dialogue content after text manipulation, and the target sequence consists of the tokens from the original utterances. In our experiment, the pre-training data was the combination of the SAMSum corpus [4], a subset of OpenSubtitles [12], and a set of synthetic clinical symptom checking conversations (40k samples). SAMSum is a social chat corpus consisting of 16k dialogues. OpenSubtitles is compiled from a large collection of TV and movie scripts across many languages, and we randomly selected 50k samples from the English part. The mixed conversational data contains a certain amount of dialogues with multiple participants, speaker role information, and conversational structures. During training, we fixed the max input length to 512 tokens. When constructing the pre-training samples, we first conducted text manipulations on tokens and speaker entities. Then the utterance-level operations were performed. We trained the language backbone with 5,000 warm-up steps. Batch size was set to 64 via applying gradient accumulation, and the initial learning rate was set at 1e-5. Cross-entropy was used as the loss function, and we selected the checkpoints at the knee point of loss decrease.
34
Z. Liu et al.
3 Dialogue Comprehension on Clinical Inquiry-Answering Conversations 3.1 Task Definition One example of the dialogue comprehension task on clinical inquiry-answering conversations is shown in Table 2. The input consists of a multi-turn symptom checking dialogue D and a question Q specifying a symptom with one of its attributes; the output is the extracted answer A from the given dialogue. A training or test sample is defined as S = {D, Q, A}. Five attributes, specifying certain details of clinical significance, are defined to characterize the answer types of A: (1) time the patient has been experiencing the symptom, (2) activities that trigger the symptom (to occur or worsen), (3) extent of seriousness, (4) frequency occurrence of the symptom, and (5) location of symptom. For each symptom/attribute, it can take on different linguistic expressions, defined as entities.
3.2 Clinical Dialogue Corpus The reading comprehension task is conducted on the data of nurse-to-patient symptom monitoring conversations. The corpus was inspired by real dialogues in the clinical setting where nurses inquire about symptoms of patients [16]. Linguistic structures at the semantic, syntactic, discourse, and pragmatic levels were abstracted from these conversations to construct templates for simulating multi-turn dialogues (40k samples in our settings). The informal styles of expressions, including incomplete sentences, incorrect grammar, and diffuse flow of topics were preserved. A team of linguistically trained personnel refined, substantiated, and corrected the automatically simulated dialogues by enriching verbal expressions through different English speaking populations in Asia, Europe, and the U.S., validating logical
Table 2 One example of the reading comprehension task on clinical inquiry-answering conversations. The synthetic dialogue is used for demonstration Conversation example (Truncated) Nurse: Hi Mr. [Name], you were discharged on [date]. There are some questions I’d like to check with you Patient: Ok, Ok ... I think I feel better ... Nurse: Is your left leg still swollen? You said so the last time I call you? Patient: Yes, only a bit when I drink too much water ... Question: What is the extent of the swollen? Reference Answer Span: only a bit
Domain-specific Language Pre-training for Dialogue Comprehension …
35
correctness through checking if the conversations were natural, reasonable, and not disobeying common sense, and verifying the clinical content by consulting certified and registered nurses. These conversations cover 9 topics/symptoms (e.g. headache, cough). For each conversation, the average word number is 255 and the average turn number is 15.5. For the comprehension task, questions were raised to query different attributes of a specified symptom; e.g., How frequently did you experience headaches? Answer spans in the dialogues were labeled with start and end indices, following the annotation scheme as in [20]. Note that if the queried symptom or attribute is not mentioned in the dialogue, the ground-truth output is “No Answer”, as the same definition in [16].
3.3 Baseline Models We further fine-tuned the Transformer-based model on the dialogue comprehension task, and compared it with several baselines, including Pointer LSTM [26], BiDirectional Attention Flow network (Bi-DAF) [22], and R-Net [27]. To evaluate the effectiveness of our domain-specific language pre-training, we use the Vanilla Transformer and the original RoBERTa-base model as control, and our proposed model is RoBERTa-base w/Domain-specific Pre-training. As shown in Fig. 2, we formulate the comprehension task as an answer extraction process. With the featurerich contextualized representation, the answer span is generated by predicting its start/end position in the sequence, by adding a linear layer on the last layer hidden states from the language modeling [1, 15].
Fig. 2 Overview of the machine comprehension model on a question-answering task. Q presents the question sequence (in green), and it is concatenated with the conversation sequence as input. Following [1, 16], the answer span is extracted from the conversation content by predicting the start and end positions (in purple)
36
Z. Liu et al.
3.4 Training Configuration All models were implemented with Pytorch and Hugging Face Transformers [18]. For models without pre-trained language backbones (e.g. Bi-DAF, R-Net), Glove embedding [19] was utilized, and out-of-vocabulary words were replaced with the token. Hidden size and embedding dimension were 300, and those of Transformerbased models were 768. We used Adam [8] with batch size 32, and gradient accumulation was applied. The initial learning rates were set at 2e-5, and dropout rate [23] was set to 0.2. During training, the validation-based early stop strategy was applied. During prediction, we selected answer spans using the maximum product of pstar t and pend .
3.5 Evaluation: Comparison with Baselines We conduct the evaluation on the synthetic clinical dialogue corpus, where the training, validation, and test size were 40k, 3k, and 3k, respectively. We adopted Exact Match (EM) and F1 score as metrics as the SQuAD benchmark [20]. As shown in Table 3, the vanilla Transformer model obtains a slightly lower performance than the non-Transformer strong baselines (i.e. Bi-DAF and R-Net), and ‘RoBERTa-base’ is on par with them. This demonstrates that the prior knowledge from general language pre-training is beneficial for the downstream tasks. With the conversational pre-training, our proposed mode obtains substantial gains and achieve the best EM and F1 scores, showing that the domain-specific feature fusion is effective.
Table 3 Evaluation result of the baseline models and our approach on the test set. Domain PT denotes the proposed domain-specific pre-training Model EM Score F1 Score Pointer LSTM Bi-Attention Flow (Bi-DAF) R-Net (Our Implementation) Vanilla Transformer RoBERTa-base w/o Domain PT RoBERTa-base w/ Domain PT
77.81 87.31 88.22 85.38 88.37 92.31
82.71 88.57 90.13 85.92 90.15 93.69
Domain-specific Language Pre-training for Dialogue Comprehension …
37
Fig. 3 Experimental result on the low-resource training. X axis is the sample size, Y axis is the evaluation metrics, including exact match (EM) and F1 score
3.6 Evaluation in Low-Resource Scenarios The limited amount of training data is a major pain point for clinical-related language tasks, as it is time-consuming and labor-intensive to collect and annotate the corpus at a large scale. Following the observation in previous work [6], we expect the domainspecific language modeling can result in more efficient learning on downstream tasks. To simulate the low-resource training scenario, we conducted experiments on a range of smaller training sizes (from 3k to 40k) with a fixed-size test set (3k samples). As shown in Fig. 3, the proposed approach outperforms all other models significantly, especially when the training size is smaller than 20k.
3.7 Evaluation: Pre-training Scheme Comparison To evaluate the effectiveness of the aforementioned strategies of pre-training sample construction. We conduct an experiment by adding different text manipulations to train the general-purpose language backbone. As shown in Table 4, we observed that the token-level infilling can bring certain improvements, and introducing the conversation-related manipulations (i.e., speaker and utterance masking and permutation) are helpful for the final dialogue comprehension performance.
38
Z. Liu et al.
Table 4 Performance comparison on pre-training schemes. The text manipulations are added from token level to utterance level Model EM score F1 score Pre-training on the RoBERTa-base backbone + Token-level Masking + Token-level Infilling + Speaker Mask & Permutation + Utterance Mask & Permutation
88.79 90.73 91.01 92.31
90.30 92.12 92.41 93.69
4 Conclusions In this paper, we introduced a domain-specific language pre-training approach, and adopted it to improve performance on downstream tasks such as question answering. Based on the linguistic characteristics of spoken dialogues, we proposed a combination of six strategies to build samples for conversational language pre-training, and conducted reading comprehension experiments on a multi-topic inquiry-answering conversation data. The experimental results showed that the proposed approach can boost performance and achieve more efficient learning outcomes. Future work include extending conversational pre-training to other clinical tasks [10, 13] and resources. Acknowledgements Research efforts were supported by funding and infrastructure from A*STAR, Singapore (Grant No. IAF H19/01/a0/023). We gratefully acknowledge valuable inputs from Angela Ng, Hong Choon Oh, Sharon Ong, Sheldon Lee, Weiliang Huang, and Ying Zi Oh at the Department of Cardiology, Health Management Unit, and Department of Health Services Research, Changi General Hospital, Singapore. We thank the anonymous reviewers for their precious feedback to help improve and extend this piece of work.
References 1. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2019). BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of NAACL-HLT 2019 (Vol. 1 pp. 4171–4186). 2. Du, N., Chen, K., Kannan, A., Tran, L., Chen, Y., & Shafran, I. (2019). Extracting symptoms and their status from clinical conversations. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (pp. 915-925). 3. Gao, X., Zhang, Y., Galley, M., Brockett, C., & Dolan, W. B. (2020). Dialogue response ranking training with large-scale human feedback data. In Proceedings of EMNLP (Vol. 2020, pp. 386–395). 4. Gliwa, B., Mochol, I., Biesek, M., & Wawer, A. (2019). SAMSum corpus: A human-annotated dialogue dataset for abstractive summarization. EMNLP-IJCNLP, 2019, 70. 5. Gu, J.-C., Li, T., Liu, Q., Ling, Z.-H., Su, Z., Wei, S., et al. (2020). Speaker-aware BERT for multi-turn response selection in retrieval-based chatbots. In Proceeding of the 29th ACM International Conference on Information & Knowledge Management (pp. 2041–2044).
Domain-specific Language Pre-training for Dialogue Comprehension …
39
6. Gururangan, S., Marasovi´c, A., Swayamdipta, S., Lo, K., Beltagy, I., Downey, D., et al. (2020). Don’t stop pretraining: Adapt language models to domains and tasks. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (pp. 8342–8360). 7. Hermann, K. M., Koˇciský, T., Grefenstette, E., Espeholt, L., Kay, W., Suleyman, M. et al. (2015). Teaching machines to read and comprehend. In Proceedings of the 28th International Conference on Neural Information Processing Systems - Volume 1, NIPS’15 (pp. 1693–1701). Cambridge, MA, USA: MIT Press. 8. Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. In Proceedings of the 3rd International Conference for Learning Representations. 9. Krishna, K., Khosla, S., Bigham, J. P., & Lipton, Z. C. (2021). Generating SOAP notes from doctor-patient conversations using modular summarization techniques. In Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Join Conference on Natural Language Processing (pp. 4958–4972). 10. Kurisinkel, L. J., Aw, A. T., & Chen, N. F. (2021). Coherent and concise radiology report generation via context specific image representations and orthogonal sentence states. NAACLHLT, 2021, 246. 11. Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mohamed, A., Levy, O. (2020). BART: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (pp. 7871–7880). 12. Lison, P., & Tiedemann, J. (2016). OpenSubtitles2016: Extracting large parallel corpora from movie and TV subtitles. In Proceedings of the Tenth International Conference on Language Resources and Evaluation (LREC’16), 923 929. 13. Liu, Z., Ng, A., Lee, S., Aw, A. T., & Chen, N. F. (2019). Topic-aware pointer-generator networks for summarizing spoken conversations. In 2019 IEEE Automatic Speech Recognition and Understanding Workshop (ASRU) (pp. 814-821). IEEE. 14. Liu, Z., & Chen, N. (2019). Reading turn by turn: Hierarchical attention architecture for spoken dialogue comprehension. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (pp. 5460–5466). 15. Liu, Y., Ott, M., Goyal, N., Du, J., Joshi, M., Chen, D., et al. (2019). Roberta: A robustly optimized BERT pre-training approach. arXiv:1907.11692. 16. Liu, Z., Lim, H., Suhaimi, N. F. A. B., Tong, S. C., Ong S., Ng, A. et al. (2019). Fast prototyping a dialogue comprehension system for nurse patient conversations on symptom monitoring. In Proceedings of the 2019 Conference of the North America Chapter of the Association for Computational Linguistics Human Language Technologies. Association for Computational Linguistics. 17. Ma, K., Jurczyk, T., & Choi, J. D. (2018). Challenging reading comprehension on daily conversation: passage completion on multiparty dialog. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Long Papers) (Vol. 1, pp. 2039–2048). Association for Computational Linguistics. 18. Paszke, A., Gross, S., Massa, F., Lerer, A., Chanan, G., Bradbury J., et al. (2019). Pytorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32, 8026–8037. 19. Pennington, J.; Socher, R., & Manning, C. (2014). Glove: Global vectors for word representation. In Proceeding of the 2014 Conference on Empirical Methods in Natural Language Processing (pp. 1532–1543). Association for Computational Linguistics. 20. Rajpurkar, P., Zhang, J., Lopyrev, K., & Liang, P. (2016). SQuAD: 100,000+ questions for machine comprehension of text. In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing (pp. 2383–2392). Association for Computational Linguistics. 21. Sacks, H., Schegloff, E. A., & Jefferson, G. (1978). A simplest systematics for the organization of turn taking for conversation. In Studies in the organization of conversational interaction (pp. 7–55). Elsevier.
40
Z. Liu et al.
22. Seo, M., Kembhavi, A., Farhadi, A., & Hajishirzi, H. (2017). Bidirectional attention flow for machine comprehension. In Proceedings of the 5th International Conference for Learning Representations. 23. Srivastava, N., Hinton, G., Krizhevsky, A., Sutskever, I., & Salakhutdinov, R. (2014). Dropout: A simple way to prevent neural networks from overfitting. The Journal of Machine Learning Research, 15(1), 1929–1958. 24. Sun, K., Yu, D., Chen, J., Yu, D., Choi, Y., & Cardie, C. (2019). Dream: A challenge data set and models for dialogue-based reading comprehension. Transactions of the Association for Computational Linguistics, 7, 217–231. 25. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones L., Gomez, A. N., et al. (2017). Attention is all you need. In Advances in neural information processing systems (pp. 5998– 6008). 26. Wang, S., & Jiang, J. (2017). Machine comprehension using match-LSTM and answer pointer.(2017). In ICLR 2017: International Conference on Learning Representations, Toulon, France, April 24-26: Proceedings (pp. 1–15). 27. Wang, W., Yang, N., Wei, F., Chang, B., & Zhou, M. (2017). Gated self-matching networks for reading comprehension and question answering. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistic (Volume 1: Long Papers) (pp. 189–198). Association for Computational Linguistics. 28. Wu, C.-S., Hoi, S. C., Socher, R., Xiong, C. (2020). TOD-BERT: Pre-trained natural language understanding for task-oriented dialogue. In Proceedings of the 202 Conference on Empirical Methods in Natural Language Processing (EMNLP) (pp. 917–929). 29. Yang, Z., & Choi, J. D. (2019). FriendsQA: Open-domain question answering on TV show transcripts. In Proceeding of the 20th Annual SIGdial Meeting on Discourse and Dialogue (pp. 188–197). 30. Zhang, Y., Sun, S., Galley, M., Chen, Y.-C., Brockett, C., Gao, X. et al. (2020). DIALOGPT: Large-scale generative pre-training for conversational response generation. In Proceedings of the 58t Annual Meeting of the Association for Computational Linguistics: System Demonstrations (pp. 270–278). 31. Zhong, M., Liu, Y., Xu, Y., Zhu, C., & Zeng, M. (2022). Dialoglm: Pre-trained model for long dialogue understanding and summarization. In Proceedings of AAAI 2022. 32. Zou, Y., Zhang, X., Lu, W., Wei, F., & Zhou, M. (2020). Pre-training for Abstractive Document Summarization by Reinstating Source Text. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP) (pp. 3646–3660).
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models Gayani Nanayakkara , Nirmalie Wiratunga , David Corsar , Kyle Martin , and Anjana Wijekoon
Abstract Good communication is critical to good healthcare. Clinical dialogue is a conversation between health practitioners and their patients, with the explicit goal of obtaining and sharing medical information. This information contributes to medical decision-making regarding the patient and plays a crucial role in their healthcare journey. The reliance on note taking and manual scribing processes are extremely inefficient and leads to manual transcription errors when digitizing notes. Automatic Speech Recognition (ASR) plays a significant role in speech-to-text applications, and can be directly used as a text generator in conversational applications. However, recording clinical dialogue presents a number of general and domain-specific challenges. In this paper, we present a seq2seq learning approach for ASR transcription error correction of clinical dialogues. We introduce a new Gastrointestinal Clinical Dialogue (GCD) Dataset which was gathered by healthcare professionals from a NHS Inflammatory Bowel Disease clinic and use this in a comparative study with four commercial ASR systems. Using self-supervision strategies, we fine-tune a seq2seq model on a mask-filling task using a domain-specific PubMed dataset which we have shared publicly for future research. The BART model fine-tuned for mask-filling was able to correct transcription errors and achieve lower word error rates for three out of four commercial ASR outputs.
G. Nanayakkara (B) · N. Wiratunga · D. Corsar · K. Martin · A. Wijekoon Robert Gordon University, Aberdeen, Scotland e-mail: [email protected] N. Wiratunga e-mail: [email protected] D. Corsar e-mail: [email protected] K. Martin e-mail: [email protected] A. Wijekoon e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_4
41
42
G. Nanayakkara et al.
Keywords Clinical dialogue transcription · Automatic speech recognition · Error correction
1 Introduction Traditional approaches to record keeping in a health service setting have relied on pen to paper for all clinical professionals who ask the same questions of the same patient. Drawbacks to this approach include the time burden of record keeping of clinical communications, the potential for error and most importantly means the patient more often than not is required to repeat and share the same detail asked in a different way. The absence of accurate clinical dialogue capture is a contributory factor to poor communication in medical practice [18]. To a patient this promotes mistrust and a feeling of fragmented care in a system that seems not to be linked up. Clinical documentation is time consuming and is associated with clinician burnout, increased cognitive load, information loss, and distractions [19]. One of the most promising avenues of automating clinical documentation with digital scribes is to use an Automatic Speech Recognition (ASR) [21] system; whereby in a process called digital transcription, audio data received as input is converted to textual data as output. Recent advances in Natural Language Processing (NLP) and adoption of cloud-based technologies have created a significant market for ASR systems. Due to the critical nature of the domain, ASR for clinical applications are expected to demonstrate high levels of performance. However variations in language, speech and environmental contexts have made it hard to achieve an acceptable levels of transcription accuracy [6]. Thus, it is important to examine strategies to mitigate or reduce the likelihood of error in a transcription. There are two approaches to correcting ASR errors: redesign and retrain the core ASR architecture; or alternatively perform a post-ASR error correction on the transcribed ASR output. In this paper we focus on the second approach and use a seq2seq fine-tuned neural model to map an ASR transcribed piece of text to its error corrected form. We select T5 [22] and BART [14] as our seq2seq models due to their dominant performance across domains. A self-supervised training strategy with fine-tuning tasks is used with a novel domain-specific dataset scraped from PubMed1 abstracts. We identified a lack of specific clinical dialogue datasets in related literature and address this deficit by introducing a novel clinical dialogue dataset which is used to test the effectiveness of our error correction models. Results from a comparative study of seq2seq models show that our proposed approach can reduce transcription errors that are introduced by several commercial ASR systems. Accordingly, our contributions are: • We demonstrate clinical dialogue error correction using the Gastrointestinal Clinical Dialogue (GCD) Dataset which we gathered in partnership with National Health Service (NHS) Scotland; 1
https://www.ncbi.nlm.nih.gov/pubmed/.
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models
43
• A self-supervision methodology to fine-tune language models for clinical dialogue error correction using a novel PubMed dataset; and • A comparative evaluation of fine-tuned language models for clinical dialogue transcription error correction. The rest of the paper is organised as follows. Section 2 presents related literature in ASR error correction methods. The Gastrointestinal Clinical Dialogue (GCD) dataset is presented in Sect. 3 followed by Sect. 4 which presents the language models considered for error correction and our approach to fine-tuning language models using the self-supervised PubMed datasets. Section 5 presents the comparative evaluation of language models and fine-tuned models for error correction using the GCD dataset. In Sect. 6 we further investigate the performance improvements we observed in the previous section to draw insights for future work. Finally we present our conclusions in Sect. 7.
2 Related Work ASR techniques are used to capture real-time speech in audio format and convert them into textual outputs. In clinical settings, ASR can be used as the initial step to gather the conversational data and to produce meaningful insights from the generated ASR transcriptions. However, ASR performance is mainly dependent on three factors: speaker variabilities (changes in voice due to ageing, illness, emotions and tiredness); spoken language variabilities (variants in speech due to accents and dialects); and other mismatch factors (communication channels and the devices) [6]. Moreover, these factors affect the performance of the ASR systems, and will generate erroneous results, from which it is challenging to extract meaningful insights. The types of errors found in speech recognition are threefold: insertion, deletion, and substitution [6]. Word Error Rate (WER) is the common evaluation metric used to evaluate the performance of ASR outputs considering the three errors as mentioned above [6, 7]. There are two alternative approaches for the ASR error correction: implement error correction algorithm within the ASR model; or as a post-processing step where the ASR outputs will be analysed for error correction. Hidden Markov Models [8, 11] and more recently deep neural architectures [9] have been explored for ASR models that include error correction. The alternative (and increasingly more common) approach involving post-ASR error corrections have in the past adopted unsupervised approaches. Early methods include lexical co-occurrence analysis on large ASR transcription corpora [23] and using statistical error correction methods [3]. FastCorrect is a more recent transformer based architecture which integrates the edit distance metric within a deep neural architecture to guide error correction [13]. Alternatively, transformer based architectures has been fine-tuned for error correction using part of the domain specific dataset (a train set) [17]. Increasingly, for post-ASR error correction there is potential to exploit recent advances in language
44
G. Nanayakkara et al.
modelling. Accordingly, in this study, we will also focus on post-ASR error correction using a transformer based architectures. However, instead of implementing a customised architecture [13] or fine-tuning with clinical dialogue data (of which there is only a very limited amount of data) [17], we explore how to effectively fine-tune a pre-trained model using publicly available clinical domain data.
3 Clinical Dialogue Transcription Clinical dialogue is a conversation, typically between a clinician and a patient, with the explicit goal of obtaining medical information regarding the patient. This information contributes to medical decision-making regarding the patient and plays a crucial role in their healthcare journey. The clinician will ask questions about the patient’s medical state, including: personal details, medical history, symptoms, prescribed medicines and clinical tests. Recording these conversations often relies on the clinician maintaining paper-based notes, which in turn requires effort and time when converting to patient Electronic Health Records (EHRs). Once created, EHRs are maintained by a centralised system and can be shared among all of the different medical specialists which may take part in a patient’s care journey. To clarify, consider the following fictitious example. A patient creates an appointment with their local general practice due to complaints of a stomach ache. The general practice clinician then uses the information in the patient’s EHR to check for allergies before prescribing new medication. Information regarding the prescription is then appended to the patient’s EHR and is visible during a follow-on appointment with a gastrointestinal specialist and allows them to try a new treatment without repetition. Accurate and timely EHRs are therefore crucial to the success of patient treatment. Accordingly, there is a demand for algorithms that efficiently create accurate electronic health records from clinical dialogues. Manual note-taking is comparatively an inefficient process compared to digitisation of clinical dialogues. Clinicians lose valuable time on administrative tasks which could be put to better use elsewhere. Furthermore, there is the possibility of error or misunderstandings being introduced at the note-taking and digitization stages. Speech-to-text transformation presents an opportunity to create (or update) EHRs from clinical dialogues. However, speech-to-text conversion in a clinical setting presents a number of challenges, both general (i.e. accurate transcription in the presence of background noise, different speaker accents and dialects, interruptions and repetitions) and domain-specific (i.e. recognising expert vocabulary, nonoverlapping expertise between conversation participants). Once corrected, the goal is to extract clinical data and insights to formulate summaries of conversations that can then be captured as part of the an electronic patient health record. To understand the performance of speech-to-text transcription of clinical dialogues we require data specifically in the form of audio and its reference transcript. To the best of our knowledge, there is no available clinical conversational dataset in the English language. Here we describe an applied machine learning task using the
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models
45
Gastrointestinal Clinical Dialogue (GCD) dataset which has been collected working in partnership with the National Health Service (NHS) Scotland. In this work we limit our scope to gastrointestinal disease related clinical conversations, which mainly took place in the Inflammatory Bowel Disease (IBD) clinic.
3.1 Gastrointestinal Clinical Dialogue Dataset The clinical conversations in the GCD dataset were generated using role-playing conversations initiated in the NHS IBD Clinic. These conversations contain clinical dialogues that often take place between an IBD clinicians and a patient. The data collection included 7 participants with Scottish accents. The accent can be viewed as a form of noise in addition to common noise factors such as background noise, interruptions and repetitions. Overall, we collected 7 audio files each with 4 ∼ 5 minutes of conversation. Each audio file contains a mean number of 47 utterances where two persons engaged in a clinical conversation. A summary of audio data statistics are presented in Table 1. The GCD dataset consist of 329 (∼47*7) data instances where each data instance has three components; audio file, reference transcript (i.e. gold standard) and multiple ASR transcripts. The reference transcript is created by listening to the audio and manually transcribing it. To create the ASR transcriptions, we used four commercially available ASR systems. These include AWS Transcribe, Microsoft speech-to-text, IBM Watson and Google speech-to-text. These ASR services were selected from a large number of ASR services both commercial and open source based on their support for British English accent and popularity. Table 2 presents some examples of reference transcriptions and their ASR Transcribe outputs (AWS Transcribe in the example) from the GCD dataset. In Fig. 1 we plot the transcription error rate measured by Word Error Rate (WER) for each ASR system. WER scores are calculated against the reference transcript of each audio file and the transcribed output from each ASR system. Accordingly, we find that Microsoft speech-to-text service generates the most accurate transcriptions from the GCD Dataset and Google speech-to-text service generates the least accurate. Although these are commercial ASR systems, the lack of knowledge on the medical domain terms and background noise may contribute to the performance differences
Table 1 Summary of the GCD dataset Feature No. of audio files Mean length of an audio file Mean no. of utterances in a file Mean no. of words in an utterance
Value 7 4 min 49 s 47 93
46 Table 2 Examples from the GCD dataset Gold reference So do you have any ideas as to what might be the cause of your symptoms at the moment? Have you noticed any changes in your weight? Okay have you noticed any mucus in your bowel motions?
G. Nanayakkara et al.
AWS transcribe output So do you have any ideas as to what might be the cause of your symptoms at the moment? Do you noticed any changes in your wit? Okay have you noticed any mucus in your bible Moshe?
Fig. 1 Comparison of ASR performance
seen in Fig. 1. In the next section we present our post-error correction approach using seq2seq models.
4 Methods We view error correction as a seq2seq task performed using an Encoder-Decoder (ED) architecture based language models. Here the text generated by the ASR forms the input to the encoder-part of the ED architecture, and the decoder-part is trained to generate the reference text. As illustrated in Fig. 2, a pre-trained language model needs also to be fine-tuned. The reasons for this are two fold: the ED models are general-purpose and not fine-tuned for terminology in the medical domain (i.e. vocabulary gap); and they are not fine-tuned to perform error-correction (i.e. objective gap). Essentially we need to integrate the domain vocabulary (e.g. gastrointestinal terminology) into the pre-trained language model that we want to use for error correction. As discussed in Sect. 3 there is only a limited amount of data available for our clinical dialogue error correction task. Accordingly, it is challenging to use this dataset for both training (fine-tuning) and testing as we would in traditional machine learning. Instead we curate a dataset of Gastrointestinal text extracted from PubMed
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models
47
Fig. 2 Clinical dialogue transcription error correction using seq2seq model
to fine-tune the general purpose language models on three distinct self-supervision tasks which are closer in nature to the error correction task. In this manner we aim to address the challenges posed by the gap in vocabulary and the difference in training objectives respectively.
4.1 General Purpose Base Language Models Seq2seq learning has been performed using ED architectures created based on Recurrence [1, 2] and Transformers [14, 22]. Transformer based ED architectures are the state-of-the-art and there exist different variants trained for different language modelling tasks. In this paper we consider the Bidirectional and Auto-Regressive Transformer (BART [14]) and seq2seq Transfer Transformer (T5 [22]) models, both of which have been pre-trained using large language corpora.
48
4.1.1
G. Nanayakkara et al.
T5 Model
T5 is a transformer-based architecture that is pre-trained on a mixture of seq2seq tasks [22]. T5 architecture resembles the original transformer architecture [24] with 12 transformer blocks in both the encoder and the decoder. The T5 model is trained on the Colossal Clean Crawled Corpus (C4) dataset, a collection of clean and natural English text from the internet in April 2019 (750GB in size). Multiple T5 model variants based on the number of attention blocks have been introduced [22]. In this work, we fine-tune the T5-small, T5-base and T5-large models with 6, 12 and 24 attention blocks respectively.
4.1.2
BART
The Bi-directional Auto-Regressive Transformer (BART) also uses the standard transformer architecture pre-trained for multiple seq2seq tasks. BART can be simplified as a transformer based architecture where the encoder is a generalised BERT (i.e. bidirectional encoder) [5], and the decoder is a GPT (i.e. left-to-right decoder) [20]. The BART model is trained on data from a combination of books and Wikipedia data, consisting of news, books, stories, and web text (160GB in size). The model is trained by adding noise to text using an arbitrary noising function and learning the model to reconstruct the original text.
4.2 PubMed Gastrointestinal Dataset PubMed is a collection of ∼33 million citations of biomedical literature collated from sources such as MEDLINE, life science journals and online books created by the US National Library for Medicine. It provides a search and retrieval engine for the public to extract biomedical articles. These articles either contains full-text and abstract text and are annotated by unique record identifiers called PMID. PubMed has been used as a resource for content classification [4], biomedical question and answering [10] and biomedical entity recognition [12]. We also extract data from PubMed to create the PubMed Gastrointestinal Dataset. The main goal of extracting data from PubMed is to create a dataset introducing medical terminology to base language models. When choosing a dataset, there are differences between written and spoken English within specialist domains. The lack of availability of a larger spoken corpus has led us to use the written corpus, which is currently the most viable alternative. To this end, we adopted the following protocol to extract data from PubMed: 1. crawl PubMed to extract paper titles and abstracts. In this work we limit the search to articles related to gastrointestinal research. Accordingly the search terms are gastrointestinal symptoms, diagnosis, clinical, examination, and patient;
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models Table 3 Summary of the PubMed dataset Feature Number of title, abstract pairs Mean no. of words in title Mean no. of words in abstract
49
Value 11,772 102 1,533
2. clean titles by removing the Unicode characters; and 3. clean abstracts by removing the different Unicode characters, equations, figures, tables and URLs. After pre-processing we obtain a dataset with title and abstract pairs (see Table 3 for corpus statistics). The methods presented in this paper can be generalised to any medical domain by using the domain-specific search queries in the PubMed database search engine.
4.3 Fine-Tune Using Self-Supervision The approach of using the same unsupervised data to create multiple training objectives is known as self-supervision. When fine-tuning the base language models we need to create self-supervisions tasks that have the general structure of an inputoutput text pair (i.e. seq2seq). When creating self-supervised datasets from PubMed data we are keen to empirically identify which fine-tuning task is best suited to introducing medical terminology as well as performing error-correction once fine-tuned. We created three variants of the PubMed dataset with a view to performing three different fine-tuning tasks on the base language models2 . In Table 4 we present an example for each fine-tune task. Summarisation task generates a summary for a given text input. In the PubMed dataset, the abstract is considered as the input and the title is considered as the gold standard for the expected summary. No changes are required to the original PubMed dataset described in the previous section. Table 4 first row presents an example of abstract and title pair for summarisation from the PubMed dataset. Paraphrasing task generates a re-phrased text for a given text input. The goal of paraphrasing is to represent the meaning of a given text using different or re-arranged words. From the PubMed dataset, we used titles as the input to a T5 model fine-tuned for paraphrasing using the Google PAWS Dataset [25] to generate a paraphrased versions of the titles. The resulting dataset has title and paraphrased title pairs as shown in the example Table 4 second row. For finetuning using this dataset, we use the paraphrased title as the input and the title as the reference text. 2
Accessible from the Hugging Face dataset repositories https://huggingface.co/gayanin.
50
G. Nanayakkara et al.
Table 4 Examples from the PubMed gastrointestinal dataset Task Input Summarisation
Paraphrasing
Mask-filling
Helicobacter pylori is a worldwide infection. It is estimated that approximately 50% of the general population is affected, but this percentage varies considerably between countries. … This study confirms relatively high prevalence of H. pylori seropositivity among Italian healthy adults and points to sex, age, BMI and sociocultural class as persisting determinant features of H. pylori infection. Determinants of seroprevalence of Helicobacter pylori among Italian blood donors Determinants < mask > < mask > pylori < mask > among Italian blood donors
Output Determinants of Helicobacter pylori seroprevalence among Italian blood donors
Mask-filling is the task of predicting matching text for masked tokens in a text. To prepare the PubMed dataset for mask-filling we augment titles such that 25% of the words in the title are replaced with the word < mask >. The resulting dataset has title and masked title pairs as shown in example in Table 4 third row. For the mask-filling fine tuning, we use the masked title as the input and the title as the reference text.
5 Evaluation The aim of this evaluation is two-fold. Firstly we measure the efficacy of base language models to perform error correction on clinical dialogue in Sect. 5.2. Secondly, we identify which fine-tuning task is best for the clinical dialogue error correction task in Sect. 5.3.
5.1 Performance Metric The metric selected to measure error correction is Word Error Rate (WER). WER is derived from the Levenshtein distance which measures the differences between two strings. WER has been used as a performance metric in ASR systems [6] and in postASR error correction [13, 16]. Given a language model output and a reference text,
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models Table 5 Comparison of base language models Model WER (%) Name Version AWS Transcribe T5
BART
T5-Small T5-Base T5-Large BART-Base BART-Large
55.41 214.08 163.96 38.29 66.95
51
Microsoft
IBM Watson
Google
54.87 205.56 163.54 30.95 55.40
61.74 209.84 153.64 42.63 61.50
64.20 162.63 137.19 44.47 55.12
WER is calculated using Eq. 1. Here S, D and I refer to the number of substitutions, deletions, and insertions operations needed to transform the reference text to the language model output. C refers to words that are equal in both reference and the output and N refers to the number of words in the reference text and N = S + D + C. Lower WER scores are desirable. WER =
S+D+I S+D+I = N S+ D+C
(1)
5.2 Comparison of Base Language Models We compare the base language models detailed in Sect. 4.1 for the task of clinical dialogue error correction. These models were pre-trained using large public domain language corpora for multiple language modelling tasks (i.e. Wikipedia [14], Book Corpus [14] and web extracted text [22]). However they are not fine-tuned to the medical domain or error-correction task. We implement these models using the Python Hugging Face3 and PyTorch4 libraries while maintaining all default hyperparameters. Models are evaluated using the four ASR outputs in the GCD datasets from Sect. 3 on which mean WER is reported as a percentage.
5.2.1
Results
Table 5 represents the WER scores for T5 and BART model variants evaluated on each ASR output. Overall, smaller models (i.e. less transformer blocks) achieve lower WER compared to larger models consistently across all four datasets. In T5 models, both base and large model variants have a WER score of more than 100%. Models with a higher number of transformer blocks tend to generate longer sentences. 3 4
https://huggingface.co/. https://pytorch.org/.
52
G. Nanayakkara et al.
Accordingly the number of insert operations (see Eq. 1) are higher when the output sentence is longer which results in a WER score higher than 100%. Similar performance is observed with the two BART variants where larger model is producing higher WER due to output length. Between T5 and BART, the winner is BART-Base, although both fail to surpass the ASR WER scores (see Fig. 1). Accordingly, we will study the impact of fine-tuning just the T5-Small and BART-Base models. Table 7 presents outputs generated for a sample input using T5 and BART base model variants. The sample input and its references text is randomly selected from the AWS Transcribe outputs and reference texts in the GCD dataset. The outputs from the T5 and BART models with increased number of transformer layers have evidently generated longer outputs which contributed to higher WER scores. Also, none of the models were able to correct the medical term present in the sample input.
5.3 Comparison of Fine-Tuned Language Models In this section, we compare the fine-tuned language models detailed in Sect. 4.3 for the task of clinical dialogue error correction. These models are fine-tuned using the three self-supervising PubMed datasets we presented in Sect. 4.3. The hyper-parameters used in the fine-tuning were: optimiser is AdamW [15]; loss is cross-entropy; learning rate is 2e − 5; and batch size is 16. For each fine-tuning task, the dataset was split 90/10 for training and evaluation. For each model the number of fine-tuning epochs varied between 10 and 40 as the fine-tuning was stopped when minimal evaluation loss was observed. For the summarisation task, the encoder input and decoder output sequence lengths were set to 1024 and 128 respectively; for paraphrasing and maskfilling tasks both encoder input and decoder sequence lengths were set to 512. Models are tested for error correction using the four ASR outputs on the GCD dataset from Sect. 3 to report mean WER as a percentage.
5.3.1
Results
Table 6 represents the WER scores for the fine-tuned model variants for the tasks summarisation, paraphrasing and mask-filling for each ASR system output. Overall, mask-filling is the best performing fine-tuning task for both BART and T5 models across all four datasets. Summarisation task resulted in the highest WER scores, which makes it is an unsuitable fine-tuning task for error correction. High WER is caused by the difference between the input and output sequence lengths used in summarisation, whereby error correction expects similar lengths for both input and output. From paraphrasing and mask-filling tasks where the input and output sequences match, mask-filling has achieved the best performance. In addition to generating an output of the expected size, the model has learned to correct errors when fine-tuned for mask-filling. In fact, intuitively, mask-filling is most similar to error-correction of the three fine-tuning tasks where it is teaching the model to
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models
53
Table 6 Comparison of Fine-tuned Language Models Model
WER (%)
Name
Fine-tune task AWS Transcribe
Microsoft
T5-Small
Summarisation 63.39
66.89
69.44
73.80
Paraphrasing
48.87
47.24
54.52
57.97
Mask-filling
46.87
BART-base
IBM Watson
Google
38.83
35.86
45.16
Summarisation 76.61
77.03
78.10
75.56
Paraphrasing
43.31
37.46
47.51
49.48
Mask-filling
32.38
26.38
38.92
40.43
find missing words. Importantly, mask-filling has improved the BART models to outperform ASR performance (see Fig. 1) with the Microsoft, IBM Watson and Google datasets and to perform comparably with the AWS Transcribe dataset. Table 7 presents outputs generated for a sample input using fine-tuned models. The models fine-tuned for summarisation generate shorter text which resulted in increased WER scores. In comparison, models fine-tuned for paraphrasing and mask-filling are generating text that are comparable to the input in length. Moreover, we observe several improvements to the output text such as T5-paraphrasing replacing the word “fortnightly” with “two-weekly” and BART Mask-filling accurately improving “Andi I know” to “And I know”. However, even after fine-tuning for medical terminology, the models struggle to find the phonetic similarity between the ASR output and medical terminology. For example models fail to correct “Adelaida map” to “Adalimumab” which will be a focus area for us to improve in future work.
6 Discussion The aim of this exploratory evaluation is to find out the type of utterances that contributed to the performance improvements with BART fine-tuned for mask-filling. To this end, first we split each ASR output in the GCD dataset in to data instances that ASR correctly transcribed (ASR output equal to reference text) and incorrectly transcribed. Then we evaluate each subset of data using the best performing models we found in Sect. 5.3 which are T5 and BART models fine-tuned for mask-filling. In Table 8 we present the results in which we opt to present the number of utterances instead of WER for easy interpretation. In the ASR row we present the baselines where Equal row refers to the number of utterances correctly transcribed by ASR and Different row refers to the number of utterances incorrectly transcribed by ASR. For example, Google speech-to-text ASR system only transcribed one utterance correctly and 224 utterances had some differences when compared to the reference text. There is a difference between the
54
G. Nanayakkara et al.
Table 7 Examples from language model variants Output Reference text And I know that youve been on fortnightly Adalimumab Input Andi I know that youve been on fortnightly Adelaida map? Baseline models T5-Small Andi I know that youve been on the fortnightly Adelaida map T5-Base Been on fortnightly Adelaida map. Andi I know that youve been tagged as andi Im sure that where youre on, but im not sure if its just me or am I right? T5-Large Youve been on the fortnightly Adelaida map been on fortnightly Adelaida map. Andi I know that youve been tagged as andi Im sure thats where youre on, but im not sure if its just me or am I right? BART-Base Andi I know that youve been on fortnightly Adelaida maps BART-Large AndAll of the time.Andi I know that youve been on fortnightly Fine-tuned models T5-Summarisation I know that youve been on fortnightly Adelaida map T5-Paraphrase Andi I know that youve been on the two-weekly Adelaida map T5-Mask-filling Andi I know that youve been on fortnightly Adelaida map BART-Summarisation Adelaida map on fortnightly basis BART-Paraphrase I know youve been on the fortnightly Adelaida map BART-Mask-filling And I know that youve been on fortnightly Adelaida map
total number of audio utterances of 329 (from Table 3) and Google ASR outputs. For given audio file, ASR system is generating a different number of utterances compared to the reference text. This is because, the ASR system is skipping some audio content that it cannot transcribe with a level of confidence. Evidently, ASR systems with higher error have less number of total utterances in Table 8. At first glance, we find that the number of equal utterances are reduced with T5 and BART models and the number of different utterances are increased. But a close examination reveals that the performance improvements we observed in Sect. 5.3 over ASR system is due to making less errors within utterances that are different.
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models Table 8 Comparison of ASR and language model outputs for error correction Model Data Number of Instances AWS Microsoft IBM Transcribe Watson ASR T5 for Mask-filling BART for Mask-filling
Equal Different Equal Different Equal Different
30 269 7 292 1 298
27 273 6 294 1 299
7 277 2 282 0 284
55
Google 1 224 1 224 1 224
This is clearly seen when comparing the number of utterances for Google where the numbers are not different from ASR yet achieved a performance improvement of 9.35% (49.78–40.43) and 2.91% (49.78–46.87) with the BART and T5 models. However we find that our models are introducing errors to utterances that ASR systems have correctly transcribed. A methodology to mitigate this will be explored in future work.
7 Conclusion In this paper, we presented a seq2seq learning approach for clinical dialogue transcription error correction. Given the lack of clinical dialogue data, we presented an approach which uses public medical domain data to fine-tune a language model to introduce domain specific clinical terms. We found out that PubMed data from a specific domain can be used in self-supervised manner to create data to fine-tune a general purpose seq2seq model. Importantly our results suggest that the choice of fine-tuning task has a significant impact on the post-ASR error correction task. Specifically we found that the mask-filling was closely aligned to the target transcription error correction task compared to alternative fine-tuning tasks such as summarisation or paraphrasing. With this method, we were able to surpass the performance of three out of four commercial ASR systems on a comparative study with fine-tuned T5 and BART seq2seq models. In future work, we plan to introduce new fine-tune tasks with more self-supervised data to improve model knowledge of phonetic relationships between medical and regular phrases and improve error correction performance. Acknowledgements We would like to thank the National Health Service (NHS) Grampian, including Amy Kousha, John Thomson and Jill Ferbrache, who helped to curate the IBD clinic role-playing dialogues.
56
G. Nanayakkara et al.
References 1. Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv:1409.0473. 2. Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H., et al. (2014). Learning phrase representations using rnn encoder-decoder for statistical machine translation. arXiv:1406.1078. 3. Cucu, H., Buzo, A., Besacier, L., & Burileanu, C. (2013). Statistical error correction methods for domain-specific ASR systems. In A. H. Dediu, C. Martín-Vide, R. Mitkov, & B. Truthe (Eds.), Statistical language and speech processing (pp. 83–92). Berlin, Heidelberg: Springer. 4. Dernoncourt, F., & Lee, J. Y. (2017). Pubmed 200k rct: A dataset for sequential sentence classification in medical abstracts. arXiv:1710.06071. 5. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2019). Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv:1810.04805. 6. Errattahi, R., El Hannani, A., & Ouahmane, H. (2018) Automatic speech recognition errors detection and correction: A review. Procedia Computer Science 128, 32–37. 7. Filippidou, F., & Moussiades, L. (2020). A benchmarking of IBM, google and wit automatic speech recognition systems. In I. Maglogiannis, L. Iliadis, & E. Pimenidis (Eds.), Artificial intelligence applications and innovations (pp. 73–82). Cham: Springer International Publishing. 8. Humphries, J. J., Woodland, P. C., & Pearce, D. J. B. (1996). Using accent-specific pronunciation modelling for robust speech recognition. In Proceeding of Fourth International Conference on Spoken Language Processing. ICSLP ’96 (Vol. 4, pp. 2324–2327). 9. Jain, A., Upreti, M., & Jyothi, P. (2018). Improved accented speech recognition using accent embeddings and multi-task learning. In INTERSPEECH. 10. Jin, Q., Dhingra, B., Liu, Z., Cohen, W., & Lu, X. (2019). Pubmedqa: A dataset for biomedical research question answering. In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP) (pp. 2567–2577). 11. Kamper, H., & Niesler, T. (2011). Multi-accent speech recognition of Afrikaans, black and white varieties of South African English. In Proceedings of the Annual Conference of the International Speech Communication Association, INTERSPEECH (pp. 3189–3192). 12. Lee, J., Yoon, W., Kim, S., Kim, D., Kim, S., So, C. H., & Kang, J. (2020). Biobert: A pretrained biomedical language representation model for biomedical text mining. Bioinformatics, 36(4), 1234–1240. 13. Leng, Y., Tan, X., Zhu, L., Xu, J., Luo, R., Liu, L., et al. (2021). Fastcorrect: Fast error correction with edit alignment for automatic speech recognition. arXiv:2105.03842. 14. Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mohamed, A., Levy, O. et al. (2019). Bart: Denoising sequence-to-sequence pre-training for natural language generation, translation, and comprehension. arXiv:1910.13461. 15. Loshchilov, I., & Hutter, F. (2017). Decoupled weight decay regularization. 16. Mani, A., Palaskar, S., & Konam, S. (2020) Towards understanding ASR error correction for medical conversations. In NLPMC. 17. Mani, A., Palaskar, S., Meripo, N. V., Konam, S., & Metze, F. (2020) ASR error correction and domain adaptation using machine translation. In ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (pp. 6344–6348). IEEE. 18. McDonald, A., & Sherlock, J. (2016). A long and winding road - improving communication with patients in the NHS. 19. Quiroz, J., Laranjo, L., Kocaballi, A. B., Berkovsky, S., Rezazadegan, D., & Coiera, E. (2019). Challenges of developing a digital scribe to reduce clinical documentation burden. npj Digital Medicine 2, 114. https://doi.org/10.1038/s41746-019-0190-1 20. Radford, A., & Narasimhan, K. (2018). Improving language understanding by generative pretraining.
Clinical Dialogue Transcription Error Correction Using Seq2Seq Models
57
21. Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. 22. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M. et al. (2020). Exploring the limits of transfer learning with a unified text-to-text transformer. arXiv:1910.10683. 23. Sarma, A., & Palmer, D. D. (2004). Context-based speech recognition error detection and correction. In NAACL. 24. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N. et al. (2017). Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems NIPS’17 (pp. 6000-6010). 25. Zhang, Y., Baldridge, J., & He, L. (2019). Paws: Paraphrase adversaries from word scrambling. arXiv:1904.01130.
Customized Training of Pretrained Language Models to Detect Post Intents in Online Health Support Groups Tootiya Giyahchi, Sameer Singh, Ian Harris, and Cornelia Pechmann
Abstract Online support groups offer low-cost and accessible health and mental health support, but low engagement challenges their effectiveness. To encourage engagement and provide evidence-based responses appropriate to participants’ needs, we propose an intent detection model for online support groups based on stateof-the-art natural language processing methods. Such a model enables a chatbot that can increase interactions and improve discussion quality. Posts in social media are often short and noisy, especially in group chat. Furthermore, many intents lack data, overlap and/or have specific priorities. We create a human-annotated dataset of posts with intent labels from 45 three-month online support groups for quitting smoking. We then train and examine models to predict the intent behind each post. To reduce the effect of noisy and sparse data, we fine-tune a massive pretrained language model. Also, to represent the unique relationships between intents, we design customized loss functions for training. Empirical evaluations show significant performance improvements with the proposed method; our best model obtains 95.5% accuracy. We also use a fine-grained set of intents and obtain higher accuracy compared to prior models on online health forums and communities. Accurate detection of fine-grained intents opens up new opportunities to improve online self-help support groups. Keywords Text classification · Online support groups · Chatbot · Smoking cessation
T. Giyahchi (B) · S. Singh · I. Harris · C. Pechmann University of California, Irvine, CA, USA e-mail: [email protected] S. Singh e-mail: [email protected] I. Harris e-mail: [email protected] C. Pechmann e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_5
59
60
T. Giyahchi et al.
1 Introduction Social media-based health interventions are increasing in the medical field to provide low-cost and accessible support at scale. However, a significant concern about such support groups is low engagement, i.e., infrequent interactions [1]. Many studies have shown a correlation between engagement and better outcomes in online health support groups [2–4]. Research [5] also suggests if participants in an online group do not receive responses to their posts promptly enough, they are likely to drop out. Using novel methods or tools to increase interactivity in such groups can positively affect user engagement and outcomes [3, 4, 6]. An accurate intent detection model facilitates solutions to increase engagement and interactivity in the group while requiring minimal human effort. For example, an effective intent detection model enables a chatbot to actively respond to posts with relevant content to address participants’ concerns and provide evidence-based health information and support. Previous intent detection models in the health domain have either focused on 1:1 conversations or a limited set of intents and have generally concentrated only on the original thread of posts. However, text classification for online support groups requires detection of the main topic discussed by the group at each point in time. Furthermore, the discussion topics and desired categories for detection are particular to the group’s purpose; hence, there is often a lack of data for the application-specific task. Further, the available datasets are often extremely imbalanced, with many important labels being rare. Additionally, online group chats are extra noisy due to parallel conversations and talk-turn interruptions. Lastly, there are application-specific relationships between the labels: labels may be symmetrically or asymmetrically related, and instances of them may be ambiguous. This paper introduces customized training of pretrained language models for automated prediction of user intents in online support groups. We create an annotated dataset of 45 online support groups for quitting smoking from the Tweet2Quit study [7] as our empirical task. The dataset consists of 82000 posts labeled with 24 expertidentified intent codes. We experiment with fine-tuning a BERT language model to address the noisy dataset’s problems and reduce the need for structured and highincidence labeled data for the domain-specific task. To improve the accuracy of detecting important infrequent intents, we balance class weights for training. Finally, we take the unique relationships between the labels into account and propose an adaptive modeling approach by designing customized loss functions and adjusting the evaluation metrics. Overall, we show that these techniques provide considerable improvements in recognition of specific intents in an online health support group with an accuracy of 95.5%.
Customized Training of Pretrained Language Models …
61
2 Background and Related Work Chatbots in Health Using chatbots in health domains is increasing because they can provide online assistance through interactive conversations at users’ convenience and for a very low cost. Many chatbots have been proposed for affordable and scalable promotion of well-being or helping users cope with mental or physical problems [8, 9]. Understanding the user’s intent is essential for such chatbots to reply with helpful content. [10] Studies suggest that using a chatbot that can accurately predict users’ mood and respond appropriately improves engagement. Researchers [11, 12] also report that chatbots with greater accuracy are more effective at comforting users and improving their mood. However, all the mentioned chatbots interact 1:1 with users. They also give users several fixed options to choose from, and reply based on users’ preformatted responses. It is rare for them to detect a user’s mood from sensor data [10]. For example, Woebot [9] is a self-help chatbot for mental health that relies mainly on pre-written questions and answers. It only uses a natural language algorithm to interpret freely written text in limited contexts: to detect self-to-harm and crisis language. It then asks the user for confirmation and, if the user confirms, it offers resources. For chatbots to expand into online self-help support groups, they must be able to detect intents based on freely written text. They cannot rely on pre-set questions and answers to identify intents since this would be unnatural and disrupt the ongoing and dynamic group discussions that are the hallmark of online communities. Moreover, most of the important intents are specific to the health topic of the group and infrequent in other corpora, and so using only pretrained models is not an option [13]. Classification of Medical/Mental Health Conversations Intent detection of posts or conversations in medical and mental health contexts has been a topic of emerging interest in recent years. One important application has been to identify intents in health forums. Zhang et al. [14] try to understand 3 user intents in an online health forum to help users find useful information within the unstructured datasets. Using a support vector machine (SVM) classifier, their best model achieves 52.47 F1. McRoy et al. [15] study online health forums for breast cancer patients and survivors to detect posts expressing information needs that could be used to improve forum resources and materials. They develop Naïve Bayes, SVM, and Random Forest classifiers to detect expressed information need in 8 categories and their best model obtains 63 F1. Huh et al. [16] use a Naïve Bayes classifier in an online health community (WebMD) to detect posts that require an expert moderator’s intervention. Their proposed model detects 4 intents for classification with the best model performing at 54 F1. Another emerging health application is to detect intents in face-to-face medical sessions between patient and provider. Park et al. [17] experiment with different classifiers to detect 27 patient-provider conversation topics in primary care office talk-turns to more efficiently understand the patient’s most significant complaint among all those expressed. Their best model attains 61% accuracy. Xiao et al. [18]
62
T. Giyahchi et al.
classify patients’ utterances in psychotherapy sessions based on domain-specific behavioral codes (8 therapist codes and 3 client codes) using a Recurrent Neural Network (RNN) model to provide guidance for the therapists, and they achieve 75% accuracy in therapist code prediction. Intent detection in online support communities and groups involves unique challenges as the posts often consist of various group- or subgroup-level discussions that transpire concurrently but may be on different topics, leading to exceptionally unstructured and noisy data. Intent codes may overlap even within a single post. Moreover, domain-specific data may be sparse, and the distribution of important intents may be especially imbalanced with some rare intents being very important to predict accurately such as self-harm or crisis events. Multi Party Dialogues While virtually all prior work has focused on detecting intents in one-on-one conversations or question-answer discussion forums, our context is a group chat setup where approximately 20 peers talk to each other as part of an online self-help group. Intent classification in such groups is crucial to design a helpful chatbot but it involves numerous challenges. Intent detection must not be focused on one person but rather on the group or subgroup discussions. Also, in most cases, the chatbot should only reply to posts that align with the group’s health or mental health goals, and not to inappropriate tangential posts that could be detrimental to group member satisfaction and outcomes. The chatbot cannot prompt with clarifying questions to understand the intents, as any irrelevant input could disturb the group’s ongoing conversations and damage engagement. Moreover, the model must be able to accurately identify the most dominant and relevant intent being discussed in the group at any one time, so the chatbot contributes to the ongoing conversation meaningfully and appropriately, without being disruptive or derailing the dialogue which would be counterproductive. Currently, there are very few chatbots designed to engage in an online group discussion. Savage et al. [19] and Kim et al. [20] both use a chatbot to improve collaboration among group members, but their goal is to ensure equal participation by members, which is just one of the many goals we have for improving engagement. Seering et al. [21] report improvement in an online community’s engagement using a chatbot but they examine a gaming community, not a self-help group for health. Seering et al. [22] suggest chatbots have great potential to serve online groups and communities and can help to address challenges regarding maintaining and moderating group members and they urge more research in the area.
3 Tweet2Quit Dataset To explore the problem of modeling a post intent predictor in an online health support group, we use data from Tweet2Quit [3]. Tweet2Quit is a social media intervention for quitting smoking, where in addition to receiving free nicotine replacement therapy (NRT), smokers seeking to quit are assigned to a private 20-person Twitter-based
Customized Training of Pretrained Language Models …
63
support group to interact with peers and exchange information. A previous study on Tweet2Quit [4] has shown that engagement is a challenge because it is often low; but the more a participant posts the more likely they are to maintain abstinence from smoking ( p < 0.001). Other main challenges with smoking cessation support groups are participants’ hesitancy to use NRT and their other struggles with medical regimen compliance [23]. While sending daily auto-messages with relevant questions to encourage peer-to-peer discussions has been effective in Tweet2Quit [3], the abstinence rate is still below 40% [4]. We believe that an automated intent detection system can be used to design a chatbot that can create a better interactive environment and encourage medical regimen compliance by contributing to the group discussion based on the immediately preceding post intent to improve engagement and ultimately successful smoking abstinence. The next section explains data collection and annotation and discusses how we identified the intent labels.
3.1 Data Collection We collected a dataset of posts from 45 groups in two prior Tweet2Quit studies. Eight groups came from the first study conducted from 2012–2014 [3] and 36 groups plus one pilot came from the second study conducted from 2016-2019. Each group ran for a three month period with 20 members, and the mean number of posts per group was 1822. Overall we collected more than 82000 posts by Tweet2Quit participants.
3.2 Identification of the Intents to Annotate Tweet2Quit researchers identified intents that were desirable or useful posts, and arranged for these intents to be annotated, so that a chatbot will be able to accurately detect and respond to such intents. The ultimate goals are to increase the number of desirable posts, enhance engagement and improve abstinence. Intents are considered desirable for the online groups if they are relevant to quitting smoking or important to the proper functioning of the support groups, and they have meaningful frequency. First, Tweet2Quit researchers identified a set of important intents related to clinical practice guidelines about the medical regimen for quitting smoking [24]. Medical regimen intents included use of Nicotine Replacement Therapy (NRT) both its efficacy and side effects, setting a quit date, and e-cigarettes. Studies on Tweet2Quit [3] affirm that posts about medical regimen compliance, e.g., setting a quit date or use of nicotine patches, relate significantly to smoking abstinence. Facts about quitting smoking were also included as intents including health benefits, money saved, weight gain and second-hand cigarette smell, because clinical practice guidelines recommend that doctors discuss these topics with smokers.
64
T. Giyahchi et al.
Table 1 Tweet2Quit Intent Labels by Category Category Intent labels Medical regimen
Empathy for negative events Empathy for positive events Facts Greetings Monresponse
nrt_dontwork, nrt_dreams, nrt_howtouse, nrt_itworks, nrt_misuseissues, nrt_od, nrt_skinirritation, nrt_stickissue, quitdate, ecigs Fail, scared, stress, tiredness, cravings Smokefree, smokingless, support Cigsmell, savingmoney, health, weightgain Greetings Nonresponse
In addition, intents were identified that relate to online community building, emotional bonding and self-disclosure among the group members based on the literature on online community functioning [6]. Community-building intents included greetings to group members, and empathy for both positive and negative events related to members’ quit-smoking attempts (e.g., successes and failures) which were self-disclosed. Tweet2Quit research has shown that bonding through self-disclosure increases the strength of social ties within the support group which significantly enhances smoking abstinence [25]. The final intent category included all posts that were tangential to the support group goal of quitting smoking, and were labeled nonresponse because the chatbot should not respond to such posts. Our recent analysis shows that, as anticipated, such posts are statistically unrelated to smoking abstinence. Overall we identified 24 intent codes that were annotated as shown in Table 1 with their corresponding categories. Twenty-three of the intents are designed to be triggering intents that will cause the chatbot to contribute to the group chat when recognized, while one (“nonresponse”) represents all other intents that are designed to be non-triggering, meaning the chatbot will not respond. Table 1 provides descriptions and examples of some of the more important intent labels (Table 2).
3.3 Reliability The overall reliability of the annotation process was determined based on Cohen’s Kappa measure as 0.93, which is considered high agreement between the annotators [26]. The reliability of each label was calculated considering the number of times two research assistants agreed on the intent compared to the total number of posts with that intent based on the final annotation. Per label reliability scores ranged from 87.3% for “nrt_itworks” to 97.2% for “nonresponse”.
Customized Training of Pretrained Language Models …
65
Table 2 Description and examples of important post intents for Tweet2Quit online support groups Intent code Description Examples nrt_howtouse
Asks question or gives instructions about how to use NRT products
nrt_misuseissues
States NRT gum/lozenge has bad taste, irritates throat, causes sense of burning or spicy, makes nauseous or gag States NRT is too strong and causes overdose
nrt_od
nrt_itworks
States NRT works
Tiredness
States feeling tired
Smokefree
States success in being smoke free
How often should I use the lozenges? Chew it for a little bit and put between cheek and gum The gum has a strong nicotine taste. The gum burns my mouth. It hurts my throat. I want to throw up The 21 mg patches are making me sweat and feel bad. I am doing better with the 14 MG The patches work if you’re determined. Lozenges are good. I still use the gum for cravings I feel like I am about to fall asleep. I am out of energy It has been 13 days since my last smoke. I’ve had some severe cravings today but worked through them
3.4 Annotation Process The annotation of intents took place from October 2019 through February 2021 with 25 research assistants working on over 82000 posts. The majority of the research assistants were undergraduates with two being high school seniors. The students’ majors included Public Health, Nursing, Biology, Biomedical Engineering, Urban Studies, Cognitive Science, Education, Sociology, Business Administration, Economics and Informatics and the students were from the United States primarily but also Brazil, Germany and China. All posts were annotated. During the training of the research assistants, the project manager discussed each intent in detail and provided formal definitions, key words and examples. After this initial training, the research assistants were given a practice set of posts and were required to achieve an 80% or higher accuracy in terms of selecting the correct intent. For continued learning and training, a database of posts for each intent was maintained, shared and referred to as needed. The training meetings were conducted in person in the beginning, then transitioned to Microsoft Teams once the team expanded to included international researchers and due to COVID-19. Meetings were held weekly until formal training was completed and then became biweekly. At later meetings, posts that were difficult to annotate were discussed by the team to determine the most suitable codes to use and why. During
66
T. Giyahchi et al.
Fig. 1 The distribution of intent labels in the combined training and validation dataset grouped by domain. 56.2% fell into the “nonresponse” (support group irrelevant) class
the post-training or annotation phase, each post was reviewed by two trained research assistants working independently to determine what intent fit the post best, including whether the post fell into the “nonresponse” (irrelevant) category. If the two reviewers disagreed, a third more highly expert research assistant was brought in to review the post. This third reviewer worked independently as well, i.e., without seeing the intents assigned by the first two. Whenever two research assistants agreed on the intent, that was the final intent that was annotated; otherwise the third research assistant who had the greatest domain-specific expertise determined the final annotated intent. The annotation process resulted in an extremely unbalanced dataset with more than 56% of the posts labeled as “nonresponse”, i.e., support group irrelevant. Figure 1 shows the distribution of intents in the training and validation dataset aggregated by their categories.
4 Models We use the annotated dataset to solve a supervised text classification problem to develop an intent detection model for our online support groups. We explore and compare two different NLP model strategies: 1-Bag-of-Words: Using Tf-Idf (Term Frequency Inverse Document Frequency) vectorization and a Random Forest model. 2-Language Models: Transfer Learning of a BERT model.
Customized Training of Pretrained Language Models …
67
4.1 Random Forest (Baseline) Random Forest (RF) [27] is one of the best performing classifiers for text classification and is specifically suitable for noisy and high-dimensional data such as social media posts [28, 29]. RF creates a set of decision trees that each decides for a random collection of features. We evaluate the performance of an RF model for our text classification task as our baseline.
4.2 Pretrained Language Models In recent years, pretrained large neural language models [30–32] have demonstrated state-of-the-art performance for all types of downstream tasks. Being trained on a massive unlabeled corpus of text, these models offer powerful universal language representations that can significantly improve the results with proper fine-tuning on a target task. A major advantage of transfer learning with pretrained large neural language models is their ability to tackle dataset sparsity; more than 56% of our labeled dataset consists of"nonresponse" posts unimportant for the intervention. Among pretrained language models, BERT (Bidirectional Encoder Representations from Transformers) [30] has obtained some of the bests results on popular NLP benchmarks such as GLUE [33]. BERT is based on a deep bidirectional transformer and is trained for masked word prediction (MLM) and next sentence prediction. Thus BERT generates token-level and sentence-level representations for both left and right contexts. We investigate the effect of fine-tuning a BERT [30] on task performance and study if it can address issues regarding noise and sparsity in the dataset.
5 Adapting to the Labels’ Relationships Although we first experiment assuming our task is a standard multi-class classification, intents are usually not completely independent in an online support group. Also, different mispredictions have different levels of impact on the group, depending on actual and predicted classes. For example, suppose the model mispredicts a “fail” labeled post (stating a failure to quit smoking) as “smokefree” (meaning success in quitting), and the chatbot responds wrongly with praise. Not only would that be an irrelevant and disruptive message, but it could also have a counter-productive effect on the group by conveying that smoking is ok. But if the model mispredicts a “nrt_skinirritation” post (comment that the NRT patch causes skin irritation) as “nrt_howtouse”, the chatbot’s response would not be perfect but still would be helpful and relevant to the topic to some degree. These relationships may be asymmetric, i.e., misprediction of intent X as intent Y may be tolerable, but misprediction of intent Y as intent X may be unacceptable and harmful to the support group.
68
T. Giyahchi et al.
Fig. 2 Acceptable mispredictions: For each true label (y-axis), its acceptable predictions are marked (x-axis). The asymmetry in the labels’ relationships implies that the classes should not be combined for training or evaluation
To understand the special relationships between the intents in our dataset, we asked domain-expert Tweet2Quit researchers to answer if every possible misprediction would be acceptable or not. Figure 2 shows the final tolerable mispredictions and demonstrates the asymmetries in the intents’ relationships. For instance, when users post seeking empathy for a positive or negative experience, a general empathetic supportive response should be acceptable if the model cannot recall the specific label with high confidence. In contrast, when users seek help with using NRT, the bot should respond with suitable NRT use guidance.
5.1 Customized Loss Functions To represent the unique relationships between the labels in our model and adjust the training process based on that, we define multiple customized Negative Log
Customized Training of Pretrained Language Models …
69
Likelihood (NLL) loss functions and use them for fine-tuning the pretrained model. A standard NLL loss for a neural network is formulated as: ln = −w yn xn,yn N 1 N (x, y) = n=1 l w n
(1)
yn
n=1
Where N is the batch size, x is the predicted vector, y is the target vector, and w is the class weight. To adjust our loss function based on the label’s relationship, for each label y, we consider a vector z of size C that represents all the acceptable labels for y. Then we use z to mask out all y-acceptable indices of x and sum over the rest to calculate the loss. We experiment with different versions of adjusted loss functions, as shown here: • Customized Non-balanced Loss: vn = z yn ◦ xn,yn l =− Nn ln (x, y) = n=1 N
C i=1
vni
(2)
While this loss function focuses on non-acceptable mispredictions to penalize the model during training, it doesn’t consider the class weights. We introduce the following loss that uses the original true label’s class weight for balancing the loss function. • Customized Weighted Loss: vn = z yn ◦ xn,yn
ln = −
(x, y) =
C
N ln Nn=1 n=1 w yn
i=1 (vn i )w yn
(3)
Since we mask out the tolerable predictions to calculate the loss, using only the true label’s class weight for balancing may not be the best option. For the following loss function, we introduce u n that aggregates (uses the mean of) the weights of all acceptable classes for prediction to compute a balanced loss. • Customized Balanced Loss: W = {w y , . . . , w yc } un =
C
i=1 tn i
|tn |
tn = z yn ◦ W C ln = − i=1 (vni )u n
(x, y) =
N l n=1 n n=1 u n
(4)
6 Experiments In this section, we present experiments to evaluate the proposed methods in Sects. 4 and 5. We compare the results of different models and discuss our observations.
70
T. Giyahchi et al.
6.1 Experimental Setup From the 45 labeled groups, the seven chronologically latest groups are set aside for testing, and the remaining 38 groups’ posts are randomly split into training and validation (development) sets using a 75–25% ratio, respectively. We use stratified sampling to make sure the split is inclusive for the imbalanced dataset. We perform training using the training dataset, find the best version of the model using the validation set, and then report the model’s performance using the test dataset. Random Forest (Baseline) To train an RF model, we first use common preprocessing techniques to clean the text data before extracting features for vectorization; we eliminate uninformative noisy data from the text such as mentions and links, convert the contracted form of verbs, and decode emojis. We also use Tf-Idf to vectorize text data for 1 to 3-g, and remove English stop words tokens along with setting a threshold to identify and remove corpus-specific stop words. Table 3 contains the performance results after training an RF classifier on our training dataset. As a result of our imbalanced dataset, aggregated recall and F1-score are very low for our baseline model, and the model does not detect many important labels properly. High precision scores and poor recall and F1-score scores for the infrequent intents indicate the model’s poor recognition of the labels. The RF model does not recognize labels like “nrt_don’twork”, “nrt_howtouse” and “smokingless”.
6.2 Pretrained Language Models To fine-tune a BERT model on our training dataset, we use the pretrained parameters as our starting point and fine-tune all parameters while appending a dense layer and a softmax layer specific to our task. As our dataset is imbalanced, we calculate a compensating balanced set of class weights to improve the model’s prediction for the rare labels. We fine-tune the model for 15 epochs and evaluate the model after each training epoch on the validation dataset to pick the best performing model. In this stage, we use validation accuracy (weighted average recall) to choose the best epoch. Table 3 compares the results for training the RF and the BERT. As we expect, using pretrained language models causes significant improvement in every aggregated metric compared to the RF model. In addition, performance scores of the infrequent labels dramatically improve compared to RF. RF scores are better solely for “nonresponse” recall and for precision with some less frequent labels, but overall show low recall and F1. These results indicate the ability of pretrained BERT to address dataset sparsity.
Customized Training of Pretrained Language Models …
71
Table 3 BERT versus RF performance with the original metrics Intent label Precision Recall F1 nrt_dontwork nrt_dreams nrt_howtouse nrt_itworks nrt_misuseissues nrt_od nrt_skinirritation nrt_stickissue Quitdate Ecigs Fail Scared Stress Tiredness Cravings Smokefree Smokingless Support Cigsmell Savingmoney Health Weightgain Greetings Nonresponse Macro avg Weighted avg
Support
RF
BERT
RF
BERT
RF
BERT
0.0 82.6 0.0 50.0 50.0 0.0 100.0 100.0 76.1 100.0 83.3 100.0 87.5 71.4 55.1 66.4 0.0 85.6 65.8 64.7 74.1 70.0 82.3 76.4 64.2 76
56.4 57.8 46.3 56.0 53.1 30.4 63.9 84.2 80.9 78.2 69.1 77.6 85.1 56.3 53.0 74.2 56.5 78.1 70.2 55.0 54.5 56.7 85.2 91.2 65.4 84.7
0.0 35.8 0.0 1.5 2.4 0.0 2.9 2.4 54.7 17.8 3.3 7.2 23.6 11.6 16.0 35.4 0.0 53.9 32.9 24.4 15.0 29.5 69.9 97.6 22.4 77
43.7 69.8 42.3 59.8 63.4 70.0 65.7 75.3 75.3 93.2 61.4 62.7 77.0 62.8 65.8 77.2 78.8 78.3 86.8 78.9 77.2 74.1 78.7 89.1 71.1 84
0.0 50.0 0.0 2.9 4.7 0.0 5.6 4.6 63.6 30.2 6.3 13.5 37.2 20.0 24.8 46.2 0.0 66.1 43.9 35.5 24.9 41.5 75.6 85.7 28.5 72.7
49.2 63.2 44.2 57.9 57.8 42.4 64.8 79.5 78.0 85.0 65.1 69.3 80.9 59.3 58.7 75.7 65.8 78.2 77.6 64.8 63.9 64.2 81.8 90.1 67.4 84.3
71 53 104 132 41 10 35 85 320 73 153 83 148 43 269 821 33 2167 76 90 267 166 634 10518 16392 16392
6.3 Loss Functions and Adjusted Metrics The conventional performance metrics for multi-class classification assume all classes are completely independent and treat all mispredictions the same. However, as we explained earlier in Sect. 5, this is not the case in our problem. Here we describe how we adjust our metrics and evaluate customized loss functions which consider the special relationships between classes in our application. To adjust the evaluation metrics we move the acceptable mispredictions from the false negative and false positive counts to the corresponding true positive counts. True negative counts of a label continue to indicate that the misprediction is unacceptable. To examine the proposed customized loss functions, we re-evaluate the
72
T. Giyahchi et al.
performance of our RF and BERT (original loss) models with the new adjusted metrics for comparison. In addition, besides adjusted accuracy, we determine the macro average F1-score as the conclusive metric to pick the best model during the validation phases. We pick the macro average since the majority class (nonresponse) is least important in our problem, and we do not want it to have more weight in our calculation. For customized training of BERT with the non-balance (Eq. 2), weighted (Eq. 3), and balanced (Eq. 4) loss functions, we follow the same configuration and process explained in Sect. 6.2. We train the model for 15 epochs for each of the proposed customized losses, and pick the best model based on the macro-F1 in the validation phase. Table 4 summarizes the results for fine-tuning BERT with the customized loss functions. For all the macro averaged metrics, every customized training performs better than the original loss function (+4–15%) demonstrating the effectiveness of the proposed loss functions. Furthermore, using the suggested loss functions, the accuracy increases to 94.3–95.5% (from 87.5% using NLL loss). Although the original loss displays better weighted precision and F1-scores, given our highly imbalanced dataset, its significantly lower macro recall demonstrates how customize training improves the model’s recognition of the infrequent labels towards the desired performance. Overall the non-balanced loss method scores the best, excelling on macro-precision, macro-F1 and accuracy. Weighting or balancing on weight class may not work as well in our application since we accept predictions from classes with different weights, e.g., a minority class may be acceptably predicted as an instance of a majority class.
7 Conclusion and Discussion This paper seeks to address the low engagement and interactivity issues in online health support groups via a post intent detection model. Subtle detection of post intents is the first step to effectively intervening with engaging strategies such as a Chatbot to improve the quality of discussions and the health outcomes of the groups.
Table 4 Evaluating loss functions using adjusted metrics for BERT. Best performance in bold Loss function Macro average Weighted average P R F1 P R F1 RF (baseline) NLL Nonbalanced Weighted Balanced
66.3 68.1 80.4 71.6 71.8
23.8 73.5 78.9 82.1 88.8
30.5 70.0 77.6 73.6 77.1
83.3 87.7 70.0 59.6 58.9
78.5 87.5 95.5 94.5 94.3
80.8 87.6 80.8 73.1 72.5
Customized Training of Pretrained Language Models …
73
While previous work focuses on 1:1 conversations or question-answer forums, our context is a live group chat where approximately 20 peers post. Accurate detection of intents is more critical in a group chat setup where irrelevant interruptions are likely to disrupt group conversations and functioning and discourage participation. However, intent detection in such an environment is exceptionally challenging due to extremely noisy turn-taking and sparsity of certain important intents in the dataset. Furthermore, intent labels are often overlapping with asymmetric relationships. We present an expert-annotated dataset with a fine-grained set of 24 intents from support groups for quitting smoking and use it to explore the problem of intent detection. We propose fine-tuning a pretrained language model (BERT) with a customized loss function representing the relationship between the labels as a promising solution that obtains 95.5% accuracy in our application. To our knowledge, no prior intent detection model in online health communities has performed this well. Although our experiments are limited to samples from online support groups for smoking cessation, given the large and fine-grained set of intents that we recognize compared to other related works, our method may well have a bearing on different online support groups. As this paper aims to identify the most dominant intent of each post to respond appropriately, future research could usefully explore multi-label intent detection in online support groups where a single post contains multiple unrelated intents. A further study could investigate ways to utilize contextual information from the whole group discussions to improve intent recognition and to distinguish between labels involving similar words. Recognition of intents expressed as jokes, memes, or particular references (e.g., inter-group incidents, movies, books, etc.) is another interesting topic for future work. Finally, more research is needed to test and evaluate the effectiveness of our work for increasing engagement in online support groups.
References 1. Arguello, J., Butler, B., Joyce, E., Kraut, R., Ling, K. S., Rosé, C., & Wang, X. (2006). Talk to me: Foundations for successful individual-group interactions in online communities. Proceedings of the SIGCHI Conference on Human Factors in Computing Systems. 2. Richardson, A., Graham, A. L., Cobb, N., Xiao, H., Mushro, A., Abrams, D., & Vallone, D. (2013). Engagement promotes abstinence in a web-based cessation intervention: Cohort study. Journal of Medical Internet Research, 15(1), e14. 3. Pechmann, C., Pan, L., Delucchi, K., Lakon, C. M., & Prochaska, J. J. (2015). Development of a twitter-based intervention for smoking cessation that encourages high-quality social media interactions via automessages. Journal of Medical Internet Research, 17(2), e50. 4. C. Pechmann, K. Delucchi, Lakon, C. M., & Prochaska, J. J. (2017). Randomised controlled trial evaluation of tweet2quit: A social network quit-smoking intervention. Tobacco Control, 26(2), 188–194. 5. Joyce, E., & Kraut, R. E. (2006). Predicting continued participation in newsgroups. Journal of Computer-Mediated Communication, 11(3), 723–747. 6. Gruzd, A., & Haythornthwaite, C. (2013). Enabling community through social media. Journal of Medical Internet Research, 15(10), e248 Oct.
74
T. Giyahchi et al.
7. Pechmann, C., Calder, D., Phillips, C., Delucchi, K., & Prochaska, J. (2020). The use of webbased support groups versus usual quit-smoking care for men and women aged 21–59 years: Protocol for a randomized controlled trial. JMIR Research Protocols, 9. 8. Ly, K. H., Ly, A.-M., & Andersson, G. (2017). A fully automated conversational agent for promoting mental well-being: A pilot RCT using mixed methods. Internet Interventions, 10, 39–46. 9. Prochaska, J. J., Vogel, E. A., Chieng, A., Kendra, M., Baiocchi, M., Pajarito, S., & Robinson, A. (2021). A therapeutic relational agent for reducing problematic substance use (woebot): Development and usability study. Journal of Medical Internet Research, 23(3), e24850. 10. Ghandeharioun, A., McDuff, D., Czerwinski, M., & Rowan, K. (2019). Emma: An emotionaware wellbeing chatbot. 2019 8th International Conference on Affective Computing and Intelligent Interaction (ACII) (pp. 1–7). 11. de Gennaro, M., Krumhuber, E. G., & Lucas, G. M. (2019). Effectiveness of an empathic chatbot in combating adverse effects of social exclusion on mood. Frontiers in Psychology, 10. 12. Bickmore, T., & Schulman, D. (2007). Practical approaches to comforting users with relational agents. In CHI Extended Abstracts. 13. Sarker, A., Belousov, M., Friedrichs, J., Hakala, K., Kiritchenko, S., Mehryary, F., Han, S., Tran, T., Rios, A., Kavuluru, R., de Bruijn, B., Ginter, F., Mahata, D., Mohammad, S. M., Nenadic, G, & Gonzalez-Hernandez, G. (2018). Data and systems for medication-related text classification and concept normalization from Twitter: Insights from the Social Media Mining for Health (SMM4H)-2017 shared task. Journal of the American Medical Informatics Association, 25(10), 1274–1283. 14. Zhang, T., Cho, J. H. D., & Zhai, C. X. (2015). Understanding user intents in online health forums. IEEE Journal of Biomedical and Health Informatics, 19, 1392–1398. 15. McRoy, S., Rastegar-Mojarad, M., Wang, Y., Ruddy, K. J., Haddad, T. C., & Liu, H. (2018). Assessing unmet information needs of breast cancer survivors: Exploratory study of online health forums using text classification and retrieval. JMIR Cancer, 4(1), e10. 16. Huh, J., Yetisgen-Yildiz, M., & Pratt, W. (2013). Text classification for assisting moderators in online health communities. Journal of Biomedical Informatics, 46(6), 998–1005. Special Section: Social Media Environments. 17. Park, J., Kotzias, D., Kuo, P. B., RobertLLogan, I. V., Merced, K., Singh, S., Tanana, M. J., Taniskidou, E. K., Lafata, J., Atkins, D. C., Tai-Seale, M., Imel, Z. E., & Smyth, P. (2019). Detecting conversation topics in primary care office visits from transcripts of patient-provider interactions. Journal of the American Medical Informatics Association : JAMIA, 26, 1493– 1504. 18. Xiao, B., Can, D., Gibson, J., Imel, Z. E., Atkins, D. C., Georgiou, P., & Narayanan, S. S. (2016). Behavioral coding of therapist language in addiction counseling using recurrent neural networks. In INTERSPEECH. 19. Savage, S., Monroy-Hernandez, A., & Höllerer, T. (2016). Botivist: Calling volunteers to action using online bots. In Proceedings of the 19th ACM Conference on Computer-Supported Cooperative Work & Social Computing, CSCW ’16, New York, NY, USA (pp. 813–822). Association for Computing Machinery. 20. Kim, S., Eun, J., Changhoon, O., Suh, B., & Lee, J. (2020). Bot in the Bunch: Facilitating Group Chat Discussion by Improving Efficiency and Participation with a Chatbot (pp. 1–13). New York, NY, USA: Association for Computing Machinery. 21. Seering, J., Luria, M., Ye, C., Kaufman, G., & Hammer, J. (2020). It Takes a Village: Integrating an Adaptive Chatbot into an Online Gaming Community (pp. 1–13). New York, NY, USA: Association for Computing Machinery. 22. Seering, J., Luria, M., Kaufman, G., & Hammer, J. (2019). Beyond Dyadic Interactions: Considering Chatbots as Community Members (pp. 1–13). New York, NY, USA: Association for Computing Machinery. 23. Kerr, A. N., Schillo, B. A., Keller, P. A., Lachter, R. B., Lien, R. K., & Zook, H. G. (2019). Impact and effectiveness of a stand-alone NRT starter kit in a statewide tobacco cessation program. American Journal of Health Promotion, 33(2), 183–190 PMID: 29747516.
Customized Training of Pretrained Language Models …
75
24. Anderson, J. E., Jorenby, D. E., Scott, W. J., & Fiore, M. C. (2002). Treating tobacco use and dependence: An evidence-based clinical practice guideline for tobacco cessation. Chest, 121(3), 932–941 March. 25. Pechmann, C., Yoon, K., Trapido, D., Prochaska, J. (2020). Perceived costs versus actual benefits of demographic self-disclosure in online support groups. Journal of Consumer Psychology, 10. Forthcoming. 26. McHugh, M. L. (2012). Interrater reliability: The kappa statistic. Biochemia Medica, 22(3), 276–282 October. 27. Breiman, L. (2001). Random forests. Machine Learning, 45(1), 5–32 October. 28. Salles, T., Gonçalves, M. A., Rodrigues, V., & Rocha, L. (2018). Improving random forests by neighborhood projection for effective text classification. Information Systems, 77, 1–21. 29. Islam, M. Z., Liu, J., Li, J., Liu, L., & Kang, W. (2019). A semantics aware random forest for text classification. In Proceedings of the 28th ACM International Conference on Information and Knowledge Management, CIKM’19, New York, NY, USA (pp. 1061–1070). Association for Computing Machinery. 30. Devlin, J., Chang, M.-W., Lee, K., & Toutanova, K. (2019). Bert: Pre-training of deep bidirectional transformers for language understanding. In NAACL. 31. Liu, Z., Lin, W., Shi, Y., & Zhao, J. (2021). A robustly optimized bert pre-training approach with post-training. In S. Li, M. Sun, Y. Liu, H. Wu, L. Kang, W. Che, S. He, & G. Rao (Eds.), Chinese Computational Linguistics, Cham (pp. 471–484). Springer International Publishing. 32. Yang, Z., Dai, Z., Yang, Y., Carbonell, J., Salakhutdinov, R. R., & Le, Q. V. (2019). Xlnet: Generalized autoregressive pretraining for language understanding. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox, & R. Garnett (Eds.), Advances in Neural Information Processing Systems (Vol. 32). Curran Associates, Inc. 33. Wang, A., Singh, A., Michael, J., Hill, F., Levy, O., & Bowman, S. (2018). GLUE: A multi-task benchmark and analysis platform for natural language understanding. In Proceedings of the 2018 EMNLP Workshop BlackboxNLP: Analyzing and Interpreting Neural Networks for NLP, Brussels, Belgium (pp. 353–355). Association for Computational Linguistics.
EXPECT-NLP: An Integrated Pipeline and User Interface for Exploring Patient Preferences Directly from Patient-Generated Text David Johnson, Nick Dragojlovic, Nicola Kopac, Yifu Chen, Marilyn Lenzen, Sarah Le Huray, Samantha Pollard, Dean Regier, Mark Harrison, Amy George, Giuseppe Carenini, Raymond Ng, and Larry Lynd Abstract Understanding patient preferences for drug therapies can help to inform drug development, reimbursement decisions by insurers, and shared decision-making by patients and clinicians. Typically, preferences are assessed using lengthy and costly studies involving qualitative analysis of interviews and focus groups along with complex stated preference surveys. We have developed a more efficient method that allows users (e.g. health services researchers, drug developers, or market researchers) to assess patient preferences semi-automatically from online patient-generated text using a weakly supervised aspect-based sentiment analysis pipeline. To facilitate D. Johnson (B) · N. Dragojlovic · N. Kopac · Y. Chen · D. Regier · M. Harrison · A. George · G. Carenini · R. Ng · L. Lynd University of British Columbia, Endowment Lands, BC, Canada e-mail: [email protected] N. Dragojlovic e-mail: [email protected] N. Kopac e-mail: [email protected] Y. Chen e-mail: [email protected] D. Regier e-mail: [email protected] G. Carenini e-mail: [email protected] R. Ng e-mail: [email protected] L. Lynd e-mail: [email protected] M. Lenzen · S. Le Huray Multiple Sclerosis Society of Canada, Toronto, ON, Canada S. Pollard · D. Regier BC Cancer Agency, Vancouver, BC, Canada e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_6
77
78
D. Johnson et al.
users’ ability to generate new insights on patient preferences using our system, we have developed a dedicated user interface, which we present here. Our hope is that this will enable users to conduct frequent and adaptable assessments of patient preferences for drug therapies over-time. Keywords Natural language processing · Sentiment analysis · Information visualization · Human computer interaction
1 Introduction When undergoing treatments for conditions such as multiple sclerosis or cancer, patients may be presented with multiple options for treatment, such as a choice between two drugs with injectable versus oral modes of administration [19]. Patients must weigh the potential advantages and disadvantages of each treatment—e.g. one treatment option may be infrequent and inexpensive, but quite painful, whereas another is painless but frequent and expensive. Understanding how patients tend to prioritize different considerations when deciding on which drug to take is valuable for a range of stakeholders, including public and private insurers (who often take patient preferences into account when deciding on which drugs to fund), drug developers (who can use such knowledge to inform which drug candidates to prioritize for development), and patients themselves (for whom other patients’ experiences can be a valuable input into their own treatment decisions). Unfortunately, common methods for determining patient preferences, such as discrete choice experiment surveys (DCEs) [1, 20], have at least three significant limitations: (i) length constraints for DCEs mean that only a limited number of drug characteristics can be evaluated in a single survey (typically 6 or fewer are selected from a larger set of potential attributes identified in preliminary focus groups); (ii) they often use convenience samples of participants who may not represent the full diversity of the patient population under study [6]; and (iii) this type of study can be resource-intensive and time-consuming. Qualitative approaches to preference exploration suffer from the same limitations in terms of sampling and cost, with a focus group (the recommended first step of a DCE [4]) requiring an average of 28.8 personnel hours [5], and 5 focus groups being necessary to reach a point of 90% thematic saturation (i.e. where additional data does not add new relevant concepts) [21]. The main goal of the research presented in this paper was to develop a system which allows (semi)automated discovery of patient preferences from unstructured patient-generated text (e.g. online discussion forums). We achieve this with EXPECTNLP (EXploration of Preferences and Experiences in Collected Texts using Natural Language Processing), which uses weakly-supervised aspect based sentiment analysis [23], enhanced with a novel lexicon refinement protocol, to annotate the source text with relevant concepts, themes, and sentiment, and provides an interface for visual analysis of the extracted data that allows a human analyst to efficiently generate insights from this data. The interface allows the exploration of the extracted
EXPECT-NLP: An Integrated Pipeline and User Interface …
79
data through visualizations of broad summary statistics juxtaposed with trends over time, co-occurrence of terms, filters, and details on demand. By automating major elements of qualitative patient preference research and text analysis, our system is intended to facilitate more efficient and agile assessment of patient preferences for drug therapies. This would be valuable for both patients considering their own treatment decisions and domain experts in pharmaceutical development and policy. Beyond patient preferences, our system can also be used for other tasks such as the discovery of subjective patient experience, evaluation of trends in treatment mentions and sentiment over time in support of post-marketing surveillance and product acceptance monitoring, or public health concerns such as vaccine hesitancy. In summary, our hypothesis is that patient-generated text can be used in EXPECTNLP through its aspect based sentiment analysis pipeline and visualized with its interface to (semi)automatically evaluate patient preferences for drug therapies, discover subjective patient experience, and track these features over time, complementing or even avoiding more resource-intensive methods. Overall, this paper makes the following key contributions: (1) while applying the aspect-based sentiment analysis ABSApp system to a large corpus of Reddit posts, we developed a new protocol for human-in-the-loop refinement of ABSApp generated lexicons; (2) by following an iterative design process, informed by multiple use case studies and informal discussions with potential stakeholders, we have developed a visual interface that supports the exploration of extracted aspects, sentiments, and other user data.
2 Aspect Based Sentiment Analysis Sentiment analysis is the task of extracting from text a polarity value measuring whether associated text is being expressed with positive, negative, or neutral emotion. Though this can be done at various levels of granularity (sentence, paragraph, etc.) aspect based sentiment analysis (ABSA) predicts finer-grained sentiment for aspects of entities referred to in the text (e.g., a drug’s side-effects). In order to explore the experiences and preferences expressed by patients in online text, we developed a pipeline to extract aspects related to drug therapies and associated opinions. The key component of our pipeline is ABSApp [23], a recent system for extracting aspect-opinion pairs from unannotated text. We chose ABSApp because it offers a reasonable weakly-supervised compromise for our purposes between noisy ABSA unsupervised methods [24] and labeled data-hungry fully supervised techniques [16, 27]. ABSApp provides an underlying ABSA model. The model uses relational rules [24] based on the output of a neural dependency parser [15] for extracting aspectopinion pairs. Then, it stores the extracted aspects and opinions in saved lexicons. After extraction on pilot data, the ABSApp interface outputs the lexicons as simple Excel sheets that the user can manually edit as necessary by adding, removing, or
80
D. Johnson et al.
changing any aspects or opinions. The edited lexicons can then be used for inference, the next stage of the pipeline. During inference aspect opinion pairs from the lexicon can be extracted from new unseen data.
3 Data Model Our dataset is composed of user-generated text collected from online forums. The web platform we used for this analysis is Reddit, a social media site that separates various topics into “subreddits”, including health-focused forums like r/multiplesclerosis, where patients (and potentially non-patients) post. We recognize that the Reddit dataset may be biased and may not be a representative sample of the general population, which is a possible limitation of this dataset (as is typically the case with social media data [25]). However, a growing body of literature indicates Reddit plays an important role as a source of health information and in health-related discussions [2, 3, 8, 22]. While we originally chose to use Reddit because of its convenient API and large volume of targeted posts in relevant subreddits, our method is not restricted to Reddit data, and can be applied to any other social media that allows export of textual data. Using the Reddit API, we downloaded 50,042 posts (including both original posts and comment replies) from the following subreddits: r/multiplesclerosis, r/rheumatoid, r/cancer, r/testicularcancer, and r/breastcancer. Attached to each user post are the following attributes: post title, author’s user name, time of post, and score of post (users can “upvote” posts they like and “downvote” posts they do not like, generating a net score of votes). Lexicon Refinement Protocol. Since automatic aspect based sentiment analysis is still a very noisy process, ABSApp supports a human-in-the-loop step where the user can manually improve the system output. However, the details of such a refinement strategy are still unclear and could vary across domains and applications. To address this challenge, after using ABSApp to generate lexicons during training, we developed a protocol to narrow down our lexicons to domain-specific terms useful for exploring patient preferences for drug therapies. The aspect lexicon was exported from ABSApp as an Excel spreadsheet with a term ID, the aspect term, its aliases (variations on the term), and up to 20 excerpts of text from which the aspect was extracted. Aspects were then examined by three reviewers, all of whom are domain experts in health research, to determine the relevance of the aspect terms to treatment preferences. For our purposes, treatments were defined to include disease-modifying therapies, medication for symptom management, surgery, lifestyle interventions, and any other treatment modality that was characterized by posters as having the potential to improve the course of the disease or improve symptoms. Aspects were included if they unambiguously referred to a treatment in the context of the text excerpt. Drug names and classes were considered relevant by definition. If an aspect was clearly irrelevant, it was deleted at first pass. If it was not clearly
EXPECT-NLP: An Integrated Pipeline and User Interface …
81
irrelevant, the reviewers examined the associated excerpts and retained the aspect if at least 50% of excerpts were relevant. If a retained aspect appeared to refer to the same concept as a previously reviewed aspect, it was included as an alias for the previously reviewed aspect. Aspects with between 50 and 80% of relevant excerpts were also reviewed to see if any more specific multi-word alternatives (e.g., “long term” and “short term” instead of the extracted aspect “term”) existed. If they did, a new aspect entry was created for each alternative. If the multi-word aspect captured the full scope of relevance for the original aspect, the multi-word term replaced the original term; if it potentially captured additional concepts, both aspects were retained. This process was complete once multi-word terms were added, aspects consolidated, and all aspects marked for retention or deletion. The reviewers next compared their coding, with disagreements resolved by discussion until consensus was achieved. Agreement prior to the consensus process was assessed using Fleiss’ multirater kappa [7]. As a final step, each retained aspect was assigned to one or more thematic categories. An initial list of categories was created from the literature on drug preferences (e.g., effectiveness, side effects, mode of administration), as well as inductive additions by reviewers. The results were compared, and disagreements or uncertainty resolved by discussion until there was consensus. Once this process had been refined, one reviewer went through the remaining 800 aspects using the above method. Aspects which were neither strongly included nor excluded, scoring between 40 and 60% relevant excerpts, were reviewed by a second reviewer, and then reviewers compared coding, with disagreements resolved by discussion. Our lexicon refinement protocol has shown effectiveness in other contexts, such as when extracting aspects and sentiment for predicting suicide risk in youth from clinical notes [9].
4 Aspect Extraction Results After developing our lexicon refinement protocol and using the protocol to create a refined domain-specific lexicon, we used the lexicon in the inference pipeline step to extract aspects from our target data. We began by applying our ABSA pipeline to a small set of documents obtained from transcriptions of patient interviews and focus groups about treatment preferences, some of which have been previously published [11, 18]. The transcribed documents contained approximately 1200 sentences, and the system extracted 45 aspects. We then applied our pipeline to approximately 12,000 Reddit posts, extracting 714 aspects and discovering that our extracted lexicon was a superset of the lexicon extracted from the transcribed interviews. With our intention for a human-in-the-loop to refine the extracted lexicons, we wanted to determine whether the number of aspects saturates at a level tenable for manual refinement. To find a saturation point we applied our pipeline to increasingly
82
D. Johnson et al.
Table 1 Posting volume, frequency, and intensity, by subreddit Forum Unique users Posts Coverage (days) Posts/user (sd) r/rheumatoid r/multiplesclerosis r/testicularcancer r/cancer r/breastcancer Total (pooled)
1,802 2,108 964 1,946 1,095 7,762
13,550 12,384 10,008 8,680 5,420 50,042
203 65 337 969 405 969
7.5 (23.7) 5.9 (14.0) 10.4 (23.3) 4.5 (8.8) 5.0 (9.6) 6.5 (17.0)
greater amounts of data, first trying approximately 26,000 documents and extracting 945 aspects, then 51,000 documents and extracting 1016 aspects. This suggests that adding increasing amounts of data beyond that point is unlikely to significantly increase the number of aspects in the lexicon. This saturation will allow for manual lexicon refinement even on very large datasets without having the refinement time commitment become prohibitively large (Table 1). Lexicon Refinement Results. A development sample of 200 aspects from the aspect lexicon was reviewed by the three reviewers. Fleiss’ kappa for aspect retention prior to consensus was 0.731 (95% CI, 0.729–0.734), which constitutes substantial agreement. Of these 200 aspects, 32 were retained, and 22 multi-word aspects were added, yielding a refined aspect lexicon of 54 aspects for these 200 aspects. After review of the full set of 1016 aspects, 146 aspects were retained, and 44 multi-word aspects were added. 190 aspects were included in the final refined lexicon. Aspects were coded into the following (non-exclusive) thematic categories: treatment (n = 107), pharmaceutical (a subset of treatment) (n = 55), effectiveness (n = 24), side effects (n = 30), cost (n = 3), route of administration (n = 9), dosage (n = 8), frequency of administration (n = 7), healthcare professional (n = 16), healthcare professional preference (a subset of healthcare professional) (n = 5), evidence (n = 4), uncertainty (n = 3), alternatives (n = 14), and other (n = 7). Extracted Aspects. ABSApp extracted a total of 30,075 aspect-opinion pairs from the five subreddits, ranging from 9,027 for r/rheumatoid to 4,127 for r/breastcancer. 18,124 aspect-opinion pairs were categorized as referring directly to treatments, with 9,688 of those referring specifically to pharmaceutical treatments. Aspect-opinion pairs were extracted from 13,627 posts (27.2% of all posts in our data). Of those posts with at least one aspect-opinion pair extracted, 49.5% (6,751) contained only one pair, 23.3% (3,180) contained two, 11.4% (1,557) contained three, and 15.7% (2,139) contained four or more. On average, each of these posts had 2.2 aspect-opinion pairs extracted (SD = 2.0, median = 2, min = 1, max = 32). Co-occurrences in Topic Segments. To support analytic user tasks within the EXPECT-NLP interface, we also extract instances of co-occurring aspect terms (e.g. the user may be comparing two aspects such as, “Ocrevus worked well for me, but Tecfidera was disappointing”). Our goal in extracting co-occurrences is to identify instances in which users are discussing terms within the same topic of discussion.
EXPECT-NLP: An Integrated Pipeline and User Interface …
83
To facilitate this extraction, we use TextTiling [12]. TextTiling works by assuming that a vocabulary changes when a speaker changes topics, and by discovering the boundaries of maximal changes in vocabulary over blocks of text, one can predict a change in a speaker’s topic. We refer to blocks of text predicted to be about the same topic as a “topic segment”. For discussion of the user tasks requiring co-occurrence and the way we support visualization of this data in the interface, see both the Task Model and Visualization and Design sections.
5 EXPECT-NLP Interface Given that the data analysis portion of the EXPECT-NLP pipeline generates a large and multidimensional dataset, we are developing a graphical user interface to facilitate visualization and exploration of data by EXPECT-NLP’s users. This is an essential component of the system, since it allows domain expert analysts who need not be familiar with NLP or machine learning to efficiently conduct high-level thematic and interpretive analysis of the underlying text data. The current prototype of the EXPECT-NLP interface allows users to explore extracted aspects and categories, sentiment, user data, and derived aspect statistics. It is targeted primarily to experts in drug/treatment development, reimbursement and market access of pharmaceuticals, health services research, and health policy, or those with similar expertise. The development of the interface has followed an iterative design process, and has had its design influenced through multiple use case studies, informal discussions with potential stakeholders, and our own team’s expertise in these domains. We are continuing with further planned user case studies and iterative design development. Task Model. The current design for the EXPECT-NLP interface prototype is based on an early task model developed by following established protocols for designing user interfaces as outlined in [26]. To keep an initial open mind on potential target user populations and collecting a large and diverse set of information needs, we began development of our interface by interviewing two patient partner members of our project team, both of whom have extensive lived experience in searching for information on treatment options. The interviews took place over a one hour block with each of the two team members. We discussed a series of questions which were designed to identify task examples that demonstrate which information seeking tasks are performed by patients when finding information on treatments. Notably, they reported using social media to find general opinions on treatment options and descriptions of other patients’ experiences, and pointed to the ability to easily view summary statistics of opinions/experiences as a potentially valuable feature of a user interface. This discussion revealed the following potential tasks: • (T1) Search by keyword: E.g. user searches social media for “Ocrevus” and a set of posts is returned all containing the keyword “Ocrevus”
84
D. Johnson et al.
• (T2) Sub-search functionality: E.g. user searches for “Ocrevus” and from the resulting set of posts, searches for “symptoms” and the resulting subset of posts containing both “Ocrevus” and “symptoms” is returned • (T3) View summary statistics of keywords: E.g. user views that 100 people express positive sentiment about “Ocrevus” and 50 express negative sentiment about “Ocrevus” – Subtask (T3a) View change over time in summary statistics E.g. user views that “Ocrevus” opinions were primarily negative in 2015, and primarily positive in 2020 • (T4) Comparison of the opinions on multiple drugs E.g. user searches for “Lemtrada” and “Ocrevus” and compares the general opinion of others on the two drugs Following these patient-focused case studies, we hypothesized that academic or scientific users of EXPECT-NLP might have more complex analysis requirements. To facilitate construction of a research professional user task model, we completed user case studies with three members of our research team, all of whom are experts in drug reimbursement and policy and/or health services research. These case studies identified the following additional user tasks: • (T5) For either a drug (e.g. “Ocrevus”) or attribute (e.g. “Side Effects”) show frequency and sentiment broken down by co-occurring drug or attribute aspects E.g. User selects “Side Effects” and views all mentions of drug aspects which co-occurred with “Side Effects” • (T6) View neutral mentions of aspects in addition to positive or negative mentions E.g. User views that “Ocrevus” was mentioned 100 times with positive sentiment, 50 times with negative sentiment, and 75 times with neutral sentiment • (T7) View sentences from which co-occurring aspects were extracted E.g. User sees that “Ocrevus” and “Lemtrada” have co-occurred together 10 times, and views the sentences in which they co-occur • (T8) View aspect mentions broken down by aliases E.g. The user selects “Ocrevus” and sees that of 100 total mentions, 75 were mentions specifically of “Ocrevus” and 25 were mentions of the alias “ocrelizumab” (the generic drug name for Ocrevus), which were counted as mentions of “Ocrevus” Visualization and Design. The design for the visual encodings of the current interface prototype is shown in Fig. 1. This preliminary design, intended to support the whole task model (including the expert tasks), comprises the following elements: • View Terms Menu (Supporting T1, T2) The interface presents a “View Terms” menu at the top left, allowing users to view extracted aspects (e.g. click “View Terms” and scroll through aspects, clicking “Ocrevus”). Terms selected are then added to the “Total Mentions” visualization panel • Total Mentions Panel (Supporting T3, T6) The Total Mentions Panel displays the main broad overview of the data through a bar chart visualization. As shown in Fig. 1, Tecfidera has 444 positive mentions and 371 negative mentions. Multiple
EXPECT-NLP: An Integrated Pipeline and User Interface …
85
Fig. 1 Prototype of the EXPECT-NLP interface. The interface provides bar chart visualizations as a broad overview of the dataset, as well as a pivot table which supports cross-tabulation (and other capabilities such as custom chart visualizations) to allow for details-on-demand. Elements of the interface are explained in the EXPECT-NLP interface section
• •
•
•
term selections can be added here to allow a simple visual comparison of the number of positive and negative mentions between terms. Neutral mentions can optionally be viewed here in addition to positive/negative if extracted from the dataset Change Over Time Panel (Supporting T3a) This panel, in the bottom left, shows the breakdown over the dataset time period of the selected term by positive and negative mentions Filter Panel (Supporting T4) The left side of the interface includes a filter panel, allowing the patients to filter on categories such as Time and/or Location (depending on the user’s dataset) which would show the aspects collected from data during certain time periods or from certain user locations. This filter panel can also incorporate user demographics if the dataset includes demographics as a dimension Pivot Table (Supporting T5, T6) The top right panel of the interface includes a pivot table which allows users to cross-tabulate aspects, co-occurring aspects and aspect categories, sentiment, year, and location. We expand on the functionality of the pivot table below Sentence Panel (Supporting T7) The bottom right panels show all the sentences in the subset selected by the user (Fig. 1 shows a selection of the positive subset of opinions on Tecfidera)
As of this moment, user task T8 is not supported by the EXPECT-NLP interface, however development is being actively continued and the intention is to support task T8 in the near future. Pivot Table. The Pivot Table provides the primary way of experiencing the detailson-demand of the extracted data, and provides a complement to the simple overview provided by the Total Mentions Panel. The Pivot Table is based on a widely-used
86
D. Johnson et al.
form of visualization construction called shelf configuration in which users typically use drag-and-drop interaction methods to create data transformations in a table or visualization format [10, 14]. In following this, our Pivot Table enables the user to cross-tab data (e.g. cross-tab aspects with sentiment, or sentiment by location, or co-occurring aspects by time, etc) using a drag-and-drop interaction method. Going beyond cross-tabs, the pivot table allows users to search for terms, and display terms in different ways such as heat map tables, or visualizations like bar charts, line charts, or area charts. The table also allows sorting, filtering, and viewing values by total counts or proportion of row/column values. In addition, the table allows for contextual information about the cross-tab values: by clicking any of the cross table values the sentence from which the cross-tabbed value was extracted is shown in a pop-up window. Ultimately, the Pivot Table is the deep dive companion to the broad overview that the rest of the interface presents.
6 Real-World Use Cases We have also had discussions with industry professionals on the potential use of EXPECT-NLP to facilitate their own exploration of user-generated data valuable for public health decisions. Our research team met with professionals from the British Columbia Centre for Disease Control to discuss their potential use of EXPECT-NLP for discovering sources of COVID-19 vaccine hesitancy from social media data sources (in this case, Reddit sources such as r/coronavirusUS, and r/canadacoronavirus). After viewing a demo of the interface, the CDC team commented that EXPECT-NLP was very interesting and had much potential, and in particular they felt it would be useful to use in looking at vaccine hesitancy and uptake over time. They also felt that there is huge value in being able to see the response to a communication campaign, to be able to see whether their information is having an effect, which EXPECT-NLP could provide. The experts did suggest that from a public health perspective it is ideal to have near real-time data, which EXPECT-NLP does not currently support. While EXPECT-NLP does not currently have the capability to stream real-time data (as it requires offline processing of a dataset using the extraction pipeline before visualization in the interface) it can support analysis over short time periods such as days if data is obtained and processed over that period of time. We intend to continue development of EXPECT-NLP in a way that may allow increasing responsiveness to real-time data through improvements to web scraping, data processing, and aspect extraction.
EXPECT-NLP: An Integrated Pipeline and User Interface …
87
7 Conclusion and Future Work We have shown that by using EXPECT-NLP’s aspect based sentiment analysis pipeline it is possible to (semi)automatically identify key factors relevant to patient drug therapy decision-making directly from patient-generated social media text, potentially greatly reducing the time and cost of patient preference and experience studies. Additionally, EXPECT-NLP provides an interface to the extracted aspect sentiment data which allows visualization and analysis of the extracted data in an efficient and convenient way. In the near future, we intend to investigate how our approach’s agility may facilitate the tracking of patient and public perceptions of drug therapies in rapidly-evolving situations like the COVID-19 pandemic. For example, it could be used to track in real-time changes to the way in which the public views potential therapeutics such as hydroxychloroquine, dexamethasone, and remdisivir, for which evidence on safety and effectiveness is uncertain, rapidly evolving, and (in one case) politically contested, as well as public preferences for COVID-19 vaccines under development, both of which could help to inform public health communication efforts. We believe EXPECT-NLP will be an exceptionally valuable tool in the analysis of the large extrapolated data generated by the pipeline through its ability to support the user task models for expert interpretation and exploration of patient preferences. Finally, we intend to continue development of our pipeline through exploring potential methods for generating preference weights for attributes from extracted aspects and sentiment, and we are also testing incorporating discourse parsing trained on sentiment [13] into our NLP pipeline hoping to strengthen aspect prediction. We are also looking beyond the use of ABSApp to apply other engines to drive the aspect extraction, such as prompting transformer-based language models [17] to more accurately predict sentiment. We also intend to run studies directly evaluating the performance of EXPECT-NLP against traditional DCE style methods for extracting preferences.
References 1. de Bekker-Grob, E. W., Ryan, M., & Gerard, K. (2012). Discrete choice experiments in health economics: A review of the literature. Health Economics, 21(2), 145–172. 2. Brochu, F., Robins, S., Miner, S. A., Grunberg, P. H., Chan, P., Lo, K., et al. (2019). Searching the internet for infertility information: A survey of patient needs and preferences. Journal of Medical Internet Research, 21(12), e15132. 3. Chew, C., Rebi´c, N., Baldwin, C., Amiri, N., Proulx, L., & De Vera, M. A. (2019). “r/thritis”, pregnancy, and parenting: A qualitative descriptive study of reddit forums to explore information needs and concerns of women with rheumatoid arthritis. ACR Open Rheumatology, 1(8), 485–492. 4. Coast, J., Al-Janabi, H., Sutton, E. J., Horrocks, S. A., Vosper, A. J., Swancutt, D. R., & Flynn, T. N. (2012). Using qualitative methods for attribute development for discrete choice experiments: Issues and recommendations. Health Economics, 21(6), 730–741.
88
D. Johnson et al.
5. Coenen, M., Stamm, T. A., Stucki, G., & Cieza, A. (2012). Individual interviews and focus groups in patients with rheumatoid arthritis: A comparison of two qualitative methods. Quality of Life Research, 21(2), 359–370. 6. Engel, L., Bansback, N., Bryan, S., Doyle-Waters, M. M., & Whitehurst, D. G. (2016). Exclusion criteria in national health state valuation studies: A systematic review. Medical Decision Making, 36(7), 798–810. 7. Fleiss, J. L. (1971). Measuring nominal scale agreement among many raters. Psychological Bulletin, 76(5), 378. 8. Garg, R., Rebi´c, N., & De Vera, M. A. (2020). Information needs about cancer treatment, fertility, and pregnancy: Qualitative descriptive study of reddit threads. JMIR Cancer, 6(2), e17771. 9. George, A., Johnson, D., Carenini, G., Eslami, A., Ng, R., & Portales-Casamar, E. (2021). Applications of aspect-based sentiment analysis on psychiatric clinical notes to study suicide in youth. In AMIA Annual Symposium Proceedings (Vol. 2021, p. 229). American Medical Informatics Association. 10. Grammel, L., Bennett, C., Tory, M., & Storey, M. A. D. (2013). A survey of visualization construction user interfaces. In EuroVis (Short Papers) (pp. 019–023). 11. Harrison, M., Marra, C., Shojania, K., & Bansback, N. (2015). Societal preferences for rheumatoid arthritis treatments: Evidence from a discrete choice experiment. Rheumatology, 54(10), 1816–1825. 12. Hearst, M. A. (1997). Text tiling: Segmenting text into multi-paragraph subtopic passages. Computational Linguistics, 23(1), 33–64. 13. Huber, P., & Carenini, G. (2019). Predicting discourse structure using distant supervision from sentiment. In K. Inui, J. Jiang, V. Ng, & X. Wan (Eds.) Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing, EMNLP-IJCNLP 2019, Hong Kong, China, November 3–7, 2019 (pp. 2306–2316). Association for Computational Linguistics. https://doi.org/10.18653/ v1/D19-1235 14. Jo, J., L’Yi, S., Lee, B., & Seo, J. (2017). Touchpivot: Blending wimp & post-wimp interfaces for data exploration on tablet devices. In Proceedings of the 2017 CHI Conference on Human Factors in Computing Systems (pp. 2660–2671). 15. Kiperwasser, E., & Goldberg, Y. (2016). Simple and accurate dependency parsing using bidirectional lstm feature representations. Transactions of the Association for Computational Linguistics, 4, 313–327. 16. Liu, P., Joty, S., & Meng, H. (2015). Fine-grained opinion mining with recurrent neural networks and word embeddings. In Proceedings of the 2015 Conference on Empirical Methods in Natural Language Processing (pp. 1433–1443). 17. Liu, P., Yuan, W., Fu, J., Jiang, Z., Hayashi, H., & Neubig, G. (2021). Pre-train, prompt, and predict: A systematic survey of prompting methods in natural language processing. arXiv:2107.13586. 18. Lynd, L. D., Henrich, N. J., Hategeka, C., Marra, C. A., Mittmann, N., Evans, C., & Traboulsee, A. L. (2018). Perspectives of patients with multiple sclerosis on drug treatment: A qualitative study. International Journal of MS Care, 20(6), 269–277. 19. Lynd, L. D., Traboulsee, A., Marra, C. A., Mittmann, N., Evans, C., Li, K. H., et al. (2016). Quantitative analysis of multiple sclerosis patients’ preferences for drug treatment: A bestworst scaling study. Therapeutic Advances in Neurological Disorders, 9(4), 287–296. 20. Mulhern, B., Norman, R., Street, D. J., & Viney, R. (2019). One method, many methodological choices: A structured review of discrete-choice experiments for health state valuation. PharmacoEconomics, 37(1), 29–43. 21. Namey, E., Guest, G., McKenna, K., & Chen, M. (2016). Evaluating bang for the buck: A costeffectiveness comparison between individual interviews and focus groups based on thematic saturation levels. American Journal of Evaluation, 37(3), 425–440. 22. Park, J. Y., Howren, A. M., Davidson, E., & De Vera, M. A. (2020). Insights on mental health when living with rheumatoid arthritis: A descriptive qualitative study of threads on the reddit website. BMC Rheumatology, 4(1), 1–9.
EXPECT-NLP: An Integrated Pipeline and User Interface …
89
23. Pereg, O., Korat, D., Wasserblat, M., Mamou, J., & Dagan, I. (2019). Absapp: A portable weakly-supervised aspect-based sentiment extraction system. arXiv:1909.05608. 24. Qiu, G., Liu, B., Bu, J., & Chen, C. (2011). Opinion word expansion and target extraction through double propagation. Computational Linguistics, 37(1), 9–27. 25. Sadah, S. A., Shahbazi, M., Wiley, M. T., & Hristidis, V. (2015). A study of the demographics of web-based health-related social media users. Journal of Medical Internet Research, 17(8), e194. 26. Shneiderman, B., Plaisant, C., Cohen, M., Jacobs, S., Elmqvist, N., & Diakopoulos, N. (2016). Designing the user interface: Strategies for effective human-computer interaction. Pearson. 27. Wang, W., & Pan, S. J. (2018). Recursive neural structural correspondence network for crossdomain aspect and opinion co-extraction. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Vol. 1: Long Papers) (pp. 2171–2181).
Medication Error Detection Using Contextual Language Models Yu Jiang and Christian Poellabauer
Abstract Medication errors most commonly occur at the ordering or prescribing stage, potentially leading to medical complications and poor health outcomes. While it is possible to catch these errors using different techniques; the focus of this work is on textual and contextual analysis of prescription information to detect and prevent potential medication errors. In this paper, we demonstrate how to use BERT-based contextual language models to detect anomalies in written or spoken text based on a data set extracted from real-world medical data of thousands of patient records. The proposed models are able to learn patterns of text dependency and predict erroneous output based on contextual information such as patient data. The experimental results yield accuracy up to 96.63% for text input and up to 79.55% for speech input, which is satisfactory for most real-world applications. Keywords Medication error detection · Contextual language model · Precision medicine and health
1 Introduction Over 6,800 prescription medication types are available in the U.S. alone; each year hundreds of thousands of patients experience adverse reactions or complications and between 7,000 and 9,000 people die each year due to medication error (Tariq et al. [11]). The leading type of error is dosing error, followed by omissions and wrong drug types (Mulac et al. [6]) and most of these errors occur at the ordering or prescribing stages. Once a prescription has been recorded, e.g., in the electronic health records (EHR) of a patient, textual analysis can be used to detect such errors and prevent their consequences. Physicians are also increasingly rely on automatic speech recognition Y. Jiang (B) University of Notre Dame, Notre Dame, IN, USA e-mail: [email protected] C. Poellabauer Florida International University, Miami, FL, USA e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_7
91
92
Y. Jiang and C. Poellabauer
Table 1 Examples of prescriptions from EHR and ASR Type Content EHR EHR ASR
Diagnosis: Cholecystitis Prescription: Aspirin eighty one milligrams daily Diagnosis: Bacteremia endocarditis Prescription: Lisinopril five milligrams oral administration once a day Diagnosis: Peripheral vascular disease Prescription: Blood pressure (lopressor) seventy five milligrams orally administration
Valid? Yes No No
(ASR) systems to communicate with medical or computing equipment, including to order drugs (Latif et al. [2]). Given that ASRs cannot provide perfect transcriptions at all times, they become another source of prescription errors. These erroneous prescriptions could have severe consequences on patient treatment (Oura [7]), e.g., by administering the wrong type or amount of drug (Rodziewicz et al. [10]). In this paper, we propose to detect invalid prescriptions by verifying that it is an appropriate choice given the status (or context) of the patient receiving the medication. Toward this end, we already have a lot of relevant information available, e.g., using a hospital’s digital systems that contain diagnoses, health history, medication history, etc., of patients. In this paper, we leverage such contextual knowledge to detect potential prescription (from textual information provided by a physician) or transcription errors (from an ASR). This is demonstrated by the examples shown in Table 1. Human error could lead to an invalid prescription which will be stored in an EHR together with other patient-specific data. The first row of the table shows a valid prescription, i.e., the drug type, usage, etc., is typical for the patient’s diagnosis. In the second row of the table, Lisinopril is not an appropriate drug for Bacteremia Endocarditis and this should be flagged as invalid. Finally, the third row shows the transcription output of an ASR where the drug name is mis-interpreted by the system (“blood pressure” instead of “Lopressor”), making this entry again invalid. In order to detect erroneous entries, we propose to use a contextual language model (CLM) that utilizes context information about a patient to determine whether a prescription is valid or not. The basic idea is to analyze the correlation between prescription and the context. If they are highly correlated, our system will accept it. On the other hand, the system will trigger an alert to the user, allowing the user to review and correct the prescription. In this work, the proposed CLM is a neural network based on two pre-trained word representations: BERT (Devlin et al. [1]) and BioBERT (Lee et al. [3]), which are described in the remainder of this paper.
Medication Error Detection Using Contextual Language Models
93
2 Proposed Methodology The goal of our work is to prevent medical errors from happening with the help of contextual information, i.e., detecting incorrect prescriptions using the proposed contextual language model. We can consider two types of sources for the prescription. First, a medical professional types a prescription into a system and any errors at this point are true human errors, e.g., the physician misjudging the patient’s situation or not having access to all required information (e.g., patient’s health history or allergies, etc.). Second, a physician dictates the drug information and an ASR system translates the recording into text. In the latter case, errors can also be due to transcription errors by the ASR. These two scenarios are demonstrated in Fig. 1.
2.1 Problem Formalization We define the set of medical inputs as H , and the set of contextual knowledge as C. Each element h i in H has its corresponding contextual knowledge ci , where h i ∈ H and ci ∈ C. Both h i and ci are text, consisting of a sequence of words. The problem we intend to solve here is to determine whether h i is correct or not based on its corresponding context ci . The proposed method is to calculate the likelihood of h i under the condition of ci P(h i |ci ), and then make a decision according to that value. The likelihood is calculated by the CLM, where the CLM is trained using labeled pairs (h i , ci ). A label represents the correlation between h i and ci ; if they are correlated, the label would be positive, and negative if uncorrelated. Therefore, the trained CLM could output the likelihood P(h i |ci ), measuring the correlation
Fig. 1 Detecting prescription errors using a contextual language model: the error detection module receives text or speech input from the physician and decides the validity of the prescription. The proposed CLM model is the key to error detection, which is based on BERT and Bio-BERT, and trained with in-domain corpora
94
Y. Jiang and C. Poellabauer
between h i and ci , which approaches “1” when correlation is high and “0” when correlation is low. When an ASR system is used, the spoken form of h i is the input to the ASR system, which produces the transcription h i . Then, the pair of text (h i , ci ) is the input to the CLM. Note that h i could be erroneous compared to h i . In this situation, the performance of CLM will decrease, because CLM is trained using (h i , ci ), and h i is likely to follow different patterns from h i .
2.2 Contextual Language Models The CLM takes text as input, i.e., we need to convert the sequence of words into vectors. There are several methods of learning word representations from a large amount of un-annotated text, such as Word2vec (Mikolov et al. [5]), ELMo (Peters et al. [8]), and CoVe (McCann et al. [4]). In the proposed approach, we build the CLM using BERT (Devlin et al. [1]) and BioBERT (Lee et al. [3]), because BERT achieves state-of-the-art performance for most NLP tasks and BioBERT is a model (based on BERT) specifically for the biomedical field. They share the same structure of bidirectional Transformer encoder (Vaswani et al. [12]), but are trained using different corpora. The model architecture is a stack of 6 identical layers and each layer has two sub-layers. One is a multi-head self-attention mechanism, the other is a simple fully connected feed-forward network. There are also residual connection and layer normalization steps after each sub-layer. The BERT model is pre-trained using two unsupervised tasks, i.e., Masked Language Model (MLM) and Next Sentence Prediction (NSP), using English Wikipedia of 2.5 billion words and BooksCorpus of 0.8 billion words. BioBERT is further trained on biomedical text, i.e., PubMed abstracts of 4.5 billion words and PMC full-text articles of 13.5 billion words, both of which contain a large number of domain-specific terms. The BERT and BioBERT models are initialized using the weights after pre-training. The input format to the CLM is consistent with BERT and BioBERT. The WordPiece embedding (Wu et al. [13]) is used to deal with the out-of-vocabulary problem. Two special symbols are added to each word sequence, i.e., the CLS and SEP token. The CLS is added to the head, and the final corresponding hidden state is the whole sequence representation. The SEP token is added to the end of each sentence, in order to separate the sentences. We directly fine-tune the BERT model as the baseline. A linear head is used to map the sentence representation to classification label. The shortcoming here is that it does not take advantage of all the information output from BERT. We could otherwise make use of all the output word representations. Therefore, we propose the model shown in Fig. 2. Besides using sentence embedding, we also take the word embedding after max-pooling and average-pooling as feature outputs from BERT. Furthermore, a Bi-LSTM layer could be added in order to further extract features from the word embedding.
Medication Error Detection Using Contextual Language Models
95
Fig. 2 BERT-based contextual language model: Bi-LSTM can be added depending on the configuration
3 Experimentation We perform experiments using two types of data: speech and text. When the data type is text, the medical prescription would be directly input to the CLM. And when the input is speech, the experimental framework is shown in Fig. 3. First we input the utterance to the ASR (which is constructed using a pre-trained model for normal language such as ASPIRE, and medical corpora, which provide domain-specific information, such as the medical dictionary and the medical language model), and obtain the decoding results. Then the CLM will determine whether the medical input is correct or not based on the contextual information.
3.1 Dataset Generation We use the dataset of the 2019 National NLP Clinical Challenges (n2c2) on Medication, provided by the i2b2 Center. The original set includes about 2,000 patients’ hospitalization records, which provides information such as diagnosis, illness history, laboratory data, hospital course, and discharge medications. We extract the con-
96
Y. Jiang and C. Poellabauer
Fig. 3 Experimental framework for speech input
tents of diagnosis and discharge medications in each record, and divide the whole paragraph of medications into different items. Each item of medication has its corresponding diagnosis, and we believe they are highly correlated. The number of this kind of correlated pairs is 8,621. Then we further clean the pairs by removing serial numbers and punctuation, deleting out-of-domain items, handling special symbols, and replacing abbreviations with full names. The final number of pairs is then 6,901. For a binary classification task, we not only need correlated pairs, but also uncorrelated pairs. Therefore, we create a dataset of uncorrelated pairs by randomly combining diagnoses and medications into different pairs. Many medications can be used for different disorders; therefore, we design an algorithm to calculate a distance between diagnoses. If the distance is greater than a threshold, we just discard this pair. The number of generated uncorrelated pairs is 70,000. To balance the data with correlated and uncorrelated items, we further duplicate the correlated pairs 10 times to obtain 69,010 pairs. The data for CLM is prepared by splitting the entire set into training, validation, and testing sets in the proportion of 6:2:2. The correlated pairs have positive labels and uncorrelated pairs have negative labels. The dataset only includes text data; the speech input is processed using Google’s Text-to-Speech (TTS) API. The medical inputs H in the test set are converted from speech to text and then used as input to the ASR system for decoding.
3.2 ASR Implementation The ASR system is implemented using the Kaldi Speech Recognition Toolkit (Povey et al. [9]). It is hard to train a good ASR model from scratch, because large amounts of speech and text data for the specific domain will be required. Instead, we adapt the ASPIRE model to the medical domain by simply combining the ASPIRE language model and the medical language model trained from our i2b2 dataset. The decoding utility in Kaldi used in this work is “online2- wav-nnet3-latgen-faster”, which provides fast and accurate decoding.
Medication Error Detection Using Contextual Language Models Table 2 Performance of different models for text input Model Accuracy Precision B E RT B E RTmlp CLM C L Mlstm C L Mbio C L Mbiolstm
0.9625 0.9647 0.9663 0.9645 0.9633 0.9653
0.9354 0.9404 0.9425 0.9395 0.9374 0.9397
97
Recall
F1
0.9928 0.9916 0.9924 0.9922 0.9921 0.9935
0.9633 0.9653 0.9668 0.9651 0.9640 0.9659
3.3 Results and Analysis We conduct experiments on the following models: BERT, BERT with a multilayer perceptron (MLP) head (B E RTmlp ), BERT-based CLM (C L M), BERT-based CLM with Bi-LSTM (C L Mlstm ), BioBERT-based CLM (C L Mbio ), and BioBERT-based CLM with Bi-LSTM (C L Mbiolstm ). All models are initialized the same way and are trained to their best states. For a binary classification task, the metrics are accuracy, precision, recall, and F1 score. We take 0.5 as the threshold. For text input, the results on the test set are shown in Table 2. These results show that the BERT-cased CLM achieves the best performance with 96.63% accuracy of the test set, and 0.4% absolute increment compared to the baseline. We have a large test set of nearly 30 thousand, so about a thousand wrong predictions made by BERT become correct using our best model. We further explore the performance of different models for speech input. The input to the CLM becomes the hypotheses output from the ASR. In this case, we need to note that the ASR sometimes makes mistakes, and in our experiments, the word error rate (WER) is 28.69%. Therefore, the medical commands, which at first are correlated with the context, might become uncorrelated after ASR decoding, and the labels of test set have to be modified. We determine the labels through comparing the text before and after ASR. If some essential entities, such as the medication name, dosage, or usage are incorrect, we change the label from “1” to “0”. In the test set, we have 13,755 correlated pairs at the beginning, and 10,092 correlated pairs left after processing. The threshold is still 0.5. The results are shown in Table 3. From the results we see that when there is an ASR system involved, the BioBERT-based CLM performs the best with an accuracy of 79.55% and a 2.96% absolute increment compared to the baseline. The erroneous transcriptions by ASR negatively impact the models, but applying BioBERT can minimize these impacts. Moreover, adding the Bi-LSTM layer and making the network deeper does not improve the performance.
98
Y. Jiang and C. Poellabauer
Table 3 Performance of different models for speech input Model Accuracy Precision Recall B E RT B E RTmlp CLM C L Mlstm C L Mbio C L Mbiolstm
0.7659 0.7724 0.7759 0.7579 0.7955 0.7508
0.7749 0.7743 0.7759 0.7937 0.7879 0.7645
0.5004 0.5267 0.5380 0.4503 0.5976 0.4530
F1 0.6081 0.6269 0.6355 0.5746 0.6797 0.5689
4 Conclusions This paper proposes a solution to detecting prescription errors with the help of contextual language models (CLM). This kind of models are effective in dealing with text data, and are able to make validity predictions based on the correlations. We achieve the best accuracy of 96.63% using the BERT based CLM when the physicians type prescription data directly into the EHR, and 79.55% using the BioBERT based CLM when the physicians dictate the prescriptions using an ASR. Future work might include integrating more diverse contextual knowledge, i.e., in addition to patient data, we can also consider physician preferences, databases that describe adverse reactions, clinical workflows and recommender outputs, etc. Acknowledgements We acknowledge Dr. Yuan Gong for giving suggestions for this project.
References 1. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2019). Bert: Pre-training of deep bidirectional transformers for language understanding. In NAACL 2. Latif, S., Qadir, J., Qayyum, A., Usama, M., & Younis, S. (2021). Speech technology for healthcare: Opportunities, challenges, and state of the art. IEEE Reviews in Biomedical Engineering, 14, 342–356. 3. Lee, J., Yoon, W., Kim, S., Kim, D., Kim, S., So, C. H., & Kang, J. (2020). Biobert: A pretrained biomedical language representation model for biomedical text mining. Bioinformatics, 36(4), 1234–1240. 4. McCann, B., Bradbury, J., Xiong, C., & Socher, R. (2017). Learned in translation: Contextualized word vectors. In Proceedings of the 31st International Conference on Neural Information Processing Systems, NIPS’17 (pp. 6297–6308). 5. Mikolov, T., Sutskever, I., Chen, K., Corrado, G. S., & Dean, J. (2013). Distributed representations of words and phrases and their compositionality. In Advances in Neural Information Processing Systems (Vol 26) 6. Mulac, A., Taxis, K., Hagesaether, E., & Gerd Granas, A. (2021). Severe and fatal medication errors in hospitals: Findings from the norwegian incident reporting system, 28(Suppl 2), s56– s61. https://doi.org/10.1136/ejhpharm-2020-002298. 7. Oura, P. (2021). Medical adverse events in the us 2018 mortality data. Preventive Medicine Reports, 24, 101574.
Medication Error Detection Using Contextual Language Models
99
8. Peters, M. E., Neumann, M., Iyyer, M., Gardner, M., Clark, C., Lee, K., & Zettlemoyer, L. (2018). Deep contextualized word representations. In Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Vol. 1 (Long Papers), pp. 2227–2237), New Orleans, Louisiana. https://doi.org/ 10.18653/v1/N18-1202. 9. Povey, D., Ghoshal, A., Boulianne, G., Burget, L., Glembek, O., Goel, N., Hannemann, M., Motlicek, P., Qian, Y., Schwarz, P., Silovsky, J., Stemmer, G., & Vesely, K. (2011). The kaldi speech recognition toolkit. 10. Rodziewicz, T. L., Houseman, B., & Hipskind, J. E. (2020). Medical error prevention. Treasure Island (FL): StatPearls Publishing. 11. Tariq, R. A., Vashisht, R., Sinha, A., & Scherbak, Y. (2021). Medication dispensing errors and prevention. Treasure Island (FL): StatPearls Publishing. 12. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention is all you need. In Advances in Neural Information Processing Systems (Vol. 30). 13. Wu, Y., Schuster, M., Chen, Z., et al. (2016). Google’s neural machine translation system: Bridging the gap between human and machine translation. arXiv:1609.08144.
Latent Representation Weights Learning of the Indefinite Length of Views for Conception Diagnosis Bo Li, Mengze Sun, Yuan Yu, Yuanyuan Zhao, Zhongliang Xiang, and Zhiyong An
Abstract Deep learning has great prevalence in various medical diagnosis tasks. Existing methods can tackle the issue of multiviews very well. However, these methods cannot process indefinite lengths of multiviews, especially with a “dimension gap" between them, such as blood flow ultrasound images. In this work, we propose Latent Representation Weight Learning (LRWL) to learn the latent representative weight of each image or view and then integrate the views with the weights and the diagnostic indexes as part of the input data to DL to predict successful conception. This method can describe the role of each view accurately. We perform thorough experiments on a real reproduction dataset to evaluate LRWL. The results show that our proposed method achieves the top performances with higher accuracy and good convergence. Keywords Multiviews · Dimension gap · Latent representation
1 Introduction Artificial intelligence (AI), especially machine learning (ML), has rapidly grown in recent years and has been successfully widely used in clinical practice. The past few years have witnessed AI-enabled medical imaging in the spotlight. Applications of deep learning (DL) to medical image analysis first started to appear at workshops B. Li · Y. Yu · Z. Xiang · Z. An (B) School of Computer Science and Technology, Shandong Technology and Business University, Yantai 264005, PR, China e-mail: [email protected] B. Li e-mail: [email protected] B. Li · M. Sun · Z. An School of Statistics, Shandong Technology and Business University, Yantai 264005, PR, China e-mail: [email protected] Y. Zhao School of Clinical Medicine, Qilu Medical University, Zibo 255000, PR, China © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_8
101
102
B. Li et al.
and conferences and then in journals. The number of papers has grown rapidly in recent years [9, 17, 24]. Researchers have extended DL frameworks, such as Convolutional Neural Networks (CNN) [1], ResNet [8], U-Net [42], and GANs [6, 26, 30]; and widely applied them in the medical image processing field. Furthermore, most researchers concentrate on medical image classification [28, 29], segmentation [33, 40, 41], reconstruction [14, 35], synthesis [7, 27], computer-assisted interventions and computed tomography [4, 16]. However, there are few studies on applying machine learning to infertility ultrasonography. Doppler ultrasound analysis has been used by gynecologists primarily to determine blood flow and in obstetrics to examine the relationship of the blood flow in the uterine and umbilical arteries with adverse fetal outcomes [11, 13, 17]. This process represents a continuum from preovulation to development of the placental circulation, and it is subject to many of the same physiological and hormonal influences. Hence, Doppler ultrasound analysis is very important for successful conception. However, using uterine blood flow ultrasonography to predict conception with machine learning methods, has been less extensively studied. The task of multi-view learning is how to effectively explore the consistency and complementary information of different views. Plenty of research works focus on multi-view learning and have achieved great progress. Some of the most representative methods are uncertainty-aware multi-view representation Learning [12], trusted multi-view classification [15], deep partial multi-view learning [37] (zhang et al. [38]) and their relative works [10, 22, 36, 38]. Nevertheless, these works do not concentration the indefinite length multiviews with a large dimension gap caused by the doctors. In the treatment, doctors often make a decision according to the physical status or physical examination indicators and blood flow ultrasound images (BFUI). There is a significant difference between BFUIs and other medical images— including the indefinite length images or views created by the doctors. However, there are few studies addressing this issue. When doctors perform a blood flow ultrasound examination of patients, they may produce different numbers of Doppler ultrasound images for two reasons: (1) First, some images are not clear. (2) The doctors want to collect more views to provide more information. For example, case A may have 1 BFUI with a dimension of 1 × 3 × 220 × 340 while case B has 10 BFUIs with a dimension of 10 × 3 × 220 × 340. Therefore, we cannot directly apply a DL to analyze BFUIs because of the gap. Moreover, we cannot neglect the correlations of the images of the same case. The key issue is how to extract the correlation of BFUIs and combine the images to the same number of channels, and then to exploit highperformance DL. Furthermore, we also need to integrate the examination indicators and combined images (shown in Fig. 1) to predict the future conception rate. In this work, the main contributions are summarized as: • We proposed a Latent Representation Weight Learning, learning the latent representative weight of each image or view, which can track a doctor’s psychological changes while capturing images.
Latent Representation Weights Learning of the Indefinite Length …
103
• We proposed Preweighted Convolution Filtering (PCF), using convolution kernels (32, 64 and 128 kernels) that yield by the distributions, to make the length of the channels the same. It assumes that the views follow some distributions. • We devise a collaborative learning frame which integrate indefinite length multiviews for Doppler ultrasound images with the diagnostic indexes as the input to DL for prediction. We also conduct experiments on the real dataset from a hospital to validate the effectiveness of the proposed method. In addition, insightful discussion is provided to explain the scarcity of such dataset. To the best of our knowledge, we are the first to concentrate on and manipulate the indefinite length multiviews with a large dimension gap created by doctors. We explore the applications of these methods to Doppler ultrasound analysis. The experimental results show that our proposed method is able to achieve the top performance.
2 Integration and Methods In this section, we introduce our framework and its detailed implementation. First, we provide an overview of the architecture and the input representations. Then, we introduce the proposed method—LRWL—and its important innovation. Finally, for comparison, we introduce another proposed method—PCF—which can also synthesize the indefinite length channels of Doppler ultrasounds. In general, the target of this work is to integrate indefinite length multiviews for Doppler ultrasound images with the diagnostic indexes as the input data to DL to predict successful conception. The architecture of this work is shown in Fig. 1. Note that the Doppler ultrasound images need to be clipped automatically by the code since there are some background areas.
Fig. 1 The architecture of this work. DU: Doppler ultrasound, DI: Diagnostic indexes, RC: Region clipping, VC: Vectorization, OP: Optimization, Dist: Distribution, CO: Convolution operation, UC: Unique channels, and DL: Deep learning
104
B. Li et al.
2.1 Latent Representation Weight Learning We first propose the LRWL method for indefinite length multiviews. Given n cases Vi 1 2 and each case with different lengths of views xi , xi , ..., xi , i = 1, 2, ...., n, LRWL aims to assign different weights to each view in the cases. The views can be linearly recovered with their respective methods, as shown in Fig. 2 and discussed in Han et al. [15], Zhang et al. [36]. However, those works often utilize methods to interpolate missing views. They do not consider the indefinite length multiviews caused by human operations, which is not a missing data problem. First, in this work, the information from different indefinite length views should be encoded into the latent representation. Then, the learned latent representation should meet the target task —conception rate prediction— and the weight of each view should be learned. Specifically, considering the vth view X iv in the ith case, which can be represented by latent representation H v and noise eiv , we have < X iv , X iv >=< P v , H v > +eiv ,
(1)
where is the operation of the tensor inner product eiv is the noise of the ith case in the vth view.X iv ∈ R c×k×d , P v ∈ R c×k×s ,and H v ∈ R c×k×s . For all the views in the ith case, we have
Fig. 2 Doppler ultrasound. The cases have indefinite length views. For example, case 1 has five views, case 2 has two views, case 3 and case 4 have three views, and case 5 has four views
Latent Representation Weights Learning of the Indefinite Length …
⎡
⎤ ⎡ 1 ei < P 1, H 1 > ⎥ ⎢ < P 2 , H 2 > ⎥ ⎢ ei2 ⎥ ⎢ ⎥ ⎢ ⎥ + ⎢ .. ⎥ = ⎢ .. ⎦ ⎣ . ⎦ ⎣ . Vi Vi < P ,H > > eiVi
< X i1 , X i1 > ⎢ < X i2 , X i2 > ⎢ ⎢ .. ⎣ . < X iVi , X iVi
⎡
⎤
105
⎤ ⎥ ⎥ ⎥. ⎦
(2)
For all the n cases, the objective function is min L X (X, P , H ) P,H
= min( (< X, X > − < P, H >)2 + μ1 ||P||∗ ),
(3)
P,H
where X are the views, and P is the reconstruction matrix variable aligned to the views of the corresponding case.H is the latent representation of X . Note that the views in the same position in different cases share common P and H , and they do not treat each case individually. L X (· , ·) denotes the reconstruction of X , and μ1 ||P||∗ is the regular term. The reconstruction matrix P is constrained to avoid all P and X being pushed arbitrarily close to X . Second, based on the representation of X , we need to consider the correlation of different channels in a tensor before using the latent representation of H . Therefore, we exploit a mapping function A :R c×k×s → R c×k to reduce the dimensions of H and maintain the correlation. For all cases, the objective function is min L H (H, Q, M) Q,M
= min( (< H A , H A > − < Q, M >)2 + μ2 ||Q||∗ ).
(4)
Q,M
As discussed in Eq. (4), we do not use since all the cases share H . M is the latent representation of H . Third, for the task-oriented goal, the objective function is min L M (Y, α, U, M) U,α
= C O STso f t max
(Y, αU M)+ φ1 (U ), s.t. αi =1,
(5)
where φ1 (U )=μ3 ||UU T −E||2,1 Therefore, we have the objective function of LRWL min
P,H,Q,M,U,α
L X (X, P, H )
= λ1 L H (H, Q, M) + λ2 L M (Y, α, U, M) + μ1 ||P||∗ + μ2 ||Q||∗ + φ1 (U ).
(6)
In this model, parameters λ1 > 0, λ2 > 0 and λ3 > 0 balance the reconstruction errors and regularizations on latent representations. In summary, LRWL can yield learning weights for the vth view - αv . Then, we can use the weights to integrate the original indefinite length views in the ith case -
106
B. Li et al.
X F,i —to fixed length channels as described in Eq. (7), which can be used as training ithor testing samples. X F,i =
Vi
αv X iv .
(7)
v=1
2.2 Optimization Since the objective function in Eq. (6) is not jointly convex with respect to the variables P,H ,Q,M,U and α, Eq. (6) gradually learns the latent representations from original indefinite length multiviews, finds the meaningful matrix variables, and meets the task-oriented goal. Thus, Eqs. (3), (5) and (6) need to be optimized [21, 34]. Considering that the variables can be solved by fixing others in each equation, the augmented lagrange multiplier (ALM) with alternating direction minimizing (ADM) strategy [26] is an effective solver for this optimization problem. To adopt the ADM strategy to our problem, we need to make our objective function separable. To adopt the strategy in Eq. (3), auxiliary variable L is introduced. Then, Eq. (3) becomes min L X (X, P, H, L X ) P,H
= min( (< X, X > − < P, H >)2 + μ||L X ||∗ ), P,H
(8)
s.t. P = L X ,
and let L X (X, P, H, L , W ) denote the Augmented Lagrange function [23]. L
X (X, P , H, L X , W ) = 21 (< X, X > − < P, H >)2 + μ||L X ||∗ − + β21 ||P − L X ||2F ,
(9)
where W ∈ R c∗k∗s is the Lagrange multiplier, and β1 is the penalty parameter. Hence, in the ALM method with the ADM strategy, we have the following solvers [26]. = min L X (X, P k , H, L kX , W Xk ) , L k+1 X
(10)
k P k+1 = min L X (X, P , H, L k+1 X , WX ) ,
(11)
W Xk+1 = W Xk − τ 1+β1 (P k+1 − L X k+1 ) .
(12)
L
P
The optimization of L and P can be separable according to the ADM strategy. For the optimization subproblem L
Latent Representation Weights Learning of the Indefinite Length …
107
L k+1 = arg min{||L X ||∗ − < W Xk , P k − L kX > X L
+ β21 ||P k − L kX ||2F } = arg min{||L X ||∗ + β21 ||P k − L X − L
(13) W Xk β1
||2F }.
Then, we have L k+1 = U Diag(shrink2,1 (σ, X
1 ))V T , β1
(14)
T
k V and = where shrink2,1 (σ, β11 )=sign(σ ) max{0, |σ |β1 }, P k − Wβ1 = U Diag(σ1 , ..., σm ). For the optimization subproblem P, let vec() denote the vectorization of a matrix, X all = (< X 1 , X 1 >, . . . , < X n , X n >)T and Hall = (vec(H ), . . . , vec(H ))T . Then, Eq. (11) becomes P k+1 = min 21 ||X all −Hall vec(P k+1 )T ||22 +
P β1 ||P k+1 2
− L k+1 X −
W Xk β1
||2F .
(15)
Then, we have 0 = (Hall )T (Hall vec(P k+1 )T −X all ) k +β1 vec(P k+1 ) − β1 vec(L k+1 X ) − vec(W X ) T T k+1 T = Hall Hall vec(P ) −Hall X all k T T +β1 vec(P k+1 )T − β1 vec(L k+1 X ) − vec(W X ) T T k+1 T = (Hall Hall + I )vec(P ) − Hall X all k T T − β1 vec(L k+1 X ) − vec(W X ) .
(16)
vec(P k+1 ) = (Hall T Hall + I )−1 [−Hall T X all k T T −β1 vec(L k+1 X ) − vec(W X ) ].
(17)
Hence
If (Hall T Hall + I ) is nonsingular, we can solve the equation with Cholesky decomposition [4] or the conjugate gradient method [19]. Then, we can obtain latent representation H with L, P, and W fixed k+1 k = Hall +β1 (tr (X T X )all − tr (P T H k )all )Pall . Hall
(18)
In a similar way, in Eq. (10), we can have 1 T L k+1 H = U Diag(shrink2,1 (σ, β1 ))V ,
k VT, s.t. Q k − Wβ1 = U
(19)
108
B. Li et al.
vec(Q − H k+1 ) = (Mall T Mall + I )−1 [−Mall T H A k T T −β1 vec(L k+1 H ) − vec(W H ) ],
(20)
k+1 W Hk+1 = W Hk − τ 1+β1 (Q k+1 ), H − LH
(21)
k+1 k Mall = Mall +β1 (tr (H A H A )all − tr (Q T M k )all )Q all .
(22)
T
For the task-oriented goal, in Eq. (6), we cannot obtain similar results due to the different forms from Eqs. (4) to (5). The SQP [3] algorithm is employed to solve the problem. Let u = [α, vec(U )]. Then, the optimization of Eq. (6) can be solved as u k+1 = minn 21 d T L uu (u k , λ, μ)d + ∇ f (u k )T d, s.t. h(u)+h (u)d = 0 , g(u)+g (u)d ≤ 0,
(23)
L uu (u k , λ, μ) = L M (Y, u, M)+λT h(u) + μT g(u),
(24)
u∈R
where,L uu (u k , λ, μ) is the Lagrange function of L M (Y, u, M),d=u − u k ,h(u) = ||UU T −E||2,1 and g(u)= |αi |−1. Then, this can be solved as a quadratic optimal problem in SQP. The optimization of latent representation weight learning is shown in Algorithm 1.The proof of the convergence of the optimization method is in the supplementary material. Algorithm 1 Optimization algorithm for LRWL Input: Indefinite length multiviews X 1 , X 2 , . . . , X n Hyperarameter: λ1 , λ2 , μ1 ,μ2 and β1 Initialize: L X , L H , W X , W H , P,H , Q, M, α,U Output: α while not converged do Update variable L X according to Eq. (14) Update variable P according to Eq. (17) Update variable W X according to Eq. (12) Update variable H according to Eq. (18) end while Input:H while not converged do Update variable L H according to Eq. (19) Update variable Q according to Eq. (20) Update variable W H according to Eq. (21) Update variable M according to Eq. (22) end while Itput:M while not converged do Update variable α and U according to Eq. (23) and SQP end while
Latent Representation Weights Learning of the Indefinite Length …
109
Table 1 The distributions of generating convolution kernels. The kernel size is 3 × 3 Distribution Density Function (x−μ)2 2σ 2
(− √1 e 2π
Gussian
1 α−1 (1 − B(α,β) x
( αk )
Beta Dirichlet
m k=1
k α −1
λe−λx
F
(n 1 /n 2 ) 2 B(n 1 /2,n 2 /2) x
Gamma
λα x α−1 e−λx (α)
Laplace
1 − |x−μ| b 2b e
Logistic
e−x (1+e−x )2
n1
lognormal Uniform Von Mises
n1 2
− (ln x−μ) 2σ 2 √ 1 e 2π xσ 1 b−a
−1
(1 +
n1 n2
x)−
n 1 +n 2 2
2
e K cos(x−μ) 2π I0 (K )
Wald 1 π
Cauchy
x)β−1
θk k
Exponential
2.2.1
)
( λ e 2π x 3
−λ(x−μ)2 2μ2 x
γ (x−x 0 )2 +γ 2
)
Preweighted Convolution Filtering
The kernel idea in this work is to integrate the origin indefinite length views into fixed length channels. We also proposed another approach - preweighted convolution filtering. This approach applies multiconvolution kernels to the images and then produces training or testing samples. These kernels are yielded by the following distributions: Gaussian, beta, Dirichlet, exponential, F, gamma, Laplace, l, lognormal, uniform, Von Mises, Wald and Cauchy [32]. Details are listed in Table 1. These distributions will yield many kernels to apply a convolution to the Doppler ultrasound images and then yield unified channels of images as the input data into the DL model.
3 Results and Discussion We perform experiments on 2D uterine Doppler ultrasound images. The task is to evaluate our proposed methods to predict the probability of conception. We first propose an evaluation method used in our experiments. Considering that previous
110
B. Li et al.
works have not addressed the issue of the indefinite length medical images created by doctors, we also propose baseline methods. We compare the results in terms of effectiveness and demonstrate how our method can achieve high performance.
3.1 Data Collection The real dataset in this work is available from the Reproduction Medical Center at Yantai Yuhuangding Hospital of Qingdao University. It comprises 100 cases. Each case has several Doppler ultrasound images or views. For example, five cases are shown in Fig. 2. The dimension of each view is 3 × 220 × 340. The overview length of all cases is given in Fig. 3. The figure shows that there is a gap of dimensions between different instances. For instance, the minimum length is one, and the maximum length is ten. This indicates that the dimensions vary greatly across all cases. If a case has only one view, and the dimension is 3 × 220 × 340; however, if a case has ten views, the dimension will be 10 × 3 × 220 × 340. Therefore, we cannot utilize the dominant DL models in the field of processing indefinite input data, such as LSTM and BERT [43], to process these data directly. The diagnostic indexes are listed in Table 2. These indicators are also needed for patients in addition to their Doppler ultrasound examination. For example, the endometrial thickness is closely related to conception. Because the female uterus is the location where a fertilized egg and fetus grow and develop, the endometrium is equivalent to the soil for fetal development. If the soil is too thin, it is insufficiently fertile. This will lead to malnutrition of the fertilized egg implanted in it. It is also possible that the fertilized egg cannot be implanted into such a uterine cavity, which
Fig. 3 The overview of lengths across all the cases. The length represents the number of Doppler ultrasound images for each case - from one to ten images
Latent Representation Weights Learning of the Indefinite Length …
111
Table 2 The diagnostic indexes for assisted conception. Each case comprises Doppler ultrasound images and these diagnostic indexes Diagnostic indexes Values Conception period types Endometrial Thickness Blood flow Days of embryo culture Estradiol Luteinizing hormone Progesterone
Artificial, natural, down regulation Real:1.0–20.0 No, yes, grade-1, and grade-2 3–6 day real: 60–2300 Real: 0.1–95.74 Real: 0.05–60
easily causes an abortion. Note that these data need to be normalized due to the difference in dimensions.
3.2 Experimental Setup We use ResNet, which is a deep network that can achieve high performance, as the backbone part of the framework for conception prediction. We also modify the input data of ResNet to integrate the diagnostic indexes into the fully connected layer in addition to the input of Doppler ultrasound images In contrast to other medical cases, it is more difficult to collect data from infertile patients. Considering the scarcity of cases because of the number of patients (the Obstetrics and Gynecology Department and the Reproductive Medicine Department are completely different departments), we divide the samples into two parts: 70 samples for training and validation and 30 samples for testing [5]. Prediction is employed as the evaluation metric. Then, a 10-fold cross-validation method is employed for the training and validation sets. We train and evaluate the methods 10 times each. Each time, 7 of the 70 cases were left out for validation, and the other 63 cases were used for training. In our experiments, 32, 64 and 128 convolution kernels are used in PCF. It is impossible to consider all possible numbers of channels. Although 32, 64, 128, 256, and even 512 kernels are commonly used in convolution layers, more kernels for preconvolution instead of in backpropagation training will not significantly enhance the performance.
3.3 Comparison with the Baseline We compare our LRWL with the baseline-PCF-on our dataset. Table 3 lists the prediction accuracy between these methods.The results show that LRWL achieves
112
B. Li et al.
Table 3 Comparison of the accuracy between LRWL and the baseline method. In PCF, we employ 13 different distributions to yield 32, 64 and 128 kernels Method Number of kernels Accuracy Gaussian
Beta
Dirichlet
Exponential
F
Gamma
Laplace
Logistic
Lognormal
Uniform
Von Mises
Wald
Cauchy LRWL
32 64 128 32 64 128 32 64 128 32 64 128 32 64 128 32 64 128 32 64 128 32 64 128 v32 64 128 32 64 128 32 64 128 32 64 128 32 64 128
0.40 0.43 0.53 0.43 0.57 0.53 0.53 0.57 0.57 0.53 0.57 0.57 0.53 0.50 0.57 0.60 0.57 0.53 0.43 0.57 0.57 0.40 0.57 0.47 0.57 0.57 0.57 0.37 0.57 0.37 0.37 0.60 0.43 0.43 0.57 0.57 0.57 0.77 0.57 0.83
Latent Representation Weights Learning of the Indefinite Length …
113
better performance than PCF methods. In PCF, there is no significant difference between different distributions, indicating that the Doppler ultrasound images dose not strictly follow with the distributions listed in Table 3. Therefore, increasing the channels of kernels does not significantly improve the performance. Why does LRWL achieve the top performance? This occurs because the proposed method is based on our in-depth communication with doctors. When doctors perform Doppler ultrasound examinations, they want to capture the images from a hemodynamic map with a color Doppler ultrasound diagnostic instrument. Hence, whether the images can satisfy the requirement depends upon two factors: the moments that the images are captured and the doctor’s concentration. The process can be described as follows. When a doctor captures an image for the first time, he is most focused, which is why many cases only comprise one image. However, this can not satisfy the requirement because of the dynamic hemodynamic map. Then, the doctor will unconcernedly capture several more images, which will be satisfactory in most cases.. In addition, if the doctor finds that the quality of images is insufficiently high, he will be nervous and concentrate on this process. This is the reason why some cases even comprise 10 images. As discussed above, the proposed method-LRWL-can track a doctor’s psychological changes while capturing images and hence achieve the top performance. We use one real dataset for validation in this work. The reason is that the dataset satisfies two conditions: (1) The length of views is indefinite. (2) There is a large dimension gap between different cases. Specially, the views can reflect doctor’s psychological changes during operation. We not only investigate some public medical datasets in [2, 18, 20, 25], but also investigate many non-medical public multi-view datasets from recent Refs. [12, 15, 36, 37], such as ADNI (consists of 774 subjects from ADNI-1, including 226 normal controls (NC), 362MCI and 186 AD subjects), Animal (consists of 10,158 images from 50 classes with two typesof deep features), Caltech101 (contains images of 101 object categories), COIL20MV (contains 1440 images from 20 object categories), CUB (contains different categories of birds), Football (a collection of 248 English Premier League football players and clubs active on Twitter), Hand-written (contains 10 categories from digits ‘0’ to ‘9’, and 200 images in each category with six types of image features are used), HMDB (a large human motion database collected from various sources), MSRCV1 (contains 30 different images for each class out of 7 classes in total), ORL (contains 10 different images of each of 40 distinct subjects under different conditions), PIE (contains more than 750,000 images of 337 people recorded in up to four sessions over the span of five months), Politics (a collection of Irish politicians and political organisations), Scene15 (contains fifteen natural scene categories), UCI-MF (consists of handwritten numerals from a collection of Dutch utility maps), 3Sources-complete (collected from three online news sources: BBC, Reuters, and Guardian). However, those datasets do not satisfy the conditions. We will access more real datasets from other hospitals in future.
114
B. Li et al.
4 Conclusion In this work, we propose the LRWL method to predict successful conception. As stated, prior multi-view methods do not have an effective way to process indefinite length multiviews with a large dimension gap, which limits their performance. To address these issues, this method can learn the latent representation weight of each image or view effectively. Then we integrate the weights and the indexes into DL. Experiments on the real reproduction dataset from a hospital show that our proposed method achieves the top performance with higher accuracy. Acknowledgements We would like to acknowledge the financial support in part by the Shandong Natural Science Foundation (ZR2021M F068, ZR2021MF015, ZR2021MF107, ZR2021QF134), Shandong Computer Society Provincial Key Laboratory Joint Open Fund (SKLCN-2020-06), Wealth Management Characteristic Construction Project of Shandong Technology and Business University (2019ZBKY032).
References 1. Agarwal, A., Goel, A., Singh, R., Vatsa, M., & Ratha, N. K. (2020). Dndnet: Reconfiguring CNN for adversarial robustness. In CVPR Workshop on Fair, Data Efficient and Trusted Computer Vision, 2020 (pp. 103–110). 2. AIMI, S. (2020). A large new cardiac motion video data resource for medical machine learning. https://stanfordaimi.azurewebsites.net/datasets/834e1cd1-92f7-4268-9daa-d359198b310a. 3. Angalaeswari, S., Sanjeevikumar, P., Jamuna, K., & Leonowicz, Z. (2020). Hybrid pipsosqp algorithm for real power loss minimization in radial distribution systems with optimal placement of distributed generation. Sustainability, 12, 5787. 4. Antonsanti, P. L., Benseghir, T., Jugnon, V., Glaunés, J. (2020). Database annotation with few examples: An atlas-based framework using diffeomorphic registration of 3d trees (pp. 160–170). 5. Bennin, K. E., Keung, J., Monden, A., Kamei, Y., & Ubayashi, N. (2016). Investigating the effects of balanced training and testing datasets on effort-aware fault prediction models. In: Computer Software & Applications Conference (pp. 154–163) 6. Cao, B., Zhang, H., Wang, N., Gao, X., & Shen, D. (2020). Auto-gan: Self-supervised collaborative learning for medical image synthesis. Proceedings of the AAAI Conference on Artificial Intelligence, 34(7), 10486–10493. 7. Chang, Q., Qu, H., Zhang, Y., Sabuncu, M., Chen, C., Zhang, T., et al. (2020). Synthetic learning: Learn from distributed asynchronized discriminator GAN without sharing medical image data (pp. 13853–13863). IEEE. 8. Chen, Z., Lin, Z., Wang, P., & Ding, M. (2021). Negative-resnet: Noisy ambulatory electrocardiogram signal classification scheme. Neural Computing and Applications, 10, 1–13. 9. Cinaroglu, I., & Bastanlar, Y. (2021). Training semantic descriptors for image-based localization. In: ECCV 2020 Workshop on Perception for Autonomous Driving. 10. Deschaintre, V., Aittala, M., Durand, F., Drettakis, G., & Bousseau, A. (2019). Flexible svbrdf capture with a multi-image deep network. 11. Dwivedi, , Ganesh, V., Shukla, R. C., Jain, M., & Kumar, I. (2020). Colour doppler evaluation of uterine and ovarian blood flow in patients of polycystic ovarian disease and post-treatment changes. Clinical Radiology, 75(10). 12. Geng, Y., Z H, Zhang, C., & Q H (2021). Uncertainty-aware multi-view representation learning. In: Proceedings of AAAI Conference on Artificial Intelligence (pp. 7545–7553).
Latent Representation Weights Learning of the Indefinite Length …
115
13. Gilboy, K. M., Wu, Y., Wood, B. J., Boctor, E. M., & Taylor, R. H. (2020). Dual-robotic ultrasound system for in vivo prostate tomography. In: International Conference on Medical Image Computing and Computer Assisted Intervention (pp. 161–170) 14. Guo, Y., Bi, L., Ahn, E., Feng, D., Wang, Q., & Kim, J. (2020). A spatiotemporal volumetric interpolation network for 4d dynamic medical image. In: IEEE Conference on Computer Vision and Pattern Recognition (pp. 4725–4734). 15. Han, Z., Zhang, C., Fu, H., & Zhou, J. T. (2021). Trusted multi-view classification. 16. He, J., Pan, C., Yang, C., Zhang, M., & Yu, Y. (2020). Learning hybrid representations for automatic 3d vessel centerline extraction. In: International Conference on Medical Image Computing and Computer Assisted Intervention (pp. 24–34). 17. He, X., Wang, S., Chu, X., Shi, S., Tang, J., Liu, X., et al. (2021). Automated model design and benchmarking of 3d deep learning models for covid-19 detection with chest CT scans. In: AAAI Conference on Artificial Intelligence (pp. 4821–4829) 18. ISBI. (2021). Grand-challenges-all challenges. https://link.zhihu.com/?target=http. 19. Jia, C., Zhao, J., Liu, Q., Ma, Y., & Hu, C. (2020). Analysis of influence of wind speed correlation in transmission congestion based on LHS-Cholesky decomposition. In: 2020 12th IEEE PES Asia-Pacific Power and Energy Engineering Conference (APPEEC). 20. Kopans, D., & Moore, R. (2021). University of South Florida digital mammography home page. http://www.eng.usf.edu/cvprg/Mammography/Database.html. 21. Li, Y., Boi, A., Zhang, T., Ji, Y., Harada, T., & Niener, M. (2020). Learning to optimize non-rigid tracking. In: IEEE Conference on Computer Vision and Pattern Recognition (pp. 4909–4917). 22. Liu, Y., Jain, A., Eng, C., Way, D. H., Lee, K., Bui, P., et al. (2019). A deep learning system for differential diagnosis of skin diseases. 23. Liu, Y. N. (2019). The augmented lagrange multiplier method for nonconvex regular matrix regression. M.S. diss.: Beijing Jiaotong University. 24. Manna, S., Bhattacharya, S., & Pal, U. (2021). SSLM: Self-supervised learning for medical diagnosis from MR video. arXiv:2104.10481v1. 25. OASIS. (2020). Classified skin lesions. https://www.isic-archive.com/. 26. Parimala, K., & Channappayya, S. (2019). Quality aware generative adversarial networks. In: IEEE Conference on Neural Information Processing Systems. 27. Peng, C., Lin, W. A., Liao, H., Chellappa, R., & Zhou, S. K. (2020). Saint: Spatially aware interpolation network for medical slice synthesis. In: IEEE Conference on Computer Vision and Pattern Recognition (pp. 7747–7756). 28. Schirmer, M., Venkataraman, A., Rekik, I., Kim, M., & Ai, W. C. (2021). Neuropsychiatric disease classification using functional connectomics - results of the connectomics in neuroimaging transfer learning challenge. Medical Image Analysis, 11, 101972. 29. St Ieler, F., Rabe, F., & Bauer, B. (2021). Towards domain-specific explainable AI: Model interpretation of a skin image classifier using a human approach. In: IEEE Conference on Computer Vision and Pattern Recognition. 30. Wang, C. R., Zhang, F., Yu, Y., & Wang, Y. (2020). Br-GAN: Bilateral residual generating adversarial network for mammogram classification. 31. Wang, W., Yan, S., Mao, L., & Guo, X. (2021). Robust minimum variance beamforming with sidelobe level control using the alternating direction method of multipliers. IEEE Transactions on Aerospace and Electronic Systems, 99, 2514–2518. 32. Woo, S. W. (2021). Probability and its distribution in statistics. Design of mechanical systems based on statistics. 33. Xie, X., Chen, J., Li, Y., Shen, L., & Zheng, Y. (2020). Instance-aware self-supervised learning for nuclei segmentation. arXiv:2007.11186. 34. Yang, H., Zhang, Z., Fan, W., & Xiao, F. (2021). Optimal design for demand responsive connector service considering elastic demand. IEEE Transactions on Intelligent Transportation Systems, PP(99), 1–11. 35. Zamir, S. W., Arora, A., Khan, S., Hayat, M., Khan, F. S., Yang, M. H., et al. (2021). Multistage progressive image restoration. In: IEEE Conference on Computer Vision and Pattern Recognition.
116
B. Li et al.
36. Zhang, C., Fu, H., Hu, Q., Cao, X., Xie, Y., Tao, D., & Xu, D. (2018). Generalized latent multiview subspace clustering. IEEE Transactions on Pattern Analysis and Machine Intelligence, 42(1), 86–99. 37. Zhang, C., Cui, Y., Han, Z., Zhou, J. T., Fu, H., & Hu, Q. (2020). Deep partial multi-view learning. arXiv:2011.06170. 38. Zhang, C., Fu, H., Wang, J., Li, W., & Hu, Q. (2020). Tensorized multi-view subspace representation learning. International Journal of Computer Vision, 128(8), 2344–2361. 39. Zhang, C., Fu, H., Wang, J., Li, W., & Hu, Q. (2020). Tensorized multi-view subspace representation learning. International Journal of Computer Vision, 9, 2344–2361. 40. Zhao, A., Balakrishnan, G., Durand, F., Guttag, J. V., & Dalca, A. V. (2020). Data augmentation using learned transformations for one-shot medical image segmentation. In: IEEE Conference on Computer Vision and Pattern Recognition (pp. 8543–8553). 41. Zheng, H., Zhang, Y., Yang, L., Wang, C., & Chen, D. Z. (2020). An annotation sparsification strategy for 3d medical image segmentation via representative selection and self-training. Proceedings of the AAAI Conference on Artificial Intelligence, 34(4), 6925–6932. 42. Zheng, J., Liu, X. Y., & Wang, X. (2020). Single image cloud removal using u-net and generative adversarial networks. IEEE Transactions on Geoscience and Remote Sensing, 99, 1–15. 43. Zhu, L., & Yang, Y. (2020). Actbert: Learning global-local video-text representations. In: IEEE Conference on Computer Vision and Pattern Recognition (pp. 8743–8752).
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies Andre Vauvelle, Hamish Tomlinson, Aaron Sim, and Spiros Denaxas
Abstract Identifying phenotypes plays an important role in furthering our understanding of disease biology through practical applications within healthcare and the life sciences. The challenge of dealing with the complexities and noise within electronic health records (EHRs) has motivated applications of machine learning in phenotypic discovery. While recent research has focused on finding predictive subtypes for clinical decision support, here we instead focus on the noise that results in phenotypic misclassification, which can reduce a phenotypes ability to detect associations in genome-wide association studies (GWAS). We show that by combining anchor learning and transformer architectures into our proposed model, AnchorBERT, we are able to detect genomic associations only previously found in large consortium studies with 5× more cases. When reducing the number of controls available by 50%, we find our model is able to maintain 40% more significant genomic associations from the GWAS catalog compared to standard phenotype definitions. Keywords Phenotyping · Machine learning · Semi-supervised · Genetic association studies · Biological discovery
A. Vauvelle (B) · S. Denaxas Institute of Health Informatics, University College London, London, UK e-mail: [email protected] S. Denaxas e-mail: [email protected] H. Tomlinson · A. Sim Benevolent AI, 4-8 Maple St, London, UK e-mail: [email protected] A. Sim e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_9
117
118
A. Vauvelle et al.
1 Introduction As the collection of healthcare data has expanded, traditional definitions of disease have been challenged due to large differences in outcomes between patients. Phenotyping refers to the process of defining a clinically relevant set of characteristics, such as exposures and outcomes for the purpose of patient identification. These characteristics can include simple traits, such as eye colour, but also extend to include definitions of disease from as wide as diseases of the circulatory system to specific disease subtypes. Identifying these phenotypes plays an important role in furthering our understanding of disease for applications within epidemiological research and drug development (Fig. 1). In the first approach, panels of experts define phenotypes by a series of rules [6]. While they are based on a consensus of domain experts, the scalability of rulebased methods are limited in that they are laborious, iterative, and time consuming processes [2]. As electronic health records grow, there is an opportunity to conduct large-scale analyses to drive our understanding of disease biology, however this will make expert review infeasible. Thus, we turn to machine learning. Machine learning has been previously used to identify phenotypes from electronic health records, primarily in the context of clinical decision support. For instance, in [20, 31], phenotypes are pragmatically identified to enable particular tasks such as characterising patients with a likelihood to require more specialised care, or are at risk of deterioration and/or death. Evaluation of machine learning generated phenotypes has focused on predictive measures of patient outcomes and response to treatment. This makes sense in a clinical setting where phenotypes need to be predictive of the future health state of a patient [17]. In this paper, we focus on the phenotyping task of diagnosis classification in electronic health records for biological discovery. Traditionally, phenotype labels are assigned to patients according to the presence or absence of International Classification of Diseases (ICD) codes. However, this method is likely to mislabel due to the large amount of noise and complexity within EHR data. Previous studies have shown that identifying phenotypes using diagnostic codes as a proxy can result in
Fig. 1 Overview of AnchorBERT phenotyping for GWAS. Unlabelled patients (grey dots) are given predicted probabilities of having the anchor variable (yellow plusses) by AnchorBERT. These patients are then used as a continuous trait in linear regression GWAS
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
119
poor positive and negative predictive values across diseases and healthcare systems [27]. Machine learning approaches may enable us to make fewer assumptions about the fidelity of individual codes while learning from longitudinal patient histories. Concretely, our goal is to learn to classify patients according to whether or not they have a disease. Within diagnosis classification, we identify two distinct issues: 1) Heterogeneity. Current definitions of disease are too broad such that multiple distinct phenotypes exist that better describe the presenting patient. In the context of biological discovery, conflicting influences from multiple subtypes can reduce p-values and distort effect sizes in association studies. 2) Phenotypic misclassification. This instead considers the possibility that diseases have been either incorrectly identified or missed during clinical observation and data collection. This noise can result in mislabelling of cases and controls, reducing power in association studies. Incorrectly assigning a patient as a case occurs when a patient is misdiagnosed as having the desired phenotype. Conversely, identifying controls with an absence of a diagnosis does not necessarily mean the diagnosis was ruled out. Instead, the diagnosis may have been missed. While the first issue has previously been addressed with subtyping [18]. We instead focus on the second issue of phenotypic misclassification, which has received comparatively less attention. We propose a robust metric for analysing phenotyping algorithms in the context of biological discovery. Our approach is to report replicated associations from previous studies in the GWAS catalog [3]. Importantly, Genome-wide association studies present a unique means of identifying phenotypes with distinct causal disease biology [5]. GWAS often require high sample sizes, particularly when effect sizes may be small, and indeed many GWAS catalog associations have been contributed by large consortia. Alternatively, we attempt to find associations that meet significance thresholds by reducing the noise due to misclassification and thus improve effective effect size. Evaluating phenotypes in this manner presents a more relevant metric to biological discovery, an independent and robust alternative to outcome-based evaluation. Problem Specification The main aim of this work is to create robust phenotypes for genomic discovery by addressing the issue of phenotypic misclassification. Our methodological contributions combine two areas of research; transformer architectures and anchor variable models. We reduce phenotypic misclassification within EHRs by treating the data as only positive and unlabelled data. Specifically, we use an anchor variable model to predict the probability of a patient being a case, employing a transformer model to improve the approximation of this probability in a model we call AnchorBERT. Finally, we validate our models against current diagnosis classification methods by reproducing known associations from five different diseases using a repository of validated studies in the GWAS catalog. Overall, our contributions are summarised as:
120
A. Vauvelle et al.
– We introduce the distinct issue of phenotypic misclassification and present the first model using noisy label learning and state-of-the-art deep learning architectures to improve genomic associations for biological discovery. – We present a robust, independent validation metric, more suited to biological discovery, based on replicating genomic associations found in the GWAS catalog. – Our proposed AnchorBERT model outperforms standard phenotypes definitions by maintaining more known associations as the number of samples used in GWAS is reduced. We are able to reproduce genetic associations only previously found in large consortium studies.
2 Related Work Previous works have looked to improve GWAS power for diseases with poor positive predictive value by setting an increased threshold on the number of total diagnosis codes required to be classified as a case [9]. Reference [23] show that setting a threshold to define a case can be avoided entirely by instead modelling the probability of a phenotype with unsupervised clustering (Pheprob). This continuous probability is then used directly in the association study in place of a binary classification, which [23] shows improved power to detect associations. Pheprob includes the total number of codes as an additional variable as part of a parametric binomial mixture model. The method provides some ability to assign non-zero probabilities to controls but largely ignores the possibility of false-negative labels. Mislabelling controls has a smaller effect on power [11], yet determining which error rate is larger is difficult and largely unknown [32]. Previous studies have considered treating cases as positive only and control data as unlabelled data [1, 14]. Rather than updating the probability of controls for use in association studies, these studies focus on semi-automated methods for improving phenotype definitions by iteratively updating anchor variables. Previous structured EHR machine learning works [22] suggest that models able to capture non-linear and sequential features between past, present, and future events produce better results on predictive tasks. Although our task is ultimately aimed at finding genetic associations, we hypothesis that more predictive representations will allow greater identification of noisy samples, reducing their negative influence on the control cohort and, in turn, provide greater overall associative power. Reference [16] applied autoencoders with anchor learning but did not evaluate performance to detect associations. [30] compare unsupervised clustering using mixture model and anchor learning methods but do not reproduce genomic associations and heavily rely on clinical notes.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
121
3 Methods 3.1 Problem Formulation T Let x = {xt }t=1 describe the sequential collection of disease codes from a total of T visits in a patient’s health record. Each visit xt ∈ X , is a multi-hot encoding of the d-total diseases, such that xt, j is marked as 1 if the jth disease was observed during visit t else 0. We let y ∈ Rd describe the latent disease state; this describes all diseases occurring during a patients life including those that are not observed in x. Finally, let g denote the genetic variable of interest measured for a patient, such as a SNP taking a value in {0, 1, 2}. For each patient, we consider the collection {x, y; g}. Ideally, we would be able to measure the association between the true, latent, disease state y and genetic variables g by defining a case-control cohort on the presence of the disease state; however, we are only presented with the observed diseases x. For GWAS, research typically proceeds by first labelling disease cases (y j = 1) if any xt, j = 1 for t ∈ {1, . . . , T }. Similarly controls (y j = 0) if all xt, j = 0 for t ∈ {1, . . . , T }. From now on, we assume that we are only interested in the latent disease j = a and drop the index from y j such that y ∈ {0, 1}. We propose learning the function p(y = 1|x) = f (x) and using this as a continuous trait in a regression to detect associations with g.
3.2 Anchor Learning Anchor learning is a previously well-studied method that can be applied to only positive and unlabelled data [12, 14]. Here, we reproduce the fundamental formalisation with reference to our task. We frame our problem as having positive only and unlabelled data by assuming that observing a specific disease code only positively identifies the latent class it is supposed to measure. For any patient without this code, we cannot say if they have the latent disease y or not. More generally, let s indicate if the patient’s sequence x is labelled, such that we know the value of y. x is positively labelled with y = 1 if s = 1. If s = 0 then the label of x is unknown, it could take either of the values, y = 0 or y = 1. We are assuming that only positive examples are labelled, which can be stated as p(s = 1|x, y = 0) = 0.
(1)
Our goal is to model p(y = 1|x). However, as stated above, we are assuming we only have positive labelled patients. Reference [12] show it is possible to progress if we assume our positive examples are chosen randomly from the set of all positive examples. This can also be stated by saying that s and x are conditionally independent given y or
122
A. Vauvelle et al.
p(s = 1|x, y = 1) = p(s = 1|y = 1).
(2)
Using this assumption, we can move closer to our goal of approximating p(y = 1|x) by learning a classifier to predict our anchor variable p(s = 1|x) as p(s = 1 ∧ y = 1|x) = p(s = 1|x) p(s = 1|x) . p(y = 1|x) = p(s = 1|y = 1, x)
(3) (4)
More specifically, in our case s is determined by the presence of a diagnosis code, xt,a = 1 for some t. If disease code a is present, we also we strong believe the patient has its described latent disease, y = 1. We say a is an anchor variable where p(s = 1 ∧ y = 1|x) = 1, if a ∈ x.
(5)
Considering Eqs. 5 and 2, our goal can be updated to p(y = 1|x) =
1, p(s=1|x) , c
if a ∈ x if a ∈ / x.
(6)
where c = p(s = 1|y = 1) is a constant. This means we can learn an anchor classifier, h(x) = p(s = 1|x), in place of f (x) to rank our instances when a ∈ / x. When the anchor variable is present in the patient in the patients’ data, the probability of the latent disease is 1. Otherwise, the probability is given by the result of an anchor classifier. Comparing Eq. 6 to the traditional definition of cases and controls, it is possible to view our task as learning from noisy labels. We are assuming that the standard definition of a control is noisy and, instead, learn a model to assign controls with scores that are higher if x is predictive of being a case.
3.3 Using BERT as Anchor Classifier Previously [15] used a logistic regression model, similar to h(x; θ ) = σ within anchor learning to update the terms of phenotype defi( x j t t, j )θ j nitions. In addition to instead using the output of the anchor classifier directly in downstream GWAS, we propose AnchorBERT, which combines NLP-like embeddings and the encoder of a transformer model to model h(x). This model is inspired by the original BERT model within the NLP domain [8].
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
123
Multi-head Attention The main component of our model relies upon entirely on (self)-attention mechanisms [26]. The primary advantage being that we are able to model global dependencies between codes, xt , regardless of their distance from each other in the sequence x. Self-attention modules first associate each input value xt with a query q and a set of key-value pairs (k, v), where the queries, keys and values are themselves linear projections of the input: q = θ q xt ,
v = θ v xt
k = θ k xt ,
(7)
where θ q , θ k and θ v are to be learnt. A self-attention score then determines how much focus to place on each output value given xt compute with the dot-product of the queries and keys: QK T V (8) Attention(Q, K , V ) = softmax √ k where each line of Q (resp. K) are matrices whose rows are the values associated with each query (resp. keys) entries. V is the matrix whose rows are the values associated with each input data. Our model uses multi-head self-attention where the input is independently processed by n self-attention modules. This leads to n outputs which are concatenated back together to form the final attention output vector. For further details and clarification, see [26]. Embeddings Since self-attention does not use any recurrent or convolutional mechanisms, we need to include an additional encoding to provide information on the absolute and relative position of inputs. We follow [19] by including a unique predetermined positional embedding for each visit. Segmentation embeddings are also included; these are two trainable vectors that alternate between subsequent visits and provide additional flexibility to encode differences between visits. Each term xt, j uses a learnt embedding to convert from a one-hot vector of dimension d to dmodel . Since our model can attend to each term within a visit equally, this is equivalent to the unordered multi-hot representation in our formulation. An attention mask, A, is used on individual anchor disease terms in the patient sequence. This is effectively equivalent to removing them entirely and negates any influence during training and evaluation. Equation 8 becomes Attention(Q, K , V ) = softmax
QK T √ + A V, k
(9)
where the elements of the attention mask, A, are zero except for at positions corresponding to xt = a, where a large negative value is used. Tokenization of the disease codes occurs before being fed as input to our model. Unique tokens are assigned to each term with a total count greater than 0.01% of the total terms; otherwise, a [UNK] token is used. [SEP] tokens are added between each
124
A. Vauvelle et al.
patient stay (episode). [PAD] tokens are appended to each sequence to maintain a maximum length of 256. [CLS] tokens are added at the start of each patient record for the BERT prediction pooling scheme. Training and Prediction Our model is updated during training by minimizing binary cross-entropy loss. The binary labels are positive if there are any anchor variables present in a patients data. We make anchor variable predictions for an entire sequence by simply taking the final attention output vector corresponding to the [CLS] token and applying a linear layer. The sigmoid function is applied to the logits from the linear layer to produce the predicted anchor probabilities. Once training and evaluation are complete, we can produce the final phenotype probabilities p(y = 1|x). As shown in Eq. 6, we replace the predicted anchor probability with 1 if the anchor label is positive. We set c = 1 as we only need to rank our examples according to the chance that they belong to class y = 1 [12].
3.4 Baselines We first compare AnchorBERT to the previously studied anchor logistic regression (Anchor LR). We apply logistic regression by aggregating our sequence of disease tokens into total counts and scaling to a unit normal distribution. Scaled anchor counts are removed from the input features. We use the sklearn implementation of logistic regression with the L-BFGS-B solver. We also compare our anchor variable models against the commonly used thresholding method and Pheprob baseline [23]. The thresholding models produce a binary phenotype, classifying a patient as a case if the total number of anchor phecodes equals or exceeds a threshold. We use the original implementation of Pheprob, a parametric binomial mixture model that includes the total count of all a patient’s phecodes and anchor phecodes as input features. Pheprob outputs a continuous phenotype probability for each patient.
3.5 Anchor Performance Metrics Anchor models are evaluated using the area under the receiver-operator curve (AUROC) and precision-recall curve (AUPRC). Average precision is used for AUPRC. While both metrics assess a model’s ability to predict the anchor variable, AUPRC removes the influence of true-negative controls allowing us better to assess the positive predictive power of the models.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
125
3.6 Hyperparameters Hyperparameters for all experiments are detailed the Appendix. Anchor classifiers are trained using a training:validation:test split of 6:2:2. Test data is unseen until final evaluation, and hyperparameters are tuned via grid search optimising for validation AUPRC. Hyperparameters are fitted for each disease, then refitted with a new seed for each evaluation. The best performing hyperparameters for each disease are shown in Table 5. Dropout is applied to the multi-head attention layers and within the final linear classification head. The model binary cross-entropy loss is optimised with BertAdam [8]. We set the size of embeddings and linear classification heads to be determined by a shared hidden size hyperparameter. See our full PyTorch Lightning implementation based on the Hugging Face [29] BERT model at https://github.com/andre-vauvelle/ AnchorBERT for clarification.
4 Experiments and Results In this section, we provide the details and results of our experiments on the UK Biobank data. We first introduce the UK Biobank data, the diseases studied, and any data preprocessing requirements. Second, we report the performance of our anchor variable models and describe an experiment to investigate their robustness to control noise and compatibility with our PU data assumption. Finally, we can compare AnchorBERT to our baseline phenotype models on their ability to reproduce known genetic associations.
4.1 UK Biobank data The UK Biobank [25] is a national population-based study comprising of 502,629 individuals. We extract all available diagnosis terms from both primary (Read V2 and V3) and secondary (ICD-10) care settings for every patient. We then map raw terms to Phecodes using the previously validated Phecode map 1.2b from [28]. We use Phecodes as our disease terms since the extracted raw terms are often distinguished by billing-specific information. Here Phecodes provide a higher level, clinically meaningful traits more suitable for genetic association studies [7]. All unmappable terms are dropped. Finally, only patients with ≥ 5 terms are retained, resulting in a total of 321,837 patients. We study five different diseases, identified by their Phecode and detailed in Table 1. Genomic data extraction and quality control processing follows the same methodology as [13]. This data is linked to all possible phenotyped patients resulting in
126 Table 1 Diseases studied Acronym MI T2D HF DM RA
A. Vauvelle et al.
Disease
Phecode
Myocardial infarction Type 2 diabetes Heart failure Dementia Rheumatoid arthritis
411.1 250.2 428.2 290.1 714.0|714.1
312,010 patients with both genetic and phenotype data. More detailed statistics of the data are available in Table 3.
4.2 Anchor Classifier Performance and Robustness to Control Noise After finding optimal hyperparameters, we retrained and evaluated each anchor classifier across ten runs to report the mean and standard deviation of each performance metric. In order to investigate if our anchor classifiers are robust to noisy controls, we would like to test if performance remains high while operating under corrupted anchor labels. Although we cannot identify which patients in our dataset have incorrectly been assigned without an expensive chart review, instead, we can artificially add noise. From the perspective of anchor learning, we consider the negative set as unlabelled. If we randomly switch a sub-sample of the positive anchor labels to unlabelled during training then evaluate on a validation set free from noise, we can effectively compare the performance of our anchor classifiers, h(x). As noise increases, models that are sensitive to label noise should perform worse. Results of Anchor Classifiers As seen in Table 2, the BERT classifier outperforms logistic regression across all disease areas. We observe considerable variation in performance between diseases, possibly indicating some diseases may have more comorbidity interactions which could help identify noisy labels. For example, T2D has almost double the number of cases compared to MI, yet performance is worse across both metrics and models. In Fig. 2, we show the results of our investigation into control noise robustness for MI and T2D, additional figures for the remaining diseases are in Appendix. Overall, BERT outperforms LR across all diseases and noise proportions, with the minor exception of DM at the highest noise level.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
127
Table 2 Mean and standard deviation anchor variable classifier model performance on test set across each investigated disease. AUPRC: Area under precision recall curve, AUROC: Area under receiver operator curve, LR: Logistic Regression Disease # Cases (ratio) AUPRC AUROC LR BERT LR BERT MI T2D HF DM RA
0.5547 ± 0.0000 0.4374 ± 0.0000 0.3806 ± 0.0000 4,582 (0.0147) 0.2380 ± 0.0000 7,956 (0.0255) 0.1029 ± 0.0000 18,007 (0.0577) 31,801 (0.1019) 9,179 (0.0294)
0.6680 0.0021 0.5071 0.0012 0.4453 0.0043 0.3286 0.0101 0.1375 0.0031
± ± ± ± ±
0.9039 ± 0.0000 0.7980 ± 0.0001 0.9208 ± 0.0000 0.8507 ± 0.0000 0.7603 ± 0.0000
Table 3 Basic statistics of UK Biobank EHR data after preprocessing # of unique patients 321,837 # of Phecodes # of visits
7,698,687
Avg. # of visits per patient Avg. # of Phecodes per patient
23.9 51.7
Avg. # of Phecodes per visit Maximum sequence length # of unique Phecode tokens
0.9538 ± 0.0006 0.8386 ± 0.0009 0.9405 ± 0.0013 0.8842 ± 0.0024 0.8037 ± 0.0016
16,655,024 2.16 256 837
Fig. 2 Area under precision recall curve for anchor variable classifiers as an increasing proportion of cases are added to the unlabelled set during training
128
A. Vauvelle et al.
4.3 Evaluating Phenotypes with GWAS We use two experimental setups aimed at validating an improved ability to detect genetic associations. (1) Full data: We run GWAS using phenotypes generated for all available patients and compare the models’ ability to reproduce known associations from the GWAS catalog. (2) Data ablation: We reduce the number of cases available in the GWAS and report which phenotyping methods are able to retain statistical significance for known associations. Associations are tested for using plink v2.0’s generalised linear model, regressing SNPs against phenotype status [4]. Linear regression is used for the anchor variable and binomial mixture model continuous phenotypes, while logistic regression is used for binary threshold phenotypes. All regressions use the following covariates: sex, age, and 1–10 population structure principle components. All reported significant SNPs, from both the GWAS catalog and our analysis, are filtered such that only those with p-value lower than 5 × 10−8 remain. Comparison to GWAS Catalog In order to assess whether our overall phenotyping methods are able to increase the power of GWAS studies, we compare their ability to reproduce known significant associations from the GWAS catalog. It is difficult to directly report on study power as a set of true phenotype-gene associations is not possible to obtain. We follow and expand upon prior work [23] by replicating previously found associations. Rather than replicating a smaller number of hand-selected associations, we instead compare against all associations found for a disease in the GWAS catalog [3]. The GWAS catalog is a widely used and freely available database of SNP-trait associations, including those from consortium studies with cohort sizes orders of magnitude larger than the UK Biobank. The catalog contains studies from vastly different populations and experimental procedures. Populations in the catalog may have been measured with a different sequencing array meaning that some SNPs may not be present in our data. In addition, reported loci in the GWAS catalog are often the result of fine-mapping, which keeps only the most likely causal SNP, discarding those highly correlated nearby in linkage disequilibrium (LD) [24]. In order to partly address these issues, we expand significant SNPs from the GWAS catalog and the UK Biobank genomic data to include SNPs within LD. We do this using an LD reference panel with a threshold of R 2 > 0.5 from the 1000 genomes project [10]. Data Ablation Study We also conduct a data ablation to study the influence of the anchor variable on the control cohort. Here, we are not reducing the amount of data available to train the anchor variable model. Instead, we randomly remove a proportion of the patients after training but before finding associations. If we reduce the case population defined by Threshold-1, study power should fall. For the anchor variable models, however, the influence of updating the probability of the controls should remain. At the extreme, with zero cases, only associations due to noisy samples should remain.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
129
Table 4 Performance of each phenotyping method when reproducing known GWAS Catalog associations. MI—Myocardial Infarction, T2D—Type 2 Diabetes, HF—Heart Failure, DM—Dementia, RA—Rheumatoid Arthritis Phenotype Total reproduced catalog genomic associations (Proportion) model MI T2D HF DM RA AnchorBERT Anchor LR Pheprob Threshold-1 Threshold-2 Threshold-3 Total Catalog Significant rsIDs
44 (0.3438) 39 (0.3047) 28 (0.2188) 29 (0.2266) 25 (0.1953) 18 (0.1406) 128
266 (0.1513) 254 (0.1445) 247 (0.1405) 280 (0.1593) 237 (0.1348) 200 (0.1138) 1,758
3 (0.0714) 2 (0.0476) 0 (0.0) 0 (0.0) 0 (0.0) 0 (0.0) 42
1 (1.0) 1 (1.0) 1 (1.0) 1 (1.0) 1 (1.0) 1 (1.0) 1
32 (0.0814) 32 (0.0814) 32 (0.0814) 32 (0.0814) 32 (0.0814) 32 (0.0840) 393
Since the variation in p-values can change significantly as some instances are included or excluded between ablation thresholds, we repeat the association studies ten times and report confidence intervals. In order to reduce computational power requirements, we only test the phenotypes associations with significant SNPs found from all methods with full data and report the proportion of reproduced significant SNPs. GWAS Results Table 4 contains the results of our comparison against the known associations from the GWAS catalog. Our proposed anchor variable approach to creating continuous phenotype traits for GWAS reproduces more associations than other models for both MI and HF. Threshold-1 outperforms others on T2D, while all models find an equal number of significant associations in the catalog for DM and RA. Where the anchor variable models do not find a greater number of associations, the proportion of reproduced associations is still the same or slightly lower than Threshold-1, the current standard method used for GWAS studies. Figure 3 shows the results of the ablation study for MI. Both proposed anchor variable models are able to reproduce the associations of the thresholding methods and Pheprob at all ablation thresholds, with AnchorBERT also outperforming the logistic regression anchor model. When removing all cases as defined by Threshold-1, nonanchor variable methods can no longer detect any associations, while AnchorBERT retains 20% and Anchor LR retains 13% of associations. In Appendix we show more complete results for each disease. Figure 5 shows the ablation study results for the other four diseases. With the exception of RA, when all Threshold-1 defined cases are removed, anchor models are still able to replicate catalog associations.
130
A. Vauvelle et al.
Fig. 3 Ablation study of patients for Myocardial Infarction and each phenotyping method. Shaded areas indicate one standard deviation of results across 10 trails
5 Discussion In this work, we present a novel phenotyping method, AnchorBERT, which uses anchor learning and transformers to generate continuous phenotypes that allow for the detection of significant genomic associations with smaller cohorts. In seeking a more representative phenotype, we argue that EHR diagnoses can be treated as positive only and unlabelled data. PU data allows the application of anchor learning, where we introduce BERT as a novel modification to the anchor classifier. BERT allows greater performance when modelling the anchor variable and is more robust to label noise, modelled here by introducing additional cases into the unlabelled set. Using data from the UK Biobank and GWAS catalog, we validate our proposed phenotyping methods, Anchor LR and AnchorBERT, together with baselines used by GWAS practitioners to detect genomic associations in HF previously only found in studies with 5× the number of available cases. Using anchor phenotypes could enable the discovery of genomic associations otherwise inaccessible to studies with small cohorts. For example from Table 4, our proposed AnchorBERT replicates three significant SNPs (rs17042102, rs55730499, rs1556516). Which were previously only found in the largest HF GWAS meta-study to date, with approximately 5× and 3× as many cases and controls of European ancestry [21]. From Figs. 3b and 5, we are able detect known associations for patients without an anchor variable (0% cases), suggesting that we are potentially identifying missed diagnoses. It is notable that, AnchorBERT reproduces at least as many or more genomic associations than Anchor LR, showing that the sequential and non-linear relationships between codes that result in higher anchor classification performance translate into improved ability to reproduce genomic associations. This trend is consistent across all disease areas considered, even for T2D where Anchor learning generally performs worse than Threshold-1.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
131
We note that the additional HF SNPs identified also have significant associations with related upstream comorbidities and traits, including: coronary artery disease, atrial fibrillation, and low-density lipoprotein cholesterol. This could be due to a genuine shared disease aetiology underlying these risk factors. Alternatively, these associations could be confounded and independently related to HF comorbidities. Performance across all diseases is not guaranteed. We find anchor phenotypes are able to identify genetic associations with 0% cases in all but RA. Considering Threshold-2 and 3 phenotypes outperform other phenotyping methods on RA, we suggest that noise in the case definition could be responsible for poor performance, as this violates our assumption in Eq. 5. We hope that these findings will help further efforts to discover new diseasegenomic associations. Ultimately leading to a greater understanding of disease and a better, more efficient process for discovering new medicines. Acknowledgements Andre Vauvelle is supported by a Benevolent AI Ph.D. studentship. We thank Prof. Christopher Yau, Dr. Eda Ozyigit, Albert Henry and Joe Farrington for their insightful guidance and discussion during this work.
Appendix See Figs. 4 and 5 and Table 5.
Fig. 4 Area under precision recall curve for anchor variable classifiers as an increasing proportion of negative examples are flipped
132
A. Vauvelle et al.
Fig. 5 Data ablation study for the remaining diseases for each phenotyping method. Shaded areas indicate one standard deviation of results across 10 trails
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
133
Table 5 Model Hyperparameters. Ranges for tuning indicated by square braces. Final values from tuning under each disease acronym BERT optimizer
Tuning
MI
Learning rate
[1 × 10−5 , 1× 10−4 , 1× 10−3 ]
1 × 10−4 1 × 10−4 1 × 10−4 1 × 10−4 1 × 10−4
T2D
HF
DM
RA
Warm-up proportion
0.1
Weight decay
0.001
BERT model hyperparameters Batch size
256
Hidden layer size
[120, 360 240, 360]
360
360
360
360
Number of hidden layers
[6, 10, 12]
2
6
10
6
6
Hidden dropout probability
0.2 12
12
12
12
12
512
256
256
256
Number of multi-head attention layers [6, 10, 12]
Intermediate layer size in transformer [128, 512 256, 512] Number of attention heads
12
Multi-head attention dropout rate
0.22
Parameter weight initializer range
0.02
Non-linear activation (Encoder and Pooler)
GELU
References 1. Agarwal, V., Podchiyska, T., Banda, J. M., Goel, V., Leung, T. I., Minty, E. P. et al. (2016). Learning statistical models of phenotypes using noisy labeled training data. Journal of the American Medical Informatics Association : JAMIA, 23(6), 1166–1173. 2. Banda, Juan M., Seneviratne, Martin, Hernandez-Boussard, Tina, & Shah, Nigam H. (2018). Advances in electronic phenotyping: From rule-based definitions to machine learning models. Annual Review of Biomedical Data Science, 1(1), 53–68. 3. Buniello, A., & Helen et al. (omitted for brevity) Parkinson (2019). The NHGRI-EBI GWAS Catalog of published genome-wide association studies, targeted arrays and summary statistics 2019. Nucleic Acids Research, 47(D1), D1005–D1012. 4. Chang, C. C., Chow, C. C., CAM Tellier, L., Vattikuti, S., Purcell, S. M., & Lee, J. J. (2015). Second-generation PLINK: Rising to the challenge of larger and richer datasets. GigaScience, 4(1). 5. Dahl, Andy, & Zaitlen, Noah. (2020). Genetic influences on disease subtypes. Annual Review of Genomics and Human Genetics, 21(1), 413–435. 6. Spiros et al. (omitted for brevity) Denaxas. (2019). UK phenomics platform for developing and validating electronic health record phenotypes: CALIBER. Journal of the American Medical Informatics Association, 26(12), 1545–1559.
134
A. Vauvelle et al.
7. Denny, Joshua C., Ritchie, Marylyn D., Basford, Melissa A., Pulley, Jill M., Bastarache, Lisa, Brown-Gentry, Kristin, et al. (2010). PheWAS: Demonstrating the feasibility of a phenomewide scan to discover gene-disease associations. Bioinformatics, 26(9), 1205–1210. 8. Devlin, J., Chang, M-W., Lee, K., & Toutanova, K. (2019). BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. arXiv:1810.04805 [cs]. 9. Diogo, D., & Heiko et al. (omitted for brevity) Runz. (2018). Phenome-wide association studies across large population cohorts support drug target validation. Nature Communications, 9 (1), 4285. 10. Durbin, R. M., & David et al (omitted for brevity) Altshuler. (2010). A map of human genome variation from population-scale sequencing. Nature, 467(7319), 1061–1073. 11. Edwards, B. J., Haynes, C., Levenstien, M. A., Finch, S. J., & Gordon, D. (2005). Power and sample size calculations in the presence of phenotype errors for case/control genetic association studies. BMC genetics, 6, 18. 12. Elkan, C., & Noto, K. (2008). Learning classifiers from only positive and unlabeled data. In Proceeding of the 14th ACM SIGKDD international conference on Knowledge discovery and data mining - KDD 08 (p. 213) Las Vegas, Nevada, USA: ACM Press. ISBN 978-1-60558193-4. 13. Garfield, V., Farmaki, A-E., Fatemifar, G., Eastwood, S. V., Mathur, R., Rentsch, C. T. et al. (2021). The relationship between glycaemia, cognitive function, structural brain outcomes and dementia: A mendelian randomisation study in the UK biobank. Diabetes, db200895. 14. Halpern, Yoni, Choi, Youngduck, Horng, Steven, & Sontag, David. (2014). Using anchors to estimate clinical state without labeled data. AMIA Annual Symposium Proceedings, 2014, 606–615. 15. Halpern, Y., Horng, S., Choi, Y., & Sontag, D. (2016). Electronic medical record phenotyping using the anchor and learn framework. Journal of the American Medical Informatics Association: JAMIA, 23(4), 731–740. 16. Hansen, M. A., Mikalsen, K., Kampffmeyer, M., Soguero-Ruiz, C., & Jenssen, R. (2018). Towards deep anchor learning. In 2018 IEEE EMBS International Conference on Biomedical Health Informatics (BHI) (pp. 315–318). 17. Lee, C., & van der Schaar, M. (2020). Temporal Phenotyping using Deep Predictive Clustering of Disease Progression. arXiv:2006.08600 [physics, stat]. 18. Li, L., Cheng, W-Y., Glicksberg, B. S., Gottesman, O., Tamler, R., Chen, R. et al. (2015). Identification of type 2 diabetes subgroups through topological analysis of patient similarity. Science Translational Medicine, 7(311), 311ra174–311ra174. 19. Li, Y., Rao, S., Roberto Ayala Solares, J., Hassaine, A., Ramakrishnan, R., Canoy, D. et al. (2020). BEHRT: Transformer for electronic health records. Scientific Reports, 10(1), 7155. 20. Miotto, R., Li, L., Kidd, B. A., & Dudley, J. T. (2016). Deep patient: An unsupervised representation to predict the future of patients from the electronic health records. Scientific Reports, 6(1), 26094. 21. Shah, S., Henry, A., Carolina et al (omitted for brevity) Roselli. (2020). Genome-wide association and Mendelian randomisation analysis provide insights into the pathogenesis of heart failure. Nature Communications, 11(1), 163. 22. Si, Y., Du, J., Li, Z., Jiang, X., Miller, T., Wang, F. et al. (2020). Deep representation learning of patient data from electronic health records (EHR): A systematic review. Journal of Biomedical Informatics. 23. Sinnott, J. A., Cai, F., Sheng, Y., Hejblum, B. P., Hong, C., Kohane, I. S. et al. (2018). PheProb: probabilistic phenotyping using diagnosis codes to improve power for genetic association studies. Journal of the American Medical Informatics Association: JAMIA, 25(10), 1359–1365. 24. Slatkin, M. (2008). Linkage disequilibrium - understanding the evolutionary past and mapping the medical future. Nature Reviews Genetics, 9(6), 477–485. 25. Cathie at al (omitted for brevity) Sudlow. (2015) UK Biobank: An open access resource for identifying the causes of a wide range of complex diseases of middle and old age. PLOS Medicine, 12(3), e1001779.
Phenotyping with Positive Unlabelled Learning for Genome-Wide Association Studies
135
26. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N. et al. (2017). Attention Is All You Need. arXiv:1706.03762 [cs]. 27. Woodfield, R., Grant, I., Sudlow, C. L. M. (2015). Accuracy of electronic health record data for identifying stroke cases in large-scale epidemiological studies: A systematic review from the UK Biobank stroke outcomes group. PLoS ONE, 10(10), e0140533. 28. Patrick, W., Gifford, A., Meng, X., Li, X., Campbell, H., Varley, T., et al. (2019). Mapping ICD-10 and ICD-10-CM codes to Phecodes: Workflow development and initial evaluation. JMIR Medical Informatics, 7(4), e14325. 29. Xia, X., Liu, T., Han, B., Wang, N., Gong, M., Liu, H. et al. (2020). Part-dependent label noise: Towards instance-dependent label noise. In Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M. F., & Lin, H. (Eds.) Advances in neural information processing systems (Vol. 33, pp. 7597–7610). Curran Associates, Inc. 30. Yu, S., Ma, Y., Gronsbell, J., Cai, T., Ananthakrishnan, A. N., Gainer, V. S. et al. (2018). Enabling phenotypic big data with PheNorm. Journal of the American Medical Informatics Association, 25(1), 54–60. 31. Zhang, X., Chou, J., Liang, J., Xiao, C., Zhao, Y., Sarva, H. et al. (2019). Data-driven subtyping of Parkinson’s disease using longitudinal clinical records: A cohort study. Scientific Reports, 9(1), 797. 32. Zwaan, L., & Singh, Ha. (2020). Diagnostic error in hospitals: finding forests not just the big trees. BMJ Quality & Safety, 29(12), 961–964.
Out-of-Distribution Detection for Medical Applications: Guidelines for Practical Evaluation Karina Zadorozhny, Patrick Thoral, Paul Elbers, and Giovanni Cinà
Abstract Detection of Out-of-Distribution (OOD) samples in real time is a crucial safety check for deployment of machine learning models in the medical field. Despite a growing number of uncertainty quantification techniques, there is a lack of evaluation guidelines on how to select OOD detection methods in practice. This gap impedes implementation of OOD detection methods for real-world applications. Here, we propose a series of practical considerations and tests to choose the best OOD detector for a specific medical dataset. These guidelines are illustrated on a real-life use case of Electronic Health Records (EHR). Our results can serve as a guide for implementation of OOD detection methods in clinical practice, mitigating risks associated with the use of machine learning models in healthcare. Keywords Out-Of-Distribution detection · Density estimation · Evaluation guidelines · Electronic health records
1 Introduction When deploying a machine learning model in a high-risk environment such as healthcare, it is crucial to reduce the risks of giving an incorrect prediction. One of the most common causes of failures of a well-performing model is dissimilarity between trainK. Zadorozhny (B) · G. Cinà Pacmed BV, Amsterdam, The Netherlands e-mail: [email protected] G. Cinà e-mail: [email protected] P. Thoral · P. Elbers Department of Intensive Care Medicine, Laboratory for Critical Care Computational Intelligence, Amsterdam Medical Data Science, Amsterdam UMC, Vrije Universiteit,Amsterdam, The Netherlands e-mail: [email protected] P. Elbers e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_10
137
138
K. Zadorozhny et al.
ing data (in-distribution) and data that the model is used on after deployment [36]. This problem is further exacerbated in a non-stationary medical environment: the performance of predictive models can drop drastically when evaluated on medical data collected just a couple of years removed from in-distribution data [29]. The reason for this can be, for example, a change in the demographics of the target population, changes in treatment protocols, or changes in data collection procedures [8]. These examples constitute a covariate shift where feature distributions changed compared to in-distribution data [24, 37]. For reliable use of prediction models in healthcare it is necessary to have the ability to detect these changes in real-time. In the past five years, there has been a surge in methods for OOD detection and uncertainty quantification. Despite the progress, OOD detection methods often fail at flagging atypical inputs. While most studies have focus on image data [27, 30], it has recently been shown that many OOD detection methods fail to distinguish OOD inputs on medical tabular data [44]. While predictive models are tailored and tested on specific datasets, there is no universal way to evaluate OOD detection methods in practice. This gap impedes wide-spread implementation of OOD detection methods. In this paper, we provide a practical set of guidelines on how to select an OOD detector in a given medical AI application. Our specific contributions are the following: • We describe how to take into account dataset-specific variables when evaluating OOD detector models in practice. • We show how to test OOD detectors on distinct families of OOD groups that can be created from available data using inclusion-exclusion criteria and withholding samples during training. • We describe how to use interpretability tools to assess whether influential features that drive novelty predictions are clinically relevant. • We provide an open-source repository that can be used to evaluate OOD detection methods for any mixed-type tabular dataset.1 We illustrate these general principles on real-world EHR data from the Intensive Care Unit (ICU) of Amsterdam University Medical Centers (AmsterdamUMC) [40]. We show that using the guidelines proposed in this paper, we can detect underperforming OOD detectors and select the best methods to be deployed in practice.
2 Related Work Despite the impressive performance of machine learning models, recent studies have highlighted the potential risks of undetected OOD data points in critical areas such as healthcare. Besides deliberate threats such as data poisoning and adversarial attacks [7, 25, 32], deployed models can suffer from severe malfunction due to dataset shifts. 1
https://github.com/Giovannicina/selecting_OOD_detector.
Out-of-Distribution Detection for Medical …
139
Finlayson et al. described different scenarios that lead to dataset shifts in practice and highlighted the need for more robust and reliable models in healthcare [8]. To mitigate risks of failure due to dataset shift caused by covariate shifts [2], recent studies proposed a plethora of uncertainty quantification methods. These methods can be broadly divided into two groups based on the type of uncertainty they express. The first group uses predictive uncertainty (uncertainty about predictions) which is model-specific and indicates how confident a model is in its prediction [30]; such uncertainty can be measured with a variety of metrics. Models in this category include Bayesian Neural Networks [3], Ensemble Neural Networks [19], or Monte Carlo Dropout [11]. The second group of models, density estimators, expresses uncertainty about samples by learning the distribution of training data. Models such as Variational Autoencoder (VAE) [17, 34], or Normalizing Flows [35] can be used to flag samples with low likelihood under an estimated in-distribution function. There has been only limited work published to date that focuses on the practical evaluation of OOD detection methods in real-world scenarios. [14] proposed a benchmark for OOD detection in the imaging domain that could replace small-scale datasets containing only a few classes with a more realistic high-scale multi-class anomaly segmentation task. [39] described practical tests of OOD detection in the imaging domain to aid the evaluation of models in real-world applications. However, no guidelines on how to test OOD detectors are available for medical tabular data. In this paper, we aim to close this gap between research studies and the practical implementation of OOD detection methods in a medical setting.
3 Methods We first describe considerations that should be taken into account when choosing an OOD detector for a specific application. We then put forward guidelines for creating tests to assess the performance of OOD detection methods. The proposed OOD tests are illustrated on a real-world use case of EHR data of AmsterdamUMC.
3.1 Considerations Model Type. It has been shown that models expressing predictive uncertainty in a classification setting can be overconfident in their faulty predictions [12, 30]. Specifically, a group of discriminator models assign high predictive confidence to regions with no training data [13, 43]. Given these findings, we advise using density estimators instead of models that express predictive uncertainty. While there is a growing interest in the limitations of density estimators for OOD detection [18, 27, 48], density models outperformed discriminator models on several OOD detection tasks on tabular medical data [44].
140
K. Zadorozhny et al.
Data Type. Many density estimators can perform well on continuous data but are not directly suitable for discrete distributions. Research efforts of extending models to discrete distributions most commonly focus on text inputs [22, 42]. For some applications, such as learning density of images, dequantization can be applied to map discrete pixel values data to continuous distributions [15]. However, this approach is not applicable to categorical data with no intrinsic order. Mixed type data that consist of categorical and continuous features pose additional challenges to density modeling [28]. Therefore, the proportion of categorical features should be taken into account when choosing an OOD detector. Dimensionality and Size of Data. The dimensionality of data is a crucial factor that influences the performance of density estimators. For example, standard kernel density estimators can achieve very good results on low-dimensional inputs but their performance degrades exponentially with an increased number of dimensions [26, 46]. Other models, such as Autoencoder (AE) and Probabilistic Principal Component Analysis (PPCA) that perform dimensionality reduction by default are less affected by the number of features. Having a large number of samples from which to estimate data distribution can unequally affect different density estimators. While most models benefit from having access to more data, algorithms such as Local Outlier Factor (LOF) become very time- and memory- inefficient with an increasing number of samples [5]. Class Balance. In the classification setting, having unbalanced classes, which is very common for medical diagnostics, can negatively influence usefulness of learned data density. If there is not a sufficient number of samples of one class present in the training dataset, new samples of this minority class can be mistakenly flagged as OOD despite being an essential component of the dataset. While densities of two classes could be modeled separately or using a class-conditional density estimator [38], this approach is not feasible as we would have to first predict a sample’s class and only then test whether the sample is in-distribution. Therefore, likelihood scores for underrepresented classes should be monitored when comparing different density models.
3.2 Designing OOD Tests While any data point outside a training set can be considered as OOD, it is not feasible to test against all possible inputs. Work by [48] argues that it is not theoretically possible to design a method that could detect all possible out-distribution inputs. [47] distinguished far-OOD and near-OOD depending on how similar inputs are to in-distribution data. While we agree that far-OOD provide a good sanity check on the usefulness of OOD detectors, we ultimately believe that the latter group is more relevant and more difficult to detect. We thus focus on near-OOD scenarios, specifying families of OOD groups that could be encountered in a medical setting in real life.
Out-of-Distribution Detection for Medical … Table 1 Clinically relevant OOD groups Type of OOD Examples groups Outside Static inclusionexclusion criteria
Dynamic
New disease
141
How to create?
Separate excluded groups according to demographicrelated characteristics. Treatment, length Split on one or of stay more features related to a clinical status that is subject to change. New infectious Artificially create disease, any an OOD group by diagnosis withholding them underrepresented during training in training data
AmsterdamUMC Example
Age, race
Ventilation, medications, patients far from discharge
COVID-19 patients and suspects
Fig. 1 OOD groups can be created from data that was excluded under inclusion-exclusion criteria (panel A). If there are not enough samples in excluded data, groups can be withheld during training (panel B)
We suggest to create clinically relevant OOD tests to cover the following categories (Table 1): • Detecting patients outside static inclusion-exclusion criteria (e.g. demographics). • Detecting patients outside dynamic inclusion-exclusion criteria (e.g. a treatment protocol). • Detecting patients with a new or underrepresented disease. The OOD groups can be created either by using data eliminated under inclusionexclusion principle described below (Fig. 1a) or by withholding specific groups from in-distribution (Fig. 1b).
142
3.2.1
K. Zadorozhny et al.
Patients Outside Static Inclusion-Exclusion Criteria
Processing medical data involves defining inclusion and exclusion criteria that specify which samples are going to be part of the cohort to be analyzed. Inclusion criteria, which are typically clinical or demographic characteristics, must be present in order to include a sample. Exclusion criteria eliminate samples that passed inclusion criteria but have other characteristics that could invalidate downstream analysis [33]. Inclusion-exclusion criteria can be used to evaluate the performance of OOD detection methods. Provided that models are trained only on samples that follow inclusion-exclusion criteria, we can treat ineligible samples as OOD and compare how well these groups are flagged by OOD detectors. Such groups can be made using demographic or clinical characteristics (Fig. 1a).
3.2.2
Patients Outside Dynamic Inclusion-Exclusion Criteria
Another example of clinically relevant OOD is a group of patients that should currently be excluded from data, but this status can change over time. In a medical setting, this is a common occurrence and can be caused by a change in the treatment protocol. If data about individual patients are available over a period of time, such groups can be created by finding a criterion that is not static and is subject to change. Alternatively, OOD groups can be created by splitting data on a received treatment.
3.2.3
Patients with New Diseases
A clinically relevant scenario that highlights the necessity of OOD detection methods is an occurrence of novel diseases. COVID-19 provides an example of the importance of being able to detect atypical symptoms in real-time instead of giving a prediction and only assessing data retrospectively. A common scenario occurs when a model is used for patients with a disease that was not sufficiently represented in the indistribution data. We advise creating OOD groups of patients with underrepresented disease by either applying inclusion-exclusion criteria or withholding samples of patients with a specific disease from in-distribution training data (Fig. 1).
3.2.4
Interpretability of OOD Detection Methods
Another issue when deploying OOD detection methods is how transparent novelty predictions are. In high-stakes areas such as healthcare, it is not only useful to be able to flag a sample as OOD but also to see what set of features makes it different from in-distribution data. Some density estimators are inherently interpretable. For example, for AE-based models, it is possible to inspect features that received the
Out-of-Distribution Detection for Medical …
143
Fig. 2 Assessing interpretability of OOD detectors. a In the dataset-level approach, data are artificially split into in-distribution and OOD based on a value of one selected feature. b In the qualitative approach, individual feature importance scores of individual patients are inspected manually with the input from clinicians
highest reconstruction error. To provide a comparative evaluation of interpretability for all models, we suggest using Shapley Additive Explanations (SHAP; [20]). We describe two approaches of assessing interpretability of OOD detection methods (Fig. 2). The first, dataset-level approach, splits data on one feature to divide in-distribution and OOD data. After scoring OOD samples, SHAP is used to determine feature importance for the assigned novelty scores. Models are then scored on the importance of the feature that was used for splitting (Fig. 2a). In the second, qualitative approach, OOD detectors are trained on in-distribution data and are used to score test data (Fig. 2b). Samples that receive the highest novelty scores (outliers) are then inspected individually. Ranking the most influential features using SHAP and having medical professionals compare values of these features can help validate whether models are influenced by clinically relevant features and verify that outliers are indeed meant to be in-distribution.
3.2.5
Inference Time for Real-Time Prediction
The ability to flag OOD samples in real-time, as compared to detecting batches of atypical samples retrospectively, is an important consideration for the deployment of OOD detection methods. While inference time for individual predictions tends not be a limiting factor, when coupled with explainability tools such as SHAP, can significantly slow down the prediction process.
144
K. Zadorozhny et al.
3.3 Dataset To illustrate the above-described principles in practice, we used a dataset from Amsterdam University Medical Center (AmsterdamUMC). This dataset reflects a real-life medical use case of EHR data, which typically has a lot of features, including categorical ones, and an unbalanced outcome. AmsterdamUMC has released a freely accessible de-identified version of the data from 2003–2016. Access to the data and/or the Dutch Data Warehouse against COVID-19 can be requested from Amsterdam Medical Data Science (https://amsterdammedicaldatascience.nl). AmsterdamUMC dataset contains retrospective data of patients admitted to Intensive Care Unit (ICU) for each day of their admission stay. This results in 58,142 rows of data corresponding to 15,753 patients. The total number of features in the dataset is 5,269. Features were created by analyzing time series and dividing data into windows of 24 h. We selected 56 features of which 49 are continuous and 7 categorical. The downstream prediction task is a binary classification of readmission to ICU after discharge or mortality. For this task, any predictive model can be used while we aim to compare different OOD detection on this dataset.
3.4 Models Given that discriminator models were shown to underperform on OOD detection for medical tabular data [44], we limited our analysis to density estimators. We included models that do not model explicit density function, such as Autoencoder and Local Outlier Factor (LOF). Specific model hyperparameters can be found in the Appendix section. Autoencoder. AE was used in combination with a reconstruction error metric. This metric was chosen because it is expected that the model will learn to encode training distribution faithfully whereas samples that are different from training data will not be reconstructed properly and will receive a high reconstruction error. Variational Autoencoder. We implemented a Variational Autoencoder (VAE), one of the most popular density estimation techniques, [17]. Similar to AE, the novelty metric was chosen to be reconstruction error. Probabilistic Principal Component Analysis. Probabilistic Principal Component Analysis (PPCA) [41] was used as a latent variable model with a log likelihood metric. Deep Neural Gaussian Process. We implemented a deep neural Gaussian Process model called Deterministic Uncertainty Estimation (DUE) according to [45]. This approach builds on the idea of feature-distance awareness which denotes ability of model to quantify distance of new samples to training data. DUE uses a distancepreserving neural network for feature extraction and models uncertainties using a Gaussian Process in the last layer. Standard deviation was used as novelty metric.
Out-of-Distribution Detection for Medical …
145
Normalizing Flows. Normalizing flow models are tractable explicit density estimators that learn training data distribution by series of transformations into a normal distribution. We used Masked Autoregressive Flow as described in [31] from an open-source implementation [6]. Local Outlier Factor. LOF is an unsupervised algorithm for finding outliers. While LOF does not model density function of training data, it expresses local density of features of a sample and compares this density to closest neighbors [5].
3.5 AUC Score of OOD Detection We use AUC-ROC scores to report the performance of OOD detection methods. First, models are trained on in-distribution training data. Then, they are used to predict novelty scores of in-distribution testing data and samples in each OOD group. The scores of test in-distribution data and an OOD group are then used to calculate AUCROC score by assigning a label 0 to all in-distribution samples and a label of 1 to all samples in an OOD group.
4 Experimental Results Dynamic Inclusion-Exclusion Criteria: Length of Stay. We followed the guidelines described above and selected patients outside dynamic inclusion-exclusion criteria which are subject to change in time. We used the fact that the AmsterdamUMC dataset contains retrospective time series data for each patients. The in-distribution data for this experiment are patients in the last day of their ICU admission and OOD groups are separated based on how far patients are from being discharged (Fig. 3). The further patients are from discharge, the more dissimilar their features are compared to in-distribution data. While all models except DUE showed the desired gradient of increasing AUC-ROC score, Flow achieved a near-perfect score already for patients 4–5 days before discharge. Dynamic Inclusion-Exclusion Criteria: Discharge Destination Given that the main predictive task is classification of readmission probability, patients that died at the ICU or were transferred to another department or a different hospital are excluded from in-distribution data. Most models detected patients with discharge destination of mortuary more easily than patients that were transferred to an ICU of a different hospital (Fig. 4). This result is reassuring as, arguably, features of transferred patients resemble more closely in-distribution data. Dynamic Inclusion-Exclusion Criteria: Received Treatment Next, we selected patients outside dynamic inclusion-exclusion criteria based on the type of received treatment. We compared the OOD detection methods on two categories of treatment-
146
K. Zadorozhny et al.
Fig. 3 Mean AUC scores (n = 5) of detecting patients far from discharge. Row labels indicate the number of days before discharge. In-distribution data in this experiment are patients 1 day before discharge
Fig. 4 AUC-ROC scores of detecting patients that died at the ICU (Mortuary with or without autopsy) and patients that were transferred to another hospital (ICU or other department). This experiment was performed on the AmsterdamUMC-DIS-II dataset
Out-of-Distribution Detection for Medical …
147
Fig. 5 Top panel: Mean AUC scores and standard deviations (n = 5) of detecting ventilated patients and patients with a renal kidney failure receiving CVVH. Bottom panel: Novelty score distributions. Plotted values are clipped to 5–95% of test novelty scores to prevent outlier scores from skewing the graphs
related exclusion criteria: patients connected to a specific device and patients receiving medication interventions. We first tested whether novelty detectors are able to flag patients connected to a ventilating machine or continuous veno-venous hemofiltration (CVVH) device (Fig. 5). Note that in our experiment, there are no features that would directly indicate whether patients are ventilated or on CVVH, and therefore, OOD detectors must infer this information from the rest of the features. While most models are assigning greater novelty scores to the OOD groups, Flow model has the least overlapping distributions for the groups and in-distribution test data (Fig. 5, bottom). Observing Patients with New Diseases We use data from the same ICU containing confirmed and suspected COVID-19 patients collected as part of the Dutch Data Warehouse [9] to test OOD detectors in their ability to flag patients with this disease (Fig. 6). While AE, PPCA, and VAE achieved AUC-ROC above 70% for COVID-19 patients, Flow was able to achieve over 90% for both groups.
5 Discussion While it is standard practice to evaluate predictive models on each dataset, there is no established way to compare different OOD detection methods. In this paper, we described a set of practical guidelines on how to create and test OOD detection methods on medical tabular data. A similar effort in bridging the gap between research studies and practical implementation was shown for other critical areas such as AI explainability. The authors in
148
K. Zadorozhny et al.
Fig. 6 Top panel: AUC scores and standard deviations (n = 5) of detecting COVID-19 patients and suspects. Bottom panel: Novelty score distributions. Plotted values are clipped to 5–95% of test novelty scores to prevent outlier scores from skewing the graphs
[23] provided categorization of evaluation strategies of explainable AI, while recent work [10] presented practical guidelines for using explainability techniques which were illustrated using different case studies. Others have proposed general considerations for deploying machine learning models into clinical practice [4]. There are many outstanding questions related to design and selection of OOD detection methods. First, to our knowledge, there is no direct solution to class imbalance of in-distribution data for OOD detection and current methods have to rely on availability of sufficiently large sample size. Second, current models are not well suited for dealing with mixed-type data containing categorical and continuous distributions as they require different likelihood functions. Extending density estimators to mixed-type data [21, 28] could greatly enhance OOD detection results. Third, more efficient ways to deal with feature correlation can improve evaluation of interpretability of OOD detectors to prevent feature-importance spread [1, 20]. The setup described here assumes some prior knowledge about the possible OOD groups that should be detected as near-OOD [47]. It is in general good practice to involve clinical experts in designing near-OOD scenarios. We acknowledge that apart from tests described in this paper, other clinically relevant scenarios include detecting samples from different hospitals, and testing for corrupted features [44]. Other open-ended and application-specific questions are: (i) how to choose the most reliable OOD detector aggregating the tests’ results, and (ii) how to select a threshold for flagging inputs. In our experiments on real-life mixed-type EHR data, Flow and VAE performed consistently well across different OOD tests. Given that the performance of density estimators depends on the type of data and dimensionality of the dataset, the superior performance of these two models can be explained by a relatively low number of categorical features (less than 10%), potential non-linear interactions of different features (which could be more problematic for PPCA), and
Out-of-Distribution Detection for Medical …
149
large data size (which can make LOF less efficient). Recently, failures of normalizing flows in assigning higher likelihood to OOD inputs were discussed by several papers [18, 27, 48]. We hypothesize that structured data such as images that give rise to local pixel correlations are more prone to such failures than tabular data. Finally, the guidelines for evaluating OOD detectors for medical data described in this paper can help OOD detectors bridge the gap from theoretical possibility to deployed application, enhancing the safety of AI tools and facilitating the uptake of this new technology in the medical field. Acknowledgements We would like to thank our colleagues at Pacmed and Amsterdam UMC for providing us with the data and insights. We also thank Dennis Ulmer for providing valuable feedback on the manuscript.
Appendix Model Hyperparameters Autoencoder For the AmsterdamUMC dataset, we used the following hyperparameters: both encoder and the encoder had 1 layer with 75 units, the latent space contained 20 dimensions, and the learning rate was 0.007.
Variational Autoencoder For experiments on AmsterdamUMC, we used the following hyperparameters: encoder and the encoder contained 3 layers each with 25 units, the latent space contained 10 dimensions, and the learning rate was 0.001. Additionally, we used beta-annealing (a deterministic warm-up) [16].
Probabilistic Principal Component Analysis The hyperparameter for PPCA model was only the number of components which was selected as 19 for AmsterdamUMC.
150
K. Zadorozhny et al.
Deep Neural Gaussian Process The hyperparameters that were used for AmsterdamUMC data were the following: we used the Matern 12 kernel function with 50 inducing points, feature-extracting neural network consisted of 4 layers with 256 units, we set the Lipschitz coefficient at 0.5, and the learning rate at 0.002.
Normalizing Flows The hyperparameters were selected to be the following: 20 layers of composite transformations and reverse permutation, the number of hidden units was 256, we used batch normalization between layers, and the learning rate was set to 0.001.
Local Outlier Factor The only tunable hyperparameter is the number of closest neighbors which was set to 5.
References 1. Aas, K., Jullum, M., & Løland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. 2. Bickel, S., Brückner, M., & Scheffer, T. (2009). Discriminative learning under covariate shift. Journal of Machine Learning Research, 10(75), 2137–2155. 3. Blundell, C., Cornebise, J., Kavukcuoglu, K., & Wierstra, D. (2015). Weight uncertainty in neural networks. In F. Bach and D. Blei (Eds.), Proceedings of the 32nd International Conference on Machine Learning (Vol. 37, pp. 1613–1622). Proceedings of Machine Learning Research. 4. Chen, P. -H. C., Liu, Y., & Peng, L. (2019). How to develop machine learning models for healthcare. Nature Materials, 18(5), 410–414. Bandiera_abtest: a Cg_type: Nature Research Journals Number: 5 Primary_atype: Comments & Opinion Publisher: Nature Publishing Group. 5. de Vries, T., Chawla, S., & Houle, M. E. (2010). Finding local anomalies in very high dimensional space. In 2010 IEEE International Conference on Data Mining (pp. 128–137). ISSN: 2374-8486. 6. Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G. (2020). Nflows: Normalizing flows in PyTorch. 7. Finlayson, S. G., Bowers, J. D., Ito, J., Zittrain, J. L., Beam, A. L., & Kohane, I. S. (2019). Adversarial attacks on medical machine learning. Science, 363(6433), 1287–1289. Publisher: American Association for the Advancement of Science Section: Policy Forum.
Out-of-Distribution Detection for Medical …
151
8. Finlayson, S. G., Subbaswamy, A., Singh, K., Bowers, J., Kupke, A., Zittrain, J., Kohane, I. S., Saria, S. (2021). The clinician and dataset shift in artificial intelligence. New England Journal of Medicine, 385(3), 283–286. Publisher: Massachusetts Medical Society. https://doi.org/10. 1056/NEJMc2104626 9. Fleuren, L. M., Dam, T. A., Tonutti, M., de Bruin, D. P., Lalisang, R. C. A., Gommers, D., Cremer, O. L., Bosman, R. J., Rigter, S., Wils, E. -J., Frenzel, T., Dongelmans, D. A., de Jong, R., Peters, M., Kamps, M. J. A., Ramnarain, D., Nowitzky, R., Nooteboom, F. G. C. A., de Ruijter, W., & Elbers, P. W. G. (2021). The dutch data warehouse, a multicenter and full-admission electronic health records database for critically ill COVID-19 patients. Critical Care, 25(1), 304. 10. Gade, K., Geyik, S. C., Kenthapadi, K., Mithal, V., & Taly, A. (2019). Explainable AI in industry. In Proceedings of the 25th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, KDD ’19 (pp. 3203–3204), New York, NY, USA: Association for Computing Machinery. 11. Gal, Y., & Ghahramani, Z. (2016). Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. In International Conference on Machine Learning (pp. 1050–1059). 12. Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In D. Precup, & Y. W. Teh (Eds.), Proceedings of the 34th International Conference on Machine Learning (Vol. 70, pp. 1321–1330). Proceedings of Machine Learning Research. PMLR. 13. Hein, M., Andriushchenko, M., & Bitterwolf, J. (2019). Why ReLU networks yield highconfidence predictions far away from the training data and how to mitigate the problem. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 41–50). 14. Hendrycks, D., Basart, S., Mazeika, M., Mostajabi, M., Steinhardt, J., & Song, D. (2020). Scaling out-of-distribution detection for real-world settings. 15. Hoogeboom, E., Cohen, T. S., & Tomczak, J. M. (2020). Learning discrete distributions by dequantization. arXiv:2001.11235 [cs, stat] 16. Huang, C. -W., Tan, S., Lacoste, A., & Courville, A. C. (2018). Improving explorability in variational inference with annealed variational objectives. In Advances in neural information processing systems (Vol. 31). Curran Associates, Inc. 17. Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. 18. Kirichenko, P., Izmailov, P., & Wilson, A. G. (2020). Why normalizing flows fail to detect out-of-distribution data. arXiv:2006.08545 [cs, stat] 19. Lakshminarayanan, B., Pritzel, A., & Blundell, C. (2017). Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in neural information processing systems (pp. 6402–6413). 20. Lundberg, S., & Lee, S. -I. (2017). A unified approach to interpreting model predictions. 21. Ma, C., Tschiatschek, S., Hernàndez-Lobato, J. M., Turner, R., & Zhang, C. (2020). Vaem: A deep generative model for heterogeneous mixed type data. 22. Miao, Y., Yu, L., & Blunsom, P. (2016). Neural variational inference for text processing. arXiv:1511.06038 [cs, stat] 23. S. Mohseni, N. Zarei, and E. D. Ragan. A multidisciplinary survey and framework for design and evaluation of explainable ai systems, 2020. 24. Moreno-Torres, J. G., Raeder, T., Alaiz-Rodríguez, R., Chawla, N. V., & Herrera, F. (2012). A unifying view on dataset shift in classification. Pattern Recognition, 45(1), 521–530. 25. Mozaffari-Kermani, M., Sur-Kolay, S., Raghunathan, A., & Jha, N. K. (2015). Systematic poisoning attacks on and defenses for machine learning in healthcare. IEEE Journal of Biomedical and Health Informatics, 19(6), 1893–1905. 26. Nagler, T., & Czado, C. (2016). Evading the curse of dimensionality in nonparametric density estimation with simplified vine copulas. Journal of Multivariate Analysis, 151. arXiv: 1503.03305 27. Nalisnick, E. T., Matsukawa, A., Teh, Y. W., Görür, D., & Lakshminarayanan, B. (2019). Do deep generative models know what they don’t know? In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, LA, USA, May 6–9, 2019. OpenReview.net.
152
K. Zadorozhny et al.
28. Nazabal, A., Olmos, P. M., Ghahramani, Z., & Valera, I. (2020). Handling incomplete heterogeneous data using VAEs. arXiv:1807.03653 [cs, stat] 29. Nestor, B., McDermott, M. B. A., Chauhan, G., Naumann, T., Hughes, M. C., Goldenberg, A., & Ghassemi, M. (2018). Rethinking clinical prediction: Why machine learning must consider year of care and feature aggregation. 30. Ovadia, Y., Fertig, E., Ren, J., Nado, Z., Sculley, D., Nowozin, S., Dillon, J., Lakshminarayanan, B., & Snoek, J. (2019). Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. Advances in Neural Information Processing Systems (Vol. 32). 31. Papamakarios, G., Pavlakou, T., & Murray, I. (2018). Masked autoregressive flow for density estimation. 32. Papangelou, K., Sechidis, K., Weatherall, J., & Brown, G. (2019). Toward an understanding of adversarial examples in clinical trials. In M. Berlingerio, F. Bonchi, T. Gärtner, N. Hurley, & G. Ifrim (Eds.), Machine learning and knowledge discovery in databases (pp. 35–51). Lecture notes in computer science. Cham: Springer International Publishing. 33. Patino, C. M., & Ferreira, J. C. (2018). Inclusion and exclusion criteria in research studies: Definitions and why they matter. Jornal Brasileiro de Pneumologia, 44(2), 84. 34. Ran, X., Xu, M., Mei, L., Xu, Q., & Liu, Q. (2020). Detecting out-of-distribution samples via variational auto-encoder with reliable uncertainty estimation. 35. Rezende, D., & Mohamed, S. (2015). Variational inference with normalizing flows. In F. Bach and D. Blei (Eds.), Proceedings of the 32nd International Conference on Machine Learning (Vol. 37, pp. 1530–1538). Proceedings of machine learning research. Lille, France, 07–09 Jul 2015, PMLR. 36. Saria, S., & Subbaswamy, A. (2019). Tutorial: Safe and reliable machine learning. 37. Shimodaira, H. (2000). Improving predictive inference under covariate shift by weighting the log-likelihood function. Journal of Statistical Planning and Inference, 90(2), 227–244. 38. Sohn, K., Lee, H., & Yan, X. (2015). Learning structured output representation using deep conditional generative models. In C. Cortes, N. Lawrence, D. Lee, M. Sugiyama & R. Garnett (Eds.), Advances in neural information processing systems (Vol. 28). Curran Associates, Inc. 39. Techapanurak, E., & Okatani, T. (2021). Practical evaluation of out-of-distribution detection methods for image classification. 40. Thoral, P. J., Peppink, J. M., Driessen, R. H., Sijbrands, E. J. G., Kompanje, E. J. O., Kaplan, L., Bailey, H., Kesecioglu, J., Cecconi, M., Churpek, M., Clermont, G., van der Schaar, M., Ercole, A., Girbes, A. R. J., & Elbers, P. W. G. (2021). Amsterdam university medical centers database (AmsterdamUMCdb) Collaborators and the SCCM/ESICM joint data science task force. Sharing ICU patient data responsibly under the society of critical care medicine/European society of intensive care medicine joint data science collaboration: The Amsterdam university medical centers database (AmsterdamUMCdb) example. Critical Care Medicine, 49(6), e563–e577. 41. Tipping, M. E., & Bishop, C. M. (1999). Probabilistic principal component analysis. Journal of the Royal Statistical Society. Series B (Statistical Methodology), 61(3), 611–622. 42. Tran, D., Vafa, K., Agrawal, K. K., Dinh, L., & Poole, B. (2019). Discrete flows: Invertible generative models of discrete data. arXiv:1905.10347 [cs, stat]. 43. Ulmer, D., & Ciná, G. (2021). Know your limits: Uncertainty estimation with relu classifiers fails at reliable ood detection. 44. Ulmer, D., Meijerink, L., & Cinà, G. (2020). Trust issues: Uncertainty estimation does not enable reliable OOD detection on medical tabular data. In Proceedings of the Machine Learning for Health NeurIPS Workshop (Vol. 136, pp. 341–354). 45. van Amersfoort, J., Smith, L., Jesson, A., Key, O., & Gal, Y. (2021). Improving deterministic uncertainty estimation in deep learning for classification and regression. 46. Wang, Z., & Scott, D. W. (2019). Nonparametric density estimation for high-dimensional data—Algorithms and applications. Wiley Interdisciplinary Reviews: Computational Statistics, 11(4), e1461. arXiv: 1904.00176
Out-of-Distribution Detection for Medical …
153
47. Winkens, J., Bunel, R., Roy, A. G., Stanforth, R., Natarajan, V., Ledsam, J. R., MacWilliams, P., Kohli, P., Karthikesalingam, A., Kohl, S., Cemgil, T., Eslami, S. M. A., & Ronneberger, O. (2020). Contrastive training for improved out-of-distribution detection. 48. Zhang, L. H., Goldstein, M., & Ranganath, R. (2021). Understanding failures in out-ofdistribution detection with deep generative models. arXiv:2107.06908 [cs]
A Robust System to Detect and Explain Public Mask Wearing Behavior Akshay Gupta and Biplav Srivastava
Abstract COVID-19 is a global health crisis during which mask-wearing has emerged as an effective tool to combat the spread of disease. During this time, nontechnical users like health officials and school administrators need tools to know how widely people are wearing masks in public. We present a robust and efficient Mask Adherence Estimation Tool (MAET) based on the pre-trained YOLOv5 object detection model and combine it with explanation methods to help the user understand the mask adherence at an individual and aggregate level. We include two novel explanation methods to compute a high-fidelity importance map based on two black-box explanation methods. For our work, we experimented with one-stage and two-stage object detector architectures. Experiment results show that MAET achieves state-ofthe-art results on a public face mask dataset, with improved performance by 2.3% precision and 0.4% recall in face detection and 2.0% precision and 1.7% recall in mask detection. We used three different evaluation metrics for explanation and find that no method dominates all metrics; therefore, we support multiple explanation methods. Keywords Face mask detection · Data augmentation · Explainable AI
1 Introduction According to WHO, coronavirus disease 2019 (COVID-19) has globally infected over 239 million people and caused over 4 million deaths by October 14,2021.1 This number is increasing day by day. This circumstance pushes the global community to look for preventive measures such as face masks and social distancing to halt the 1 https://covid19.who.int/.
A. Gupta (B) Indian Institute of Technology, Kanpur, India e-mail: [email protected] B. Srivastava AI Institute, University of South Carolina, Columbia, USA © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_11
155
156
A. Gupta and B. Srivastava
spread of this infectious virus. In multiple studies, it has been found that wearing face masks can minimize the risk of viral transmission and provide a sense of protection. Corona vaccinations have begun to appear in the market, although they are not available in all parts of the world. Therefore, wearing face masks on a daily basis is a must until the infection is totally eradicated. However, enforcing such a policy and keeping track of any violations on a broad scale is impossible. As a result, face mask detection has become a critical computer vision problem for assisting the global society, but research on the topic of face mask detection is limited. Face mask detection is a technique for determining a person’s facial position and whether or not they are wearing a mask. The problem is linked to object detection, which is used to recognize various types of objects. The technique of recognizing a certain collection of items, especially faces, is known as face identification. It can be used for a variety of things, including autonomous driving, education, and monitoring. Computer vision can be used to automate the monitoring of mask-related public policy. Deep learning-based object detectors have recently proven to be extremely effective, and they have dominated modern object detector development. Face detection, like general object detection, employs a similar architecture but adds more face-specific features, such as facial landmarks as in [9, 34]. For numerous vision tasks, such as image classification, object detection, and semantic segmentation, deep vision models are rapidly surpassing human-level performance. However, they hardly offer any explanation of their decisions. It is critical that models in domains where a decision can have serious consequences (e.g., autonomous driving, medical diagnosis, public health related policies, etc.) must be transparent. Addressing this issue, methods like [6, 22, 33, 35] came. To compute importance for a given base model and an output category, they need access to the internals of the base model, such as the gradients of the output with respect to the input While methods like [18, 20] offers such a black-box approach by drawing random samples around the instance to be explained and fitting an approximate linear decision model. However, these techniques have primarily focused on the image classification task. In this paper we are presenting a robust face-mask detection system MAET– Mask Adherence Estimation Tool that can also be used by non-technical users. The proposed system appraised promising output on data collected from different sources. For our work, we experimented with different object detector architectures, namely one-stage YOLOv5 [10] and two-stage MTCNN [34] with Xception [3]. We also experimented with the possibility of using ResNet [8] as a backbone with Yolov5. The proposed methods are tested on a face mask dataset [4]. Based on the test set performance, we chose the best-performing model. Experiment results show that MAET achieves state-of-the-art results on a public face mask dataset, with performance improvement of 2.3% precision and 0.4% recall in face detection and 2.0% precision and 1.7% recall in mask detection. Further to check its effectiveness, we evaluated it on other publicly available datasets and scrapped images, considering them as the Wild dataset. We also solved few practical challenges. Authors of [7] think that the unavailability of sizable datasets containing both masked and unmasked faces is a significant
A Robust System to Detect and Explain Public Mask Wearing Behavior
157
Fig. 1 Screenshots of MAET for LIME explanation; with one (a) and multiple people (b) in the scenes. Example (a) shows that our face mask detector model can differentiate between the covering of the face by mask and by hand
challenge in face mask detection. For reducing the shortage of datasets, we applied transfer learning. To further improve the model performance and making it more robust, we use data augmentation. We used the bounding box augmentation introduced in [36] and experimented with proposed policies. We also address the problem of providing an explanation for our tool’s decision. We provide two explanation methods that can visually explain detector’s prediction. Both of our explanation methods use the input masking techniques introduced in LIME and RISE, respectively, in a complete black-box manner. We evaluate our explanation result using classification saliency metrics like deletion and insertion [18], Pointing game [33] with suitable adoption for detection purpose. Our contributions can be summarized as: • An integrated system—An end-to-end mask detection and explanation-based system. The entire approach is shown in Fig. 2. • Explanation for mask detectors—Explored the topic of mask detector explanation, which is currently unexplored, and suggested two black-box explanation approaches that can provide an accurate visual explanation of proposed mask detection methods while preserving map sanity. • Interactive individual explanation and aggregate estimation—The tool gives individual-level explanation with the interactivity of choosing faces and evaluating the outcome. It also includes a composite assessment of how well is the mask adhesion in the given image based on customizable parameters, making it useful for non-technical users also. See Fig. 1 for screenshots. • Extensive experimentation and evaluation—Extensively experiment and evaluate our proposed tool. We used transfer learning and data augmentation to deal with data shortage. We tested different detection architectures and evaluated them using precision and recall. We also evaluated our explanation methods utilizing different automated classification saliency measures adopted for mask detectors. The rest of this paper is arranged in the following manner. We examine related works on face mask detection and explanation algorithms in Sect. 2. In Sect. 3, the
158
A. Gupta and B. Srivastava
Fig. 2 MAET Workflow. Our tool takes an image of arbitrary size and outputs (class + bbox coordinates + objectness score) for each face, where class represents whether a person is wearing a mask or not. In addition to it, our tool also gives an aggregate result of the image in terms of safety level based on a pre-defined heuristic. Next, the user selects a face target for an individual explanation. For LIME-based explanation, a set of superpixels contributing most, and a saliency map is superimposed over the image for RISE-based explanation. The uploaded image will be auto deleted after 20 min
recommended methodology is provided. Datasets, augmentation, and experiment settings are discussed in Sect. 4. The result is analysed qualitatively and quantitatively in Sect. 5. Finally, Sect. 6 brings the paper to a conclusion.
2 Background and Related Work 2.1 Face Mask Detection With the onset of the COVID pandemic, the face mask’s importance was realized as an essential preventive measure. Subsequently, implementations of face mask detection systems emerged but were limited in number. Nevertheless, they were targeted towards technical AI users and not non-technical users like health officials. Face detection is difficult because the faces change in size, shape, color, etc., and are not immutable. Reference [19] introduced a method to identify face mask-wearing conditions. Their system takes an image, detects faces, and crops them, then uses SRCNet [5] to perform image super-resolution and classify them. Reference [4] introduced a new dataset and model architecture. The dataset is composed of WIDER face [29], and MAFA [7]. The author used the structure of SSD [14]. The model in total has 24 layers, with the location and classification layers counted. Reference [9] introduced a new architecture RetinaFaceMask. ResNet [8] as the standard backbone, FPN [11] as the neck, and context saliency modules as the heads are used. The authors adopt a similar multi-scale detection strategy as SSD to predict with multiple FPN
A Robust System to Detect and Explain Public Mask Wearing Behavior
159
feature maps. Reference [15] proposed a two-stage model. The first stage, ResNet50 as a feature extractor. The next stage face mask classifier based on an ensemble of classical Machine Learning algorithms. The authors evaluated their system and estimated that Deep Transfer Learning approaches would achieve better results.
2.2 Explanation Backpropagation-based method. Backpropagation is a technique for tracking information from a network’s output back to its input or an intermediary layer. Many attribution techniques leverage this technique for explanation purposes. Reference [32] visualizes the internal representation learned by CNNs using deconvolutional networks. Other approaches [17, 23, 30] have tried to synthesize an input that highly activates a neuron. Many methods combine the layer activation maps with the gradient information to investigate saliency computation Such approaches include CAM [35], which computes a weighted sum of the feature activation values at each position of an image across all channels to achieve class-specific significance. GradCAM [22] uses the average gradient of the class score for each feature map channel to weight the feature activation values at each location. Above approaches require access to the underlying model’s internals to acquire feature activation values, gradients, or weights. Several papers have demonstrated that some backpropagation based methods yield same saliency map regardless of the output neuron being studied [1] or even network parameters [16]. As a result, such approaches may be able to characterize average network characteristics but not individual outputs, intermediate activations or model parameters. Perturbation-based methods. Another family of approaches perturbs the inputs to a model and observes resultant changes to the outputs. Occlusion [32] occlude an image using regular occlusion patterns and weight the changes in the output by occlusion pattern. Meaningful perturbations [6] optimize a spatial perturbation mask that maximally affects a model’s output. Other works have leveraged perturbations at the input [24] to perform weakly and fully supervised localization. LIME [20] utilizes a linear classifier to approximate the deep model, trains it at the input point using samples with occluded superpixels, and uses the acquired weights as a measure of superpixel significance. RISE [18] creates a collection of random masks, applies the classifier to masked copies of the input, uses the predicted class probabilities as weights, and computes the saliency map as a weighted sum of the masks. However, all the methods described above can only justify scalar values, class probabilities, as in the case of image classification. An object detector also predicts an object’s bounding box location, which also requires explanations in addition to class probabilities. Furthermore, it generates several detection suggestions per picture, which might vary considerably for perturbated image copies. Because of these discrepancies, it is impossible to apply existing categorization saliency algorithms
160
A. Gupta and B. Srivastava
directly. In our work, we address these issues, making it possible to produce saliency or segmentation maps for mask detectors. We use the last two masking approaches in our work with modifications to feasibly compute the explanation for object detection.
3 Methodology We proposed the MAET—Mask Adherence Estimation Tool, which consists of masked or unmasked face detection and individual and aggregate level explanations. The tool requires inputting any arbitrary size image to the model, outputting bounding box coordinates, objectness score, and class prediction along with an aggregate mask adherence score. Then, based on the explanation method, the tool provides a high-fidelity visual explanation of the predicted bounding box for the chosen target in terms of segmentation of regions or saliency map superimposed over the image, making it interactive. The visual explanation takes into consideration the localization and classification aspects. The complete workflow is depicted in Fig. 2.
3.1 Detection Architecture We used the object detector frameworks to create an effective network for face mask detection. An object detector uses input images to produce features, which are then sent into a prediction system, predicting bounding boxes and class for each object in the image. We experimented with one-stage YOLOv5 [10] and two-stage MTCNN [34] with Xception [3] object detector architecture. YOLOv5. YOLOv5 is a single-stage detector. The YOLOv5 network consists of a backbone, neck, and head. The backbone refers to a general feature extractor made up of convolutional neural networks to extract information in images to feature maps. In YOLOv5, the first layer of the network performs pixel un-shuffling to enhance speed. In MAET, we chose CSPNet-Cross Stage Partial Networks [28] as standard backbone. The CSPNet addresses duplicate gradient problems in other larger ConvNet backbones resulting in fewer parameters and fewer FLOPS. For comparison we also used ResNet [8] as backbone. The model Neck consists of a series of layers to mix and combine image features to pass them forward to prediction, used to generate feature pyramids, help models to generalized well on object scaling. We used PANet [13] as the neck. It enhances the entire feature hierarchy with accurate localization signals in lower layers by bottom-up path augmentation, which shortens the information path between lower layers and topmost features. PANet used FPN [11] as the backbone. Bottom-up route augmentation is used to reduce the information path and enhance the feature pyramid with accurate localization signals found at low levels. Then, a primary component called adaptive feature pooling is utilized to aggregate features from all feature levels for each proposal. Finally, the model Head
A Robust System to Detect and Explain Public Mask Wearing Behavior
161
is used to perform the final detection part. It consumes features from the neck, applies anchor boxes on features, and predicts final output with class, objectness scores, and bounding box coordinates. MTCNN with Xception. In this method, we implement a two-stage detector. We used Multi-task Cascaded Convolutional Network (MTCNN) [34] model for face detection and then passed each bounding box to Xception [3] model for classification. The MTCNN model consists of a deep cascaded multi-task framework that exploits the inherent correlation between face detection and alignment to boost performance. They adopt a cascaded structure with three deep convolutional networks that predict face and landmark location in a coarse-to-fine manner. P-Net is a fully convolutional network to obtain the candidate windows and their bounding box regression vectors. Then, the estimated bounding box regression vectors are used to calibrate the candidates and perform non-maximum suppression (NMS). The resized bounding boxes of P-Net pass to R-Net where all candidates are fed, which further rejects a significant number of erroneous candidates. Then finally, bounding boxes pass to O-Net. This stage is identical to the second, but the goal is to output five facial landmark positions. Xception model is based on Inception [25] model in which Inception modules are replaced with depthwise separable convolutions. A depthwise convolution follows the pointwise convolution in the modified depthwise separable convolution. From now onwards, we would call MTCNN+Xception as Approach-1, Yolov5+ ResNet as Approach-2, and Yolov5+CSPNet as Approach-3.
3.2 Explanation Architecture We propose two explanation methods for providing an explanation in a black-box manner. We use input perturbation-based methods LIME and RISE to measure the effect of masking the input on the detector’s output. As these methods are for classification purpose, hence cannot be directly used for detection model. In contrast to classification model, face detection additionally outputs bounding box points. Our detection model takes an arbitrary sized image (I ), outputs a detection vector (Di ) for each face consisting of object localization L i = {xlow , ylow , x high , x high }, face confidence Oi ∈ [0, 1] representing confidence of detected face and class Ci . Then, a particular bounding-box is selected, and the explanation method runs over it. To explain this, we define a notion of similarity score (ss) between chosen bounding box detection vector (DT B ) and predicted masked-input bounding box detection vectors (P B) in terms of localization and classification. For localization purpose, we use the Intersection over Union (IoU) as a measure, and for classification, we use object confidence as a measure. The implementation is shown in Algorithm 1. Local Interpretable Model-Agnostic Explanations (LIME). It focuses on training local surrogate models to explain individual predictions. It generates a new dataset consisting of perturbed samples and the corresponding predictions of the black-box model. This new dataset trains an interpretable model weighted by the proximity
162
A. Gupta and B. Srivastava
Algorithm 1 Similarity Score Require: DT B , P B I oU (Di , D j ) ← I oU (L i , L j ) 1: D j = arg max I oU (Dk , DT B ) \\ Localization aspect Dk ∈P B
2: 3: 4: 5: 6: 7:
if C j == C T B then ss ← I oU (DT B , D j ) × O j \\ Classification aspect else ss ← 0 \\ Taking only positive contribution end if return ss
of the sampled instances to the instance of interest. By segmenting the image into superpixels and turning superpixels off or on, i.e., by replacing each pixel in a superpixel with a user-defined color, variations of the images are made. For LIME, we employ the Quickshift [26] technique to return a segmented mask based on a kernelized mean-shift approximation. New samples are created by drawing random numbers from a normal distribution with mean and standard deviation taken from the user, then binary thresholding for turning superpixels off or on based on a userdefined threshold. Similarity scores are considered labels for the linear regressor to get importance weight for each superpixel. We used top-3 weighted superpixels for explanation purposes. Randomized Input Sampling for Explanation of Black-Box Models (RISE). RISE is a black-box explanation method in which the base model is probed by randomly masking the input and recording its response to each masked image. The final saliency map is created by the weighted sum of the random binary masks with the combination weights derived from the base model’s output probabilities on the masked pictures. N masks are generated, and for each, a similarity score is obtained. The final saliency map is generated as weighted average generated masks with similarity scores. Masks are generated as described in [18] • Sample N binary masks of size h × w (smaller than image size H × W ) by setting each element independently to 1 with probability p and 0 with 1 − p probability. • Upsample all masks to size (h + 1)C H × (w + 1)C W using bilinear interpolation, where C H × C W = H/ h × W/w is the size of the cell in the up-sampled mask. • Crop areas H × W with uniformly random indents from (0, 0) up to (C H , C W ).
3.3 Transfer Learning Due to the limited availability of the face mask dataset, it is difficult for learning algorithms to learn better features. Deep learning-based methods are data-hungry.
A Robust System to Detect and Explain Public Mask Wearing Behavior
163
According to [31], transfer learning has helped with the learning in a significant way In our work, the CSPNet and PANet are pre-trained on the MS COCO dataset [12], Xception on Imagenet dataset [21] and MTCNN on WIDER FACE dataset [29].
3.4 Individual and Aggregate Explanation The person detected now can be explained. The user can select a person for getting a particular explanation. For LIME-based explanation, users can vary top-3 explanation features, for each feature, a segment mask boundary is generated representing the part of the image contributing towards prediction. For RISE-based explanation, a saliency map is generated. We also provide rules to convert mask-wearing in the scene to aggregate safety estimates of the whole image: ≺ 50% for low, 50–75% for medium, and rest for high depending on the ratio of people wearing mask to total number of people detected.
4 Experimentation 4.1 Dataset The Face Mask Dataset [4] contains 7959 images with either a mask or no mask annotated on their faces. The Face Mask Dataset is a composite of the Wider Face [29] and MAsked FAces (MAFA) datasets [7]. In dataset some of the faces are hidden by hands or other objects rather than actual masks, which gives the dataset advantages over other Face Mask Datasets. We also used other publicly available datasets to test our tool’s robustness, considering them as the Wild dataset. It includes Face Mask Detection and Medical Mask Dataset [27] contain 853 and 678 images, respectively. We also added 400 new images along with annotation containing diverse background situations to our wild testing dataset. We choose images with a single person per frame for classification purposes and some cropped images from multi-person per frame. The classification dataset contains a total of 4102 images. Examples of datasets are shown in Fig. 3.
4.2 Augmentation Data augmentation is used to improve the diversity of data without having to collect additional data. For each training image, four augmented images are generated. Reference [36] states that bbox-operation, i.e., introducing augmentation that uniquely act upon the contents within each bounding box, are the most effective one for object
164
A. Gupta and B. Srivastava
Fig. 3 The first row represents images from the dataset [4]. The second row has examples of wild-dataset images that we have taken from different domains with diverse scenarios. Notice the diversity in types of masks worn by people
detectors, so we experimented with the given policies. We have five policies, with each policy consisting of 2 operations applied in sequence to a single image. Additionally, each operation is also associated with two hyperparameters specifying the probability of applying and the magnitude. As per our inspection, rotate, which rotates the entire picture and the bounding boxes, is the most often employed operation in effective policies. After the rotation, the bounding boxes grow larger to encompass the entire turned item. Equalize, and BBox-Only-TranslateY are two more operations that are oftently utilized. Equalize flattens the pixel values histogram. Only the items in bounding boxes are translated vertically, up or down, with equal probability when using BBox-Only-TranslateY. Augmentations are generated with the help of Albumentations [2].
4.3 Experiment Setup While experimenting, we employed Stochastic Gradient Descent (SGD) with an input size of 640 × 640 and batch size of 4 for 100 epochs. The dataset has been split into a train, validation, and test set with 4590, 1530, and 1839 images. Two augmented images are generated for each training image, so the final training set consists of 9180 images. The pre-trained Xception network was fine-tuned using the classification dataset for 10 epochs, achieving an accuracy of 99.12% on the whole dataset. For LIME, the size of the neighborhood was set to 1000 samples. For RISE, 3000 random masks were generated. The training was done on NVIDIA Tesla K80 GPU. The numbers are chosen to make a balance between computational cost and metrics score.
A Robust System to Detect and Explain Public Mask Wearing Behavior
165
(a) Target 0 (b) Target 0 (c) Target 0 (d) Target 0 (e) Target 0 (f) Target 0 (g) Target 1 Fig. 4 Examples of detection and explanation output. The first row is the output of Approach-1. Each bounding box is enumerated for explanation reference. The second row represents the LIME explanation with top-3 weighted superpixels. The third row represents the RISE explanation, and the red intensity represents the importance of the saliency map
5 Evaluation 5.1 Qualitative Results Qualitative results for mask detection with Approach-3 and explanations for both LIME and RISE are shown in Fig. 4. From the explanation in Fig. 4, we can see that model has learned semantics correctly for the mask-wearing task. Figure 4c shows that our detection model cannot be fooled by just covering the face, and explanation methods were able to catch correct features from the image to justify it. From 4b explanations, we can make the inference that covering of face part surrounding the nose is a decisive feature for the model. This feature can also help us regulate public health policies like mask-wearing. Supporting images are provided in Fig. 5. Our technique maintains its performance across a variety of scenarios, including clear faces in Fig. 4a; blurred faces in Fig. 4e; tilted faces in Fig. 4e. Our explanation methods provide rich semantics decisive features model learned while training, and our tools provide precise bounding boxes and explanation maps for small and large faces in the images. Due to the dependency of LIME over superpixels, its explanation tends to cover a little more image part as compared to RISE. These are highly promising results for face mask detection and the explainability of modern deep mask detectors.
166
A. Gupta and B. Srivastava
Fig. 5 Not wearing a mask properly will result in assignment to the unmaksed class. Hence, the approach can help in regulating mask-wearing more efficiently. Such a model can be used to verify adherence to health policies. Specifically, one can estimate adherence in video frames and if it reaches above certain thresholds, the situation can be brought to the attention of concerned authorities for requisite action
5.2 Quantitative Results 5.2.1
Face Mask Detection
We use Precision (P) and Recall (R) as evaluation metrics. P=
True Positives True Positives R= True Positives + False Positives True Positives + False Negatives
The performance of our methods is compared with a public baseline result published by the creator of the dataset [4] and RetinaFaceMask [9] using the annotations included in dataset. The evaluation metrics results are shown in Table 1. Table 1 shows that all three approaches outperform the baseline. Approach-3 achieves the stateof-the-art compared to RetinaFaceMask. Particularly, Approach-3 achieves 2.3% precision, and 0.4% recall better than RetinaFaceMask in face detection, and 2.0% precision and 1.7% recall better than RetinaFaceMask in mask detection. We also evaluate our methods over the Wild dataset, and results are shown in Table 2. Table 2 shows that the performance of both the methods did not deteriorate much over the wild dataset. We use Approach-3 in our tool because of its superior performance.
Table 1 Evaluation of mask detectors in terms of precision and recall on test dataset Model Face Mask Precision (%) Recall (%) Precision (%) Recall (%) Baseline [4] RetinaFaceMask [9] Approach-1 Approach-2 Approach-3
89.6 91.9
85.3 96.3
91.9 93.4
88.6 94.5
90.5 93.1 94.2
86.9 94.3 96.7
90.8 94.2 95.4
88.8 93.7 96.2
A Robust System to Detect and Explain Public Mask Wearing Behavior
167
Table 2 Evaluation of mask detectors in terms of precision and recall on Wild dataset Model Face Mask Precision (%) Recall (%) Precision (%) Recall (%) Approach-1 Approach-2 Approach-3
86.9 90.1 93.4
84.3 91.0 94.2
87.1 92.6 94.7
83.1 92.9 95.5
Table 3 Evaluation of explanation maps in terms of pointing game (higher is better), deletion (lower is better) and insertion (higher is better) scores on test dataset Method Pointing game (↑) Insertion (↑) Deletion (↓) LIME RISE
0.979 0.961
0.671 0.713
0.133 0.118
Explanation For an explanation, we used the two automatic evaluation metrics: Deletion and Insertion [18]. The deletion measure is based on the assumption that when more and more significant pixels are eliminated, this metric reflects a drop in the likelihood of the predicted class, with the significance acquired from the importance map. A small Area-Under-Curve(AUC) indicates a good explanation. On the other side, the insertion metric adopts a complementary approach. It calculates the likelihood increase when additional pixels are added, with a higher AUC indicating a better explanation. We adapted these matrices with the help of the similarity score. We also report the Pointing game score [33]. A hit is scored if the point of maximum saliency lies within the ground-truth object bounding box; otherwise, a miss is counted. The pointing game measures the accuracy of saliency maps by computing the number of hits over the total number of hits and misses. For LIME, we use the most weighted superpixel and consider a hit if any of its pixels lies within the ground truth box; similarly, for insertion and deletion, superpixels are on-ed/offed according to their weights. The evaluation metrics results are shown in Table 3 for the test dataset. LIME shows better performance over RISE on pointing game by a margin of 0.018. While RISE shows better performance in other two score with the margin of 0.042 and 0.015 respectively. As no method dominates all three metrics; therefore, we support multiple explanation methods.
5.3 Sanity Check Recently, a concern regarding the validity of saliency methods has been raised [1] demonstrated that when the weights of the model that each saliency technique claims to describe are randomized, the outputs of specific commonly recognized saliency
168
A. Gupta and B. Srivastava
methods do not alter appreciably. When depending just on the visual judgment, a confirmation bias may result in an erroneous human assessment. We did a model parameter randomization test to address this issue. Our findings show that weight randomization produces incomprehensible saliency maps, implying that our approach relies on the trained model’s information.
6 Conclusion In this paper, we presented a practical tool consisting of face mask detection and explaining them. The detector uses a pre-trained YOLOv5 model and two novel sanity preserving explanation methods based on LIME and RISE for explaining the detector’s result. We experimented with many architectures, explanation methods and data augmentation techniques to make the tool as robust as possible. The tool will be open-sourced so anyone can access and extend it. Authorities around the world can use the tool by uploading the image and see results with the explanation. In times of the COVID-19 pandemic, with the world looking back to normalcy and people resuming in-person work, such monitoring of face masks at public places will make them safer. Beyond COVID, this tool can be easily adapted for other public policy implementations like helmet wearing and smoking.
References 1. Adebayo, J., Gilmer, J., Muelly, M., Goodfellow, I. J., Hardt, M., & Kim, B. (2018). Sanity checks for saliency maps. 2. Buslaev, A., Iglovikov, V. I., Khvedchenya, E., Parinov, A., Druzhinin, M., & Kalinin, A. A. (2020). Albumentations: Fast and flexible image augmentations. Information, 11(2). https:// doi.org/10.3390/info11020125, https://www.mdpi.com/2078-2489/11/2/125 3. Chollet, F. (2017). Xception: Deep learning with depthwise separable convolutions. 4. Chiang, D. (2020). Detect faces and determine whether people are wearing mask. https:// github.com/AIZOOTech/FaceMaskDetection 5. Dong, C., Loy, C. C., & Tang, X. (2016). Accelerating the super-resolution convolutional neural network 6. Fong, R. C., & Vedaldi, A. (2017). Interpretable explanations of black boxes by meaningful perturbation. In 2017 IEEE ICCV. 7. Ge, S., et al. (2017). Detecting masked faces in the wild with LLE-CNNs. In 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). 8. He, K., et al. (2015). Deep residual learning for image recognition. 9. Jiang, M., et al. (2020). Retinamask: A face mask detector. 10. Jocher, G., et al. (2021). ultralytics/yolov5: v4.0—nn.SiLU() activations, Weights & Biases logging, PyTorch Hub integration. 11. Lin, T. Y., et al. (2017). Feature pyramid networks for object detection. 12. Lin, T. Y., et al. (2015). Microsoft coco: Common objects in context. 13. Liu, S., et al. (2018). Path aggregation network for instance segmentation. 14. Liu, W., et al. (2016). SSD: Single shot multibox detector. ECCV.
A Robust System to Detect and Explain Public Mask Wearing Behavior
169
15. Loey, M., et al. (2020). A hybrid deep transfer learning model with machine learning methods for face mask detection in the era of the covid-19 pandemic. 16. Mahendran, A., & Vedaldi, A. (2016). Salient deconvolutional networks. 17. Nguyen, A. M., Dosovitskiy, A., & Yosinski, J., et al. (2016). Synthesizing the preferred inputs for neurons in neural networks via deep generator networks 18. Petsiuk, V., Das, A., & Saenko, K. (2018). Rise: Randomized input sampling for explanation of black-box models. 19. Qin, B., et al. (2020). Identifying facemask-wearing condition using image super-resolution with classification network to prevent covid-19. https://doi.org/10.21203/rs.3.rs-28668/v1 20. Ribeiro, M. T., Singh, S., & Guestrin, C. (2016). Why should i trust you?: Explaining the predictions of any classifier. 21. Russakovsky, O., Deng, J., Su, H., & Krause, J., et al. (2009). Imagenet large scale visual recognition challenge 22. Selvaraju, R. R., & Das, A., et al. (2019) Grad-cam: Visual explanations from deep networks via gradient-based localization. International Journal of Computer Vision. 23. Simonyan, K., et al. (2014). Deep inside convolutional networks: Visualising image classification models and saliency maps. 24. Singh, K. K., Lee, Y. J. (2017). Hide-and-seek: Forcing a network to be meticulous for weakly-supervised object and action localization. 25. Szegedy, C., et al. (2014). Going deeper with convolutions. 26. Vedaldi, A., & Soatto, S. (2008). Quick shift and kernel methods for mode seeking. In ECCV. 27. Waghe, S. (2020). Medical mask detection. 28. Wang, C., et al. (2019). CSPNet: A new backbone that can enhance learning capability of CNN. CoRR. arxiv:abs/1911.11929 29. Yang, S., et al. (2015). Wider face: A face detection benchmark. 30. Yosinski, J., Clune, J., & Nguyen, A. M., et al. (2015). Understanding neural networks through deep visualization. 31. Zamir, A., et al. (2018). Taskonomy: Disentangling task transfer learning. 32. Zeiler, M. D., & Fergus, R. (2014). Visualizing and understanding convolutional networks. 33. Zhang, J., Lin, Z. L., & Brandt, J., et al. (2017). Top-down neural attention by excitation backprop. 34. Zhang, K., et al. (2016). Joint face detection and alignment using multitask cascaded convolutional networks. IEEE Signal Processing Letters. 35. Zhou, B., Khosla, A.,& Lapedriza, À., et al. (2016). Learning deep features for discriminative localization. 36. Zoph, B., Cubuk, E. D., & Ghiasi, G., et al. (2020). Learning data augmentation strategies for object detection.
A Federated Cox Model with Non-proportional Hazards D. Kai Zhang , Francesca Toni , and Matthew Williams
Abstract Recent research has shown the potential for neural networks to improve upon classical survival models such as the Cox model, which is widely used in clinical practice. Neural networks, however, typically rely on data that are centrally available, whereas healthcare data are frequently held in secure silos. We present a federated Cox model that accommodates this data setting and also relaxes the proportional hazards assumption, allowing time-varying covariate effects. In this latter respect, our model does not require explicit specification of the time-varying effects, reducing upfront organisational costs compared to previous works. We experiment with publicly available clinical datasets and demonstrate that the federated model is able to perform as well as a standard model. Keywords Survival analysis · Federated learning · Non-proportional hazards
1 Introduction Estimating how long patients might live for is a key task in clinical medicine, and is a common question from patients. Survival analysis is the statistical branch used to perform these estimates, which can range in its application from predicting death following diagnosis to loan defaults or machine part failures. Amongst survival models, the Cox model [6] is one of the most widely used. Machine learning techniques have received attention for their potential to improve upon the performance of the Cox model. Many recent efforts [13, 16, 20] have D. K. Zhang (B) Imperial College London, SW7 2BX, London, UK e-mail: [email protected] F. Toni Imperial College London, SW7 2BX, London, UK e-mail: [email protected] M. Williams Imperial College London, SW7 2BX, London, UK e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_12
171
172
D. K. Zhang et al.
exploited neural networks (NNs) to model more complex relationships as well as enable typically unsupported input data types such as images [4, 17, 27, 28]. Notwithstanding, the Cox model has remained the standard in survival analysis [24]. Indeed, the adoption of machine learning has progressed haltingly in many areas of healthcare [14]. One challenge lies in the distributed nature of healthcare data [25]. In much of machine learning, data are centralised, whereas privacy concerns often result in secure data “silos” in healthcare. Federated learning (FL) accommodates this decentralised data environment and has shown promise in clinical contexts [23]. Despite a fast emerging literature in FL, there has been scant work on federated survival analysis. Reference [1] propose a federated Cox model that is closest to this paper. The standard Cox model is, however, limited in that it can only correctly model proportional hazards. We take an alternative approach allowing us to embed time-varying covariate effects (non-proportional hazards) directly in the architecture, potentially reducing organisational setup costs for federations. Such effects are relevant to adapt models for patients such as those with breast cancer where the proportional hazards assumption has been shown to be violated [3, 5, 11]. In the following, we briefly discuss relevant background and highlight related work (Sect. 2), before defining our model (Sect. 3), instantiating it with different hazards assumptions and presenting our experiments with real-world clinical datasets (Sect. 4). Section 5 concludes with potential directions for future work.
2 Background and Related Work The promise of greater control over data ownership and enhanced privacy that FL affords has generated interest in the healthcare community. Few works, however, have investigated the intersection between survival analysis and FL. We present background on each of these areas separately before discussing their intersection relevant for this work. Background on Federated Learning (FL). FL is a framework for decentralised data that cannot be shared due to their sensitive content or prohibitive communication costs [21]. In the context of healthcare, patient data may be kept in this way by the clinical unit (e.g., the hospital) at which the patient was treated. In the following, we will simply refer to these data-keeping units as federation members or centres. Typically (and in this paper), the federated objective is to minimise L F (X, φ) with respect to φ with: wk Lk (X k , φ) (1) L F (X, φ) = k∈K
where L F represents the global loss: an average of the local losses Lk computed by the federation members in K on their own data X k weighted by wk , where φ
A Federated Cox Model with Non-proportional Hazards
173
represents the model parameters. Typically, each member customises φ for a number of local optimisation rounds before aggregating the customised φ for a new global consensus model. Background on Survival Analysis. Survival analysis estimates the time to an event for a population N with data D = {(xi , ti , si )}i∈N where each person i has covariates xi = (xi1 , . . . , xi p ) , a time of observation ti and an indicator si ∈ {0, 1} which equals 1 if i has experienced the event or 0 if not, i.e., if i is censored. The Cox model [6] is one of the most widely used survival models. It defines a hazard function h, which expresses the rate of failure at time t subject to survival until then as follows: h(t|xi ) = P(T = t|T > t − 1) = h 0 (t)ex p[g(xi )] with g(xi ) = β xi
(2)
where h 0 (t) is some baseline hazard and where β = (β1 , . . . , β p ) is a coefficient vector. Later works replace the linear predictor β xi with NNs gφ (xi ), demonstrating competitive performance [9, 13]. The coefficients are estimated by minimising the negative partial log-likelihood given by: ⎤ ⎡ si ⎣g(xi ) − log( ex p[g(x j )])⎦ (3) − i∈N
j∈Ri
where Ri = { j ∈ N : t j ≥ ti } denotes the individuals who are still at risk when i experiences the event. In a federated setting, this loss generally cannot be decomposed into local losses due to the logarithmic term, as the risk set Ri can contain individuals from centres other than the one of i. This therefore does not match the formulation of Eq. 1. The hazard function also assumes proportional hazards (PH)—differences in covariates result in constant proportional differences in hazards. Over long time horizons, this can be restrictive [2, 3, 16]. State-of-the-Art in Federated Survival Analysis. Our work is situated in the intersection of federated learning and survival analysis and proposes a novel framework. A handful of works have already proposed such frameworks of which we provide a brief overview here. The works of [8, 19] embody one approach which relies on substantial sharing of summary statistics over the local datasets in every training iteration. This differs in spirit from FL where more abstract parameters are shared and, often, infrequently so. Moreover, their models are based on linear predictors and do not address integration with NNs. Recent work by [1] is closest to our approach. Their model exploits a discretisation of the Cox model (also by [6]) with an NN-based predictor: h 0 (t) h(t|x) = ex p[gφ (xi )] 1 − h(t|x) 1 − h 0 (t)
(4)
174
D. K. Zhang et al.
which can be rewritten in a sigmoid form: h(t|xi ) =
1 1 + ex p[−(αt + gφ (xi ))]
(5)
h 0 (t) where αt = log 1−h . (t) 0 They follow [7] in estimating this function like a logistic regression with negative log-likelihood: ti − [yik log[h(k|xi )] + (1 − yik ) log[(1 − h(k|xi ))]]
(6)
i∈N k=1
where yi j = 1{t j = ti , si = 1}. Importantly, this loss does not depend on risk sets and is therefore separable—each centre’s loss only depends on local data—recovering the federated objective (Eq. 1). Reference [1] demonstrate that a federation of this model can draw even in performance with a model trained on pooled data. This is, however, only shown with aggregation after every local optimisation round—a setup that may need to differ in practice [23]—and assuming PH, as their predictor gφ (xi ) is time-invariant. We note that non-PH can be admitted to their model by including time interactions (giving gφ (xi , f (t)))—an approach sometimes taken in standard Cox models—as demonstrated on pooled data by [7]. This could, however, introduce a dependency on the specification of f (t) and its interactions. Crucially, this may add to the organisational setup costs of a federation: even though interactions could be learned, f (t) needs to be fully specified and agreed upon in advance. In contrast, we follow [10] by making the choice between PH and non-PH a binary decision over the architecture of the output layer.
3 Model We build upon the discretised Cox model and detail how the PH assumption is relaxed and formulate the federated objective. We describe a discretisation procedure and an optional interpolation scheme for smooth predictions. Lastly, we outline two complementary performance metrics. Non-proportional Hazards. We use a discretised Cox model (Eq. 5) but parameterised with a time-varying, NN-based predictor gφ,t (x). Following [10], we allow for non-PH by fully connecting the output layer to the previous layer. The output layer thus encodes time-varying covariate effects in time-specific weights. A sigmoid is used to retrieve the hazard rates. For PH, the output component is split into a first layer with a single neuron and no bias. The output of this neuron is passed into a second layer with as many neurons as
A Federated Cox Model with Non-proportional Hazards
175
Fig. 1 Output components for 3 time steps
time steps. This captures time-varying baseline hazards in the second layer and timeinvariant covariate effects in the first. The difference in components is illustrated in Fig. 1. Federated Objective. To conform to a federated formulation (Eq. 1), we split the objective (Eq. 6): LF =
wk Lk (X k , φ)
k∈K
=−
ti |Nk | k∈K
|N |
yi j log[h( j|xi )] + (1 − yi j ) log[(1 − h( j|xi ))]
(7)
i∈Nk j=1
where yi j = 1{t j = ti , si = 1}. Each centre calculates Lk (X k , φ) on its own subset of the population Nk . We adapt the FedAvg algorithm of [21] to minimise this loss (Algorithm 1). Discretisation. The model operates on discretised time, so that t indexes into a set of intervals [τt−1 , τt ). Following [15] we discretise on Kaplan-Meier quantiles. Defining the survival curve S(τ ) = S(τ − 1)(1 − h(τ )), the quantiles {τ1 , τ2 , . . . , τm } can be obtained as: S(T = τ j ) − S(T = τ j+1 ) =
1 − S(T = τmax ) m
(8)
for j = {0, 1, . . . , m − 1}. This discretisation procedure yields a set of steps {τ1 , τ2 , . . . , τm } where each step results in the same decrease in survival (an illustration is provided in Appendix A). Interpolation. To smooth step-wise predictions, we use constant density interpolation [15]. Letting S(τ ) denote the interpolation of the survival curve S(τ ), we then have:
176
D. K. Zhang et al.
Algorithm 1 Procedure to optimise the federated objective. 1: Initialise global model with φ0 2: for each round t = 1, . . . , T do 3: for each centre k = 1, . . . , K in parallel do 4: Send φt−1 to centre k 5: for each local round b = 1, . . . , B do 6: Local update φkt ← φkt − λ∇ Lk (X k , φkt ) 7: end for 8: Receive φkt from centre k 9: end for k| k 10: Aggregate φt ← k∈K |N |N | φt 11: end for 12: return φT
τ − τ j−1 S(τ ) = S(τ j−1 ) + S(τ j ) − S(τ j−1 ) τ j − τ j−1
(9)
for a given time τ ∈ (τ j−1 , τ j ]. Intuitively, the step survival curve is linearly interpolated between any adjacent steps, resulting in constant densities in the corresponding interval (an illustration is provided in Appendix A). Performance Metrics—Concordance. We use the time-dependent concordance index [2], or c-index, which is a discriminative measure for how well the model ranks the relative survival between patient pairs, expressed as: P(S(ti |xi ) < S(ti |x j ) & ti < t j & si = 1)
(10)
which is estimated as follows: c=
i∈N
i∈N
j∈N , j=i
conci j
j∈N , j=i
compi j
(11)
compi j = 1{ti < t j & si = 1} + 1{ti = t j & si = 1 & s j = 0}
(12)
conci j = 1{S(ti |xi ) < S(ti |x j )} compi j
(13)
Performance Metrics—Calibration. While the c-index measures the discriminative performance of the model, it does not measure how well calibrated these estimates are (an illustration is provided in Appendix A). As a measure of calibration, we follow [12] who propose a Brier score for use with censored data defined as follows: B S(t) =
2 1 wi (t) yi (t) − h(t|xi ) |N | i∈N
(14)
A Federated Cox Model with Non-proportional Hazards
wi (t) =
177
si /G(t), if ti ≤ t 1/G(t), if ti > t
(15)
where yi (t) = 1{ti = t} and G(t) is the Kaplan-Meier estimate of the censoring distribution (i.e., estimated on {(xi , ti , 1 − si )}i∈N ). To measure calibration across the entire time horizon, we numerically integrate the Brier score using 100 time points [16].
4 Experiments We introduce the datasets we experiment on, describe the setup of a simulated federation and instantiate our model with three different linearity and hazards assumptions, and present our results.1
4.1 Datasets We experiment on three clinical datasets (Table 1; for Kaplan-Meier curves see Appendix A) made available by [13], namely the Molecular Taxonomy of Breast Cancer International Consortium (METABRIC), the Study to Understand Prognoses Preferences Outcomes and Risks of Treatment (SUPPORT), and the Rotterdam tumour bank and German Breast Cancer Study Group (GBSG). METABRIC and GBSG both relate to breast cancer patients, a group for whom non-PH have been noted [3, 5], while SUPPORT presents serious hospitalisations for a second application area.
Table 1 Overview of datasets Dataset Size METABRIC SUPPORT GBSG
1
1,904 8,873 2,232
Features
Prop. censored (%)
Last event
9 14 7
42 32 43
355 days 1,944 days 83 days
For source code, see https://github.com/dkaizhang/federated-survival.
178
D. K. Zhang et al.
4.2 Setup We simulate two federated data cases: In the first, data are randomly distributed (“IID”), simulating the case of each centre seeing a similar sample of the patient population. In the second, data are stratified on the time to event (“Non-IID”), simulating the case that each centre sees a non-overlapping quantile of the population— from centre 1 seeing only the shortest survivals leading to centre 4 with the longest survivals (Fig. 2). For comparability, we maintain the total number of local training rounds at 100. A pooled data baseline is provided (“Pooled”—no distinction between local and global rounds). In all cases, 80% of the overall data are split, if federated, and used for training, while 20% are held out for evaluation. We instantiate the model with different choices for g(x)—with a linear predictor or with an NN, with and without PH (Table 2). For baselines, we considered the works of [1] and [7]. The former assumes PH, however, while the latter is a pooled data model. Both require upfront agreement on a specification of f (t) to include non-PH, adding to the setup costs of a federation. We further note that no implementations of these models are available. We therefore provide the NN PH model to approximate the model of [1]—a federated NN-based Cox model with PH—and the Linear PH model as a standard baseline. Architectures (Fig. 1) are implemented in PyTorch 1.8.0 [22] with two hidden layers of 32 neurons for NN models and none for the linear model. Optimisation
300
Event time
Event time
300 200 100 0
200 100 0
1
2
3
Centre
4
1
(a) IID
2
3
Centre
4
(b) Non-IID
Fig. 2 Event time distribution for IID and non-IID data using stratification by event time on METABRIC Table 2 Model choices Model Linear PH NN PH NN nonPH
Predictor β x gφ (x) gφ,t (x)
A Federated Cox Model with Non-proportional Hazards
179
uses Adam with grid-sought learning rates (10−1 to 10−5 on 20% of the training data) and a batch size of 256. Base case discretisation uses 10 time steps.
4.3 Results In this section, we first compare the performance of the three models trained in a centralised fashion on pooled data against their federated performance on decentralised data. We next provide additional experiments exploring the impact of the discretisation grid chosen for the base case. We report averaged 5-fold cross-validation performance throughout. Federated Performance. On pooled data, the NN nonPH model outperforms or ties in concordance (Table 3) and ties with the best in calibration (Table 4), indicating a gain from the relaxation of the PH assumption. For METABRIC and GBSG this aligns with [3, 5, 11] who find non-PH amongst this patient group. Comparing this to the federated setting with IID data, all three models maintain their performance (within one standard deviation) in concordance and calibration when aggregation is frequent. First hints of performance degradation amongst the NN-based models occur as aggregation becomes very infrequent (rightmost columns)
Non-IID
IID
Pooled
Table 3 C-index (rebased to 100)—mean and standard deviation. Higher values are better METABRIC SUPPORT GBSG Global/local rounds Global/local rounds Global/local rounds Data Model 100/1 20/5 1/100 100/1 20/5 1/100 100/1 20/5 1/100 Linear 63.5± PH 1.4 NN PH 64.0± 0.6 NN 66.7± nonPH 2.1 Linear 63.9± PH 0.8 NN PH 63.8± 1.6 NN 65.4± nonPH 1.9 Linear 59.2± PH 3.0 NN PH 60.9± 1.1 NN 57.9± nonPH 2.9
64.0± 2.2 62.5± 1.9 65.7± 1.4 59.8± 2.1 59.4± 2.5 59.6± 3.9
63.7± 1.7 63.3± 1.2 61.5± 2.1 61.0± 1.3 57.3± 4.2 54.6± 5.3
57.2± 1.0 60.8± 0.6 61.5± 1.2 57.2± 0.8 60.6± 1.0 62.1± 0.7 55.4± 1.5 57.0± 1.4 50.8± 0.6
57.2± 0.8 60.9± 0.9 62.4± 0.4 56.1± 0.9 56.7± 1.1 50.9± 0.9
57.2± 0.4 58.3± 1.4 56.7± 2.3 56.2± 0.5 52.9± 0.8 50.8± 1.2
66.5± 2.1 66.2± 2.6 66.6 ± 1.9 66.5± 1.5 67.4± 1.7 66.5± 1.0 61.1± 0.9 61.3± 4.3 57.6± 1.9
66.3± 0.6 67.1± 1.2 66.4± 0.9 61.9± 3.0 61.8± 2.9 55.8± 2.0
66.3 ± 1.4 63.5± 2.6 62.8± 1.9 56.5± 6.9 53.0± 6.2 52.0± 2.5
180
D. K. Zhang et al.
Non-IID
IID
Pooled
Table 4 Integrated Brier scores (rebased to 100)—mean and standard deviation. Lower values are better METABRIC SUPPORT GBSG Global/local rounds Global/local rounds Global/local rounds Data Model 100/1 20/5 1/100 100/1 20/5 1/100 100/1 20/5 1/100 Linear 16.4± PH 0.6 NN PH 16.8± 0.7 NN 16.4± nonPH 0.8 Linear 16.3± PH 1.1 NN PH 17.2± 1.1 NN 16.3± nonPH 1.3 Linear 18.3± PH 0.8 NN PH 18.3± 0.5 NN 20.9± nonPH 0.3
16.7± 0.9 18.1± 0.5 16.5± 0.6 18.0± 0.6 18.1± 0.2 20.1± 0.5
16.5± 0.7 18.1± 1.1 19.1± 0.8 19.3± 0.7 19.7± 1.0 21.1± 1.0
20.9± 0.5 19.6± 0.4 19.6± 0.4 20.9± 0.6 19.7± 0.7 19.5± 0.7 22.9± 0.7 22.6± 0.5 26.2± 0.4
20.9± 0.3 19.8± 0.2 19.4± 0.4 22.1± 0.3 21.9± 0.2 25.3± 0.5
20.9± 0.6 22.6± 1.3 21.1± 0.6 21.4± 0.3 24.0± 1.8 25.3± 0.7
18.0± 0.5 18.2± 0.8 18.0± 0.4 18.1± 0.2 17.7± 0.3 18.3± 0.5 20.4± 0.5 20.5± 0.6 23.6± 0.9
18.1± 0.2 17.8± 0.4 18.2± 0.6 20.1± 0.5 20.6± 0.4 22.4± 0.8
18.1± 0.3 19.8± 1.9 21.7± 1.5 22.1± 0.3 22.3± 0.5 22.7± 1.0
while the Linear PH model appears largely unaffected. This observation is noteworthy, as infrequent aggregation will be a likely feature in practice given communication costs. This indicates a potential trade-off between model complexity and achievable aggregation frequency to support its training. In practice, data are likely to be non-IID across centres. The results show that the performances of all three models suffer when this is the case. Generally, the NN-based models experience the most severe losses in performance and are largely outperformed by the Linear PH model. When aggregation is infrequent, the NNbased models on SUPPORT and GBSG effectively approach a no-skill predictor in concordance (average c-index of 0.5). Further, performance losses under infrequent aggregation of the NN-based models are, as would be expected, worse than under IID data. On SUPPORT, the NN-based models exhibit much greater performance differences than on the other datasets. In this respect, we note that SUPPORT has much longer survival times than METABRIC or GBSG (Table 1), so that stratification by event time likely results in a more significantly different partition of the data for the former than for the latter two. Impact of Discretisation Fineness. We re-train models on finer time grids using 100 global and 1 local rounds. A finer grid on METABRIC (Fig. 3 upper panel) and GBSG (Fig. 3 lower panel) did not improve performance and, in fact, appears to
Integrated Brier score
C-index
A Federated Cox Model with Non-proportional Hazards METABRIC
0.70
0.70
0.65
0.65
0.65
0.60
0.60
0.60
0.55
0.55
0.55
0.50
0.50
0.26
Pooled
IID
Non-IID
Linear PH NN PH NN nonPH
Pooled
C-index
IID
0.70
Non-IID
0.50
0.26
0.26
0.22
0.22
0.18
0.18
0.18
0.14
0.14
0.22
Pooled
IID
Non-IID
Pooled
(a) 10 steps
IID
Non-IID
0.14
SUPPORT
0.60
0.60
0.60
0.55
0.55
0.55
0.30 0.26
IID
Non-IID
Linear PH NN PH NN nonPH
0.22 0.18
Pooled
IID
Pooled
Non-IID
0.50
Pooled
IID
Non-IID
0.50 0.30
0.26
0.26
0.22
0.22
0.18
Pooled
IID
Non-IID
IID
Non-IID
0.65
0.30
(a) 10 steps
IID
(c) 40 steps
0.65
Pooled
Pooled
(b) 20 steps
0.65
0.50
Integrated Brier score
181
Non-IID
0.18
Pooled
Pooled
IID
Non-IID
IID
Non-IID
(c) 40 steps
(b) 20 steps
C-index
GBSG 0.70
0.70
0.70
0.60
0.60
0.60
0.50
0.50
0.50
Integrated Brier score
0.40 0.56
Pooled
IID
Non-IID
0.40
Pooled
IID
Non-IID
0.40
0.56
0.56
0.46
0.46
0.36
0.36
0.26
0.26
0.26
0.16
0.16
0.46 0.36
Linear PH NN PH NN nonPH
Pooled
IID (a) 10 steps
Non-IID
Pooled
IID (b) 20 steps
Non-IID
0.16
Pooled
Pooled
IID
Non-IID
IID
Non-IID
(c) 40 steps
Fig. 3 Model performance with increasing discretisation fineness. Federated models were trained with 100 global and 1 local rounds. Performance decreases on smaller METABRIC and GBSG datasets with mixed results on the larger SUPPORT dataset
182
D. K. Zhang et al.
degrade performance. Notably, the Linear PH model becomes a no-skill predictor in terms of concordance on the non-IID GBSG case. The results are less conclusive on SUPPORT (Fig. 3 middle panel), as a finer time grid appears to result in a minor to no increase in concordance at the expense of a loss in calibration. A finer time grid can be expected to result in a trade-off between closer approximation of true (smooth) survival and a reduction in data available in any given time step. An increase from 10 to 20 time steps, for instance, halves the number of available data points to estimate a given step. The latter effect appears to dominate on the smaller METABRIC and GBSG datasets, and less so for the approximately 4-times larger SUPPORT dataset.
5 Conclusion We present a federated Cox model that relaxes the proportional hazards (PH) assumption and demonstrate its ability to maintain concordance and calibration relative to a pooled baseline under various linearity and PH assumptions. Compared to prior work, this federation scheme encodes the decision between PH and non-PH in a binary choice over the output layer, rather than requiring upfront agreement on a specification of f (t). We note that our model is not restricted to a particular data type or network architecture excepting the output component. Future work could adapt the model for image-based federated survival predictions. The decrease in performance on non-IID data (even if pathologically derived in this paper) represents a challenge to the application of federated learning in practice. Extensions could include exploring methods accounting for statistical heterogeneity [18, 26] or other federation topologies which maintain locally specialised models trained in a peer-to-peer fashion [23]. While the heterogeneity in this paper was derived from label stratification, other types of heterogeneity, such as covariate shifts, could be explored: for image-based survival predictions, differences in acquisition protocols could provide one such avenue. Acknowledgements This work was supported by the UKRI CDT in AI for Healthcare http:// ai4health.io (Grant No. EP/S023283/1).
Appendix A Additional Figures See Figs. 4, 5, 6.
A Federated Cox Model with Non-proportional Hazards
183
Fig. 4 Discretisation and interpolation
Fig. 5 Two sets of survival estimates with correct ranking (green above blue) but poor calibration given under-/overestimation of true survival curves
Fig. 6 Kaplan-Meier estimates with 95% confidence interval
184
D. K. Zhang et al.
References 1. Andreux, M., Manoel, A., Menuet, R., Saillard, C., & Simpson, C. (2020). Federated survival analysis with discrete-time cox models (pp. 1–21). 2. Antolini, L., Boracchi, P., & Biganzoli, E. (2005). A time-dependent discrimination index for survival data. Statistics in Medicine, 24(24), 3927–3944. https://doi.org/10.1002/sim.2427 3. Bellera, C. A., MacGrogan, G., Debled, M., De Lara, C. T., Brouste, V., & Mathoulin-Pélissier, S. (2010). Variables with time-varying effects and the Cox model: Some statistical concepts illustrated with a prognostic factor study in breast cancer. BMC Medical Research Methodology, 10. https://doi.org/10.1186/1471-2288-10-20 4. Bello, G. A., Dawes, T. J., Duan, J., Biffi, C., de Marvao, A., Howard, L. S., Gibbs, J. S. R., Wilkins, M. R., Cook, S. A., Rueckert, D., & O’Regan, D. P. (2019). Deep-learning cardiac motion analysis for human survival prediction. Nature Machine Intelligence, 1(2), 95–104. https://doi.org/10.1038/s42256-019-0019-2 5. Coradini, D., Daidone, M. G., Boracchi, P., Biganzoli, E., Oriana, S., Bresciani, G., Pellizzaro, C., Tomasic, G., Di Fronzo, G., & Marubini, E. (2000). Time-dependent relevance of steroid receptors in breast cancer. Journal of Clinical Oncology, 18(14), 2702–2709. https://doi.org/ 10.1200/JCO.2000.18.14.2702 6. Cox, D. R. (1972). Regression models and life-tables. Journal of the Royal Statistical Society. Series B (Methodological), 34(2), 187–202. https://doi.org/10.1111/j.2517-6161.1972. tb00899.x, https://www.jstor.org/stable/2985181 7. Craig, E., Zhong, C., & Tibshirani, R. (2021). Survival stacking: Casting survival analysis as a classification problem (pp. 1–17). arXiv:abs/2107.13480 8. Dai, W., Jiang, X., Bonomi, L., Li, Y., Xiong, H., & Ohno-Machado, L. (2020). VERTICOX: Vertically distributed cox proportional hazards model using the alternating direction method of multipliers. IEEE Transactions on Knowledge and Data Engineering, 4347(c), 1. https:// doi.org/10.1109/tkde.2020.2989301 9. Faraggi, D., & Simon, R. (1995). A neural network model for survival data. Statistics in Medicine, 14(1), 73–82. https://doi.org/10.1002/sim.4780140108 10. Gensheimer, M. F., Narasimhan, B. (2019). A scalable discrete-time survival model for neural networks. PeerJ, 1–19. https://doi.org/10.7717/peerj.6257 11. Gore, S. M., Pocock, S. J., & Kerr, G. R. (1984). Regression models and non-proportional hazards in the analysis of breast cancer survival author. Journal of the Royal Statistical Society. Series C (Applied Statistics), 33(2), 176–195. 12. Graf, E., Schmoor, C., Sauerbrei, W., & Schumacher, M. (1999). Assessment and comparison of prognostic classification schemes for survival data. Statistics in Medicine, 18(17–18), 2529– 2545. https://doi.org/10.1002/(sici)1097-0258(19990915/30)18:17/183. 0.co;2-5 13. Katzman, J. L., Shaham, U., Cloninger, A., Bates, J., Jiang, T., & Kluger, Y. (2018). DeepSurv: Personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC Medical Research Methodology, 18(1), 1–11. https://doi.org/10.1186/s12874018-0482-1 14. Kelly, C. J., Karthikesalingam, A., Suleyman, M., Corrado, G., & King, D. (2019). Key challenges for delivering clinical impact with artificial intelligence. BMC Medicine, 17(1), 1–9. https://doi.org/10.1186/s12916-019-1426-2 15. Kvamme, H., & Borgan, O. (2019). Continuous and discrete-time survival prediction with neural networks. 16. Kvamme, H., Borgan, O., & Scheel, I. (2019). Time-to-event prediction with neural networks and cox regression. Journal of Machine Learning Research, 20, 1–30. 17. Li, H., Boimel, P., Janopaul-Naylor, J., Zhong, H., Xiao, Y., Ben-Josef, E., & Fan, Y. (2019). Deep convolutional neural networks for imaging data based survival analysis of rectal cancer (pp. 1–4). 18. Li, T., Sahu, A. K., Zaheer, M., Sanjabi, M., Talwalkar, A., & Smith, V. (2018). Federated optimization in heterogeneous networks. arxiv:abs/1812.06127
A Federated Cox Model with Non-proportional Hazards
185
19. Lu, C. L., Wang, S., Ji, Z., Wu, Y., Xiong, L., Jiang, X., & Ohno-Machado, L. (2015). WebDISCO: A web service for distributed cox model learning without patient-level data sharing. Journal of the American Medical Informatics Association, 22(6), 1212–1219. https://doi.org/ 10.1093/jamia/ocv083 20. Luck, M., Sylvain, T., Cardinal, H., Lodi, A., & Bengio, Y. (2017). Deep learning for patientspecific kidney graft survival analysis (Nips 2017). arxiv:abs/1705.10245 21. McMahan, H. B., Moore, E., Ramage, D., Hampson, S., & Agüera y Arcas, B. (2017). Communication-efficient learning of deep networks from decentralized data. In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics, AISTATS 2017 (Vol. 54). 22. Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J., & Chintala, S. (2019). PyTorch: An imperative style, high-performance deep learning library. Advances in Neural Information Processing Systems, 32(NeurIPS). 23. Rieke, N., Hancox, J., Li, W., Milletarì, F., Roth, H. R., Albarqouni, S., Bakas, S., Galtier, M. N., Landman, B. A., Maier-Hein, K., Ourselin, S., Sheller, M., Summers, R. M., Trask, A., Xu, D., Baust, M., & Cardoso, M. J. (2020). The future of digital health with federated learning. NPJ Digital Medicine, 3(1), 1–7. https://doi.org/10.1038/s41746-020-00323-1 24. Wang, P., Li, Y., & Reddy, C. K. (2019). Machine learning for survival analysis: A survey. ACM Computing Surveys, 51(6). https://doi.org/10.1145/3214306 25. Wiens, J., Saria, S., Sendak, M., Ghassemi, M., Liu, V. X., Doshi-Velez, F., Jung, K., Heller, K., Kale, D., Saeed, M., Ossorio, P. N., Thadaney-Israni, S., & Goldenberg, A. (2019). Do no harm: A roadmap for responsible machine learning for health care. Nature Medicine, 25(September). https://doi.org/10.1038/s41591-019-0548-6 26. Yang, L., Beliard, C., & Rossi, D.: Heterogeneous data-aware federated learning (1). 27. Zhu, X., Yao, J., & Huang, J. (2017). Deep convolutional neural network for survival analysis with pathological images. In Proceedings—2016 IEEE International Conference on Bioinformatics and Biomedicine, BIBM 2016 (Vol. 2, pp. 544–547). https://doi.org/10.1109/BIBM. 2016.7822579 28. Zhu, X., Yao, J., Zhu, F., & Huang, J. (2017). WSISA: Making survival prediction from whole slide histopathological images. In Proceedings—30th IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2017 (pp. 6855–6863). https://doi.org/10.1109/CVPR.2017. 725
A Step Towards Automated Functional Assessment of Activities of Daily Living Bappaditya Debnath, Mary O’brien, Swagat Kumar, and Ardhendu Behera
Abstract Current activity recognition approaches have achieved a great success due to the advancement in deep learning and the availability of huge public benchmark datasets. These datasets focus on highly distinctive actions involving discriminative body movements, body-object and/or human-human interactions. However, in real-world scenarios, e.g., functional assessment of a rehabilitation task, which requires the capability of differentiating the execution of same activities performed by individuals with different impairments, their recognition accuracy is far from being satisfactory. To address this, we develop Functional-ADL, a challenging novel dataset to take action recognition to a new level. Compared to the existing datasets, Functional-ADL is distinguished in multi-label and impaired-specific executions of different Activities of Daily Living (ADL) to contribute towards vision-based automated assessment and rehabilitation of physically impaired persons. We also propose a novel pose-based two-stream multi-label activity recognition model consisting of a spatial and a temporal stream. The proposed approach significantly outperforms the state-of-the-art by a considerable margin. This new Functional-ADL dataset presents significant challenges for human activity recognition, and we hope this could advance research towards activity understanding and monitoring. Keywords Physical rehabilitation · Functional activity recognition · Computer vision · Deep learning · Body-pose sequence · Fisher vectors
B. Debnath (B) · M. O’brien · S. Kumar · A. Behera Edge Hill University, St Helen’s Road, Ormskirk L394QP, UK e-mail: [email protected] M. O’brien e-mail: [email protected] S. Kumar e-mail: [email protected] A. Behera e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_13
187
188
B. Debnath et al.
1 Introduction Activity recognition is an important and challenging problem in computer vision with many applications linking assistive and rehabilitative robotics for health and social care services. This research aims to contribute towards this where there has been an increased interest in using vision-based human motion understanding for rehabilitation and assessment of physically impaired patients [27]. Physically impaired people (e.g., affected by stroke, spinal cord injury, etc.) often experience problems with physical movement and balance. As a result, they face difficulties in performing day to day tasks, known as Activities of Daily Living (ADL) [9]. To recover, improve or avoid further loss of physical functionality, such patients undergo physical rehabilitation programs [9] involving repetitive therapeutic exercises or ADL. These activities are usually guided by health professionals (clinicians, occupational therapists and physiotherapists) at home or in a clinic [9]. The assessment part of this rehabilitation process is often carried out via direct observation, which requires the observer to note down the detailed movements of the patients performing a given ADL. The process is time-consuming, laborious and often requires a significant attention from the observer. This process could be automated by using vision-based autonomous systems that can recognise and evaluate the difference between normal and impaired physical activities. This has the potential to lower cognitive load on the observers, time and overall cost (Fig. 1). Functional assessment through ADL is widely carried out for assessing a patient’s condition and progress. To measure it, there exists various methods [12] that focus on complex manual assessment. The current study is only the first step towards automating functional assessment of various ADL. The main aim of this study is to recognise different ADL (e.g., eating, drinking, etc.) as well as their impairmentspecific variations executed by individuals with physical impairments (e.g., ataxia, elbow rigidity, etc.). This will help in monitoring the extent of their progress and
(a) Reaching Above: Normal, Elbow Rigidity, and Shoulder Weakness
(b) Walking: Normal, Knee Rigidity, and Wider Gait
Fig. 1 We introduce a novel multi-label functional ADL dataset consisting of 10 activities and four different physical impairment-specific executions of each ADL for fine-grained activity recognition in rehabilitation videos. a “Reaching Above” activity is executed by a normal person, an individual with ‘Elbow Rigidity’, and a person with “Shoulder Weakness” impairment (left to right). b Similarly, “Walking” activity is performed ‘Normally’, with ‘Knee Rigidity’ and ‘Wider Gait’
A Step Towards Automated Functional Assessment …
189
achievement while undergoing impairment-specific rehabilitation. In recent times, computer vision researchers have focused on Deep Learning (DL) [1, 16, 28] to improve human activity recognition accuracy. This is possible due to the availability of large-scale datasets. However, for vision-based ADL assessment and rehabilitation, authors have mainly used their own small in-house datasets [4, 25, 27, 37]. Lack of suitable publicly available datasets is one of the reasons why there has been less participation from the vision community in developing models for solving such problem. To address this, we present a novel multi-label functional ADL dataset that is targeted towards automated functional assessment of physically impaired persons through ADL. Physically impaired persons would perform an ADL differently from healthy individuals resulting in a different spatio-temporal pattern, which is dependent on the type of impairment they are suffering from. For example, a person having tremors would shake his/her hand while drinking water and the spatio-temporal trajectory would be different from a drinking action without tremors. The existing human activity datasets (Table 1) are not appropriate to develop and validate solutions targeting this issue. Thus, this study presents a dataset that consists of a normal and various physical impairment-specific executions of the same ADL. The proposed dataset contains 5685 samples of 10 common ADL performed by 10 subjects, captured in video, depth and human body-pose sequence format. For each ADL, the dataset presents one normal and four different physical impairment-specific executions. Thus, each sample has two labels, one for the ‘Activity’ (e.g., drinking, walking, etc.) and the other one for the ‘Impairment’ (e.g., normal, ataxic, etc.) and hence, the name multi-label functional ADL dataset. Furthermore, we present a
Table 1 Comparison of the proposed dataset with other popular activity recognition datasets. The proposed multi-label functional ADL recognition dataset represents a normal and four different physical impairment-specific executions for each ADL. R: RGB, D: Depth, P: Pose Datasets #Videos #Activities #Impairments #Subjects Data modalities MSRDailyActivity3D [35] UTKinect [39] MSR-Action3D [20] CAD-60 [32] CAD-120 [17] Northwestern-UCLA [36] NTU-RGBD [28] Chardes [30] NTU-RGBd 120 [21] Toyota Smart Home [5] UA-Concurrent [38] Ours
320 200 567 60 120 1475 60 K 10 K 120 K 16 K 201 5865
16 10 20 12 20 10 60 157 120 51 35 10
0 0 0 0 0 0 0 0 0 0 0 8
10 10 10 4 4 10 40 267 106 18 NA 10
R, D, P R, D, P R, D, P R, D, P R, D, P R, D, P R, D, P R R, D, P R, D, P R, D, P R, D, P
190
B. Debnath et al.
novel pose-based two-stream multi-label activity recognition model based on TCNResNet [16] that comprehensively outperforms the TCN-ResNet on our dataset and the well-known NTU-RGBD ADL recognition dataset [28]. The two-stream architecture inspired by [6], which consists of a spatial and a temporal stream. The spatial stream contains a Spatial Encoding Unit (SEU), which provides an enriched representation that learns to capture the structural relationships between various body joints in a given video frame. Similarly, the temporal stream includes a Temporal Encoding Unit (TEU) that learns to encode the temporal relationship of each body joint over the duration of a given sequence. The performance of the network is further enhanced by the introduction of a Fisher Vector (FV) based activity-aware learn-able pooling mechanism introduced at the end of each stream to replace the Global Average Pooling (GAP) in the TCN-ResNet [16]. Our novel contributions are: (1) A novel functional ADL recognition dataset that presents a normal and four different physical impairment-specific versions of each ADL. (2) A pose-based (skeleton) two-stream functional ADL recognition model that integrates a spatialtemporal body-pose encoding mechanism with FV-based pooling in a novel manner.
2 Related Works Datasets: Major advances have been made in human activity recognition influenced by the availability of large-scale datasets. Well-known datasets in this domain are shown in Table 1. However, the existing datasets are largely targeted towards normal human activity recognition and are not suitable for functional assessment of physically impaired patients through ADL. These datasets consist different normal ADL but do not capture various physical impairment-specific versions of the same ADL (Table 1, Column 4). Most of the datasets in Table 1 are single-label datasets whereas we present a multi-label dataset where there are two labels (‘Activity’ and ‘Impairment’) for each sample. The Chardes [30] and the UA-Concurrent [38] datasets present multiple labels for a single activity sample but these are multi-label normal activities. The NTU-RGBD 120 dataset [21] presents 12 medical conditions including neck pain, fall, etc. However, it is a single-label dataset which does not demonstrated the difference between impairment-specific executions of the same ADL. To the best of our knowledge, this is the first dataset that illustrates the difference between various physical impairment-specific versions of the same ADL. Pose-based activity recognition: The availability of cheap depth sensors (e.g., Microsoft Kinect) has significantly influenced pose-based activity recognition. Processing 3D pose (body skeleton) information is computationally much less expensive than RGB video processing and thus, researchers have increasingly relied on pose-based methods for human activity recognition. Most works in this area have explored recurrent networks such as RNN, LSTM and GRU which are specially designed for processing sequential information such as trajectory of human body joints in a given activity. Liu et al. [22] advanced the human tree-structure to model spatio-temporal features learned from a modified gating mechanism of LSTM. Song
A Step Towards Automated Functional Assessment …
191
et al. [31] introduced LSTM-based spatial and temporal networks with attention mechanism. Temporal Convolutional Network (TCN) which are stack of 1D convolutional layer have been explored as an alternate to recurrent mechanism. Lea et al. [18] proposed TCN with an Encoder-Decoder and a dilated convolution model for activity recognition. An LSTM cell processes each time-step sequentially whereas no such constraint exists within TCN. This makes TCN inherently faster than LSTMs. Kim et al. [16] present a pose-based TCN-ResNet model which combines residual connections with TCN and shown to be computationally inexpensive without compromising the recognition accuracy on NTU-RGBD dataset [28]. The proposed model is inspired by this lightweight architecture which is necessary for home-based or in-clinic assessment of patients where high-performance computational facilities (e.g., servers, GPUs, etc.) are not available. However, pose-based models do not benefit from contextual cues such as hand-object interactions, background information other than body pose. Thus, authors have focused on enriching the pose-information with physics-based measurements such velocities and acceleration [7, 42], different normalisation techniques [42], relative body joints positions [15] etc. Instead of handcrafting such features, [6] uses SEU and TEU to automatically learn enhanced representations that can capture structural information and various inter-joint dependencies of the human body joints. Inspired by this, we adapt the TCN-ResNet model [16] to use a spatial-temporal architecture involving SEU and TEU layer to advance the light-weight human activity recognition approaches. Learnable-Pooling: To further enhance the performance of our model, a FVbased pooling mechanism is used that replaces the GAP layer typically present towards the end of many standard convolutional architectures including TCN-ResNet [16]. Similar to GAP, the literature presents other statistical pooling methods like average or max-pooling [13, 14], rank-pooling [8], context-gating [24] and highdimensional Feature encoding [40]. Pooling using statistical methods do not consider spatial-temporal and other semantic information in feature maps produced by CNN, TCN or LSTM. Thus, learn-able pooling methods have been explored by researchers to pool the most relevant features based on learned representations. In [10], authors present a second order attentional pooling, in which the output map from a CNN is multiplied with a weighted version of itself. A well-known technique called VLAD for image feature representation is integrated by Girdhar et al. [11] for learn-able pooling-based activity recognition. In [24], authors introduce learn-able FV (NetFV) to semantically cluster and pool audio and video features by integrating it to a deep model. In this study, NetFV is adapted for semantically clustering information present in TCN-ResNet maps which further enhances the model performance.
3 Dataset Human motion manifests in a wide variety of forms and so does its abnormalities. It is not feasible to capture whole range of ADL and their corresponding impairments. The idea is to prepare a dataset that would meet the following constraints: (1) The
192
B. Debnath et al.
dataset should contain enough samples uniformly spread across subjects, activities and impairments that would suffice the needs of DL-based models. (2) The dataset should contain enough activities that would collectively cover a wide range of body movements and capture a few common abnormalities. To assess a patient’s condition and to determine their functional independence, clinicians often require them to perform ADL [12]. The initial idea was to capture patients performing these activities and annotate each action with an ‘Activity’ and an ‘Impairment’. To create a dataset, one needs multiple samples for each annotation, ideally uniformly spread across number of subjects. It is very difficult to ask patients to perform multiple repetitions of each of the activities owing to their physical constraints. It is easy to see that a patient with a ‘bent knee’ would face difficulty in performing sit to stand multiple times and would not be able to provide a sample of a normal sit to stand sequence. Thus, to address this, the workaround was to film the activities with healthy subjects acting like patients. To make sure that activities performed by healthy subjects accurately reflect the performance of real patients, help was sought from an occupational therapist. Under the guidance of the occupational therapist the ADL were chosen in a manner that would collectively cover a wide range of body movements and test various parts of the musculo-skeletal system. Each ADL filmed for this dataset is captured in one healthy and four different physical impairment-specific executions of the same. The activities “Sitting”, “Standing” and “Walking” cover lower torso and leg movements. “Drinking”, “Brushing Hair” and “Wearing Glasses” test the functionality of upper limbs. “Brushing Floor”, “Answering Phone” and “Clapping” are preformed while standing and thus they require close co-ordination between upper and lower halves of the body. The impairments ‘Weakness to One Side’, ‘Knee Rigidity’ and ‘Wider Gait’ are represented in all the lower-limb activities. The impairments ‘Shoulder Weakness’, ‘Tremors’ and ‘Elbow Rigidity’ are present for all the upperlimb activities. All the activities exhibit the ‘Normal’ and the ‘Ataxic’ versions. The dataset presents seven impairments in total while each of the 10 activity is represented through four different impairments in addition to a regular healthy execution. Altogether there are 5685 samples, each annotated with an activity and an impairment, performed by 10 (5 female and 5 male) subjects.
4 Proposed Approach There are many aspects of designing a pose-based model and the proposed model aims to address the following aspects: (1) Effectively capture the saptio-temporal information contained within human body-pose sequence. (2) Semantically cluster meaningful information represented by the body-pose network. The proposed model is based on the TCN-ResNet architecture [16], which is basically a combination of 1D convolutions with residual connections. We use two TCN-ResNet models, one for a spatial stream and another one is for a temporal stream. In the spatial stream, Block-A of the TCN-ResNet is replaced with SEU and similarly, TEU substitutes Block-A in the temporal stream (Fig. 2). As in [6], for each frame the SEU captures
A Step Towards Automated Functional Assessment …
193
Fig. 2 The proposed model consists of a spatial and a temporal stream where each stream uses a TCN-ResNet [16]. Block-A of the spatial stream is replaced with the SEU [6] while the same block in temporal stream is exchanged with the TEU [6]. The GAP layer of the TCN-ResNet [16] is replaced by a FV-based pooling mechanism [24]. The Soft-max output of both the streams are multiplied (indicated by ×) and normalised for the final output. The model is trained through a multi-hot encoded label where in each label vector there are two ‘1’s indicating ‘Activity’ and ‘Impairment’ label
the structural relationship between various body joints and enhanced/augmented representation of the human body pose sequence to rest of the network. On the other hand, the goal of the TEU is to encode the frame-wise positions of body joints and present a temporally rich representation for each joint, individually. Furthermore, we introduce a novel FV-based learn-able pooling mechanism in each stream replacing the GAP layer in the original TCN-ResNet. This is mainly due to the fact that learnable pooling approaches have shown to be more effective in pooling more relevant features instead of statistical pooling (e.g., average or max-pool). This has been further discussed in related works. This learn-able pooling method integrates FVbased clustering mechanism which semantically clusters the spatial and temporal structures contained within the respective streams. Thus, it has significantly improved the recognition accuracy as shown in the experimental evaluation section. TCN-ResNet: The TCN-ResNet model is basically a stacking of 1D convolutional layers followed by the standard GAP + FC layers. As shown in Fig. 2, the network is composed of three 1D convolutional blocks (Block-A, Block-B and Block-C) and each of the three block is composed of three layers of 1D convolutions. Each convolutional operation is followed by BN and a ReLU activation function. The convolutional operation at the start of Block-B and Block-C is of stride 2, which means the input is halved along the first dimension (normally time dimension) as it passes from Block-A to Block-B and then from Block-B to Block-C. There are two paths between any two layers: (1) First is through 1D convolutional operation followed by BN and a ReLU activation function; (2) Second, through a residual or skip-connection (omitted in Fig. 2 for simplicity). Let T be number of frames in a sequence, J the number of body joints, D dimension of each joint (3 for 3D pose) and F the total number of filters in a layer. Then, with input pose map V ∈ R T ×J D , 1D convolution operations in each block performs the following transformation:
194
B. Debnath et al. a Block A : VT,J ×D → MT,F a
(1)
Block B :
(2)
BlockC :
a MT,F → MTb /2,Fb a MTb /2,Fb → MTc /4,Fc
(3)
Here, Fa = 64, Fb = 128, Fc = 256 indicate the number of filters and M a ,M b , M c imply the output maps of Block-A, Block-B and Block-C (Fig. 2), respectively. The output of Block-C is passed through a standard GAP layer followed by a FC layer with Soft-max activation function. Spatial Stream: The spatial stream is a TCN-ResNet [16] model. The difference is, it adapts the Block-A (Fig. 2) to the SEU introduced by [6]. Normal 1D convolutions process all the frames in a body-pose sequence together as shown in Eq. 1. In contrast, the SEU processes each frame through separate and independent convolutional operations and then concatenate the outputs. The enables the network to learn relationships and dependencies between various body joints for each point in time. In other words, the network learns the body-structure spatially for each timestep for a given sequence. This structural learning is absent in case of TCN-ResNet which involves normal 1D convolutions. Formally the SEU in Block-A performs the following transformation [6]: at = Uˆ t (Fa , V J,D ) Mˆ J,F a
(4)
at a → Mˆ T,J Mˆ J,F ×Fa a
(5)
where Uˆ is the convolution operation parameterised by filters Fa . Mˆ indicates a map for the spatial stream, which corresponds to M for the TCN-ResNet (Eqs. 1– 3). Normally, body-pose sequence is represented by a map where for each frame, the body-pose is represented by a vector of size J × D. In TCN-ResNet [16], the convolution operation in Block-A transforms this vector into a vector of length Fa (Eq. 1), which represents the body-pose for each frame transformed through 1D convolutions. In contrast, the transformation by SEU produces a body-pose that is represented through a vector of size J × Fa for each frame (Eq. 4). The SEU at for each individual transforms joints in each frame separately to produces maps Mˆ J,F a frame (Eq. 4). These maps are aggregated to form the final SEU output (Eq. 5). Thus, instead of Fa for all the joints in TCN-ResNet (Eq. 1), the SEU represents each joint by a vector of size Fa (Eq. 5). Through this enhanced representation from D coordinates (normally 3 for 3D pose) in input to Fa (=64) at the output of Block-A (Fig. 2), in each frame, the SEU encodes the relationships and dependencies between various body joints that is learnt [6]. The SEU increases the number of parameters from SEU by a factor of number of joints (Eq. 5). Including more blocks (Block-B, Block-C) in the SEU, increases the number of parameters because in TCN-ResNet the filter count is doubled each time a block is traversed (Eqs. 1, 2, 3). Empirically it was observed that including more block for the SEU, made the model slower while having no positive impact on the performance. The output of Block-C in TCN-ResNet
A Step Towards Automated Functional Assessment …
195
has the temporal dimension reduced by a fourth, as shown in Eq. 3. Thus from Eq. 3 (TCN-ResNet), the spatial stream (Fig. 2) performs the following transformation: VT,J ×D − > Mˆ Tc /4,Fc
(6)
Note that that because of SEU, the Block-A (Eq. 5) in the spatial stream produces different output than Block-A (Eq. 1) in the TCN-ResNet. However, this does not make any difference in the overall output of the spatial stream at the end of Block-C which is determined by the number of filters at the end of Block-C and the number of frames (time-steps) in the body-pose sequence. Effectively, output of Block-C in spatial stream (Eq. 6) is same as output of Block-C in TCN-ResNet (Eq. 3). Temporal Stream: Similar to the spatial stream, a TCN-ResNet [16] is used for the temporal stream. The first block (Block-A, Fig. 2) is used for TEU as done in [6]. Formally, the TEU performs the following transformation: VT,J ×D → V¯ J ×D,T
(7)
M¯ Ja ×D,Fa = U¯ (Fa , V¯ J ×D,T )
(8)
M¯ Ja ×D,Fa → M¯ Faa ,J ×D
(9)
Here, U¯ is the convolution operation. M¯ indicates maps for temporal stream corresponding to M in TCN-ResNet (Eqs. 1–3). Normally, the input consists of a map VT,J ×D which is transposed to V¯ J ×D,T in case of TEU. This means in TEU each input data point (row) represents the temporal variation of a body joint over t = 1 to t = T and hence the name TEU. This is in contrast to a normal convolution operation where each data point consists of body joint j = 1 to j = J for a single frame. Similar to a the spatial stream the output map M¯ F,J ×D of the TEU is passed to Block-B and Block-C (Fig. 2) in the temporal stream which perform the following transformation (Eqs. 2, 3): (10) M¯ Faa ,J ×D − > M¯ Fc a /4,Fc Streams fusion: The potential points for fusion of the two streams are at the end of each block. The SEU and the TEU produce maps of different dimensions (Eqs. 5, 10) at the end of Block-A (Fig. 2). Moreover, the TCN-ResNet [16] reduces the temporal dimensions through Block-B and Block-C (Eqs. 1, 2, 3). Thus, the spatial and temporal streams produces maps of different dimensions at the end of each blocks. For example, the Block-C of spatial stream produces map Mˆ Tc /4,Fc (Eq. 6), whereas the temporal stream has the map M¯ Fc a /4,Fc (Eq. 10). The different sizes of the maps do not allow the maps to be fused with either concatenation or addition in a semantic manner. At the end of Block-C, the two streams can be fused by flattening and concatenating however, flattening disturbs the spatial and temporal structural organisation of the maps. FV-based clustering mechanism relies on such meaningful representations for semantic clustering [26]. Empirically, it was observed
196
B. Debnath et al.
that flattening the two streams at this stage for fusion lead to poor performance. To preserve the structural organisation of the maps and to cluster them semantically, each stream uses its own learn-able FV pooling. FVs [26] are computed as the aggregation of cluster weights, means and co-variances computed from Gaussian Mixture Model (GMM). Instead of calculating the FVs from GMM, NetFV (FV with neural network) learns these parameters [24]. Let M R,S be the input to FV-based pooling. In spatial stream, M R,S corresponds to Mˆ Tc /4,Fc (Eq. 6) where R = T /4 and S = Fc . Similarly, in temporal stream, M R,S corresponds to M¯ Fc a /4,Fc (Eq. 10) where R = Fa /4 and S = Fc . Let r ∈ (1, R). The idea is to assign each S-dimensional data point of M i.e., Mr to a cluster as a soft-assignment [24]: Tr
αk (Mr ) =
e Wk
Mr +bk Tr
Kj=1 e W j
Mr +b j
(11)
Here matrix W j and bias-vector b j are learn-able parameters. The soft-assignment αk (Mr ) to the kth cluster indicates how close Mr is to the cluster k. Here, j ∈ (1, K ) where K is the total number of clusters. Using the above soft-assignment, the Fisher vector is computed using the NetFV representation by [24]: Mr ( j) − ck ( j) F V1 ( j, k) = αk (Mr ) σk ( j) r =1 R Mr ( j) − ck ( j) 2 αk (Mr ) −1 F V2 ( j, k) = σk ( j) r =1 R
(12)
F V1 and F V2 respectively are the first-order and second-order statistics FV. ck and σk are the learned cluster centre and the diagonal co-variance of the kth cluster, where k ∈ (1, K ). Here, ck and σk are randomly initialised and are learned independently from the parameters of the soft-assignment αk as in Eq. 11. As in [24], the FVs are then L2 normalised and concatenated and to get the final F V = [F V1 , F V2 ]. In [24], the learned FV from the video stream is concatenated with the FV from the audio stream to form a FC layer which is further processed with context gating and mixture of experts identifier. Our implementation is different from the approach of [24], where a weighted pooling mechanism is used to output the final class maps: Pooling(F V ) = softmax(W p F V + b p )
(13)
Here, matrix W p ∈ R|F V |×C and bias vector b p are learn-able parameters, and C is number of human activity classes. The class-maps from both the stream are multiplied and normalised to get the final output. By avoiding concatenation of FVs we preserve the intelligently pooled features from the semantic FV-based clusters to form the class-maps. The multiplication and normalisation (Fig. 2) operation ensures that the
A Step Towards Automated Functional Assessment …
197
network automatically learns the contribution weightage of each stream without the need for a further FC layer. Thus, in contrast to NetFV [24] we are able to produce class-maps without any further processing.
5 Training and Evaluation Apart from the ground truth and evaluation method, the model is trained in a similar manner to standard single-label classification models. The ground truth is presented as multi-hot encoded labels to train the multi-label model. Two separate one-hot encoded labels, prepared as ‘activity’ labels and ‘impairment’ labels are concatenated to form the final ground truth labels. Let there be A activity classes and I impairment classes. For ath activity class where a ∈ {1 . . . A} and i th impairment class where i ∈ {1 . . . I }, the one-hot encoded labels for activity and impairment respectively are: 1, if m = a, 1, if n = i, I L n∈I = (14) AL m∈A = 0, if m = a. 0, if n = i. GT = AL m∈A ⊕ I L n∈I
(15)
To create the final ground truth label GT , the two labels are simply concatenated (Eq. 15). Thus, each of the ground truth label vectors GT has two ‘1’ values indicating activity and impairment. In GT , the ‘activity’ label comes from the first A elements whereas the ‘impairment’ label is determined from final I elements. Thus, to evaluate the model, the prediction probability vector (i.e., the model output) is split into two parts where the first part contains the first A elements indicating the ‘activity’ class probabilities and the rest I elements indicate the ‘impairment’ probabilities. Then, the accuracy for the ‘activities’ and ‘impairments’ are calculated separately. Finally, prediction by the model is considered to be true if both the ‘activity’ and ‘impairment’ predictions are correct.
6 Experiments and Results We evaluate the proposed pose-based model in both single-label and multi-label mode. The well-known and challenging NTU-RGBD [28], which contains around 60 K samples distributed over 60 action classes has been used for evaluation in singlelabel mode. We use the authors [28] protocol of cross-subject (CS) evaluation which is harder than the cross-view evaluation. Table 2 compares the proposed model to existing state-of-the art approaches and shows that the model achieves competitive performance under the constraints of data modality (pose-based, RGB video-based), end-to-end trainable and random initialisation (i.e., not pre-trained). The proposed
198
B. Debnath et al.
Table 2 The proposed model achieves competitive accuracy when compared with other pose-based state-of-the art models given the constraints of data mode (P: Pose, R: RGB-video), being end-toend trainable (E2E) and random initialisation (RI). Given these constraints ST-GCN achieves the best performance and we achieve performance close to ST-GCN Model Mode E2E RI CS (%) TCN-ResNet [16] Synth-CNN [23] ST-GCN [41] DPRL+GCNN [33] 3Scale-ResNet152 [19] Glimpse Clouds [2] Learned-Encoding [6] DGNN [29] Ours
P P P P P R RP P P
x x x
x x x x x
74.3 80.0 81.5 83.5 85.0 86.6 87.7 89.9 80.2
model has the advantage to being end-to-end trainable as compared to [23, 29, 33]. Also, in contrast to [2, 19, 23, 33] we do not pre-train the proposed model which reflects the true capacity of model to learn without prior information. Given these constraints the best performance is achieved by ST-GCN [41] and the proposed model achieves almost similar performance while being a very light-weight model. ST-GCN [41] requires 8 Nvidia Titan X GPUs for training while we use only one Titan X GPU. Next, we evaluate our model using the proposed multi-label functional ADL recognition dataset using the following protocol. A cross-validation approach is used where the dataset is split into two subject-wise folds for good generalisation. The first fold uses subjects 1, 3, 5, 7, 9 for training while subjects 2, 4, 6, 8, 10 are used for validation and vice-versa for the second fold. Thus, out of 5685 samples, the dataset is split into two groups of approximately equal groups of 2869 and 2816 samples which indicates a very good generalisation protocol. We do not use any data-augmentation or any transfer-learning approach to understand the true capacity of the model to learn. The results in Table 3 are weighted average average of the two-fold cross-validation mentioned above. For each sample, the models in Table 3 predict ‘activity’ (A) and ‘impairment’ (I) classes and a model’s ‘Final’ prediction is considered to be true if both the ‘activity’ and ‘impairment’ predictions are true.
6.1 Ablation Study In Table 4, we demonstrate the evolution of the model from the base TCN-ResNet to the final model through step-by-step inclusion of spatial-temporal architecture and the FV-based pooling. First, we experiment with the original TCN-ResNet (Row 1),
A Step Towards Automated Functional Assessment …
199
Table 3 Evaluation of the proposed dataset using different models that predict ‘Activity’ (A) and ‘Impairment’ (I) and a model’s prediction is considered as correct if both the ‘Activity and ‘Impairment’ predictions are true. Mode: Pose (P), RGB (R) Model Mode A I Final I3D [3] C3D [34] TCN-ResNet [16] Ours
R R P P
87.2 90.1 91.2 97.1
65.9 73.2 69.0 80.7
55.9 63.3 63.4 78.7
Table 4 Ablation study demonstrating the effectiveness of the two-stream architecture and FVs Model Split 1 Split 2 Weighted-average TCN-ResNet [16] Two-Stream (TEU) Two-Stream (TEU + SEU) Two-Stream (SEU + TEU) + FV
A 90.4 92.9 93.1
I 72.0 73.1 74.4
Final 65.3 69.3 69.8
A 90.0 96.2 95.0
I 69.0 75.7 79.1
Final 61.9 73.3 75.6
A 90.2 94.6 94.0
I 70.5 74.4 76.8
Final 63.6 71.3 72.7
96.9
77.7
75.3
97.3
83.7
82.2
97.1
80.7
78.7
Fig. 3 Grid search for appropriate cluster-sizes show several parameter choices provide close to peak performance. This indicates that the TCN maps can be semantically clustered in multiple ways. Search range: 2n , where n = 2, 3, 4, 5, 6, 7
200
B. Debnath et al.
and then we experiment with the two-stream architecture consisting of two parallel TCN-ResNets (Row 2). Here, TEU is introduced to Block-A (Fig. 2) of one of the streams making the stream temporal in nature, which is otherwise identical to the other stream. This greatly improves the accuracy which is further enhanced by the introduction of SEU to the spatial stream (Row 3). The final model accuracy is the greatly enhanced by the introduction of FV-based pooling to both the streams (Row 4). The number of clusters in FV is a tune-able hyper-parameter and for a grid-search was performed within the search space 2n , where n = 2, 3, 4, 5, 6, 7. The search results illustrated in Fig. 3 which show that there are several peaks indicating higher performance with multiple cluster-size settings. The best performance (78.8%) is obtained with a cluster-size (CS) of 23 for both the streams. Similar, results (78.0%) are obtained with CS is set at 8 (Spatial) and 16 (Temporal). CS of 16 (Spatial) and 64 (Temporal) gave 78.6% while CS of 64 (Spatial) and 32 (Temporal) gave 77.2% accuracy. The results suggest that the TCN maps can be semantically clustered in a more than one way. However, with increased CS the model parameters increases and thus the aim should be to keep the CS to a minimum.
7 Conclusion The paper is a step towards a robot or an automated systems-based assessment of physically impaired persons through ADL. To this end, we propose a dataset that consists of 10 different ADL that can test functional capacity of different body parts. Further, we present a normal and four different physical impairment-specific executions of each ADL. To our knowledge, the dataset is first to explore robot or an automated system’s perception of functional assessment through ADL. The paper also presents a novel multi-label functional ADL recognition model that integrates a spatial-temporal body-pose encoding method with FV-based pooling in a novel manner. The dataset and the model will be made publicly available along with the publication of this paper. Acknowledgements We are immensely thankful to Dr Helen Carey for guiding us through the data collection process and Nvidia for providing the GPU.
References 1. Baradel, F., Wolf, C., & Mille, J. (2018). Human activity recognition with pose-driven attention to rgb. In BMVC. 2. Baradel, F., Wolf, C., Mille, J., & Taylor, G. W. (2018). Glimpse clouds: Human activity recognition from unstructured feature points. In CVPR. 3. Carreira, J., & Zisserman, A. (2017). Quo vadis, action recognition? A new model and the kinetics dataset. In CVPR.
A Step Towards Automated Functional Assessment …
201
4. Da Gama, A., Fallavollita, P., Teichrieb, V., & Navab, N. (2015). Motor rehabilitation using kinect: A systematic review. Games for Health Journal, 4(2), 123–135. 5. Das, S., Dai, R., Koperski, M., Minciullo, L., Garattoni, L., Bremond, F., & Francesca, G. (2019). Toyota smarthome: Real-world activities of daily living. In Proceedings of the ICCV (pp. 833–842). 6. Debnath, B., O’Brient, M., Kumar, S., & Behera, A. (2021). Attention-driven body pose encoding for human activity recognition. In ICPR. 7. Demisse, G. G., Papadopoulos, K., Aouada, D., & Ottersten, B. (2018). Pose encoding for robust skeleton-based action recognition. In Proceedings of the CVPR. 8. Fernando, B., Gavves, E., Oramas, J., Ghodrati, A., & Tuytelaars, T. (2016). Rank pooling for action recognition. In PAMI. 9. Ferrucci, L., Koh, C., Bandinelli, S., & Guralnik, J. M. (2010). Disability, functional status, & activities of daily living. In Encyclopedia of gerontology (pp. 427–436). Elsevier Inc. 10. Girdhar, R., & Ramanan, D. (2017). Attentional pooling for action recognition. In Proceedings of the NIPS, (pp. 34–45). 11. Girdhar, R., Ramanan, D., Gupta, A., Sivic, J., & Russell, B. (2017). Actionvlad: Learning spatio-temporal aggregation for action classification. In Proceedings of the CVPR (pp. 971– 980). 12. Green, J., & Young, J. (2001). A test-retest reliability study of the barthel index, the rivermead mobility index, the nottingham extended activities of daily living scale and the frenchay activities index in stroke patients. Disability and Rehabilitation, 23(15), 670–676. 13. Habibian, A., Mensink, T., & Snoek, C. G. M. (2016). Video2vec embeddings recognize events when examples are scarce. In PAMI. 14. Hussein, N., Gavves, E., & Smeulders, A. W. M. (2017). Unified embedding and metric learning for zero-exemplar event detection. In Proceedings of the CVPR (pp. 1096–1105). 15. Ke, Q., Bennamoun, M., An, S., Sohel, F., & Boussaid, F. (2017). A new representation of skeleton sequences for 3d action recognition. In Proceedings of the CVPR (pp. 3288–3297). 16. Kim, T. S., & Reiter, A. (2017). Interpretable 3d human action analysis with temporal convolutional networks. In Proceedings of the CVPRW (pp. 1623–1631). IEEE. 17. Koppula, H. S., Gupta, R., & Saxena, A. (2013). Learning human activities and object affordances from rgb-d videos. The International Journal of Robotics Research, 32(8), 951–970. 18. Lea, C., Flynn, M. D., Vidal, R., Reiter, A., & Hager, G. D. (2017). Temporal convolutional networks for action segmentation and detection. In Proceedings of the CVPR (pp. 156–165). 19. Li, B., Dai, Y., Cheng, X., Chen, H., Lin, Y., & He, M. (2017). Skeleton based action recognition using translation-scale invariant image mapping and multi-scale deep cnn. In ICMEW. IEEE. 20. Li, W., Zhang, Z., & Liu, Z. (2010). Action recognition based on a bag of 3d points. In CVPR Workshops (CVPRW) (pp. 9–14). IEEE. 21. Liu, J., Shahroudy, A., Perez, M., Wang, G., Duan, L.-Y., & Kot, A. C. (2019). Ntu rgb+ d 120: A large-scale benchmark for 3d human activity understanding. IEEE Transactions on PAMI, 42(10), 2684–2701. 22. Liu, J., Shahroudy, A., Xu, D., & Wang, G. (2016). Spatio-temporal LSTM with trust gates for 3D human action recognition. In ECCV. 23. Liu, M., Liu, H., & Chen, C. (2017). Enhanced skeleton visualization for view invariant human action recognition. Pattern Recognition, 68, 346–362. 24. Miech, A., Laptev, I., & Sivic, J. (2017). Learnable pooling with context gating for video classification. arXiv:1706.06905. 25. Mousavi Hondori, H., & Khademi, M. (2014). A review on technical and clinical impact of microsoft kinect on physical therapy and rehabilitation. Journal of Medical Engineering. 26. Perronnin, F., & Dance, C. (2007). Fisher kernels on visual vocabularies for image categorization. In CVPR. IEEE. 27. Sathyanarayana, S., Satzoda, R. K., Sathyanarayana, S., & Thambipillai, S. (2018). Visionbased patient monitoring: A comprehensive review of algorithms and technologies. Journal of Ambient Intelligence and Humanized Computing, 9(2), 225–251.
202
B. Debnath et al.
28. Shahroudy, A., Liu, J., Ng, T.-T., & Wang, G. (2016). Ntu rgb+ d: A large scale dataset for 3d human activity analysis. In CVPR. 29. Shi, L., Zhang, Y., Cheng, J., & Lu, H. (2019). Skeleton-based action recognition with directed graph neural networks. In CVPR. 30. Sigurdsson, G. A., Varol, G., Wang, X., Farhadi, A., Laptev, I., & Gupta, A. (2016). Hollywood in homes: Crowdsourcing data collection for activity understanding. In ECCV, (pp. 510–526). Springer. 31. Song, S., Lan, C., Xing, J., Zeng, W., & Liu, J. (2017). An end-to-end spatio-temporal attention model for human action recognition from skeleton data. In Proceedings of the AAAI. 32. Sung, J., Ponce, C., Selman, B., & Saxena, A. (2011). Human activity detection from rgbd images. In AAAI Workshop. 33. Tang, Y., Tian, Y., Lu, J., Li, P., & Zhou, J. (2018). Deep progressive reinforcement learning for skeleton-based action recognition. In CVPR. 34. Tran, D., Bourdev, L., Fergus, R., Torresani, L., & Paluri, M. (2015). Learning spatiotemporal features with 3d convolutional networks. In Proceedings of the ICCV, (pp. 4489–4497). 35. Wang, J., Liu, Z., Ying, W., & Yuan, J. (2012). Mining actionlet ensemble for action recognition with depth cameras. In CVPR. IEEE. 36. Wang, J., Nie, X., Xia, Y., Wu, Y., & Zhu, S.-C. (2014). Cross-view action modeling, learning and recognition. In CVPR. 37. Webster, D., & Celik, O. (2014). Systematic review of kinect applications in elderly care and stroke rehabilitation. Journal of Neuroengineering and Rehabilitation, 11(1), 108. 38. Wei, Y., Li, W., Fan, Y., Xu, L., Chang, M.-C., & Lyu, S. (2020). 3d single-person concurrent activity detection using stacked relation network. In AAAI. 39. Xia, L., Chen, C. C., & Aggarwal, J. K. (2012). View invariant human action recognition using histograms of 3D joints. In CVPR Workshops. IEEE. 40. Xu, Z., Yang, Y., & Hauptmann, A. G. (2015). A discriminative cnn video representation for event detection. In Proceedings of the CVPR. 41. Yan, S., Xiong, Y., & Lin, D. (2018). Spatial temporal graph convolutional networks for skeleton-based action recognition. In Proceedings of the AAAI. 42. Zanfir, M., Leordeanu, M., & Sminchisescu, C. (2013). The moving pose: An efficient 3d kinematics descriptor for low-latency action recognition and detection. In Proceedings of the ICCV.
The Interpretation of Deep Learning Based Analysis of Medical Images—An Examination of Methodological and Practical Challenges Using Chest X-ray Data Steinar Valsson and Ognjen Arandjelovi´c Abstract With the increase in availability of annotated X-ray image data, there has been an accompanying and consequent increase in research on machine learning based, and particularly deep learning based, X-ray image analysis. A major problem with this body of work lies in how newly proposed algorithms are evaluated. Usually, comparative analysis is reduced to the presentation of a single metric, often the area under the receiver operating characteristic (AUROC), which does not provide much clinical value or insight, and thus fails to communicate the applicability of proposed models. In the present paper we address this limitation of previous work by presenting a thorough analysis of a state of the art learning approach, and hence illuminate various weaknesses of similar algorithms in the literature, which have not yet been fully acknowledged and appreciated. Our analysis is performed on the ChestX-ray14 dataset which has 14 lung disease labels and metainfo such as patient age, gender, and the relative X-ray direction. We examine the diagnostic significance of different metrics used in the literature including those proposed by the International Medical Device Regulators Forum, and present qualitative assessment of spatial information learnt by the model. We show that models that have very similar AUROCs can exhibit widely differing clinical applicability. As a result, our work demonstrates the importance of detailed reporting and analysis of performance of machine learning approaches in this field, which is crucial both for progress in the field and the adoption of such models in practice. Keywords Thorax · Occlusion · Interpretability · Explainability
S. Valsson · O. Arandjelovi´c (B) University of St Andrews, St Andrews KY16 9SX, Scotland, UK e-mail: [email protected] S. Valsson e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_14
203
204
S. Valsson and O. Arandjelovi´c
1 Introduction Chest X-ray is one of the most widely available and easy-to-use medical imaging tools in the diagnostics of lung disease. It is relatively inexpensive compared to other imaging techniques [11, 18]. The quality of the acquisition process and the subsequent analysis are of crucial importance as more extensive tests are often only done for acute cases due to cost or lack of availability. A wrongly interpreted X-ray images can lead to a misdiagnosis with severe consequences. Advances in the field of Machine Learning (ML) have made it possible, in principle, to automate the interpretation of X-ray images or at least assist in the process. Interpreting X-ray images can be quite challenging to do accurately. Junior doctors generally perform rather poorly on the task [4] and even specialists exhibit significant variability between readings (intra-personal variability) or one another (inter-personal variability) [15]. The difference in contrast between an anomaly and normal tissue can often be minimal and it is often virtually or literally impossible to distinguish between two conditions from an X-ray alone, and further investigation may be needed. The goal here is to emphasise the importance of interpreting model results by training and evaluating the diagnostic capabilities of a model to diagnose and localise 14 disease labels.
1.1 Previous Work As noted earlier, the focus of the present work is not on the technical approach itself, but rather on the issues related to the interpretation of the output of machine learning models trained to analyse X-ray imagery. Hence, since all but without exception, previous work suffers from much the same weaknesses (while differing in ‘under the bonnet’ technicalities), we illustrate this with a representative example—namely the work of Wang et al. [20]—without seeking to survey different learning methodologies in detail. The authors describe a data gathering and labelling process using Natural Language Processing (NLP) from radiology reports gathered from institutional Picture Archiving and Communication Systems (PACS), and train a deep CNN model to predict the label corresponding to an input X-ray image. Their experimental corpus includes labelled X-ray images and meta data such as patient ID, age, sex, and the X-ray view position (VP) (antero-posterior or postero-anterior). A total of 14 disease labels are considered: Atelectasis, Cardiomegaly, Consolidation, Edema, Effusion, Emphysema, Fibrosis, Hernia, Infiltration, Mass, Nodule, Pleural Thickening, Pneumonia and Pneumothorax, with the meaning of each being clear from the label itself. Furthermore, for approximately 1000 images the information on the locality of the label (or indeed, the disease) is provided in the form of a bounding box. The promising results reported by the authors have made this work influential, with a number of follow-up methods having been put forward by others, all bearing
The Interpretation of Deep Learning Based Analysis of Medical …
205
conceptual and methodological similarity, such as those by Baltruschat et al. [1], Rajpurkar et al. [14], Yao et al. [21], Li et al. [10], and Gündel et al. [8]. In none of the aforementioned work, except for that of Baltruschat et al. [1], is there a discussion of the shortcomings to any extent. The scores, usually Area Under Curve (AUC) for the Receiver Operating Characteristic (ROC) [7], or the F1-score, are adopted without any consideration of their clinical significance or insight in what is failing in the proposed method when it does (and failure certainly does occur often enough that it ought to have been discussed). Quantifying performance using a single numerical measure is certainly an attractive proposition: it is usually easily interpretable, quickly absorbed, and provides unambiguous rank ordering of different approaches. While this approach can be appropriate in some problem contexts, it certainly is not in the case of X-ray image analysis, when nuances in what a model is learning or basing its decisions on, can lead to significant clinical differences, yet leave a simple all-encompassing performance measure unaltered (or virtually so). The present paper sheds additional light on this issue and furthers the understanding of the effectiveness Software as a Medical Device (SaMD) may be measured.
2 Performance Quantification The Food and Drug Administration (FDA), as a part of the IMDRF, has issued guidelines for SaMDs clinical evaluation where they list a number of evaluation functions they’d like to see reported for clinical validation in future SaMDs. These are specificity, sensitivity, accuracy, and the odds ratio [3]. These metrics can all be computed from the values comprising the confusion matrix—a 2 × 2 matrix containing the empirical True Positive (TP), True Negative (TN), False Positive (FP), and False Negative (FN) ratios measured by applying a model on test data. Sensitivity, or recall, specificity, accuracy, and F1-score are thus defined as follows: TP T P + FN TN Speci f icit y = T N + FP TP +TN Accuracy = T P + T N + FP + FN TP F1 = . 1 T P + 2 × (F P + F N ) Sensitivit y =
(1) (2) (3) (4)
A high sensitivity entails that there are very few false negatives, while high specificity means that there are few false positives. Accuracy describes the proportion of correct diagnoses but has the downside of not accounting for imbalanced data as it is possible to always predict a class with very few samples as another class with
206
S. Valsson and O. Arandjelovi´c
more numerous samples and still have high accuracy. Having both sensitivity and specificity included can therefor indicate how well the SaMD performs in a relatively straightforward way. Accuracy can then be looked at with respect to the other metrics. The Diagnostic Odds Ratio (DOR), is also often used as a single indicator of diagnostic performance. Its value can range from 0 to infinity, with higher values corresponding to better performance. A value of 1 means that the a positive result is as likely to be the result of a true positive or a true negative, and a score below 1 means that there are more negative results for positive examples of a given class. The DOR is independent of sample prevalence, as apposed to accuracy and a 95% confidence interval can be calculated as ln(D O R) ± 1.96 × S E(ln(D O R)) where
S E(ln(D O R)) =
1 1 1 1 + + + TP TN FP FN
(5)
(6)
A drawback of the DOR is that it is undefined when the confusion matrix contains zero entries (i.e. in practice, if there are no false positives or false negatives). A commonly used ad hoc adjustment applied in such cases is to add 0.5 to all values in the matrix.
3 Model Training As we noted earlier, the method described by Wang et al. [20] is an influential and representative of a whole body of work in the area, and hence herein we adopt it as our baseline. We take a pre-trained network and re-training on the task specific data set—that of X-ray images. A key feature of this process is that the entire network is retrained and not just the classification layer (which is more common in the literature). In particular, we adopt the 121-layer Dense Convolutional Network (DenseNet) [9] pre-trained on the ImageNet corpus [6], and re-train on the data made available by Wang et al., using the same training-validation-test split as the original authors and the Binary Cross-Entropy loss function: (x, y) =
N 1 li N i=1
(7)
where ln = −[yn · log xn + (1 − yn ) · log(1 − xn )]
(8)
where x and y are respectively the input and the output vectors, and N is the batch size.
The Interpretation of Deep Learning Based Analysis of Medical …
207
For the localisation of salient image region corresponding to the label, we used Gradient-weighted Class Activation Mapping, or Grad-CAM, based on work by Zhou et al. [22] and further improved on by Selvaraju et al. [17], typical of a variety of saliency detection algorithms [5]. Herein we summarize the process for the reader’s benefit. Firstly, an input image is run through the model and the activations from the forward pass on the last convolutional layer saved. Then, back-propagation with respect to a given label is performed and the output gradients from the backwards pass on the same convolutional layer also saved. Next, the gradients are pooled together into a single layer and multiplied by the activations saved earlier. An average pooling is applied to the activation, per feature, leaving a H × W matrix. A ReLU function is then applied to the matrix, removing all negative feature output and the remaining features then normalized around the maximum entry in the array. At this point the Grad-CAM heatmap has been generated and can be overlayed on top of the original image. In the end, we compared two models. One that just follows the method mentioned above and another one where the network was modified to use metadata by virtue of two additional binary nodes, corresponding to a patient’s gender and the X-ray VP, in the last prediction layer. We’ll refer to the first model as the standard model and the second one as the modified model.
4 Analysis In line with the primary aims of this work, we started by assessing the different methods’ performance using the most widely used metric in the literature, namely the AUROC. Under this metric, the standard and the modified models stand on par with one another, the former achieving the AUROC value of 0.800 and the latter the marginally higher value of 0.806. We note that this is consistent with the previous reports in the literature, with the reported AUROC ranging from 0.745 (see Wang et al. [20]) to 0.806 using the method proposed by Baltruschat et al. [1]. The picture painted by comparing the per label AUROC values, shown in Table 1, is similar: on some labels one model performs somewhat better, on others the other. Weighted by the frequencies of the labels, as we saw earlier, the difference all but disappears. Both the standard and the modified model achieve nearly identical empirical AUROC scores which, as we noted already, are normally used as the metric for ranking different methods in the field. Thus, superficially, this result suggests that the two methods are performing on par. Yet, in clinical terms, which is really what is of ultimate interest, this is far from the case—a closer look shows that the models actually perform rather differently. Consider a slightly more nuanced comparison of the methods’ performances summarized in Table 2. In terms of specificity and accuracy, the standard model can be seen to be superior. This is significant. For example, the difference of 0.023 in specificity means that out of 1000 patients, 23 can be (correctly) not subjected to further investigation and tests, thereby reducing unnecessary discomfort caused and reduc-
208
S. Valsson and O. Arandjelovi´c
Table 1 Comparison of the standard and modified models using the standard AUROC score, per label and overall Label Modified model Standard model Atelectasis Cardiomegaly Consolidation Edema Effusion Emphysema Fibrosis Hernia Infiltration Mass Nodule Pleural thickening Pneumonia Pneumothorax Average
0.763 0.875 0.749 0.846 0.822 0.895 0.816 0.937 0.694 0.820 0.747 0.763 0.714 0.840 0.806
Table 2 Coarse model comparison Model Specificity Standard Modified
0.741 0.718
0.768 0.887 0.749 0.835 0.830 0.873 0.818 0.896 0.697 0.814 0.739 0.762 0.708 0.829 0.800
Sensitivity
Accuracy
DOR
0.726 0.751
0.741 0.720
9.56 10.63
ing the financial burden on the health care system. On the other hand, the modified model has a higher recall so it is more likely to detect disease present in patients that have it. The difference in recall of 0.025 means correctly diagnoses 25 more patients in a 1000 than the standard model. To contextualize this, patients and healthcare professionals were willing to exchange 2250 FP diagnoses of colorectal cancer for one additional TP diagnosis [2]. Similarly, 63% of women found >500 FPs reasonable per one life saved, and 37% would tolerate 10,000 or more [16]. Reflecting on these observations, it is neither correct to say that the methods perform comparably, nor that one is superior to the other. Rather, there are significant differences between the two, and the question which is to be preferred in a specific context is one which demands collaborative consultative effort between teams of clinicians who understand the particular operative environment of interest, and, no less importantly, medical ethicists whose role in the process is still inadequately appreciated.
The Interpretation of Deep Learning Based Analysis of Medical …
209
4.1 Understanding Data and Findings Interpretation A major concern of relevance to the efforts in the development of medical applications of machine learning concerns data used for training and testing algorithms. Notable problems include quality control (both of data itself as well as of its labelling), the clinical relevance and appropriateness of any associated annotations, data balance, and numerous others. Indeed, concerns regarding the ChestX-ray14 corpus have been raised too. Indeed, their nature mirrors the aforementioned pervasive ones: labelling accuracy (quality control), confounding information (quality control), clinical meaning of labels (quality control and clinical significance), and the usefulness of the labels (clinical significance and appropriateness) [13]. Consider the following quality control concern: since some pneumothorax images are of patients that have already been treated and who hence have a chest drain, a machine learning algorithm can learn to detect the presence of a drain and thus to correctly label the image, rather than learning to detect directly the condition itself (a similar issue in an anatomically different context was noted by Tun et al. [19]). This is illustrated in Fig. 1 which shows on the left the original image, with the drain tube indicated, and on the right the learnt class (pneumothorax) activation map. Another important observation is that an image can have more than one class label associated with it, (e.g. both ‘Pneumonia’ and ‘Infiltration’ labels can be associated with the same X-ray image). Using the same loss function used to train the network, we can compute the mean model loss as a function of the number of labels, N , associated with an image (n.b. N ranges from 0 for healthy lungs and goes up to 8, which is the maximum number of labels in this corpus). The loss increases at a linear rate with each additional label (see Table 3), suggesting that the number of labels does not effect the per label accuracy.
Fig. 1 Image labelled as ‘Pneumothorax’ after treatment with a drain tube
210
S. Valsson and O. Arandjelovi´c
Table 3 Mean model loss dependency on the number of labels per image N 0 1 2 3 4 5 6 Loss Count
0.055 9861
0.206 7992
0.346 5021
0.491 1958
0.647 572
0.827 152
0.956 31
7
8
1.134 8
1.353 1
Table 4 Mean activation of ‘Consolidation’ for single label images, across different ground truth target labels Class Mean activation Atelectasis Cardiomegaly Consolidation Edema Effusion Emphysema Fibrosis Hernia Infiltration Mass Nodule Pleural thickening Pneumonia Pneumothorax
0.134 0.023 0.084 0.075 0.244 0.011 0.006 5.8 K/uL ≤4.6 M/uL > 19.8 mg/L > 6.0 K/uL ≤11.1% ≤4.6 M/uL > 129.6 mg/L ≤11.6 g/d > 6.6 K/uL ≤ 3.4 g/dL ≤8.3% ≤11.6 g/dL ≤3.9 M/uL > 7.1 ≤3.8 g/dL ≤17.5% > 16.1 mg/dL
107
40
4
32
19
5
6
CRP ALB AST %LYMPH WBC RBC – – – CRP WBC ALB WBC AST ALB – – – – – –
253
ICU admission > 32.4 mg/L ≤ 3.5 g/dL > 23.0 U/L ≤34.0% > 3.3 K/uL ≤ 4.6 M/uL – – – > 55.8 mg/L > 5.0 K/uL ≤3.5 g/dL > 4.1 K/uL > 16.0 U/L ≤3.5 g/dL – – – – – –
∗Abbreviations: N/A, no comorbidity; COPD, chronic obstructive pulmonary disease; ALB, albumin; AST, aspartate aminotransferase; BUN, blood urea nitrogen; CRP, C-reactive protein; HGB, hemoglobin; %LYMPH, lymphocyte percentage; %MONO, monocytes percentage; RBC, red blood cell; WBC, white blood cell
model without Transition FF, demonstrating that a combination of past and current observations improves prediction performance. Our model also demonstrated better robustness in prediction for data affected by the Delta variant; it had a relatively small decrease (9.9% to 13.5%) in performance on the post-Delta data relative to pre-Delta data, while most of the other methods (listed in Table 4) showed a large decline. Nevertheless, the contribution of the post-Delta data versus the pre-Delta data requires a deeper level of analysis than this paper provides. Specifically, we believe a molecular-level analysis is necessary [15, 16]. Summary. The success of our model is due to its ability to integrate past state/observations and current observations to make predictions. The decrease in prediction performance of the proposed model without the Transition FF demonstrates the importance of temporal information in prediction. The existing state-of-the-art methods do not take into account temporal variables and have lower prediction accuracy.
254
Y. Cao et al.
Table 4 Prediction performance. Changes in performance w.r.t. pre-Delta period are provided in parentheses Method
Pre-delta
Post-delta
d=3 AUC
d=7
d=3
d=7
ACC
AUC
ACC
AUC
ACC
AUC
ACC
Proposed 0.81
0.74
0.82
0.75
0.73 (-9.9%)
0.64 (-13.5%)
0.87 (+6.1%)
0.66 (−12.0%)
Proposed w/o Tr.
0.76
0.68
0.69
0.75
0.67 0.62 (−11.8%) (−8.8%)
0.67 (−2.9%)
0.62 (−17.3%)
LR
0.85
0.77
0.83
0.74
0.62 0.57 0.57 0.57 (−27.1%) (−26.0%) (−31.3%) (−23.0%)
SVM
0.86
0.80
0.82
0.72
0.66 0.58 0.57 0.50 (−23.3%) (−27.5%) (−30.5%) (−30.6%)
Random Forest
0.88
0.79
0.69
0.70
0.70 0.50 0.73 (−20.5%) (−36.7%) (+5.8%)
0.57 (−18.6%)
Decision Tree
0.63
0.63
0.54
0.65
0.60 (−4.8%)
0.67 (+3.1%)
0.53 0.71 (−15.9%) (+9.2%)
∗The best AUC/ACC in each column is in bold. Abbreviations: Tr., Transition factor function; AUC, area under the ROC curve; ACC, accuracy; LR, logistic regression; SVM, support vector machine
5 Limitations and Future Work One limitation of our work is that we only constructed the factor functions with joint probabilities supported by statistical analysis and have not experimented with more sophisticated factor functions. For example, we could include domain knowledge from clinical experts to construct multivariate factor functions that better explain the relationships between the variables. Second, the proposed model only considers current lab/vital measurements, not patterns/trends in past measurements, when making predictions. Future work will extend the model to capture the temporal trend, i.e., the rate of change in the lab/vital measurements, of the events. In addition, performing hyperparameter tuning to emphasize different weights among the factor functions may improve the prediction performance and demonstrate the contributions of different factors. The model will be implemented as a toolset to provide advice to triaging physicians. Subsequent work will incorporate the impact of vaccines and emerging mutations automatically as a learning paradigm.
6 Conclusion We proposed a factor graph-based framework that predicts ICU admissions of hospitalized COVID-19 patients d days in advance. Our model demonstrates comparable and better performance than the state-of-the-art machine learning methods on 3-day
Predicting ICU Admissions for Hospitalized COVID-19 Patients …
255
and 7-day predictions, respectively. The relationships between comorbidities and labs/vitals captured by the model shed light on understanding ICU admissions for COVID-19, for which greater severity of comorbidities introduces a higher risk of ICU admissions. Most importantly, the model’s prediction performance is robust for the post-Delta data. Acknowledgements This work was partly supported by the Center for Computational Biotechnology and Genomic Medicine, Carle Foundation Hospital, and the Jump ARCHES endowment fund. We thank Yufu Zhang, Jorge Rodriguez Fernandez, and Jai Nebhrajani for providing and interpreting the data. We also thank our colleagues in the DEPEND group, particularly Krishnakant Saboo, Chang Hu, Anirudh Choudhary, Mosbah Aouad, Yixin Chen, Kathleen Atchley, and Jenny Applequist for their valuable feedback. The project was also supported by the National Center for Advancing Translational Sciences (NCATS), National Institutes of Health, through Grant Award Number UL1TR002003. The content is solely the responsibility of the authors and does not necessarily represent the official views of the NIH.
References 1. Vitiello, A., Ferrara, F., Troiano, V., & La Porta, R. (2021). COVID-19 vaccines and decreased transmission of SARS-CoV-2. Inflammopharmacology. https://doi.org/10.1007/s10787-02100847-2 2. Dougherty, K., Mannell, M., Naqvi, O., Matson, D., & Stone, J. (2021). SARS-CoV-2 B.1.617.2 (Delta) variant COVID-19 outbreak associated with a gymnastics facility. MMWR Morb Mortal Wkly Rep. https://doi.org/10.15585/mmwr.mm7028e2external. 3. Mlcochova, P., Kemp, S., Dhar, M., et al. (2020). SARS-CoV-2 B. 1.617.2 delta variant replication and immune evasion. Nature, 599, 114–119. 4. Galanter, W., Rodríguez-Fernández, J., Chow, K., et al. (2021). Predicting clinical outcomes among hospitalized COVID-19 patients using both local and published models. BMC Medical Informatics and Decision Making. https://doi.org/10.1186/s12911-021-01576-w 5. Zhang, J., Jun, T., Frank, J., et al. (2021). Prediction of individual COVID-19 diagnosis using baseline demographics and lab data. Scientific Reports. https://doi.org/10.1038/s41598-02193126-7 6. Jiang, X., Coffee, M., Bari, A., et al. (2020). Towards an artificial intelligence framework for data-driven prediction of coronavirus clinical severity. Computers, Materials and Continua. https://doi.org/10.32604/cmc.2020.010691 7. Iwendi, C., Bashir, A., Peshkar, A., et al. (2020). COVID-19 patient health prediction using boosted random forest algorithm. Frontiers in Public Health. https://doi.org/10.3389/fpubh. 2020.00357 8. Loeliger, H. (2004). An introduction to factor graphs. IEEE Signal Processing Magazine. https://doi.org/10.1109/MSP.2004.1267047 9. Varatharajah, Y., Chong, M., Saboo, K., et al. (2017). EEG-GRAPH: A factor-graph-based model for capturing spatial, temporal, and observational relationships in electroencephalograms. In: NeurIPS (pp. 5372–5381). 10. Yang, Y., Walter, L., Lu, L., et al. (2014). Forecasting potential diabetes complications. In: Proceedings of 28th AAAI Conference on Artificial Intelligence (pp. 313–319). 11. Cao, P., Badger, E., Kalbarczyk, Z., Iyer, R., & Slagell, A. (2015). Preemptive intrusion detection: Theoretical framework and real-world measurements. In: Proceedings of the 2015 Symposium and Bootcamp on the Science of Security. https://doi.org/10.1145/2746194.2746199. 12. Cao, P. (2019). On preempting advanced persistent threats Using probabilistic graphical models. arXiv:1903.08826 [cs.CR].
256
Y. Cao et al.
13. Yedidia, J., Freeman, W., & Weiss, Y. (2005). Constructing free-energy approximations and generalized belief propagation algorithms. IEEE Transactions on Information Theory. https:// doi.org/10.1109/TIT.2005.850085 14. Harrison, S., Fazio-Eynullayeva, E., Lane, D., et al. (2020). Comorbidities associated with mortality in 31,461 adults with COVID-19 in the United States: A federated electronic medical record analysis. PLOS Medicine. https://doi.org/10.1371/journal.pmed.1003321 15. McCallum, M., Walls, A., Sprouse, K., et al. (2021). Molecular basis of immune evasion by the Delta and Kappa SARS-CoV-2 variants. Science. https://doi.org/10.1126/science.abl8506 16. Ghosh, A., Kaiser, M., Molla, M., et al. (2021). Molecular and serological characterization of the SARS-CoV-2 Delta variant in Bangladesh in 2021. Viruses. https://doi.org/10.3390/ v13112310
Semantic Network Analysis of COVID-19 Vaccine Related Text from Reddit Chad A. Melton, Jintae Bae, Olufunto A. Olusanya, Jon Hael Brenas, Eun Kyong Shin, and Arash Shaban-Nejad
Abstract Vaccinations are critical and effective in resolving the current pandemic. With the highly transmissible and deadly SARS-CoV-2 virus (COVID-19), a delay in acceptance, or refusal of vaccines despite the availability of vaccine services poses a significant public health threat. Moreover, vaccine-related hesitancy, mis/disinformation, and anti-vaccination discourse are hindering the rapid uptake of the COVID-19 vaccine. It is urgent to examine how anti-vaccine sentiment and behavior spread online to influence vaccine acceptance. Therefore, this study aimed to investigate the COVID-19 vaccine hesitancy diffusion networks in an online Reddit community within the initial phase of the COVID-19 pandemic. We also sought to assess the anti-vaccine discourse evolution in language content and style. Overall, our study findings could help facilitate and promote efficient messaging strategies/campaigns to improve vaccination rates. Keywords Semantic network analysis · COVID-19 vaccines · Misinformation · Online social media · Reddit
C. A. Melton · A. Shaban-Nejad (B) The Bredesen Center for Interdisciplinary Research and Graduate Education, University of Tennessee, Knoxville, USA e-mail: [email protected] J. Bae · E. K. Shin Korea University, Seoul, South Korea e-mail: [email protected] E. K. Shin e-mail: [email protected] C. A. Melton · O. A. Olusanya · A. Shaban-Nejad Center for Biomedical Informatics, Department of Pediatrics, College of Medicine, University of Tennessee Health Science Center, Memphis, TN, USA e-mail: [email protected] J. H. Brenas Sanger Institute, Cambridge, UK e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_18
257
258
C. A. Melton et al.
1 Introduction The severe acute respiratory syndrome coronavirus 2, or SARS-COV-2 is responsible for the Coronavirus Disease 2019 (COVID-19) which has profoundly impacted the globe causing significant morbidity and mortality [1]. An unprecedented and accelerated effort was made to develop the COVID-19 vaccine since long-term pandemic containment and recovery were contingent upon the collective acceptance and uptake of the vaccine. Moreover, the COVID-19 vaccine is proven to be safe and effective at preventing life-threatening COVID-19 infections, hospitalizations, and deaths [2]. However, despite the many benefits, vaccine mis/disinformation, myths, false conspiracy theories, rumors, skepticism, and mistrust propagated through online digital platforms have driven vaccine hesitancy and compromised the effectiveness of scientific-based evidence on vaccines. As a result, vaccine hesitancy is classified by the World Health Organization to be among the 10 threats to global health [3]. Digital platforms provide the heuristic locus to fully comprehend the public discourse and sentiments regarding the COVID-19 pandemic [4–12]. Specifically, online digital data offers access to real-time which facilitates the collection, management, retrieval, mining, and interpretation of information to trace trends as well as gauge public sentiments, networks, and behaviors influencing vaccine acceptance. Therefore, in order to develop efficient messaging strategies and intervention campaigns to address the information crises and improve vaccination rates, it is essential to examine how multiple underlying factors ranging from personal, religious, political, social, etc. within the online interactive digital space influence vaccination decision making. However, a few studies have sought to understand the impacts of social networks and interactions on epidemic-related discourse within online digital platforms. Herein, our study objective was to examine how the public’s antivaccine sentiment and behavior influenced vaccine acceptance via an online social media platform. This work is part of an ongoing study to explore vaccine-related content in social media with a focus on identifying, characterizing, and combating misinformation/disinformation.
2 Data We collected approximately 300,000 posts and comments from 12 online communities or subreddits from the social media platform, Reddit. Subreddit members have the option to post links, images, videos, and text which community users can upvote or downvote based on their sentiment of the post and/or leave comments. The upvote/downvote system within Reddit is intended to increase the quality of the posts to minimize non-relevant material from a community. These communities often have rules that members must follow or risk the deletion of a post or a community ban. Though these community rules have the potential to create echo chambers, subreddits are typically less segregated than Facebook groups [13]. In
Semantic Network Analysis of COVID-19 Vaccine …
259
our dataset, a few subreddits actively remove posts containing misinformation (e.g., r/CovidVaccine) but must rely on other users to report the occurrence to a moderator while other communities are more open to discussions involving misinformation and do not remove posts unless considered dangerous and detected/reported by Reddit officials. In some occurrences, entire subreddits have been removed due to rampant misinformation/disinformation, harassment, and threats of violence (r/NoNewNormal). To avoid this echo chamber effect, we attempted to choose an ideologically heterogeneous collection of subreddits (r/CovidVaccinated, r/Coronavirus, r/CovidVaccine, r/conspiracy_commons, r/COVID19_support, r/COVID19, r/AntiVaxxers, r/VACCINES, r/conspiracytheories, r/China_Flu, r/COVID19positive, r/conspiracy) with a total of 5.1 million. Data were cleaned and queried for posts related to COVID-19 vaccines/vaccination. Our final amalgamated dataset consisted of 31,432 posts/comments authored by 20,429 users between January 1, 2020, and Oct 31, 2021. Finally, these posts/comments received a total of approximately 1.26 million votes, indicating a high degree of community interaction. The majority of subreddits were similar in posting frequency for the first months of our timeframe. Posting in several subreddits rapidly increased over time with the greatest increase occurring in April 2021 as a result of more widespread vaccine availability (e.g., Fig. 1).
3 Methods Our approach can unveil the unbounded network structure and facilitate vaccine hesitancy discourse online. To accomplish this task, we conducted a temporal semantic network analysis to observe graph evolution over time. Network analysis [14, 15] uses a graphical representation of nodes and edges to provide insight into data that may not be observable upon the surface. In computational semantic network analysis, nodes represent tokenized words while Edges represent a connectedness between nodes. Our multistep analysis was conducted with the Python library, Networkx [16, 17]. After removing stop words (e.g., determiners, conjunctions, and prepositions), lemmatizing and normalizing the corpus (i.e., converting a word to its base form), a word co-occurrence network was created with nx.Graph(). Due to the hairball effect that occurred while including each word in a monthly corpus, networks included in this document were limited to less than or equal to 160 nodes by setting an appropriate weight threshold determined by iterating through the corpus. After removing isolates, betweenness centrality (BC) networks were created for each month in our dataset. Essentially, BC displays the importance of a node (i.e., word) based on calculating the number of times a node is included in the shortest route between other nodes [18]. Ego networks were centered around the word “vaccine” and built for each month. To detect the tendency for clustering to occur within our data, transitivity was calculated for the complete corpus as well as for the scaled-down versions with limited node counts. Graph density and networks/subreddit statistics were collected. Lastly,
260
C. A. Melton et al.
Fig. 1 Reddit posting-frequency over time. Subreddits are symbolized by lines of varying color
networks were also visually inspected to verify coherence in our interpretation of the results.
4 Results Overall, the Giant network node and edge count generally increased with time, exhibiting fluctuations in certain months. The Vaccine Ego (VE) network behaved similarly although the node and edge count were much less due to the nature of an ego network. Transitivity values for the complete corpus or Giant network decreased over time and ranged between 0.13 (Dec 2020) and 0.21 (Jan 2020) with an average value of 0.15. Greater node clustering occurred within this downtrend during Aug 2020, Oct 2020, April 2021, and July 2021. The VE net exhibited similar characteristics with increased clustering in June 2020, Oct 2020, April 2021, and July 2021. Density for both Giant and VE networks tended to ebb and flow from month to month but decreased over time as well due to an increasing post quantity (e.g., Figs. 2 and 3).
Semantic Network Analysis of COVID-19 Vaccine …
261
Fig. 2 Transitivity over time. The orange line represents the Vaccine Ego network and the blue line represents the Giant network transitivity
Fig. 3 Network density over time. The blue line represents the Vaccine Ego network density and the orange line represents the Giant network density
262
C. A. Melton et al.
Betweenness Centrality Network for September 2020
Fig. 4 Betweenness centrality for September 2020. Nodes are indicated by orange circles and edges are indicated lines. Node size is reflective of the weight and edge thickness indicative of interconnectedness between nodes. Note the vaccine centered cluster and the vitamin and d centered clusters are mainly connected by the Covid node in between the two clusters
For the betweenness centrality networks, centrality ranged from 0.21 for the node immunity to 0.892 for the node vaccine. Variations of the word “vaccine” (i.e., vaccine, vaccines) appeared in the top 10 highest BC values throughout each month in our data set. Furthermore, we observe changes in centrality values related to nodes that represent some terms common in COVID-19 misinformation. For example, in September 2020 (e.g., Fig. 4), the nodes vitamin and d are connected to a few terms related to COVID19. Centrality for the nodes was calculated to be 0.16 and 0.20 respectively and amongst the top five for the month. As the occurrence of these two nodes diffuses over time, the centrality values diminished substantially to vitamin (0.0004725 and rank 513/5894) and d (0.00055 and rank 461/5894) in October 2020. However, d (0.001193 and rank 253/9301) and vitamin (0.00168 and rank 177/9301) both increased but then increased again in December 2020. Nodes indicating vaccine hesitancy were also observable in our networks and represented by example keywords such as scared or worried. However, visual inspection of some comments revealed the intent of vaccination even though the user experienced anxiety related to the vaccine. For example, a comment from April 2021 said, “I am going to get my shot today! Half excited, half scared. Not scared from like conspiracy theory stuff lol, but I have had systemic allergic reactions before, so yeah a little nervous there.” The ego network focused on the node vaccine provides interesting insight into vaccine discussion directly related to COVID-19 related terms. In February 2020, the vaccine is significantly connected to the nodes flu, people, and g (i.e., G-protein). The node coronavirus is present in this network. However, the connection to the
Semantic Network Analysis of COVID-19 Vaccine …
263
central node is so infrequent that edges are not visible (see Fig. 5a). Moving six months forward to August 2020, the flu node is still present but is far less prevalent than before, as terms related to COVID-19 (e.g., sars, covid, virus have begun to dominate online discussion (see Fig. 5b). The network for October 2020 continues to display connectedness with terms related to COVID-19 along with other terms now typical of this pandemic (i.e., masks, and deaths). This month also begins to display other terms related to vaccine testing and manufacturers (i.e., Moderna, placebo, and efficacy) (see Fig. 5c). The greatest increase in node/edge quantity the second largest spike for transitivity occured in April 2021. With the successful rollout of COVID-19 vaccines, a multitude of terms appears related to vaccination, types of vaccines, most
Fig. 5 a (top left), b (top right), c (bottom left), d (bottom right): Vaccine Ego Network for February 2020, August 2020, October 2020, and April 2021. Nodes are indicated by gold circles and edges are indicated lines. Node size is reflective of the weight and edge thickness indicative of interconnectedness between nodes
264
C. A. Melton et al.
vaccine manufacturers, as well as a wide variety of side-effects ranging in severity (see Fig. 5d).
5 Discussion As we expected, changes in our networks reflect the dynamic conditions and events that have occurred since the first COVID-19 cases were detected. Semantic as well as network structural changes are observable in the Giant and VE networks in several significant shifts as COVID-19 spread throughout the world. For example, in early Jan 2020, a small number of nodes are visible representing words associated with COVID 19 but the Giant network does not display interconnectedness between the node Vaccine and other nodes representing COVID19. These occurrences rapidly increase as infection rates climb and online discussion shifts towards vaccines for COVID-19. Nodes tend to reflect conversation regarding side effects (e.g., fever, sore arm, body aches, etc.) as vaccines become more readily available. In our data set, the large increase in posts from r/CovidVaccinated in April 2021 contributes to vaccine side-effects interconnectedness and appearance as well. Moreover, this occurrence was also expected based on previous topic modeling studies (Melton et al. 2021). Unfortunately, nodes representing misinformation keywords become more apparent as the interconnectedness with the node “Vaccines” increases in conjunction with COVID-19 keywords (e.g., Vitamin D, autism, Bill Gates, Big Pharma) in several months. Visual analysis of the raw text data suggested a wide range of vaccine hesitancy behaviors as well. These behaviors included hesitancy due to fear of vaccine side effects, feelings of “threatened freedoms”, false expertise, ignorance of how vaccines work, “big pharma motivated pandemic conspiracy, antivaccination beliefs, and many others (see https://github.com/Cheltone/W3PHIAI2022 for supplementary materials including plots, data, and tables). Our study has some limitations. Though great care was taken to create an unbiased data set, the possibility of some potential biases still exists, including selection bias from our choice of subreddits. Obtaining purely non-biased data is a challenging aspect of many scientific domains, and is especially important with sentiment analysis, topic modeling, and semantic network modeling because unsupervised learning methods may cluster topics in noncoherent ways. Results of unsupervised classification methods are often challenging to evaluate for similar reasons. Conducting a manual semantic analysis [19] of a sample of our dataset could offer further insights into discourse occurring in social media. Moreover, utilizing graph mining algorithms or comparison with topic modeling (Latent Dirichlet Allocation) techniques could bolster our results. Lastly, it is conceivable that the significant increase in posts in r/CovidVaccinated that occurred in April 2021 could have overwhelmed the network structure with nodes concerning vaccine side effects in graphs from April 2021 through October 2021. A comparison of the data set without r/CovidVaccinated could reveal discussions involving misinformation/disinformation in the other subreddits.
Semantic Network Analysis of COVID-19 Vaccine …
265
6 Conclusion We conducted a betweenness centrality analysis of the Vaccine ego networks using approximately 31,000 comments/posts harvested from 12 subreddits. Our analysis found significant mentioning of COVID-19 and COVID-19 vaccine misinformation/disinformation, along with other vaccine-hesitant content. Ongoing work by our team is focusing on exploring other measures (i.e., semantic centrality, degree centrality, eigenvector centrality, PageRank) as well as tracing the diffusion on nodes specific to misinformation throughout our data set. Future work will also explore the evolution of semantic networks along with user activities. Because users from different thought communities contribute differently to semantic networks, it is crucial to understand both users and their activities. Tracing activity logs of Reddit users along with their posted contents, we expect to detect the diffusion pathways and their associated focal actors. We also plan to work on a user-subreddit bipartite network to examine the dynamics of vaccine discourse throughout a wider community level. These next steps will ultimately guide the development of a precision digital intervention tool that can target misinformation.
References 1. Ortiz-Prado, E., Simbaña-Rivera, K., Gómez-Barreno, L., Rubio-Neira, M., Guaman, L. P., Kyriakidis, N. C., & López-Cortés, A. (2020). Clinical, molecular, and epidemiological characterization of the SARS-CoV-2 virus and the Coronavirus Disease 2019 (COVID-19), a comprehensive literature review. Diagnostic Microbiology and Infectious Disease, 98(1), 115094. 2. Rosenberg, E. S., Holtgrave, D. R., Dorabawila, V., Conroy, M., Greene, D., Lutterloh, E., Backenson, B., Hoefer, D., Morne, J., Bauer, U., & Zucker, H. A. (2021). New COVID-19 cases and hospitalizations among adults, by vaccination status—New York, May 3–July 25, 2021. Morbidity and Mortality Weekly Report, 70(37), 1306. 3. World Health Organization (WHO). (2019). Ten threats to global health in 2019. https://www. who.int/news-room/spotlight/ten-threats-to-global-health-in-2019 4. Broniatowski, D. A., Paul, M. J., & Dredze, M. (2013). National and local influenza surveillance through Twitter: An analysis of the 2012–2013 influenza epidemic. PLoS ONE, 8(12), e83672. 5. Brownstein, J., & Freifeld, C. (2007). HealthMap: The development of automated real-time internet surveillance for epidemic intelligence. Eurosurveillance Weekly, 12(11), E071129. 6. Culotta, A. (2010). Towards detecting influenza epidemics by analyzing Twitter messages. Paper presented at the Proceedings of the first workshop on social media analytics. 7. Fung, I.C.-H., Tse, Z. T. H., & Fu, K.-W. (2015). The use of social media in public health surveillance. Western Pacific Surveillance and Response Journal, 6(2), 3–6. 8. Gu, H., Chen, B., Zhu, H., Jiang, T., Wang, X., Chen, L., & Jiang, J. (2014). Importance of Internet surveillance in public health emergency control and prevention: Evidence from a digital epidemiologic study during avian influenza A H7N9 outbreaks. Journal of medical Internet Research, 16(1), e20. 9. Mollema, L., Harmsen, I. A., Broekhuizen, E., Clijnk, R., De Melker, H., Paulussen, T., Das, E. (2015). Disease detection or public opinion reflection? Content analysis of tweets, other social media, and online newspapers during the measles outbreak in The Netherlands in 2013. Journal of medical Internet research, 17(5).
266
C. A. Melton et al.
10. Salathé, M., Freifeld, C. C., Mekaru, S. R., Tomasulo, A. F., & Brownstein, J. S. (2013). Influenza A (H7N9) and the importance of digital epidemiology. The New England Journal of Medicine, 369(5), 401. 11. Shin, E. K., & Shaban-Nejad, A. (2017). Public Health Intelligence and the Internet: Current State of the Art. In A. Shaban-Nejad, J. S. Brownstein, & D. L. Buckeridge (Eds.), Public Health Intelligence and the Internet (pp. 1–17). Springer International Publishing. 12. Zhang, E. X., Yang, Y., Di Shang, R., Simons, J. J. P., Quek, B. K., Yin, X. F., & Ling, V. R. Y. (2015). Leveraging social networking sites for disease surveillance and public sensing: The case of the 2013 avian influenza A (H7N9) outbreak in China. Western Pacific Surveillance and Response Journal, 6(2), 66–72. 13. Cinelli, M., Morales, G. D. F., Galeazzi, A., Quattrociocchi, W., & Starnini, M. (2021). The echo chamber effect on social media. Proceedings of the National Academy of Sciences, 118(9). 14. Melton, C., Olusanya, O. A., & Shaban-Nejad, A. (2021). Network analysis of COVID19 vaccine misinformation on social media. Studies in Health Technology and Informatics, 18(287), 165–166. https://doi.org/10.3233/SHTI210839 PMID: 34795104. 15. Shin, E. K., & Shaban-Nejad, A. (2019). Applied network science for relational chronic disease surveillance. Studies in Health Technology Informatics, 4(262), 336–339. https://doi.org/10. 3233/SHTI190087 PMID: 31349336. 16. Developers, NetworkX. NetworkX documentation. (2012). 17. Srinivasa-Desikan, B. (2018). Natural language processing and computational linguistics: A practical guide to text analysis with Python, Gensim, spaCy, and Keras. Packt Publishing Ltd. 18. Linton, C. (1977). Freeman: A set of measures of centrality based on betweenness. Sociometry, 40(1), 35–41. 19. Brien, S, Naderi, N, Shaban-Nejad, A, Mondor, L, Buckeridge, D. L. (2013). Vaccine attitude surveillance using semantic analysis: Constructing a semantically annotated corpus. In WWW (Companion Volume) 2013, 13–17 May 2013 (pp. 683–686). Rio de Janeiro, Brazil: ACM Press. https://doi.org/10.1145/2487788.2488023
Towards Providing Clinical Insights on Long Covid from Twitter Data Rohan Bhambhoria, Jad Saab, Sara Uppal, Xin Li, Artur Yakimovich, Junaid Bhatti, Nirma Khatri Valdamudi, Diana Moyano, Michael Bales, Elham Dolatabadi, and Sedef Akinli Kocak
Abstract From the outset of the COVID-19 pandemic, social media has provided a platform for sharing and discussing experiences in real time. This rich source of information may also prove useful to researchers for uncovering evolving insights into post-acute sequelae of SARS-CoV-2 (PACS), commonly referred to as Long COVID. In order to leverage social media data, we propose using entity-extraction methods for providing clinical insights prior to defining subsequent downstream tasks. In this work, we address the gap between state-of-the-art entity recognition models and the extraction of clinically relevant entities which may be useful to provide explanations for gaining relevant insights from Twitter data. We then propose an approach to bridge the gap by utilizing existing configurable tools, and datasets to enhance the capabilities of these models. Code for this work is available at: https:// github.com/VectorInstitute/ProjectLongCovid-NER. Keywords Interpretability · Entity extraction · Healthcare R. Bhambhoria (B) Queen’s University, Kingston, ON, Canada e-mail: [email protected] J. Saab Telus Communications Inc., Vancouver, BC, Canada e-mail: [email protected] S. Uppal Telus Communications Inc., Vancouver, BC, Canada e-mail: [email protected] X. Li University of Toronto, Toronto, ON, Canada e-mail: [email protected] Vector Institute, Toronto, ON, Canada A. Yakimovich Roche Products Ltd., Welwyn Garden City, UK e-mail: [email protected] J. Bhatti Manulife, Toronto, ON, Canada e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_19
267
268
R. Bhambhoria et al.
1 Introduction Since the emergence of the SARS-CoV-2 virus in late 2019, the resulting COVID19 pandemic has brought many challenges for patients, healthcare professionals, and society. At onset, one particular challenge was a lack of knowledge about both acute and long-term symptoms from infection, and appropriate means of treating or managing them. Before research publications emerged, clinicians had to rely on anecdotal information to guide patient care decisions [27]. Post-acute sequelae of SARS-CoV-2, or Long COVID, describes a condition in which patients have symptoms persisting or recurring for weeks or months following acute COVID-19 infection [7]. Currently, the risk factors leading to the persistence, late onset, or recurrence of symptoms in previously infected COVID-19 patients are unclear [23]. Given the SARS-COV-2 virus is now widely believed to be endemic to the global population [15], it is important to understand the broader impact of the disease, beyond acute infection, and appropriate therapeutic approaches to optimally manage the various patient sub-populations and their disease profiles. Social media is often turned to by patients as an outlet to express their experience with illness [21]. As such, the authors propose social media may be a useful aid for researchers and clinicians in gaining novel insights into clinical characteristics arising due to emergent illnesses. In this work, we explore the utility of publicly-available, self-reported, user-generated conversations on social media towards capturing terms related to Long COVID symptoms, recoveries and experiences. These terms can then guide investigation into clinical databases, thereby accelerating the discovery, study, description, and hopefully treatment of unexpected consequences from COVID-19. With a few adjustments, the same data processing pipeline may be extensible to other applications, such as emerging infectious diseases or rare and neglected diseases. Adapting and extending natural language processing (NLP) techniques to extract patient experiences, including symptom evolution and treatment approaches, from user generated data can open doors to better research and characterize the Long N. K. Valdamudi University of British Columbia, Vancouver, BC, Canada e-mail: [email protected] D. Moyano Vector Institute, Toronto, ON, Canada e-mail: [email protected] M. Bales Hoffmann-La Roche Ltd., Mississauga, ON, Canada e-mail: [email protected] E. Dolatabadi Vector Institute, Toronto, ON, Canada e-mail: [email protected] S. A. Kocak Vector Institute, Toronto, ON, Canada e-mail: [email protected]
Towards Providing Clinical Insights on Long Covid from Twitter Data
269
Fig. 1 Our proposed framework for providing insights through entity extraction for Long COVIDRelated tasks
COVID phenomenon (Fig. 1). However, health language processing with social media is challenging due to the nature of user generated data, such as informal language, short context, and content that is noisy and sparse [13]. To respond, this study aims to utilize deep-learning based entity extraction in order to facilitate decisionmaking of crucial importance to clinicians, whilst also being interpretable. In this study we also incorporate Metathesaurus, a repository of inter-related biomedical concepts from the Unified Medical Language System (UMLS) [3], and our own rule-based matching systems from publicly available datasets in our entity extraction pipeline to enhance its capability. In this work, our contributions are three-fold: • First, we propose utilizing entity-extraction methods to provide insights into Long COVID experiences as expressed by patients. • Secondly, we empirically show that the performance capabilities of state-of-theart models on existing datasets fall short of human evaluation for relevant Long COVID-related terms. We then propose using MetaMapLite and existing datasets as a way of bridging this gap. • Finally, we release a dataset and discuss methods which may be utilized in future experiments for augmentation to enable improved extraction of entities from social media text using configurable tools and public datasets. We also release a human annotated test set to serve as a benchmark for this task.
2 Related Work 2.1 Social Media Platform for COVID-19 Analysis of social media data rapidly became popular among researchers and epidemiologists to identify and detect outbreaks of infectious diseases and to interpret public attitudes, behaviours, and perceptions. A systematic review by Tsao and colleagues [25] identified various categories of Twitter use for health research such as surveillance and monitoring [9, 17], and disease control [28]. Among the selected studies in the review, Twitter was the leading social media platform to explore multi-
270
R. Bhambhoria et al.
ple facets of public health research. Mackey and colleagues [11] employed machine learning approaches to investigate COVID-19-related symptoms, experiences with access to testing, and mentions of disease recovery from 4.5 million tweets that were related to COVID-19. Another study conducted by Guo and colleagues [8] was able to extract all the symptoms suggested by Centers for Disease Control and Prevention (CDC) for COVID-19 screening in March, April, and May from 30,732 unique tweets. The findings of this study revealed that mining social media is a promising approach to identify COVID-19-related symptoms earlier than the announcement by the CDC. Motivated by the crucial role of social media in public health surveillance, this study explores mining self-reported experiences from Twitter to identify symptoms related to the “long-haulers,” in response to the existing scientific and clinical knowledge gap and many unknowns surrounding Long COVID [1].
2.2 Clinical Information Extraction To date, various NLP models have been developed to automate extraction of clinical information from both healthcare data and user generated content. In this regard, several challenges and shared tasks have been organized to establish state-of-the-art benchmarks and advance explorations on important information extraction problems in NLP such as entity extraction and normalization from clinical notes and social media posts. In recent years, top performing models in these competitions used transformer architectures for downstream NLP tasks [2, 16, 24, 26]. For instance, in the SMM4H 2021 [12] shared tasks which were organized to promote the use of social media such as Twitter data (tweets) for health applications, the percentage of teams that used either BERT or its variants such as RoBERTA for most of the subtasks was nearly 100%. The dramatic success of transformer models for rich contextual representations led to the creation of various domain specific BERT models such as COVID-Twitter-BERT (CT-BERT) [14] which was trained on a large corpus of tweets about the coronavirus. The model was shown to outperform original BERTLARGE model on five different Twitter-based datasets. Another variant of the BERT model called UmlsBERT [10] incorporates UMLS Metathesaurus in the pre-training phase in order to build ‘semantically enriched’ contextual representations that benefit from both the contextual learning (BERT architecture) and the domain knowledge (UMLS Metathesaurus). Given the promising performance of UmlsBERT on clinical NLP tasks, we extend its use to extract clinical terms from social media.
2.3 Interpretability Several works explored interpretability using attention towards high-stakes decision making [4]. However, recent research sparked a debate on whether attention may provide what we consider as explanations which are faithful to the predictions of the
Towards Providing Clinical Insights on Long Covid from Twitter Data
271
model [20, 29]. From [18], we have seen that using other methods such as post-hoc interpretability do not provide us with explanations that may be seen to have high fidelity. Towards this end, in this work we explore utilizing extracted entities which can be regarded as providing explanations to clinicians for gaining insights into Long COVID. We are the first, to the best of our knowledge, to present a novel approach for enhancing state-of-the-art Name Entity Recognition (NER) models with UMLS Metathesaurus and a benchmark social media dataset to improve entity-extraction related to Long COVID from social media.
3 Dataset In this section, we explain data set acquisition, de-identification, preparation and filtering Long COVID self reports details.
3.1 Data Acquisition and De-identification The dataset was acquired using academic access to Twitter API v2. Using the API, we acquired 466, 651 data points between August 2019 and June 2021. Each data point contains a tweet, timestamp, geographical coordinates if available, user’s location (user-defined) and user’s profile description. All tweets were retrieved using a list of hashtags (e.g., longcovid, covidlong, longhauler, etc.) and words included in tweets (e.g., long-hauler, chronic symptoms, long-term effects) among others. Some exclusion words related to collaborators’ products were also part of the search criteria. Retweets, replies, quotes and nullcast were also excluded from the dataset, as well as any tweets that were not in English. Moreover, apart from the inclusion and exclusion criteria, all usernames and mentions were hashed to pseudonymize the data. Following the de-identification step, all URLs and special characters were removed. We also transformed all text to lowercase, and expanded contractions. The average character length in our clean dataset is 133, and average word count is 23. Top 5 unigrams are long, haul, covid, term, and effects. Top 5 bigrams are long haul, long term, term effects, long covid, and covid 19 Top 5 trigrams are the long haul, for the long, long term effects, the long term, and a long haul.
272
R. Bhambhoria et al.
3.2 Filtering Long COVID Self Reports We utilize ReGEX filters using personal pronouns (e.g. I, Me, My, Mine, Myself ) and expressions of feeling (e.g. feel, experience, symptoms) to capture self reports and exclude any tweets from news outlets or other irrelevant discussions. We were able to extract self reports with 78% accuracy on a manually curated dataset of 1000 identified tweets using this simple regex based filtering method. Although this method cannot capture all self-reports, it is effective enough to collect a large number of relevant tweets for our analysis. As future work, the method to identify self reports can be replaced by a text classifier given enough annotation data. Such a method could both improve the filtering accuracy and also reduce the irrelevant tweets.
4 Methodology We made use of the following transformer-based models to extract entities from the tweets by utilizing the n2c2 (2010) dataset [26] for fine-tuning: • COVID-Twitter BERT v2 [14]: A BERT-large uncased model pretrained on 97M unique tweets. We utilize this model in our experiments as several COVID-related terms are contained within the training corpus which may serve as a strong base model for subsequent fine-tuning on entity-extraction tasks. • UmlsBERT [10]: This model considers the connection between medical works through the usage of Concept Unique Identifiers (CUI) of the Unified Medical Language System (UMLS) software that brings together a number of health and biomedical vocabularies to support interoperability [3]. In addition, we passed our corpus of tweets through UMLS’ MetaMapLite [6] to augment entities extracted by the transformer models as shown in Fig. 2. MetaMapLite is a tool that uses NLP and computational linguistic techniques to extract entities associating with UMLS’ Concept Unique Identifiers (CUIs). CUIs capture a wide range of clinical concepts that fall under various categories, or semantic types, such
Apply UMLS MetaMapLite
Original Tweets
AMIA Task 3 Data (MedDRA codes)
Clean Tweets and filter for self-reports
Apply UMLS Metathesaurus
“I am fatigued from long covid and...” CUI: C0015672 Preferred Name: Fatigue Semantic Type: ‘sosy’
Preprocessed Tweets
AMIA Task 3 Data (CUI codes)
Filter Tweet entities by semantic type (e.g. ‘sosy’)
Search Tweets for AMIA terms not found by MetaMapLite
Fig. 2 Pipeline used to label entities with CUI codes
Tweets with labelled entities (CUI codes)
Towards Providing Clinical Insights on Long Covid from Twitter Data
273
as signs and symptoms, or sosy (e.g., coughing), and disease or syndrome, or dsyn (e.g., Influenza). To compensate for MetaMapLite’s limited coverage of colloquial expressions (e.g., brain fog), we introduced an additional approach, also shown in Fig. 2. We started with the AMIA Task 3 dataset, which consists of clinical concepts from tweets (e.g., symptoms, adverse drug reactions) and their corresponding, human-assigned Medical Dictionary for Regulatory Activities (MedDRA) codes [19]. MedDRA is a medical terminology dictionary that is frequently used by regulatory authorities and the biopharmaceutical industry. Using the UMLS Metathesaurus, a biomedical thesaurus that links synonymous names from over 200 source vocabularies, we first mapped these MedDRA codes to CUIs. Then, for each concept in the AMIA dataset, we searched our tweets for matches wherein, (i) the AMIA concept appears in the text and, (ii) it was not already captured by MetaMapLite. For N tweets, we have extracted entities from human analysis given by {e1 , e2 , ...en }∀ei ∈ E and model analysis given by {e1 , e2 , ...en }∀ei ∈ E . We calculate MatchCount and MatchCount as defined by Eqs. 1 and 2, where E and E are the set of entities extracted by humans and the model, respectively, with duplicated extracted entities removed. Results of 200 tweets provided to annotators from this process are shown in Table 3, having trained 4 annotators on this task of entity extraction. N E MatchCount = N n=1 (1) (E ∪ E ) n=1
N
MatchCount = N
n=1
n=1 (E
E ∪ E )
(2)
5 Results and Discussion From Table 1, we observe that transformer-based models have picked up most of the terms as humans. Occasionally, these models have extracted non-clinical descriptions of symptoms, e.g., my heart rate instead of heart rate. Other times, models failed to retrieve relevant terms, in example tweet 2 the model did not identify cannot taste or smell. As observed in Table 2, state-of-the-art model architectures such as UmlsBERT solely trained on datasets with the ability to capture “Problem”, “Treatment”, or “Test”, albeit achieving strong performance on a test set from the same distribution (Table 2), are unable to generalize well to unseen samples, as is observed from the “Model Output” in Table 1. Furthermore, model architectures such as CT-BERT show poor performance due to the domain shift between the pre-training and fine-tuning datasets.
274
R. Bhambhoria et al.
Table 1 Human Evaluation results are randomly sampled examples from tweets annotated by 4 trained annotators. Model Output results are observed by using a fine-tuned UmlsBERT model on the n2c2 (2010) dataset. The rightmost column represents strings that match MetaMapLite’s extractions and terms in the AMIA Task 3 dataset Tweet Human evaluation Model output Metamap + AMIA output Still recovering from #Covid, but fatigue and pots are persistent—I feel my heart rate going crazy. Tried resting, antivirals and some home remedies. Next week I’ll get my second vaccine. #longcovid, wish me luck! I cannot taste or smell my food, my energy levels are low, I can barely sleep and getting out of bed is hard. COVID sucks! #longhauler I got covid 3 months ago. Hard to concentrate the same way as before getting this virus... I’m having constant headaches and dealing with brain fog is exhausting. Anybody with the same symptoms? #LongCovid I am seeing an occupational therapist next week to get rid of my #longcovid symptoms. I hope I will climb the stairs without taking a breath halfway or using a steroid inhaler
#Covid, fatigue, pots, heart rate, resting, antivirals, vaccine,#longcovid
Covid, fatigue, pots, my heart rate, resting, antivirals, home remedies, my second vaccine, longcovid
Pots, fatigue, heart rate, crazy
Cannot taste or smell, low energy levels, barely sleep, getting out of bed is hard, COVID, #longhauler
My energy levels
Bed, low, energy levels, sleep, taste
Covid, Hard to concentrate, constant headaches, brain fog, exhausting, #LongCovid
This virus, constant headaches, brain fog, the same symptoms
Headaches, brain, exhausting, fog, brain fog
Occupational therapist,#longcovid, climb the stairs without taking a breath halfway, using a steroid inhaler
My # longcovid symptoms, a steroid inhaler
–
Towards Providing Clinical Insights on Long Covid from Twitter Data
275
Table 2 Results of transformer-based models after fine-tuning on the n2c2 (2010) dataset for entity recognition Model P R F1 UmlsBERT CT-BERT
0.8742 0.7014
0.8947 0.7375
0.8843 0.7190
Table 3 Result of exact matches of human evaluation on 200 tweets identified by trained annotators with the model output. Annotator #3 and #4 have medical backgrounds Annotator MatchCount MatchCount 1 2 3 4
0.4529 0.6206 0.4228 0.5760
0.5823 0.6400 0.4176 0.4440
The terms extracted by MetaMapLite, along with terms in the AMIA Task 3, provide a promising dataset for fine-tuning the transformer models. We hypothesize that this could produce a model that better captures the idiosyncrasies of social media text; however, in a preliminary experiment, the model performed poorly when trained on only MetaMapLite results of the “sosy” semantic type, combined with AMIA Task 3 data. This is likely a reflection of i) sparse labelling—MetaMapLite often captures fewer entities than are actually present, and ii) incorrect labels, as there are frequent examples of non-clinical terms that are captured. To improve performance, we would like to both expand the dataset for fine-tuning and improve its quality. This would benefit greatly from the input of clinical SMEs capable of informing the inclusion of additional semantic types, and guidelines on limiting inaccurate or erroneous MetaMapLite results. Conclusion and Future Work In this work, we propose utilizing entity extraction methods as a means to provide insights into patient self-reported experiences with post-acute sequelae of SARSCoV-2 (PASC), or Long COVID. We evaluate the performance of—(1) state-of-theart transformer-based models fine-tuned on the n2c2 (2010) dataset and (2) data augmentation using entities extracted by MetaMapLite, alongside terms from the AMIA Task 3 dataset. We observe that, although this generally produces sensible results, it still falls short of human assessments of clinical entities, and frequently misses key terms while embellishing others with superfluous text. Future work would further explore the use of MetaMapLite, and datasets such as the one from AMIA Task 3, to annotate a dataset of tweets for fine-tuning the NER models while capturing the idiosyncrasies of social media text.
276
R. Bhambhoria et al.
Ethical Considerations A number of considerations are important, for example, if the present work can be extended more broadly to monitor patient discussions on their experience with other illnesses. Obviously, the first challenge is whether the use of such data meets the ethical requirements [30]. In our study, we obtained an ethics opinion suggesting deidentified publicly available data can be analyzed for generating trends and insights if no individual data are shared and risk of re-identification is extremely low. Based on this opinion, we implemented an anonymization step early in the dataset development process and restricted updating of the dataset made available for analyses. We see future studies will have to carve their paths to secure ethics approval and implement the required actions. This aligns with the recommendations from Chiauzzi and Wicks [5] that similar data science projects need to assure the participants in their studies who have not formally consented that their anonymity will not be compromised, or risk of harmful outcomes is rare. Nevertheless, Staccani and Lau highlight that patient acceptance of social media use for clinical trial surveillance could be favorable [22]. Their findings underline the need to seek social media users’ opinion and perception on use of their social media posts for studying disease characteristics and surveillance purposes in public health. Another important consideration is applying these approaches to other conditions. We understand that this may be an ambitious recommendation given that there is sparse labelling in the social media datasets. We overcame this challenge by casting a wide net of subject matter experts from academia, industry, and from expertise spanning clinical medicine, clinical research, epidemiology, and pharmaceutical research. Combined, these individuals provided insights about data curation strategies including but not limited to, selection of filters, data pre-processing, labelling of limited data, and scaling of this labelling to large datasets. These experts also participated in evaluating the quality of labelled data. Given our experience, involving a multidisciplinary group to expand these approaches to other conditions is feasible given that patients are expressing their views on social media (e.g., another emerging infection, high incidence cancer, unknown illness). Moreover, we suggest involving a group of patients to include patient reported outcomes in real-time and further inform the key word search strategy. Acknowledgements The authors would like to thank Vector Institute for making this collaboration possible and providing academic infrastructure and computing support during all phases of this work. We would also like to thank Antoaneta Vladimirova and Celine Leng from Roche, Esmat Sahak from University of Toronto for their support throughout this project, as well as Dr. Angela Cheung from the University Health Network for her expertise and guidance.
Towards Providing Clinical Insights on Long Covid from Twitter Data
277
References 1. Aucott, J. N., & Rebman, A. W. (2021). Long-haul covid: Heed the lessons from other infectiontriggered illnesses. The Lancet, 397(10278), 967–968. 2. Bhambhoria, R., et al. (2020). A smart system to generate and validate question answer pairs for covid-19 literature. In: Proceedings of the First Workshop on Scholarly Document Processing (pp. 20–30) 3. Bodenreider, O.: The unified medical language system (umls): Integrating biomedical terminology. Nucleic Acids Research, 32(Database issue), 267–270. https://doi.org/10.1093/nar/ gkh061. 4. C., E., et al.: RETAIN: interpretable predictive model in healthcare using reverse time attention mechanism. CoRR abs/1608.05745 (2016), http://arxiv.org/abs/1608.05745 5. Chiauzzi, E., & Wicks, P. (2019). Digital trespass: Ethical and terms-of-use violations by researchers accessing data from an online patient community. Journal of Medical Internet Research, 21(2), e11985. 6. Demner-Fushman, D., et al.: MetaMap Lite: An evaluation of a new Java implementation of MetaMap. Journal of the American Medical Informatics Association, 24(4), 841–844 (01 2017). https://doi.org/10.1093/jamia/ocw177 7. Domingo, F., et al. (2021) Prevalence of long-term effects in individuals diagnosed with covid19: A living systematic review. 8. Guo, J., et al. (2020). Mining twitter to explore the emergence of covid-19 symptoms. Public Health Nursing, 37(6), 934–940. 9. Jelodar, H., et al. (2020). Deep sentiment classification and topic discovery on novel coronavirus or covid-19 online discussions: Nlp using lstm recurrent neural network approach. IEEE Journal of Biomedical and Health Informatics, 24(10), 2733–2742. 10. M., G., et al.: Umlsbert: Clinical domain knowledge augmentation of contextual embeddings using the unified medical language system metathesaurus (2020) 11. Mackey, T., et al. (2020). Machine learning to detect self-reporting of symptoms, testing access, and recovery associated with covid-19 on twitter: retrospective big data infoveillance study. JMIR Public Health and Surveillance, 6(2), e19509. 12. Magge, A., Klein, A., Miranda-Escalada, A., Al-garadi, M.A., Alimova, I., Miftahutdinov, Z., Farre-Maduell, E., Lopez, S.L., Flores, I., O’Connor, K., Weissenbacher, D., Tutubalina, E., Sarker, A., Banda, J.M., Krallinger, M., & Gonzalez-Hernandez, G. (Eds.). (2021). In Proceedings of the Sixth Social Media Mining for Health (#SMM4H) Workshop and Shared Task. Association for Computational Linguistics, Mexico City, Mexico, June 2021. https:// aclanthology.org/2021.smm4h-1.0 13. Morgan, M., et al. (2014). Information extraction for social media. In: Proceedings of the Third Workshop on Semantic Web and Information Extraction (pp. 9–16). 14. Müller, M., et al. (2020). Covid-twitter-bert: A natural language processing model to analyse covid-19 content on twitter. arXiv:2005.07503 (2020) 15. Phillips, N. (2021). The coronavirus is here to stay-here’s what that means. Nature, 590(7846), 382–384. 16. Pradhan, S., et al. (2014). Semeval-2014 task 7: Analysis of clinical text. In: Proceedings of the 8th International Workshop on Semantic Evaluation (SemEval 2014). Citeseer 17. Qin, L., et al. (2020). Prediction of number of cases of 2019 novel coronavirus (covid-19) using social media search index. International Journal of Environmental Research and Public Health, 17(7), 2365. 18. Rudin, C.: Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead (2019) 19. Sarker, A., Gonzalez-Hernandez, G.: Overview of the second social media mining for health (smm4h) shared tasks at AMIA 2017. In: Proceedings of the 2nd Social Media Mining for Health Research and Applications Workshop co-located with the American Medical Informatics Association Annual Symposium (AMIA 2017). http://ceur-ws.org/Vol-1996/.
278
R. Bhambhoria et al.
20. Sha, Y., Wang, M.D. (2017). Interpretable predictions of clinical outcomes with an attentionbased recurrent neural network. In: Proceedings of the 8th ACM International Conference on Bioinformatics, Computational Biology, and Health Informatics (pp. 233–240). 21. Smailhodzic, E., et al. (2016). Social media use in healthcare: A systematic review of effects on patients and on their relationship with healthcare professionals. BMC Health Services Research, 16(1), 1–14. 22. Staccini, P., Lau, A. Y., et al. (2020). Social media, research, and ethics: Does participant willingness matter? Yearbook of Medical Informatics, 29(01), 176–183. 23. Sudre, C., et al. (2021). Attributes and predictors of long covid. Nature Medicine, 27(4), 626– 631. 24. Suominen, H., et al. (2013). Overview of the share/clef ehealth evaluation lab 2013. In: International Conference of the Cross-Language Evaluation Forum for European Languages (pp. 212–231). Springer. 25. Tsao, S., et al. (2021). What social media told us in the time of covid-19: A scoping review. The Lancet Digital Health. 26. Uzuner, Ö., et al. (2011). 2010 i2b2/va challenge on concepts, assertions, and relations in clinical text. Journal of the American Medical Informatics Association, 18(5), 552–556. 27. Vijayan, T., et al. (2020). Trusting evidence over anecdote: Clinical decision making in the era of covid-19. BMJ. https://blogs.bmj.com/bmj/2020/07/23/trusting-evidence-over-anecdoteclinical-decision-making-in-the-era-of-covid-19/. 28. Wang, Y., et al. (2021). Examining risk and crisis communications of government agencies and stakeholders during early-stages of covid-19 on twitter. Computers in Human Behavior, 114, 106568. 29. Wiegreffe, S., Pinter, Y. (2019). Attention is not not explanation. arxiv:1908.04626. 30. Williams, M. L., Burnap, P., & Sloan, L. (2017). Towards an ethical framework for publishing twitter data in social research: Taking into account users’ views, online context and algorithmic estimation. Sociology, 51(6), 1149–1168.
Predicting Infections in the Covid-19 Pandemic—Lessons Learned Sharare Zehtabian, Siavash Khodadadeh, Damla Turgut, and Ladislau Bölöni
Abstract Throughout the Covid-19 pandemic, a significant amount of effort had been put into developing techniques that predict the number of infections under various assumptions about the public policy and non-pharmaceutical interventions. While both the available data and the sophistication of the AI models and available computing power exceed what was available in previous years, the overall success of prediction approaches was very limited. In this paper, we start from prediction algorithms proposed for XPrize Pandemic Response Challenge and consider several directions that might allow their improvement. Then, we investigate their performance over medium-term predictions extending over several months. We find that while augmenting the algorithms with additional information about the culture of the modeled region, incorporating traditional compartmental models and up-to-date deep learning architectures can improve the performance for short term predictions, the accuracy of medium-term predictions is still very low and a significant amount of future research is needed to make such models a reliable component of a public policy toolbox. Keywords Prediction models · Covid-19 pandemic · Artificial intelligence · Deep learning
S. Zehtabian (B) · S. Khodadadeh · D. Turgut · L. Bölöni Department of Computer Science, University of Central Florida, Orlando, FL 32816, USA e-mail: [email protected] S. Khodadadeh e-mail: [email protected] D. Turgut e-mail: [email protected] L. Bölöni e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_20
279
280
S. Zehtabian et al.
1 Introduction There is no epidemiological event in the history of humanity that received the level of computational modeling and prediction effort as the Covid-19 pandemic. In contrast to previous outbreaks, humanity faced the pandemic not only with the scientific models honed over many previous outbreaks but also with new computational tools and a relatively well functioning international reporting system. The latter, at least in the developed world, allowed the accurate tracking of the infections, hospitalizations, and deaths on a day-by-day, region-by-region basis. Gene sequencing tools allowed the tracking of virus variants in each area. Finally, the pandemic outbreak occurred after the deep learning revolution in AI, opening the possibility to learn predictive models from the data as it becomes available. Yet, with all these apparent advantages, humanity’s ability to predict the spread of infections was little better than random. Repeatedly, public policy measures were based on wildly incorrect forecasts. Predicted waves did not materialize, while significant outbreaks happened in locations and at times that nobody predicted. As different variants of the virus had different levels of infectiousness, authorities were effectively dealing with several overlapping pandemics, connected by the partial immunity conferred by previous infections. The improvement of treatment procedures, the emergence of vaccines, and more recently, antiviral medications also changed the nature of the pandemic. The strong age-dependency of the medical outcomes meant that the pandemic waves were dependent on the age pyramid of the geographic area. Finally, countries and regions instituted different non-pharmaceutical interventions (NPIs) such as public health measures, lockdowns and masking recommendations. The nature and stringency of these measures depended not only on the evolution of the pandemic but also on local political events, popular opinion, and personalities with strong influence over the public narrative. Despite the weak results of the pandemic prediction approaches, continued research into the prediction models remains highly important for the remainder of the Covid-19 pandemic, as well as for future epidemics. Even just a better understanding of what is predictable and what is not would be helpful for policymakers of the future. The impetus for this paper was the Pandemic Response Challenge sponsored by Cognizant Inc. at the XPrize Foundation, launched on October 30, 2020. The two phases of the competition focused on prediction and prescription models respectively. For the prediction phase, the competition organizers implemented a framework in which the prediction models submitted by the contestants have access to the Oxford Covid-19 Government Response Tracker (OxCGRT) data [6]. For the first phase of the competition, a leaderboard was dynamically evaluated every day on up-to-date real-world data from Dec 22, 2020, to January 12, 2021. The final phase of the competition focused on choosing an optimal set of NPIs, which, while interesting from an algorithmic point of view, cannot be evaluated with real-world data. We participated in this challenge as the PawP (Pandemic Wave Predictor) team with an earlier version of the model referred to as LSTM-CultD-SIR in this paper. Our
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
281
highest ranking in the leaderboard was on January 14, 2021, where our team was 5th out of 336 registered teams. Since the closing of the competition, the pandemic went through several new phases. For instance, in the United States, the availability of vaccines lowered the number of cases in June 2021 to a point where the public perception considered the pandemic to be over. Then, the spread of the Delta variant lead to a new wave of infections in Fall 2021, while the Omicron variant creating an even higher wave in Winter 2021. However, the availability of treatments such as monoclonal antibodies lowered the mortality rate compared to previous phases of the pandemic. On the other hand, “pandemic fatigue”, economic pressures and cultural backlash affected the ways NPIs are chosen as well as the population compliance level. The objective of this paper is to revisit, 11 months later, prediction models that were among some of the most accurate ones in January 2021, and investigate whether any useful lessons can be learned from the evolution of the pandemic since then. The overall questions we want to answer in this paper are the following: • How do prediction models that were successful in the early stage of the pandemic hold up at later stages, with or without retraining on new data? • The public press assigned a significant importance to cultural factors for the different evolution of the pandemic in specific countries or regions. Can we gain additional accuracy in short and medium term predictions by adding to the inputs of a model quantifiable descriptors of culture in various countries and regions? • Would the most recent advances in AI models, such as the multi-head attention used in transformer architectures [16] help the accuracy of the prediction?
2 Related Work The Covid-19 pandemic led to the initiation of a significant number of modeling and prediction projects in the academic community. Flaxman et al. [5] studied the effects of major non pharmaceutical interventions (NPIs) across 11 European countries and introduced a Bayesian model to estimate the epidemic. Dehning et al. [4] focused on short-term infection forecasting based on NPIs and studied how the interventions affect the epidemiological parameters. They combined a SIR model with Bayesian parameter inference to analyse the time-dependent spreading rate. Arık et al. [1] proposed an approach for modeling Covid-19 forecasts by integrating covariate encoding to compartmental models. They used the standard SEIR model but modeled more compartments such as: undocumented infected and recovered compartments, hospitalized, ICU and ventilator compartments, and partial immunity. Jin et al. [8] focused on a direct data-driven prediction model for predicting Covid-19 without using compartmental models. They developed a neural forecasting model called Attention Crossing Time Series (ACTS) that predicts cases by comparing patterns of time series from different regions. Xiao et al. [18] proposed a data-driven framework called C-Watcher to screen all the neighborhoods in a city and detect
282
S. Zehtabian et al.
the neighborhoods with the highest risk of Covid-19 spread before they contaminate other neighborhoods. Liao et al. [10] proposed a time-window based SIR prediction model and used a machine learning method to predict the basic reproduction number R0 and the exponential growth rate of the epidemic. Mehta et al. [12] focused on country level prediction of Covid-19 for the near future based on a combination of health statistics, demographics, and geographical features of counties. They used US Census data to obtain county-level population statistics for age, gender, and density. Watson et al. [17] proposed a Bayesian time series model that fits a location-specific curve to the velocity (first derivative) of the log transformed cumulative case count. Then, they use a random forest algorithm that learns the relationship between Covid-19 cases and population characteristics to predict deaths. Finally, they embed these models to a compartmental model which can provide projections for active cases and confirmed recoveries. Zou et al. [20] introduced the SuEIR model, a variant of the SEIR model, to predict confirmed and fatality cases, the peak date of active cases, and estimate the basic reproduction number (R0 ) in the United States. Their model considers additional information such as the untested/unreported cases of Covid-19. Qian and Alaa [14] focused on developing a model to learn the policies that affect the fatality rate of the Covid-19 in a global context. They used a Bayesian model with a two-layer Gaussian process (GP) prior. Sharma, Mindermann, Brauner et al. [15] investigated the robustness of the estimated effects of NPIs against Covid-19. Mastakouri and Schölkopf [11] studied a causal time series analysis of the Covid-19 spread in Germany to understand the causal role of the applied non pharmaceutical interventions (NPIs) in containing the spread among German regions. They used a causal feature selection method for time series with latent confounders called SyPI to analyse and detect the restriction policies that have a causal impact on the daily number of Covid-19 cases. Yeung et al. [19] combined NPIs and Hofstede cultural dimensions in predicting the infection rate for 114 countries. Particularly, they predict confirmed infection growth (CIG), which is defined as the 14-day growth in the cumulative number of reported infection cases. They used OxCGRT data of the NPIs and trained different non-time series models such as ridge regression, decision tree, random forest, AdaBoost, and support vector regression using mean squared error (MSE), and performed a grid search on the combination of these models. Johndrow et al. [9] built a model for Covid-19 transmission only by using the number of daily deaths, timing of containment measures, and information on the clinical progression of the disease. Bengio et al. [2] proposed a proactive contact tracing method. They embedded two neural networks namely Deep Sets and Set Transformers and evaluated the resulting models via the COVIsim testbed. Their methods are able to leverage weak signals and patterns in noisy, heterogeneous data to better estimate infectiousness compared to binary contact tracing and rule-based methods. Davis et al. [3] designed a stochastic age-structured transmission model to study different intervention scenarios and their impacts on the transmission of Covid-19 cases in the UK.
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
283
3 Learning-based Models for Predicting the Number of Infections Traditional epidemiological models, such as the SIR compartmental model aims to predict the evolution of the epidemic based on first principles and a relatively few number of human-understandable parameters. The input to these models is usually the current state of the pandemic and they are calibrated by human experts based on the infectiousness of the disease. In contrast, learning based models use significantly more complex models with a very large number of parameters, such as deep neural networks. These parameters are not individually human-interpretable and the only realistic way to acquire them is through the use of learning. Often, these models look at the pandemic as a function unfolding in time, thus they take an input either a sliding window of the recent history of the pandemic at every timestep, or maintain internally a memory of it. The official examples and the majority of entries to the Pandemic Response Challenge were such learning-based systems (although it is difficult to know how much human-expert data was individually incorporated). In the remainder of this section, we will discuss four possible models (a model presented as the “official” one in the competition and three models designed by our team), presenting their architecture and design rationale in a comparable way (see Fig. 1). The models assume that we have access to two streams of daily data, both of them extractable from the Oxford Covid-19 Government Response Tracker (OxCGRT) data [6]. The context stream provides describes the basic context in which the prediction needs to be made - for instance, the total number of people already infected as well as specific identifiers or the country and geographical region. The context columns include, for each region, information such as country name, region name, geo ID, date, confirmed cases, confirmed deaths, and population. The action stream, in contrast, describes the specific actions, typically non-pharmaceutical interventions that authorities enacted in a given area. The action stream includes 12 Nonpharmaceutical Intervention (NPI) columns: school closing, workplace closing, cancel public events, restrictions on gatherings, close public transport, stay at home requirements, restrictions on internal movement, international travel controls, public information campaigns, testing policy, contact tracing, and facial coverings. We ignored countries or regions for which no number of cases or deaths are available and filled the empty or missing values on NPI columns in the data with 0 for each region or country. Finally, by having access to the country and geographical region, it is possible to extend the context and action stream with other information that can be looked up from other databases or web services. For instance, should we want to investigate the hypothesis that the weather affects the spread of an infection, this information could be brought in by correlating the geographical identifier with an external weather service. Having access to the stream of information contained in the context and the action stream and whatever auxiliary data the system might choose to look up, the objective of the predictive model is to predict the number of infections in the next
284
S. Zehtabian et al.
Fig. 1 The architecture of the compared models: LSTM-Baseline (top-left), LSTM-UT-Cogn (bottom-left), LSTM-CultD-SIR (top-right) TRANSENC-CultD-SIR (bottom-right)
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
285
day, and through extrapolation, for a larger period into the future. To describe the fields of the datastreams, we will use the following notations: for a region r with population P r , we refer to the values of NPI columns for day t by NPI rt which is a vector of length 12. Each element i is an integer between 0 and NPI_MAX i . We denote the number of Covid-19 cases at day t by nCtr . LSTM-UT-Cogn: The first model we are describing [13] was developed by the organizers of the Pandemic Response Challenge and provided to the teams that qualified to the finals of the competition to serve as a metric for prescriptive measures. Although this model did not directly compete in the challenge, it was clearly seen as a state-of-the-art model at that point in the pandemic. The prediction architecture, shown in Fig. 1-bottom-left, is unusual in that it is using two separate branches for the context and the action data, with the predicted value being the proportion of new people infected from the population that is currently not infected (naturally, the absolute number of infected people can be calculated from this value). The input of the first branch is the infection ratio from the context stream processed by an LSTM layer, followed by a dense layer with one node and softplus activation. In the second branch, the model takes as inputs the NPIs from the action stream, which is processed by an LSTM followed by a dense layer with a single node and sigmoid activation function. The outputs of context branch h and action branch g are combined using a lambda layer implementing the formula (1 − g) × h to produce the output of the model. The model was trained on sequences of length of 21 days. LSTM-Baseline The remaining three models we are considering were developed by our team. The simplest, baseline model also uses an LSTM network, which, in recent years had became the most popular way in the deep learning community to process data from time series that is presented to the model one at a time. In this simplest model (Fig. 1-top-left), we investigate the hypothesis that the LSTM network can learn how to select the important information from the combined context and action stream without any further input from the modeler. Before inputting our data to LSTM, we have to preprocess it such that the values are normalized. To achieve this, we use the “Infection Ratio” column that is evaluated as follows. First, we compute the infected proportion by dividing the number of cases by population in region r nC r and day t: atr = P rt . Then, the smoothed version of atr is computed for each day by t getting the average of these values in a 7-day time window. Next, we compute the a r −a r percent change by ∇rt+1 = t+1ar t t The model processes streams of data of a width of 13, with the first column being the value for the infection ratio while the rest of the columns representing the NPIs. This input goes into an LSTM component with 64 nodes, and they are processed by a Dense layer with just one neuron that outputs “Infection Ratio” for the next day. We use the L1 loss between the output of the model and the real value of the “Infection Ratio” to train the network. During the inference, we only have access to NPI columns on each day in the future and use the prediction of the network as the “Infection Ratio” input for the following days. We also clip the network’s output
286
S. Zehtabian et al.
between 0 and 2 during inference to make sure that the outputs do not diverge. This is especially important when we use the model for longer predictions. Taking into account culture: It had been an important part of the public narrative of the pandemic that various aspects of the interventions such as mask wearing, refraining from large gatherings, adherence to social distancing rules and vaccinations are culture-dependent. Unfortunately, quantifying various aspects of the culture as relates to the pandemic is not easy. Furthermore, similar cultures can accommodate very different public policies, as illustrated by the case of Scandinavian countries where culturally similar countries like Sweden and Norway chose to adapt different policy approaches. Nevertheless, the hypothesis that taking the cultural aspects into consideration can improve the prediction accuracy is definitely worth considering. A problem in implementing such a system is that in the social sciences, culture is often discussed in qualitative, narrative form. There are relatively few examples of quantitative models of human interactions. One of the efforts that assigned numerical values to aspects of the culture of various nations is the cultural dimensions model [7] which attempted to quantify natural culture along six numerical dimensions, with public databases available with approximations of these values at a nation-state level. Admittedly, this model received significant criticism over the years, among other things for the choice of a nation as the resolution of the model. For instance, the model does not differentiate between California and Alabama in the United States. However, to our best knowledge this is the only culture quantification model for which public databases exist for the majority of regions. We note that we do not make any assumptions about the impact of the cultural dimension values on our prediction—we add these values to the system and allow it to learn their possible impact. Adding compartmental models: Another direction for improving learning based models of pandemic prediction is to incorporate in their architecture the foundations of established models of epidemiology, for instance the SIR family of compartmental models. These have a comparatively small number of parameters compared to neural networks; furthermore, these parameters are often human-interpretable. The models encapsulate pre-existing knowledge about the dynamics of epidemics, which, thus, does not need to be learned from scratch. The SIR model assumes that a fixed size population P contains members that can be in one of three states: Susceptible (S), Infected (I), and Recovered (R). We add these columns to the data for each country using the following equations: St = St−1 − newCasest
(1)
where S0 = Population It = It−1 −
1 × It−1 + newCasest − newDeathst d
(2)
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
287
where 1/d is the daily recovery rate and d is the average number of days required to recovering. (3) Rt = Population − St − It The transition from these states can be modeled by parameters α and β as described below: S = −α × S × I I = α × S × I − β × I
(4) (5)
R = β × I,
(6)
where S , I , and R are the rate of change in value of S, I , and R respectively. We train the networks to take as input the data from the last T days and predict the value of S p , I p , R p which are susceptible fraction of the population, infected fraction of the population and recovered fraction of the population for the next day(s). The model looks at NPI and all other data and outputs S p , I p , and R p . LSTM based predictor using cultural dimensions and the SIR model (LSTMCultD-SIR: In this model, shown in Fig. 1-top-right, we use a compartmental model to create new columns for the susceptible fraction of the population (S p ), infected fraction of the population (I p ), and recovered fraction of the population (R p ). We initialize S p with 1, and I p and R p with 0. Then, we calculate these values for the next rows based on Eqs. 1, 2, and 4 over population for each country or region. We use 14 as the average number of days required for recovery to compute recovery rate in Eq. 2. We use these columns alongside the infection ratio column as context input. We concatenate the context input and action input (NPI columns) of 21 previous days and feed it to a LSTM layer. Then, we add the cultural features of the Hofstede dimensions as constant features to the output of the LSTM layer and feed them to a dense layer with 4 nodes. The model is trained using the Adam optimizer and the mean absolute error as a loss and outputs the infection ratio, S p , I p , and R p . Transformer encoder based predictor using cultural dimensions and the SIR model (TRANSENC-CultD-SIR): This model is similar to LSTM-CultD-SIR, but we use a transformer encoder layer instead of the LSTM layer (see Fig. 1-bottomright). The difference is that this model can read the whole sequence all at the same while the LSTM-CultD-SIR model reads that sequence one by one. Transformers were first introduced in [16] and are an architecture that instead of using recurrent networks uses an attention mechanism to learn relations between input and output. The main advantage of the transformers is their ability to see the sequence of data in parallel and learn very long-term interactions. We propose using the multi-head attention module from the transformer architecture to train the predictor model. To the best of our knowledge, we are the first to use attention models on NPI features and combining the attention model’s output with the cultural dimensions and the SIR model for prediction of new Covid-19 cases.
288
S. Zehtabian et al.
The transformer layer includes attention and normalization part and a feed forward part. The attention and normalization part includes a multi head attention layer, a dropout and a normalization layer. The feed forward layer is a sequence of two dense layers: one with ReLU activation and the other one with linear activation, followed by a dropout layer and a normalization layer.
4 Experiments To evaluate the predictive accuracy of the model, we need a metric that smooths out daily variations and is comparable across regions with different populations. Thus, for a region r with a population Pr we will use the average number of cases per 100 k people over a span of 7 days: Cumul−7DMA−MAE−per−100 K(r ) =
d∈D
| y¯ − y¯ˆ | ×
100000 , Pr
(7)
where x¯ is the 7-day moving average on x and denotes the population in region r . We trained the four models separately on a few months of data from all the countries and regions. For every model, we use 1000 epochs for training with early stopping with patience 20 that restores best weights. We split the training data into training and validation with 90 and 10% rate, respectively, with a batch size of 32. We used a learning rate of 0.001. We run two different experiments, with the training data and test data being chosen from different calendar months. Experiment E2020 used data from January 2020 to July 2020 for training and data from August 2020 to end of December 2020 for evaluation. Experiment E2021 used data from the whole year of 2020 and made predictions for January to April 2021. See Fig. 2.
Fig. 2 Average of 7-day predicted daily new cases over all countries using our predictors, LSTMCultD-SIR and TRANSENC-CultD-SIR and two baselines LSTM-Baseline and LSTM-UT-Cogn. Left: E2020, Right: E2021
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
289
Fig. 3 Cumulative 7-day mean absolute error per 100 k for each prediction approach. Left: E2020. Right: E2021
Fig. 4 Color scaled cumulative 7-day mean absolute error per 100 k per country or region based on each prediction approach. Green color shows zero to 2 k and red color shows 8 k or more cumulative 7-day mean absolute error per 100 k
Based on our experiments, our TRANSENC-CultD-SIR approach had the lowest cumulative mean absolute error per 100k over 7 days. See Fig. 3 for the overall cumulative 7 day moving average mean absolute error per 100 k for both of the experiments. Also, Fig. 4 shows the color scaled version of this metric for all the countries and regions around the world for E2020 experiment. The green nodes are showing the regions or countries with the lowest error (0 to 2 k) per 100k, and red nodes are showing 8 k or more error per 100k. We find that the TRANSENC-CultDSIR approach has the lowest number of red-orange nodes which means that it has the better performance comparing to other approaches.
5 Conclusion In this paper, we described the design of several pandemic prediction models and compared them with each other. As a comparison point, we used the model that was used as the official predictor for the finals of the XPrize Pandemic Response Challenge. The models introduced in this paper build on and improve our submission to this competition. By testing the models on data that extends several months after the
290
S. Zehtabian et al.
competition, we can make several observations that can serve as lessons for modeling approaches in the future. First, models that are finely tuned to predict over the spans of days and weeks accurately can diverge significantly over the span of months. Second, sophisticated machine learning models such as transformer-style multi-head attention replacing LSTMs can produce an iterative improvement if everything else is equal but are not making a decisive difference in prediction accuracy. Third, while the canonical models of prediction such as the SIR compartmental model cannot, by themselves, provide an accurate prediction, they can serve a useful role in preventing runaway errors in the models. Finally, while cultural factors are clearly influencing the evolution of the pandemic, we do not yet have a method to incorporate this information in a rigorous manner. Acknowledgements This work was partially supported by the National Science Foundation under Award No. 1800961.
References 1. Arik, S. O., Li, C. L., Yoon, J., Sinha, R., Epshteyn, A., Le, L. ., Menon, V., Singh, S., Zhang, L., Yoder, N., Nikoltchev, M., Sonthalia, Y., Nakhost, H., Kanal, E., & Pfister, T. (2020). Interpretable sequence learning for COVID-19 forecasting. In: Advances in Neural Information Processing Systems (NeurIPS) 2. Bengio, Y., Gupta, P., Maharaj, T., Rahaman, N., Weiss, M., Deleu, T., Muller, E.B., Qu, M., St-charles, P. l., Bilaniuk, O., et al.: Predicting infectiousness for proactive contact tracing. In: Proceedings of International Conference on Learning Representations (ICLR-2020) 3. Davies, N. G., Kucharski, A. J., Eggo, R. M., Gimma, A., Edmunds, W. J., Jombart, T., et al. (2020). Effects of non-pharmaceutical interventions on COVID-19 cases, deaths, and demand for hospital services in the UK: A modelling study. The Lancet Public Health, 5(7), e375–e385. 4. Dehning, J., Zierenberg, J., Spitzner, F.P., Wibral, M., Neto, J.P., Wilczek, M., & Priesemann, V. (2020). Inferring COVID-19 spreading rates and potential change points for case number forecasts. https://doi.org/10.1101/2020.04.02.20050922 5. Flaxman, S., Mishra, S., Gandy, A., Unwin, H. J. T., Mellan, T. A., Coupland, H., Whittaker, C., Zhu, H., Berah, T., Eaton, J. W., Monod, M., Imperial College COVID-19 Response Team, Ghani, A. C., Donnelly, C. A., Riley, S., Vollmer, M. A. C., Ferguson, N. M., Okell, L. C., Bhatt, S. (2020). Estimating the effects of non-pharmaceutical interventions on COVID-19 in Europe. Nature, 584(7820), 257–261. 6. Hale, T., Angrist, N., Goldszmidt, R., Kira, B., Petherick, A., Phillips, T., Webster, S., CameronBlake, E., Hallas, L., Majumdar, S., & Tatlow, H. (2021). A global panel database of pandemic policies (Oxford COVID-19 government response tracker). Nature Human Behaviour, 5(4), 529–538. 7. Hofstede, G. (1984). Culture’s consequences: International differences in work-related values. Journal of Service Research, 5. 8. Jin, X., Wang, Y. X., Yan, X. (2021). Inter-series attention model for COVID-19 forecasting. In: Proceedings of the 2021 SIAM International Conference on Data Mining (SDM-2021). 9. Johndrow, J., Ball, P., Gargiulo, M., & Lum, K. (2020). Estimating the number of SARS-CoV-2 infections and the impact of mitigation policies in the United States. Harvard Data Science Review 10. Liao, Z., Lan, P., Liao, Z., Zhang, Y., & Liu, S. (2020). TW-SIR: Time-window based SIR for COVID-19 forecasts. Scientific Reports, 10(1), 1–15.
Predicting Infections in the Covid-19 Pandemic—Lessons Learned
291
11. Mastakouri, A. A., Schölkopf, B. (2020). Causal analysis of Covid-19 spread in Germany. In: Advances in Neural Information Processing Systems (NeurIPS). 12. Mehta, M., Julaiti, J., Griffin, P., & Kumara, S. (2020). Early stage machine learning-based prediction of US county vulnerability to the COVID-19 pandemic: Machine learning approach. JMIR Public Health and Surveillance, 6(3), e19446. 13. Miikkulainen, R., Francon, O., Meyerson, E., Qiu, X., Sargent, D., Canzani, E., & Hodjat, B. (2021). From prediction to prescription: Evolutionary optimization of nonpharmaceutical interventions in the COVID-19 pandemic. IEEE Transactions on Evolutionary Computation, 25(2), 386–401. 14. Qian, Z., Alaa, A. M., van der Schaar, M. (2020). When and how to lift the lockdown? Global COVID-19 scenario analysis and policy assessment using compartmental Gaussian Processes. In: Advances in Neural Information Processing Systems (NeurIPS). 15. Sharma, M., Mindermann, S., Brauner, J. M., Leech, G., Stephenson, A. B., Gavenˇciak, T., Kulveit, J., Teh, Y. W., Chindelevitch, L., & Gal, Y. (2020). How robust are the estimated effects of nonpharmaceutical interventions against COVID-19? In: Advances in Neural Information Processing Systems (NeurIPS) 16. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In: Advances in Neural Information Processing Systems (NeurIPS) 17. Watson, G. L., Xiong, D., Zhang, L., Zoller, J. A., Shamshoian, J., Sundin, P., et al. (2021). Pandemic velocity: Forecasting COVID-19 in the US with a machine learning & Bayesian time series compartmental model. PLoS Computational Biology, 17(3), e1008837. 18. Xiao, C., Zhou, J., Huang, J., Zhuo, A., Liu, J., Xiong, H., & Dou, D. (2020). C-Watcher: A framework for early detection of high-risk neighborhoods ahead of COVID-19 outbreak. arXiv:2012.12169 19. Yeung, A. Y., Roewer-Despres, F., Rosella, L., & Rudzicz, F. (2021). Machine learning-based prediction of growth in confirmed COVID-19 infection cases in 114 countries using metrics of nonpharmaceutical interventions and cultural dimensions: Model development and validation. Journal of Medical Internet Research, 23(4), e26628. 20. Zou, D., Wang, L., Xu, P., Chen, J., Zhang, W., & Gu, Q. (2020). Epidemic model guided machine learning for COVID-19 forecasts in the United States. https://doi.org/10.1101/2020. 05.24.20111989
Improving Radiology Report Generation with Adaptive Attention Lin Wang and Jie Chen
Abstract To avoid the tedious and laborious radiology report writing, automatic radiology reports generation has drawn great attention in recent years. As vision to language task, visual features and language features are equally important for radiology report generation. However, previous methods mainly pay attention to generating fluent reports, which neglects the eminent importance of how to better extract and utilize vision information. Keeping this in mind, we propose a novel architecture with a CLIP-based visual extractor and Multi-Head Adaptive Attention (MHAA) module to address the above two issues: through the vision-language pretrained encoders, more sufficient visual information has been explored, then during report generation, MHAA controls the visual information participating in the generation of each word. Experiments conducted on two public datasets demonstrate that our method outperforms state-of-the-art methods on all the metrics. Keywords Radiology report generation · Adaptive attention mechanism · Visual encoder
1 Introduction With the development of medical imaging technology, a large number of radiology reports are generated and used for medical diagnosis every day [6]. In order to ensure the accuracy of diagnosis, radiology reports must be written by experienced radiologists. In addition, it is laborious, time-consuming and tedious for radiologists to write these reports [9]. Therefore, automatic radiology report generation has emerged as a critical task in clinical practice [2, 6] and many methods have been proposed recently [5, 20]. In particular, an automatically generated report should be fluent L. Wang School of Electronic and Computer Engineering, Peking University, Shenzhen, China e-mail: [email protected] J. Chen (B) Peking University, Beijing, China e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_21
293
294
L. Wang and J. Chen
Ground Truth heart size is normal . the mediastinal and hilar contours are normal . the pulmonary vasculature is normal . lungs are clear . no pleural effusion or pneumothorax is seen . there are no acute osseous abnormalities .
Ground Truth the heart size is mildly enlarged . the aorta is tortuous and there are mild aortic knob calcifications . the pulmonary vascularity is not engorged . streaky bibasilar airspace opacities likely reflect atelectasis . small bilateral pleural effusions are present new in the interval . there is no pneumothorax . no acute osseous abnormality is detected . Fig. 1 Two examples of Chest radiographs and the corresponding ground truth report from MIMICCXR dataset. The upper one does not contain any abnormal symptoms; the lower one has multiple abnormal symptoms. Sentences describing abnormalities are marked in red
and readable, consisting of multiple sentences describing normal and abnormal area information. For abnormal findings, information such as location, size and severity should be described in detail. Figure 1 shows two examples of chest radiographs and corresponding report. In recent years, automatic radiology report generation has attracted extensive research interests [12–14, 20]. Most existing methods follow the standard image captioning approaches due to the big similarities of two tasks. Most works employ the encoder-decoder framework, e.g., CNN-HRNN [13, 29] or CNN-Transformer [4, 5]. In addition, some advanced technologies are employed to boost the generation performance, such as memory-driven natural language processing structure [5], consistency-based constraint rule [22, 23], and well-designed curriculum Learning strategy [19]. Radiology report generation is a kind of vision-language problem. Therefore, it should be considering from both vision and language aspects. Most previous methods pay more attention on language part for generating fluent sentences, which neglects the eminent importance of vision part. First of all, previous methods treat each word equally. In this way, when generating each word, the network provides the same visual features and language features. However, when different words are generated, they need different information. Some words like ’is’ or ’and’, need no visual information at all, and they can be inferred only by context language features; other words like
Improving Radiology Report Generation with Adaptive Attention
295
’heart’ or ’lung’, need more spatial information. Besides, few methods consider the importance of pretrained visual extractor. In this work, we first propose a Multi-Head Adaptive Attention module for radiology report generation, which can be performed at any transformer based models. To begin with, a scaled dot-product and a linear operation are used to obtain the the visual and language signals respectively. Then, we split the language signals by time step and combine them with visual signals to obtain the weight of the gate. Finally, a gate mechanism is utilized to control the visual and language information adaptively. Besides, we adapt the visual encoder of CLIP as our visual extractor. In summary, this work makes the following contributions: • We propose a Multi-Head Adaptive Attention (MHAA) module to control the visual information used in generating each word. • We adopt the visual extractor pretrained by CLIP and discuss the importance of vision-language pretrained models for radiology reports generation. • We conduct extensive experiments on two widely-used public datasets, which demonstrates the effectiveness of the proposed MHAA module as well as the more powerful pretrained visual extractor.
2 Related Work Image Captioning and Paragraph Generation Image captioning denotes the process of generating the description of a given image [3], which has received extensive research interests. Most of the approaches adopt the encoder-decoder framework to translate images into a single descriptive sentence. Generally, while image captioning tasks aim to generate one single sentence, radiology report generation needs to generate a long and complete report, which is more similar to image paragraph generation. Image paragraph generation produces a semantic-coherent and long paragraph to describe the input image, and a hierarchical recurrent network (HRNN) [16] is commonly adopted. In particular, the HRNN generates the paragraph based on the image features by a two-level RNN structure, which includes a paragraph RNN and a sentence RNN. The paragraph RNN is used to generate topic vectors and each topic vector is used by the sentence RNN to produce a sentence to describe the image. Radiology Report Generation Most previous studies simply transfer the image captioning methods to the task of radiology report generation. Treating the radiograph and corresponding report as an image-textual pair in captioning task, [13, 30, 32] employ a CNN-RNN structure to generate radiology descriptions. Considering the structured nature of radiology reports, [5, 23] utilizes memorized Transformer to generate professional radiology descriptions. Some studies notice the data imbalance problem and attempt to resolve the issue. References [20, 31] refer to posterior and prior knowledge in dataset to identify potential abnormalities and train networks sensitive to abnormal regions; [21] compare the target image with normal images to enhance the signal of abnormal regions; [19, 29] introduce a more powerful learning
296
L. Wang and J. Chen
strategy to encourage the models to pay more attention to abnormal samples. However, omitting the importance of how to better extract and utilize vision information limit the performance of all the approaches described above.
3 Method In order to make full use of visual information, we propose a framework that consists of a Multi-Head Adaptive Attention(MHAA) module and a more powerful visual extractor, as shown in Fig. 2. In this section, we introduce the three aspects of the proposed method: (1) the whole encoder-decoder framework for medical report generation; (2)our proposed Multi-Head Adaptive Attention; (3) the reason to choose CLIP and discussion about the future work.
3.1 Encoder-Decoder Framework We adopt the encoder-decoder framework for medical report generation, which contains a visual extractor as encoder and a transformer as decoder. It is worth noting that the decoder can be exchanged to any transformer-based network, such as R2GEN [5] and CMN [4].
Grid Features
Output Probabilities
Linear & Softmax
Add & Norm Feed Forward
… Multi-Head Attention
Add & Norm
Add & Norm
Visual Extractor
Encoder N
Multi-Head Adaptive Attention V K Q
Decoder N
Feed Forward Add & Norm
Add & Norm Masked Multi-Head Attention
Outputs shifted right
Output Embedding
Fig. 2 Overview of the proposed framework. The left part is the visual extractor; the mid part is a standard encoder block of Transformer; the right part is a modified decoder with Multi-Head Adaptive Attention module
Improving Radiology Report Generation with Adaptive Attention
297
Visual Extractor The visual extractor is used to extract the visual representations of the input radiology image, which can be any feature extraction network, such as VGG [27], ResNet [10], DenseNet [11] or even the image encoder of CLIP [25]. Because the output of the fully connected layer lacks some spatial information, we utilize the output of the last convolutional layer as the visual features, which is in shape of H × W × C, where H, W, C donate the height, width and channel. We regard the visual features as a vector {v1 , v2 , ..., v H ×W }, vi ∈ R C , which is fed to a subsequent Transformer. In all, given a radiology image X, its visual representations are extracted as following: {v1 , v2 , ..., v H ×W } = f V (X),
(1)
where f V denotes the visual extractor. In order to match the dimensions between Visual extractor and Transformer, we add a linear layer and an activation function ReLU after CNN, which can map Vi to Vi from C dimension to dmodel dimension, where dmodel is the dimension of word embedding and intermediate state of Transformer. Transformer The Transformer consists of a standard encoder and a modified decoder. The encoder part is stacked by standard Transformer encoder blocks [28] to transfer visual representations to hidden states, which can be formulated as: {h 1 , h 2 , ..., h H ×W } = f E ({v1 , v2 , ..., v H ×W }),
(2)
where f E denotes the standard encoder of Transformer and {h 1 , h 2 , ..., h H ×W } denotes the sequence of hidden states. The modified decoder part is also stacked by same blocks, unlike the standard Transformer, we replace the Multi-Head Attention module by our proposed Multi-Head Adaptive Attention (MHAA) module. yt = f D ({h 1 , h 2 , ..., h H ×W ; y1 , y2 , ..., yt−1 }),
(3)
where f D denotes the decoder and yi denotes the i th word of the target report. The entire generation process can be formalized as a recursive application of the chain rule: T p(yt |y1 , ..., yt−1 , X), (4) p(Y |X) = t=1
where Y = {y1 , y2 , ..., yT } is the target report and T is the length. The model is optimized by minimizing the cross-entropy loss: −logp(yt |y1 , ..., yt−1 , X; θ ), (5) θ ∗ = arg min θ
where θ are the parameters of the whole model.
298
L. Wang and J. Chen Q
V
K
Scaled Dot-product
Linear
K
Q
Multi-Head Attention
Linear
1 …
… Language Signals
Ouput of Multi-Head Adaptive Attention
Visual Signals
Select time step t as example
…
v1 v2
Concat & Softmax
vH W lt
Get last value
Concat by Time step
Fig. 3 Illustration of our Multi-Head Adaptive Attention module. This module ensures the model considering the weight of visual and language signals at each time step
3.2 Multi-Head Adaptive Attention Multi-Head Adaptive Attention (MHAA) module works at the decoder of the transformer, which introduces gate mechanism to control the spatial and language features when generating a certain word. MHAA consists of three steps: (1) prepare for the language and visual signals; 2) calculate the weight at each time step; (3) adaptive select the features by gating mechanism, which is shown in Fig. 3. The inputs of MHAA consist of the query, key and value, which is the same as the general Multi-Head Attention module. We donate them as Q, K , V respectively. In order to gain visual and language signals, the scaled dot-product operation and a external linear layer are adopted respectively. Then, in order to separately calculate the weight of each word, we split the language features from the time dimension and merged it with the visual feature to obtain the information vector corresponding to each word. We calculate the adaptive gate based on this vector by a softmax activation function. Afterwards, we keep the last value of the adaptive gate, which is actually the weight of the language information, and stitch it up by time step. We dotate it as α. Finally, MHAA selects the features by: Vadap = α ⊗ Linear (Q) + (1 − α) ⊗ M H A(Q, K , V ), where, Vadap is the output of MHAA.
(6)
Improving Radiology Report Generation with Adaptive Attention
299
3.3 Choice of Pretrained Visual Extractor Another key to generating an accurate report is to obtain sufficient visual features from the image, however, few methods pay attention to the importance of visual extractor. In medical report generation, almost all methods use the universal classification network, e.g., ResNet or DenseNet pretrained on ImageNet [8] to extract visual representation. However, it’s hard for the decoder to describe some detail information of image, e.g., the size of heart and the location of supporting devices. Inspired by CLIP-ViL [26], we adopt the CLIP (Contrastive Language-Image Pre-training) [25] pretrained visual encoder to extract visual information. CLIP is a model trained on a massive amount of image-caption pairs, which has shown a strong zero-shot capability on various vision tasks. To the best of our knowledge, it is the first try of CLIP on radiology report generation. In future work, we will explore the effectiveness of CLIP finetuned on chest radiography datasets. Besides, combining CLIP with vision-language pre-training on radiology datasets can provide more improvement.
4 Experiments and Results 4.1 Datasets In order to validate the effectiveness of the proposed method, we conduct experiments on two public datasets, as described below: • IU X-ray [7] is a widely-used public radiology dataset which is collected by Indiana University. It consists of 7,470 chest X-ray images along with 3,955 radiology reports from 3,955 patients. Following previous works [5, 12], we randomly split the IU X-Ray dataset into traininging, validation and testinging sets with a proportion of 7:1:2 at patients level, hence the patients in training, validation and testing sets do not overlap. • MIMIC-CXR [14] is a recently released dataset with 473,057 chest X-ray images and 206,563 radiology reports of 63,478 patients which is provided by the Beth Israel Deaconess Medical Center. Following the official splits [5], the dataset is divided into training, validation and testing sets with 368,960, 2,991 and 5,159 patients.
4.2 Model Evaluation Three widely-used metrics are adopted in our experiments to evaluate the performance of the proposed framework.
300
L. Wang and J. Chen
• BLEU [24] metrics are determined by comparing a candidate sentence with reference sentences in n-grams. Specifically, to determine BLEU-1, the candidate sentence is compared with reference sentences in unigram, while for calculating BLEU-2, bigram will be used for matching. For BLEU metrics, the unigram scores account for the adequacy, while higher n-gram scores account for the fluency. In general, BLEU-1 4 are used in experiment. • METEOR [18] first performs generalized unigram matches between a candidate sentence and a reference sentence, then computes a score based on the matching results. The computation involves precision, recall and alignments of the matched words. Introduction of this metric is for addressing weakness of the BLEU metric, which is derived only based on the precision of matched n-grams. • ROUGE-L [1] is designed to evaluate the adequacy and fluency of machine translation. This metric employs the longest common subsequence between a candidate sentence and a set of reference sentences to measure their similarity. For a fair comparison to previous methods [20, 21], these metrics are calculated with a standard toolkit [3], which is widely used in image captioning task.
4.3 Model Development and Hyper Parameter Tuning As stated in the method section, a modified RN50x4 trained based on CLIP [25] are adopted as the visual extractor to capture visual representations from radiographs. Other part of the proposed network, e.g., the Transformer part is randomly initialized. Both encoder and decoder of the Transformer have 3 layers with 8 attention heads and 512 dimensions for the hidden states. We adopt the ADAM optimizer [15] for parameter optimization. The learning rates of the visual extractor and other parameters are initially set to 5 × 10−5 and 1 × 10−4 , respectively, and decayed by 0.8 per epoch. In order to balance the generation effectiveness and efficiency, the beam size is set as 3 for both IU X-Ray and MIMIC-CXR. It is worth noting that, in order to ensure consistency with the experimental settings of the previous work [5, 21], we utilize paired images of the same patient as the input for IU X-Ray and utilize single image as the input for MIMIC-CXR.
4.4 Ablation Study To verify the effectiveness of our method, we perform a series of ablation studies. First of all, we perform our adaptive attention on three transformer-based methods as our baselines: (1) ‘transformer’: the original transformer [28]; (2) ‘R2GEN’: the R2GEN method [5]; (3) ‘CMN’: the Cross-modal Memory Network [4]. We replace the Multi-Head Attention in the decoder part of each baseline by our MHAA to verify the effectiveness. The results on the test set of IU X-Ray and MIMIC-CXR datasets
Improving Radiology Report Generation with Adaptive Attention
301
Table 1 Ablation study about of MHAA, carried on the IU X-Ray and MIMIC-CXR datasets Dataset IU X-Ray
MIMICCXR
Methods
BLEU-1
BLEU-2
BLEU-3
BLEU-4
METEOR ROUGEL
Transformer
0.463
0.295
0.207
0.155
0.161
0.366
Transformer+MHAA
0.475
0.307
0.220
0.168
0.185
0.383
R2GEN [5]
0.470
0.304
0.219
0.165
0.187
0.371
R2GEN+MHAA
0.487
0.324
0.233
0.180
0.210
0.393
CMN [4]
0.475
0.309
0.222
0.169
0.193
0.380
CMN+MHAA
0.489
0.322
0.229
0.175
0.195
0.383
Transformer
0.331
0.209
0.137
0.096
0.134
0.267
Transformer+MHAA
0.348
0.217
0.140
0.097
0.148
0.270
R2GEN [5]
0.353
0.218
0.145
0.103
0.142
0.277
R2GEN+MHAA
0.367
0.230
0.153
0.107
0.160
0.286
CMN [4]
0.353
0.218
0.148
0.106
0.142
0.278
CMN+MHAA
0.359
0.223
0.144
0.100
0.151
0.285
are reported in Table 1. As we can see, our MHAA improves all the three baselines on almost all the metrics on both IU X-Ray and MIMIC-CXR datasets. These results demonstrate the accuracy of our definition of the problem and the effectiveness of our strategies. Due to the well performance on both IU X-Ray and MIMIC-CXR datasets, we choose R2GEN as our base model to evaluate the importance of pretrained visual extractor, which is shown in Table 2. On top of that, the ‘CLIP-Res50x4’ further brings large improvements of 0.016 on IU X-Ray and 0.010 on MIMIC-CXR, respectively, in the BLEU-1 metric. Similar trends of improvements can be observed for the METEOR metric on both datasets, which shows the importance of more powerful pretrained visual extractor.
Table 2 Ablation study about of pretrained visual extractor, carried on the IU X-Ray and MIMICCXR datasets. The model used in this experiment is R2GEN+MHAA Dataset Visual Encoder BLEU-1 METEOR IU X-Ray MIMIC-CXR
ImageNet-Res101 CLIP-Res50x4 ImageNet-Res101 CLIP-Res50x4
0.487 0.503 0.367 0.377
0.200 0.212 0.160 0.166
302
L. Wang and J. Chen
Table 3 Performance of our approach and other state-of-the-art methods on the IU-Xray and MIMIC-CXR datasets Dataset Methods BLEU-1 BLEU-2 BLEU-3 BLEU-4 METEOR ROUGEL IU X-Ray HRGR [17] CMAS-RL [12] R2GEN [5] CMCL [19] CMN [4] CA [21] PPKED [20] Ours MIMICR2GEN [5] CXR CMCL [19] CMN [4] CA [21] PPKED [20] Ours
0.438 0.464
0.298 0.301
0.208 0.210
0.151 0.154
– –
0.322 0.362
0.470 0.473 0.475 0.492 0.483 0.503 0.353
0.304 0.305 0.309 0.314 0.315 0.328 0.218
0.219 0.217 0.222 0.222 0.224 0.232 0.145
0.165 0.162 0.170 0.169 0.168 0.172 0.103
0.187 0.186 0.191 0.193 0.190 0.212 0.142
0.371 0.378 0.375 0.380 0.376 0.395 0.277
0.344 0.353 0.350 0.360 0.377
0.217 0.218 0.219 0.224 0.237
0.140 0.148 0.152 0.149 0.156
0.097 0.106 0.109 0.106 0.111
0.133 0.142 0.151 0.149 0.166
0.281 0.278 0.283 0.284 0.287
4.5 Quantitative Evaluation We choose the ‘R2GEN+MHAA+CLIP-Res50x4’ as our proposed model. To further demonstrate the effectiveness, we compare our proposed method with several stateof-the-art radiology report generation methods on both IU X-Ray and MIMIC-CXR datasets, with the results presented in Table 3. The proposed method outperforms other approaches by large margins on all the metrics, demonstrating the effectiveness of adaptive attention mechanism and CLIP-based visual extractor. This also reflects the importance of visual feature utilization in medical report generation.
4.6 Qualitative Analysis In addition to the quantitative evaluation, we further perform a qualitative analysis to perceptually understand the improvements. Two representative samples are displayed in Fig. 4, where we use different colors to differentiate descriptions of different medical terms and underlined words for wrong descriptions. For the first sample, the baseline produces a trivial solution, i.e., a normal heart, wrongly describes the posterior pleural sinuses, and omits many descriptions. In contrast, our model successfully generates a more precise report. For the second sample, the baseline again generates a trivial solution for the lung volume and omits many medical terms, whereas our
Improving Radiology Report Generation with Adaptive Attention
303
Ground-truth
R2GEN
Ours
cardiac silhouette is markedly enlarged but stable in size with indwelling right atrial and right ventricular pacing leads unchanged in position. the lungs are well expanded and grossly clear except for a small calcified granuloma at the left lung apex. there are no pleural effusions or acute skeletal findings.
pa and lateral chest views were obtained with patient in upright position. the heart size remains unchanged and is within normal limits. Thoracic aorta mildly widened and elongated but no local contour abnormallities are seen. the pulmonary vasculature is not congested. posterior pleural sinuses are free. no pneumothorax in the apical area. skeletal structures of the thorax grossly within normal limits.
pa and lateral views of the chest provided. the lungs are clear. there are no acute osseous abnormalities. there is no pleural effusion or pneumothorax. moderate cardiomegaly is stable. no free air below the right hemidiaphragm is seen. the pulmonary vasculature is normal. there is a left-sided pacemaker with leads terminating in the right atrium and right ventricle.
Ground-truth
R2GEN
Ours
there are low lung volumes . the lungs are clear . there is no pleural effusion or pneumothorax . the cardiomediastinal silhouette is unremarkable . left central line terminates in the right atrium . median sternotomy wires and mediastinal clips are noted . a calcified lymph node is noted in the ap window .
lung volumes are low . heart size is mildly enlarged . the aorta is calcified and tortuous . mediastinal and hilar contours are unremarkable . crowding of bronchovascular structures is present without overt pulmonary edema . patchy opacities in the lung bases likely reflect areas of atelectasis . no focal consolidation pleural effusion or pneumothorax is present . there are no acute osseous abnormalities .
ap portable upright view of the chest . midline sternotomy wires and mediastinal clips are again noted . there is a right chest wall port-a-cath with its tip in the mid svc . a calcific density in the region of the ap window corresponds with a calcified lymph node on prior ct . lung volumes are low limiting evaluation . no large effusion or pneumothorax .
Fig. 4 Illustrations of reports from Ground-truth, R2GEN and R2GEN+MHAA for two X-Ray chest images. For a clear comparison, different structures are marked with different colors
model describes almost all the medical terms mentioned in ground truth. We thus conclude that our approach can enhance the power of radiology report generation.
5 Conclusion In this study, we proposed a novel method for radiology report generation. Having noticed the importance of visual information, we discussed how to effectively utilize visual information and how to obtain richer visual features. We firstly apply a MultiHead Adaptive Attention module to control the visual features used in generation each word. Then, the more powerful pretrained visual extractor is adopted. Besides, a pretrained vision-language model based on radiology report dataset can be explored in the future work. At last, extensive experiments on two public datasets demonstrate the superiority of the proposed method comparing with SOTA methods. Acknowledgements This work is supported by the Nature Science Foundation of China (No. 61972217, No.62081360152), Natural Science Foundation of Guangdong Province in China (No.2019B1515120049, 2020B1111340056).
304
L. Wang and J. Chen
References 1. Banerjee, S., & Lavie, A. (2005). Meteor: An automatic metric for MT evaluation with improved correlation with human judgments. In: Proceedings of the ACL Workshop on Intrinsic and Extrinsic Evaluation Measures for Machine Translation and/or Summarization (pp. 65–72). 2. Brady, A., Laoide, R. Ó., McCarthy, P., & McDermott, R. (2012). Discrepancy and error in radiology: Concepts, causes and consequences. The Ulster Medical Journal, 81(1), 3. 3. Chen, X., Fang, H., Lin, T.Y., Vedantam, R., Gupta, S., Dollár, P., & Zitnick, C. L. (2015). Microsoft coco captions: Data collection and evaluation server. arXiv:1504.00325. 4. Chen, Z., Shen, Y., Song, Y., & Wan, X. (2021). Cross-modal Memory Networks for Radiology Report Generation, 1. 5. Chen, Z., Song, Y., Chang, T.H., & Wan, X. (2020). Generating radiology reports via memorydriven transformer. arXiv:2010.16056. 6. Delrue, L., Gosselin, R., Ilsen, B., Van Landeghem, A., de Mey, J., & Duyck, P. (2011). Difficulties in the interpretation of chest radiography. In: Comparative Interpretation of CT and Standard Radiography of the Chest (pp. 27–49). Springer. 7. Demner-Fushman, D., Kohli, M. D., Rosenman, M. B., Shooshan, S. E., Rodriguez, L., Antani, S., Thoma, G. R., & McDonald, C. J. (2016). Preparing a collection of radiology examinations for distribution and retrieval. Journal of the American Medical Informatics Association, 23(2), 304–310. 8. Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., & Fei-Fei, L. (2009). Imagenet: A largescale hierarchical image database. In: 2009 IEEE Conference on Computer Vision and Pattern Recognition (pp. 248–255). IEEE. 9. Goergen, S. K., Pool, F. J., Turner, T. J., Grimm, J. E., Appleyard, M. N., Crock, C., Fahey, M. C., Fay, M. F., Ferris, N. J., Liew, S. M., et al. (2013). Evidence-based guideline for the written radiology report: Methods, recommendations and implementation challenges. Journal of Medical Imaging and Radiation Oncology, 57(1), 1–7. 10. He, K., Zhang, X., Ren, S., & Sun, J. (2016). Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 770–778). 11. Huang, G., Liu, Z., Van Der Maaten, L., & Weinberger, K. Q. Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 4700–4708). 12. Jing, B., Wang, Z., & Xing, E. (2020). Show, describe and conclude: On exploiting the structure information of chest x-ray reports. arXiv:2004.12274. 13. Jing, B., Xie, P., & Xing, E. (2017). On the automatic generation of medical imaging reports. arXiv:1711.08195. 14. Johnson, A. E., Pollard, T. J., Greenbaum, N. R., Lungren, M. P., Deng, C. Y., Peng, Y., Lu, Z., Mark, R. G., Berkowitz, S. J., & Horng, S. (2019). Mimic-cxr-jpg, a large publicly available database of labeled chest radiographs. arXiv:1901.07042. 15. Kingma, D.P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv:1412.6980. 16. Krause, J., Johnson, J., Krishna, R., Fei-Fei, L.: A hierarchical approach for generating descriptive image paragraphs. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 317–325). 17. Li, C.Y., Liang, X., Hu, Z., & Xing, E. P. (2018). Hybrid retrieval-generation reinforced agent for medical image report generation. arXiv:1805.08298. 18. Lin, C. Y. (2004). Rouge: a package for automatic evaluation of summaries. In: Text summarization branches out (pp. 74–81). 19. Liu, F., Ge, S., & Wu, X. (2021). Competence-based multimodal curriculum learning for medical report generation. In: ACL (Vol. 1, p. 3). 20. Liu, F., Wu, X., Ge, S., Fan, W., & Zou, Y. (2021). Exploring and distilling posterior and prior knowledge for radiology report generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 13753–13762).
Improving Radiology Report Generation with Adaptive Attention
305
21. Liu, F., Yin, C., Wu, X., Ge, S., Zhang, P., & Sun, X. (2021). Contrastive attention for automatic chest x-ray report generation. arXiv:2106.06965. 22. Lovelace, J., & Mortazavi, B. (2020). Learning to generate clinically coherent chest x-ray reports. In: Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: Findings (pp. 1235–1243). 23. Miura, Y., Zhang, Y., Tsai, E. B., Langlotz, C. P., & Jurafsky, D.: Improving factual completeness and consistency of image-to-text radiology report generation. arXiv:2010.10042. 24. Papineni, K., Roukos, S., Ward, T., & Zhu, W. J. (2002). Bleu: A method for automatic evaluation of machine translation. In: Proceedings of the 40th Annual Meeting of the Association for Computational Linguistics (pp. 311–318). 25. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., Sastry, G., Askell, A., Mishkin, P., & Clark, J., et al. (2021). Learning transferable visual models from natural language supervision. arXiv:2103.00020. 26. Shen, S., Li, L.H., Tan, H., Bansal, M., Rohrbach, A., Chang, K. W., Yao, Z., & Keutzer, K. (2021). How much can clip benefit vision-and-language tasks? arXiv:2107.06383. 27. Simonyan, K., Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv:1409.1556. 28. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In: Advances in neural information processing systems (pp. 5998–6008). 29. Wang, Z., Zhou, L., Wang, L., & Li, X. (2021). A self-boosting framework for automated radiographic report generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 2433–2442). 30. Xue, Y., Xu, T., Long, L. R., Xue, Z., Antani, S., Thoma, G. R., & Huang, X. (2018). Multimodal recurrent model with attention for automated radiology report generation. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 457–466). Springer. 31. Yang, X., Ye, M., You, Q., & Ma, F. (2021). Writing by memorizing: Hierarchical retrievalbased medical report generation. arXiv:2106.06471. 32. Yuan, J., Liao, H., Luo, R., & Luo, J. (2019). Automatic radiology report generation based on multi-view image fusion and medical concept enrichment. In: International Conference on Medical Image Computing and Computer-Assisted Intervention (pp. 721–729). Springer.
Instantaneous Physiological Estimation Using Video Transformers Ambareesh Revanur, Ananyananda Dasari, Conrad S. Tucker, and László A. Jeni
Abstract Video-based physiological signal estimation has been limited primarily to predicting episodic scores in windowed intervals. While these intermittent values are useful, they provide an incomplete picture of patients’ physiological status and may lead to late detection of critical conditions. We propose a video Transformer for estimating instantaneous heart rate and respiration rate from face videos. Physiological signals are typically confounded by alignment errors in space and time. To overcome this, we formulated the loss in the frequency domain. We evaluated the method on the large scale Vision-for-Vitals (V4V) benchmark. It outperformed both shallow and deep learning based methods for instantaneous respiration rate estimation. In the case of heart-rate estimation, it achieved an instantaneous-MAE of 13.0 beats-per-minute. Keywords Transformer architecture · Physiological estimation · Machine learning
1 Introduction Contact-based devices (e.g. pulse-oximeter) are prevalent among healthcare professionals for assessing and monitoring the vital signs of patients in hospital settings. These vital sign monitoring devices require physical contact and can cause discomGithub link: https://github.com/revanurambareesh/instantaneous_transformer. A. Revanur (B) · L. A. Jeni Robotics Institute, Carnegie Mellon University, Pittsburgh, USA e-mail: [email protected] L. A. Jeni e-mail: [email protected] A. Dasari · C. S. Tucker Department of Mechanical Engineering, Carnegie Mellon University, Pittsburgh, USA e-mail: [email protected] C. S. Tucker e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_22
307
308
A. Revanur et al.
fort to patients. As remote diagnosis is becoming increasingly common, partly due to the recent COVID-19 pandemic, there is a pressing demand for non-contact physiological estimation methods. Over the years, conventional photoplethysmography (PPG), the contact-based optical estimation of microvascular blood volume changes, has evolved into contactless imaging PPG (iPPG). These methods utilize digital cameras and computer vision techniques for estimating heart-generated pulse waves and their respiratory modulation by means of peripheral blood perfusion measurements. Past research has demonstrated that these bio-signals can be extracted with high fidelity in a strictly controlled environment; i.e. with a prediction error < 3 beats-per-min for Heart Rate extraction [3]. Intuitively, the camera captures subtle periodic color variations that result from the blood volume changes in the underlying skin tissues. A careful analysis of the subtle changes in the video reveals the physiological state. Previous video-based physiological extraction techniques have been limited to predicting episodic heart rate values over large, non-overlapping windows (e.g. 30 s), which influenced the choice of performance metrics to evaluate these methods. The window-based evaluation protocol does not provide complete insight into Heart Rate Variability (HRV), which plays an important role in understanding the physical and mental conditions of an individual [15]. This evaluation gap has been considered in a recent physiological challenge organized at the ICCV conference called “Visionfor-Vitals” (V4V). Instead of using non-overlapping windows to measure the error, the challenge organizers [18] proposed using an instantaneous (a.k.a. continuous) evaluation protocol that measures the performance of a method at a per-frame level. Hence, in our work, we aim to benchmark different methods including our proposed method using this continuous evaluation protocol (Fig. 1). Recently, deep learning based solutions have been proposed for the task of human physiology estimation [3, 14]. One of the most popular methods in this direction is DeepPhys [3]. This method utilizes the optical principles of PPG to predict blood volume pulse (and respiratory wave for respiration rate estimation) from a facial video. Even though DeepPhys has made significant progress, it still limited to episodic evaluation and is unable to fully exploit the temporal and periodic nature of the blood volume pulse. To remedy this, we draw motivation from recent work on video analysis [13] and utilize a Transformer based architecture for frame-level prediction. There is a body of literature that shows that the Transformer is effective in modelling temporal sequences [2, 21]. The paper advances two main novelties: • In our work we use a Transformer based architecture for instantaneous prediction and evaluation of human physiological signals from the facial video. We formulate the loss function in the frequency domain to minimize confounds introduced by the temporal misalignment of the video and the PPG signals. • We evaluate the method on the challenging Vision-for-Vitals (V4V) dataset. We show that our approach reliably estimates Heart Rate (HR) and Respiration Rate (RR), outperforming shallow and deep-learning methods trained on the same data (V4V training set).
Instantaneous Physiological Estimation Using Video Transformers
309
Fig. 1 Traditional evaluation v/s instantaneous evaluation (Ours) for heart rate and respiration rate estimation
2 Related Work 2.1 Video Based Physiology Extraction Based on the principles of remote-photoplethysmography (rPPG), several methods have been advanced [4, 5, 16, 17, 20, 22–24] for the extraction of physiological signals from facial videos. In [22], the authors highlighted that the green channel of the RGB video can be used to compute the heart rate since the green channel has the strongest signature of photoplethysmography. Wang et al. [23] proposed a mathematical model for the reflection properties of skin and developed a novel rPPG method based on the model. Further, researchers have utilized face detection and tracking methods such as Bounded Kalman Filter for extracting facial regions of interest [4, 17]. More recently, deep learning approaches [3, 10, 11] have been proposed for the task of physiological estimation. One of the main aspect of DeepPhys [3] and MTTS-CAN [10] architectures is the spatial attention mechanism which is used to determine the right regions of interest thereby enabling end-to-end trainability of network. However, all of these methods are evaluated on episodic scores over
310
A. Revanur et al.
windowed intervals. To tackle this limitation, the Vision-for-Vitals workshop [18] held at ICCV introduced multiple metrics to promote instantaneous prediction of HR and RR. In our work, we aim to evaluate all methods including ours over these metrics. We also aim to incorporate spatial attention masks and utilize Transformer for temporal learning.
2.2 Transformers A Transformer [21] consists of a Transformer-encoder and a Transformer-decoder which in turn are composed of several multi-headed self-attention layers. In [21], the authors demonstrated high accuracy of Transformer based architecture for multiple language translation tasks. With minor modifications to the proposed Transformer architecture, it has been successfully adapted to a wide range of research problems in computer vision and natural language processing. Particularly, the computer vision community has explored the Transformer based architecture in two forms. In one form, the Transformer based architecture includes a convolutional neural network (CNN) backbone that is used as a feature extractor [2, 9, 13]. In [2] the authors employed ImageNet pretrained ResNet based backbone as a spatial feature extractor for the task of object detection using Transformers in an end-to-end manner. In other related work, video Transformers [13] have been employed for the goal of temporal modeling of the videos. Here, the convolutional backbone extracts features and the Transformer is used for temporal modelling. In our work, we use a DeepPhys [3] based convolutional backbone network and a Transformer for temporal modeling. In another form, the architecture is developed purely using Transformer layers. In [6] the authors trained a pure transformer for the task of image classification by dividing the image into multiple parts. In a related work, [25] aimed at detecting fake 3D printed face masks using Transformer based architecture by drawing motivation from the principles of Photoplethysmography. This is achieved by feeding the Multi-scale Spatio-Temporal maps (MSTmaps) of the facial regions and background regions along with a positional embedding into a Transformer network. In this work, we focus on developing a method for the task of instantaneous evaluation of physiological signals by using video Transformer. We aim to train the network in an end-to-end manner by relying on spatial attention module in DeepPhys.
3 Methods The goal of remote PPG extraction is to effectively extract a bio-signal that contains HR (or RR) using a facial video. To this end, we propose a Transformer-based architecture, inspired by the principles of remote PPG. In this section, we first introduce the optical basis of our method by relying on the skin reflection model [3, 23]. Next,
Instantaneous Physiological Estimation Using Video Transformers
311
we propose the architecture for the HR/RR estimation and finally explain our loss formulation that we used for training the model.
3.1 Optical Basis of Video-Based Bio-Signal Extraction The changes in the volume of blood flowing underlying the facial skin result in subtle color changes. In order to extract this bio-signal, we use the popular skin reflection model that is based on Shafer’s dichromatic reflection [3, 23]. At a given time instance t in the video, the reflection of the light back to the camera can be considered as a function that varies in the RGB color space according to the Eq. 1. Ck (t) = i(t) · (vs (t) + vd (t)) + vn (t)
(1)
Here, Ck (t) is the color intensity of the RGB pixel k, i(t) luminance intensity level which is regulated by specular relectance vs and diffuse reflectance vd . The term vn is the camera quantization noise. The specular reflection is a mirror-like reflection that bounces the light right off the facial skin while the diffuse component contains useful pulsatile signals. The components vd (t) and vs (t) can be expressed further in terms of stationary reflection strength, underlying physiological bio-signal p(t) and motion induced changes m(t) (e.g. facial movements, expressions). vd (t) = ud · d0 + u p · p(t)
(2)
vs (t) = us · (s0 + Φ(m(t), p(t)))
(3)
Here ud is the stationary skin reflection strength and u p is pulsatile signal strength that varies according to the volume of hemoglobin. Notice how Eq. 2 does not depend on m(t), while Eq. 3 depends on both m(t) and p(t). Next, i(t) can be further decomposed into stationary and varying components according to, i(t) = i 0 · (1 + Ψ (m(t), p(t)))
(4)
where, us is the unit norm vector indicating the color of light spectrum and s0 is the stationary component of the specular reflection and Φ is a function of motion m(t) and the physiological signals p(t). Substituting Eqs. 2, 3 and 4 into Eq. 1, we obtain Eq. 5. Ck (t) ≈ uc · i 0 ·c0 + uc · i 0 · c0 · Ψ (m(t), p(t))+ us · i 0 · Φ(m(t), p(t)) + u p · i 0 · p(t) + vn (t)
(5)
312
A. Revanur et al.
As a next step, we remove the dominant stationary signal by computing first order derivative of Ck (t) in line with [3]. Additionally, we reduce the image size to a size of 36 × 36 to suppress the quantization noise i.e. vn ≈ 0. Ck (t) ≈ uc · i 0 · c0 · (
∂Ψ ∂Ψ m (t) + p (t))+ ∂m ∂p ∂Φ ∂Φ m (t) + p (t) + u p · i 0 · p (t) (6) us · i 0 · ∂m ∂p
As an additional normalization step, we also divide Ck (t) by uc · i 0 · c0 to remove further stationary luminance intensity across all pixels. In practice, we approximate Ck (t) by computing the pixel-wise difference for consecutive pairs of frames and normalizing it by using mean frame. We remove the constant factor of 2 from the resulting expression as it is just a scaling term. Therefore, the we obtain Mk (t), which we use for training our model (Eq. 7). Mk (t) = (Ck (t + 1) − Ck (t)) (Ck (t + 1) + Ck (t))
(7)
where is the Hadamard division (element-wise) operation.
3.2 Video Transformer for Physiological Estimation In this section we propose a Video Transformer for the task of human physiological estimation. As shown in Fig. 2, the video transformer consists of a spatial backbone P and frame-level temporal aggregation module T to assist with learning temporal dependencies of the bio-signal waveform. For the spatial backbone we use the popular DeepPhys-based architecture and for the temporal module, we use the Transformerencoder architecture. Therefore, our video Transformer is an end-to-end trainable framework unlike the approaches which use facial landmark detection for selecting pixels in the region of interest [11]. We begin by describing the DeepPhys-based spatial backbone network depicted in Fig. 3. The network consists of two branches for modeling the motion representation and spatial attention. The motivation for using the spatial attention branch is to learn regions of face (via the attention masks) that could assist the motion branch for physiological estimation. The spatial attention branch is trained on a video frame Ck (t) of dimension 36 × 36, and the motion branch is trained on the normalized frame differences Mk (t) of dimension 36 × 36 (Eq. 7). The attention mask q is given by, q=
(h z wz ) · za 2za 1
(8)
Instantaneous Physiological Estimation Using Video Transformers
313
Postprocessing linear
~ Transformer Encoder
.... Preprocessing
....
Fig. 2 Video transformer for human physiological signal extraction
where h z , wz are the dimensions of the input feature map, za are the sigmoid activations of the features from spatial attention branch (Fig. 3). These attention masks are multiplied element-wise across the features of the motion branch. Finally, the FC-layer of P encodes the spatial features into a vector of dimension d (Fig. 3) We achieve frame-level temporal aggregation by utilizing a Transformer encoder T [21]. The encoder network is trained on the FC-features of P using N frames of the video where, N is the number of frames used for temporal aggregation. We pass all the features of P through a linear layer that reduces d dimensional vector into a dT = 32-dimensional embeddings. Along with these embeddings, we additionally include [C L S] token for training the encoder network. Each encoder layer consists of a multi-headed self-attention block with 8 self-attention modules and an MLP layer. Further, at each input step of the Transformer, we also include positional encoding (P E) to inject the temporal information into the model. To this end, we utilize the position-encoding layer proposed by [21] and is controlled according to Eqs. 9 and 10. P E ( pos,2i) = sin( pos/100002i/dT )
(9)
P E ( pos,2i+1) = cos( pos/100002i/dT )
(10)
314
A. Revanur et al.
linear
conv x2
conv x2 attention
attention
conv x2
conv x2
Fig. 3 DeepPhys-based encoder P as a spatial backbone network for video Transformer. Here, σ denotes the sigmoid activation and q is the attention mask
where, pos is the position and i is the dimension in the position embedding. The output features of the encoder layers are then fed into a single MLP layer to obtain yˆ .
3.3 Loss Formulation We train our model end-to-end using a ground truth signal such as blood pressure (for HR) or respiratory wave (for RR). One straight-forward way to train the model is by using a Mean-Squared Error (MSE) loss. However, MSE loss assumes that the ground truth is accurately synchronized with the bio-signal in the facial video. Unfortunately, it is challenging to perfectly synchronize the bio-signal with ground-truth for two reasons. First the devices used for ground-truth video capture and physiological signal recording are different. Therefore, one has to manually align the video-frames with the ground-truth signal. Second, the ground truth is often collected at a peripheral site such as finger. Therefore, there is additional delay resulting from the Pulse-Transit Time (PTT). A related work [1] shows that the time delay for pulse transit between ear and finger is close to 150ms. One of the other limitations of MSE loss is that it trains the model to learn both amplitude and frequency of a wave. However, for the task of HR/RR estimation, we are interested only in the frequency of the underlying pulsatile signal and not the amplitude of the signal. Therefore, we make use of Maximum Cross-Correlation loss l(y, yˆ ) [7] and perform the cross-correlation computation in the frequency domain instead of time domain.
Instantaneous Physiological Estimation Using Video Transformers
l(y, yˆ ) = −c · Max
F −1 {Ω(F(y) · F( yˆ ))} σ y σ yˆ
315
(11)
where y is the ground-truth as computed by signal differences p and yˆ is the predicted waveform. Further, Ω is a bandpass operator which retains only frequencies of interest, c is the ratio of power present inside the frequency range of heart rate to the total power, F is the Fourier-transform operator and (·) is the conjugate operator.
4 Results 4.1 Implementation Details We reduced the input image size to 36 × 36 inline with [3] and computed the normalized input frame difference for motion branch. For training the video Transformer, we fixed the number of input frames to N (where N = 100 for HR estimation and N = 1000 for RR estimation) and trained our network end-to-end. During inference, we used the predictions from T and computed cumulative sum to obtain the final waveform prediction. After that, we calculated the Fourier Transform of the waveform and applied bandpass filter to limit the frequencies within the range of HR / RR and obtained yˆ . For HR model, we used d = 128 and for RR model, we used d = 32.
4.2 Datasets and Evaluation Protocol We use Vision-for-Vitals (V4V) dataset that consists of 179 subjects and 1358 videos in total. The V4V dataset contains continuous blood pressure waveform recorded at 1KHz, frame-aligned HR and frame-aligned RR. We use the V4V training dataset for training the model and report the performance of our model on both V4V validation set and V4V test set. We follow the evaluation protocol set forth in the V4V challenge [18] and report continuous MAE (cM AE) and continuous RMSE (c R M S E). cM AE =
H R i − H Ri | Σi | N
cRM SE =
(Σi | H R i − H R i |2 ) N
(12)
(13)
where, H R i and H Ri are the predicted HR and ground-truth HR for the frame i respectively and N is total number of frames in test-set. We use the same evaluation
316
A. Revanur et al.
Table 1 Comparison of our method against previous works for HR estimation on the V4V validation set and V4V test set. Note that lower cM AE and lower c R M S E are better (↓) Name cM AE cRM SE cM AE cRM SE Val. set (↓) Val. set (↓) Test set (↓) Test set (↓) Green [22] POS [23] ICA [16] DeepPhys [3] TS-CAN [10] Ours
16.5 17.3 13.9 13.6 11.7 10.3
21.4 21.2 20.0 18.1 17.8 16.1
15.5 15.3 15.1 14.7 13.9 13.0
21.9 21.8 20.6 19.7 19.2 18.8
protocol for benchmarking RR results. Further, in order to enable evaluation of all methods on continuous evaluation protocol, we use a short moving window over the predicted blood volume pulse for HR (and predicted respiratory wave for RR) and employed FFT to predict continuous HR.
4.3 Heart Rate Estimation Results Table 1 shows the comparison of our method against traditional non-deep learning methods (implemented in [12])—Green [22], POS [23] and recent deep-learning methods—DeepPhys [3] and TS-CAN [10]. It is important to note, that for fair comparison, we excluded studies that either utilize external training data [8] or access the test set for domain adaptation [19] on this benchmark. For computing the HR, we used a bandpass filter with range of [0.7, 2.5].
4.4 Spatial Attention Mask The spatial attention mask offers visibility into where the model is extracting the HR and RR in a given frame. As shown in Fig. 4, the base encoder is able to focus on the regions corresponding to facial skin for extraction of physiological signals. Notice how the model is excluding facial accessories such as eye-glasses (top-left subject in Fig. 4a) and facial hair (bottom-right subject in Fig. 4a).
Instantaneous Physiological Estimation Using Video Transformers
HR model
a.
317
RR model
b.
Fig. 4 Spatial attention masks obtained for the HR model (left) and the RR model (right) on V4V test set Table 2 Comparison of our method against previous works for RR estimation on the V4V validation set and V4V test set. Note that lower cM AE and lower c R M S E are better (↓) Name cM AE cRM SE cM AE cRM SE Val. set (↓) Val. set (↓) Test set (↓) Test set (↓) Green [22] POS [23] ICA [16] DeepPhys [3] Ours
5.9 6.1 6.4 5.0 4.8
6.8 6.9 7.2 6.1 5.6
7.0 6.5 5.8 5.5 5.4
7.5 6.9 6.2 5.9 6.0
4.5 Respiration Rate Estimation Results We train our proposed model on the continuous respiration waveform of the V4V training dataset and we report the results on the V4V validation set and the V4V test set in Table 2. We compare our results with traditional and deep-learning based approaches. For computing the RR, we extracted the biosignal [20] and used a bandpass filter with range of [0.13, 0.34]. Results indicate that our method performs better than the other approaches on the V4V dataset.
5 Conclusion In this paper, we take a step towards instantaneous prediction of physiological signals by utilizing a Transformer based architecture for extracting the heart rate and respiration rate from facial videos. We train the video Transformer model in an endto-end manner using a cross-correlation loss in the frequency domain. The results of our approach over continuous evaluation metric using Vision-for-Vitals (V4V) dataset shows that the model is able to outperform both shallow and deep learning
318
A. Revanur et al.
methods on the task of heart rate and respiration rate estimation. As part of future work Video Transformers can be used to tackle the domain shift problem (laboratory to real-world) and can be used to extract other physiological signals such as Oxygen Saturation (SpO2 ). Acknowledgements This project is funded by the Bill & Melinda Gates Foundation (BMGF). Any opinions, findings, or conclusions are those of the authors and do not necessarily reflect the views of the sponsors.
References 1. Block, R. C., Yavarimanesh, M., Natarajan, K., Carek, A., Mousavi, A., Chandrasekhar, A., Kim, C. S., Zhu, J., Schifitto, G., & Mestha, L.K., et al. (2020). Conventional pulse transit times as markers of blood pressure changes in humans. Scientific Reports, 10(1). 2. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., & Zagoruyko, S. (2020). Endto-end object detection with transformers. In: European Conference on Computer Vision. 3. Chen, W., & McDuff, D. (2018). Deepphys: Video-based physiological measurement using convolutional attention networks. In: Proceedings of the European Conference on Computer Vision (ECCV). 4. Dasari, A., Prakash, S. K. A., Jeni, L. A., & Tucker, C. (2021). Evaluation of biases in remote photoplethysmography methods. NPJ Digital Medicene. 5. De Haan, G., & Jeanne, V. (2013). Robust pulse rate from chrominance-based rppg. IEEE Transactions on Biomedical Engineering, 60(10). 6. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., & Gelly, S., et al. (2020). An image is worth 16 × 16 words: Transformers for image recognition at scale. arXiv:2010.11929. 7. Gideon, J., & Stent, S. (2021). The way to my heart is through contrastive learning: Remote photoplethysmography from unlabelled video. In: Proceedings of the IEEE/CVF International Conference on Computer Vision. 8. Hill, B., Liu, X., & McDuff, D. (2021). Beat-to-beat cardiac pulse rate measurement from video. In: Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 9. Lin, K., Wang, L., & Liu, Z. (2021). End-to-end human pose and mesh reconstruction with transformers. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 1954–1963), June 2021. 10. Liu, X., Fromm, J., Patel, S., & McDuff, D. (2020). Multi-task temporal shift attention networks for on-device contactless vitals measurement. arXiv:2006.03790. 11. Lu, H., Han, H., & Zhou, S. K. (2021). Dual-gan: Joint bvp and noise modeling for remote physiological measurement. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (2021) 12. McDuff, D., & Blackford, E. (2019). iphys: An open non-contact imaging-based physiological measurement toolbox. In: 2019 41st Annual International Conference of the IEEE Engineering in Medicine and Biology Society (EMBC). IEEE. 13. Neimark, D., Bar, O., Zohar, M., & Asselmann, D. (2021). Video transformer network. arXiv:2102.00719. 14. Niu, X., Yu, Z., Han, H., Li, X., Shan, S., & Zhao, G. (2020). Video-based remote physiological measurement via cross-verified feature disentangling. In: European Conference on Computer Vision. 15. Pereira, T., Tran, N., Gadhoumi, K., M. Pelter, M., Do, D.H., Lee, R.J., Colorado, R., Meisel, K., & Hu, X. (2020). Photoplethysmography based atrial fibrillation detection: a review. NPJ Digital Medicene.
Instantaneous Physiological Estimation Using Video Transformers
319
16. Poh, M. Z., McDuff, D. J., & Picard, R. W. (2010). Non-contact, automated cardiac pulse measurements using video imaging and blind source separation. Optics Express, 18(10). 17. Prakash, S. K. A., & Tucker, C. S. (2018). Bounded kalman filter method for motion-robust, non-contact heart rate estimation. Biomedical Optics Express, 9(2). 18. Revanur, A., Li, Z., Ciftci, U. A., Yin, L., & Jeni, L. A. (2021). The first vision for vitals (v4v) challenge for non-contact video-based physiological estimation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 19. Stent, S., & Gideon, J. (2021). Estimating heart rate from unlabelled video. In: Proceedings of the IEEE/CVF International Conference on Computer Vision Workshops. 20. Tarassenko, L., Villarroel, M., Guazzi, A., Jorge, J., Clifton, D., & Pugh, C. (2014). Noncontact video-based vital sign monitoring using ambient light and auto-regressive models. Physiological Measurement, 35(5). 21. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In: Advances in neural information processing systems (pp. 5998–6008). 22. Verkruysse, W., Svaasand, L. O., & Nelson, J. S. (2008). Remote plethysmographic imaging using ambient light. Optics Express, 16(26). 23. Wang, W., den Brinker, A. C., Stuijk, S., De Haan, G. (2016). Algorithmic principles of remote PPG. IEEE Transactions on Biomedical Engineering, 64(7). 24. Wu, H. Y., Rubinstein, M., Shih, E., Guttag, J., Durand, F., Freeman, W. T. (2012). Eulerian video magnification for revealing subtle changes in the world. ACM Transactions on Graphics (Proceedings of the SIGGRAPH 2012), 31(4). 25. Yu, Z., Li, X., Wang, P., & Zhao, G. (2021). Transrppg: Remote photoplethysmography transformer for 3d mask face presentation attack detection. IEEE Signal Processing Letters.
Automated Vision-Based Wellness Analysis for Elderly Care Centers Xijie Huang, Jeffry Wicaksana, Shichao Li, and Kwang-Ting Cheng
Abstract The growth in the aging population require caregivers to improve both efficiency and quality of healthcare. In this study, we develop an automatic, visionbased system for monitoring and analyzing the physical and mental well-being of senior citizens. Through collaboration with Haven of Hope Christian Service, we collect video recording data in the care center with surveillance camera. We then process and extract personalized facial, activity, and interaction features from the video data using deep neural networks. This integrated health information systems can assist caregivers to gain better insights into the seniors they are taking care of. We report findings of our analysis and evaluate the system quantitatively to demonstrate the effectiveness.
1 Introduction The need to rethink the care quality for elderly citizens is getting more important as the growth of the older population outpaces the growth of available caregivers. During the COVID-19 pandemic, there have been several outbreaks in the elderly care centers because caregivers need to serve multiple care centers. To address the X. Huang (B) Center for Aging Science, Department of Computer Science and Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong e-mail: [email protected] J. Wicaksana Department of Electronic and Computer Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong e-mail: [email protected] S. Li · K.-T. Cheng Department of Computer Science and Engineering, Hong Kong University of Science and Technology, Clear Water Bay, Hong Kong e-mail: [email protected] K.-T. Cheng e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_23
321
322
X. Huang et al.
Fig. 1 Overview of our framework: a human recognition module detects individual elderly from the video clips, then feed the individual videos into three modules: facial analysis, activity detection, and scene understanding. The final target is to perform analysis, both daily and long-term, and provide immediate assistance for any detected anomaly
manpower issue, we aim to improve the efficiency of each caregiver by providing them a vision-based tool to automatically provide assistive insights about each senior citizen. The representational power brought by the development of deep learning approaches and increased amount of annotation has revolutionized the way of processing visual data. Such visual understanding models lead to new opportunities for human-centric healthcare. For example, human activity classification from video data can be extended to classify clinically relevant activities and monitor the occurrence of abnormal events such as falls [1–4]. To improve elderly healthcare quality through computer vision-based solution, we partner with the Haven of Hope Christian Service elderly care centers in Hong Kong to develop such a system. Cameras were installed to capture the daily activities of senior citizens which are automatically analyzed with the goal of providing meaningful insights relevant to senior citizens’ wellness and their trends. Results are then provided to caregivers to adjust their care strategy regarding who, when, and where they should pay greater attention to. We collected about one month of video data from one daycare center of demented elder to drive the development and evaluation of our automated vision-based system. The overall framework of our system is illustrated in Fig. 1. After automatically identifying each individual, we analyze three major aspects of each individual senior citizen: the person’s facial information, physical activities, and social interaction. Facial activities include but not limited to facial expression, yawning, blinking, and talking. The derived information with respect to these three components are then summarized and visualized to help caregivers better understand senior citizens’ daily behaviors and patterns over time. Additionally, detected anomaly that requires immediate assistance will also be provided to the caregivers.
Automated Vision-Based Wellness Analysis for Elderly Care Centers
323
In the care center, daily guided physical exercises are designed for senior citizens to keep them physically active. In order to evaluate and improve the efficacy of daily exercises and better tailor them for different senior citizens, we automatically analyze the physical movements of each person during the exercises. We define a metric, exercise intensity score, to summarize how active each senior citizen is when performing the physical exercises. We track the activeness level of each individual over time and when there is a significant deviation from their standard activeness level, caregivers can be alerted automatically. Our current system can produce two types of insightful visualization of the analysis results: the activity temporal heatmaps and the scene graphs. A temporal heatmap records various activities and their duration engaged by each person throughout a day. A scene graph focuses on capturing social interaction among individuals and the interaction between the elderly and inanimate objects such as television. By tracking this information over a month, we gain significant insights to the behavior patterns of a senior citizen. Based on these visualization, caregivers can better update their caregiving strategy accordingly. Instead of monitoring the status of a single elderly, caregivers are able to monitor and compare the wellness of a group of elderly with our system. In the long term, caregivers will be notified when there is abnormal change of temporal health metric of individual and personalized care can be provided. To quantitatively evaluate the detection and recognition accuracy of our system, we label some clips of video recording for various activities. We designed experiments to show that our system can achieve sufficient accuracy and efficiently detect various activities. In summary, this paper contributes as the following: • We report an automated vision-based wellness analysis system which can effectively and accurately perform facial analysis and activity analysis at the same time under the realistic elderly care scenarios. • We propose visualization of the patterns via scene graphs and temporal activity heatmaps to provide an easy-to-understand summary for caregivers. • We cooperate with Haven of Hope Christian Service elderly care center, collect and label a practical video dataset and evaluate the key component of our system quantitatively.
2 Related Work Smart healthcare systems have drawn increased research attention since the flourish of artificial intelligence and smart devices. However, most of the existing healthcare systems require wearable devices to monitor the activity which is intrusive and hence not widely accepted. Computer vision based solutions bring the opportunity of a non-intrusive system based upon stationary cameras that allow passive detection of important activities. Most of the vision-based works focuses on fall detection [1–4]. While these models are effective for identifying critical acute conditions, they do not
324
X. Huang et al.
detect and analyze daily behaviors and their long-term patterns, which are of great importance for understanding elderly citizens. Prior works focusing on the long-term health monitoring [5, 6], focuses on using depth [5] and thermal [6] sensors instead of vision sensors. To the best of our knowledge, an automated non-intrusive visionbased system that is capable of activity detection and long-term health monitoring has not been proposed prior to our work.
3 Proposed Wellness Analysis System In this section, we will show components of proposed system: facial analysis, activity analysis, and interaction analysis module. These three modules all target at extracting meaningful feature. We utilize visualization tool to demonstrate these feature as straightforward wellness metrics.
3.1 Facial Analysis To help caretakers to gain better understanding of the senior citizens they care for, we propose to detect their facial activities, including yawning, dozing off, talking, blinking, etc. These facial activities can be detected by analyzing the facial landmarks dynamic between frames. We first briefly introduce how we extract the facial landmarks using a pre-trained deep neural network. Then we introduce how we utilize the facial landmarks to detect specific activities using a set of rules. For our application, the subjects of interest are senior citizens and caregivers. We built a human face library for all the subjects and apply the YOLOFace [7] model to detect the faces. When a face is detected, we can pair each detected face with the corresponding detected human with maximization of the Intersection of Union (IoU). The next step for face feature extraction is detecting facial landmarks on the facial image obtained from the face detection model. We use 68 semantic landmarks defined in iBUG [8]. These landmarks are chosen as they include relevant facial features (eyes, nose, mouth, etc.). We use a pre-trained ResNet-34 from the Dlib library for facial landmark detection. The model is highly robust and is able to analyze non-frontal faces, as well as occluded landmarks relatively well. From the detected facial landmarks over a series of frames, we can directly identify some facial activities such as yawning, napping, blinking, and talking. For yawning, we define this activity based on the distance among the lip landmarks. For a given frame sequence [ f m , f n ], the facial landmark of upper lip and lower lip is pul ( f i ) and pll ( f i ) respectively, if the condition of DW ( pul ( f i ), pll ( f i )) > Dthrl is met for all frames f i ∈ [ f m , f n ], the yawning activity is detected for this specific frame sequence. DW ( p1 ( f i ), p2 ( f i )) denotes the average distance of point p1 ( f i ) and p2 ( f i ) for a sliding window of W frames.
Automated Vision-Based Wellness Analysis for Elderly Care Centers
325
Similarly, the facial landmarks of upper eyelid and lower eyelid is pue ( f i ) and ple ( f i ). If the condition of DW ( pue ( f i ), ple ( f i )) > Dthr e is met for frames f i ∈ [ f m , f n ], the eye-closing status is detected in this sequence. The threshold for the yawning Dthrl and eye closing Dthr e is defined manually considering the difference of individual. Based on this, the facial activity is classified as napping when the time span is relatively long f n − f m > f thr , otherwise the activity is classified as blinking. Talking is different from other activities since it involves both the speaker and the listener. We adopt a hierarchical two-stage classification paradigm. Firstly, the talking behavior of an individual is detected based on the facial landmarks, then all human subjects are analyzed and in turn categorized as a speaker, or non-speaker. We then further classify the non-speaker into two fine-grained types: active listener or inactive bystander for which we use gaze and head pose detection to identify such talking interactions. The active listeners are those who the speaker is talking to and can be detected if their gaze directions pass through the head bounding box of the detected speaker. These inactive bystander include people listening but not interacting with speaker and people not listening at all.
3.2 Activity Analysis Based on the suggestion from the caregivers, we automatically analyze a set of dayto-day activities of senior citizens. In particular, the care center routinely conducts guided group physical exercises, detailed analysis of which can gain good insights to the exercise intensity and fitness of the elderly. In light of this, we propose an action detection module consisting of two parts: exercise intensity detection which aims to monitor the activeness level during group exercises, and activity detection which focus more on the other non-trivial daily activities including interaction with other objects in the care center.
3.2.1
Exercise Intensity Detection
In the care center, senior citizens must follow guided group exercises daily to maintain their physical fitness, as illustrated in Fig. 2. We attempt to automatically analyze each individual’s exercise intensity in such group exercises which should reflect their physical health conditions. We define a metric exercise intensity score which measures the activeness level in the guided group exercise. We first extract the key points of each person’s 2D body pose, obtained using OpenPose [9]. To compute exercise intensity score from these extracted points and their dynamics entailing various aspects of their body movements, we derive three meta features: angle, range, and speed of body movements (Fig. 2). Based on the elbow keypoints obtained from pose estimation of a frame over a sequence of frames, we compute angle via ar ccos(lse · lew /|lse ||lew |), where lse and lew represent the keypoint vector from the shoulder to the elbow, and from the elbow
326
X. Huang et al.
Fig. 2 Exercise intensity score for different senior citizens within 35 s. We can see from the bar that elderly #2 and #4 actively were doing exercise with higher intensity
to the wrist respectively. Then we normalize angle to [0, 1] to get the angle score as f angle . The angle encode the pose feature and can guide the inference of body movement during exercise. Similarly, we compute the movement range of wrist keypoint and normalize it to [0, 1] to derive f range , and compute normalized f speed considering the time taken to reach the maximum body movement. We combine these three meta features with the weighting coefficients λa , λr , λs to balance between the different components. The exercise intensity score f eis can be denoted as: f eis = λa · f angle + λr · f range + λs · f speed .
3.2.2
Temporal Activity Detection
In addition to detecting the activity level during the guided group exercises, we detect various activities, including sitting, watching TV, eating, drinking, and doing individual exercise. We can define our temporal activity classification as a K -way classification (K = 5 for our system) with a training set Dtrain = {(xi , yi )}i=1 , where xi ∈ RT ×H ×W ×C is the video input and yi ∈ [0, 1] K is a vector representing whether an activity from the whole K defined activities is performed or not in this video. → yi from video input to Our goal is to train a model learning a mapping f : xi − the activity vector label. This can be done by minimizing the loss function:L∗ = 1 arg max f ∈F |Dtrain (xi ,yi )∈Dtrain E[l( f (x i ), yi )], where F is a class of functions and | f (xi ) is the softmax score representing the probability of each category. l is the softmax cross entropy loss function. Once we have a pre-trained model f , we also want to fully utilize the temporal information for better accuracy. Given a continuous
Automated Vision-Based Wellness Analysis for Elderly Care Centers
327
L video V, we segment it into L overlapping clips of T frames X = {xi }i=1 , where T ×H ×W ×C . Each video clip is input into our model f and we use a sliding xi ∈ R j=i+W/2 f (x ). window W to smooth the predictions: f¯(xi ) = W1 j j=i−W/2 Additionally, for different activity categories, the best sliding window length W is different, which is different from using a fixed window length. We apply grid search on different activities and find the best sliding window length. In the experiment part, we will show the precision improvement boosted by the sliding is significant.
3.3 Interaction Analysis Some activities involve the interaction between a person and a certain object. These human-object interaction includes sitting on chair, watching TV, eating at table, and drinking with bottle or cup. While our temporal activity detection module is capable of predicting which activity is performed in a given frame sequence (e.g. predicting “drinking”), we also want to locate the position of human and object and assign the detected interaction to a pair of detected human instance and object instance (e.g. detecting “Senior #1 is drinking with a bottle”). The task of detecting these activities can be formulated as detecting a triplet human, ver b, object. Similar to previous human-object interaction detection model such as InteractNet [10], we build a network consists of three network branches: human, object, and interaction. The object detection branch is capable of detecting candidate human/object boxes bh , bo and predict the class score sh , so . Then the cropped human instance is input to the human branch, we extract feature with RoIAlign [11] and predict a score sha for each action a. The interaction branch receive the input of both cropped human and object boxes bh , bo . Then we prea for each action a. The output layer for human and interaction dict probabilities sh,o branch consists of four binary sigmoid classifiers for multi-label classification. The a . final prediction for each interaction a is Sa = sh · so · sha · sh,o Besides interaction of human and objects, human-human interaction is another important analysis target of a healthcare system. As we have built the facial analysis and temporal activity detection module, we can get the clues of social interaction from the detected human-human interaction—talking activity. Our intuition is that the more talkative an individual is, the better mental health condition he or she is in. This association has been proved in previous research, especially for dementia patients [12, 13]. We adopted scene graph to visualize and describe the activities in a given scene in a straightforward manner. Formally, scene graph G = (O, E) is a directed graph, where O = o1 , ..., on is the detected instance in the images or the frame of a video. Each detected instance oi = (ci , Ai ) has category ci and attribute Ai (e.g. a human is doing exercise). E is a set of directed edges that represent the relationships between objects. To understand the mental wellness of senior citizens, we define a feature for the elderly node in the graph. The feature can be computed via aggregating the
328
X. Huang et al.
information of neighboring nodes and their edges. The most important is the talking activity between human nodes. We give this activity high weight as we believe a talkative person has a better mental health condition. There are also other activities such as watching TV. These activity has some contribution to the mental wellness but not significant. Based on the scene graph, we define the close relationship between two persons. If the average talking time between the pair is more than a given threshold tth , we say that the pair of persons has a close relationship. Note that the pair of persons can be two senior citizens, or one caregiver and one senior citizen. The mental wellness system can provide guidance for possible psychotherapy accordingly.
3.4 Analysis of Long-Term Pattern and Trend Figure 3 gives an example of activity temporal heatmap visualizing the activity of a specific senior citizen during a whole day. Compared to the probability prediction of various activities, the heatmap provides a more straightforward illustration of what and when an activity is performed. In the example figure, we can see that the elderly citizen was sitting on the chair all day, except for some time slot. The system can alert the caregivers when the elderly was not on the chair. Additionally, we can see which time during the day a senior citizen is more likely to be sleepy and change the group exercise time accordingly. Some actions are more likely to be performed together while others not. To mine the association rule in our detection result, we use correlation to measure the co−μY )) ) = E((X −μσ XX σ)(Y , where occurrence. The correlation η is defined as η = cov(X,Y σ X σY Y X and Y are probability distribution output of two activities, E(X ) is the expected value of distribution X and σ X is the variance of distribution X . We compute the correlation between each activity pairs and show them in the co-occurrence matrix in Fig. 4.
Sit on Chair Absent from Chair
Watch TV Eat at Table Napping
Yawning Drinking Talking Doing Exercise Time (h)
Fig. 3 Activity temporal heatmap of eight activities in a given day. The dark blue at certain time indicate higher probability of an action was being performed. Red box denotes the missing of “sit on chair” action and our system will alert caregivers accordingly
Automated Vision-Based Wellness Analysis for Elderly Care Centers
329
Fig. 4 Correlation Matrix of different activities. Positive correlation implies the two actions are more likely to be performed together
Both intuitive and some counter-intuitive results can be observed. For example, some activities, such as yawning and napping, eating, and drinking are more likely to be performed together. While watching TV and napping are highly exclusive. This is close to our common sense. There are also interesting findings that may provide insights for the caregivers. As an instance, if someone watches TV more, he or she is more likely to have more close relationship with others. These correlations can help − improve the accuracy of prediction via the following prediction correction: pai ← pai + a j ∈A,a j =ai ηi j pa j . For activity ai in the action set A, we use the correlation ηi j to correct the prediction pai based on the activity prediction pa j . This feedback paradigm based on correlation analysis effectively boosts the performance of our model and we will show the effectiveness in the experiment part. Besides the correlation analysis, we also inspect the activities and features in the long term. Figure 5 gives an example of the activity timeline of napping and the average exercise intensity score during the group exercise in 7 days. It allows caregivers to observe the changes of specific features and can alert the caregivers when there is an anomalous increase or decrease of some metric, e.g. too long napping or lower exercise intensity than normal.
330
X. Huang et al.
Fig. 5 Average napping time and exercise intensity score in consequent seven days
4 Evaluation In this section we show how to collect and process video recording dataset. To extensively evaluate the system, we need some corresponding label to check if our prediction was correct. One of the most representative and important module in our system is the activity detection module. We choose the activity detection task for evaluation.
4.1 Data Collection Our dataset consists of numerous videos, each of which contains elderly individuals doing their daily activities at an elderly care center. This dataset was obtained from the Haven of Hope Elderly Homes, which is a care center providing holistic care for frail elders and chronically ill patients. To gather the dataset, we have obtained permissions from the Haven of Hope Elderly Homes to record their citizens’ daily activities within the working hour periods. These activities consist of both individual activities such as watching television, individual physical exercise, eating—and group activities such as interaction with other citizens, interaction with the caregivers, aided group physical exercise. Overall, we have 168 h of video recording of four cameras from different angles. There are a total 60M frames with 4K spatial resolution. The recording time span a whole month which can provide research possibility for long-term feature analysis. We only use our data for research purposes and blur all the faces when showing results.
Automated Vision-Based Wellness Analysis for Elderly Care Centers
331
4.2 Results We have manage to provide the healthcare assistance for 21 senior citizens in the care center. Monitoring the health condition of all the 21 senior citizens during the day time is not feasible for caregivers. These long-term feature and trend also cannot be simply observed by caregivers. With the help of our system, these challenges are tackled. To evaluate the performance of our the activity detection module, we select some representative clips from the whole dataset and label the instance-level activity. There are totally more than 4000 instances in the test set, covering seven activities excluded exercising (because exercising time is fixed every day in our data and we manually set a window to filter out all the false positive detection). We have made certain efforts, including dynamic sliding window length, correcting the prediction with activity co-occurrence. Table 1 shows the Average Precision (AP) of the instance-level activity classification. We define a detection as correct when the K ground truth activities are the same as top-K predicted activity categories. The “fixed sliding” indicates the sliding window with fixed length and “sliding” in the table represents the average sliding with the dynamic length for different activities. The “co-occurrence” implies the prediction correction based on the correlation. We can see that both sliding window and co-occurrence correction boost the performance of our model on the test set. Compared to the baseline model, the sliding windows and correlation correction improves the mean average precision (mAP) by 19.7% and 11.2% respectively. Combination of all tricks boost mAP by 26.4%.
Table 1 Activity classification result of Average Precision (AP) comparison on our test set. VPN [14] and I3D [15] are trained on Kinetics [17] and Toyota Smarthome [18] action classification dataset. Active Speakers in Context [16] is trained on AVA-Active Speaker dataset [19] Average Precision (AP)
Sit on chair
Watch tv
Eating
Drinking
Napping
Yawning
Talking
mAP
VPN [14]
0.844
1.000
0.769
0.896
–
–
–
–
I3D [15]
0.792
0.920
0.722
0.858
–
–
–
–
Active Speakers Context [16]
–
–
–
–
–
–
0.689
–
Baseline
0.604
0.893
0.735
0.829
0.374
0.712
0.546
0.670
Baseline + fixed sliding
0.773
0.893
0.687
0.844
0.412
0.890
0.727
0.747
Baseline + sliding
0.858
0.893
0.714
0.844
0.581
1.000
0.727
0.802
Baseline + co-occurrence
0.693
0.946
0.782
0.851
0.443
0.756
0.746
0.745
Baseline + co-occurrence + sliding
0.862
1.000
0.769
0.902
0.650
1.000
0.746
0.847
332
X. Huang et al.
Compared to state-of-the-art activity classification model VPN [14] and I3D [15], our model achieves better accuracy on four activity categories. As baseline model fails to outperform VPN and I3D, we can see the significant contribution of temporal information and activity correlation. For talking detection, we adopt a facial landmark based solution and it surprisingly outperforms state-of-the-art multi-modal speaker detection model Active Speakers in Context [16]. This is mainly because our video recording has poor audio quality and multi-modal methods fail to utilize the acoustic feature.
5 Conclusion We build an automated vision-based wellness analysis system and demonstrate its effectiveness on the healthcare of elderly care center citizens. Additionally, we extensively perform long-term feature analysis which can be a reference for the caregivers. We are still working on the project to incorporate more components into our model. This includes, but not limited to, more long-term analysis, personalized healthcare profiles, and more activities including abnormal ones. Our final goal is to build an privacy-preserving, robust, and high-efficiency healthcare system that provides comprehensive to both the caregivers and elderly without high computation cost.
References 1. Rougier, C., Meunier, J., St-Arnaud, A., & Rousseau, J. (2011). Robust video surveillance for fall detection based on human shape deformation. IEEE Transactions on Circuits and Systems for Video Technology, 21(5), 611–622. 2. Miao, Y., Rhuma, A., Mohsen Naqvi, S., Wang, L., & Chambers, J. (2012). A posture recognition-based fall detection system for monitoring an elderly person in a smart home environment. IEEE Transactions on Information Technology in Biomedicine, 16(6), 1274–1286. 3. Mastorakis, G., & Makris, D. (2014). Fall detection system using kinect’s infrared sensor. Journal of Real-Time Image Processing, 9(4), 635–646. 4. Zhang, Z., Conly, C., & Athitsos, V. (2015). A survey on vision-based fall detection. In Proceedings of the 8th ACM International Conference on PErvasive Technologies Related to Assistive Environments (pp. 1–7). 5. Parajuli, M., Tran, D., Ma, W., & Sharma, D. (2012) Senior health monitoring using kinect. In 2012 Fourth International Conference on Communications and Electronics (ICCE) (pp. 309–312). IEEE. 6. Luo, Z., Hsieh, J-T., Balachandar, N., Yeung, S., Pusiol, G., Luxenberg, J. et al. (2018). Computer vision-based descriptive analytics of seniors’ daily activities for long-term health monitoring. Machine Learning for Healthcare (MLHC), 2,. 7. Li, C., Wang, R., Li, J., & Fei, L. (2020). Face detection based on yolov3. In Recent Trends in Intelligent Computing, Communication and Devices (pp. 277–284). Springer. 8. Sagonas, C., Antonakos, E., Tzimiropoulos, G., Zafeiriou, S., & Pantic, M. (2016). 300 faces in-the-wild challenge: Database and results. Image and Vision Computing, 47, 3–18.
Automated Vision-Based Wellness Analysis for Elderly Care Centers
333
9. Cao, Z., Hidalgo Martinez, G., Simon, T., Wei, S., & Sheikh, Y. A. (2019). Openpose: Realtime multi-person 2d pose estimation using part affinity fields. IEEE Transactions on Pattern Analysis and Machine Intelligence,. 10. Gkioxari, G., Girshick, R., Dollár, P., & He, K. (2018). Detecting and recognizing humanobject interactions. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (pp. 8359–8367). 11. He, K., Gkioxari, G., Dollár, P., & Girshick, R. (2017). Mask r-cnn. In Proceedings of the IEEE International Conference on Computer Vision (pp. 2961–2969). 12. Rousseaux, M., Sève, A., Vallet, M., Pasquier, F., & Anne Mackowiak-Cordoliani, M. (2010). An analysis of communication in conversation in patients with dementia. Neuropsychologia, 48(13), 3884–3890. 13. Adams, T., & Gardiner, P. (2005). Communication and interaction within dementia care triads: Developing a theory for relationship-centred care. Dementia, 4(2), 185–205. 14. Das, S., Sharma, S., Dai, R., Bremond, F., & Thonnat, M. (2020). Vpn: Learning video-pose embedding for activities of daily living. In European Conference on Computer Vision (pp. 72–90). Springer. 15. Carreira, J., & Zisserman, A. (2017). Quo vadis, action recognition? a new model and the kinetics dataset. In proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, (pp. 6299–6308). 16. León Alcázar, J., Caba, F., Mai, L., Perazzi, F., Lee, J-Y., Arbeláez, P., et al. (2020). Active speakers in context. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 12465–12474). 17. Kay, W., Carreira, J., Simonyan, K., Zhang, B., Hillier, C., Vijayanarasimhan, S., et al. (2017). The kinetics human action video dataset. arXiv:1705.06950. 18. Dai, R., Das, S., Sharma, S., Minciullo, L., Garattoni, L., Bremond, F., et al. (2020). Toyota smarthome untrimmed: Real-world untrimmed videos for activity detection. 19. Roth, J., Chaudhuri, S., Klejch, O., Marvin, R., Gallagher, A., Kaver, L., et al. (2020). Ava active speaker: An audio-visual dataset for active speaker detection. In ICASSP 2020-2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (pp. 4492–4496). IEEE.
Efficient Extraction of Pathologies from C-Spine Radiology Reports Using Multi-task Learning Arijit Sehanobish, Nathaniel Brown, Ishita Daga, Jayashri Pawar, Danielle Torres, Anasuya Das, Murray Becker, Richard Herzog, Benjamin Odry, and Ron Vianu
Abstract Pretrained Transformer based models finetuned on domain specific corpora have changed the landscape of NLP. Generally, if one has multiple tasks on a given dataset, one may finetune different models or use task specific adapters. In this work, we show that a multi-task model can beat or achieve the performance of multiple BERT-based models finetuned on various tasks and various task specific adapter augmented BERT-based models. We validate our method on our internal radiologist’s report dataset on cervical spine. We hypothesize that the tasks are semantically close
A. Sehanobish (B) · N. Brown · I. Daga · J. Pawar · D. Torres · A. Das · M. Becker · R. Herzog · B. Odry · R. Vianu Covera Health, NYC, New York, USA e-mail: [email protected] N. Brown e-mail: [email protected] I. Daga e-mail: [email protected] J. Pawar e-mail: [email protected] D. Torres e-mail: [email protected] A. Das e-mail: [email protected] M. Becker e-mail: [email protected] R. Herzog e-mail: [email protected] B. Odry e-mail: [email protected] R. Vianu e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_24
335
336
A. Sehanobish et al.
and related and thus multitask learners are powerful classifiers. Our work opens the scope of using our method to radiologist’s reports on various body parts. Keywords Transformers · Medical NLP · BERT · Multitask learning
1 Introduction Since the seminal work by [1], Transformers have become the de-facto architecture for most Natural Language Processing (NLP) tasks. Self-supervised pretraining of massive language models like BERT [2] and GPT [3] has allowed practitioners to use these large language models with little or no finetuning to various downstream tasks. Multitask learning (MTL) in NLP has been a very promising approach and has shown to lead to performance gains even over task specific fine-tuned models [4]. However, applying these large pre-trained Transformer models to downstream medical NLP tasks is still quite challenging. Medical NLP has its unique challenges ranging from domain specific corpora, noisy annotation labels and scarcity of high quality labeled data. In spite of these challenges, a number of different groups have successfully finetuned these large language models for various medical NLP tasks. However, there is not much literature that uses multi-task learning in medical NLP to classify and extract diagnoses in clinical text [5]. Moreover, there is almost no work in predicting spine pathologies from radiologists’ notes [6]. In this article, we are interested in extracting information from radiologists’ notes on the cervical spine. In a given note, the radiologist discusses the specific, often multiple pathologies present on medical images and usually grade their severity. Extracting pathology information from a cervical spine report can facilitate the creation of structured databases that can be used for a number of downstream use-cases, such as cohort creation, quality assessment and outcome tracking. We focus on four of the most common pathologies in the cervical spine—central canal and foraminal stenosis, disc herniation and cord compression. Next, we create multiple tasks on a given report, where each task is to predict the severity of a pathology for each motion segment—the smallest physiological motion unit of the spinal cord [7]. Breaking information down to the motion segment level enables any pathological findings to be correlated with clinical exam findings and could feasibly inform future treatment interventions. Given the semantic similarities between pathologies and the co-occurrence of multiple pathologies in a given sentence, we believe that these tasks are similar. Thus it is tempting to ask whether MTL approach can match the performance of task specific models, since it will cut down on the hardware requirements for training and will be faster during inference time. A number of different approaches have looked at task similarity and semantics to understand what tasks need to be grouped together or the conditions required for MTL to succeed [8–10]. Inspired by the theoretical results in [11], we hypothesize that if the Wasserstein distance between tasks is small, a single multitasking model can match the performance of task specific models.
Efficient Extraction of Pathologies from C-Spine Radiology …
337
In this work, we (a) design a novel pipeline to extract and predict the severity of various pathologies at each motion segment, (b) quantify the notion of task similarity by computing the Wasserstein distance between tasks, and (c) show how to leverage that information into a simple MTL framework allowing us to achieve significant model compression during deployment and also speed up our inference without sacrificing the accuracy of our predictions.
2 Datasets We use our internal dataset consisting of radiologists’ MRI reports on cervical spine. The data consists of 1578 reports coming from 97 different radiology practices detailing various pathologies of the cervical spine. Our dataset is heterogeneous and is diversely sampled from a large number of different radiology practices and medical institutions. We annotate the data with the 4 following pathologies: spinal stenosis, disc herniation, cord compression and neural foraminal stenosis. Each of these pathologies is bucketed into various severity categories. For central canal stenosis, the 3 categories are based on gradation; none/mild are not clinically significant, moderate and severe definitions involved cord compression or flattening. The moderate versus severe gradation refers to the varying degrees of cord involvement. For disc herniation like central canal stenosis, the categories are based on a continuous spectrum and it is a standard practice in radiology for any continuous spectrum to be bucketed in mild, moderate and severe discrete categories. For cord compression, it is a binary classification problem: compression/signal change versus none. This is because either cord compression or signal change can cause symptoms and is therefor clinically relevant. For foraminal stenosis, we are only interested in binary classification as well: severe versus non-severe as severe foraminal stenosis may indicate nerve impingement which is clinically significant. The splits and the details of each category can be found in Table 1. The data distribution is highly imbalanced and about 20% of these reports are OCR-ed, which leads to additional challenges stemming from bad OCR errors. An example of a cervical report can be found in Fig. 1.
Table 1 Statistics of our dataset Split Stenosis Train
Test
None/Mild: 5488 Moderate: 561 Severe: 178 None/Mild: 793 Moderate: 68 Severe: 19
Disc
Cord
None/Mild: 2731 Moderate: 2699 Severe: 797 None/Mild: 401 Moderate: 378 Severe: 101
None: 5702 None: 5262 Mild/Severe: 525 Severe: 965 None: 806 Mild/Severe: 74
Foraminal
None: 789 Severe: 91
338
A. Sehanobish et al.
Fig. 1 Example of our dataset
3 Description of the Workflow In this section, we will briefly describe our novel pipeline. The reports are first deidentified according to HIPAA regulations and the Spacy [12] parser is used to break the report into sentences. Then, each sentence is tagged by annotators and given labels of various pathologies and severities if the sentence mentions that pathology. For example, in the above report, the sentence: “C1-C2: No significant neuroforaminal or spinal canal narrowing”: will be given normal or 0 class for each of the 4 pathologies. A BERT based NER model which we call the report segmenter is then used to identify the motion segment(s) present in a particular sentence and all the sentences containing a particular motion segment are concatenated together. The BERT based NER model achieves an F1 score of 0.9. More details about the NER model and the hyperparameters used to train it can be found in Appendix A. All pathologies are predicted using the concatenated text for a particular motion segment. Finally, the severities for each pathology are modeled as a multi-label classification problem, and a pre-trained transformer is finetuned using the text for each motion segment. For more details about our pipeline and data processing, please see Appendix B. Figure 2 breaks down how a report looks as it goes through our pipeline.
Fig. 2 Figure showing how a report looks as it goes through our pipeline
Efficient Extraction of Pathologies from C-Spine Radiology …
339
4 Methods Two types of multitasking models are assessed: (a) A multitasking BERT model and (b) A multitasking adapter augmented BERT model. We consider Adapters in our experiments as they provide a simple way to quickly train these transformer models on a limited computational budget and on smaller datasets. For our classification tasks, Clinical BERT model [14] is used as a backbone. The Clinical BERT model is then finetuned on the above tasks resulting in 4-task specific BERT sequence classifier models which provide our baseline results. Now, instead of finetuning 4-BERT based models, 4 classifier heads (i.e. 4 linear layers) is applied to a single Clinical BERT model to create an output layer of shape [3, 3, 2, 2], where the first 3-outputs correspond to the logits for the stenosis severity prediction, the next 3 for the disc severity, the next 2 for the cord severity and the final 2 logits for the foraminal severity. A dropout of 0.5 is added to the BERT vectors before passing them to the classifier layers. Each of these classifier heads is trained with a cross entropy loss with the predicted logits and the ground truth targets. All the losses are added as in the equation below 4 which allows the gradients to back propagate through the whole model and train these classifier heads jointly. L = lstenosis + ldisc + lcord + lforaminal where lpathology is the cross-entropy between the predicted logits for the given pathology severity and the ground truth labels. Finetuning these large transformer models is expensive, and sometimes they do not show much improvement where there is a lack of training data. To alleviate these problems [13, 15] introduce a novel parameter efficient transfer and multitask learning technique by adding in small networks called Adapters in between various Transformer blocks. Adapter modules perform more general architectural modifications to re-purpose a pretrained network for a downstream task. In standard fine-tuning, the new top-layer and the original weights are updated. In contrast, in Adapter tuning, the parameters of the original network are frozen and therefore may be shared by many tasks. Given the success of adapters for MTL, we experiment by adding Adapters as described in [16] in between every BERT output layers. When training with Adapters, the BERT weights are frozen. For multitasking, [16] proposed splitting and fusion adapters to prevent catastrophic forgetting across tasks. But in our work, the same adapters across all tasks are used like [17]. But unlike the above work, our results match the results with task-specific adapters or the fusion adapters for multitask learning. Our implementation follows the one outlined in [13] and thus our work can be seen as a simplification of the training strategies as proposed above without sacrificing the accuracy or speed. We conjecture that the tasks are similar and the sentences across these tasks have similar structure and semantic meaning, which allows these multitask models to perform without any need for task specific architectures. These hypotheses are validated by computing Wasserstein distances between various tasks in Sect. 6. PyTorch and the Hugging Face Library [18]
340
A. Sehanobish et al.
is used to train our models on NVIDIA V100 16GB GPU and the POT library [19] is used to compute the Wasserstein distances. More training details can be found in the Appendix A.
5 Results In this section, we validate our multitasking models on our cervical dataset. For detailed comparison, we also experiment with our multitasking and Adapter based models starting with the weights of BERT-base. Table 2 shows the results of our multitasking models over our baseline models and the adapter augmented models. Models which are initialized with the weights of Clinical BERT show an improvement over the corresponding models initialized with the weights of BERT base. Moreover, our results show that the multitasking models perform as well as the task specific models. In fact, our results with the Fusion adapter modules show that mixing information from various tasks can actually improve model performance. Finally, Table 3 shows significant improvements in inference speeds on our test set of the multitasking models over the baseline single taskers.
Table 2 Table showing the macro F1 scores over 5 trials of our baseline and multitasking models Backbone Model Stenosis Disc Cord Foraminal BERT BASE
CLINICAL BERT
Baseline (single tasker) MultiTasking Task specific adapters Multitasking adapter Fusion adapters [13] Baseline (single tasker) MultiTasking Task specific adapters Multitasking adapter Fusion adapters [13]
0.62 ± 0.03
0.64 ± 0.03
0.70 ± 0.03
0.79 ± 0.03
0.62 ± 0.02 0.63 ± 0.01
0.65 ± 0.03 0.64 ± 0.03
0.72 ± 0.02 0.68 ± 0.02
0.78 ± 0.01 0.79 ± 0.02
0.65 ± 0.02
0.66 ± 0.03
0.70 ± 0.04
0.80 ± 0.03
0.64 ± 0.03
0.66 ± 0.01
0.70 ± 0.03
0.79 ± 0.03
0.64 ± 0.05
0.66 ± 0.02
0.71 ± 0.02
0.82 ± 0.01
0.63 ± 0.02 0.66 ± 0.01
0.67 ± 0.01 0.65 ± 0.03
0.75 ± 0.01 0.69 ± 0.03
0.79 ± 0.03 0.81 ± 0.01
0.66 ± 0.01
0.67 ± 0.03
0.71 ± 0.03
0.81 ± 0.02
0.65 ± 0.03
0.67 ± 0.01
0.72 ± 0.02
0.81 ± 0.02
Efficient Extraction of Pathologies from C-Spine Radiology …
341
Table 3 Table showing faster inference speed of our multitasking models over the baseline models Model Baseline clinical Multitasking Baseline clinical Multitasking BERT clinical BERT BERT with adapter (single tasker) task specific augmented adapters clinical BERT Walltime (seconds)
259.64
56.93
281.56
60.16
6 Empirical Evidence Behind MultiTasking Models In this section, we provide some evidence behind the performance of our MultiTasking models. There is an explicit relationship between the Wasserstein distances and the generalization error in MTL as proposed in Theorems 1 and 2 in [11]. Moreover, motivated by the work of [20] to compute distances between labeled datasets, we define a task T y (X ) := P(X |Y = y) as a conditional distribution. We then define the distance between the tasks to be: d((X, y), (X , y )) := W2 (T y (X ), T y (X ))
(1)
where W2 is 2-Wasserstein distance. This conditional distribution also appears in the work of [21]. Computing Wasserstein distances are extremely computationally expensive. Thus, various authors approximate W2 by the Wasserstein-Bures metric or by various entropic regularized Sinkhorn divergences [20, 22]. Here, W2 metric is approximated by the sliced Wasserstein distance [23] with 60 random projections of dimensions in logspace between 1 and 4. To apply the sliced Wasserstein distance, the embeddings from the final BERT layer of our pretrained task specific BERT models are extracted, i.e. X is the 768-dimensional vector representation of the [CLS] token coming from the appropriate BERT model. The sliced Wasserstein distances are computed between these BERT embeddings. Table 4 shows
Table 4 Table showing the Sliced Wasserstein Distance between Tasks on the Training Set. Some of the Wasserstein distances are not shown as they can not be computed owing to the sample size of some of our minority classes Task Mild stenosis Severe disc Mild disc Mild/Severe Mild/Severe foraminal cord Mild stenosis Severe Disc Mild Disc Mild/Severe Foraminal Mild/Severe Cord
0 0.7 ± 0.4 0.3 ± 0.2 1.2 ± 0.3
0.7 ± 0.4 0 0.2 ± 0.1 0.8 ± 0.6
0.3 ± 0.2 0.2 ± 0.1 0 0.7 ± 0.5
1.2 ± 0.3 0.8 ± 0.6 0.7 ± 0.5 0
0.7 ± 0.6 0.6 ± 0.5 1.1 ± 0.7 0.8 ± 0.7
0.7 ± 0.6
0.6 ± 0.5
1.1 ± 0.7
0.8 ± 0.7
0
342
A. Sehanobish et al.
the sliced Wasserstein distances between various conditional distributions. The upper bound for Wasserstein distance between two probability measures is given by: W2 (T y (X ), T y (X )) ≤ diam(A)T V (T y (X ), T y (X )) where diam(A) is the diameter of the support of the measures and in our cases can be bounded by 59.4, similar to the reported values in [24]. T V (T y (X ), T y (X )) is the total variation and can be trivially bounded by 1. The relatively small distances (which are also considerably lower than the upper bound) between tasks is most likely why a multitask model is able to replicate the performances of task specific models. The bound given by the above equation is not tight so we provide some empirical analysis on the bounds of the Wasserstein distances by using the methods of [20] on some public text classification datasets. This analysis can be found in Appendix C.
7 Conclusion In this work, a simple multitasking model is presented that is competitive with task specific models and is 4 times faster during inference time. Instead of training and deploying 4 models, only one model is trained and deployed, thus achieving significant model compression. This work opens the possibility of using multitasking models to extract information over various different body parts and thus allows users to leverage these large transformer models using limited resources. Our novel pipeline is one of the very few works that attempts to extract pathologies and their severities from a heterogeneous source of radiologists’ notes on cervical spine MRIs at the level of motion segments. In these ways, our findings suggest that our approach may not only be more widely generalizable and applicable but also more clinically actionable. Finally, we also shed light on how closely related or semantically similar these tasks are. In our future work, we will expand on our observations to the radiologist’s reports for other body parts and other pathologies for the cervical spine. We will also focus on further characterizing medical NLP datasets and tasks using our definition of task similarity so we can define when learning can be cooperative and when learning is competitive and whether our definition of task similarity has any clinical significance.
Appendix A Training Details We create a validation set using 10% of the samples of the training set where the samples are drawn via stratified samples so the data distribution is maintained across splits. For finetuning the BERT models (BERT-Base and Clinical BERT) or the multitasking model, we finetune the whole model with a batch size of 16 for 3−5
Efficient Extraction of Pathologies from C-Spine Radiology …
343
epochs with BERT Adam optimizer which is basically a weight decoupled Adam optimizer [25]. The learning rate used is 1e−5 with a linear learning rate decay scheduler and the weight decay is 1e−4. For training the Adapter augmented models, we use a higher learning of 1e−4. We use the Adapter architecture as defined in [13]. Adapters are 2 layer feedforward network with a bottle neck dimension of 48. We initialize the adapter weights such that initially the whole Adapter layer is almost an identity function. Thus one can think of these adapter architectures as autoencoder architectures. We experimented with both GELU and ReLU as non-linearities between the feedforward layers and the results were similar. We train our adapter augmented models for about 10−12 epochs with early stopping on the validation loss. For Adapter augmented models, the BERT weights were frozen and only the adapter and classifier weights were updated. All the baseline models were finetuned for 15 epochs and the multitasking model was finetuned for 12 epochs. The Adapter based models were trained for 20 epochs and the Multitasking Adapter model is trained for 23 epochs. The sequence length used for the all the baseline and the multitasking classifier models is 512. The NER model is a BERT-based binary classifier (Location Tag versus the Other Tag). It is our in-house model that is trained on both lumbar and cervical MRI reports (about 6000 reports) that can predict the location tags in those reports. The model is trained for 5 epochs with a batch size of 16 and sequence length 256. We used AdamW optimizer with weight decay of 1e−4. The learning rate used is 1e−5 with a linear learning rate decay scheduler. Our NER model achieves an F1 score of 0.9.
Appendix B Detailed Description of Our Workflow In this section, we give a more detailed description of our workflow. Our main goal is to detect pathologies at the motion segment level from radiologists’ MRI reports on cervical spine. The motion segments we care about in our work are C2-C3, C3-C4, C4-C5, C5-C6, C6-C7 and C7-T1. We first make sure that the reports are de-identified and then use a Spacy [12] parser to break the report into sentences. Then each sentence is tagged by annotators and they are given labels of various pathologies and their severities if the sentence mentions that pathology. To detect pathologies at a motion segment level, we use our BERT based NER system to tag the locations present in each sentence. Our tag of interest for the NER model is the motion segment tag and we achieve an F1 score of .9 for that tag. We then use a rule based system to group all sentences to the correct motion segment. If a sentence does not explicitly have a motion segment mentioned it, we use a rule based method to assign the sentence to one of the above mentioned motion segments or to a generic category “No motion segments found”. Given the disparate source of our data, for example, C34, C3-C4, C3_C4 all refer to the motion segment C3-C4 and thus our systems are mindful of this diversity of the clinical notes. Finally to use our BERT based models for pathology detection on the level of motion segments for a given
344
A. Sehanobish et al.
Fig. 3 Workflow of our approach
report, we concatenate all sentences for a given motion segment and use the [CLS] token for the segment that is used for the downstream classification task. Figure 3 shows our workflow. Since we are interested in predictions at the motion segment level, we do not use the sentences that are grouped under “No segment found” to train the classifier models, nor do we evaluate our classifier models on those sentences.
Appendix C Bounds on Wasserstein Distances Between Text Datasets The upper bound for the Wasserstein distance derived in our paper is not tight. We are unable to provide a tighter bound without additional hypothesis on the empirical distributions considered in our work. Instead we provide clarity on how large these numbers can be between various publicly available datasets. The following comparisons are inspired by the results described in Fig. 8 in [20]. The distance between labeled datasets as described in the above paper has two components: (1) Euclidean distance between the feature vectors and (2) Wasserstein distance between the conditional distributions (also described in [21]). Disentangling the contribution of the feature vectors following the method outlined in their appendix, we find the Wasserstein distance between positive labels in the Yelp binary polarity dataset and the “Educational Institution” class in the DbPedia-14 to be as large as 37.67 and distances as low as 0.64 between positive classes in Amazon reviews binary polarity and the Yelp binary polarity datasets. We hope that these numbers shed some light on the possible range of the Wasserstein distances between some benchmark text classification datasets. However, given the SOTA performances by BERT on these datasets is almost perfect and the possibility of DbPedia leaking into BERT’s training data (DbPedia is scrawled from Wikipedia), an MTL experiment on these datasets may not justify our hypothesis.
Efficient Extraction of Pathologies from C-Spine Radiology …
345
References 1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. In Proceedings of the 31st International Conference on Neural Information Processing Systems, NIPS’17, pp. 6000–6010. Red Hook, NY, USA: Curran Associates Inc. 2. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2019). BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pp. 4171–4186. Minneapolis, Minnesota: Association for Computational Linguistics. 3. Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J. D., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., Hesse, C., Chen, M., Sigler, E., Litwin, M., Gray, S., Chess, B., Clark, J., Berner, C., McCandlish, S., Radford, A., Sutskever, I., & Dario Amodei. Language Models are Few-Shot Learners. 4. Worsham, Joseph, & Kalita, Jugal. (2020). Multi-task learning for natural language processing in the 2020s: Where are we going? Pattern Recognition Letters, 136, 120–126 Aug. 5. Peng, Y., Chen, Q., & Lu, Z. (2020). An empirical study of multi-task learning on BERT for biomedical text mining. 6. Azimi, P., Yazdanian, T., Benzel, E.C., Aghaei, H.N., Azhari, S., Sadeghi, S. and Montazeri, A. (2020). A review on the use of artificial intelligence in spinal diseases. Asian Spine J, 14(4), 543–571. 7. Swartz, E. E., Floyd, R. T., & Cendoma, M. (2005). Cervical spine functional anatomy and the biomechanics of injury due to compressive loading. Journal of Athletic Training, 40(3), 155–161. 8. Bingel, J., & Søgaard, A. (2017). Identifying beneficial task relations for multi-task learning in deep neural networks. In Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 2, Short Papers (pp. 164–169), Valencia, Spain: Association for Computational Linguistics. 9. Zamir, A. R., Sax, A., Cheerla, N., Suri, R., Cao, Z., Malik, J., & Guibas, L. J. (2020). Robust learning through cross-task consistency. 10. Standley, T., Zamir, A., Chen, D., Guibas, L., Malik, J., & Savarese, S. (2020). Which tasks should be learned together in multi-task learning?. 11. Shui, C., Abbasi, M., Robitaille, L.-É, Wang, B., & Gagné, C. (2019). A principled approach for learning task similarity in multitask learning. In Proceedings of the 28th International Joint Conference on Artificial Intelligence, IJCAI’19 (pp. 3446–3452). AAAI Press. 12. Honnibal, M., Montani, I., Van Landeghem, S., & Boyd, A. (2020). spaCy: Industrial-strength natural language processing in python. 13. Pfeiffer, J., Rücklé, A., Poth, C., Kamath, A., Vuli´c, I., Ruder, S., Cho, K., & Gurevych, I. (2020). Adapterhub: A framework for adapting transformers. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations (pp. 46–54). 14. Alsentzer, E., Murphy, J. R., Boag, W., Weng, W. H., Jin, D., Naumann, T., & McDermott, M. (2019). Publicly available clinical BERT embeddings 15. Houlsby, N., Giurgiu, A., Jastrzebski, S., Morrone, B., De Laroussilhe, Q., Gesmundo, A., Attariyan, M., & Gelly, S. (2019). Parameter-efficient transfer learning for NLP. 16. Pfeiffer, J., Kamath, A., Rücklé, A., Cho, K., & Gurevych, I. (2021). AdapterFusion: Nondestructive task composition for transfer learning. 17. Stickland, A. C., & Murray, I. (2019). BERT and PALs: Projected attention layers for efficient adaptation in multi-task learning 18. Wolf, T., Debut, L., Sanh, V., Chaumond, J., Delangue, C., Moi, A., Cistac, P., Rault, T., Louf, R., Funtowicz, M., Davison, J., Shleifer, S., von Platen, P., Ma, C., Jernite, Y., Plu, J.,
346
19.
20. 21. 22. 23. 24. 25.
A. Sehanobish et al. Xu, C., Le Scao, T., Gugger, S., Drame, M., Lhoest, Q., & Rush, A. M., Transformers: stateof-the-art natural language processing. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations (pp. 38–45). Association for Computational Linguistics. Flamary R, Courty N, Gramfort A, Alaya MZ, Boisbunon A, Chambon S, Chapel L, Corenflos A, Fatras K, Fournier N, Gautheron L, Gayraud, N. T. H., Janati, H., Rakotomamonjy, A., Redko, I., Rolet, A., Schutz, A., Seguy, V., Sutherland, D. J., Tavenard, R., Tong, A., & Vayer, T. (2021). POT: Python optimal transport. Journal of Machine Learning Research, 22(78), 1–8. Alvarez-Melis, D., & Fusi, N. (2020). Geometric dataset distances via optimal transport. Courty, N., Flamary, R., Habrard, A., & Rakotomamonjy, A. (2017). Joint distribution optimal transportation for domain adaptation. Chizat, L., Roussillon, P., Léger, F., Vialard, F. X., Peyré, G. (2020). Faster wasserstein distance estimation with the sinkhorn divergence. Kolouri, S., Nadjahi, K., Simsekli, U., Badeau, R., & Rohde, G. (2019). Generalized Sliced Wasserstein Distances. Kobayashi, G., Kuribayashi, T., Yokoi, S., & Inui, K. (2020). Attention is not only a weight: Analyzing transformers with vector norms. Loshchilov, I., & Hutter, F. (2019). Decoupled weight decay regularization.
Benchmarking Uncertainty Quantification on Biosignal Classification Tasks Under Dataset Shift Tong Xia, Jing Han, and Cecilia Mascolo
Abstract A biosignal is a signal that can be continuously measured from human bodies, such as respiratory sounds, heart activity (ECG), brain waves (EEG), etc., based on which, machine learning models have been developed with very promising performance for automatic disease detection and health status monitoring. However, dataset shift, i.e., data distribution of inference varies from the distribution of the training, is not uncommon for real biosignal-based applications. To improve the robustness, probabilistic models with uncertainty quantification are adapted to capture how reliable a prediction is. Yet, assessing the quality of the estimated uncertainty remains a challenge. In this work, we propose a framework to evaluate the capability of the estimated uncertainty in capturing different types of biosignal dataset shifts with various degrees. In particular, we use three classification tasks based on respiratory sounds and electrocardiography signals to benchmark five representative uncertainty quantification methods. Extensive experiments show that, although Ensemble and Bayesian models could provide relatively better uncertainty estimations under dataset shifts, all tested models fail to meet the promise in trustworthy prediction and model calibration. Our work serves as a benchmark evaluation for any future biosignal classifiers without requiring additional datasets. Keywords Machine learning for health · Biosignal classification · Uncertainty quantification · Dataset shift
T. Xia (B) · J. Han · C. Mascolo Department of Computer Science and Technology, University of Cambridge, Cambridge, UK e-mail: [email protected] J. Han e-mail: [email protected] C. Mascolo e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_25
347
348
T. Xia et al.
1 Introduction The growth of commercial wearables and the ubiquity of smartphones with numerous sensors have enabled multi-modal, affordable, non-invasive, round-the-clock biosignal collection, based on which deep learning facilitates a wide spectrum of mental and physical health applications. Despite the impressive performance achieved by deep learning, the underlying premise is that training and test data are from the same distribution. Yet, this assumption often does not hold in practice, because dataset shift caused by user variability, device discrepancy, artifact, and other factors are ineluctable during in-the-wild biosignal acquisition [18]. Most existing deep learning models cannot flag this distributional shift and tend to be over-confident during inference, which may, in the long run, undermine people’s trust in applying deep learning for healthcare. In the field of machine learning, uncertainty has been used as a measurement of how a deep neural network can be trusted [7]. Different types and sources of uncertainties have been identified and a variety of approaches to quantify uncertainty in neural networks have been proposed. Given the importance of risk management for safety-critical health applications, uncertainty is recognised to be helpful, because it can not only inform the confidence of the prediction but also provide the opportunity to keep doctors in the loop to correct the potentially wrong automatic predictions. Leibig et al. found that in the task of diagnosing diabetic retinopathy from fundus images of the eye, incorrect prediction usually yielded higher uncertainty than true predictions: by excluding the least uncertain predictions, the automatic diagnosis’ accuracy could be improved from 87 to 96% [14]. In spite of the efforts to estimate and understand the uncertainty of health-related deep learning models, key aspects are still under-explored. First, most of the previous work focuses on clinical images [7], while how uncertainty performs on other health physical signals, e.g., biosignal from wearables, remains unclear. Moreover, the quality of uncertainty is mainly assessed in the independent and identically distributed (i.i.d.) testing set, while in the distributional shifted regime, problems like how trustworthy the existing uncertainty estimation methods are not explicit. To answer these questions, in this paper, we conduct a comprehensive evaluation across five uncertainty quantification methods on three representative biosignal classification tasks under a controlled dataset shift. The implemented uncertainty quantification methods cover Bayesian neural network, approximate Bayesian neural network, calibration and deep ensemble approaches, and the tasks include audiobased COVID-19 prediction, breathing-based respiratory abnormality prediction, and ECG-based heart arrhythmia detection applications. To assess the quantified uncertainty, we propose a framework for analysing the above methods on dataset shift without requiring the collection of new datasets. Specifically, the key mechanism is to empirically synthesise signal-specific distributional shift according to real signal data collection scenarios, so that both the shift type and the degree can be controlled and the evaluation framework can be generalised to any biosignal tasks.
Benchmarking Uncertainty Quantification …
349
Consequently, we find that the widely used uncertainty estimation approaches fail to yield well-calibrated uncertainty under dataset shift, and we draw attention to better methods for safety-critical health applications.
2 Related Work Uncertainty is recognised as a measurement of model’s trust and its importance has been widely discussed in the literature, particularly for deep learning-enabled health applications [2, 16]. The quantified uncertainty can be used for selective prediction: keeping low-uncertain outputs but referring high uncertain (unsafe) predictions to doctors, which allows clinicians in the loop and improves the system robustness. Diagnosing diabetic retinopathy from fundus images of the eye is a commonly used task to assess the estimated uncertainty [14, 19, 21, 22], and uncertainty-aware lung disease detection from X-rays also gains massive attention [8]. To the best of our knowledge, only a few works explored uncertainty for biosignal [23]. Moreover, all those works demonstrate the quality of uncertainty in the testing set which was selected from the same distribution with training data, while the effectiveness remains unclear under dataset shift. As dataset shift is now being highlighted frequently by the research community, there have been some attempts [1, 17, 20] to assess the estimated uncertainty (none of them is for biosignals): In [17], a comparison of multiple uncertainty estimation methods was conducted on image, context, and ad-click datasets, discovering that along with accuracy, the quality of uncertainty consistently degrades with increasing dataset shift regardless of method. Reference [20] proved uncertainty estimation does not enable reliable out-of-distribution detection on medical tabular data. Recently, [1] proposed a benchmark to evaluate Bayesian deep learning on diabetic retinopathy detection task, where two human retina image datasets with country shift and severity shift are constructed for use. Inspired by but different from the aforementioned works, in this paper, we aim to close the gap between uncertainty quantification and biosignal dataset shift, towards more realistic performance evaluation and more reliable healthcare deployment.
3 Uncertainty Quantification Approaches For model θ , the output probability pθ from So f tmax layer at hand may indicate the confidence of the prediction to some degree. However, it tends to overestimate the confidence and requires further calibration [17]. Only well-calibrated uncertainty would be very useful to tell to what degree the model is certain about its predictions. Two distinct types of uncertainties exist: aleatoric uncertainty stems from stochastic variability inherent in the data generating process (also know as data uncertainty), while epistemic uncertainty arises due to our lack of knowledge about the data
350
T. Xia et al.
generating mechanism. More specifically, epistemic uncertainty is associated with model structures¶meters (also known as model uncertainty), and the systematic discrepancy between the training and testing data (distributional shift) [15]. Existing uncertainty quantification methods mainly include the following categories. Bayesian Methods. Bayesian methods explicitly define a combination of infinite plausible models according to a learnable prior distribution estimated from the observed data. With the posterior distribution p(θ |D), for a new test data point x , the predictive posterior distribution can be informed in the continuous space. Monte Carlo Dropout. Dropout is a widely used technique during training to tackle overfitting. Gal et al. [5] leveraged variational distribution to view dropout in the forward phrase as approximate Bayesian inference. With dropout kept during inference, the predictive probability can be computed through the randomly sampled models and treated as an estimation of the uncertainty. Ensemble Methods. Ensemble, also known as frequentist methods, does not specify a prior distribution over parameters. Instead of learning infinite models, ensemble approaches only require a limited number of models, which is computationally tractable [6]. Herein, the final predictive probability for each instance x is estimated by simply averaging the outputs over all M models. Model Calibration. The probability that a model outputs should reflect the true correctness likelihood, yet most modern neural networks are poorly calibrated. Posthoc probability calibration, from [11], temperature scaling, is effective: using one parameter to re-scale logits before passing them into softmax. To sum up, we select five methods for uncertainty estimation considering their prevalence, scalability, and practical applicability [17]. They are, • Vanilla: Maximum softmax probability from the deterministic backbone model. • Scaling: Post-hoc calibration from vanilla probabilities by temperature scaling parameterized by value T [11]. • MCDropout: Monte-Carlo Dropout with a dropout rate of p during inference [5]. A sample will be fed into the model M times to quantify the model uncertainty. • Bayesian: Stochastic variational Bayesian inference [10] with Gaussian priors. • Ensemble: Ensembles of several networks with identical structures, which are trained with random initialisation [13].
4 Bencmark Tasks and Experiemnts 4.1 Biosignal Classification Tasks From the literature, we select three representative tasks with different biosignals to investigate whether the estimated uncertainty works when dataset shift occurs. Table 1 presents a summary of the above tasks.
Benchmarking Uncertainty Quantification …
351
Table 1 A summary of the tasks, datasets, models, and baseline performance Task Dataset (size) Classes Backbone COVID-19 Respiratory Arrhythmia
Audio dataset (1,000 users) Breathing recordings (1,990 clips) ECG (8,224 recordings)
COVID-19 positive/negative Normal/abnormal Normal/AF(atrial fibrillation)/others
Accuracy
VGGish
0.68
ResNet
0.75
Transformer 0.83
COVID-19 prediction. Sound-based COVID-19 prediction has shown great promise. We choose to implement the VGGish model proposed in [12] for its available code and dataset. This is a binary classification task, where cough, breathing, and voice sounds are transferred into spectrograms to distinguish COVID-19 positive from negative participants. Respiratory abnormality detection. Auscultation of the lung is a part of the clinical examination. It is vital to distinguish abnormal respiratory sounds from normal ones to enable the correct treatment. The ICBHI 1 provided a breathing sound database collected from heterogeneous auscultation equipment. A binary classification task is formulated to detect whether a breathing sound segment contains abnormalities, including crackle and wheeze. We use the backbone of a deep convolutional model (ResNet) proposed in [4] for its favourable performance. Heart arrhythmia detection. ECG, another type of widely used biosignal, involves the recording of electrical impulses generated by the heart muscle during beating activity. Through ECG, arrhythmia (irregular beat) can be identified. For the evaluation, we use the dataset from [9] and the developed transformer-based neural network in [3].2 The task is to predict, from a single short ECG lead recording (between 30 and 60 s in length), whether the recording shows normal sinus rhythm, atrial fibrillation (AF), or an alternative rhythm.
4.2 Dataset Shift and Evaluation Protocol To assess the quality of the predictive uncertainty yielded by different methods, we propose and apply the following perturbations covering major potential shifts in practice: • Mixing Gaussian noise. Gaussian noise is statistical noise having a probability density function equal to the normal distribution. This type of noise is common in various types of signals and can arise from acquisition (e.g., sensor noise) or 1 2
https://bhichallenge.med.auth.gr/. https://physionet.org/content/challenge-2017/1.0.0/sources/files-panel.
352
•
•
•
•
T. Xia et al.
transmission (e.g., electronic circuit noise). We add the generated Gaussian noise to raw biosignals by controlling the SNR (signal-to-noise ratio). Mixing background noise. Random background or environmental noise is another type of prevalent noise when collecting biosignals, particularly in audio signals. To simulate this, we mix pre-recorded TV news with raw audio signals according to different SNRs. Signal amplitude distortion. The amplitude distortion, also called clipping, is the result of “over-driving” the input of the amplifier, which is a part of the analogy signal acquisition circuit. For example, when a user speaks closed to the microphone loudly, the audio signal can be distorted with some peak value flattened in the waveform view. We synthesise the distortion by replacing the amplitude over a pre-defined threshold to the threshold value. Signal segment missing When signal acquisition or transmission is unstable, some segments can be missing, which leads to incomplete signals in the time domain. According to this, we manually mask a portion of the signal by setting the value to 0 within several masking blocks. Sampling rate mismatching. Physical biosignals are continuous (analogy), which need to be discretised and stored as digital signals by a given sampling rate for further utilisation. A very low sampling rate can lead to information loss. To simulate this, we randomly down-sample some frames in the raw biosignals for testing.
Since there is no ground truth for uncertainty, it is not straightforward to evaluate the quality of the uncertainty. Our proposed evaluation protocol is to add controllable perturbations to the original biosignals in the testing sets, and then compare the yielded predictive uncertainty under different shifting degrees. Overall, we define a shifting degree from 0 to 5 with 0 denoting the original testing set. For mixing Gaussian noise and background noise, shift degrees of 1, 2, 3, 4, and 5 indicate an SNR of 50, 40, 30, 20, 10, respectively. For amplitude distortion, threshold of 80, 60, 50, 20, 10% of the maximum amplitude are applied. For signal segment missing, we mask 20, 35, 50, 65, 80% of the raw signals, and for sampling rate mismatching, every 1/80, 1/50, 1/30, 1/20, 1/10 of the data points are evenly dropped. Examples for various types and degrees of dataset shift are given in Fig. 1.3
4.3 Metrics To assess a model’s performance, classification accuracy↑ (arrows indicating which direction is better) is often used, which only concerns its categorical output yˆ = argmax( p(y|x )) as a hard prediction. In addition, in this study, we explore the following metrics, where the predictive uncertainty is taken into account by exploiting the predictive probabilities instead of the hard predictions: 3
More details can be found from https://arxiv.org/abs/2112.09196.
Benchmarking Uncertainty Quantification …
353
Fig. 1 Illustrations of synthesising different specific dataset shifts on biosignals. Blue signals present random selected original signals for the three tasks respectively, and accordingly yellow signals are the synthesised signals under different controlled perturbations with a given degree
Brier Score↓. Brier score measures the distance between the one-hot labels and the predictive probabilities. It is computed as |D1 | i (1 − p(y = yi |xi , θ ))2 . ECE↓. Expected calibration error (ECE) measures the correspondence between the predicted probabilities and empirical accuracy. It is computed as the weighted average gap between within bucket accuracy and probability. A bucket Bs = { p(y = yi |xi , θ ) ∈ (ρs , ρs+1 )} with ρs denotes the quantiles of the predictive probabilities. S |Bs | Hence, EC E = s=1 |acc(Bs ) − con f (Bs )| with acc(Bs ) and con f (Bs ) denot|D | ing the accuracy and the average predictive probability within Bs. Predictive entropy. For each testing sample, entropy describes the average level of information in a random variable. KIn this context, this variable is the predictive probap(y = k|xi , θ )log( p(y = k|xi , θ )), can capture bility and hence H ( p|θ ) = − k=1 the data uncertainty. For methods including MCDropout, Bayesian, Ensemble, a testing sample will be passed entropy is model M times, and1 predictive K into1 the p(y = k|xi , θm )log( M p(y = k|xi , θm )). (M formulated as H ( p) = − k=1
354
T. Xia et al.
This captures both aleatoric and epistemic uncertainty. For an easy illustration, we will the notation Uncertainty, which measures the average predictive entropy across the whole testing set D . What do we expect to see? Intuitively, on increasingly shifted data, a model’s performance might degrade, reflected by a decrease in Accuracy and a rise in Brier. Moreover, ideally, this decrease of performance should coincide with an increase in Uncertainty. In particular, an increased uncertainty implies that the model becomes less and less confident of its predictions, and this will be a good indicator of potential dataset shifts during inference for real-world health applications. Meanwhile, we would expect a good model remains well-calibrated under different dataset shifts, which is represented by a small and stable ECE.
5 Results and Analysis Results for COVID-19 prediction task. Experimental results for COVID-19 prediction are shown in Fig. 2. First, from Fig. 2a, b, and c, it is clear that all methods achieve worse performance with the increased shift degree, as Accuracy decays significantly with Brier and ECE showing an upward trend. Yet, Uncertainty does not perform as expected for all the cases: in Fig. 2a on degree 5 and in Fig. 2b for all degrees, most methods yield declining uncertainties. While an increasing Brier implies that a model becomes more and more uncertain, the corresponding Uncertainty decreases, indicating that the model produces over-confident incorrect predictions. Moreover, we observe that the deterministic model might lead to biased predictions under severe dataset shift. In Fig. 3, we inspect the true positive rate (TPR) and true negative rate (TNR) in this binary classification task. Figure 3a shows that with severer Gaussian noise, all methods except Bayesian tend to classify more testing samples into the negative group, while the opposite direction can be observed in
Fig. 2 Accuracy and uncertainty under various corruptions for COVID-19 detection task. Note that Vanilla and scaling methods yield the same Acc, so their lines overlapped (and only the scaling one shows up)
Benchmarking Uncertainty Quantification …
355
Fig. 3 True positive rate (TPR) and true negative rate (TNR) on the increasing shift for COVID-19 task, with the same colour legend used in Fig. 2
Fig. 3b with TV show noise. It is worth highlighting that a balanced testing set is used in this task, otherwise, Accuracy may not be a good metric to evaluate a model’s generalisation performance. For example, if the negative class is the majority in the testing set, a severer Gaussian noise shift would result in higher Accuracy. This is consistent with the finding that uncertainty-unaware dataset shift evaluation can be misleading, as suggested by [1]. Comparing the five methods, Ensemble is relatively the best regarding both Accuracy and Uncertainty. The post-hoc calibration method (Scaling) cannot keep ECE, as the temperature scale factor T was optimised on non-shifted data and thus the model was are probably not tolerant of various types and degress of dataset shift. In contrast, although Bayesian achieves the lowest Accuracy, its ECE shows as only a small fluctuation. This might due to the fact that the size of training data for this task is very small but the parameter estimation of the Bayesian model is more computationally intractable and needs more training data, and thus the model is still under-fitted. Results for respiratory task. In this task, as the dataset shift gets severer, for all methods, Accuracy declines, Brier becomes larger, and for most cases, Uncertainty goes up, indicating that the models are getting more uncertain, as what we expected. It is also worth noting that although Accuracy of those methods are relatively close, Brier gives a clear and fine-grained picture showing that the output probabilities are quite different and Ensemble can achieve the minimum error. However, an increase in ECE can be observed from Fig. 4a, b, and c, which shows that all methods yield over-confident predictions on the increasingly shifted testing set. Since Brier and Uncertainty show an upward trend in this task, we conduct further in-depth analysis to inspect whether the qualified predictive uncertainty, particularly by Bayesian and Ensemble approaches, are sufficient to secure the predictions under dataset shift. First, as uncertainty is usually used to select low-confident predictions and pass them to doctors [14], we compare the Accuracy on the remained data in Fig. 5a. Despite the performance improvement from selective prediction, the gap between the Accuracy on the original testing set (solid lines) and on the data mixed with shifts (dashed lines) is notable. This implies that the estimated distributional
356
T. Xia et al.
Fig. 4 Accuracy and uncertainty under corruptions for respiratory abnormality detection task
Fig. 5 A detailed comparison for the quantify of uncertainty in the respiratory abnormality detection task. a Accuracy on the remained data with samples having uncertainty higher the threshold referred. Solid lines denote the original testing set, while dashed lines present the average accuracy on the mixed original and shifted sets with shade showing the variance among shift degrees and types. b Accuracy for shift detection on the mixed original and shifted testing set: samples with Pr edictiveEntr opy > threshold will be detected as shifted inputs. c Uncertainty distribution from Ensemble method on Gaussian shift with degr ee = 5
uncertainty by Bayesian and Ensemble models cannot help to hold the impressive accuracy achieved on the non-shift testing set. Figure 5b further verifies the incapability, where we investigate if Uncertainty can be exploited to detect a shifted input from the training distribution. Yet, the performance is just slightly better than random guess: accuracy < 0.6 of Bayesian and Ensemble versus 0.5 of chance level. This indicates that the shifted data is not distinguishable from the original data by the present uncertainty estimation methods. Further, in Fig. 5c, we display an example of the severest Gaussian shift with blue bars denoting the uncertainty distribution on the original testing set and orange bars for the shifted set. It is good to see that the predictive uncertainties on the shifted set are generally higher than that on the non-shifted set. However, those two distributions are still close to each other, which undermines the capacity of current uncertainty measure approaches to tackle dataset shift in real applications. Results for heart arrhythmia task. Results for the heart arrhythmia detection task on Gaussian noise shift, segment missing shift, and sampling rate mismatch shift are presented in Fig. 6. The primary findings are consistent with the observations
Benchmarking Uncertainty Quantification …
357
Fig. 6 Accuracy and uncertainty under various corruptions for heart arrhythmia detection
Fig. 7 Comparison of output probability and predictive uncertainty from two methods on original and shifted samples. Gaussian noise is added to the signals with S N R = 10. a shows a normal ECG recording, and b is a non-AF abnormal rhythm recording
from the prior two tasks: Accuracy degrades and model becomes increasingly overconfident for all methods. Although as shown in Fig. 6b, with a great proportion of the signals missed, Bayesian and Ensemble methods can keep relatively good Accuracy compared to other methods, the slight increase in ECE and the small reduction in Uncertainty suggest that the uncertainty might be not fully reliable. We also carry out some visualisation comparison in Fig. 7. Figure 7a presents an example where Bayesian and Ensemble both achieve correct predictions for the original and shifted inputs. This normal rhythm ECG signal is predicted as normal with a probability over 0.5, but with the Gaussian noise shift, Bayesian and Ensemble yield lower probability for the normal class: p N = 0.37 and p N = 0.34, respectively. Meanwhile, Uncertainty of those two approaches rises from H = 1.0 to H = 1.1. Herein, this is a case that the distributional shift has been captured by the predictive uncertainty. In contrast, Fig. 7b demonstrates a failed prediction on the shifted input, as the shifted non-AF abnormal rhythm sample is predicted as a normal and an AF signal by Bayesian and Ensemble, respectively. What’s worse, Bayesian model yields a lower uncertainty value: from H = 1.09 to H = 0.98, which indicates that it is less uncertain although it makes a wrong prediction. This case may lead to a negative impact in a real automatic diagnosis system, as the estimated uncertainty falls short of its promise to reflect the reliability of the prediction. Summary and takeaways. Combining the observations from all three tasks, we draw the following conclusions and recommendations,
358
T. Xia et al.
• With increasing dataset shift in biosignals, all uncertainty estimation approaches we evaluated fail to report a reasonable increasing uncertainty score to notify the changes in data distribution, while the performance in terms of accuracy degrades sharply. • Ensemble can achieve a slightly better uncertainty estimation than the other methods, although it needs relatively heavier computing cost and memory consumption. Bayesian method can obtain similar performance when training data is sufficient. • Classifiers trained on non-shifted data might be biased on a specific dataset shift during inference. Thus, the measure of prediction uncertainty is as important as the prediction itself, particularly in safety-critical healthcare applications. • Models may become more and more over-confident as the shift gets severer. None of the existing methods is perfect in capturing distributional shifts and calibrating the deep neural networks. New approaches are needed.
6 Conclusions In this paper, we conduct extensive experiments and analysis to assess the estimated predictive uncertainty under dataset shift on biosignals. We implemented five uncertainty quantification methods on three representative biosignal classification tasks under a controlled dataset shift. To enable a comparison, we propose a protocol to analyse all methods without requiring new data: we synthesise signal-specific distributional shifts according to real signal data collection scenarios. Our work establishes a benchmark for future evaluation of uncertainty quantification methods. Acknowledgements This work was supported by ERC Project 833296 (EAR).
References 1. Band, N., Rudner, T. G., Feng, Q., Filos, A., Nado, Z., Dusenberry, M. W., Jerfel, G., Tran, D., & Gal, Y. (2021). Benchmarking bayesian deep learning on diabetic retinopathy detection tasks. In Thirty-fifth Conference on Neural Information Processing Systems Datasets and Benchmarks Track (Round 2). 2. Bhatt, U., Antorán, J., Zhang, Y., Liao, Q. V., Sattigeri, P., Fogliato, R., Melançon, G., Krishnan, R., Stanley, J., & Tickoo, O., et al. (2021). Uncertainty as a form of transparency: Measuring, communicating, and using uncertainty. In Proceedings of the 2021 AAAI/ACM Conference on AI, Ethics, and Society (AIES) (pp. 401–413). 3. Clifford, G. D., Liu, C., Moody, B., Li-wei, H. L., Silva, I., Li, Q., Johnson, A., & Mark, R. G. (2017). AF classification from a short single lead ECG recording: The physionet/computing in cardiology challenge 2017. In Proceedings of Computing in Cardiology (CinC) (pp. 1–4). IEEE. 4. Gairola, S., Tom, F., Kwatra, N., & Jain, M. (2020). RespireNet: A deep neural network for accurately detecting abnormal lung sounds in limited data setting. arXiv:2011.00196
Benchmarking Uncertainty Quantification …
359
5. Gal, Y., & Ghahramani, Z. (2016). Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In Proceedings of the International Conference on Machine Learning (ICML) (pp. 1050–1059). 6. Ganaie, M., Hu, M., Tanveer, M., & Suganthan, P. (2021). Ensemble deep learning: A review. arXiv:2104.02395 7. Gawlikowski, J., Tassi, C. R. N., Ali, M., Lee, J., Humt, M., Feng, J., Kruspe, A., Triebel, R., Jung, P., & Roscher, R., et al. (2021). A survey of uncertainty in deep neural networks. arXiv:2107.03342 8. Ghoshal, B., & Tucker, A. (2020). Estimating uncertainty and interpretability in deep learning for coronavirus (COVID-19) detection. arXiv:2003.10769 9. Goldberger, A. L., Amaral, L. A., Glass, L., Hausdorff, J. M., Ivanov, P. C., Mark, R. G., Mietus, J. E., Moody, G. B., Peng, C. K., & Stanley, H. E. (2000). PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex physiologic signals. Circulation, 101(23), e215–e220. 10. Graves, A. (2011). Practical variational inference for neural networks. Advances in Neural Information Processing Systems, 24, 9. 11. Guo, C., Pleiss, G., Sun, Y., & Weinberger, K. Q. (2017). On calibration of modern neural networks. In Proceedings of International Conference on Machine Learning (ICML) (pp. 1321– 1330). PMLR 12. Han, J., Xia, T., Spathis, D., Bondareva, E., Brown, C., Chauhan, J., Dang, T., Grammenos, A., Hasthanasombat, A., & Floto, A., et al. (2021). Sounds of COVID-19: Exploring realistic performance of audio-based digital testing. arXiv:2106.15523 13. Lakshminarayanan, B., Pritzel, A., & Blundell, C. (2017). Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in Neural Information Processing Systems, 30. 14. Leibig, C., Allken, V., Ayhan, M. S., Berens, P., & Wahl, S. (2017). Leveraging uncertainty information from deep neural networks for disease detection. Scientific Reports, 7(1), 1–14. 15. Liu, J., Paisley, J., Kioumourtzoglou, M. A., & Coull, B. (2019). Accurate uncertainty estimation and decomposition in ensemble learning. Advances in Neural Information Processing Systems, 32, 8952–8963. 16. Moon, J., Kim, J., Shin, Y., & Hwang, S. (2020). Confidence-aware learning for deep neural networks. In Proceedings of International Conference on Machine Learning (ICML) (pp. 7034– 7044). PMLR. 17. Ovadia, Y., Fertig, E., Ren, J., Nado, Z., Sculley, D., Nowozin, S., Dillon, J., Lakshminarayanan, B., & Snoek, J. (2019). Can you trust your model’s uncertainty? Evaluating predictive uncertainty under dataset shift. Advances in Neural Information Processing Systems, 32, 13991– 14002. 18. Pooch, E. H., Ballester, P. L., & Barros, R. C. (2019). Can we trust deep learning models diagnosis? The impact of domain shift in chest radiograph classification. arXiv:1909.01940 19. Raghu, M., Blumer, K., Sayres, R., Obermeyer, Z., Kleinberg, B., Mullainathan, S., & Kleinberg, J. (2019). Direct uncertainty prediction for medical second opinions. In Proceedings of International Conference on Machine Learning (ICML) (pp. 5281–5290). PMLR. 20. Ulmer, D., Meijerink, L., & Cinà, G. (2020). Trust issues: Uncertainty estimation does not enable reliable OOD detection on medical tabular data. In Proceedings of Machine Learning for Health (ML4H) (pp. 341–354). PMLR. 21. Van Amersfoort, J., Smith, L., Teh, Y. W., & Gal, Y.: Uncertainty estimation using a single deep deterministic neural network. In Proceedings of International Conference on Machine Learning (ICML) (pp. 9690–9700). PMLR. 22. Wang, C., Sun, S., & Grosse, R. (2021). Beyond marginal uncertainty: How accurately can bayesian regression models estimate posterior predictive correlations? In Proceedings of International Conference on Artificial Intelligence and Statistics (pp. 2476–2484). PMLR. 23. Xia, T., Han, J., Qendro, L., Dang, T., & Mascolo, C. (2021). Uncertainty-aware COVID19 detection from imbalanced sound data. In Proceedings of the Annual Conference of the International Speech Communication Association (INTERSPEECH) (Vol. 2021, pp. 2951– 2955).
Mining Adverse Drug Reactions from Unstructured Mediums at Scale Hasham Ul Haq, Veysel Kocaman, and David Talby
Abstract Adverse drug reactions/events (ADR/ADE) have a major impact on patient health and health care costs. While most ADR’s are not reported via formal channels, they are often documented in a variety of unstructured conversations such as social media posts or customer support call transcripts. In this paper, we propose a natural language processing (NLP) solution that detects ADR’s in such unstructured free-text conversations, which improves on previous work in three ways. First, a new Named Entity Recognition (NER) model obtains state-of-the-art accuracy for ADR and Drug entity extraction on the ADE, CADEC, and SMM4H benchmark datasets (91.75, 78.76, and 83.41% F1 scores respectively). Second, two new Relation Extraction (RE) models are introduced—one based on BioBERT while the other utilizing crafted features over a Fully Connected Neural Network (FCNN)—perform on par with existing state-of-the-art models, and outperform them when trained with a supplementary clinician-annotated RE dataset. Third, a new text classification model, obtains new state-of-the-art accuracy on the CADEC dataset (86.69% F1 score). The complete solution is implemented as a unified NLP pipeline in a productiongrade library built on top of Apache Spark, making it natively scalable for processing millions of records on commodity clusters. Keywords NLP · NER · Relation Extraction · Pharmacovigilance · Sparknlp
H. U. Haq (B) · V. Kocaman · D. Talby John Snow Labs Inc., 16192 Coastal Highway Lewes, Lewes, DE 19958, USA e-mail: [email protected] V. Kocaman e-mail: [email protected] D. Talby e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_26
361
362
H. U. Haq et al.
1 Introduction Adverse drug events are harmful side effects of drugs, comprising of allergic reactions, overdose response, and general unpleasant side effects. Approximately 2 million patients in the United States are affected each year by serious ADR’s, resulting in roughly 100,000 fatalities [1], and making ADR’s the fourth leading cause of death in the United States [2]. Treatment related to ADR’s has been estimated to cost $136 billion each year in the United States alone [3]. Finding all ADR’s of a drug before it is marketed is not practical for several reasons. First, The number of human subjects going through clinical trials is often too small to detect rare ADR’s. Second, many clinical trials are short-lasting while some ADR’s take time to manifest. Third, some ADR’s only show when a drug is taken together with other drugs, and not all drug-drug combinations can be tested during clinical trials. Fourth, drug repurposing or off-label usage can lead to unforeseen ADR’s. As a result, detecting ADR’s in drugs which are already being marketed is critical—a discipline known as postmarketing pharmacovigilance [4]. Schemes which allow hospitals, clinicians, and patients to report ADR’s have existed for many years, but only a fraction of events get reported through them. A meta-analysis of 37 studies from 12 countries found that the median rate of underreporting was 94% [5]. This led to work on mining ADR’s from alternative sources, such as social media posts by patients or healthcare providers [6] as such platforms have emerged as the largest source of information sharing among the masses. Outbreak of the COVID-19 pandemic has precipitated this trend of sharing such information [7]; The size, variety, and instantaneous nature of social media provides opportunities for real-time monitoring of ADRs [8]. Compared to traditional data source like research publications, this data is more challenging to process, as it is unstructured and contains noise in the form of jargon, abbreviations, misspellings, and complex sentence structures. Recent advancements in Natural Language Processing (NLP) in the form of Transformers [9] based architectures like BERT [10], have significantly pushed the boundaries of NLP capabilities. There is an increasing trend of training large models on domain-specific data like BioBERT [11], and these methods have proven to achieve state-of-the-art (SOTA) results for document understanding and named entity recognition (NER). However, since these methods require significant computational resources during both training and inferring, it becomes impractical to apply them over large quantities of records in compute-restricted production environments. Despite the growing interest and opportunities to process large quantities of data, models and software frameworks that can scale to leverage compute clusters are scarce. This restricts the ability to utilize available data from social media and other mediums—such as transcripts of customer service calls with patients, or CRM notes about sales and support discussions with clinicians—to their true potential. The availability of high volume, variety, and velocity of data presents the opportunity to develop NLP solutions that outperform the existing SOTA accuracy, while also being easily scalable and computationally efficient.
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
363
The purpose of this study is to illustrate how an end-to-end system, based on the Apache Spark ecosystem and comprising of novel NLP techniques, can be used to process large quantities of unstructured text to mine ADRs. This system has been implemented on a production-ready, widely deployed, and natively scalable library, thus capable of processing millions of records, in either batch or streaming modes. The unified NLP pipeline includes new models for the three required sub-tasks: classifying text to decide if it is an indication of an ADR, recognizing named entities for reactions and drugs, and linking adverse events with drugs. Following are the novel contributions of this paper: • The first scalable end-to-end system for mining ADR’s from unstructured text, including Document Classification, Named Entity Recognition, and Relation Extraction Models within a unified NLP pipeline. • New NER model for extracting reactions and drugs, whose accuracy outperforms previous SOTA models on public datasets for this task. • New Relation Extraction models for linking reactions and drugs, which outperform previous SOTA models when trained with additional data that was annotated by clinicians as part of this effort. • New text classification model for deciding if a piece of text reports an ADR, whose accuracy outperforms previous SOTA models. • Studying the utility of using non-contextual lightweight embeddings [12] like GloVe [13] instead of memory-intensive contextual embeddings like BioBERT for these tasks, by comparing training times and accuracy improvements. • Detailed analysis of all the solution components and datasets, explaining how its modular structure can be customized to different data sources and runtimes.
2 Related Work The extraction of ADRs from unstructured text has received growing attention in the past few years due to wide-spread adoption of Electronic Medical Records (EMR), and ever-increasing number of users on social media who share their experiences. Existing work comprises of significant contributions in both, novelty in information extraction methodologies, as well as availability of relevant pre-annotated datasets containing annotations for a variety of subtasks. The problem of ADR extraction gained visibility with the introduction of challenges like Social Media Mining for Healthcare (SMM4H) [14] and National Clinical NLP Challenges (n2c2) [15], which provide pre-annotated datasets for researchers to compete on. Other significant contributions for data collection include [16] which used the Pubmed corpus to develop the ADE corpus benchmark dataset, covering Classification, NER, and RE annotations for extracting and relating ADRs and drugs respectively. Another work [17] produced an NER dataset (CADEC) by collecting and annotating reviews and comments from forums.
364
H. U. Haq et al.
Identification of text containing ADRs is formulated as a text classification problem for which different techniques have been applied. Huynh et al. [18] used different variations of Convolutional Neural Network (CNN) (e.g., CNN, CRNN, CNNA) to identify tweets containing ADRs on the twitter dataset. More elaborate techniques like fine-tuning of BERT models have been applied for text classification as well [19]. A standard method of formulating the extraction of drugs and ADR mentions is NER, for which, a number of architectures have been proposed. One of the classical approach is to use a BiLSTM [20] architecture with Conditional Random Fields (CRF) as used by [21]. This method is a shallow network that relies on word embeddings and part of speech tags to classify each token to extract ADR mentions. Ge et al. [22] also added character level embeddings to the same architecture to incorporate spelling features, and enriched the training dataset by annotating additional data from DBpedia to achieve SOTA results on the CADEC dataset, demonstrating the benefits of using additional data. Similar to our approach, they also built an extensive training framework over multiple nodes. Relating ADR mentions with the drugs is formulated as a relation extraction (RE) task, which comprises of creation and classification of relations between entities [23]. Classical RE methods like [24] use lexical rules based on dependency parsing tree of the document. The introduction of transformer allowed for more context-aware solutions like feeding entity spans and document to transformers to predict relations [25]. Recently, more elaborate approaches like joint learning of both NER and RE have proved to be more beneficial. For example, [26] used a single base network to generate joint features, while using separate BiRNN layers for both NER and RE, and creating skip connections between the NER and RE BiRNN layers to achieve SOTA performance on RE. While existing work has focused on pushing the boundaries for accuracy, little work is done to build a framework that can process large quantities of data from social media with accuracy. To achieve this, we develop separate architectures for all three tasks, and place them in a single pipeline, allowing us to maintain a modular structure to develop and test each component separately, while sharing common components (e.g., tokenization and embedding generation) for scalability.
3 Approach We divide the problem into three main tasks; Document Classification, Named Entity Recognition and Relation Extraction, and draw distinct solutions for each one of them for scalability. Since NER plays the most important role of identifying entity spans, we place all components in a single pipeline for an end-to-end solution. Figure 1 explains the complete pipeline using Apache Spark framework. As illustrated in the system diagram in Fig. 1, Relation Extraction is heavily dependent on the NER model, as the latter provides relevant entity chunks which form basic inputs of the RE model. Since NER requires token level embeddings, we
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
365
Fig. 1 Overview of the complete architecture. All the components are sequentially placed in a single pipeline. Arrows represent output of one stage as input to the next stage
test with different types of embeddings; namely GLoVe [13] and BERT [10] based embeddings. This modular approach helps us keep the NER and RE architecture static while experimenting with different embedding types to analyse accuracy and performance differences. Given the nature of the data, we trained 200-dimension GLoVe embeddings on Pubmed and MIMIC datasets. For BERT embeddings we utilize the work by [11], namely BioBERT. In general, BERT embeddings provide more useful information due to being context-aware and better handling of out of vocabulary (OOV) tokens.
3.1 Classification To be able to process large volume of data, the text classification model needs to be scalable, and accurate, as it is used to filter out documents, reviews, and tweets that do not contain any indication of adverse event. To achieve this, we use an FCNN model that does not require hand-crafted features, and relies on a single embedding vector for classification. Given the conversational nature of social media text, we can utilise the entire document to get efficient embeddings (with little text clipping in case of BioBERT embeddings) that is directly fed to the classifier model. Since there is only a single feature vector as input to the model, we test multiple embedding techniques to analyse performance.
3.2 Named Entity Recognition To extract ADR and other entities from text, we use our class-leading NER architecture, called BiLSTM-CNN-Char. We build our NER model by taking the work of [27] as the base model, and made a few changes in the architecture according to our testing; removing lexical features like POS tags, and introducing new character level features. We used 1D convolution layer comprising of 25 filters having kernel size 3 to generate token feature maps that encapsulate information like spelling and
366
H. U. Haq et al.
Fig. 2 Proposed NER architecture, inspired by [27]
casing. These additional features proved highly useful while dealing with spelling mistakes, as well as out-of-vocabulary tokens. We also updated the architecture by using BlockFused-LSTM cells in our implementation for increased speed. Figure 2 explains the architecture of our NER model.
3.3 Relation Extraction We treat Relation Extraction (RE) as a binary classification problem where each example is a pair of drug and ADR mentions in a given context, and develop two novel solutions; the first one comprising of a simpler FCNN architecture for speed, and the second one based on the BioBERT architecture for accuracy. We experiment both approaches and compare their results. For our first RE solution we rely on entity spans and types identified by the NER model to develop distinct features to feed to an FCNN for classification. At first we generate pairs of adverse event and drug entities, and then generate custom features for each pair. These features include semantic similarity of the entities, syntactic distance of the two entities, dependency structure of the entire document, embedding vectors of the entity spans, as well as embedding vectors for 100 tokens within the vicinity of each entity. Figure 3 explains our model architecture in detail. We then concatenate these features and feed them to fully connected layers with leaky relu activation. We also use batch normalisation after each affine transformation before feeding to the final softmax layer with cross-entropy loss function. We use softmax cross-entropy instead of binary cross-entropy loss to keep the architecture flexible for scaling on datasets having multiple relation types.
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
367
Fig. 3 Overview of the first Relation Extraction model. All the features are vertically stacked in a single feature vector. The feature vector is kept dynamic with additional padding for compatibility across different embedding sizes, and complex dependency structures
Our second solution focuses on a higher accuracy, as well as exploration of relations across long documents, and is based on [25]. In our experiment we take checkpoints from the BioBERT model and train an end-to-end model for relation extraction. Similar to our first solution, we rely on entity spans and use the entire document as context string while training the model. The original paper used sequence length of 128 tokens for the context string, which we keep constant, and instead experiment with the context string, additional data, and fine-tuning techniques.
4 Experimental Setup 4.1 Datasets We test our models on three benchmark datasets; SMM4H NER challenge [14], ADE Corpus [16] and CADEC [17]. The SMM4H NER challenge is a yearly challenge based on annotated twitter data. As this dataset is entirely based on tweets, it forms an ideal testing bed to test our model’s performance on real world data. The ADE Corpus dataset is a benchmark dataset for classification, NER and RE tasks, while the CADEC dataset is primarily used for classification and NER benchmarks only. Keeping consistency with existing work, as well as aligning with our primary goal
368
H. U. Haq et al.
Table 1 Statistics of the benchmark NER datasets Dataset # Sentences # Tokens ADE corpus
4272
86865
CADEC
7597
121656
SMM4H
2253
42175
# Entity tags ADE: 12264, Drug: 5544 ADE: 15903, Drug: 2032 ADE: 3575, Drug: 1586
Table 2 Statistics of the RE datasets used to train and validate the models. ADE Corpus is the standard dataset, which is then enriched with n2c2 data for more robust performance Dataset # Positive relations # Negative relations ADE Corpus ADE with n2c2
6821 12929
183 8935
of extracting ADRs and related drugs, we keep two entities in all datasets; ADE and Drug. Details of the NER datasets can be found in Table 1. Since we treat the RE problem as binary classification, we need positive relations as well as negative relations to train the model. Positive relations are defined if the drug and reaction entities are related in the context, while negative relations comprise of drugs that are not responsible for a particular reaction. This relation can be formulated as below: P(Dr ug|AD E) From the ADE dataset, we can sample negative relations by subtracting annotated drug-reaction pairs from all drug-reaction pairs in the same document. Table 2 shows data distribution of the standard and enriched RE datasets. The standard ADE Corpus does not have sufficient negative relations, raising the issue of class imbalance. To address this, we sampled and annotated 2000 notes from 2018 n2c2 shared task on ADE and medication extraction in EHRs dataset [15], to create a supplementary dataset for relations. We keep the same entities (i.e., Drug and ADE) while annotating to align with our benchmark datasets. Also, to keep human bias at a minimum, we don’t annotate entity spans; rather we use existing NER annotations to generate a dataset comprising of Drug and ADE pairs, and only classify each relation based on their context. Following previous work, we evaluate the models using 10-fold cross validation, and report macro and micro averaged precision, recall, and F1 scores. Exact experimental and evaluation settings for each stage are described below.
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
369
4.2 Experiments • Keeping the Classification model architecture constant, we test two methods of embedding generation; For the first experiment we generated token level GLoVe and BioBERT embeddings for each token and averaged them to generate document embedding. This method did not produce accurate embeddings, resulting in lower F1 scores. Our second experiment utilised Sentence Bert Embeddings [28] trained on the Pubmed dataset, namely Sentence BioBERT, which we further pretrained on the MedNLI dataset [29] for better performance. We thoroughly test the performance of each embedding generation approach, and report the findings in Table 3. • For the NER model we align all the datasets in the standard CoNLL format and use IOB (Inside, Outside, Beginning) tagging scheme. We tested other tagging schemes as well, like BIOES, and found IOB to be the simplest and best performing scheme. Since our NER architecture requires word embeddings, we experiment with two types of embeddings. For the GLoVe embeddings, we use our 200-dimension embeddings, and leverage BioBERT embeddings for contextual embeddings. For thorough analysis of the NER model, we evaluate the results using both strict and relax approach. Under strict evaluation a label is considered as correct if the starting and ending tags exactly match with the gold labels, while under relax evaluation only an overlap between annotations is considered. Consequently, the ‘O’ tag is not included in the calculation. Hyperparameter values, and training code is explained in Appendix A & B. • For training RE models, we use standard NER spans and binary labels. For our base RE model we use 200-dimensional token-level GLoVe embeddings—the same embeddings we use for our base NER model. For our BERT based RE model, we don’t use any explicit embeddings as the BERT model is trained in an end-to-end fashion. We do specify details of entity spans like starting, ending indices, entity types, and the context in between. The context is generally the entire document, but since the model architecture has a 128 token limit, we create context text by taking text in between the entities, and found this method to be more accurate.
Table 3 Classification Metrics on benchmark datasets. For each dataset, Macro and Micro averaged scores are displayed on first and second row respectively. SOTA metrics for ADE and CADEC datasets are obtained from [18] to [30] respectively Dataset GLoVe (Avg.) Emb. BERT (Avg.) Emb. BERT Sent. Emb. SOTA Prec. Recall F1 Prec. Recall F1 Prec. Recall F1 F1 ADE CADEC
75.96 86.84 85.29 85.99
79.53 81.22 84.24 86.10
76.86 83.43 84.71 86.0
76.91 88.13 86.50 87.38
84.96 84.38 86.11 87.43
79.37 85.38 86.30 87.40
87.41 90.97 87.13 87.78
84.72 91.20 86.32 87.86
85.96 91.03 86.69 87.79
87.0 81.5
370
H. U. Haq et al.
We also test a hypothesis that fine-tuning BioBERT model on similar Relation Extraction tasks would increase the overall performance on the benchmark datasets. To test this hypothesis, we train an end-to-end RE model on Disease and Drug datasets like the 2010 i2b2 challenge [32] and saved it. We then use the same weights while discarding the final layers, and retrain the model on the benchmark dataset. Since the base model is trained on a similar taxonomy, the convergence was much faster, while being less prone to over-fitting. For Hyperparameter tuning we utilize the development set and use random search. Exact hyperparameter values, and the search space for all the models can be found in Appendix A.
4.3 Results Despite using a shallow architecture for classification, we achieved metrics that are on-par with SOTA metrics by using more accurate Sentence Bert Embeddings, as shown in Table 3. While the performance difference between BioBERT and GLoVe embeddings is minor on the CADEC dataset, the difference is more prominent on the ADE dataset. This is primarily because of the complex intrinsic nature of biomedical text, where averaging token (GLoVe) embeddings does not efficiently capture the context of complex sentence structures. Our NER architecture acheives new SOTA metrics on SMM4H, ADE, and CADEC NER datasets using contextual BioBERT embeddings as shown in Table 4. Since the NER model was kept constant during the experiment, and we tuned the hyper parameters for each experiment, the performance difference between embedding types can be attributed to the word embeddings alone. Being able to incorporate contextual information with attention mechanism, BioBERT embeddings peformed better than non-contextual GLoVe embeddings. However, it is worth noticing that the performance difference between the two is in a margin of 1–2%, proving that domain
Table 4 NER metrics on benchmark datasets. For each dataset, macro and micro averaged scores are displayed on first and second row respectively. SOTA metrics for ADE, CADEC, and SMM4H are obtained from [21, 22, 31] respectively, and are macro-averaged Dataset
GLoVe Embeddings Precision Strict
ADE CADEC
BERT Embeddings
Recall
Relax Strict
F1
Relax Strict
Precision Relax Strict
Recall
Relax Strict
SOTA F1 Relax Strict
Relax
88.32 93.77 89.26 94.80 88.78 94.27 90.0
94.47 93.56 98.22 91.75 96.31 91.3
87.81 93.59 88.81 94.66 88.30 94.12 89.6
94.37 93.18 98.13 91.36
96.21
78.14 89.04 77.14 88.01 77.62 88.50 78.53 88.63 79.03 89.32 78.76 88.95 71.9 71.87 86.36 71.67 86.13 71.75 86.23 72.38 86.14 73.64 87.66 72.99
SMM4H
F1
81.43 90.33 72.17 78.51 76.01 83.41 78.5
86.88
86.76 75.23 82.42 76.73 84.41 67.81
83.66 91.34 71.31 77.86 76.99 84.06 79.13 87.09 74.33 81.81 76.65
84.36
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
371
Table 5 Training and inference time (in seconds) taken by the NER model on each dataset using different token embeddings with respect to overal performance on test set. Epoch count was kept constant for all datasets while training. The experiment was performed on an 8-core machine having 64gb memory Dataset BERT GLoVe Train (s) Infer (s) F1 Train Infer (s) F1 ADE CADEC SMM4H
1980 2475 860
329 351 136
91.75 78.76 76.73
1417s 1929 635s
22 26 11
88.78 77.62 76.01
Table 6 Relation Extraction performance on the ADE benchmark dataset. The test set was kept standard for a fair comparison, and all scores are macro-averaged due to high class imbalance. SOTA metrics for RE on ADE corpus as reported by [26] Dataset Base (FCNN) RE BERT RE SOTA Prec. Recall F1 Prec. Recall F1 F1 ADE 69.11 Corpus ADE 89.01 Enriched with n2c2
86.22
74.70
81.31
79.03
80.10
89.44
89.22
89.19
90.93
90.02
83.74
Table 7 Full pipeline results on sample texts. Documents having indication for ADR are classified as ADE, while positive relations represent causality between two entities (Drug and ADE). The last example is classified as negative—meaning it does not contain any ADE indication, so we don’t process it further Document Class ADE entity Drug entity Relation I feel a bit drowsy & Have a little blurred vision after taking insulin @yho fluvastatin gave me cramps, but lipitor suits me! I just took advil and haven’t had any gastric problems so far
ADE
Drowsy blurred vision
Insulin insulin
Positive positive
ADE
Cramps cramps
Fluvastatin lipitor
Positive negative
NEG
–
–
–
specific GLoVe embeddings can provide comparable performance while requiring significantly less memory and computational resources. Table 5 provides side-byside comparison of time and accuracy differences while using different embedding types. On average, the GLoVe embeddings are 30% faster compared to BioBERT embeddings during training, and more than 5x faster during inference, while being on-par in terms of f1 score.
372
H. U. Haq et al.
Fig. 4 Visualization of Entity Recognition and Relation Extraction results on a sample text. 1 denotes positive relation and 0 denotes negative relation (not related)
Our RE solutions perform on-par with existing SOTA systems, while being scalable and requiring less memory to train and test. The introduction of the extra data greatly improved results, enabling us to achieve SOTA on benchmark datasets as shown in Table 6. While the more heavy BioBERT model outperformed our proposed RE model on the limited and imbalanced ADE dataset, the performance difference becomes diminutive when more data is added to the training data. Sample output and visualization of NER and RE results can be seen in Table 7 and Fig. 4.
5 Conclusion Despite the growing need and explosion of useful data for pharmacovigilance, there is a severe deficiency of production-ready NLP systems that can process millions of records while being accurate and versatile. In this study we address the problem by introducing novel solutions for Classification, NER, and RE while leveraging the Spark ecosystem and contemplating on accuracy, scalability, and versatility. For which we explain how we build a modular structure comprising of different embedding types, a classification and NER model, and two approaches for RE. We trained custom GLoVe embeddings model on domainspecific dataset, and compare its performance to SOTA BioBERT embeddings. We show through extensive testing that our text classification model, for deciding if a conversation includes an ADR, obtains new state-of-the-art accuracy on the CADEC dataset (86.69% F1 score). Our proposed NER architecture achieves SOTA results on multiple benchmark datasets. Namely, our proposed NER models obtain new stateof-the-art accuracy for ADR and Drug entity extraction on the ADE, CADEC, and SMM4H benchmark datasets (91.75%, 78.76%, and 83.41% F1 scores respectively). Then we explain two different architectures for RE, one based on BioBERT while the other utilizing crafted features over an FCNN, test them individually, and show that a simpler RE architecture with bespoke features performed on-par with more sophisticated BERT solution. To improve our RE model, we built a new dataset by manual annotations, and achieved higher metrics on the RE test datasets. Furthermore, we performed speed benchmarks to compare efficiency of two distinct embedding generation models to determine the ideal choice for deploying such solutions to process large quantities of data. In general, most pharmaceutical com-
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
373
panies run on-premise servers which are geared towards general computation and do not utilise hardware acceleration like GPUs for running heavy models; In such cases where infrastructure is not mature enough to handle heavy models, lightweight glove-based models are a compelling alternative to BERT-based models, as they offer comparable performance while being memory and CPU efficient. Finally, we implement all these algorithms in Apache Spark ecosystem for scalability, and shipped in a production grade NLP library: Spark NLP.
Appendix A. Hyperparameter Settings The following parameters provided best results on the classification development set (values within the parenthesis represent the parameter ranges tested): • • • • • •
Dropout rate: 0.2 (0.2, 0.7) Batch size: 8 (4, 256) Learning rate: 0.0003 (0.01, 0.0003) Epoch: 25–30 (10, 100) Optimizer: Adam Learning rate decay coefficient (po) (real learning rate = lr / (1 + po * epoch) : 0.005 (0.001, 0.01))
The following parameters provided best results on the NER development set (Values within the parenthesis represent the parameter ranges tested): • • • • • • •
LSTM state size: 200 (200, 250) Dropout rate: 0.5 (0.3, 0.7) Batch size: 8 (4, 256) Learning rate: 0.001 (0.01, 0.0003) Epoch: 25–35 (10, 100) Optimizer: Adam Learning rate decay coefficient (po) (real learning rate = lr / (1 + po * epoch) Smith [2018] : 0.005 (0.001, 0.01))
The following parameters provided best results on the RE development set (Values within the parenthesis represent the parameter ranges tested): • • • •
Dropout rate: 0.5 (0.3, 0.7) Batch size: 8 (4, 256) Learning rate: 0.0001 (0.01, 0.0003) Epoch: 4-BERT (1–10), 50-FCNN (10–100).
374
H. U. Haq et al.
B. Training Code Code for training an RE model is provided as a google colab notebook [33].
References 1. Leaman, R., Wojtulewicz, L., Sullivan, R., Skariah, A., Yang, J., & Gonzalez, G.(2010). Towards internet-age pharmacovigilance: Extracting adverse drug reactions from user posts in Health-Related social networks. In Proceedings of the 2010 Workshop on Biomedical Natural Language Processing (pp. 117–125). 2. Giacomini, K. M., Krauss, R. M., Roden, D. M., Eichelbaum, M., Hayden, M. R., & Nakamura, Y. (2007). When good drugs go bad. Nature, 446(7139), 975–977. 3. van der Hooft, C. S., Sturkenboom, M. C. J. M., van Grootheest, K., Kingma, H. J., & Stricker, B. H. C. H. (2006). Adverse drug reaction-related hospitalisations. Drug Safety, 29(2), 161–168. 4. Mammi, M., Citraro, R., Torcasio, G., Cusato, G., Palleria, C., & , di Paola, E. D. (2013). Pharmacovigilance in pharmaceutical companies: An overview. Journal of Pharmacology & Pharmacotherapeutics, 4(Suppl. 1), S 33. 5. Hazell, L., & Shakir, S. A. W. (2006). Under-Reporting of adverse drug reactions. Drug Safety, 29(5), 385–396. 6. Bollegala, D., Maskell, S., Sloane, R., Hajne, J., Pirmohamed, M., et al. (2018). Causality patterns for detecting adverse drug reactions from social media: Text mining approach. JMIR Public Health and Surveillance, 4(2), e8214. 7. Cinelli, M., Quattrociocchi, W., Galeazzi, A., Valensise, C. M., Brugnoli, E., Schmidt, A. L., Zola, P., Zollo, F., & Scala, A. (2020). The covid-19 social media infodemic. Scientific Reports, 10(1), 1–10. 8. Sloane, R., Osanlou, O., Lewis, D., Bollegala, D., Maskell, S., & Pirmohamed, M. (2015). Social media and pharmacovigilance: A review of the opportunities and challenges. British Journal of Clinical Pharmacology, 80(4), 910–920. 9. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention is all you need. arxiv:1706.03762. 10. Devlin, J., Chang, M. W., Lee, K., & Toutanova, K. (2018). BERT: Pre-training of deep bidirectional transformers for language understanding. arxiv:1810.04805. 11. Lee, J., Yoon, W., Kim, S., Kim, D., Kim, S., So, C. H., & Kang, J. (2019). Biobert: A pre-trained biomedical language representation model for biomedical text mining. arxiv:1901.08746. 12. Mikolov, T., Chen, K., Corrado, G., & Dean, J. (2013). Efficient estimation of word representations in vector space. In Y. Bengio & Y. LeCun (Eds.), 1st International Conference on Learning Representations, ICLR 2013, Scottsdale, Arizona, USA, May 2–4, 2013, Workshop Track Proceedings. 13. Pennington, J., Socher, R., & Manning, C. (2014). GloVe: Global vectors for word representation. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP) (pp. 1532–1543). Doha, Qatar: Association for Computational Linguistics. 14. Weissenbacher, D., & Gonzalez-Hernandez, G. (Eds.). (2019). Proceedings of the Fourth Social Media Mining for Health Applications (#SMM4H) Workshop & Shared Task. Florence, Italy: Association for Computational Linguistics. 15. Henry, S., Buchan, K., Filannino, M., Stubbs, A., & Uzuner, O. (2020). 2018 n2c2 shared task on adverse drug events and medication extraction in electronic health records. Journal of the American Medical Informatics Association, 27(1), 3–12. 16. Gurulingappa, H., Rajput, A. M., Roberts, A., Fluck, J., Hofmann-Apitius, M., & Toldo, L. (2012). Development of a benchmark corpus to support the automatic extraction of drug-related
Mining Adverse Drug Reactions from Unstructured Mediums at Scale
17. 18.
19.
20. 21.
22. 23. 24. 25. 26. 27. 28. 29. 30.
31. 32.
33.
375
adverse effects from medical case reports. Journal of Biomedical Informatics, 45(5), 885–892. Text Mining and Natural Language Processing in Pharmacogenomics. Karimi, S., Metke-Jimenez, A., Kemp, M., & Wang, C. (2015). Cadec: A corpus of adverse drug event annotations. Journal of Biomedical Informatics, 55, 73–81, 03 Huynh, T., He, Y., Willis, A., & Rueger, S. (2016). Adverse drug reaction classification with deep neural networks. In Proceedings of COLING 2016, the 26th International Conference on Computational Linguistics: Technical Papers (pp.877–887). Osaka, Japan: The COLING 2016 Organizing Committee. Kayastha, T., Gupta, P., & Bhattacharyya, P. (2021). BERT based adverse drug effect tweet classification. In Proceedings of the Sixth Social Media Mining for Health (#SMM4H) Workshop and Shared Task (pp. 88–90). Mexico City, Mexico: Association for Computational Linguistics. Graves, A., & Schmidhuber, J. (2005). Framewise phoneme classification with bidirectional lstm and other neural network architectures. Neural Networks, 18(5), 602–610, IJCNN 2005. Stanovsky, G., Gruhl, D., & Mendes, P. (2017). Recognizing mentions of adverse drug reaction in social media using knowledge-infused recurrent models. In Proceedings of the 15th Conference of the European Chapter of the Association for Computational Linguistics: Volume 1, Long Papers (pp. 142–151). Valencia, Spain: Association for Computational Linguistics. Ge, S., Wu, F., Wu, C., Qi, T., Huang, Y., & Xie, X. (2020). Fedner: Privacy-preserving medical named entity recognition with federated learning. arxiv:2003.09288. Haq, H. U., Kocaman, V., & Talby, D. (2021). Deeper clinical document understanding using relation extraction. Fundel, K., Kuffner, R., & Zimmer, R. (2006). RelEx-Relation extraction using dependency parse trees. Bioinformatics, 23(3), 365–371, 12 Soares, L. B., FitzGerald, N., Ling, J., & Kwiatkowski, T. (2019). Matching the blanks: Distributional similarity for relation learning. arxiv:1906.03158. Crone, P. (2020). Deeper task-specificity improves joint entity and relation extraction. arxiv:2002.06424. Chiu, J. P. C. & Nichols, E. (2015). Named entity recognition with bidirectional lstm-cnns. arxiv:1511.08308. Reimers, N., & Gurevych, I. (2019). Sentence-bert: Sentence embeddings using siamese bertnetworks. arxiv:1908.10084. Shivade, C. (2019). Mednli-a natural language inference dataset for the clinical domain. Alimova, I., & Tutubalina, E. (2019). Entity-level classification of adverse drug reaction: A comparative analysis of neural network models. Programming and Computer Software, 45, 439–447, 12. Yan, Z., Zhang, C., Fu, J., Zhang, Q., & Wei, Z. (2021). A partition filter network for joint entity and relation extraction. Uzuner, O., South, B. R., Shen, S., & DuVall, S. L. (2011). 2010 i2b2/va challenge on concepts, assertions, and relations in clinical text. Journal of the American Medical Informatics Association, 18(5), 552–556. JSL. (2021). Training code for re. https://github.com/JohnSnowLabs/spark-nlp-workshop/ blob/master/tutorials, 2021. Retrieved December 23, 2021, from Notebook: Certification_Trainings/Healthcare/10.3.Clinical_RE_SparkNLP_Paper_Reproduce.ipynb.
A Graph-based Imputation Method for Sparse Medical Records Ramon Viñas, Xu Zheng, and Jer Hayes
Abstract Electronic Medical Records (EHR) are extremely sparse. Only a small proportion of events (symptoms, diagnoses, and treatments) are observed in the lifetime of an individual. The high degree of missingness of EHR can be attributed to a large number of factors, including device failure, privacy concerns, or other unexpected reasons. Unfortunately, many traditional imputation methods are not well suited for highly sparse data and scale poorly to high dimensional datasets. In this paper, we propose a graph-based imputation method that is both robust to sparsity and to unreliable unmeasured events. Our approach compares favourably to several standard and state-of-the-art imputation methods in terms of performance and runtime. Moreover, results indicate that the model learns to embed different event types in a clinically meaningful way. Our work can facilitate the diagnosis of novel diseases based on the clinical history of past events, with the potential to increase our understanding of the landscape of comorbidities. Keywords Graph neural network · Data imputation · Data sparsity
1 Introduction Missing data is a pervasive challenge for medical domains that can result in the reduction of the statistical power of a study and can produce biased estimates, leading to invalid conclusions [3]. Previous research has demonstrated success in data imputation with both statistical and generative models based approaches. These traditional data imputation methods include univariate methods such as mean or median impuR. Viñas University of Cambridge, Cambridge, UK e-mail: [email protected] X. Zheng (B) · J. Hayes Accenture Labs Dublin, Dublin, Ireland e-mail: [email protected] J. Hayes e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_27
377
378
R. Viñas et al.
tation, multi-variate methods like k-NN imputation [7] and Multivariate Imputation by Chained Equations (MICE) [8], and deep learning methods such as autoencoders [9], and GAIN [10]. Although these methods show good performance on traditional datasets, they do not explicitly deal with the sparsity and imbalance problem that characterises many medical datasets. From a machine learning perspective, medical data can potentially be very sparse, e.g. there are over 69,000 International Classification of Diseases (ICD) diagnosis codes and even a very diseased patient will only present a minuscule fraction of these codes. Together with high dimensionality, the sparsity issue often makes it infeasible to apply traditional data imputation methods on medical datasets. In recent years, Graph Neural Networks (GNN) have gained increasing attention for modelling graph-based datasets, including social networks, citation networks, and molecules. In this paper, we build on GRAPE [11] to develop a GNN-based imputation method for highly imbalanced medical datasets. Essentially, GRAPE represents data as a bipartite graph where sample and feature nodes are connected according to the missingness pattern of the dataset. Data imputation then corresponds to edge-level prediction for the missing links. To deal with the highly sparse and unary nature of EHRs (i.e. missing values represent either unmeasured events or negative outcomes), we employ a training scheme that balances positive and unmeasured edges in the graph. We perform experiments on EHRs from a subset of more than 72,000 diabetes patients from the IBM Explorys dataset. Our results show improved performance with respect to traditional and state-of-the-art methods that do not explicitly deal with the imbalanced and unary nature of the data. Besides, we demonstrate that the latent embeddings of medical concepts retrieved from our model cluster in a clinically meaningful way.
2 Method We propose a graph-based imputation method based on GRAPE [11] for sparsity and to unreliable unmeasured events using IBM Explores diabetes patient dataset. Problem Definition Let X ∈ {0, 1}m×n be a medical dataset with m patients and n event types (e.g. diagnoses). For each patient i and event type j, the entry xi j is one if event i (e.g. familial Mediterranean fever) has been observed for patient j and zero otherwise (i.e. unmeasured events with unknown outcomes). Here we assume that the entries xi j are binary (unmeasured and measured), but our approach can be extended to account for more complex variable types (e.g. sequence of event times). Our goal is to impute the unmeasured values of dataset X. Bipartite graph representation Following [11], we represent data in a bipartite graph G = {V p ∪ Ve , E}, where V p = p p {v1 , ..., vm } is the patient partition, Ve = {v1e , ..., vne } is the event partition, and E =
A Graph-based Imputation Method for Sparse Medical Records p
379
p
{(vi , v ej )|vi ∈ V p , v ej ∈ Ve , xi j = 1} is the set of edges connecting patients from V p to events from Ve according to the measured entries of the dataset. This framework also allows attributed edges (e.g. with event times) with potential edge repetitions (i.e. a certain event can occur several times for the same patient). Model We employ a graph neural network to perform link prediction on the bipartite graph. This procedure can be divided into 3 steps. First, we initialise node features of the patient V p and event nodes Ve . For event nodes, we use d-dimensional learnable embeddings as initial node values. The idea is that these weights, which will be learnt through gradient descent, should summarise relevant properties of each event. For patient nodes, we initialise the node features with the available demographic information (e.g. age and sex) and project them to the d-dimensional space with a multi-layer perceptron. Importantly, this formulation allows transfer learning between sets of distinct patients. Second, we perform message passing to compute latent node embeddings. p Let H p = {h1 , ..., hmp } and He = {he1 , ..., hen } be the initial k-dimensional patient p and event node embeddings, respectively. Let N p (i; E) = { j|(vi , v ej ) ∈ E} and p p e Ne (i; E) = { j|(v j , vi ) ∈ E} be the set of neighbours of nodes vi and vie , respectively. p p e e We compute latent node embeddings Hˆ p = { hˆ 1 , ..., hˆ m } and Hˆ e = { hˆ 1 , ..., hˆ n } with separate GraphSAGE layers [2] as follows: 1 hej |N p (i; E)| j∈N (i) p p e 1 hˆ i = W e1 hie + W e2 h |Ne (i, E)| j∈N (i) j
p p p p hˆ i = W 1 hi + W 2
(1)
(2)
e
p
p
where W 1 , W 2 , W e1 , W e2 are learnable weights. Optionally, we can stack several layers interleaving non-linearities. Finally, unmeasured values are imputed via link prediction. We compute the probp e p ability of an edge between nodes vi and v ej as p(xi j = 1) = MLP( hˆ i , hˆ j ), where MLP is a multi-layer perceptron with a sigmoid function as output activation. Optimisation To deal with the highly unbalanced data, we employ a training scheme that balances positive and unmeasured edges. At each training iteration, we sample three sets of edges: • Invisible edges. The set Einv contains k positive edges from E that are unseen to the model, where k is a hyperparameter. The goal is to correctly predict the presence of these edges via link prediction.
380
R. Viñas et al.
Fig. 1 Distribution of per-event recalls and specificities (y-axis) by event frequencies (x-axis), computed from the graph-based model trained under two undersampling schemes. Version 1 (red): the number of events in the invisible and negative sets match (i.e. |Einv | = |Eneg |). Version 2 (blue, ours): the event and patient frequencies are additionally preserved (i.e. |N p (i; Eneg )| = |N p (i; Einv )| and |Ne ( j; Eneg )| = |Ne ( j; Einv )| for any patient i and event type j). Version 1 (red) introduces a sampling bias – common events are rarely imputed as negative, while rare events are rarely imputed as active
• Visible edges. The set Evis contains the remaining |E| − k positive edges that are seen during message passing. The visible and invisible subsets are disjoint and E = Evis ∪ Einv . • Negative edges. The set Eneg contains k edges that do not belong to E. The goal of the model is to correctly predict the absence of these edges. Importantly, the cardinalities of Eneg and Einv match. Moreover, the set Eneg preserves the patients and event frequencies, i.e. |N p (i; Eneg )| = |N p (i; Einv )| and |Ne ( j; Eneg )| = |Ne ( j; Einv )| for any patient i and event type j. This effectively balances the model’s exposure to positive and unmeasured edges for each patient and event type, preventing any sampling biases (see Fig. 1). We then optimise model’s parameters via gradient descent by minimising the binary cross-entropy: L(Einv , Eneg ) = − −
1 2k 1 2k
log p(xi j = 1)
(3)
log p(xi j = 0)
(4)
p
(vi ,v ej )∈Einv
p (vi ,v ej )∈Eneg
A Graph-based Imputation Method for Sparse Medical Records
381
3 Results Evaluation metrics Medical records are extremely sparse—only a very small proportion of events (e.g. symptoms, diagnoses, treatments) are observed during the lifetime of an individual. At the same time, we cannot always be certain that unobserved events have not occurred because in practise we can only measure a small fraction of them (e.g., it is unfeasible to test someone for all known diseases)—this is precisely why we want to impute missing values. Conversely, observed events have happened in reality with high confidence (e.g. chemotherapy for lung cancer). In this paper, we evaluate the imputation performance with sensitivity, specificity, and balanced accuracy. In contrast to accuracy (uninformative in sparse scenarios) and precision (sensitive to unreliable false positives), the proposed metrics are both robust to sparsity and to unreliable unmeasured events. Dataset Patients with diabetes are sampled from the IBM Explorys database. We create a bipartite graph of patients and events using diagnoses-related events. We filter out events that appear in less than 0.1% of the records, resulting in 3284 unique events. We split patients into disjoint train (70%) and test (30%) sets, yielding 72801 and 30334 unique train and test patients, respectively. For the test patients, we mask out 30% of the observed values and use them to evaluate the performance of all the models. We leverage the age and sex of the patients as demographic information provided as input to the models. When represented as a matrix, the dataset is highly sparse, with 98.3% of zero entries. Baseline models We compare our model to several baseline methods, including k-NN imputation [7], Generative Adversarial Imputation Networks (GAIN) [10] and Denoising Autoencoders (DAE) [9]. As these baseline models can only handle tabular data, we represent patient records as a binary matrix where rows correspond to patients and columns to unique diagnosis codes. In this matrix, entry (i, j) is one if the j-th diagnosis has been observed for patient i-th and zero otherwise. The denoising autoencoder (DAE) and generative adversarial imputation networks (GAIN) are both optimised via the reconstruction error on the observed values (plus an adversarial term for the missing values for GAIN). Because the dataset is highly imbalanced, both models are by default biased towards the majority class, i.e. zero for each feature. Additionally, GAIN cannot readily deal with the unary nature of the data—missing positive values cannot be distinguished from actual zeros (i.e. events with negative outcome) and they both form the mask vector. To address these issues, we adopt an undersampling mechanism that closely mimicks the training scheme of the GNN model. At each training iteration, we randomly sample k negative values (i.e. unmeasured events), where k is the total number of positive values, and treat them as observable. The remaining entries are masked out (and form the mask vector for GAIN) and the methods are optimised by minimising their respective loss functions.
382
R. Viñas et al.
Hyperparameters We use node embeddings of dimension d = 95. We initialise them with the right singular vectors of the train dataset computed via singular value decomposition (SVD). This yields higher validation scores according to our experiments. The graph neural network architecture consists of 3 GraphSAGE layers with node embeddings of dimension d = 95 and rectified linear unit (ReLU) activations. The final multi-layer perceptron comprises 1 hidden layer with 32 units followed by ReLU. We optimise the model with the Adam optimiser [4] and a learning rate of 0.0066. At each training iteration, we randomly sample the invisible Einv set of edges from a binomial distribution B(1, p) with probability p = 0.2, that is, on average 20% of the total number of training edges are masked out. The set Eneg of unmeasured edges is then sampled as described in the optimisation section, preserving the cardinality of Einv . We implement the model in Pytorch [6] and Pytorch Geometric [1]. Results Table 1 shows the test imputation scores. We compute the sensitivity, specificity, and balanced accuracy for each event type and report the mean standard deviation of these metrics. Importantly, conventional imputation methods such as k-NN imputation do not have any built-in mechanisms to deal with the inherent characteristics of medical records (i.e. sparsity and unreliability of unmeasured events) and model calibration is therefore unclear. For these methods, we analyse their performance under two different thresholds, namely a 0.5 cutoff (i.e. for patient i, event j is imputed as measured if p(xi j = 1) > 0.5) and the per-event-type frequencies of measured values in the train set (i.e. for patient i, event j is imputed as measured if p(xi j = 1) > m 1 i=1 x i j , where m is the number of train patients). Despite the data sparsity, the m GNN-based method attains highly balanced predictions (sensitivity= 0.78 ± 0.12, specificity= 0.79 ± 0.09) with the default 0.5 cutoff, outperforming other baselines by a large margin in terms of balanced accuracy (arithmetic mean of sensitivity and specificity). We attribute this to the training scheme, which effectively balances the model’s exposure to positive and unmeasured edges for each event type, yielding
Table 1 Imputation results. We report the mean and standard deviation of the per-event-type evaluation metric for different cutoff values. For the threshold Avg., we use the per-event-type frequencies of observed values in the train set as cutoffs. The best scores are highlighted in bold Method Cutoff Sensitivity Specificity Balanced Acc. Runtime 10-NN DAE GAIN Ours
0.5 Avg. 0.5 Avg. 0.5 Avg. 0.5
0.08 ± 0.20 0.63 ± 0.29 0.39 ± 0.36 0.98 ± 0.07 0.36 ± 0.29 0.45 ± 0.31 0.78 ± 0.12
0.98 ± 0.10 0.63 ± 0.26 0.85 ± 0.26 0.17 ± 0.25 0.83 ± 0.25 0.73 ± 0.29 0.79 ± 0.09
0.53 ± 0.07 0.63 ± 0.09 0.62 ± 0.12 0.57 ± 0.12 0.59 ± 0.11 0.59 ± 0.11 0.79 ± 0.09
43.29h 0.25h 0.25h 1.09h
A Graph-based Imputation Method for Sparse Medical Records
383
Fig. 2 Low-dimensional representation of the latent event embeddings. Each point corresponds to a different event type (e.g. malnutrition-related diabetes mellitus). We use spectral clustering to cluster events into 10 clusters and analyse their composition in Fig. 3. Overall, we find that the model is able to identify clusters of clinically-related events
a well-calibrated model (see Fig. 1 for a comparison with a random undersampling method). The proposed approach is also highly scalable and significantly faster than traditional methods. We further study the behaviour of our model by inspecting the latent event embeddings obtained after message passing. We employ UMAP [5] to project these embeddings into a 2-dimensional space and further apply spectral clustering (N = 10 clusters) to cluster events (see Fig. 2). Figure 3 depicts their event type composition. Cluster 0 mostly consists of mental and behavioural related disorders (65%). Events in cluster 2 involve injuries and poisoning (69%) and external causes of morbidity and mortality (31%). Cluster 3 is composed exclusively by diseases of the ear and mastoid process. Clusters 4 and 6 are both related to pregnancy, childbirth, and the puerperium. The majority of events in cluster 5 are diseases of the skin and the subcutaneous tissue (55%), while most events in cluster 9 are diseases of the genitourinary system. The remaining clusters are more heterogeneous and understanding their semantics would possibly require a finer-grained analysis (e.g. cluster 1 consists of events related to contact with health services, 23%, and diseases of the digestive system, 21%, among others). Overall, this analysis shows that the model is grouping clinically-related event types in the latent space, with high cluster purity.
384
R. Viñas et al.
Fig. 3 Composition of low-dimensional event clusters. We inspect the event-type composition of clusters identified via spectral clustering on the latent event embeddings of our imputation model (see Fig. 2). We observe high cluster purity, that is, several clusters consist mostly of events of the same type. For example, 65% of events in cluster 0 are mental and behavioural related disorders, while most events in clusters 4 and 6 are related to pregnancy and childbirth. Other clusters are more heterogeneous and understanding their semantics would possibly require a finer-grained analysis
4 Conclusion In this paper, we have studied the problem of imputing missing data in medical records. These datasets are highly sparse and unmeasured events are unreliable (i.e. the fact that a specific event has not been observed for a certain patient does not entail that it has not occurred in reality). Unfortunately, traditional imputation methods are not well suited for this scenario. To address this challenge, we have proposed a graphbased deep learning model that is both scalable and effective at imputing missing values in sparse regimes. The proposed model is easy to use and well-calibrated by default. Furthermore, our approach compares favourably to existing methods in terms of performance and runtime. This work can facilitate the diagnosis of new events and shed light into the landscape of comorbidities.
References 1. Fey, M., & Lenssen, J. E. (2019). Fast graph representation learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds. 2. Hamilton, W. L., Ying, R, & Leskovec, J. (2018). Inductive representation learning on large graphs. 3. Hyun, K. (2013). The prevention and handling of the missing data. Korean Journal of Anesthesiology, 64(5), 402–406.
A Graph-based Imputation Method for Sparse Medical Records
385
4. Kingma, D. P., & Ba, J. (2014). Adam: A method for stochastic optimization. arXiv:1412.6980. 5. McInnes, L., Healy, J., & Melville, J. (2018). Umap: Uniform manifold approximation and projection for dimension reduction. arXiv:1802.03426. 6. Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Kopf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., Bai, J., & Chintala, S. (2019). Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’Alché-Buc, E. Fox & R. Garnett (Eds.), Advances in Neural Information Processing Systems (Vol. 32, pp. 8024–8035). Curran Associates, Inc. 7. Troyanskaya, O., Cantor, M., Sherlock, G., Brown, P., Hastie, T., Tibshirani, R., Botstein, D., & Altman, R. B. (2001). Missing value estimation methods for DNA microarrays. Bioinformatics, 17(6), 520–525. 8. Van Buuren, S., & Groothuis-Oudshoorn, K. (2011). Mice: Multivariate imputation by chained equations in R. Journal of Statistical Software, 45, 1–67. 9. Vincent, P., Larochelle, H., Bengio, Y., & Manzagol, P.-A. (2008). Extracting and composing robust features with denoising autoencoders. In Proceedings of the 25th International Conference on Machine Learning (pp. 1096–1103). 10. Yoon, J., Jordon, J., & Schaar, M. (2018). Gain: Missing data imputation using generative adversarial nets. In International Conference on Machine Learning (pp. 5689–5698). PMLR. 11. You, J., Ma, X., Ding, D., Kochenderfer, M., & Leskovec, J. (2020). Handling missing data with graph representation learning. NeurIPS.
Using Nursing Notes to Predict Length of Stay in ICU for Critically Ill Patients Sudeshna Jana, Tirthankar Dasgupta, and Lipika Dey
Abstract Managing resource critical facilities like Intensive Care Units (ICUs) is an important task for hospital management officials. Predicting how long a patient is going to stay in ICU is considered an important problem for managers. Several attempts have been made to solve this problem using different types of clinical data that are available from the past. While a number of studies have deployed classification models that use structured clinical variables, recent advances in Natural Language Processing models have opened up the possibilities of using unstructured text data like nursing notes, discharge summaries, etc. for prediction. In this work, we have proposed the use of CNN and LSTM based prediction networks along with transformer-based language models for representing the notes data. The proposed model can predict with a much higher accuracy rate than any other existing model. The dataset used for the experiment is MIMIC, which is an anonymized dataset that contains detailed records of around 40,000 patients most of whom were critically ill. We use the first day’s nursing notes for prediction since that can provide most relevant and valuable input to planning. Keywords ICU Length of stay · Nursing note · Clinical bioBERT · Severity of illness score
1 Introduction Intensive Care Units (ICUs) in hospitals are specially equipped to continuously monitor and provide continuous care and support to critically ill patients. Given the cost of setting up these facilities, every hospital has only a limited number of ICU facilities. Consequently, hospital management needs to do careful planning to ensure that S. Jana (B) · T. Dasgupta · L. Dey TCS Research, Kolkata, India e-mail: [email protected] T. Dasgupta e-mail: [email protected] L. Dey e-mail: [email protected] © The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 A. Shaban-Nejad et al. (eds.), Multimodal AI in Healthcare, Studies in Computational Intelligence 1060, https://doi.org/10.1007/978-3-031-14771-5_28
387
388
S. Jana et al.
while the resources are well-utilized, non-availability of the facilities do not lead to fatality. Predictive planning based on available data has proved to be of great help in this regard. Admission records of past patients containing details about their illness, results of diagnostic tests, treatments, nursing notes along with number of days spent in ICU and other wards can help in building these predictive models. In general, this kind of data is difficult to obtain, due to privacy regulations and security reasons. In this work, we have utilized the publicly available Medical Information Mart for Intensive Care (MIMIC) database [11] to build a predictive model for determining length of ICU stay (henceforth, will be referred to as ICU LOS) for a patient. This data has been anonymized and specially made available for research purposes only. In the past, researchers working on the MIMIC dataset have used a range of predictor variables from the available health data to predict the length of ICU stay [6, 7]. This includes both numeric data like body temperature, blood pressure, heart rate etc. and categorical valued data like diagnosis category, admission type etc. The model presented by us in this paper additionally utilizes the Nursing notes, which are completely unstructured in nature. During hospitalization, a nursing note contains information about a patient’s condition - both physical and psychological as assessed by a nurse, which can provide additional information about a patient beyond physiological parameters measured by instruments or radiology reports etc. Nursing notes can also provide critical information about a patient’s response to treatment based on behavioral descriptions documented by the caregiver. The notes therefore are a rich source of information for predicting the status of a patient, and consequently the need for critical care, if any. Use of linguistic expressions like “extensive cardiac hx”, “slightly tachypneic”, “severe scrotal infection” provide an added dimension of human assessment, that cannot be captured through numbers only, but can be important while distinguishing between two similar patients who are possibly responding differently to the treatment. Our objective is to develop a model system that can use the first nursing note that is prepared at the time of admission to ICU to predict whether the patient’s ICU stay will be short or long, where a stay shorter than the median value of a dataset is termed as short, otherwise long. For this dataset, the median value was found to be 4. The nursing notes, that are written as free unstructured text, are represented using transformer based language models. We have worked with BlueBERT [15] and Clinical BioBERT [3] to build the representation. Additionally, we have observed that using a Term Frequency—Inverted Document Frequency(TF-IDF) feature vector can substantially improve the performance of prediction. This is due to the inherent capability of the TF-IDF vector to simultaneously capture the common and distinct features effectively. We have experimented with Convolutional Neural Network(CNN ) [9] and Long Short Term Memory(LSTM) [8] based architectures to build the predictive model. The rest of the paper is organized as follows: In Sect. 2, we present a brief overview of related works done by previous authors in this area. Before presenting the proposed predictive architecture, details of the MIMIC dataset is provided in Sect. 3.
Using Nursing Notes to Predict Length of Stay in ICU for Critically Ill Patients
389
Detailed discussions about the architecture and results are presented in Sects. 4 and 5 respectively. On comparison with past prediction results obtained on the dataset, the performance of the proposed model is found to be significantly higher for all evaluation indices such as accuracy, AUC-ROC and kappa score [4].
2 Related Works In this section, we present a review of related works in the area. It may be noted that while most of these studies have worked on predicting the length of ICU stay, assumption about the time at which the prediction is made, varies across the papers. Though most of the work reported here have used the MIMIC dataset for prediction, a few like [17] have worked on other datasets. One major aspect that distinguishes the models from each other is the set of predictor variables used. The choice of prediction method is often guided by the choice of the predictor features. In [6] a neural network based model is used to predict the length of remaining hospital stay for a patient at the time of exit from ICU unit. In their study, they used several medical attributes like patient’s demography, CPT events, services, procedures, diagnosis, etc. of 31,018 patients from the MIMIC database. In [7] the authors proposed a channel-wise LSTM model using multitask training for predicting mortality along with a forecast for length of stay in ICU. In this work, the forecast was for remaining time to be spent in ICU made at each hour of stay, where the remaining time was a member of one of 10 classes. Predictions were generated from 17 clinical variables like capillary refill rate, diastolic blood pressure, fraction inspired oxygen, glasgow coma scale, glucose, heart rate, BP, etc. of a patient from the mimic database. In [17] a deep learning architecture based on the combination of temporal convolution and pointwise convolution was proposed to predict the length of ICU stay. This work used a different dataset, the eICU critical care dataset [16], that contained records of 118,534 unique patients. Patient details in this dataset contain features like gender, age, hour of admission, height, weight, ethnicity, unit admit source, unit visit number, unit stay, num beds, physician speciality, etc. which are all structured in nature. In [2] a study was presented that reported prediction of length of stay in ICU and mortality, using patient’s vital signs like heart rate, BP, temp., resp. rate, age, gender, height, weight, etc. This work was done on the MIMIC database. In this work, they presented the results obtained by using traditional classification models like SVM, Decision tree and Random Forest for predicting length of stay as a binary valued variable. In [18] models for predicting mortality, severity, and length of stay were proposed for a set of 2224 Sepsis patients who were admitted to the ICU of Peking Union Medical College Hospital over a period of three years. Three machine learning classification models namely logistic regression, random forest, and XGBoost were tested to classify the patients into two classes that indicated a long or short ICU
390
S. Jana et al.
stay. The models used clinical parameters such as age, P(v-a)CO2 /C(a-v)O2 , SO2 , oxygenation index, white blood cell count, oxygen concentration, bpm, temp, etc. from the first 6h in the ICU. It is worth mentioning here that none of the above mentioned works used textual data for prediction. Recently, in [1] the authors have used discharge summary notes to predict multiple clinical outcomes along with length of ICU stay. In MIMIC database, the discharge summary contains an aggregation of all clinical notes generated during the hospital stay of a patient. In this work, a separate note was generated from the discharge summary after some filtering. This filtered note was used for the prediction purpose. They developed a pretrained model, called CORe, on top of BioBERT [14] weights and classified the patients into four categories. The four categories were based on the actual length of ICU stay- under 3 days, 3 to 7 days, 1 week to 2 weeks, and more than 2 weeks. Other than the last study, most of the other works have used only clinical parameters for predicting the length of stay in ICU. The richness of nursing notes have not been exploited earlier for prediction. In this work we report higher accuracies of prediction using the nursing notes.
3 Dataset As our primary data source, we have used MIMIC-III (v1.4) database [11], which contain the details of over forty thousand patients who stayed in critical care units of the Beth Israel Deaconess Medical Center between 2001 and 2012. This publicly available dataset contains the details of around 58,976 distinct hospital admissions. An overview of the overall dataset is depicted in Table 1. In our present study, for each ICU admission, we have considered all the nursing notes from the first 24 h of ICU admission and combined them as a single note. In a single hospital admission, a patient could have stayed in the ICU more than once. We have taken only those admissions which have unique ICU stay and the patients were admitted in ICU within 2 days of hospital admission. We exclude all those admissions whose total length of ICU stay is less than 1 day. This resulted in a total of 22,789 unique admission entries. Table 2 depicts the detailed description of the modified dataset. Table 1 MIMIC data statistics No of patients No of hospital admission No of disease No of ICU admission No of patient who admitted in ICU No of Nursing notes
46,520 58,976 15,692 61,532 46,476 10,46,053
Using Nursing Notes to Predict Length of Stay in ICU for Critically Ill Patients
391
Table 2 Summary of the dataset used for ICU stay classification No. of ICU admission 22,789 No. of notes Max no. of tokens in a note Avg. no. of tokens in a note No. of classes No. of data in class ‘Short’ No. of data in class ‘Long’
22,789 4986 1937 2(“Short”; ICU LOS