145 57 117MB
English Pages [821]
LNCS 14220
Hayit Greenspan · Anant Madabhushi · Parvin Mousavi · Septimiu Salcudean · James Duncan · Tanveer Syeda-Mahmood · Russell Taylor (Eds.)
Medical Image Computing and Computer Assisted Intervention – MICCAI 2023 26th International Conference Vancouver, BC, Canada, October 8–12, 2023 Proceedings, Part I
Lecture Notes in Computer Science Founding Editors Gerhard Goos Juris Hartmanis
Editorial Board Members Elisa Bertino, Purdue University, West Lafayette, IN, USA Wen Gao, Peking University, Beijing, China Bernhard Steffen , TU Dortmund University, Dortmund, Germany Moti Yung , Columbia University, New York, NY, USA
14220
The series Lecture Notes in Computer Science (LNCS), including its subseries Lecture Notes in Artificial Intelligence (LNAI) and Lecture Notes in Bioinformatics (LNBI), has established itself as a medium for the publication of new developments in computer science and information technology research, teaching, and education. LNCS enjoys close cooperation with the computer science R & D community, the series counts many renowned academics among its volume editors and paper authors, and collaborates with prestigious societies. Its mission is to serve this international community by providing an invaluable service, mainly focused on the publication of conference and workshop proceedings and postproceedings. LNCS commenced publication in 1973.
Hayit Greenspan · Anant Madabhushi · Parvin Mousavi · Septimiu Salcudean · James Duncan · Tanveer Syeda-Mahmood · Russell Taylor Editors
Medical Image Computing and Computer Assisted Intervention – MICCAI 2023 26th International Conference Vancouver, BC, Canada, October 8–12, 2023 Proceedings, Part I
Editors Hayit Greenspan Icahn School of Medicine, Mount Sinai, NYC, NY, USA Tel Aviv University Tel Aviv, Israel Parvin Mousavi Queen’s University Kingston, ON, Canada
Anant Madabhushi Emory University Atlanta, GA, USA Septimiu Salcudean The University of British Columbia Vancouver, BC, Canada Tanveer Syeda-Mahmood IBM Research San Jose, CA, USA
James Duncan Yale University New Haven, CT, USA Russell Taylor Johns Hopkins University Baltimore, MD, USA
ISSN 0302-9743 ISSN 1611-3349 (electronic) Lecture Notes in Computer Science ISBN 978-3-031-43906-3 ISBN 978-3-031-43907-0 (eBook) https://doi.org/10.1007/978-3-031-43907-0 © The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 Chapter “Unsupervised Domain Transfer with Conditional Invertible Neural Networks” is licensed under the terms of the Creative Commons Attribution 4.0 International License (http://creativecommons.org/licenses/ by/4.0/). For further details see license information in the chapter. This work is subject to copyright. All rights are reserved by the Publisher, whether the whole or part of the material is concerned, specifically the rights of translation, reprinting, reuse of illustrations, recitation, broadcasting, reproduction on microfilms or in any other physical way, and transmission or information storage and retrieval, electronic adaptation, computer software, or by similar or dissimilar methodology now known or hereafter developed. The use of general descriptive names, registered names, trademarks, service marks, etc. in this publication does not imply, even in the absence of a specific statement, that such names are exempt from the relevant protective laws and regulations and therefore free for general use. The publisher, the authors, and the editors are safe to assume that the advice and information in this book are believed to be true and accurate at the date of publication. Neither the publisher nor the authors or the editors give a warranty, expressed or implied, with respect to the material contained herein or for any errors or omissions that may have been made. The publisher remains neutral with regard to jurisdictional claims in published maps and institutional affiliations. This Springer imprint is published by the registered company Springer Nature Switzerland AG The registered company address is: Gewerbestrasse 11, 6330 Cham, Switzerland Paper in this product is recyclable.
Preface
We are pleased to present the proceedings for the 26th International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI). After several difficult years of virtual conferences, this edition was held in a mainly in-person format with a hybrid component at the Vancouver Convention Centre, in Vancouver, BC, Canada October 8–12, 2023. The conference featured 33 physical workshops, 15 online workshops, 15 tutorials, and 29 challenges held on October 8 and October 12. Co-located with the conference was also the 3rd Conference on Clinical Translation on Medical Image Computing and Computer-Assisted Intervention (CLINICCAI) on October 10. MICCAI 2023 received the largest number of submissions so far, with an approximately 30% increase compared to 2022. We received 2365 full submissions of which 2250 were subjected to full review. To keep the acceptance ratios around 32% as in previous years, there was a corresponding increase in accepted papers leading to 730 papers accepted, with 68 orals and the remaining presented in poster form. These papers comprise ten volumes of Lecture Notes in Computer Science (LNCS) proceedings as follows: • Part I, LNCS Volume 14220: Machine Learning with Limited Supervision and Machine Learning – Transfer Learning • Part II, LNCS Volume 14221: Machine Learning – Learning Strategies and Machine Learning – Explainability, Bias, and Uncertainty I • Part III, LNCS Volume 14222: Machine Learning – Explainability, Bias, and Uncertainty II and Image Segmentation I • Part IV, LNCS Volume 14223: Image Segmentation II • Part V, LNCS Volume 14224: Computer-Aided Diagnosis I • Part VI, LNCS Volume 14225: Computer-Aided Diagnosis II and Computational Pathology • Part VII, LNCS Volume 14226: Clinical Applications – Abdomen, Clinical Applications – Breast, Clinical Applications – Cardiac, Clinical Applications – Dermatology, Clinical Applications – Fetal Imaging, Clinical Applications – Lung, Clinical Applications – Musculoskeletal, Clinical Applications – Oncology, Clinical Applications – Ophthalmology, and Clinical Applications – Vascular • Part VIII, LNCS Volume 14227: Clinical Applications – Neuroimaging and Microscopy • Part IX, LNCS Volume 14228: Image-Guided Intervention, Surgical Planning, and Data Science • Part X, LNCS Volume 14229: Image Reconstruction and Image Registration The papers for the proceedings were selected after a rigorous double-blind peerreview process. The MICCAI 2023 Program Committee consisted of 133 area chairs and over 1600 reviewers, with representation from several countries across all major continents. It also maintained a gender balance with 31% of scientists who self-identified
vi
Preface
as women. With an increase in the number of area chairs and reviewers, the reviewer load on the experts was reduced this year, keeping to 16–18 papers per area chair and about 4–6 papers per reviewer. Based on the double-blinded reviews, area chairs’ recommendations, and program chairs’ global adjustments, 308 papers (14%) were provisionally accepted, 1196 papers (53%) were provisionally rejected, and 746 papers (33%) proceeded to the rebuttal stage. As in previous years, Microsoft’s Conference Management Toolkit (CMT) was used for paper management and organizing the overall review process. Similarly, the Toronto paper matching system (TPMS) was employed to ensure knowledgeable experts were assigned to review appropriate papers. Area chairs and reviewers were selected following public calls to the community, and were vetted by the program chairs. Among the new features this year was the emphasis on clinical translation, moving Medical Image Computing (MIC) and Computer-Assisted Interventions (CAI) research from theory to practice by featuring two clinical translational sessions reflecting the real-world impact of the field in the clinical workflows and clinical evaluations. For the first time, clinicians were appointed as Clinical Chairs to select papers for the clinical translational sessions. The philosophy behind the dedicated clinical translational sessions was to maintain the high scientific and technical standard of MICCAI papers in terms of methodology development, while at the same time showcasing the strong focus on clinical applications. This was an opportunity to expose the MICCAI community to the clinical challenges and for ideation of novel solutions to address these unmet needs. Consequently, during paper submission, in addition to MIC and CAI a new category of “Clinical Applications” was introduced for authors to self-declare. MICCAI 2023 for the first time in its history also featured dual parallel tracks that allowed the conference to keep the same proportion of oral presentations as in previous years, despite the 30% increase in submitted and accepted papers. We also introduced two new sessions this year focusing on young and emerging scientists through their Ph.D. thesis presentations, and another with experienced researchers commenting on the state of the field through a fireside chat format. The organization of the final program by grouping the papers into topics and sessions was aided by the latest advancements in generative AI models. Specifically, Open AI’s GPT-4 large language model was used to group the papers into initial topics which were then manually curated and organized. This resulted in fresh titles for sessions that are more reflective of the technical advancements of our field. Although not reflected in the proceedings, the conference also benefited from keynote talks from experts in their respective fields including Turing Award winner Yann LeCun and leading experts Jocelyne Troccaz and Mihaela van der Schaar. We extend our sincere gratitude to everyone who contributed to the success of MICCAI 2023 and the quality of its proceedings. In particular, we would like to express our profound thanks to the MICCAI Submission System Manager Kitty Wong whose meticulous support throughout the paper submission, review, program planning, and proceeding preparation process was invaluable. We are especially appreciative of the effort and dedication of our Satellite Events Chair, Bennett Landman, who tirelessly coordinated the organization of over 90 satellite events consisting of workshops, challenges and tutorials. Our workshop chairs Hongzhi Wang, Alistair Young, tutorial chairs Islem
Preface
vii
Rekik, Guoyan Zheng, and challenge chairs, Lena Maier-Hein, Jayashree KalpathyKramer, Alexander Seitel, worked hard to assemble a strong program for the satellite events. Special mention this year also goes to our first-time Clinical Chairs, Drs. Curtis Langlotz, Charles Kahn, and Masaru Ishii who helped us select papers for the clinical sessions and organized the clinical sessions. We acknowledge the contributions of our Keynote Chairs, William Wells and Alejandro Frangi, who secured our keynote speakers. Our publication chairs, Kevin Zhou and Ron Summers, helped in our efforts to get the MICCAI papers indexed in PubMed. It was a challenging year for fundraising for the conference due to the recovery of the economy after the COVID pandemic. Despite this situation, our industrial sponsorship chairs, Mohammad Yaqub, Le Lu and Yanwu Xu, along with Dekon’s Mehmet Eldegez, worked tirelessly to secure sponsors in innovative ways, for which we are grateful. An active body of the MICCAI Student Board led by Camila Gonzalez and our 2023 student representatives Nathaniel Braman and Vaishnavi Subramanian helped put together student-run networking and social events including a novel Ph.D. thesis 3minute madness event to spotlight new graduates for their careers. Similarly, Women in MICCAI chairs Xiaoxiao Li and Jayanthi Sivaswamy and RISE chairs, Islem Rekik, Pingkun Yan, and Andrea Lara further strengthened the quality of our technical program through their organized events. Local arrangements logistics including the recruiting of University of British Columbia students and invitation letters to attendees, was ably looked after by our local arrangement chairs Purang Abolmaesumi and Mehdi Moradi. They also helped coordinate the visits to the local sites in Vancouver both during the selection of the site and organization of our local activities during the conference. Our Young Investigator chairs Marius Linguraru, Archana Venkataraman, Antonio Porras Perez put forward the startup village and helped secure funding from NIH for early career scientist participation in the conference. Our communications chair, Ehsan Adeli, and Diana Cunningham were active in making the conference visible on social media platforms and circulating the newsletters. Niharika D’Souza was our cross-committee liaison providing note-taking support for all our meetings. We are grateful to all these organization committee members for their active contributions that made the conference successful. We would like to thank the MICCAI society chair, Caroline Essert, and the MICCAI board for their approvals, support and feedback, which provided clarity on various aspects of running the conference. Behind the scenes, we acknowledge the contributions of the MICCAI secretariat personnel, Janette Wallace, and Johanne Langford, who kept a close eye on logistics and budgets, and Diana Cunningham and Anna Van Vliet for including our conference announcements in a timely manner in the MICCAI society newsletters. This year, when the existing virtual platform provider indicated that they would discontinue their service, a new virtual platform provider Conference Catalysts was chosen after due diligence by John Baxter. John also handled the setup and coordination with CMT and consultation with program chairs on features, for which we are very grateful. The physical organization of the conference at the site, budget financials, fund-raising, and the smooth running of events would not have been possible without our Professional Conference Organization team from Dekon Congress & Tourism led by Mehmet Eldegez. The model of having a PCO run the conference, which we used at
viii
Preface
MICCAI, significantly reduces the work of general chairs for which we are particularly grateful. Finally, we are especially grateful to all members of the Program Committee for their diligent work in the reviewer assignments and final paper selection, as well as the reviewers for their support during the entire process. Lastly, and most importantly, we thank all authors, co-authors, students/postdocs, and supervisors for submitting and presenting their high-quality work, which played a pivotal role in making MICCAI 2023 a resounding success. With a successful MICCAI 2023, we now look forward to seeing you next year in Marrakesh, Morocco when MICCAI 2024 goes to the African continent for the first time. October 2023
Tanveer Syeda-Mahmood James Duncan Russ Taylor General Chairs Hayit Greenspan Anant Madabhushi Parvin Mousavi Septimiu Salcudean Program Chairs
Organization
General Chairs Tanveer Syeda-Mahmood James Duncan Russ Taylor
IBM Research, USA Yale University, USA Johns Hopkins University, USA
Program Committee Chairs Hayit Greenspan Anant Madabhushi Parvin Mousavi Septimiu Salcudean
Tel-Aviv University, Israel and Icahn School of Medicine at Mount Sinai, USA Emory University, USA Queen’s University, Canada University of British Columbia, Canada
Satellite Events Chair Bennett Landman
Vanderbilt University, USA
Workshop Chairs Hongzhi Wang Alistair Young
IBM Research, USA King’s College, London, UK
Challenges Chairs Jayashree Kalpathy-Kramer Alexander Seitel Lena Maier-Hein
Harvard University, USA German Cancer Research Center, Germany German Cancer Research Center, Germany
x
Organization
Tutorial Chairs Islem Rekik Guoyan Zheng
Imperial College London, UK Shanghai Jiao Tong University, China
Clinical Chairs Curtis Langlotz Charles Kahn Masaru Ishii
Stanford University, USA University of Pennsylvania, USA Johns Hopkins University, USA
Local Arrangements Chairs Purang Abolmaesumi Mehdi Moradi
University of British Columbia, Canada McMaster University, Canada
Keynote Chairs William Wells Alejandro Frangi
Harvard University, USA University of Manchester, UK
Industrial Sponsorship Chairs Mohammad Yaqub Le Lu Yanwu Xu
MBZ University of Artificial Intelligence, Abu Dhabi DAMO Academy, Alibaba Group, USA Baidu, China
Communication Chair Ehsan Adeli
Stanford University, USA
Organization
Publication Chairs Ron Summers Kevin Zhou
National Institutes of Health, USA University of Science and Technology of China, China
Young Investigator Chairs Marius Linguraru Archana Venkataraman Antonio Porras
Children’s National Institute, USA Boston University, USA University of Colorado Anschutz Medical Campus, USA
Student Activities Chairs Nathaniel Braman Vaishnavi Subramanian
Picture Health, USA EPFL, France
Women in MICCAI Chairs Jayanthi Sivaswamy Xiaoxiao Li
IIIT, Hyderabad, India University of British Columbia, Canada
RISE Committee Chairs Islem Rekik Pingkun Yan Andrea Lara
Imperial College London, UK Rensselaer Polytechnic Institute, USA Universidad Galileo, Guatemala
Submission Platform Manager Kitty Wong
The MICCAI Society, Canada
xi
xii
Organization
Virtual Platform Manager John Baxter
INSERM, Université de Rennes 1, France
Cross-Committee Liaison Niharika D’Souza
IBM Research, USA
Program Committee Sahar Ahmad Shadi Albarqouni Angelica Aviles-Rivero Shekoofeh Azizi Ulas Bagci Wenjia Bai Sophia Bano Kayhan Batmanghelich Ismail Ben Ayed Katharina Breininger Weidong Cai Geng Chen Hao Chen Jun Cheng Li Cheng Albert C. S. Chung Toby Collins Adrian Dalca Jose Dolz Qi Dou Nicha Dvornek Shireen Elhabian Sandy Engelhardt Ruogu Fang
University of North Carolina at Chapel Hill, USA University of Bonn and Helmholtz Munich, Germany University of Cambridge, UK Google, Google Brain, USA Northwestern University, USA Imperial College London, UK University College London, UK University of Pittsburgh and Boston University, USA ETS Montreal, Canada Friedrich-Alexander-Universität Erlangen-Nürnberg, Germany University of Sydney, Australia Northwestern Polytechnical University, China Hong Kong University of Science and Technology, China Institute for Infocomm Research, A*STAR, Singapore University of Alberta, Canada University of Exeter, UK Ircad, France Massachusetts Institute of Technology and Harvard Medical School, USA ETS Montreal, Canada Chinese University of Hong Kong, China Yale University, USA University of Utah, USA Heidelberg University Hospital, Germany University of Florida, USA
Organization
Aasa Feragen Moti Freiman Huazhu Fu Adrian Galdran Zhifan Gao Zongyuan Ge Stamatia Giannarou Yun Gu Hu Han Daniel Hashimoto Mattias Heinrich Heng Huang Yuankai Huo Mobarakol Islam Jayender Jagadeesan Won-Ki Jeong Xi Jiang Yueming Jin Anand Joshi Shantanu Joshi Leo Joskowicz Samuel Kadoury Bernhard Kainz
Davood Karimi Anees Kazi Marta Kersten-Oertel Fahmi Khalifa Minjeong Kim Seong Tae Kim Pavitra Krishnaswamy
Jin Tae Kwak Baiying Lei Xiang Li Xiaoxiao Li Yuexiang Li Chunfeng Lian
xiii
Technical University of Denmark, Denmark Technion - Israel Institute of Technology, Israel IHPC, A*STAR, Singapore Universitat Pompeu Fabra, Barcelona, Spain Sun Yat-sen University, China Monash University, Australia Imperial College London, UK Shanghai Jiao Tong University, China Institute of Computing Technology, Chinese Academy of Sciences, China University of Pennsylvania, USA University of Lübeck, Germany University of Pittsburgh, USA Vanderbilt University, USA University College London, UK Harvard Medical School, USA Korea University, South Korea University of Electronic Science and Technology of China, China National University of Singapore, Singapore University of Southern California, USA UCLA, USA Hebrew University of Jerusalem, Israel Polytechnique Montreal, Canada Friedrich-Alexander-Universität Erlangen-Nürnberg, Germany and Imperial College London, UK Harvard University, USA Massachusetts General Hospital, USA Concordia University, Canada Mansoura University, Egypt University of North Carolina, Greensboro, USA Kyung Hee University, South Korea Institute for Infocomm Research, Agency for Science Technology and Research (A*STAR), Singapore Korea University, South Korea Shenzhen University, China Massachusetts General Hospital, USA University of British Columbia, Canada Tencent Jarvis Lab, China Xi’an Jiaotong University, China
xiv
Organization
Jianming Liang Jianfei Liu Mingxia Liu Xiaofeng Liu Herve Lombaert Ismini Lourentzou Le Lu Dwarikanath Mahapatra Saad Nadeem Dong Nie Yoshito Otake Sang Hyun Park Magdalini Paschali Tingying Peng Caroline Petitjean Esther Puyol Anton Chen Qin Daniel Racoceanu Hedyeh Rafii-Tari Hongliang Ren Tammy Riklin Raviv Hassan Rivaz Mirabela Rusu Thomas Schultz Feng Shi Yang Song Aristeidis Sotiras Rachel Sparks Yao Sui Kenji Suzuki Qian Tao Mathias Unberath Martin Urschler Maria Vakalopoulou Erdem Varol Francisco Vasconcelos Harini Veeraraghavan Satish Viswanath Christian Wachinger
Arizona State University, USA National Institutes of Health Clinical Center, USA University of North Carolina at Chapel Hill, USA Harvard Medical School and MGH, USA École de technologie supérieure, Canada Virginia Tech, USA Damo Academy USA, Alibaba Group, USA Inception Institute of Artificial Intelligence, United Arab Emirates Memorial Sloan Kettering Cancer Center, USA Alibaba (US), USA Nara Institute of Science and Technology, Japan Daegu Gyeongbuk Institute of Science and Technology, South Korea Stanford University, USA Helmholtz Munich, Germany LITIS Université de Rouen Normandie, France King’s College London, UK Imperial College London, UK Sorbonne Université, France Auris Health, USA Chinese University of Hong Kong, China and National University of Singapore, Singapore Ben-Gurion University, Israel Concordia University, Canada Stanford University, USA University of Bonn, Germany Shanghai United Imaging Intelligence, China University of New South Wales, Australia Washington University in St. Louis, USA King’s College London, UK Peking University, China Tokyo Institute of Technology, Japan Delft University of Technology, Netherlands Johns Hopkins University, USA Medical University Graz, Austria CentraleSupelec, University Paris Saclay, France New York University, USA University College London, UK Memorial Sloan Kettering Cancer Center, USA Case Western Reserve University, USA Technical University of Munich, Germany
Organization
Hua Wang Qian Wang Shanshan Wang Yalin Wang Bryan Williams Matthias Wilms Jelmer Wolterink Ken C. L. Wong Jonghye Woo Shandong Wu Yutong Xie Fuyong Xing Daguang Xu Yan Xu Yanwu Xu Pingkun Yan Guang Yang Jianhua Yao Chuyang Ye Lequan Yu Ghada Zamzmi Liang Zhan Fan Zhang Ling Zhang Miaomiao Zhang Shu Zhang Rongchang Zhao Yitian Zhao Tao Zhou Yuyin Zhou Dajiang Zhu Lei Zhu Xiahai Zhuang Veronika Zimmer
xv
Colorado School of Mines, USA ShanghaiTech University, China Paul C. Lauterbur Research Center, SIAT, China Arizona State University, USA Lancaster University, UK University of Calgary, Canada University of Twente, Netherlands IBM Research Almaden, USA Massachusetts General Hospital and Harvard Medical School, USA University of Pittsburgh, USA University of Adelaide, Australia University of Colorado, Denver, USA NVIDIA, USA Beihang University, China Baidu, China Rensselaer Polytechnic Institute, USA Imperial College London, UK Tencent, China Beijing Institute of Technology, China University of Hong Kong, China National Institutes of Health, USA University of Pittsburgh, USA Harvard Medical School, USA Alibaba Group, China University of Virginia, USA Northwestern Polytechnical University, China Central South University, China Chinese Academy of Sciences, China Nanjing University of Science and Technology, USA UC Santa Cruz, USA University of Texas at Arlington, USA ROAS Thrust HKUST (GZ), and ECE HKUST, China Fudan University, China Technical University of Munich, Germany
xvi
Organization
Reviewers Alaa Eldin Abdelaal John Abel Kumar Abhishek Shahira Abousamra Mazdak Abulnaga Burak Acar Abdoljalil Addeh Ehsan Adeli Sukesh Adiga Vasudeva Seyed-Ahmad Ahmadi Euijoon Ahn Faranak Akbarifar Alireza Akhondi-asl Saad Ullah Akram Daniel Alexander Hanan Alghamdi Hassan Alhajj Omar Al-Kadi Max Allan Andre Altmann Pablo Alvarez Charlems Alvarez-Jimenez Jennifer Alvén Lidia Al-Zogbi Kimberly Amador Tamaz Amiranashvili Amine Amyar Wangpeng An Vincent Andrearczyk Manon Ansart Sameer Antani Jacob Antunes Michel Antunes Guilherme Aresta Mohammad Ali Armin Kasra Arnavaz Corey Arnold Janan Arslan Marius Arvinte Muhammad Asad John Ashburner Md Ashikuzzaman Shahab Aslani
Mehdi Astaraki Angélica Atehortúa Benjamin Aubert Marc Aubreville Paolo Avesani Sana Ayromlou Reza Azad Mohammad Farid Azampour Qinle Ba Meritxell Bach Cuadra Hyeon-Min Bae Matheus Baffa Cagla Bahadir Fan Bai Jun Bai Long Bai Pradeep Bajracharya Shafa Balaram Yaël Balbastre Yutong Ban Abhirup Banerjee Soumyanil Banerjee Sreya Banerjee Shunxing Bao Omri Bar Adrian Barbu Joao Barreto Adrian Basarab Berke Basaran Michael Baumgartner Siming Bayer Roza Bayrak Aicha BenTaieb Guy Ben-Yosef Sutanu Bera Cosmin Bercea Jorge Bernal Jose Bernal Gabriel Bernardino Riddhish Bhalodia Jignesh Bhatt Indrani Bhattacharya
Binod Bhattarai Lei Bi Qi Bi Cheng Bian Gui-Bin Bian Carlo Biffi Alexander Bigalke Benjamin Billot Manuel Birlo Ryoma Bise Daniel Blezek Stefano Blumberg Sebastian Bodenstedt Federico Bolelli Bhushan Borotikar Ilaria Boscolo Galazzo Alexandre Bousse Nicolas Boutry Joseph Boyd Behzad Bozorgtabar Nadia Brancati Clara Brémond Martin Stéphanie Bricq Christopher Bridge Coleman Broaddus Rupert Brooks Tom Brosch Mikael Brudfors Ninon Burgos Nikolay Burlutskiy Michal Byra Ryan Cabeen Mariano Cabezas Hongmin Cai Tongan Cai Zongyou Cai Liane Canas Bing Cao Guogang Cao Weiguo Cao Xu Cao Yankun Cao Zhenjie Cao
Organization
Jaime Cardoso M. Jorge Cardoso Owen Carmichael Jacob Carse Adrià Casamitjana Alessandro Casella Angela Castillo Kate Cevora Krishna Chaitanya Satrajit Chakrabarty Yi Hao Chan Shekhar Chandra Ming-Ching Chang Peng Chang Qi Chang Yuchou Chang Hanqing Chao Simon Chatelin Soumick Chatterjee Sudhanya Chatterjee Muhammad Faizyab Ali Chaudhary Antong Chen Bingzhi Chen Chen Chen Cheng Chen Chengkuan Chen Eric Chen Fang Chen Haomin Chen Jianan Chen Jianxu Chen Jiazhou Chen Jie Chen Jintai Chen Jun Chen Junxiang Chen Junyu Chen Li Chen Liyun Chen Nenglun Chen Pingjun Chen Pingyi Chen Qi Chen Qiang Chen
Runnan Chen Shengcong Chen Sihao Chen Tingting Chen Wenting Chen Xi Chen Xiang Chen Xiaoran Chen Xin Chen Xiongchao Chen Yanxi Chen Yixiong Chen Yixuan Chen Yuanyuan Chen Yuqian Chen Zhaolin Chen Zhen Chen Zhenghao Chen Zhennong Chen Zhihao Chen Zhineng Chen Zhixiang Chen Chang-Chieh Cheng Jiale Cheng Jianhong Cheng Jun Cheng Xuelian Cheng Yupeng Cheng Mark Chiew Philip Chikontwe Eleni Chiou Jungchan Cho Jang-Hwan Choi Min-Kook Choi Wookjin Choi Jaegul Choo Yu-Cheng Chou Daan Christiaens Argyrios Christodoulidis Stergios Christodoulidis Kai-Cheng Chuang Hyungjin Chung Matthew Clarkson Michaël Clément Dana Cobzas
Jaume Coll-Font Olivier Colliot Runmin Cong Yulai Cong Laura Connolly William Consagra Pierre-Henri Conze Tim Cootes Teresa Correia Baris Coskunuzer Alex Crimi Can Cui Hejie Cui Hui Cui Lei Cui Wenhui Cui Tolga Cukur Tobias Czempiel Javid Dadashkarimi Haixing Dai Tingting Dan Kang Dang Salman Ul Hassan Dar Eleonora D’Arnese Dhritiman Das Neda Davoudi Tareen Dawood Sandro De Zanet Farah Deeba Charles Delahunt Herve Delingette Ugur Demir Liang-Jian Deng Ruining Deng Wenlong Deng Felix Denzinger Adrien Depeursinge Mohammad Mahdi Derakhshani Hrishikesh Deshpande Adrien Desjardins Christian Desrosiers Blake Dewey Neel Dey Rohan Dhamdhere
xvii
xviii
Organization
Maxime Di Folco Songhui Diao Alina Dima Hao Ding Li Ding Ying Ding Zhipeng Ding Nicola Dinsdale Konstantin Dmitriev Ines Domingues Bo Dong Liang Dong Nanqing Dong Siyuan Dong Reuben Dorent Gianfranco Doretto Sven Dorkenwald Haoran Dou Mitchell Doughty Jason Dowling Niharika D’Souza Guodong Du Jie Du Shiyi Du Hongyi Duanmu Benoit Dufumier James Duncan Joshua Durso-Finley Dmitry V. Dylov Oleh Dzyubachyk Mahdi (Elias) Ebnali Philip Edwards Jan Egger Gudmundur Einarsson Mostafa El Habib Daho Ahmed Elazab Idris El-Feghi David Ellis Mohammed Elmogy Amr Elsawy Okyaz Eminaga Ertunc Erdil Lauren Erdman Marius Erdt Maria Escobar
Hooman Esfandiari Nazila Esmaeili Ivan Ezhov Alessio Fagioli Deng-Ping Fan Lei Fan Xin Fan Yubo Fan Huihui Fang Jiansheng Fang Xi Fang Zhenghan Fang Mohammad Farazi Azade Farshad Mohsen Farzi Hamid Fehri Lina Felsner Chaolu Feng Chun-Mei Feng Jianjiang Feng Mengling Feng Ruibin Feng Zishun Feng Alvaro Fernandez-Quilez Ricardo Ferrari Lucas Fidon Lukas Fischer Madalina Fiterau Antonio Foncubierta-Rodríguez Fahimeh Fooladgar Germain Forestier Nils Daniel Forkert Jean-Rassaire Fouefack Kevin François-Bouaou Wolfgang Freysinger Bianca Freytag Guanghui Fu Kexue Fu Lan Fu Yunguan Fu Pedro Furtado Ryo Furukawa Jin Kyu Gahm Mélanie Gaillochet
Francesca Galassi Jiangzhang Gan Yu Gan Yulu Gan Alireza Ganjdanesh Chang Gao Cong Gao Linlin Gao Zeyu Gao Zhongpai Gao Sara Garbarino Alain Garcia Beatriz Garcia Santa Cruz Rongjun Ge Shiv Gehlot Manuela Geiss Salah Ghamizi Negin Ghamsarian Ramtin Gharleghi Ghazal Ghazaei Florin Ghesu Sayan Ghosal Syed Zulqarnain Gilani Mahdi Gilany Yannik Glaser Ben Glocker Bharti Goel Jacob Goldberger Polina Golland Alberto Gomez Catalina Gomez Estibaliz Gómez-de-Mariscal Haifan Gong Kuang Gong Xun Gong Ricardo Gonzales Camila Gonzalez German Gonzalez Vanessa Gonzalez Duque Sharath Gopal Karthik Gopinath Pietro Gori Michael Götz Shuiping Gou
Organization
Maged Goubran Sobhan Goudarzi Mark Graham Alejandro Granados Mara Graziani Thomas Grenier Radu Grosu Michal Grzeszczyk Feng Gu Pengfei Gu Qiangqiang Gu Ran Gu Shi Gu Wenhao Gu Xianfeng Gu Yiwen Gu Zaiwang Gu Hao Guan Jayavardhana Gubbi Houssem-Eddine Gueziri Dazhou Guo Hengtao Guo Jixiang Guo Jun Guo Pengfei Guo Wenzhangzhi Guo Xiaoqing Guo Xueqi Guo Yi Guo Vikash Gupta Praveen Gurunath Bharathi Prashnna Gyawali Sung Min Ha Mohamad Habes Ilker Hacihaliloglu Stathis Hadjidemetriou Fatemeh Haghighi Justin Haldar Noura Hamze Liang Han Luyi Han Seungjae Han Tianyu Han Zhongyi Han Jonny Hancox
Lasse Hansen Degan Hao Huaying Hao Jinkui Hao Nazim Haouchine Michael Hardisty Stefan Harrer Jeffry Hartanto Charles Hatt Huiguang He Kelei He Qi He Shenghua He Xinwei He Stefan Heldmann Nicholas Heller Edward Henderson Alessa Hering Monica Hernandez Kilian Hett Amogh Hiremath David Ho Malte Hoffmann Matthew Holden Qingqi Hong Yoonmi Hong Mohammad Reza Hosseinzadeh Taher William Hsu Chuanfei Hu Dan Hu Kai Hu Rongyao Hu Shishuai Hu Xiaoling Hu Xinrong Hu Yan Hu Yang Hu Chaoqin Huang Junzhou Huang Ling Huang Luojie Huang Qinwen Huang Sharon Xiaolei Huang Weijian Huang
xix
Xiaoyang Huang Yi-Jie Huang Yongsong Huang Yongxiang Huang Yuhao Huang Zhe Huang Zhi-An Huang Ziyi Huang Arnaud Huaulmé Henkjan Huisman Alex Hung Jiayu Huo Andreas Husch Mohammad Arafat Hussain Sarfaraz Hussein Jana Hutter Khoi Huynh Ilknur Icke Kay Igwe Abdullah Al Zubaer Imran Muhammad Imran Samra Irshad Nahid Ul Islam Koichi Ito Hayato Itoh Yuji Iwahori Krithika Iyer Mohammad Jafari Srikrishna Jaganathan Hassan Jahanandish Andras Jakab Amir Jamaludin Amoon Jamzad Ananya Jana Se-In Jang Pierre Jannin Vincent Jaouen Uditha Jarayathne Ronnachai Jaroensri Guillaume Jaume Syed Ashar Javed Rachid Jennane Debesh Jha Ge-Peng Ji
xx
Organization
Luping Ji Zexuan Ji Zhanghexuan Ji Haozhe Jia Hongchao Jiang Jue Jiang Meirui Jiang Tingting Jiang Xiajun Jiang Zekun Jiang Zhifan Jiang Ziyu Jiang Jianbo Jiao Zhicheng Jiao Chen Jin Dakai Jin Qiangguo Jin Qiuye Jin Weina Jin Baoyu Jing Bin Jing Yaqub Jonmohamadi Lie Ju Yohan Jun Dinkar Juyal Manjunath K N Ali Kafaei Zad Tehrani John Kalafut Niveditha Kalavakonda Megha Kalia Anil Kamat Qingbo Kang Po-Yu Kao Anuradha Kar Neerav Karani Turkay Kart Satyananda Kashyap Alexander Katzmann Lisa Kausch Maxime Kayser Salome Kazeminia Wenchi Ke Youngwook Kee Matthias Keicher Erwan Kerrien
Afifa Khaled Nadieh Khalili Farzad Khalvati Bidur Khanal Bishesh Khanal Pulkit Khandelwal Maksim Kholiavchenko Ron Kikinis Benjamin Killeen Daeseung Kim Heejong Kim Jaeil Kim Jinhee Kim Jinman Kim Junsik Kim Minkyung Kim Namkug Kim Sangwook Kim Tae Soo Kim Younghoon Kim Young-Min Kim Andrew King Miranda Kirby Gabriel Kiss Andreas Kist Yoshiro Kitamura Stefan Klein Tobias Klinder Kazuma Kobayashi Lisa Koch Satoshi Kondo Fanwei Kong Tomasz Konopczynski Ender Konukoglu Aishik Konwer Thijs Kooi Ivica Kopriva Avinash Kori Kivanc Kose Suraj Kothawade Anna Kreshuk AnithaPriya Krishnan Florian Kromp Frithjof Kruggel Thomas Kuestner
Levin Kuhlmann Abhay Kumar Kuldeep Kumar Sayantan Kumar Manuela Kunz Holger Kunze Tahsin Kurc Anvar Kurmukov Yoshihiro Kuroda Yusuke Kurose Hyuksool Kwon Aymen Laadhari Jorma Laaksonen Dmitrii Lachinov Alain Lalande Rodney LaLonde Bennett Landman Daniel Lang Carole Lartizien Shlomi Laufer Max-Heinrich Laves William Le Loic Le Folgoc Christian Ledig Eung-Joo Lee Ho Hin Lee Hyekyoung Lee John Lee Kisuk Lee Kyungsu Lee Soochahn Lee Woonghee Lee Étienne Léger Wen Hui Lei Yiming Lei George Leifman Rogers Jeffrey Leo John Juan Leon Bo Li Caizi Li Chao Li Chen Li Cheng Li Chenxin Li Chnegyin Li
Organization
Dawei Li Fuhai Li Gang Li Guang Li Hao Li Haofeng Li Haojia Li Heng Li Hongming Li Hongwei Li Huiqi Li Jian Li Jieyu Li Kang Li Lin Li Mengzhang Li Ming Li Qing Li Quanzheng Li Shaohua Li Shulong Li Tengfei Li Weijian Li Wen Li Xiaomeng Li Xingyu Li Xinhui Li Xuelu Li Xueshen Li Yamin Li Yang Li Yi Li Yuemeng Li Yunxiang Li Zeju Li Zhaoshuo Li Zhe Li Zhen Li Zhenqiang Li Zhiyuan Li Zhjin Li Zi Li Hao Liang Libin Liang Peixian Liang
Yuan Liang Yudong Liang Haofu Liao Hongen Liao Wei Liao Zehui Liao Gilbert Lim Hongxiang Lin Li Lin Manxi Lin Mingquan Lin Tiancheng Lin Yi Lin Zudi Lin Claudia Lindner Simone Lionetti Chi Liu Chuanbin Liu Daochang Liu Dongnan Liu Feihong Liu Fenglin Liu Han Liu Huiye Liu Jiang Liu Jie Liu Jinduo Liu Jing Liu Jingya Liu Jundong Liu Lihao Liu Mengting Liu Mingyuan Liu Peirong Liu Peng Liu Qin Liu Quan Liu Rui Liu Shengfeng Liu Shuangjun Liu Sidong Liu Siyuan Liu Weide Liu Xiao Liu Xiaoyu Liu
Xingtong Liu Xinwen Liu Xinyang Liu Xinyu Liu Yan Liu Yi Liu Yihao Liu Yikang Liu Yilin Liu Yilong Liu Yiqiao Liu Yong Liu Yuhang Liu Zelong Liu Zhe Liu Zhiyuan Liu Zuozhu Liu Lisette Lockhart Andrea Loddo Nicolas Loménie Yonghao Long Daniel Lopes Ange Lou Brian Lovell Nicolas Loy Rodas Charles Lu Chun-Shien Lu Donghuan Lu Guangming Lu Huanxiang Lu Jingpei Lu Yao Lu Oeslle Lucena Jie Luo Luyang Luo Ma Luo Mingyuan Luo Wenhan Luo Xiangde Luo Xinzhe Luo Jinxin Lv Tianxu Lv Fei Lyu Ilwoo Lyu Mengye Lyu
xxi
xxii
Organization
Qing Lyu Yanjun Lyu Yuanyuan Lyu Benteng Ma Chunwei Ma Hehuan Ma Jun Ma Junbo Ma Wenao Ma Yuhui Ma Pedro Macias Gordaliza Anant Madabhushi Derek Magee S. Sara Mahdavi Andreas Maier Klaus H. Maier-Hein Sokratis Makrogiannis Danial Maleki Michail Mamalakis Zhehua Mao Jan Margeta Brett Marinelli Zdravko Marinov Viktoria Markova Carsten Marr Yassine Marrakchi Anne Martel Martin Maška Tejas Sudharshan Mathai Petr Matula Dimitrios Mavroeidis Evangelos Mazomenos Amarachi Mbakwe Adam McCarthy Stephen McKenna Raghav Mehta Xueyan Mei Felix Meissen Felix Meister Afaque Memon Mingyuan Meng Qingjie Meng Xiangzhu Meng Yanda Meng Zhu Meng
Martin Menten Odyssée Merveille Mikhail Milchenko Leo Milecki Fausto Milletari Hyun-Seok Min Zhe Min Song Ming Duy Minh Ho Nguyen Deepak Mishra Suraj Mishra Virendra Mishra Tadashi Miyamoto Sara Moccia Marc Modat Omid Mohareri Tony C. W. Mok Javier Montoya Rodrigo Moreno Stefano Moriconi Lia Morra Ana Mota Lei Mou Dana Moukheiber Lama Moukheiber Daniel Moyer Pritam Mukherjee Anirban Mukhopadhyay Henning Müller Ana Murillo Gowtham Krishnan Murugesan Ahmed Naglah Karthik Nandakumar Venkatesh Narasimhamurthy Raja Narayan Dominik Narnhofer Vishwesh Nath Rodrigo Nava Abdullah Nazib Ahmed Nebli Peter Neher Amin Nejatbakhsh Trong-Thuan Nguyen
Truong Nguyen Dong Ni Haomiao Ni Xiuyan Ni Hannes Nickisch Weizhi Nie Aditya Nigam Lipeng Ning Xia Ning Kazuya Nishimura Chuang Niu Sijie Niu Vincent Noblet Narges Norouzi Alexey Novikov Jorge Novo Gilberto Ochoa-Ruiz Masahiro Oda Benjamin Odry Hugo Oliveira Sara Oliveira Arnau Oliver Jimena Olveres John Onofrey Marcos Ortega Mauricio Alberto Ortega-Ruíz Yusuf Osmanlioglu Chubin Ou Cheng Ouyang Jiahong Ouyang Xi Ouyang Cristina Oyarzun Laura Utku Ozbulak Ece Ozkan Ege Özsoy Batu Ozturkler Harshith Padigela Johannes Paetzold José Blas Pagador Carrasco Daniel Pak Sourabh Palande Chengwei Pan Jiazhen Pan
Organization
Jin Pan Yongsheng Pan Egor Panfilov Jiaxuan Pang Joao Papa Constantin Pape Bartlomiej Papiez Nripesh Parajuli Hyunjin Park Akash Parvatikar Tiziano Passerini Diego Patiño Cortés Mayank Patwari Angshuman Paul Rasmus Paulsen Yuchen Pei Yuru Pei Tao Peng Wei Peng Yige Peng Yunsong Peng Matteo Pennisi Antonio Pepe Oscar Perdomo Sérgio Pereira Jose-Antonio Pérez-Carrasco Mehran Pesteie Terry Peters Eike Petersen Jens Petersen Micha Pfeiffer Dzung Pham Hieu Pham Ashish Phophalia Tomasz Pieciak Antonio Pinheiro Pramod Pisharady Theodoros Pissas Szymon Płotka Kilian Pohl Sebastian Pölsterl Alison Pouch Tim Prangemeier Prateek Prasanna
Raphael Prevost Juan Prieto Federica Proietto Salanitri Sergi Pujades Elodie Puybareau Talha Qaiser Buyue Qian Mengyun Qiao Yuchuan Qiao Zhi Qiao Chenchen Qin Fangbo Qin Wenjian Qin Yulei Qin Jie Qiu Jielin Qiu Peijie Qiu Shi Qiu Wu Qiu Liangqiong Qu Linhao Qu Quan Quan Tran Minh Quan Sandro Queirós Prashanth R Febrian Rachmadi Daniel Racoceanu Mehdi Rahim Jagath Rajapakse Kashif Rajpoot Keerthi Ram Dhanesh Ramachandram João Ramalhinho Xuming Ran Aneesh Rangnekar Hatem Rashwan Keerthi Sravan Ravi Daniele Ravì Sadhana Ravikumar Harish Raviprakash Surreerat Reaungamornrat Samuel Remedios Mengwei Ren Sucheng Ren Elton Rexhepaj
Mauricio Reyes Constantino Reyes-Aldasoro Abel Reyes-Angulo Hadrien Reynaud Razieh Rezaei Anne-Marie Rickmann Laurent Risser Dominik Rivoir Emma Robinson Robert Robinson Jessica Rodgers Ranga Rodrigo Rafael Rodrigues Robert Rohling Margherita Rosnati Łukasz Roszkowiak Holger Roth José Rouco Dan Ruan Jiacheng Ruan Daniel Rueckert Danny Ruijters Kanghyun Ryu Ario Sadafi Numan Saeed Monjoy Saha Pramit Saha Farhang Sahba Pranjal Sahu Simone Saitta Md Sirajus Salekin Abbas Samani Pedro Sanchez Luis Sanchez Giraldo Yudi Sang Gerard Sanroma-Guell Rodrigo Santa Cruz Alice Santilli Rachana Sathish Olivier Saut Mattia Savardi Nico Scherf Alexander Schlaefer Jerome Schmid
xxiii
xxiv
Organization
Adam Schmidt Julia Schnabel Lawrence Schobs Julian Schön Peter Schueffler Andreas Schuh Christina Schwarz-Gsaxner Michaël Sdika Suman Sedai Lalithkumar Seenivasan Matthias Seibold Sourya Sengupta Lama Seoud Ana Sequeira Sharmishtaa Seshamani Ahmed Shaffie Jay Shah Keyur Shah Ahmed Shahin Mohammad Abuzar Shaikh S. Shailja Hongming Shan Wei Shao Mostafa Sharifzadeh Anuja Sharma Gregory Sharp Hailan Shen Li Shen Linlin Shen Mali Shen Mingren Shen Yiqing Shen Zhengyang Shen Jun Shi Xiaoshuang Shi Yiyu Shi Yonggang Shi Hoo-Chang Shin Jitae Shin Keewon Shin Boris Shirokikh Suzanne Shontz Yucheng Shu
Hanna Siebert Alberto Signoroni Wilson Silva Julio Silva-Rodríguez Margarida Silveira Walter Simson Praveer Singh Vivek Singh Nitin Singhal Elena Sizikova Gregory Slabaugh Dane Smith Kevin Smith Tiffany So Rajath Soans Roger Soberanis-Mukul Hessam Sokooti Jingwei Song Weinan Song Xinhang Song Xinrui Song Mazen Soufi Georgia Sovatzidi Bella Specktor Fadida William Speier Ziga Spiclin Dominik Spinczyk Jon Sporring Pradeeba Sridar Chetan L. Srinidhi Abhishek Srivastava Lawrence Staib Marc Stamminger Justin Strait Hai Su Ruisheng Su Zhe Su Vaishnavi Subramanian Gérard Subsol Carole Sudre Dong Sui Heung-Il Suk Shipra Suman He Sun Hongfu Sun
Jian Sun Li Sun Liyan Sun Shanlin Sun Kyung Sung Yannick Suter Swapna T. R. Amir Tahmasebi Pablo Tahoces Sirine Taleb Bingyao Tan Chaowei Tan Wenjun Tan Hao Tang Siyi Tang Xiaoying Tang Yucheng Tang Zihao Tang Michael Tanzer Austin Tapp Elias Tappeiner Mickael Tardy Giacomo Tarroni Athena Taymourtash Kaveri Thakoor Elina Thibeau-Sutre Paul Thienphrapa Sarina Thomas Stephen Thompson Karl Thurnhofer-Hemsi Cristiana Tiago Lin Tian Lixia Tian Yapeng Tian Yu Tian Yun Tian Aleksei Tiulpin Hamid Tizhoosh Minh Nguyen Nhat To Matthew Toews Maryam Toloubidokhti Minh Tran Quoc-Huy Trinh Jocelyne Troccaz Roger Trullo
Organization
Chialing Tsai Apostolia Tsirikoglou Puxun Tu Samyakh Tukra Sudhakar Tummala Georgios Tziritas Vladimír Ulman Tamas Ungi Régis Vaillant Jeya Maria Jose Valanarasu Vanya Valindria Juan Miguel Valverde Fons van der Sommen Maureen van Eijnatten Tom van Sonsbeek Gijs van Tulder Yogatheesan Varatharajah Madhurima Vardhan Thomas Varsavsky Hooman Vaseli Serge Vasylechko S. Swaroop Vedula Sanketh Vedula Gonzalo Vegas Sanchez-Ferrero Matthew Velazquez Archana Venkataraman Sulaiman Vesal Mitko Veta Barbara Villarini Athanasios Vlontzos Wolf-Dieter Vogl Ingmar Voigt Sandrine Voros Vibashan VS Trinh Thi Le Vuong An Wang Bo Wang Ce Wang Changmiao Wang Ching-Wei Wang Dadong Wang Dong Wang Fakai Wang Guotai Wang
Haifeng Wang Haoran Wang Hong Wang Hongxiao Wang Hongyu Wang Jiacheng Wang Jing Wang Jue Wang Kang Wang Ke Wang Lei Wang Li Wang Liansheng Wang Lin Wang Ling Wang Linwei Wang Manning Wang Mingliang Wang Puyang Wang Qiuli Wang Renzhen Wang Ruixuan Wang Shaoyu Wang Sheng Wang Shujun Wang Shuo Wang Shuqiang Wang Tao Wang Tianchen Wang Tianyu Wang Wenzhe Wang Xi Wang Xiangdong Wang Xiaoqing Wang Xiaosong Wang Yan Wang Yangang Wang Yaping Wang Yi Wang Yirui Wang Yixin Wang Zeyi Wang Zhao Wang Zichen Wang Ziqin Wang
Ziyi Wang Zuhui Wang Dong Wei Donglai Wei Hao Wei Jia Wei Leihao Wei Ruofeng Wei Shuwen Wei Martin Weigert Wolfgang Wein Michael Wels Cédric Wemmert Thomas Wendler Markus Wenzel Rhydian Windsor Adam Wittek Marek Wodzinski Ivo Wolf Julia Wolleb Ka-Chun Wong Jonghye Woo Chongruo Wu Chunpeng Wu Fuping Wu Huaqian Wu Ji Wu Jiangjie Wu Jiong Wu Junde Wu Linshan Wu Qing Wu Weiwen Wu Wenjun Wu Xiyin Wu Yawen Wu Ye Wu Yicheng Wu Yongfei Wu Zhengwang Wu Pengcheng Xi Chao Xia Siyu Xia Wenjun Xia Lei Xiang
xxv
xxvi
Organization
Tiange Xiang Deqiang Xiao Li Xiao Xiaojiao Xiao Yiming Xiao Zeyu Xiao Hongtao Xie Huidong Xie Jianyang Xie Long Xie Weidi Xie Fangxu Xing Shuwei Xing Xiaodan Xing Xiaohan Xing Haoyi Xiong Yujian Xiong Di Xu Feng Xu Haozheng Xu Hongming Xu Jiangchang Xu Jiaqi Xu Junshen Xu Kele Xu Lijian Xu Min Xu Moucheng Xu Rui Xu Xiaowei Xu Xuanang Xu Yanwu Xu Yanyu Xu Yongchao Xu Yunqiu Xu Zhe Xu Zhoubing Xu Ziyue Xu Kai Xuan Cheng Xue Jie Xue Tengfei Xue Wufeng Xue Yuan Xue Zhong Xue
Ts Faridah Yahya Chaochao Yan Jiangpeng Yan Ming Yan Qingsen Yan Xiangyi Yan Yuguang Yan Zengqiang Yan Baoyao Yang Carl Yang Changchun Yang Chen Yang Feng Yang Fengting Yang Ge Yang Guanyu Yang Heran Yang Huijuan Yang Jiancheng Yang Jiewen Yang Peng Yang Qi Yang Qiushi Yang Wei Yang Xin Yang Xuan Yang Yan Yang Yanwu Yang Yifan Yang Yingyu Yang Zhicheng Yang Zhijian Yang Jiangchao Yao Jiawen Yao Lanhong Yao Linlin Yao Qingsong Yao Tianyuan Yao Xiaohui Yao Zhao Yao Dong Hye Ye Menglong Ye Yousef Yeganeh Jirong Yi Xin Yi
Chong Yin Pengshuai Yin Yi Yin Zhaozheng Yin Chunwei Ying Youngjin Yoo Jihun Yoon Chenyu You Hanchao Yu Heng Yu Jinhua Yu Jinze Yu Ke Yu Qi Yu Qian Yu Thomas Yu Weimin Yu Yang Yu Chenxi Yuan Kun Yuan Wu Yuan Yixuan Yuan Paul Yushkevich Fatemeh Zabihollahy Samira Zare Ramy Zeineldin Dong Zeng Qi Zeng Tianyi Zeng Wei Zeng Kilian Zepf Kun Zhan Bokai Zhang Daoqiang Zhang Dong Zhang Fa Zhang Hang Zhang Hanxiao Zhang Hao Zhang Haopeng Zhang Haoyue Zhang Hongrun Zhang Jiadong Zhang Jiajin Zhang Jianpeng Zhang
Organization
Jiawei Zhang Jingqing Zhang Jingyang Zhang Jinwei Zhang Jiong Zhang Jiping Zhang Ke Zhang Lefei Zhang Lei Zhang Li Zhang Lichi Zhang Lu Zhang Minghui Zhang Molin Zhang Ning Zhang Rongzhao Zhang Ruipeng Zhang Ruisi Zhang Shichuan Zhang Shihao Zhang Shuai Zhang Tuo Zhang Wei Zhang Weihang Zhang Wen Zhang Wenhua Zhang Wenqiang Zhang Xiaodan Zhang Xiaoran Zhang Xin Zhang Xukun Zhang Xuzhe Zhang Ya Zhang Yanbo Zhang Yanfu Zhang Yao Zhang Yi Zhang Yifan Zhang Yixiao Zhang Yongqin Zhang You Zhang Youshan Zhang
Yu Zhang Yubo Zhang Yue Zhang Yuhan Zhang Yulun Zhang Yundong Zhang Yunlong Zhang Yuyao Zhang Zheng Zhang Zhenxi Zhang Ziqi Zhang Can Zhao Chongyue Zhao Fenqiang Zhao Gangming Zhao He Zhao Jianfeng Zhao Jun Zhao Li Zhao Liang Zhao Lin Zhao Mengliu Zhao Mingbo Zhao Qingyu Zhao Shang Zhao Shijie Zhao Tengda Zhao Tianyi Zhao Wei Zhao Yidong Zhao Yiyuan Zhao Yu Zhao Zhihe Zhao Ziyuan Zhao Haiyong Zheng Hao Zheng Jiannan Zheng Kang Zheng Meng Zheng Sisi Zheng Tianshu Zheng Yalin Zheng
Yefeng Zheng Yinqiang Zheng Yushan Zheng Aoxiao Zhong Jia-Xing Zhong Tao Zhong Zichun Zhong Hong-Yu Zhou Houliang Zhou Huiyu Zhou Kang Zhou Qin Zhou Ran Zhou S. Kevin Zhou Tianfei Zhou Wei Zhou Xiao-Hu Zhou Xiao-Yun Zhou Yi Zhou Youjia Zhou Yukun Zhou Zongwei Zhou Chenglu Zhu Dongxiao Zhu Heqin Zhu Jiayi Zhu Meilu Zhu Wei Zhu Wenhui Zhu Xiaofeng Zhu Xin Zhu Yonghua Zhu Yongpei Zhu Yuemin Zhu Yan Zhuang David Zimmerer Yongshuo Zong Ke Zou Yukai Zou Lianrui Zuo Gerald Zwettler
xxvii
xxviii
Organization
Outstanding Area Chairs Mingxia Liu Matthias Wilms Veronika Zimmer
University of North Carolina at Chapel Hill, USA University of Calgary, Canada Technical University Munich, Germany
Outstanding Reviewers Kimberly Amador Angela Castillo Chen Chen Laura Connolly Pierre-Henri Conze Niharika D’Souza Michael Götz Meirui Jiang Manuela Kunz Zdravko Marinov Sérgio Pereira Lalithkumar Seenivasan
University of Calgary, Canada Universidad de los Andes, Colombia Imperial College London, UK Queen’s University, Canada IMT Atlantique, France IBM Research, USA University Hospital Ulm, Germany Chinese University of Hong Kong, China National Research Council Canada, Canada Karlsruhe Institute of Technology, Germany Lunit, South Korea National University of Singapore, Singapore
Honorable Mentions (Reviewers) Kumar Abhishek Guilherme Aresta Shahab Aslani Marc Aubreville Yaël Balbastre Omri Bar Aicha Ben Taieb Cosmin Bercea Benjamin Billot Michal Byra Mariano Cabezas Alessandro Casella Junyu Chen Argyrios Christodoulidis Olivier Colliot
Simon Fraser University, Canada Medical University of Vienna, Austria University College London, UK Technische Hochschule Ingolstadt, Germany Massachusetts General Hospital, USA Theator, Israel Simon Fraser University, Canada Technical University Munich and Helmholtz AI and Helmholtz Center Munich, Germany Massachusetts Institute of Technology, USA RIKEN Center for Brain Science, Japan University of Sydney, Australia Italian Institute of Technology and Politecnico di Milano, Italy Johns Hopkins University, USA Pfizer, Greece CNRS, France
Organization
Lei Cui Neel Dey Alessio Fagioli Yannik Glaser Haifan Gong Ricardo Gonzales Sobhan Goudarzi Michal Grzeszczyk Fatemeh Haghighi Edward Henderson Qingqi Hong Mohammad R. H. Taher Henkjan Huisman Ronnachai Jaroensri Qiangguo Jin Neerav Karani Benjamin Killeen Daniel Lang Max-Heinrich Laves Gilbert Lim Mingquan Lin Charles Lu Yuhui Ma Tejas Sudharshan Mathai Felix Meissen Mingyuan Meng Leo Milecki Marc Modat Tiziano Passerini Tomasz Pieciak Daniel Rueckert Julio Silva-Rodríguez Bingyao Tan Elias Tappeiner Jocelyne Troccaz Chialing Tsai Juan Miguel Valverde Sulaiman Vesal
xxix
Northwest University, China Massachusetts Institute of Technology, USA Sapienza University, Italy University of Hawaii at Manoa, USA Chinese University of Hong Kong, Shenzhen, China University of Oxford, UK Sunnybrook Research Institute, Canada Sano Centre for Computational Medicine, Poland Arizona State University, USA University of Manchester, UK Xiamen University, China Arizona State University, USA Radboud University Medical Center, the Netherlands Google, USA Northwestern Polytechnical University, China Massachusetts Institute of Technology, USA Johns Hopkins University, USA Helmholtz Center Munich, Germany Philips Research and ImFusion GmbH, Germany SingHealth, Singapore Weill Cornell Medicine, USA Massachusetts Institute of Technology, USA Chinese Academy of Sciences, China National Institutes of Health, USA Technische Universität München, Germany University of Sydney, Australia CentraleSupelec, France King’s College London, UK Siemens Healthineers, USA Universidad de Valladolid, Spain Imperial College London, UK ETS Montreal, Canada Nanyang Technological University, Singapore UMIT - Private University for Health Sciences, Medical Informatics and Technology, Austria TIMC Lab, Grenoble Alpes University-CNRS, France Queens College, City University New York, USA University of Eastern Finland, Finland Stanford University, USA
xxx
Organization
Wolf-Dieter Vogl Vibashan VS Lin Wang Yan Wang Rhydian Windsor Ivo Wolf Linshan Wu Xin Yang
RetInSight GmbH, Austria Johns Hopkins University, USA Harbin Engineering University, China Sichuan University, China University of Oxford, UK University of Applied Sciences Mannheim, Germany Hunan University, China Chinese University of Hong Kong, China
Contents – Part I
Machine Learning with Limited Supervision PET-Diffusion: Unsupervised PET Enhancement Based on the Latent Diffusion Model . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Caiwen Jiang, Yongsheng Pan, Mianxin Liu, Lei Ma, Xiao Zhang, Jiameng Liu, Xiaosong Xiong, and Dinggang Shen MedIM: Boost Medical Image Representation via Radiology Report-Guided Masking . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Yutong Xie, Lin Gu, Tatsuya Harada, Jianpeng Zhang, Yong Xia, and Qi Wu UOD: Universal One-Shot Detection of Anatomical Landmarks . . . . . . . . . . . . . . Heqin Zhu, Quan Quan, Qingsong Yao, Zaiyi Liu, and S. Kevin Zhou S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning for Scribble-Supervised Polyp Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . An Wang, Mengya Xu, Yang Zhang, Mobarakol Islam, and Hongliang Ren Modularity-Constrained Dynamic Representation Learning for Interpretable Brain Disorder Analysis with Functional MRI . . . . . . . . . . . . . . Qianqian Wang, Mengqi Wu, Yuqi Fang, Wei Wang, Lishan Qiao, and Mingxia Liu Anatomy-Driven Pathology Detection on Chest X-rays . . . . . . . . . . . . . . . . . . . . . Philip Müller, Felix Meissen, Johannes Brandt, Georgios Kaissis, and Daniel Rueckert VesselVAE: Recursive Variational Autoencoders for 3D Blood Vessel Synthesis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Paula Feldman, Miguel Fainstein, Viviana Siless, Claudio Delrieux, and Emmanuel Iarussi Dense Transformer based Enhanced Coding Network for Unsupervised Metal Artifact Reduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Wangduo Xie and Matthew B. Blaschko
3
13
24
35
46
57
67
77
xxxii
Contents – Part I
Multi-scale Cross-restoration Framework for Electrocardiogram Anomaly Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Aofan Jiang, Chaoqin Huang, Qing Cao, Shuang Wu, Zi Zeng, Kang Chen, Ya Zhang, and Yanfeng Wang Correlation-Aware Mutual Learning for Semi-supervised Medical Image Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . Shengbo Gao, Ziji Zhang, Jiechao Ma, Zihao Li, and Shu Zhang
87
98
TPRO: Text-Prompting-Based Weakly Supervised Histopathology Tissue Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 109 Shaoteng Zhang, Jianpeng Zhang, Yutong Xie, and Yong Xia Additional Positive Enables Better Representation Learning for Medical Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 119 Dewen Zeng, Yawen Wu, Xinrong Hu, Xiaowei Xu, Jingtong Hu, and Yiyu Shi Multi-modal Semi-supervised Evidential Recycle Framework for Alzheimer’s Disease Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 130 Yingjie Feng, Wei Chen, Xianfeng Gu, Xiaoyin Xu, and Min Zhang 3D Arterial Segmentation via Single 2D Projections and Depth Supervision in Contrast-Enhanced CT Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 141 Alina F. Dima, Veronika A. Zimmer, Martin J. Menten, Hongwei Bran Li, Markus Graf, Tristan Lemke, Philipp Raffler, Robert Graf, Jan S. Kirschke, Rickmer Braren, and Daniel Rueckert Automatic Retrieval of Corresponding US Views in Longitudinal Examinations . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 152 Hamideh Kerdegari, Nhat Phung Tran Huy, Van Hao Nguyen, Thi Phuong Thao Truong, Ngoc Minh Thu Le, Thanh Phuong Le, Thi Mai Thao Le, Luigi Pisani, Linda Denehy, Reza Razavi, Louise Thwaites, Sophie Yacoub, Andrew P. King, and Alberto Gomez Many Tasks Make Light Work: Learning to Localise Medical Anomalies from Multiple Synthetic Tasks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 162 Matthew Baugh, Jeremy Tan, Johanna P. Müller, Mischa Dombrowski, James Batten, and Bernhard Kainz AME-CAM: Attentive Multiple-Exit CAM for Weakly Supervised Segmentation on MRI Brain Tumor . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 173 Yu-Jen Chen, Xinrong Hu, Yiyu Shi, and Tsung-Yi Ho
Contents – Part I
xxxiii
Cross-Adversarial Local Distribution Regularization for Semi-supervised Medical Image Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 183 Thanh Nguyen-Duc, Trung Le, Roland Bammer, He Zhao, Jianfei Cai, and Dinh Phung AMAE: Adaptation of Pre-trained Masked Autoencoder for Dual-Distribution Anomaly Detection in Chest X-Rays . . . . . . . . . . . . . . . . . . 195 Behzad Bozorgtabar, Dwarikanath Mahapatra, and Jean-Philippe Thiran Gall Bladder Cancer Detection from US Images with only Image Level Labels . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 206 Soumen Basu, Ashish Papanai, Mayank Gupta, Pankaj Gupta, and Chetan Arora Dual Conditioned Diffusion Models for Out-of-Distribution Detection: Application to Fetal Ultrasound Videos . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 216 Divyanshu Mishra, He Zhao, Pramit Saha, Aris T. Papageorghiou, and J. Alison Noble Weakly-Supervised Positional Contrastive Learning: Application to Cirrhosis Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 227 Emma Sarfati, Alexandre Bône, Marc-Michel Rohé, Pietro Gori, and Isabelle Bloch Inter-slice Consistency for Unpaired Low-Dose CT Denoising Using Boosted Contrastive Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 238 Jie Jing, Tao Wang, Hui Yu, Zexin Lu, and Yi Zhang DAS-MIL: Distilling Across Scales for MIL Classification of Histological WSIs . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 248 Gianpaolo Bontempo, Angelo Porrello, Federico Bolelli, Simone Calderara, and Elisa Ficarra SLPD: Slide-Level Prototypical Distillation for WSIs . . . . . . . . . . . . . . . . . . . . . . . 259 Zhimiao Yu, Tiancheng Lin, and Yi Xu PET Image Denoising with Score-Based Diffusion Probabilistic Models . . . . . . . 270 Chenyu Shen, Ziyuan Yang, and Yi Zhang LSOR: Longitudinally-Consistent Self-Organized Representation Learning . . . . 279 Jiahong Ouyang, Qingyu Zhao, Ehsan Adeli, Wei Peng, Greg Zaharchuk, and Kilian M. Pohl
xxxiv
Contents – Part I
Self-supervised Learning for Physiologically-Based Pharmacokinetic Modeling in Dynamic PET . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 290 Francesca De Benetti, Walter Simson, Magdalini Paschali, Hasan Sari, Axel Rominger, Kuangyu Shi, Nassir Navab, and Thomas Wendler Geometry-Invariant Abnormality Detection . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 300 Ashay Patel, Petru-Daniel Tudosiu, Walter Hugo Lopez Pinaya, Olusola Adeleke, Gary Cook, Vicky Goh, Sebastien Ourselin, and M. Jorge Cardoso Modeling Alzheimers’ Disease Progression from Multi-task and Self-supervised Learning Perspective with Brain Networks . . . . . . . . . . . . . . 310 Wei Liang, Kai Zhang, Peng Cao, Pengfei Zhao, Xiaoli Liu, Jinzhu Yang, and Osmar R. Zaiane Unsupervised Discovery of 3D Hierarchical Structure with Generative Diffusion Features . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 320 Nurislam Tursynbek and Marc Niethammer Domain Adaptation for Medical Image Segmentation Using Transformation-Invariant Self-training . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 331 Negin Ghamsarian, Javier Gamazo Tejero, Pablo Márquez-Neila, Sebastian Wolf, Martin Zinkernagel, Klaus Schoeffmann, and Raphael Sznitman Multi-IMU with Online Self-consistency for Freehand 3D Ultrasound Reconstruction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 342 Mingyuan Luo, Xin Yang, Zhongnuo Yan, Junyu Li, Yuanji Zhang, Jiongquan Chen, Xindi Hu, Jikuan Qian, Jun Cheng, and Dong Ni Deblurring Masked Autoencoder Is Better Recipe for Ultrasound Image Recognition . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 352 Qingbo Kang, Jun Gao, Kang Li, and Qicheng Lao You’ve Got Two Teachers: Co-evolutionary Image and Report Distillation for Semi-supervised Anatomical Abnormality Detection in Chest X-Ray . . . . . . 363 Jinghan Sun, Dong Wei, Zhe Xu, Donghuan Lu, Hong Liu, Liansheng Wang, and Yefeng Zheng Masked Vision and Language Pre-training with Unimodal and Multimodal Contrastive Losses for Medical Visual Question Answering . . . . . . . . . . . . . . . . . 374 Pengfei Li, Gang Liu, Jinlong He, Zixu Zhao, and Shenjun Zhong
Contents – Part I
xxxv
CL-ADDA: Contrastive Learning with Amplitude-Driven Data Augmentation for fMRI-Based Individualized Predictions . . . . . . . . . . . . . . . . . . . 384 Jiangcong Liu, Le Xu, Yun Guan, Hao Ma, and Lixia Tian An Auto-Encoder to Reconstruct Structure with Cryo-EM Images via Theoretically Guaranteed Isometric Latent Space, and Its Application for Automatically Computing the Conformational Pathway . . . . . . . . . . . . . . . . . . 394 Kimihiro Yamazaki, Yuichiro Wada, Atsushi Tokuhisa, Mutsuyo Wada, Takashi Katoh, Yuhei Umeda, Yasushi Okuno, and Akira Nakagawa Knowledge Boosting: Rethinking Medical Contrastive Vision-Language Pre-training . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 405 Xiaofei Chen, Yuting He, Cheng Xue, Rongjun Ge, Shuo Li, and Guanyu Yang A Small-Sample Method with EEG Signals Based on Abductive Learning for Motor Imagery Decoding . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 416 Tianyang Zhong, Xiaozheng Wei, Enze Shi, Jiaxing Gao, Chong Ma, Yaonai Wei, Songyao Zhang, Lei Guo, Junwei Han, Tianming Liu, and Tuo Zhang Multi-modal Variational Autoencoders for Normative Modelling Across Multiple Imaging Modalities . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 425 Ana Lawry Aguila, James Chapman, and Andre Altmann LOTUS: Learning to Optimize Task-Based US Representations . . . . . . . . . . . . . . 435 Yordanka Velikova, Mohammad Farid Azampour, Walter Simson, Vanessa Gonzalez Duque, and Nassir Navab Unsupervised 3D Out-of-Distribution Detection with Latent Diffusion Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 446 Mark S. Graham, Walter Hugo Lopez Pinaya, Paul Wright, Petru-Daniel Tudosiu, Yee H. Mah, James T. Teo, H. Rolf Jäger, David Werring, Parashkev Nachev, Sebastien Ourselin, and M. Jorge Cardoso Improved Multi-shot Diffusion-Weighted MRI with Zero-Shot Self-supervised Learning Reconstruction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 457 Jaejin Cho, Yohan Jun, Xiaoqing Wang, Caique Kobayashi, and Berkin Bilgic Infusing Physically Inspired Known Operators in Deep Models of Ultrasound Elastography . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 467 Ali K. Z. Tehrani and Hassan Rivaz
xxxvi
Contents – Part I
Weakly Supervised Lesion Localization of Nascent Geographic Atrophy in Age-Related Macular Degeneration . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 477 Heming Yao, Adam Pely, Zhichao Wu, Simon S. Gao, Robyn H. Guymer, Hao Chen, Mohsen Hejrati, and Miao Zhang Can Point Cloud Networks Learn Statistical Shape Models of Anatomies? . . . . . 486 Jadie Adams and Shireen Y. Elhabian CT-Guided, Unsupervised Super-Resolution Reconstruction of Single 3D Magnetic Resonance Image . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 497 Jiale Wang, Alexander F. Heimann, Moritz Tannast, and Guoyan Zheng Image2SSM: Reimagining Statistical Shape Models from Images with Radial Basis Functions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 508 Hong Xu and Shireen Y. Elhabian MDA-SR: Multi-level Domain Adaptation Super-Resolution for Wireless Capsule Endoscopy Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 518 Tianbao Liu, Zefeiyun Chen, Qingyuan Li, Yusi Wang, Ke Zhou, Weijie Xie, Yuxin Fang, Kaiyi Zheng, Zhanpeng Zhao, Side Liu, and Wei Yang PROnet: Point Refinement Using Shape-Guided Offset Map for Nuclei Instance Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 528 Siwoo Nam, Jaehoon Jeong, Miguel Luna, Philip Chikontwe, and Sang Hyun Park Self-Supervised Domain Adaptive Segmentation of Breast Cancer via Test-Time Fine-Tuning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 539 Kyungsu Lee, Haeyun Lee, Georges El Fakhri, Jonghye Woo, and Jae Youn Hwang Decoupled Consistency for Semi-supervised Medical Image Segmentation . . . . 551 Faquan Chen, Jingjing Fei, Yaqi Chen, and Chenxi Huang Combating Medical Label Noise via Robust Semi-supervised Contrastive Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 562 Bingzhi Chen, Zhanhao Ye, Yishu Liu, Zheng Zhang, Jiahui Pan, Biqing Zeng, and Guangming Lu Multi-scale Self-Supervised Learning for Longitudinal Lesion Tracking with Optional Supervision . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 573 Anamaria Vizitiu, Antonia T. Mohaiu, Ioan M. Popdan, Abishek Balachandran, Florin C. Ghesu, and Dorin Comaniciu
Contents – Part I
xxxvii
Tracking Adaptation to Improve SuperPoint for 3D Reconstruction in Endoscopy . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 583 O. León Barbed, José M. M. Montiel, Pascal Fua, and Ana C. Murillo Structured State Space Models for Multiple Instance Learning in Digital Pathology . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 594 Leo Fillioux, Joseph Boyd, Maria Vakalopoulou, Paul-henry Cournède, and Stergios Christodoulidis vox2vec: A Framework for Self-supervised Contrastive Learning of Voxel-Level Representations in Medical Images . . . . . . . . . . . . . . . . . . . . . . . . . 605 Mikhail Goncharov, Vera Soboleva, Anvar Kurmukov, Maxim Pisov, and Mikhail Belyaev Mesh2SSM: From Surface Meshes to Statistical Shape Models of Anatomy . . . . 615 Krithika Iyer and Shireen Y. Elhabian Graph Convolutional Network with Morphometric Similarity Networks for Schizophrenia Classification . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 626 Hye Won Park, Seo Yeong Kim, and Won Hee Lee M-FLAG: Medical Vision-Language Pre-training with Frozen Language Models and Latent Space Geometry Optimization . . . . . . . . . . . . . . . . . . . . . . . . . . 637 Che Liu, Sibo Cheng, Chen Chen, Mengyun Qiao, Weitong Zhang, Anand Shah, Wenjia Bai, and Rossella Arcucci Machine Learning - Transfer Learning Foundation Ark: Accruing and Reusing Knowledge for Superior and Robust Performance . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 651 DongAo Ma, Jiaxuan Pang, Michael B. Gotway, and Jianming Liang Masked Frequency Consistency for Domain-Adaptive Semantic Segmentation of Laparoscopic Images . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 663 Xinkai Zhao, Yuichiro Hayashi, Masahiro Oda, Takayuki Kitasaka, and Kensaku Mori Pick the Best Pre-trained Model: Towards Transferability Estimation for Medical Image Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 674 Yuncheng Yang, Meng Wei, Junjun He, Jie Yang, Jin Ye, and Yun Gu Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 684 Longxiang Tang, Kai Li, Chunming He, Yulun Zhang, and Xiu Li
xxxviii
Contents – Part I
Unsupervised Domain Adaptation for Anatomical Landmark Detection . . . . . . . 695 Haibo Jin, Haoxuan Che, and Hao Chen MetaLR: Meta-tuning of Learning Rates for Transfer Learning in Medical Imaging . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 706 Yixiong Chen, Li Liu, Jingxian Li, Hua Jiang, Chris Ding, and Zongwei Zhou Multi-Target Domain Adaptation with Prompt Learning for Medical Image Segmentation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 717 Yili Lin, Dong Nie, Yuting Liu, Ming Yang, Daoqiang Zhang, and Xuyun Wen Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 728 Jiajin Zhang, Hanqing Chao, Amit Dhurandhar, Pin-Yu Chen, Ali Tajer, Yangyang Xu, and Pingkun Yan Cross-Dataset Adaptation for Instrument Classification in Cataract Surgery Videos . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 739 Jay N. Paranjape, Shameema Sikder, Vishal M. Patel, and S. Swaroop Vedula Black-box Domain Adaptative Cell Segmentation via Multi-source Distillation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 749 Xingguang Wang, Zhongyu Li, Xiangde Luo, Jing Wan, Jianwei Zhu, Ziqi Yang, Meng Yang, and Cunbao Xu MedGen3D: A Deep Generative Framework for Paired 3D Image and Mask Generation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 759 Kun Han, Yifeng Xiong, Chenyu You, Pooya Khosravi, Shanlin Sun, Xiangyi Yan, James S. Duncan, and Xiaohui Xie Unsupervised Domain Transfer with Conditional Invertible Neural Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 770 Kris K. Dreher, Leonardo Ayala, Melanie Schellenberg, Marco Hübner, Jan-Hinrich Nölke, Tim J. Adler, Silvia Seidlitz, Jan Sellner, Alexander Studier-Fischer, Janek Gröhl, Felix Nickel, Ullrich Köthe, Alexander Seitel, and Lena Maier-Hein Author Index . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 781
Machine Learning with Limited Supervision
PET-Diffusion: Unsupervised PET Enhancement Based on the Latent Diffusion Model Caiwen Jiang1 , Yongsheng Pan1 , Mianxin Liu3 , Lei Ma1 , Xiao Zhang1 , Jiameng Liu1 , Xiaosong Xiong1 , and Dinggang Shen1,2,4(B) 1
School of Biomedical Engineering, ShanghaiTech University, Shanghai, China {jiangcw,panysh,dgshen}@shanghaitech.edu.cn 2 Shanghai United Imaging Intelligence Co., Ltd., Shanghai 200230, China 3 Shanghai Artificial Intelligence Laboratory, Shanghai 200232, China 4 Shanghai Clinical Research and Trial Center, Shanghai 201210, China Abstract. Positron emission tomography (PET) is an advanced nuclear imaging technique with an irreplaceable role in neurology and oncology studies, but its accessibility is often limited by the radiation hazards inherent in imaging. To address this dilemma, PET enhancement methods have been developed by improving the quality of low-dose PET (LPET) images to standard-dose PET (SPET) images. However, previous PET enhancement methods rely heavily on the paired LPET and SPET data which are rare in clinic. Thus, in this paper, we propose an unsupervised PET enhancement (uPETe) framework based on the latent diffusion model, which can be trained only on SPET data. Specifically, our SPET-only uPETe consists of an encoder to compress the input SPET/LPET images into latent representations, a latent diffusion model to learn/estimate the distribution of SPET latent representations, and a decoder to recover the latent representations into SPET images. Moreover, from the theory of actual PET imaging, we improve the latent diffusion model of uPETe by 1) adopting PET image compression for reducing the computational cost of diffusion model, 2) using Poisson diffusion to replace Gaussian diffusion for making the perturbed samples closer to the actual noisy PET, and 3) designing CT-guided cross-attention for incorporating additional CT images into the inverse process to aid the recovery of structural details in PET. With extensive experimental validation, our uPETe can achieve superior performance over state-of-the-art methods, and shows stronger generalizability to the dose changes of PET imaging. The code of our implementation is available at https://github.com/jiang-cw/PET-diffusion. Keywords: Positron emission tomography (PET) · Enhancement · Latent diffusion model · Poisson diffusion · CT-guided cross-attention
1
Introduction
Positron emission tomography (PET) is a sensitive nuclear imaging technique, and plays an essential role in early disease diagnosis, such as cancers and c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 3–12, 2023. https://doi.org/10.1007/978-3-031-43907-0_1
4
C. Jiang et al.
Alzheimer’s disease [8]. However, acquiring high-quality PET images requires injecting a sufficient dose (standard dose) of radionuclides into the human body, which poses unacceptable radiation hazards for pregnant women and infants even following the As Low As Reasonably Achievable (ALARA) principle [19]. To reduce the radiation hazards, besides upgrading imaging hardware, designing advanced PET enhancement algorithms for improving the quality of low-dose PET (LPET) images to standard-dose PET (SPET) images is a promising alternative. In recent years, many enhancement algorithms have been proposed to improve PET image quality. Among the earliest are filtering-based methods such as non-local mean (NLM) filter [1], block-matching 3D filter [4], bilateral filter [7], and guided filter [22], which are quite robust but tend to over-smooth images and suppress the high-frequency details. Subsequently, with the development of deep learning, the end-to-end PET enhancement networks [9,14,21] were proposed and achieved significant performance improvement. But these supervised methods relied heavily on the paired LPET and SPET data that are rare in actual clinic due to radiation exposure and involuntary motions (e.g., respiratory and muscle relaxation). Consequently, unsupervised PET enhancement methods such as deep image prior [3], Noise2Noise [12,20], and their variants [17] were developed to overcome this limitation. However, these methods still require LPET to train models, which contradicts with the fact that only SPET scans are conducted in clinic. Fortunately, the recent glowing diffusion model [6] provides us with the idea for proposing a clinically-applicable PET enhancement approach, whose training only relies on SPET data. Generally, the diffusion model consists of two reversible processes, where the forward diffusion adds noise to a clean image until it becomes pure noise, while the reverse process removes noise from pure noise until the clean image is recovered. By combining the mechanics of diffusion model with the observation that the main differences between LPET and SPET are manifested as levels of noises in the image [11], we can view LPET and SPET as results at different stages in an integrated diffusion process. Therefore, when a diffusion model (trained only on SPET) can recover noisy samples to SPET, this model can also recover LPET to SPET. However, extending the diffusion model developed for 2D photographic images to PET enhancement still faces two problems: a) three-dimensionsal (3D) PET images will dramatically increase the computational cost of diffusion model; b) PET is the detail-sensitive images and may be introduced/lost some details during the procedure of adding/removing noise, which will affect the downstream diagnosis. Taking all into consideration, we propose the SPET-only unsupervised PET enhancement (uPETe) framework based on the latent diffusion model. Specifically, uPETe has an encoder--decoder structure that first uses the encoder to compress input the LPET/SPET images into latent representations, then uses the latent diffusion model to learn/estimate the distribution of SPET latent representations, and finally uses the decoder to recover SPET images from the estimated SPET latent representations. The keys of our uPETe
PET-Diffusion: Unsupervised PET Enhancement
5
Fig. 1. Overview of proposed uPETe. (a) and (b) provide the framework of uPETe as well as depict its implementation during both the training and testing phases, and (c) illustrates the details of CT-guided cross-attention.
include 1) compressing the 3D PET images into a lower dimensional space for reducing the computational cost of diffusion model, 2) adopting the Poisson noise, which is the dominant noise in PET imaging [20], to replace the Gaussian noise in the diffusion process for avoiding the introduction of details that are not existing in PET images, and 3) designing CT-guided cross-attention to incorporate additional CT images into the inverse process for helping the recovery of structural details in PET. Our work had three main features/contributions: i) proposing a clinicallyapplicable unsupervised PET enhancement framework, ii) designing three targeted strategies for improving the diffusion model, including PET image compression, Poisson diffusion, and CT-guided cross-attention, and iii) achieving better performance than state-of-the-art methods on the collected PET datasets.
2
Method
The framework of uPETe is illustrated in Fig. 1. When given an input PET image x (i.e., SPET for training and LPET for testing), x is first compressed into the latent representation z0 by the encoder E. Subsequently, z0 is fed into a latent diffusion model followed by the decoder D to output the expected SPET image x ˆ. In addition, a specialized encoder ECT is used to compress the CT image corresponding to the input PET image into the latent representation zCT , which is fed into each denoising network for CT-guided cross-attention. In the following, we introduce the details of image compression, latent diffusion model, and implementation.
6
2.1
C. Jiang et al.
Image Compression
The conventional diffusion model is computationally-demanding due to its numerous inverse denoising steps, which severely restricts its application to 3D PET enhancement. To overcome this limitation, we adopt two strategies including 1) compressing the input image and 2) reducing the diffusion steps (as described in Sect. 2.3). Similar to [10,18], we adopt an autoencoder (E and D) to compress the 3D PET images into a lower dimensional but more compact space. The crucial aspects of this process is to ensure that the latent representation contains the necessary and representative information for the input image. To achieve this, we train the autoencoder by a combination of perceptual loss [24] and patch-based adversarial loss [5], instead of simple voxel-level loss such as L2 or L1 loss. Among them, the perceptual loss, designed on a pre-trained 3D ResNet [2], constrains higher-level information such as texture and semantic content, and the patchbased adversarial loss ensures globally coherent while remaining locally realistic. Let x ∈ RH,W,Z denote the input image and z0 ∈ Rh,w,z,c denote the latent representation. The compression process can be formulated as x ˆ = D(z0 ) = D(E(x)). In this way, we compress the input image by a factor of f = H/h = W/w = Z/z. The results of SPET estimation under different compression rates f are provided in the supplement. 2.2
Latent Diffusion Model
After compressing the input PET image, its latent representation is fed into the latent diffusion model, which is the key to achieving the SPET-only unsupervised PET enhancement. As described above, the LPET can be viewed as noisy SPET (even in the compressed space), so the diffusion process from SPET to pure noise actually covers the situations of LPET. That is, the diffusion model trained with SPET is capable of estimating SPET from the noisy sample (diffused from LPET). But the diffusion model is developed from photographic images, which have significant difference with the detail-sensitive PET images. To improve its applicability for PET images, we design several targeted strategies for the diffusion process and inverse process, namely Poisson diffusion and CT-guided cross-attention, respectively. Poisson Diffusion. In conventional diffusion models, the forward process typically employs Gaussian noise to gradually perturb input samples. However, in PET images, the dominant source of noise is Poisson noise, rather than Gaussian noise. Considering this, in our uPETe we choose to adopt Poisson diffusion to perturb the input samples, which facilitates the diffusion model for achieving better performance on the PET enhancement task. Let zt be the perturbation sample in Poisson diffusion, where t = 0, 1, ..., T . Then the Poisson diffusion can be formulate as follows: zt = perturb(zt−1 , λt ),
λ1 < λ2 < ... < λT .
(1)
PET-Diffusion: Unsupervised PET Enhancement
7
At each diffusion step, we apply the perturb function to the previous perturbed sample zt−1 by imposing a Poisson noise with an expectation of λt , which is linearly interpolated from [0, 1] and incremented with t. In our implementation, we apply the same Poisson noise imposition operation as in [20], i.e., applying Poisson deviates on the projected sinograms, to generate a sequence of perturbed samples with increasing Poisson noise intensity as the step number t increases. CT-Guided Cross-Attention. The attenuation correction of PET typically relies on the corresponding anatomical image (CT or MR), resulting in a PET scan usually accompanied by a CT or MR scan. To fully utilize the extramodality images (i.e., CT in our work) as well as improve the applicability of diffusion models, we design a CT-guided cross-attention to incorporate the CT images into the reverse process for assisting the recovery of structural details. As shown in Fig. 1, to achieve a particular SPET estimation, the corresponding CT image is first compressed into the latent representation zCT by encoder ECT . Then zCT is fed into a denoising attention U-Net [16] at each step for calculation of cross-attention, where the query Q and key K are calculated from zCT while the value V is still calculated from the output of the previous layer because our final goal is SPET estimation. Denoting the output of previous layer as zP ET , the CT-guided cross-attention can be formulated as follows: T QCT KCT √ + B) · VP ET , d = ConvK (zCT ), VP ET = ConvV (zP ET ),
Output = sof tmax( QCT = ConvQ (zCT ),
KCT
(2)
where d is the number of channels, B is the position bias, and Conv(·) denotes the 1 × 1 × 1 convolution with stride of 1. 2.3
Implementation Details
Typically, the trained diffusion model generates target images from random noise, requiring a large number of steps T to make the final perturbed sample (zT ) close to pure noise. However, in our task, the target SPET image is generated from a given LPET image during testing, and making zT as close to pure noise as possible is not necessary since the remaining PET-related information can also benefit the image recovery. Therefore, we can considerably reduce the number of diffusion steps T to accelerate the model training, and T is set to 400 in our implementation. We evaluate the quantitative results using two metrics, including Peak Signal to Noise Ratio (PSNR) and Structural Similarity Index Measure (SSIM).
8
C. Jiang et al.
Table 1. Quantitative results of ablation analysis, in terms of PSNR and SSIM. Method
PSNR [dB]↑
SSIM ↑
LDM
23.732 ± 1.264
0.986 ± 0.010
LDM-P
24.125 ± 1.072
0.987 ± 0.009
LDM-CT
25.348 ± 0.822
0.990 ± 0.006
LDM-P-CT 25.817 ± 0.675 0.992 ± 0.004
Fig. 2. Generalizability to dose changes.
3 3.1
Experiments Dataset
Our dataset consists of 100 SPET images for training and 30 paired LPET and SPET images for testing. Among them, 50 chest-abdomen SPET images are collected from (total-body) uEXPLORER PET/CT scanner [25], and 20 paired chest-abdomen images are collected by list mode of the scanner with 256 MBq of [18 F]-FDG injection. Specifically, the SPET images are reconstructed by using the 1200 s data between 60–80 min after tracer injection, while the corresponding LPET images are simultaneously reconstructed by 120 s data uniformly sampled from 1200 s data. As a basic data preprocessing, all images are resampled to voxel spacing of 2 × 2 × 2 mm3 and resolution of 256 × 256 × 160, while their intensity range is normalized to [0, 1] by min-max normalization. For increasing the training samples and reducing the dependence on GPU memory, we extract the overlapped patches of size 96 × 96 × 96 from every whole PET image. 3.2
Ablation Analysis
To verify the effectiveness of our proposed strategies, i.e. Poisson diffusion process and CT-guided cross-attention, we design another four variant latent diffusion models (LDMs) with the same compression model, including: 1) LDM: standard LDM; 2) LDM-P: LDM with Poisson diffusion process; 3) LDM-CT: LDM with CT-guided cross-attention; 4) LDM-P-CT: LDM with Poisson diffusion process and CT-guided cross-attention. All methods use the same experimental settings, and their quantitative results are given in Table 1. From Table 1, we can have the following observations. (1) LDM-P achieves better performance than LDM. This proves that the Poisson diffusion is more appropriate than the Gaussian diffusion for PET enhancement. (2) LDM-CT with the corresponding CT image for assisting denoising achieves better results than LDM. This can be reasonable as the CT image can provide anatomical information, thus benefiting the recovery of structural details (e.g., organ boundaries) in SPET images. (3) LDM-P-CT achieves better results than all other variants
PET-Diffusion: Unsupervised PET Enhancement
9
Table 2. Quantitative comparison of our uPETe with several state-of-the-art PET enhancement methods, in terms of PSNR and SSIM, where ∗ denotes unsupervised method and † denotes fully-supervised method. Method
PSNR [dB]↑
SSIM ↑
DIP∗ [3]
22.538 ± 2.136
0.981 ± 0.015
∗
[23] 22.932 ± 1.983
0.983 ± 0.014
[21]
23.351 ± 1.725
0.984 ± 0.012
23.628 ± 1.655
0.985 ± 0.011
23.852 ± 1.522
0.985 ± 0.009
24.263 ± 1.351
0.987 ± 0.009
DF-GAN† [9]
24.821 ± 0.975
0.989 ± 0.007
AR-GAN† [14]
25.217 ± 0.853
0.990 ± 0.006
uPETe∗
25.817 ± 0.675 0.992 ± 0.004
Noisier2Noise LA-GAN MR-GDD
† ∗
[17]
Trans-GAN† [13] Noise2Void
∗
[20]
on both PSNR and SSIM, which shows both of our proposed strategies contribute to the final performance. These three comparisons conjointly verify the effective design of our proposed uPETe, where the Poisson diffusion process and CT-guided cross-attention both benefit the PET enhancement. 3.3
Comparison with State-of-the-Art Methods
We further compare our uPETe with several state-of-the-art PET enhancement methods, which can be divided into two classes: 1) fully-supervised methods, including LA-GAN [21], Transformer-GAN (Trans-GAN) [13], Dual-frequency GAN (DF-GAN) [9], and AR-GAN [14]; 2) unsupervised methods, including deep image prior (DIP) [3], Noisier2Noise [23], magnetic resonance guided deep decoder (MR-GDD) [17], and Noise2Void [20]. The quantitative and qualitative results are provided in Table 2 and Fig. 3, respectively. Quantitative Comparison: Table 2 shows that our uPETe outperforms all competing methods. Compared to the fully-supervised method AR-GAN which achieves sub-optimal performance, our uPETe does not require paired LPET and SPET, yet still achieves improvement. Additionally, uPETe also achieves noticeable performance improvement to Noise2Void (which is a supervised method). Specifically, the average improvement in PSNR and SSIM on SPET estimation are 1.554 dB and 0.005, respectively. This suggests that our uPETe can generate promising results without relying on paired data, demonstrating its potential for clinical applications. Qualitative Comparison: In Fig. 3, we provide a visual comparison of SPET estimation for two typical cases. First, compared to unsupervised methods such as DIP and Noise2Void, the SPET images estimated by our uPETe have less noise but clearer boundaries. Second, our uPETe performs better on the structural details compared to the fully-supervised methods, i.e., missing unclear tissue (Trans-GAN) or introducing non-existing artifacts in PET image (DF-GAN).
10
C. Jiang et al.
LPET
SPET
DIP
Trans-GAN
Noise2Void
DF-GAN
AR-GAN
uPETe
Fig. 3. Visual comparison of estimated SPET images on two typical cases. In each case, the first and second rows show the axial and coronal views, respectively, and from left to right are the input (LPET), ground truth (SPET), results by five other methods (3rd–7th columns), and the result by our uPETe (last column). Red boxes and arrows show areas for detailed comparison. (Color figure online)
Overall, these pieces of evidence demonstrate the superiority of our uPETe over state-of-the-art methods. 3.4
Generalization Evaluation
We further evaluate the generalizability of our uPETe to tracer dose changes by simulating Poisson noise on SPET to produce different doses for LPET, which is a common way to generate noisy PET data [20]. Notably, we do not need to retrain the models since they have been trained in Sect. 3.3. The quantitative results of our uPETe and five state-of-the-art methods are provided in Fig. 2. As shown in Fig. 2, our uPETe outperforms the other five methods at all doses and exhibits a lower PSNR descent slope as dose decreases (i.e., λ increases), demonstrating its superior generalizability to dose changes. This is because uPETe is based on diffusion model, which simplifies the complex distribution prediction task into a series of simple denoising tasks and thus has strong generalizability. Moreover, we also find that the unsupervised methods (i.e., uPETe, Noise2Void, and DIP) have stronger generalizability than fully-supervised methods (i.e., AR-GAN, DF-GAN, and Trans-GAN) as they have a smoother descent slope. The main reason is that the unsupervised learning has the ability to extract patterns and features from the data based on the inherent structure and distribution of the data itself [15].
PET-Diffusion: Unsupervised PET Enhancement
4
11
Conclusion and Limitations
In this paper, we have developed a clinically-applicable unsupervised PET enhancement framework based on the latent diffusion model, which uses only the clinically-available SPET data for training. Meanwhile, we adopt three strategies to improve the applicability of diffusion models developed from photographic images to PET enhancement, including 1) compressing the size of the input image, 2) using Poisson diffusion, instead of Gaussian diffusion, and 3) designing CT-guided cross-attention to enable additional anatomical images (e.g., CT) to aid the recovery of structural details in PET. Validated by extensive experiments, our uPETe achieved better performance than both state-of-the-art unsupervised and fully-supervised PET enhancement methods, and showed stronger generalizability to the tracer dose changes. Despite the advance of uPETe, our current work still suffers from a few limitations such as (1) lacking theoretical support for our Poisson diffusion, which is just an engineering attempt, and 2) only validating the generalizability of uPETe on a simulated dataset. In our future work, we will complete the design of Poisson diffusion from theoretical perspective, and collect more real PET datasets (e.g., head datasets) to comprehensively validate the generalizability of our uPETe. Acknowledgment. This work was supported in part by National Natural Science Foundation of China (No. 62131015), Science and Technology Commission of Shanghai Municipality (STCSM) (No. 21010502600), The Key R&D Program of Guangdong Province, China (No. 2021B0101420006), and the China Postdoctoral Science Foundation (Nos. BX2021333, 2021M703340).
References 1. Buades, A., Coll, B., Morel, J.: A non-local algorithm for image denoising. In: 2005 IEEE Computer Society Conference on Computer Vision and Pattern Recognition, vol. 2, pp. 60–65 (2005) 2. Chen, S., Ma, K., Zheng, Y.: Med3D: transfer learning for 3D medical image analysis. arXiv preprint arXiv:1904.00625 (2019) 3. Cui, J., et al.: PET image denoising using unsupervised deep learning. Eur. J. Nucl. Med. Mol. Imaging 46(13), 2780–2789 (2019) 4. Dabov, K., Foi, A., Katkovnik, V., Egiazarian, K.: Image denoising with blockmatching and 3D filtering. Image Process. Algorithms Syst. Neural Netw. Mach. Learn. 6064, 354–365 (2006) 5. Dosovitskiy, A., Brox, T.: Generating images with perceptual similarity metrics based on deep networks. In: Advances in Neural Information Processing Systems, vol. 29 (2016) 6. Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. In: Advances in Neural Information Processing Systems, vol. 33, pp. 6840–6851 (2020) 7. Hofheinz, F., et al.: Suitability of bilateral filtering for edge-preserving noise reduction in PET. EJNMMI Res. 1(1), 1–9 (2011)
12
C. Jiang et al.
8. Jiang, C., Pan, Y., Cui, Z., Nie, D., Shen, D.: Semi-supervised standard-dose PET image generation via region-adaptive normalization and structural consistency constraint. IEEE Trans. Med. Imaging (2023) 9. Jiang, C., Pan, Y., Cui, Z., Shen, D.: Reconstruction of standard-dose PET from low-dose PET via dual-frequency supervision and global aggregation module. In: 2022 IEEE 19th International Symposium on Biomedical Imaging (ISBI), pp. 1–5 (2022) 10. Khader, F., et al.: Medical diffusion-denoising diffusion probabilistic models for 3D medical image generation. arXiv preprint arXiv:2211.03364 (2022) 11. Lu, W., et al.: An investigation of quantitative accuracy for deep learning based denoising in oncological PET. Phys. Med. Biol. 64(16), 165019 (2019) 12. Lu, Z., Li, Z., Wang, J., Shen, D.: Two-stage self-supervised cycle-consistency network for reconstruction of thin-slice MR images. arXiv preprint arXiv:2106.15395 (2021) 13. Luo, Y., et al.: 3D transformer-GAN for high-quality PET reconstruction. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12906, pp. 276–285. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87231-1_27 14. Luo, Y., et al.: Adaptive rectification based adversarial network with spectrum constraint for high-quality PET image synthesis. Med. Image Anal. 77, 102335 (2022) 15. Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9910, pp. 69–84. Springer, Cham (2016). https://doi.org/10.1007/9783-319-46466-4_5 16. Oktay, O., et al.: Attention U-Net: learning where to look for the pancreas. arXiv preprint arXiv:1804.03999 (2018) 17. Onishi, Y., et al.: Anatomical-guided attention enhances unsupervised PET image denoising performance. Med. Image Anal. 74, 102226 (2021) 18. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., Ommer, B.: High-resolution image synthesis with latent diffusion models. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10684–10695 (2022) 19. Slovis, T.L.: The ALARA concept in pediatric CT: myth or reality? Radiology 223(1), 5–6 (2002) 20. Song, T., Yang, F., Dutta, J.: Noise2Void: unsupervised denoising of PET images. Phys. Med. Biol. 66(21), 214002 (2021) 21. Wang, Y., et al.: 3D auto-context-based locality adaptive multi-modality GANs for PET synthesis. IEEE Trans. Med. Imaging 38(6), 1328–1339 (2019) 22. Yan, J., Lim, J., Townsend, D.: MRI-guided brain PET image filtering and partial volume correction. Phys. Med. Biol. 60(3), 961 (2015) 23. Yie, S., Kang, S., Hwang, D., Lee, J.: Self-supervised PET denoising. Nucl. Med. Mol. Imaging 54(6), 299–304 (2020) 24. Zhang, R., Isola, P., Efros, A.A., Shechtman, E., Wang, O.: The unreasonable effectiveness of deep features as a perceptual metric. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 586–595 (2018) 25. Zhang, X., et al.: Total-body dynamic reconstruction and parametric imaging on the uEXPLORER. J. Nucl. Med. 61(2), 285–291 (2020)
MedIM: Boost Medical Image Representation via Radiology Report-Guided Masking Yutong Xie1 , Lin Gu2,3 , Tatsuya Harada2,3 , Jianpeng Zhang4 , Yong Xia4 , and Qi Wu1(B) 1
Australian Institute for Machine Learning, The University of Adelaide, Adelaide, Australia [email protected] 2 RIKEN AIP, Tokyo, Japan 3 RCAST, The University of Tokyo, Tokyo, Japan 4 School of Computer Science and Engineering, Northwestern Polytechnical University, Xi’an, China
Abstract. Masked image modelling (MIM)-based pre-training shows promise in improving image representations with limited annotated data by randomly masking image patches and reconstructing them. However, random masking may not be suitable for medical images due to their unique pathology characteristics. This paper proposes Masked medical Image Modelling (MedIM), a novel approach, to our knowledge, the first research that masks and reconstructs discriminative areas guided by radiological reports, encouraging the network to explore the stronger semantic representations from medical images. We introduce two mutual comprehensive masking strategies, knowledge word-driven masking (KWM) and sentence-driven masking (SDM). KWM uses Medical Subject Headings (MeSH) words unique to radiology reports to identify discriminative cues mapped to MeSH words and guide the mask generation. SDM considers that reports usually have multiple sentences, each of which describes different findings, and therefore integrates sentencelevel information to identify discriminative regions for mask generation. MedIM integrates both strategies by simultaneously restoring the images masked by KWM and SDM for a more robust and representative medical visual representation. Our extensive experiments on various downstream tasks covering multi-label/class image classification, medical image segmentation, and medical image-text analysis, demonstrate that MedIM with report-guided masking achieves competitive performance. Our method substantially outperforms ImageNet pre-training, MIM-based pre-training, and medical image-report pre-training counterparts. Codes are available at https://github.com/YtongXie/MedIM.
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 2. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 13–23, 2023. https://doi.org/10.1007/978-3-031-43907-0_2
14
Y. Xie et al.
1
Introduction
Accurate medical representation is crucial for clinical decision-making. Deep learning has shown promising results in medical image analysis, but the accuracy of these models heavily relies on the quality and quantity of data and annotations [21]. Masked image modelling (MIM)-based pre-training approach [3,8,23] such as masked autoencoders (MAE) [8] has shown prospects in improving the image representation under limited annotated data. MIM masks a set of image patches before inputting them into a network and then reconstructs these masked patches by aggregating information from the surrounding context. This ability to aggregate contextual information is essential for vision tasks and understanding medical image analysis [24]. Recently, MIM has witnessed much success in medical domain [4–6,11,20,24] such as chest X-ray and CT image analysis. While the random masking strategy is commonly used in current MIM-based works, randomly selecting a percentage of patches to mask. We argue that such a strategy may not be the most suitable approach for medical images due to the domain particularity. Medical images commonly present relatively fixed anatomical structures, while subtle variations between individuals, such as sporadic lesions that alter the texture and morphology of surrounding tissues or organs, may exist. These pathology characteristics may be minute and challenging to perceive visually but are indispensable for early screening and clinical diagnosis. Representation learning should capture these desired target representations to improve downstream diagnosis models’ reliability, interpretability, and generalizability. Random masking is less likely to deliberately focus on these important parts. We put forward a straightforward principle, i.e., masking and reconstructing meaningful characteristics, encouraging the network to explore stronger representations from medical images. We advocate utilising radiological reports to locate relevant characteristics and guide mask generation. These reports are routinely produced in clinical practice by expert medical professionals such as radiologists, and can provide a valuable source of semantic knowledge at little to no additional cost [9,17]. When medical professionals read a medical image, they will focus on areas of the image that are relevant to the patient’s or clinical conditions. These areas are then recorded in a report, along with relevant information such as whether they are normal or abnormal, the location and density of abnormal areas, and any other materials about the patient’s condition. By incorporating reports into the medical image representation learning, the models can simulate the professionals’ gaze and learn to focus on the pathology characteristics of images. In this paper, we propose a new approach called MedIM (Masked medical Image Modelling). MedIM aligns semantic correspondences between medical images and radiology reports and reconstructs regions masked by the guidance of learned correspondences. Especially we introduce two masking strategies: knowledge word-driven masking (KWM) and sentence-driven masking (SDM). KWM uses Medical Subject Headings (MeSH) words [14] as the domain knowledge. MeSH words provide a standardized language for medical concepts and conditions. In radiology reports, MeSH words describe imaging modalities, anatomic
MedIM
15
locations, and pathologic findings, such as “Heart”, “Pulmonary”, “Vascular”, and “Pneumothorax” in Fig. 1, and are important semantic components. This inspired KWM to identify regions mapped to MeSH words and generate an attention map, where the highly activated tokens indicate more discriminative cues. We utilize this attention map to selectively mask then restore the high-activated regions, stimulating the network to focus more on regions related to MeSH words during the modelling process. SDM considers multiple sentences in reports, each potentially providing independent information about different aspects of the image. It generates an attention map by identifying regions mapped to one selected sentence, enabling the network to focus on specific aspects of the image mentioned in that sentence during modelling. KWM and SDM identify different sources of discriminative cues and are therefore complementary. MedIM leverages the superiority of both strategies by simultaneously restoring images masked by KWM and SDM in each iteration. This integration creates a more challenging and comprehensive modelling task, which encourages the network to learn more robust and representative medical visual representations. Our MedIM approach is pre-trained on a large chest X-ray dataset of image-report pairs. The learned image representations are transferred to several medical image analysis downstream tasks: multi-label/class image classification and pneumothorax segmentation. Besides, our MedIM pre-trained model can be freely applied to image-text analysis downstream tasks such as image-to-text/text-to-image retrieval. Our contributions mainly include three-fold: (1) we present a novel masking approach MedIM, which is the first work to explore the potential of radiology reports in mask generation for medical images, offering a new perspective to enhance the accuracy and interpretability of medical image representation; (2) we propose two mutual comprehensive masking strategies, KWM and SDM, that effectively identify word-level and sentence-level of discriminative cues to guide the mask generation; (3) we conduct extensive experiments on medical image and image-text downstream tasks, and the performance beats strong competitors like ImageNet pre-training, MIM-based pre-training and advanced medical imagereport pre-training counterparts.
2
Approach
As shown in Fig. 1, our MedIM framework has dual encoders that map images and reports to a latent representation, a report-guided mask generation module, and a decoder that reconstructs the images from the masked representation. 2.1
Image and Text Encoders
Image Encoder. We use the vision Transformer (ViT) [7] as the image encoderF(·). For an input medical image x, it is first reshaped into a sequence of flattened patches that are then embedded and fed into stacked Transformer layers to obtain the encoded representations of visual tokens Eimg = F(x) ∈ RNimg ×C , where C is the encoding dimension and Nimg denotes the number of patches.
16
Y. Xie et al.
Fig. 1. Illustration of our MedIM framework. It includes dual encoders to obtain latent representations. Two report-guided masking strategies, KWM and SDM, are then introduced to generate the masked representations. The decoder is built to reconstruct the original images from the masked representation. Noted that the back regions in the generated mask will be masked.
Text Encoder. We use the BioClinicalBERT [2] model, pre-trained on the MIMIC III dataset [13], as our text encoder T (·). We employ WordPiece [19] for tokenizing free-text medical reports. This technique is particularly useful for handling the large and diverse vocabularies that are common in the medical language. For an input medical report r with Ntext words, the tokenizer segments each word to sub-words and generates word piece embeddings as the input to the text encoder. The text encoder extracts features for word pieces, which are aggregated to generate the word representations Etext = T (r) ∈ RNtext ×C . 2.2
Report-Guided Mask Generation
We introduce two radiology report-guided masking strategies, i.e., KWM and SDM, identifying different cues to guide the mask generation. Knowledge Word-Driven Masking (KWM). MeSH words shown in Fig. 1 are important for accurately describing medical images, as they provide a standardized vocabulary to describe the anatomical structures and pathologies observed in the images. Hence the KWM is proposed to focus on the MeSH word tokens during mask generation. Given a report r and its text representations Etext , we first match MeSH words in the report based on the MeSH Table [14] and extract the representations of MeSH word tokens, formally as j EMeSH = Etext , r j ∈ MeSH, j ∈ {1, ..., Ntext } ∈ RNMeSH ×C , (1) where NMeSH represents the number of MeSH words in the report r. Then, we compute an attention map CMeSH to identify image regions mapped to MeSH
MedIM
17
Fig. 2. A image-report pair and the corresponding attention map and mask generated by KWM and SDM. The black regions in the generated mask will be masked.
words as follows
T softmax(Eimg · EMeSH )) ∈ RH×W , CMeSH = R(
(2)
where H = W = Nimg , T and R represent the transpose and reshape functions, and the softmax function normalizes the elements along the image dimension to find the focused region matched to each MeSH word. The summation operation performs on the text dimension to aggregate the attentions related to all MeSH words. Subsequently, the high-activated masking is presented to remove the discovered attention regions. Here, we define a corresponding binary mask m ∈ [γ∗N ] [γ∗N ] (i,j) {0, 1}H×W formulated as m(i,j) = I(CMeSH CMeSHimg ). Here CMeSHimg refers to the (γ ∗ Nimg )-th largest activation in CMeSH , andγ is the masking ratio that determines how many activations would be suppressed. With this binary mask, we can compute the masked representations produced by KWM as W M(CMeSH ; λ)kwm = {z (i,j) |m(i,j) ·R(Eimg )(i,j) +(1−m(i,j) )·[MASK]}H i=1 j=1 , (3)
where [MASK] is a masked placeholder. Sentence-Driven Masking (SDM). Medical reports often contain multiple sentences that describe different findings related to the image, which inspires SDM to introduce sentence-level information during mask generation. For the report r, we randomly select a sentence s and extract its representations as j , r j ∈ s, j ∈ {1, ..., Ntext } ∈ RNs ×C (4) Es = Etext where Ns represents the length of s. Then, an attention map Cs can be computed to identify regions mapped to this sentence as Cs = R( softmax(Eimg · EsT )) ∈ RH×W , (5) After that, the high-activated masking is performed based on Cs to compute the masked representations M(Cs ; λ)sdm . We also select an image-report pair and visualize the corresponding attention map and generated mask procured by KWM and SDM in Fig. 2 to show the superiority of our masking strategies.
18
2.3
Y. Xie et al.
Decoder for Reconstruction
Both masked representations M(CMeSH ; λ)kwm and M(Cs ; λ)sdm are mapped to the decoder D(·) that includes four conv-bn-relu-upsample blocks. We design two independent reconstruction heads to respectively accept the decoded features D(M(CMeSH ; λ)kwm ) and D(M(Cs ; λ)sdm ) and generate the final reconstruction results y kwm and y sdm . 2.4
Objective Function
MedIM creates a more challenging reconstruction objective by removing then restoring the most discriminative regions guided by radiological reports. We optimize this reconstruction learning process with the mean square error (MSE) loss function, expressed as 2 2 Lrestore = y kwm , x + y sdm , x
(6)
MedIM also combines the cross-modal alignment constraint, which aligns medical images’ visual and semantic aspects with their corresponding radiological reports, benefiting in better identifying the reported-guided discriminative regions during mask generation. We follow the work [17] and compute the objective alignment function Lalign by exploiting the fine-grained correspondences between images and reports. The final objective of our MedIM is the combination of reconstruction and alignment objectives as LMedIM = αLrestore + Lalign , where α is a weight factor to balance both objectives. 2.5
Downstream Transfer Learning
After pre-training, we can transfer the weight parameters of the MedIM to various downstream tasks. For the classification task, we use the commonly used Linear probing, i.e., freezing the pre-trained image encoder and solely training a randomly initialized linear classification head. For the segmentation task, the encoder and decoder are first initialized with the MedIM pre-trained weights, and a downstream-specific head is added to the network. The network is then fine-tuned end-to-end. For the retrieval task, we take an image or report as an input query and retrieve target reports or images by computing the similarity between the query and all candidates using the learned image and text encoders.
3 3.1
Experiments and Results Experimental Details
Pre-training Setup. We use the MIMIC-CXR-JPG dataset [12] to pre-train our MedIM framework. Following [17], we only include frontal-view chest images from the dataset and extract the impression and finding sections from radiological reports. As a result, over 210,000 radiograph-report pairs are available. We
MedIM
19
manually split 80% of pairs for pre-training and 20% of pairs used for downstream to validate in-domain transfer learning. We set the input size to 224 × 224 adopt the AdamW optimizer [16] with a cosine decaying learning rate [15], a momentum of 0.9, and a weight decay of 0.05. We set the initial learning rate to 0.00002, batch size to 144, and maximum epochs to 50. Through the ablation study, we empirically set the mask ratio to 50% and loss weight α to 10. Downstream Setup. We validate the transferability of learned MedIM representations on four X-ray-based downstream tasks: (1) multi-label classification on CheXpert [10] dataset using its official split, which contains five individual binary labels: atelectasis, cardiomegaly, consolidation, edema, and pleural effusion; (2) multi-class classification on COVIDx [18] dataset with over 30k chest X-ray images, which aims to classify each radiograph into COVID-19, nonCOVID pneumonia or normal, and is split into training, validation, and test set with 80%/10%/10% ratio; (3) pneumothorax segmentation on SIIM-ACR Pneumothorax Segmentation dataset [1] with over 12k chest radiographs, which is split into training, validation, and test set with 70%/15%/15% ratio; and (4) image-text/report-text retrieval on the MIMIC-CXR validation dataset. We use the Dice coefficient score (Dice) to measure the segmentation performance, use the mean area under the receiver operator curve (mAUC) to measure the multilabel classification performance, and use the accuracy to measure the multi-class classification performance. We use the recall of the corresponding image/report that appears in the top-k ranked images/reports (denoted by R@k) to measure the retrieval performance [9]. Each downstream experiment is conducted three times and the average performance is reported. More details are in the Appendix. 3.2
Comparisons with Different Pre-training Methods
We compare the downstream performance of our MedIM pre-training with five pre-training methods in Table 1 and Table 2. Our MedIM achieves state-of-theart results on all downstream datasets, outperforming ImageNet pre-training [7], MIM-based pre-training MAE [8] and three medical image-report pre-training approaches, GLoRIA [9], MRM [22] and MGCA [17], under different labelling ratios. The superior performance corroborates the effectiveness of our reportguided masking pre-training strategy over other pre-training strategies in learning discriminative information. Besides, our MedIM achieves 88.91% when using only 1% downstream labelled data on CheXpert, better than other competitors with 100% labelled data. These convincing results have demonstrated the enormous potential of MedIM for annotation-limited medical image tasks. 3.3
Discussions
Ablation Study. Ablation studies are performed over each component of MedIM, including knowledge word-driven masking (KWM) and Sentence-driven masking (SDM), as listed in Table 3. We sequentially add each component to
20
Y. Xie et al.
Table 1. Classification and segmentation results of different pre-training methods on three downstream test sets under different ratios of available labelled data. All methods were evaluated with the ViT-B/16 backbone. * denotes our implementation of on same pre-training dataset and backbone due to the lack of available pre-trained weights. Methods Random Init ImageNet [7] MAE* [8] GLoRIA* [9] MRM [22] MGCA* [17] MedIM
CheXpert 1% 10%
100%
COVIDx 1% 10%
100%
SIIM 10%
100%
68.11 73.52 82.36 86.50 88.50 88.11 88.91
71.91 81.84 86.69 88.24 88.70 88.88 89.65
67.01 71.56 73.31 75.79 76.11 76.29 77.22
82.71 89.74 91.79 92.11 92.21 92.47 93.57
19.13 55.06 57.68 57.67 61.21 60.64 63.50
60.97 76.02 77.16 77.23 79.45 79.31 81.32
71.17 80.38 85.22 87.53 88.50 88.29 89.25
79.68 84.28 87.67 88.68 88.92 89.04 90.34
the vanilla baseline, Lalign only, thus the downstream performance is gradually improved in Table 3. First, by reconstructing the masked representations produced by KWM, the total performance of three tasks is increased by 3.28 points. This indicates that using MeSH words as knowledge to guide the mask generation can improve the model representations and generalization. Equipped with KWM and SDM, our MedIM can surpass the baseline model by a total of 5.12 points on three tasks, suggesting the superiority of adding the SDM strategy and integrating these two masking strategies. Masking Strategies. To demonstrate the effectiveness of the High-activated masking strategy, we compare it with three counterparts, No masking, Random masking, and Low-activated masking. Here No masking means that the recon-
Table 2. Image-to-text (I2T) and text-to- Table 3. Ablation study of different comimage (T2I) retrieval results on the MIMIC- ponents in MedIM. CXR test set. Different components Tasks Methods
T2I R@1 R@5
MGCA [17] 5.74 MedIM
Lalign KWM SDM COVIDx CheXpert SIIM
I2T
22.91
R@10 R@1 R@5
R@10
×
×
89.04
88.29
31.90
32.51
×
89.85
88.86
62.54
90.34
89.25
63.50
6.22
23.61
7.67 23.96 33.55 8.70 24.63 34.27
60.64
Fig. 3. Left: Results when using different masking strategies. Right: Results when using different masking ratios.
MedIM
21
struction is performed based on the complete image encoder representations instead of the masked one. Low-activated masking refers to masking the tokens exhibiting a low response in both KWM and SDM strategies. The comparison on the left side of Fig. 3 reveals that all masking strategies are more effective in improving the accuracy than No masking. Benefiting from mining more discriminative information, our High-activated masking performs better than the Random and Low-activated masking. Besides, we also compare different masking ratios, varying from 25% to 75%, on the right side of Fig. 3.
4
Conclusion
We propose a new masking approach called MedIM that uses radiological reports to guide the mask generation of medical images during the pre-training process. We introduce two masking strategies KWM and SDM, which effectively identify different sources of discriminative cues to generate masked inputs. MedIM is pre-trained on a large dataset of image-report pairs to restore the masked regions, and the learned image representations are transferred to three medical image analysis tasks and image-text/report-text retrieval tasks. The results demonstrate that MedIM outperforms strong pre-training competitors and the random masking method. In the future, we will extend our MedIM to handle other modalities, e.g., 3D medical image analysis. Acknowledgments. Dr. Lin Gu was supported by JST Moonshot R&D Grant Number JPMJMS2011, Japan. Prof. Yong Xia was supported in part by the Key Research and Development Program of Shaanxi Province, China, under Grant 2022GY084, in part by the National Natural Science Foundation of China under Grants 62171377, and in part by the National Key R&D Program of China under Grant 2022YFC2009903/2022YFC2009900.
References 1. Siim-acr pneumothorax segmentation. Society for Imaging Informatics in Medicine (2019) 2. Alsentzer, E., et al.: Publicly available clinical BERT embeddings. arXiv preprint arXiv:1904.03323 (2019) 3. Bao, H., Dong, L., Piao, S., Wei, F.: Beit: BERT pre-training of image transformers. In: International Conference on Learning Representations (ICLR) (2022) 4. Cai, Z., Lin, L., He, H., Tang, X.: Uni4Eye: unified 2D and 3D self-supervised pre-training via masked image modeling transformer for ophthalmic image classification. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13438, pp. 88–98. Springer, Cham (2022). https://doi.org/10. 1007/978-3-031-16452-1 9
22
Y. Xie et al.
5. Chen, Z., Agarwal, D., Aggarwal, K., Safta, W., Balan, M.M., Brown, K.: Masked image modeling advances 3D medical image analysis. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 1970– 1980 (2023) 6. Chen, Z., et al.: Multi-modal masked autoencoders for medical vision-and-language pre-training. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13435, pp. 679–689. Springer, Cham (2022). https://doi. org/10.1007/978-3-031-16443-9 65 7. Dosovitskiy, A., et al.: An image is worth 16x16 words: transformers for image recognition at scale. In: International Conference on Learning Representations (ICLR) (2021) 8. He, K., Chen, X., Xie, S., Li, Y., Doll´ ar, P., Girshick, R.: Masked autoencoders are scalable vision learners. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 16000–16009 (2022) 9. Huang, S.C., Shen, L., Lungren, M.P., Yeung, S.: Gloria: a multimodal global-local representation learning framework for label-efficient medical image recognition. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3942–3951 (2021) 10. Irvin, J., et al.: CheXpert: a large chest radiograph dataset with uncertainty labels and expert comparison. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 590–597 (2019) 11. Jiang, J., Tyagi, N., Tringale, K., Crane, C., Veeraraghavan, H.: Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (smit). In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13434, pp. 556–566. Springer, Cham (2022). https://doi.org/10.1007/978-3031-16440-8 53 12. Johnson, A.E., et al.: Mimic-CXR, a de-identified publicly available database of chest radiographs with free-text reports. Sci. Data 6(1), 1–8 (2019) 13. Johnson, A.E., et al.: Mimic-III, a freely accessible critical care database. Sci. Data 3(1), 1–9 (2016) 14. Lipscomb, C.E.: Medical subject headings (mesh). Bull. Med. Libr. Assoc. 88(3), 265 (2000) 15. Loshchilov, I., Hutter, F.: SGDR: stochastic gradient descent with warm restarts. In: ICLR (2017) 16. Loshchilov, I., Hutter, F.: Fixing weight decay regularization in Adam (2018) 17. Wang, F., Zhou, Y., Wang, S., Vardhanabhuti, V., Yu, L.: Multi-granularity cross-modal alignment for generalized medical visual representation learning. In: Advances in Neural Information Processing Systems (2022) 18. Wang, L., Lin, Z.Q., Wong, A.: COVID-net: a tailored deep convolutional neural network design for detection of COVID-19 cases from chest x-ray images. Sci. Rep. 10(1), 1–12 (2020) 19. Wu, Y., et al.: Google’s neural machine translation system: bridging the gap between human and machine translation. arXiv preprint arXiv:1609.08144 (2016) 20. Xiao, J., Bai, Y., Yuille, A., Zhou, Z.: Delving into masked autoencoders for multilabel thorax disease classification. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3588–3600 (2023) 21. Xie, Y., Zhang, J., Xia, Y., Wu, Q.: UniMISS: universal medical self-supervised learning via breaking dimensionality barrier. In: Avidan, S., Brostow, G., Ciss´e, M., Farinella, G.M., Hassner, T. (eds.) ECCV 2022. LNCS, vol. 13681, pp. 558–575. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-19803-8 33
MedIM
23
22. Zhou, H.Y., Lian, C., Wang, L., Yu, Y.: Advancing radiograph representation learning with masked record modeling. In: International Conference on Learning Representations (ICLR) (2023) 23. Zhou, J., et al.: Image BERT pre-training with online tokenizer. In: International Conference on Learning Representations (ICLR) (2022) 24. Zhou, L., Liu, H., Bae, J., He, J., Samaras, D., Prasanna, P.: Self pre-training with masked autoencoders for medical image analysis. arXiv preprint arXiv:2203.05573 (2022)
UOD: Universal One-Shot Detection of Anatomical Landmarks Heqin Zhu1,2,3 , Quan Quan3 , Qingsong Yao3 , Zaiyi Liu4,5 , and S. Kevin Zhou1,2(B) 1
School of Biomedical Engineering, Division of Life Sciences and Medicine, University of Science and Technology of China, Hefei 230026, Anhui, People’s Republic of China [email protected] 2 Suzhou Institute for Advanced Research, University of Science and Technology of China, Suzhou 215123, Jiangsu, People’s Republic of China 3 Key Lab of Intelligent Information Processing of Chinese Academy of Sciences (CAS), Institute of Computing Technology, CAS, Beijing 100190, China 4 Department of Radiology, Guangdong Provincial People’s Hospital, Guangdong Academy of Medical Sciences, Guangzhou, China 5 Guangdong Provincial Key Laboratory of Artificial Intelligence in Medical Image Analysis and Application, Guangdong Provincial People’s Hospital, Guangdong Academy of Medical Sciences, Guangzhou, China Abstract. One-shot medical landmark detection gains much attention and achieves great success for its label-efficient training process. However, existing one-shot learning methods are highly specialized in a single domain and suffer domain preference heavily in the situation of multidomain unlabeled data. Moreover, one-shot learning is not robust that it faces performance drop when annotating a sub-optimal image. To tackle these issues, we resort to developing a domain-adaptive one-shot landmark detection framework for handling multi-domain medical images, named Universal One-shot Detection (UOD). UOD consists of two stages and two corresponding universal models which are designed as combinations of domain-specific modules and domain-shared modules. In the first stage, a domain-adaptive convolution model is self-supervised learned to generate pseudo landmark labels. In the second stage, we design a domain-adaptive transformer to eliminate domain preference and build the global context for multi-domain data. Even though only one annotated sample from each domain is available for training, the domainshared modules help UOD aggregate all one-shot samples to detect more robust and accurate landmarks. We investigated both qualitatively and quantitatively the proposed UOD on three widely-used public X-ray datasets in different anatomical domains (i.e., head, hand, chest) and obtained state-of-the-art performances in each domain. The code is at https://github.com/heqin-zhu/UOD_universal_oneshot_detection. Keywords: One-shot learning · Domain-adaptive model landmark detection · Transformer network
· Anatomical
c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 24–34, 2023. https://doi.org/10.1007/978-3-031-43907-0_3
UOD: Universal One-Shot Detection of Anatomical Landmarks
1
25
Introduction
Robust and accurate detecting of anatomical landmarks is an essential task in medical image applications [24,25], which plays vital parts in varieties of clinical treatments, for instance, vertebrae localization [20], orthognathic and orthodontic surgeries [9], and craniofacial anomalies assessment [4]. Moreover, anatomical landmarks exert their effectiveness in other medical image tasks such as segmentation [3], registration [5], and biometry estimation [1]. In the past years, lots of fully supervised methods [4,8,11,11,20,21,26,27] have been proposed to detect landmarks accurately and automatically. To relieve the burden of experts and reduce the amount of annotated labels, various oneshot and few-shot methods have been come up with. Zhao et al. [23] demonstrate a model which learns transformations from the images and uses the labeled example to synthesize additional labeled examples, where each transformation is composed of a spatial deformation field and an intensity change. Yao et al. [22] develop a cascaded self-supervised learning framework for one-shot medical landmark detection. They first train a matching network to calculate the cosine similarity between features from an image and a template patch, then fine-tune the pseudo landmark labels from coarse to fine. Browatzki et al. [2] propose a semisupervised method that consists of two stages. They first employ an adversarial auto-encoder to learn implicit face knowledge from unlabeled images and then fine-tune the decoder to detect landmarks with few-shot labels. However, one-shot methods are not robust enough because they are dependent on the choice of labeled template and the accuracy of detected landmarks may decrease a lot when choosing a sub-optimal image to annotate. To address this issue, Quan et al. [12] propose a novel Sample Choosing Policy (SCP) to select the most worthy image to annotate. Despite the improved performance, SCP brings an extra computation burden. Another challenge is the scalability of model building when facing multiple domains (such as different anatomical regions). While conventional wisdom is to independently train a model for each domain, Zhu et al. [26] propose a universal model YOLO for detecting landmarks across different anatomies and achieving better performances than a collection of single models. YOLO is regularly supervised using the CNN as backbone and it is unknown if the YOLO model works for a one-shot scenario and with a modern transformer architecture. Motivated by above challenges, to detect robust multi-domain label-efficient landmarks, we design domain-adaptive models and propose a universal oneshot landmark detection framework called Universal One-shot Detection (UOD), illustrated in Fig. 1. A universal model is comprised of domain-specific modules and domain-shared modules, learning the specified features of each domain and common features of all domains to eliminate domain preference and extract representative features for multi-domain data. Moreover, one-shot learning is not robust enough because of the sample selection while multi-domain oneshot learning reaps benefit from different one-shot samples from various domains, in which cross-domain features are excavated by domain-shared modules. Our proposed UOD framework consists of two stages: 1) Contrastive learning for
26
H. Zhu et al.
Stage I ...
multi-domain unlabeled data
COS
augmentation
universal convnet
multi-scale embedding
softmax
cosine similarity
probability map
COS softmax
Train
MUL
dot multiply
...
Pseudo labels
Infer
supervision
one-shot sample
Stage II extraction
multi-domain unlabeled data universal transformer
Predicted landmarks
heatmap
Fig. 1. Overview of UOD framework. In stage I, two universal models are learned via contrastive learning for matching similar patches from original image and augmented one-shot sample image and generating pseudo labels. In stage II, DATR is designed to better capture global context information among all domains for detecting more accurate landmarks.
training a universal model with multi-domain data to generate pseudo landmark labels. 2) Supervised learning for training domain-adaptive transformer (DATR) to avoid domain preference and detect robust and accurate landmarks. In summary, our contributions can be categorized into three parts: 1) We design the first universal framework for multi-domain one-shot landmark detection, which improves detecting accuracy and relieves domain preference on multidomain data from various anatomical regions. 2) We design a domain-adaptive transformer block (DATB), which is effective for multi-domain learning and can be used in any other transformer network. 3) We carry out comprehensive experiments to demonstrate the effectiveness of UOD for obtaining SOTA performance on three publicly used X-ray datasets of head, hand, and chest.
2
Method
As Fig. 1 shows, UOD consists of two stages: 1) Contrastive learning and 2) Supervised learning. In stage I, to learn the local appearance of each domain, a universal model is trained via self-supervised learning, which contains domainspecific VGG [15] and UNet [13] decoder with each standard convolution replaced by a domain adaptor [7]. In stage II, to grasp the global constraint and eliminate domain preference, we designed a domain-adaptive transformer (DATR).
UOD: Universal One-Shot Detection of Anatomical Landmarks
2.1
27
Stage I: Contrastive Learning
As Fig. 1 shows, following Yao et al. [22], we employ contrastive learning to train siamese network for matching similar patches of original image and augmented d d d image. Given a multi-domain input image X d ∈ RH ×W ×C belongs to domain d from multi-domain data, we randomly select a target point P and crop a halfsize patch Xpd which contains P . After applying data augmentation on Xpd , the target point is mapped to Pp . Then we feed X d and Xpd into the siamese network respectively and obtain the multi-scale feature embeddings. We compute cosine similarity of two feature embeddings from each scale and apply softmax to the cosine similarity map to generate a probability matrix. Finally, we calculate the cross entropy loss of the probability matrix and ground truth map which is produced with the one-hot encoding of Ppd to optimize the siamese network for learning the latent similarities of patches. At inferring stage, we replace augmented patch Xpd with the augmented one-shot sample patch Xsd . We use the annotated one-shot landmarks as target points to formulate the ground truth maps. After obtaining probability matrices, we apply arg max to extract the strongest response points as the pseudo landmarks, which will be used in UOD Stage II. 2.2
Stage II: Supervised Learning
In stage II, we design a universal transformer to capture global relationship of multi-domain data and train it with the pseudo landmarks generated in stage I. The universal transformer has a domain-adaptive transformer encoder and domain-adaptive convolution decoder. The decoder is based on a U-Net [13] decoder with each standard convolution replaced by a domain adaptor [7]. The encoder is based on Swin Transformer [10] with shifted window and limited self-attention within non-overlapping local windows for computation efficiency. Different from Swin Transformer [10], we design a domain-adaptive transformer block (DATB) and use it to replace the original transformer block. Domain-Adaptive Transformer Encoder. As Fig. 2(a) shows, the transformer encoder is built up with DATB, making full use of the capability of transformer for modeling global relationship and extracting multi-domain representative features. As in Fig. 2(b), a basic transformer block [17] consists of a multi-head self-attention module (MSA), followed by a two-layer MLP with GELU activation. Furthermore, layer normalization (LN) is adopted before each MSA and MLP and a residual connection is adopted after each MSA and MLP. Given a feature map xd ∈ Rh×w×c from domain d with height h, width w, and c channels, the output feature maps of MSA and MLP, denoted by yˆd and y d , respectively, are formulated as: yˆd = MSA(LN(xd )) + xd y d = MLP(LN(ˆ y d )) + yˆd
(1)
28
H. Zhu et al.
Fig. 2. (a) The architecture of DATR in stage II, which is composed of domainadaptive transformer encoder and convolution adaptors [7]. (b) Basic transformer block. (c) Domain-adaptive transformer block. Each domain-adaptive transformer is a basic transformer block with query matrix duplicated and domain-adaptive diagonal for each domain. The batch-normalization, activation, and patch merging are omitted.
where MSA = softmax(QK T )V . As illustrated in Fig. 2(b)(c), DATB is based on Eq. (1). Similar to U2Net [7] and GU2Net [26], we adopt domain-specific and domain-shared parameters in DATB. Since the attention probability is dependent on query and key matrix which are symmetrical, we duplicate the query matrix for each domain to learn domain-specific query features and keep key and value matrix domain-shared to learn common knowledge and reduce parameters. Inspired by LayerScale [16], we further adopt learnable diagonal matrix [16] after each MSA and MLP module to facilitate the learning of domain-specific features, which costs few parameters (O(N ) for N × N diagonal). Different from LayerScale [16], proposed domainadaptive diagonal D1d and D2d are applied for each domain with D2d applied after residual connection for generating more representative and direct domain-specific features. The above process can be formulated as: yˆd = D1d × MSAQd (LN(xd )) + xd
(2)
y d = D2d × (MLP(LN(ˆ y d )) + yˆd ) where MSAQd = softmax(Qd K T )V .
Overall Pipeline. Given that a random input X d ∈ RH ×W ×C belongs to domain d from mixed datasets on various anatomical regions, which contains N d landmarks with corresponding coordinates being d )}, we set the n-th ∈ {1, 2, . . . , N d } initial heatmap {(id1 , j1d ), (id2 , j2d ), . . . , (idNd , jN d d
d
d
UOD: Universal One-Shot Detection of Anatomical Landmarks
29
(i−id )2 +(j−j d )2
n n d d d 1 2σ 2 Y˜nd ∈ RH ×W ×C with Gaussian function to be Y˜nd = √2πσ e− if (i − idn )2 + (j − jnd )2 ≤ σ and 0 otherwise. We further add an exponential weight to the Gaussian distribution to distinguish close heatmap pixels and ˜d obtain the ground truth heatmap Ynd (i, j) = αYn (i,j) . As illustrated in Fig. 2, firstly, the input image from a random batch is partitioned into non-overlapping patches and linearly embedded. Next, these patches are fed into cascaded transformer blocks at each stage, which are merged except in the last stage. Finally, a domain-adaptive convolution decoder makes dense prediction to generate heatmaps, which is further used to extract landmarks via threshold processing and connected components filtering.
3
Experiment
Datasets. For performance evaluation, we adopt three public X-ray datasets from different domains on various anatomical regions of head, hand, and chest. (i) Head dataset is a widely-used dataset for IEEE ISBI 2015 challenge [18,19] which contains 400 X-ray cephalometric images with 150 images for training and 250 images for testing. Each image is of size 2400 × 1935 with a resolution of 0.1 mm × 0.1 mm, which contains 19 landmarks manually labeled by two medical experts and we use the average labels same as Payer et al. [11]. (ii) Hand dataset is collected by [6] which contains 909 X-ray images and 37 landmarks annotated by [11]. We follow [26] to split this dataset into a training set of 609 images and a test set of 300 images. Following [11] we assume the distance between two endpoints of wrist is 50 mm and calculate the physical distance 50 where p, q are the two endpoints of as distancephysical = distancepixel × p−q 2 the wrist respectively. (iii) Chest dataset [26] is a popular chest radiography database collected by Japanese Society of Radiological Technology (JSRT) [14] which contains 247 images. Each image is of size 2048 × 2048 with a resolution of 0.175 mm × 0.175 mm. We split it into a training set of 197 images and a test set of 50 images and select 6 landmarks from landmark labels at the boundary of the lung as target landmarks. Implementation Details. UOD is implemented in Pytorch and trained on a TITAN RTX GPU with CUDA version being 11. All encoders are initialized with corresponding pre-trained weights. We set batch size to 8, σ to 3, and α to 10. We adopt binary cross-entropy (BCE) as loss function for both stages. In stage I, we resize each image to the same shape of 384 × 384 and train universal convolution model by Adam optimizer for 1000 epochs with a learning rate of 0.00001. In stage II, we resize each image to the same shape of 576 × 576 and optimize the universal transformer by Adam optimizer for 300 epochs with a learning rate of 0.0001. When calculating metrics, all predicted landmarks are resized back to the original size. For evaluation, we choose model with minimum validation loss as the inference model and adopt two metrics: mean radial error (MRE) N (xi − x ˜i )2 + (yi − y˜i )2 and successful detection rates (SDR) MRE = N1 i N within different thresholds t: SDR(t) = N1 i δ( (xi − x ˜i )2 + (yi − y˜i )2 ≤ t).
30
H. Zhu et al.
Fig. 3. Comparison of single model and universal model on head dataset. Table 1. Quantitative comparison of UOD with SOTA methods on head, hand, and chest datasets. * denotes the method is trained on every single dataset respectively while †denotes the method is trained on mixed data. Method
Label Head [19] MRE↓ SDR↑ (%) (mm) 2 mm 2.5 mm 3 mm
4 mm
Hand [6] MRE↓ SDR↑ (%) ( mm) 2 mm 4 mm
Chest [14] MRE↓ SDR↑ (%) 10 mm (mm) 2 mm 4 mm
10 mm 93.67
YOLO [26]† all
1.32
81.14
87.85
92.12
96.80
0.85
94.93
99.14
99.67
4.65
31.00
69.00
YOLO [26]† 25
1.96
62.05
77.68
88.21
97.11
2.88
72.71
92.32
97.65
7.03
19.33
51.67
89.33
YOLO [26]† 10
2.69
47.58
66.47
78.42
90.89
9.70
48.66
76.69
90.52
16.07
11.67
33.67
76.33
YOLO [26]† 5
5.40
26.16
41.32
54.42
73.74
24.35
20.59
48.91
72.94
34.81
4.33
19.00
56.67
CC2D [22]* 1
2.76
42.36
51.82
64.02
78.96
2.65
51.19
82.56
95.62
10.25
11.37
35.73
68.14
Ours†
2.43
51.14 62.37
2.52
53.37 84.27 97.59
3.1
1
74.40 86.49
8.49
14.00 39.33 76.33
Experimental Results
The Effectiveness of Universal Model: To demonstrate the effectiveness of universal model for multi-domain one-shot learning, we adopt head and hand datasets for evaluation. In stage I, the convolution models are trained in two ways: 1) single: trained on every single dataset respectively, and 2) universal: trained on mixed datasets together. With a fixed one-shot sample for the hand dataset, we change the one-shot sample for the head dataset and report the MRE and SDR of the head dataset. As Fig. 3 shows, universal model performs much better than single model on various one-shot samples and metrics. It is proved that universal model learns domain-shared knowledge and promotes domainspecific learning. Furthermore, the MRE and SDR metrics of universal model have a smaller gap among various one-shot samples, which demonstrates the robustness of universal model learned on multi-domain data. Comparisons with State-of-the-Art Methods: As Table 1 shows, we compare UOD with two open-source landmark detection methods, i.e., YOLO [26] and CC2D [22]. YOLO is a multi-domain supervised method while CC2D is a single-domain one-shot method. UOD achieves SOTA results on all datasets under all metrics, outperforming the other one-shot method by a big margin. On the head dataset, benefiting from multi-domain learning, UOD achieves an MRE of 2.43 mm and an SDR of 86.49% within 4 mm, which is comparative with supervised method YOLO trained with at least 10 annotated labels, and much
UOD: Universal One-Shot Detection of Anatomical Landmarks
31
Table 2. Ablation study of different components of our DATR. Base is the basic transformer block; MSAQd denotes the domain-adaptive self-attention and Dd denotes the domain-adaptive diagonal matrix. In each column, the best results are in bold. Transformer
Head [19] MRE↓ SDR↑ (%) (mm) 2 mm 2.5 mm 3 mm
Hand [6] MRE↓ SDR↑ (%) (mm) 2 mm 4 mm
Chest [14] MRE↓ SDR↑ (%) 10 mm (mm) 2 mm 4 mm
(a) Base
24.95
2.02
3.17
4.51
5.85
9.83
5.33
16.79
58.64
58.11
0.37
1.96
(b) +Dd
22.75
2.13
3.24
3.85
4.61
6.96
7.52
6.13
20.66
68.43
52.98
0.59
2.17
4.68
(c) +MSAQd
2.51
49.29
60.89
(d) +MSAQd +Dd
2.43
51.14 62.37
72.17
84.36
2.72
48.56
80.44
94.38
9.09
12.00
19.33
74.00
8.49
14.00 39.33 76.33
4 mm
74.40 86.49 2.52
53.37 84.27 97.59
10 mm
Fig. 4. Qualitative comparison of UOD and CC2D [22] on head, hand, and chest datasets. The red points • indicate predicted landmarks while the green points • indicate ground truth landmarks. The MRE value is displayed in the top left corner of the image. (Color figure online)
better than CC2D. On the hand dataset, there are some performance improvements in all metrics compared to CC2D, outperforming the supervised method YOLO trained with 25 annotated images. On the chest dataset, UOD shows the superiority of DATR which eliminates domain preference and balances the performance of all domains. In contrast, the performance of YOLO on chest dataset suffers a tremendous drop when the available labels are reduced to 25, 10, and 5. Figure 4 visualizes the predicted landmarks by UOD and CC2D. Ablation Study: We compare various components of the proposed domainadaptive transformer. The experiments are carried out in UOD Stage II. As presented in Table 2, the domain-adaptive transformer has two key components: domain-adaptive self-attention MSAQd and domain-adaptive diagonal matrix Dd . The performances of (b) and (c) are much superior to those of (a) which demonstrates the effectiveness of Dd and MSAQd . Further, (d) combines the two components and achieves much better performances, which illustrates that domain-adaptive transformer improves the accuracy of detecting via crossdomain knowledge and global context information. We take (d) as the final transformer block.
4
Conclusion
To improve the robustness and reduce domain preference of multi-domain oneshot learning, we design a universal framework in that we first train a universal
32
H. Zhu et al.
model via contrastive learning to generate pseudo landmarks and further use these labels to learn a universal transformer for accurate and robust detection of landmarks. UOD is the first universal framework of one-shot landmark detection on multi-domain data, which outperforms other one-shot methods on three public datasets from different anatomical regions. We believe UOD will significantly reduce the labeling burden and pave the path of developing more universal framework for multi-domain one-shot learning. Acknowledgment. Supported by Natural Science Foundation of China under Grant 62271465 and Open Fund Project of Guangdong Academy of Medical Sciences, China (No. YKY-KF202206).
References 1. Avisdris, N., et al.: BiometryNet: landmark-based fetal biometry estimation from standard ultrasound planes. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13434, pp. 279–289. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16440-8_27 2. Browatzki, B., Wallraven, C.: 3fabrec: fast few-shot face alignment by reconstruction. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 6110–6120 (2020) 3. Chen, Z., Qiu, T., Tian, Y., Feng, H., Zhang, Y., Wang, H.: Automated brain structures segmentation from PET/CT images based on landmark-constrained dual-modality atlas registration. Phys. Med. Biol. 66(9), 095003 (2021) 4. Elkhill, C., LeBeau, S., French, B., Porras, A.R.: Graph convolutional network with probabilistic spatial regression: Application to craniofacial landmark detection from 3D photogrammetry. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13433, pp. 574–583. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16437-8_55 5. Espinel, Y., Calvet, L., Botros, K., Buc, E., Tilmant, C., Bartoli, A.: Using multiple images and contours for deformable 3D-2D registration of a preoperative ct in laparoscopic liver surgery. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12904, pp. 657–666. Springer, Cham (2021). https://doi.org/10.1007/978-3030-87202-1_63 6. Gertych, A., Zhang, A., Sayre, J., Pospiech-Kurkowska, S., Huang, H.: Bone age assessment of children using a digital hand atlas. Comput. Med. Imaging Graph. 31(4–5), 322–331 (2007) 7. Huang, C., Han, H., Yao, Q., Zhu, S., Zhou, S.K.: 3D U2 -Net: a 3D universal U-net for multi-domain medical image segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 291–299. Springer, Cham (2019). https://doi.org/10. 1007/978-3-030-32245-8_33 8. Jiang, Y., Li, Y., Wang, X., Tao, Y., Lin, J., Lin, H.: CephalFormer: incorporating global structure constraint into visual features for general cephalometric landmark detection. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13433, pp. 227–237. Springer, Cham (2022). https://doi.org/10. 1007/978-3-031-16437-8_22 9. Lang, Y., et al.: DentalPointNet: landmark localization on high-resolution 3d digital dental models. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.)
UOD: Universal One-Shot Detection of Anatomical Landmarks
10.
11.
12.
13.
14.
15.
16.
17. 18.
19. 20. 21.
22.
23.
24.
25.
33
MICCAI 2022. LNCS, vol. 13432, pp. 444–452. Springer, Cham (2022). https:// doi.org/10.1007/978-3-031-16434-7_43 Liu, Z., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10012–10022 (2021) Payer, C., Štern, D., Bischof, H., Urschler, M.: Integrating spatial configuration into heatmap regression based CNNs for landmark localization. Med. Image Anal. 54, 207–219 (2019) Quan, Q., Yao, Q., Li, J., Zhou, S.K.: Which images to label for few-shot medical landmark detection? In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 20606–20616 (2022) Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28 Shiraishi, J., et al.: Development of a digital image database for chest radiographs with and without a lung nodule: receiver operating characteristic analysis of radiologists’ detection of pulmonary nodules. Am. J. Roentgenol. 174(1), 71–74 (2000) Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. In: International Conference on Learning Representations (ICLR), pp. 1–14 (2015) Touvron, H., Cord, M., Sablayrolles, A., Synnaeve, G., Jégou, H.: Going deeper with image transformers. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 32–42 (2021) Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) Wang, C.W., et al.: Evaluation and comparison of anatomical landmark detection methods for cephalometric x-ray images: a grand challenge. IEEE Trans. Med. Imaging 34(9), 1890–1900 (2015) Wang, C.W., et al.: A benchmark for comparison of dental radiography analysis algorithms. Med. Image Anal. 31, 63–76 (2016) Wang, Z., et al.: Accurate scoliosis vertebral landmark localization on x-ray images via shape-constrained multi-stage cascaded CNNs. Fundam. Res. (2022) Yao, Q., He, Z., Han, H., Zhou, S.K.: Miss the point: targeted adversarial attack on multiple landmark detection. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12264, pp. 692–702. Springer, Cham (2020). https://doi.org/10.1007/978-3030-59719-1_67 Yao, Q., Quan, Q., Xiao, L., Kevin Zhou, S.: One-shot medical landmark detection. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 177–188. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87196-3_17 Zhao, A., Balakrishnan, G., Durand, F., Guttag, J.V., Dalca, A.V.: Data augmentation using learned transformations for one-shot medical image segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8543–8553 (2019) Zhou, S.K., et al.: A review of deep learning in medical imaging: imaging traits, technology trends, case studies with progress highlights, and future promises. Proc. IEEE (2021) Zhou, S.K., Rueckert, D., Fichtinger, G.: Handbook of Medical Image Computing and Computer Assisted Intervention. Academic Press, Cambridge (2019)
34
H. Zhu et al.
26. Zhu, H., Yao, Q., Xiao, L., Zhou, S.K.: You only learn once: universal anatomical landmark detection. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 85–95. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087240-3_9 27. Zhu, H., Yao, Q., Xiao, L., Zhou, S.K.: Learning to localize cross-anatomy landmarks in x-ray images with a universal model. BME Front. 2022 (2022)
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning for Scribble-Supervised Polyp Segmentation An Wang1 , Mengya Xu2 , Yang Zhang1,3 , Mobarakol Islam4 , and Hongliang Ren1,2(B) 1
3
Department of Electronic Engineering, Shun Hing Institute of Advanced Engineering (SHIAE), The Chinese University of Hong Kong, Hong Kong, Hong Kong SAR, China [email protected], [email protected], [email protected] 2 Department of Biomedical Engineering, National University of Singapore, Singapore, Singapore [email protected] School of Mechanical Engineering, Hubei University of Technology, Wuhan, China 4 Department of Medical Physics and Biomedical Engineering, Wellcome/EPSRC Centre for Interventional and Surgical Sciences (WEISS), University College London, London, UK [email protected] Abstract. Fully-supervised polyp segmentation has accomplished significant triumphs over the years in advancing the early diagnosis of colorectal cancer. However, label-efficient solutions from weak supervision like scribbles are rarely explored yet primarily meaningful and demanding in medical practice due to the expensiveness and scarcity of denselyannotated polyp data. Besides, various deployment issues, including data shifts and corruption, put forward further requests for model generalization and robustness. To address these concerns, we design a framework of Spatial-Spectral Dual-branch Mutual Teaching and Entropy-guided Pseudo Label Ensemble Learning (S2 ME). Concretely, for the first time in weakly-supervised medical image segmentation, we promote the dualbranch co-teaching framework by leveraging the intrinsic complementarity of features extracted from the spatial and spectral domains and encouraging cross-space consistency through collaborative optimization. Furthermore, to produce reliable mixed pseudo labels, which enhance the effectiveness of ensemble learning, we introduce a novel adaptive pixelwise fusion technique based on the entropy guidance from the spatial and spectral branches. Our strategy efficiently mitigates the deleterious effects of uncertainty and noise present in pseudo labels and surpasses previous alternatives in terms of efficacy. Ultimately, we formulate a holistic optimization objective to learn from the hybrid supervision of
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 4. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 35–45, 2023. https://doi.org/10.1007/978-3-031-43907-0_4
36
A. Wang et al. scribbles and pseudo labels. Extensive experiments and evaluation on four public datasets demonstrate the superiority of our method regarding in-distribution accuracy, out-of-distribution generalization, and robustness, highlighting its promising clinical significance. Our code is available at https://github.com/lofrienger/S2ME. Keywords: Polyp Image Segmentation · Weakly-supervised Learning · Spatial-Spectral Dual Branches · Mutual Teaching · Ensemble Learning
1
Introduction
Colorectal cancer is a leading cause of cancer-related deaths worldwide [1]. Early detection and efficient diagnosis of polyps, which are precursors to colorectal cancer, is crucial for effective treatment. Recently, deep learning has emerged as a powerful tool in medical image analysis, prompting extensive research into its potential for polyp segmentation. The effectiveness of deep learning models in medical applications is usually based on large, well-annotated datasets, which in turn necessitates a time-consuming and expertise-driven annotation process. This has prompted the emergence of approaches for annotation-efficient weakly-supervised learning in the medical domain with limited annotations like points [8], bounding boxes [12], and scribbles [15]. Compared with other sparse labeling methods, scribbles allow the annotator to annotate arbitrary shapes, making them more flexible than points or boxes [13]. Besides, scribbles provide a more robust supervision signal, which can be prone to noise and outliers [5]. Hence, this work investigates the feasibility of conducting polyp segmentation using scribble annotation as supervision. The effectiveness of medical applications during in-site deployment depends on their ability to generalize to unseen data and remain robust against data corruption. Improving these factors is crucial to enhance the accuracy and reliability of medical diagnoses in real-world scenarios [22,27,28]. Therefore, we comprehensively evaluate our approach on multiple datasets from various medical sites to showcase its viability and effectiveness across different contexts. Dual-branch learning has been widely adopted in annotation-efficient learning to encourage mutual consistency through co-teaching. While existing approaches are typically designed for learning in the spatial domain [21,25,29, 30], a novel spatial-spectral dual-branch structure is introduced to efficiently leverage domain-specific complementary knowledge with synergistic mutual teaching. Furthermore, the outputs from the spatial-spectral branches are aggregated to produce mixed pseudo labels as supplementary supervision. Different from previous methods, which generally adopt the handcrafted fusion strategies [15], we design to aggregate the outputs from spatial-spectral dual branches with an entropy-guided adaptive mixing ratio for each pixel. Consequently, our incorporated tactic of pseudo-label fusion aptly assesses the pixel-level ambiguity emerging from both spatial and frequency domains based on their entropy maps, thereby allocating substantially assured categorical labels to individual pixels and facilitating effective pseudo label ensemble learning.
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
37
Fig. 1. Overview of our Spatial-Spectral Dual-branch Mutual Teaching and Pixel-level Entropy-guided Pseudo Label Ensemble Learning (S2 ME) for scribble-supervised polyp segmentation. Spatial-spectral cross-domain consistency is encouraged through mutual teaching. High-quality mixed pseudo labels are generated with pixel-level guidance from the dual-space entropy maps, ensuring more reliable supervision for ensemble learning.
Contributions. Overall, the contributions of this work are threefold: First, we devise a spatial-spectral dual-branch structure to leverage cross-space knowledge and foster collaborative mutual teaching. To our best knowledge, this is the first attempt to explore the complementary relations of the spatial-spectral dual branch in boosting weakly-supervised medical image analysis. Second, we introduce the pixel-level entropy-guided fusion strategy to generate mixed pseudo labels with reduced noise and increased confidence, thus enhancing ensemble learning. Lastly, our proposed hybrid loss optimization, comprising scribblessupervised loss, mutual training loss with domain-specific pseudo labels, and ensemble learning loss with fused-domain pseudo labels, facilitates obtaining a generalizable and robust model for polyp image segmentation. An extensive assessment of our approach through the examination of four publicly accessible datasets establishes its superiority and clinical significance.
2 2.1
Methodology Preliminaries
Spectral-domain learning [26] has gained increasing popularity in medical image analysis [23] for its ability to identify subtle frequency patterns that may not be well detected by the pure spatial-domain network like UNet [20]. For instance, a recent dual-encoder network, YNet [6], incorporates a spectral encoder with Fast Fourier Convolution (FFC) [4] to disentangle global patterns across varying frequency components and derives hybrid feature representation. In addition, spectrum learning also exhibits advantageous robustness and generalization against adversarial attacks, data corruption, and distribution shifts [19]. In label-efficient learning, some preliminary works have been proposed to encourage
38
A. Wang et al.
mutual consistency between outputs from two networks [3], two decoders [25], and teacher-student models [14], yet only in the spatial domain. As far as we know, spatial-spectral cross-domain consistency has never been investigated to promote learning with sparse annotations of medical data. This has motivated us to develop the cross-domain cooperative mutual teaching scheme to leverage the favorable properties when learning in the spectral space. Besides consistency constraints, utilizing pseudo labels as supplementary supervision is another principle in label-efficient learning [11,24]. To prevent the model from being influenced by noise and inaccuracies within the pseudo labels, numerous studies have endeavored to enhance their quality, including averaging the model predictions from several iterations [11], filtering out unreliable pixels [24], and mixing dual-branch outputs [15] following pmix = α × p1 + (1 − α) × p2 , α = random(0, 1),
(1)
where α is the random mixing ratio. p1 , p2 , and pmix denote the probability maps from the two spatial decoders and their mixture. These approaches only operate in the spatial domain, regardless of single or dual branches, while we consider both spatial and spectral domains and propose to adaptively merge dual-branch outputs with respective pixel-wise entropy guidance. 2.2
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
Spatial-Spectral Cross-domain Mutual Teaching. In contrast to prior weakly-supervised learning methods that have merely emphasized spatial considerations, our approach designs a dual-branch structure consisting of a spatial branch fspa (x, θspa ) and a spectral branch fspe (x, θspe ), with x and θ being the input image and randomly initialized model parameters. As illustrated in Fig. 1, the spatial and spectral branches take the same training image as the input and extract domain-specific patterns. The raw model outputs, i.e., the logits lspa and lspe , will be converted to probability maps pspa and pspe with Softmax normalization, and further to respective pseudo labels yˆspa and yˆspe by yˆ = arg max p. The spatial and spectral pseudo labels supervise the other branch collaboratively during mutual teaching and can be expressed as yˆspa → fspe and yˆspe → fspa ,
(2)
1
where “→” denotes supervision . Through cross-domain engagement, these two branches complement each other, with each providing valuable domain-specific insights and feedback to the other. Consequently, such a scheme can lead to better feature extraction, more meaningful data representation, and domain-specific knowledge transmission, thus boosting model generalization and robustness. Entropy-Guided Pseudo Label Ensemble Learning. In addition to mutual teaching, we consider aggregating the pseudo labels from the spatial and spectral branches in ensemble learning, aiming to take advantage of the distinctive 1
For convenience, we omit the input x and model parameters θ.
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
39
yet complementary properties of the cross-domain features. As we know, a pixel characterized by a higher entropy value indicates elevated uncertainty in terms of its corresponding prediction. We can observe from the entropy maps Hspa and Hspe in Fig. 1 that the pixels of the polyp boundary exhibit greater difficulties in accurate segmentation, presenting with higher entropy values (the white contours). Considering such property, we propose a novel adaptive strategy to automatically adjust the mixing ratio for each pixel based on the entropy of its categorical probability distribution. Hence, the mixed pseudo labels are more reliable and beneficial for ensemble learning. Concretely, with the spatial and spectral probability maps pspa and pspe , the corresponding entropy maps Hspa and Hspe can be computed with H=−
C−1
p(c) × log p(c),
(3)
c=0
where C is the number of classes that equals 2 in our task. Unlike previous image-level fixed-ratio mixing or random mixing as Eq. (1), we can update the mixing ratio between the two probability maps pspa and pspe with the weighted entropy guidance at each pixel location by ps 2 =
Hspe Hspa ⊗ pspa + ⊗ pspe , Hspa + Hspe Hspa + Hspe
(4)
where “⊗” denotes pixel-wise multiplication. ps2 is the merged probability map and can be further converted to the pseudo label by yˆs2 = arg max ps2 to supervise the spatial and spectral branch in the context of ensemble learning following yˆs2 → fspa and yˆs2 → fspe .
(5)
By absorbing strengths from the spatial and spectral branches, ensemble learning from the mixed pseudo labels facilitates model optimization with reduced overfitting, increased stability, and improved generalization and robustness. Hybrid Loss Supervision from Scribbles and Pseudo Labels. Besides the scribble annotations for partial pixels, the aforementioned three types of pseudo labels yˆspa , yˆspe , and yˆs2 can offer complementary supervision for every pixel, with different learning regimes. Overall, our hybrid loss supervision is based on Cross Entropy loss CE and Dice loss Dice . Specifically, we employ the partial Cross Entropy loss [13] pCE , which only calculates the loss on the labeled pixels, for learning from scribbles following Lscrib = pCE (lspa , y) + pCE (lspe , y),
(6)
where y denotes the scribble annotations. Furthermore, the mutual teaching loss with supervision from domain-specific pseudo labels is Lmt = {CE (lspa , yˆspe )+Dice (pspa , yˆspe )} + {CE (lspe , yˆspa )+Dice (pspe , yˆspa )} . (7) y ˆspe →fspa
y ˆspa →fspe
40
A. Wang et al.
Likewise, the ensemble learning loss with supervision from the enhanced mixed pseudo labels can be formulated as Lel = {CE (lspa , yˆs2 ) + Dice (pspa , yˆs2 )} + {CE (lspe , yˆs2 ) + Dice (pspe , yˆs2 )} . y ˆs2 →fspa
(8)
y ˆs2 →fspe
Holistically, our hybrid loss supervision can be stated as Lhybrid = Lscrib + λmt × Lmt + λel × Lel ,
(9)
where λmt and λel serve as weighting coefficients that regulate the relative significance of various modes of supervision. The hybrid loss considers all possible supervision signals in the spatial-spectral dual-branch network and exceeds partial combinations of its constituent elements, as evidenced in the ablation study.
3
Experiments
3.1
Experimental Setup
Datasets. We employ the SUN-SEG [10] dataset with scribble annotations for training and assessing the in-distribution performance. This dataset is based on the SUN database [16], which contains 100 different polyp video cases. To reduce data redundancy and memory consumption, we choose the first of every five consecutive frames in each case. We then randomly split the data into 70, 10, and 20 cases for training, validation, and testing, leaving 6677, 1240, and 1993 frames in the respective split. For out-of-distribution evaluation, we utilize three public datasets, namely Kvasir-SEG [9], CVC-ClinicDB [2], and PolypGen [1] with 1000, 612, and 1537 polyp frames, respectively. These datasets are collected from diversified patients in multiple medical centers with various data acquisition systems. Varying data shifts and corruption like motion blur and specular reflections2 pose significant challenges to model generalization and robustness. Implementation Details. We implement our method with PyTorch [18] and run the experiments on a single NVIDIA RTX3090 GPU. The SGD optimizer is utilized for training 30k iterations with a momentum of 0.9, a weight decay of 0.0001, and a batch size of 16. The execution time for each experiment is approximately 4 h. The initial learning rate is 0.03 and updated with the polyscheduling policy [15]. The loss weighting coefficients λmt and λel are empirically set the same and exponentially ramped up [3] from 0 to 5 in 25k iterations. All the images are randomly cropped at the border with maximally 7 pixels and resized to 224×224 in width and height. Besides, random horizontal and vertical flipping are applied with a probability of 0.5, respectively. We utilize UNet [20] and YNet [6] as the respective segmentation model in the spatial and spectral branches. The performance of the scribble-supervised model with partial Cross Entropy [13] loss (Scrib-pCE) and the fully-supervised 2
Some exemplary polyp frames are presented in the supplementary materials.
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
41
model with Cross Entropy loss (Fully-CE) are treated as the lower and upper bound, respectively. Five classical and relevant methods, including EntMin [7], GCRF [17], USTM [14], CPS [3], and DMPLS [15] are employed as the comparative baselines and implemented with UNet [20] as the segmentation backbone referring to the WSL4MIS3 repository. For a fair comparison, the output from the spatial branch is taken as the final prediction and utilized in evaluation without post-processing. In addition, statistical evaluations are conducted with multiple seeds, and the mean and standard deviations of the results are reported. 3.2
Results and Analysis
Table 1. Quantitative comparison of the in-distribution segmentation performance. The shaded grey and blue rows are the lower and upper bound. The best results of the scribble-supervised methods are in bold. Method
DSC ↑
SUN-SEG [10] IoU ↑ Prec ↑
HD ↓
Scrib-pCE [13] 0.633±0.010 0.511±0.012 0.636±0.021 5.587±0.149 EntMin [7] 0.642±0.012 0.519±0.013 0.666±0.016 5.277±0.063 GCRF [17] 0.656±0.019 0.541±0.022 0.690±0.017 4.983±0.089 USTM [14] 0.654±0.008 0.533±0.009 0.663±0.011 5.207±0.138 CPS [3] 0.658±0.004 0.539±0.005 0.676±0.005 5.092±0.063 DMPLS [15] 0.656±0.006 0.539±0.005 0.659±0.011 5.208±0.061 S2 ME (Ours) 0.674±0.003 0.565±0.001 0.719±0.003 4.583±0.014 Fully-CE 0.713±0.021 0.617±0.023 0.746±0.027 4.405±0.119
The performance of weakly-supervised methods is assessed with four metrics,i.e., Dice Similarity Coefficient (DSC), Intersection over Union (IoU), Precision (Prec), and a distance-based measure of Hausdorff Distance (HD). As shown in Table 1 and Fig. 2, our S2 ME achieves superior in-distribution performance quantitatively and qualitatively compared with other baselines on the SUN-SEG [10] dataset. Regarding generalization and robustness, as indicated in Table 2, our method outperforms other weakly-supervised methods by a significant margin on three unseen datasets, and even exceeds the fully-supervised upper bound on two of them4 . These results suggest the efficacy and reliability of the proposed solution S2 ME in fulfilling polyp segmentation tasks with only scribble annotations. Notably, the encouraging performance on unseen datasets exhibits promising clinical implications in deploying our method to real-world scenarios.
3 4
https://github.com/HiLab-git/WSL4MIS. Complete results of all four metrics are present in the supplementary materials.
42
A. Wang et al.
Fig. 2. Qualitative performance comparison of one camouflaged polyp image with DSC values on the left top. The contour of the ground-truth mask is displayed in black, in comparison with that of each method shown in different colors. Table 2. Generalization comparison on three unseen datasets. The underlined results surpass the upper bound. Method
Kvasir-SEG [9] DSC ↑ HD ↓
CVC-ClinicDB [2] DSC ↑ HD ↓
PolypGen [1] DSC ↑ HD ↓
Scrib-pCE [13] 0.679±0.010 6.565±0.173 0.573±0.016 6.497±0.156 0.524±0.012 6.084±0.189 EntMin [7] 0.684±0.004 6.383±0.110 0.578±0.016 6.308±0.254 0.542±0.003 5.887±0.063 GCRF [17] 0.702±0.004 6.024±0.014 0.558±0.008 6.192±0.290 0.530±0.006 5.714±0.133 USTM [14] 0.693±0.005 6.398±0.138 0.587±0.019 5.950±0.107 0.538±0.007 5.874±0.068 CPS [3] 0.703±0.011 6.323±0.062 0.591±0.017 6.161±0.074 0.546±0.013 5.844±0.065 DMPLS [15] 0.707±0.006 6.297±0.077 0.593±0.013 6.194±0.028 0.547±0.007 5.897±0.045 S2 ME (Ours) 0.750±0.003 5.449±0.150 0.632±0.010 5.633±0.008 0.571±0.002 5.247±0.107 Fully-CE 0.758±0.013 5.414±0.097 0.631±0.026 6.017±0.349 0.569±0.016 5.252±0.128
3.3
Ablation Studies
Network Structures. We first conduct the ablation analysis on the network components. As shown in Table 3, the spatial-spectral configuration of our S2 ME yields superior performance compared to single-domain counterparts with ME, confirming the significance of utilizing cross-domain features. Table 3. Ablation comparison of dual-branch network architectures. Results are from outputs of Model-1 on the SUN-SEG [10] dataset. DSC ↑
IoU ↑
Prec ↑
UNet [20] UNet [20] ME (Ours)
0.666 ± 0.002
0.557 ± 0.002
0.715 ± 0.008
4.684 ± 0.034
YNet [6]
0.648 ± 0.004
0.538 ± 0.005
0.695 ± 0.004
4.743 ± 0.006
Model-1
Model-2 YNet [6]
UNet [20] YNet [6]
Method
HD ↓
S2 ME (Ours) 0.674 ± 0.003 0.565 ± 0.001 0.719 ± 0.003 4.583 ± 0.014
Pseudo Label Fusion Strategies. To ensure the reliability of the mixed pseudo labels for ensemble learning, we present the pixel-level adaptive fusion strategy according to entropy maps of dual predictions to balance the strengths and weaknesses of spatial and spectral branches. As demonstrated in Table 4, our method achieves improved performance compared to two image-level fusion strategies, i.e., random [15] and equal mixing. Hybrid Loss Supervision. We decompose the proposed hybrid loss Lhybrid in Eq. (9) to demonstrate the effectiveness of holistic supervision from scribbles,
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
43
Table 4. Ablation on the pseudo label Table 5. Ablation study on the loss comfusion strategies on the SUN-SEG [10] ponents on the SUN-SEG [10] dataset. dataset. Loss Fusion
Metrics
Metrics
Lscrib Lmt Lel DSC ↑
HD ↓
Strategy
Level DSC ↑
HD ↓
✓
✗
✗
0.627 ± 0.004
5.580 ± 0.112
Random [15]
Image 0.665 ± 0.008
4.750 ± 0.169
Image 0.667 ± 0.001
4.602 ± 0.013
✓
✓
✗
0.668 ± 0.007
4.782 ± 0.020
Equal (0.5)
✓
✗
✓
0.662 ± 0.004
4.797 ± 0.146
✓
✓
✓
0.674 ± 0.003 4.583 ± 0.014
Entropy (Ours) Pixel
0.674 ± 0.003 4.583 ± 0.014
mutual teaching, and ensemble learning. As shown in Table 5, our proposed hybrid loss, involving Lscrib , Lmt , and Lel , achieves the optimal results.
4
Conclusion
To our best knowledge, we propose the first spatial-spectral dual-branch network structure for weakly-supervised medical image segmentation that efficiently leverages cross-domain patterns with collaborative mutual teaching and ensemble learning. Our pixel-level entropy-guided fusion strategy advances the reliability of the aggregated pseudo labels, which provides valuable supplementary supervision signals. Moreover, we optimize the segmentation model with the hybrid mode of loss supervision from scribbles and pseudo labels in a holistic manner and witness improved outcomes. With extensive in-domain and out-ofdomain evaluation on four public datasets, our method shows superior accuracy, generalization, and robustness, indicating its clinical significance in alleviating data-related issues such as data shift and corruption which are commonly encountered in the medical field. Future efforts can be paid to apply our approach to other annotation-efficient learning contexts like semi-supervised learning, other sparse annotations like points, and more medical applications. Acknowledgements. This work was supported by Hong Kong Research Grants Council (RGC) Collaborative Research Fund (CRF C4063-18G), the Shun Hing Institute of Advanced Engineering (SHIAE project BME-p1-21) at the Chinese University of Hong Kong (CUHK), General Research Fund (GRF 14203323), Shenzhen-Hong Kong-Macau Technology Research Programme (Type C) STIC Grant SGDX20210823103535014 (202108233000303), and (GRS) #3110167.
References 1. Ali, S., et al.: a multi-centre polyp detection and segmentation dataset for generalisability assessment. Sci. Data 10(1), 75 (2023) 2. Bernal, J., S´ anchez, F.J., Fern´ andez-Esparrach, G., Gil, D., Rodr´ıguez, C., Vilari˜ no, F.: WM-DOVA maps for accurate polyp highlighting in colonoscopy: Validation vs. saliency maps from physicians. Comput. Med. Imaging Graph. 43, 99–111 (2015)
44
A. Wang et al.
3. Chen, X., Yuan, Y., Zeng, G., Wang, J.: Semi-supervised semantic segmentation with cross pseudo supervision. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2613–2622 (2021) 4. Chi, L., Jiang, B., Mu, Y.: Fast Fourier convolution. Adv. Neural. Inf. Process. Syst. 33, 4479–4488 (2020) 5. Cinbis, R.G., Verbeek, J., Schmid, C.: Weakly supervised object localization with multi-fold multiple instance learning. IEEE Trans. Pattern Anal. Mach. Intell. 39(1), 189–203 (2016) 6. Farshad, A., Yeganeh, Y., Gehlbach, P., Navab, N.: Y-net: a spatiospectral dualencoder network for medical image segmentation. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) Medical Image Computing and Computer Assisted Intervention - MICCAI 2022. Lecture Notes in Computer Science, vol. 13432, pp. 582–592. Springer, Cham (2022) 7. Grandvalet, Y., Bengio, Y.: Semi-supervised learning by entropy minimization. In: Advances in Neural Information Processing Systems, vol. 17 (2004) 8. He, X., Fang, L., Tan, M., Chen, X.: Intra-and inter-slice contrastive learning for point supervised oct fluid segmentation. IEEE Trans. Image Process. 31, 1870– 1881 (2022) 9. Jha, D., et al.: Kvasir-SEG: a segmented polyp dataset. In: Ro, Y.M., et al. (eds.) MMM 2020. LNCS, vol. 11962, pp. 451–462. Springer, Cham (2020). https://doi. org/10.1007/978-3-030-37734-2 37 10. Ji, G.P., et al.: Video polyp segmentation: a deep learning perspective. Mach. Intell. Res. 1–19 (2022) 11. Lee, H., Jeong, W.-K.: Scribble2Label: scribble-supervised cell segmentation via self-generating pseudo-labels with consistency. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 14–23. Springer, Cham (2020). https://doi.org/ 10.1007/978-3-030-59710-8 2 12. Li, Y., Xue, Y., Li, L., Zhang, X., Qian, X.: Domain adaptive box-supervised instance segmentation network for mitosis detection. IEEE Trans. Med. Imaging 41(9), 2469–2485 (2022) 13. Lin, D., Dai, J., Jia, J., He, K., Sun, J.: Scribblesup: scribble-supervised convolutional networks for semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3159–3167 (2016) 14. Liu, X., et al.: Weakly supervised segmentation of covid19 infection with scribble annotation on CT images. Pattern Recogn. 122, 108341 (2022) 15. Luo, X., et al.: Scribble-supervised medical image segmentation via dual-branch network and dynamically mixed pseudo labels supervision. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13431, pp. 528–538. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16431-6 50 16. Misawa, M., et al.: Development of a computer-aided detection system for colonoscopy and a publicly accessible large colonoscopy video database (with video). Gastrointest. Endosc. 93(4), 960–967 (2021) 17. Obukhov, A., Georgoulis, S., Dai, D., Van Gool, L.: Gated CRF loss for weakly supervised semantic image segmentation. arXiv preprint arXiv:1906.04651 (2019) 18. Paszke, A., et al.: Automatic differentiation in pyTorch. In: NIPS-W (2017) 19. Rao, Y., Zhao, W., Zhu, Z., Lu, J., Zhou, J.: Global filter networks for image classification. Adv. Neural. Inf. Process. Syst. 34, 980–993 (2021) 20. Ronneberger, O., Fischer, P., Brox, T.: U-net: convolutional networks for biomedical image segmentation (2015)
S2 ME: Spatial-Spectral Mutual Teaching and Ensemble Learning
45
21. Valvano, G., Leo, A., Tsaftaris, S.A.: Learning to segment from scribbles using multi-scale adversarial attention gates. IEEE Trans. Med. Imaging 40(8), 1990– 2001 (2021) 22. Wang, A., Islam, M., Xu, M., Ren, H.: Rethinking surgical instrument segmentation: a background image can be all you need. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13437, pp. 355–364. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16449-1 34 23. Wang, K.N., et al.: Ffcnet: Fourier transform-based frequency learning and complex convolutional network for colon disease classification. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13433, pp. 78–87. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16437-8 8 24. Wang, Y., et al.: Freematch: self-adaptive thresholding for semi-supervised learning. arXiv preprint arXiv:2205.07246 (2022) 25. Wu, Y., Xu, M., Ge, Z., Cai, J., Zhang, L.: Semi-supervised left atrium segmentation with mutual consistency training. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 297–306. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87196-3 28 26. Xu, K., Qin, M., Sun, F., Wang, Y., Chen, Y.K., Ren, F.: Learning in the frequency domain. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1740–1749 (2020) 27. Xu, M., Islam, M., Lim, C.M., Ren, H.: Class-incremental domain adaptation with smoothing and calibration for surgical report generation. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12904, pp. 269–278. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87202-1 26 28. Xu, M., Islam, M., Lim, C.M., Ren, H.: Learning domain adaptation with model calibration for surgical report generation in robotic surgery. In: 2021 IEEE International Conference on Robotics and Automation (ICRA), pp. 12350–12356. IEEE (2021) 29. Zhang, K., Zhuang, X.: Cyclemix: a holistic strategy for medical image segmentation from scribble supervision. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 11656–11665 (2022) 30. Zhang, K., Zhuang, X.: ShapePU: a new PU learning framework regularized by global consistency for scribble supervised cardiac segmentation. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13438, pp. 162–172. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16452-1 16
Modularity-Constrained Dynamic Representation Learning for Interpretable Brain Disorder Analysis with Functional MRI Qianqian Wang1 , Mengqi Wu1 , Yuqi Fang1 , Wei Wang2 , Lishan Qiao3(B) , and Mingxia Liu1(B) 1
Department of Radiology and BRIC, University of North Carolina at Chapel Hill, Chapel Hill, NC 27599, USA mingxia [email protected] 2 Department of Radiology, Beijing Youan Hospital, Capital Medical University, Beijing 100069, China 3 School of Mathematics Science, Liaocheng University, Shandong 252000, China [email protected]
Abstract. Resting-state functional MRI (rs-fMRI) is increasingly used to detect altered functional connectivity patterns caused by brain disorders, thereby facilitating objective quantification of brain pathology. Existing studies typically extract fMRI features using various machine/deep learning methods, but the generated imaging biomarkers are often challenging to interpret. Besides, the brain operates as a modular system with many cognitive/topological modules, where each module contains subsets of densely inter-connected regions-of-interest (ROIs) that are sparsely connected to ROIs in other modules. However, current methods cannot effectively characterize brain modularity. This paper proposes a modularity-constrained dynamic representation learning (MDRL) framework for interpretable brain disorder analysis with rs-fMRI. The MDRL consists of 3 parts: (1) dynamic graph construction, (2) modularity-constrained spatiotemporal graph neural network (MSGNN) for dynamic feature learning, and (3) prediction and biomarker detection. In particular, the MSGNN is designed to learn spatiotemporal dynamic representations of fMRI, constrained by 3 functional modules (i.e., central executive network, salience network, and default mode network). To enhance discriminative ability of learned features, we encourage the MSGNN to reconstruct network topology of input graphs. Experimental results on two public and one private datasets with a total of 1, 155 subjects validate that our MDRL outperforms several state-of-the-art methods in fMRI-based brain disorder analysis. The detected fMRI biomarkers have good explainability and can be potentially used to improve clinical diagnosis. Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 5. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 46–56, 2023. https://doi.org/10.1007/978-3-031-43907-0_5
Modularity-Constrained Dynamic Representation Learning Keywords: Functional MRI
1
47
· Modularity · Biomarker · Brain disorder
Introduction
Resting-state functional magnetic resonance imaging (rs-fMRI) has been increasingly used to help us understand pathological mechanisms of neurological disorders by revealing abnormal or dysfunctional brain connectivity patterns [1–4]. Brain regions-of-interest (ROIs) or functional connectivity (FC) involved in these patterns can be used as potential biomarkers to facilitate objective quantification of brain pathology [5]. Previous studies have designed various machine and deep learning models to extract fMRI features and explore disease-related imaging biomarkers [6,7]. However, due to the complexity of brain organization and black-box property of many learning-based models, the generated biomarkers are usually difficult to interpret, thus limiting their utility in clinical practice [8,9].
Fig. 1. Illustration of our modularity-constrained dynamic representation learning (MDRL) framework, with 3 components: (1) dynamic graph construction via sliding windows, (2) modularity-constrained spatiotemporal graph neural network (MSGNN) for dynamic representation learning, and (3) prediction and biomarker detection. The MSGNN is designed to learn spatiotemporal features via GIN and transformer layers, constrained by 3 neurocognitive modules (i.e., central executive network, salience network, and default mode network) and graph topology reconstruction.
On the other hand, the human brain operates as a modular system, where each module contains a set of ROIs that are densely connected within the module but sparsely connected to ROIs in other modules [10,11]. In particular, the central executive network (CEN), salience network (SN), and default mode network (DMN) are three prominent resting-state neurocognitive modules in the
48
Q. Wang et al.
brain, supporting efficient cognition [12]. Unfortunately, existing learning-based fMRI studies usually ignore such inherent modular brain structures [13,14]. To this end, we propose a modularity-constrained dynamic representation learning (MDRL) framework for interpretable brain disorder analysis with rsfMRI. As shown in Fig. 1, the MDRL consists of (1) dynamic graph construction, (2) modularity-constrained spatiotemporal graph neural network (MSGNN) for dynamic graph representation learning, and (3) prediction and biomarker detection. The MSGNN is designed to learn spatiotemporal features via graph isomorphism network and transformer layers, constrained by 3 neurocognitive modules (i.e., central executive network, salience network, and default mode network). To enhance discriminative ability of learned fMRI embeddings, we also encourage the MSGNN to reconstruct topology of input graphs. To our knowledge, this is among the first attempts to incorporate modularity prior to graph learning models for fMRI-based brain disorder analysis. Experimental results on two public and one private datasets validate the effectiveness of the MDRL in detecting three brain disorders with rs-fMRI data.
2 2.1
Materials and Methodology Subjects and Image Preprocessing
Two public datasets (i.e., ABIDE [15] and MDD [16]) and one private HIVassociated neurocognitive disorder (HAND) dataset with rs-fMRI are used. The two largest sites (i.e., NYU and UM) of ABIDE include 79 patients with autism spectrum disorder (ASD) and 105 healthy controls (HCs), and 68 ASDs and 77 HCs, respectively. The two largest sites (i.e., Site 20 and Site 21) of MDD contain 282 patients with major depressive disorder (MDD) and 251 HCs, 86 MDDs and 70 HCs, respectively. The HAND were collected from Beijing YouAn Hospital, with 67 asymptomatic neurocognitive impairment patients (ANIs) with HIV and 70 HCs. Demographics of subjects are reported in Supplementary Materials. All rs-fMRI data were preprocessed using the Data Processing Assistant for Resting-State fMRI (DPARSF) pipeline [17]. Major steps include (1) magnetization equilibrium by trimming the first 10 volumes, (2) slice timing correction and head motion correction, (3) regression of nuisance covariates (e.g., white matter signals, ventricle, and head motion parameters), (4) spatial normalization to the MNI space, (5) bandpass filtering (0.01–0.10 Hz). The average rs-fMRI time series of 116 ROIs defined by the AAL atlas are extracted for each subject. 2.2
Proposed Method
As shown in Fig. 1, the proposed MDRL consists of (1) dynamic graph construction via sliding windows, (2) MSGNN for dynamic graph representation learning, and (3) prediction and biomarker detection, with details introduced below.
Modularity-Constrained Dynamic Representation Learning
49
Dynamic Graph Construction. Considering that brain functional connectivity (FC) patterns change dynamically over time [18], we propose to first construct dynamic networks/graphs using a sliding window strategy for each subject. Denote original fMRI time series as S ∈ RN ×M , where N is the number of ROIs and M is the number of time points of blood-oxygen-level-dependent (BOLD) signals in rs-fMRI. We first divide the original time series into T segments along the temporal dimension via overlapped sliding windows, with the window size of Γ and the step size of τ . Then, we construct an FC network by calculating Pearson correlation (PC) coefficients between time series of pairwise ROIs for each of T segments, denoted as Xt ∈ RN ×N (t = 1, · · · , T ). The original feature for the j-th node is represented by the j-th row in Xt for segment t. Considering all connections in an FC network may include some noisy or redundant information, we retain the top 30% strongest edges in each FC network to generate an adjacent matrix At ∈ {0, 1}N ×N for segment t. Thus, the obtained dynamic graph sequence of each subject can be described as Gt = {At , Xt } (t = 1, · · · , T ). Modularity-Constrained Spatiotemporal GNN. With the dynamic graph sequence {Gt }Tt=1 as input, we design a modularity-constrained spatiotemporal graph neural network (MSGNN) to learn interpretable and discriminative graph embeddings, with two unique constraints: 1) a modularity constraint, and 2) a graph topology reconstruction constraint. In MSGNN, we first stack two graph isomorphism network (GIN) layers [19] for node-level feature learning. The nodelevel embedding Ht at the segment t learned by GIN layers is formulated as: Ht = ψ(ε(1) I + At )[ψ(ε(0) I + At )Xt W (0) ]W (1)
(1)
where ψ is nonlinear activation, ε(i) is a parameter at the i-th GIN layer, I is an identity matrix, and W (i) is the weight for the fully connected layers in GIN. 1) Modularity Constraint. It has been demonstrated that the central executive network (CEN), salience network (SN) and default mode network (DMN) are three crucial neurocognitive modules in the brain and these three modules have been consistently observed across different individuals and experimental paradigms, where CEN performs high-level cognitive tasks (e.g., decision-making and rule-based problem-solving), SN mainly detects external stimuli and coordinates brain neural resources, and DMN is responsible for self-related cognitive functions [10–12]. The ROIs/nodes within a module are densely inter-connected, resulting in a high degree of clustering between nodes from the same module. Based on such prior knowledge and clinical experience, we reasonably assume that the learned embeddings of nodes within the same neurocognitive module tend to be similar. We develop a novel modularity constraint to encourage similarity between paired node-level embeddings in the same module. Mathematically, the proposed modularity constraint is formulated as: LM = −
T t=1
K Nk k=1
i,j=1
t,k ht,k i · hj
t,k ht,k i · hj
(2)
50
Q. Wang et al.
where ht,k and ht,k are embeddings of two nodes in the k-th module (with Nk i j ROIs) at segment t, and K is the number of modules (K = 3 in this work). With Eq. (2), we encourage the MSGNN to focus on modular brain structures during representation learning, thus improving discriminative ability of fMRI features. 2) Graph Topology Reconstruction Constraint. To further enhance discriminative ability of learned embeddings, we propose to preserve graph topology by reconstructing adjacent matrices. For the segment t, its adjacent matrix At can be reconstructed through Aˆt = σ(Ht · Ht ), where σ is a nonlinear mapping function. The graph topology reconstruction constraint is then formulated as: LR =
T t=1
Ψ (At , Aˆt )
(3)
where Ψ is a cross-entropy loss function. We then apply an SERO operation [20] to generate graph-level embeddings based on node-level embeddings, formulated as ht = Ht Φ(P (2) σ(P (1) Ht φmean )), where Φ is a sigmoid function, P (1) and P (2) are learnable weight matrices, and φmean is average operation. 3) Temporal Feature Learning . To further capture temporal information, a single-head transformer is used to fuse features derived from T segments, with a self-attention mechanism to model temporal dynamics across segments. We then sum the learned features {hi }Ti=1 to obtain the whole-graph embedding. Prediction and Biomarker Detection. The whole-graph embedding is fed into a fully connected layer with Softmax for prediction, with final loss defined as: L = LC + λ1 LR + λ2 LM
(4)
where LC is a cross-entropy loss for prediction, and λ1 and λ2 are two hyperparameters. To facilitate interpretation of our learned graph embeddings, we calculate PC coefficients between paired node embeddings for each segment and average them across segments to obtain an FC network for each subject. The upper triangle of each FC network is flattened into a vector and Lasso [21] (with default parameter) is used to select discriminative features. Finally, we map these features to the original feature space to detect disease-related ROIs and FCs. Implementation. The MDRL is implemented in PyTorch and trained using an Adam optimizer (with learning rate of 0.001, training epochs of 30, batch size of 8 and τ = 20). We set window size Γ = 40 for NYU and Γ = 70 for the rest, and results of MDRL with different Γ values are shown in Supplementary Materials. In the modularity constraint, we randomly select m = 50% of all Nk (N2k −1) paired ROIs in the k-th module (with Nk ROIs) to constrain the MDRL.
Modularity-Constrained Dynamic Representation Learning
3
51
Experiment
Competing Methods. We compare the MDRL with 2 shallow methods: 1) linear SVM with node-level statistics (i.e., degree centrality, clustering coefficient, betweenness centrality, and eigenvector centrality) of FC networks as fMRI features (with each FC network constructed using PC), 2) XGBoost with the same features as SVM; and 4 state-of-the-art (SOTA) deep models with default architectures: 3) GCN [22], 4) GAT [23], 5) BrainGNN [9], and 6) STGCN [18]. Table 1. Results of seven methods on ABIDE. Method SVM XGBoost GCN GAT
ASD vs. HC classification on NYU AUC (%) ACC (%) SEN (%) SPE (%) 56.6(2.9)∗ 54.8(3.1) 51.5(4.6) 57.9(4.7) 61.9(0.6)∗ 63.0(1.6) 48.0(2.7) 75.9(3.7) 67.5(3.3)∗ 64.9(2.6)∗
STGCN
66.9(2.9)∗ 66.6(0.8)∗
MDRL (Ours)
72.6(1.7)
BrainGNN
BAC (%) 54.7(3.1) 61.9(0.6)
63.6(3.1)
51.0(5.1)
73.5(4.7)
62.3(2.8)
60.1(2.6)
53.0(4.9)
66.1(3.2)
59.5(2.8)
63.2(3.2)
57.1(4.8)
68.5(3.0)
62.8(3.2)
61.5(1.5)
53.6(2.3)
68.4(1.7)
65.6(2.1)
57.0(2.8)
74.1(3.1)
ASD vs. HC classification on UM AUC (%) ACC (%) SEN (%) 53.6(4.3)∗ 53.3(3.6) 50.3(4.5) 58.8(0.8)∗ 58.6(1.9) 47.6(3.6) 66.7(2.6) 66.5(3.5)∗
61.0(1.5)
65.9(2.5) 64.0(0.1)∗
65.6(1.9)
67.1(2.3)
SPE (%)
BAC (%)
56.6(4.7)
53.5(3.8)
70.0(3.9)
58.8(0.8)
60.0(3.0)
54.6(4.4)
66.5(5.2)
60.6(2.6)
60.4(3.1)
56.1(3.1)
65.4(6.9)
60.7(2.8)
62.7(2.6)
55.5(3.3)
68.1(5.7)
61.8(2.1)
63.9(0.1)
55.9(1.1)
72.1(0.5)
64.0(0.1)
64.5(1.4)
55.6(3.9)
72.7(2.5)
64.1(2.4)
The term ‘*’ denotes the results of MDRL and a competing method are statistically significantly different (p < 0.05).
Table 2. Results of seven methods on MDD. Method
MDD vs. HC classification on Site 20 AUC (%) ACC (%) SEN (%) SPE (%)
SVM
53.7(2.1) 55.9(1.9)∗
53.0(2.6) 55.9(2.2)
64.4(4.1)
47.4(1.3)
55.9(1.9)
55.7(2.7) 57.8(1.3)∗
54.9(2.1)
59.6(4.4)
50.1(4.4)
54.8(2.0)
55.7(1.3)
61.4(5.7)
49.5(4.4)
55.4(1.1)
52.8(2.1)
51.7(7.6)
55.0(8.5)
53.4(2.0)
54.6(4.1)
56.6(5.2)
52.2(2.9)
57.4(1.9)
62.2(5.0)
51.6(3.3)
XGBoost GCN GAT STGCN
56.3(2.4)∗ 54.2(4.5)∗
MDRL (Ours)
60.9(2.6)
BrainGNN
54.6(3.1)
51.5(1.8)
BAC (%) 53.0(2.4)
MDD vs. HC classification on Site 21 AUC (%) ACC (%) SEN (%) SPE (%) 52.8(3.0)∗ 52.7(2.7) 59.0(3.8) 45.8(4.7) ∗ 52.0(2.3) 52.5(2.7) 66.2(4.7) 37.8(4.6) 54.8(3.1)∗ 54.0(3.0) 60.9(6.0) 46.6(7.5) 53.2(3.1)∗ 52.8(2.4) 62.0(4.7) 42.7(7.1)
BAC (%) 52.4(2.8) 52.0(2.3) 53.8(3.3) 52.3(3.4)
53.5(8.6)
58.4(12.8)
45.7(2.0)
54.4(4.0)
53.9(3.3) 54.9(0.3)∗
52.1(5.4)
53.4(0.1)
61.4(1.8)
44.2(3.7)
52.7(0.9)
56.9(1.4)
56.6(9.4)
55.2(7.8)
58.1(10.2)
51.3(7.1)
54.6(7.8)
The term ‘*’ denotes the results of MDRL and a competing method are statistically significantly different (p < 0.05).
Experimental Setting. Three classification tasks are performed: 1) ASD vs. HC on ABIDE, 2) MDD vs. HC on MDD, and 3) ANI vs. HC on HAND. A 5-fold cross-validation (CV) strategy is employed. Within each fold, we also perform an inner 5-fold CV to select optimal parameters. Five evaluation metrics are used: area under ROC curve (AUC), accuracy (ACC), sensitivity (SEN), specificity (SPE), and balanced accuracy (BAC). Paired sample t-test is performed to evaluate whether the MDRL is significantly different from a competing method. Classification Results. Results achieved by different methods in three classification tasks on three datasets are reported in Tables 1–2 and Fig. 2. It can be seen that our MDRL generally outperforms two shallow methods (i.e., SVM and XGBoost) that rely on handcrafted node features without modeling wholegraph topological information. Compared with 4 SOTA deep learning methods, our MDRL achieves superior performance in terms of most metrics in three tasks. For instance, for ASD vs. HC classification on NYU of ABIDE (see Table 1), the
52
Q. Wang et al.
AUC value of MDRL is improved by 5.7% compared with BrainGNN (a SOTA method designed for brain network analysis). This implies the MDRL can learn discriminative graph representations to boost fMRI-based learning performance. Ablation Study. We compare the proposed MDRL with its three degenerated variants: 1) MDRLw/oM without the modularity constraint, 2) MDRLw/oR without the graph topology reconstruction constraint, and 3) MDRLw/oMR without the two constraints. The results are reported in Fig. 3, from which one can see that MDRL is superior to its three variants, verifying the effectiveness of the two constraints defined in Eqs. (2)–(3). Besides, MDRLw/oM is generally inferior to MDRLw/oR in three tasks, implying that the modularity constraint may contribute more to MDRL than the graph reconstruction constraint.
Fig. 2. Results of seven methods in ANI vs. HC classification on HAND.
Fig. 3. Performance of the MDRL and its variants in three tasks on three datasets.
Fig. 4. Visualization of the top 10 most discriminative functional connectivities identified by the MDRL in identifying 3 diseases on 3 datasets (with AAL for ROI partition).
Modularity-Constrained Dynamic Representation Learning
53
Discriminative ROIs and Functional Connectivities. The top 10 discriminative FCs detected by the MDRL in three tasks are shown in Fig. 4. The thickness of each line represents discriminative power that is proportional to the corresponding Lasso coefficient. For ASD identification (see Fig. 4 (a)), the FCs involved in thalamus and middle temporal gyrus are frequently identified, which complies with previous findings [24,25]. For MDD detection (see Fig. 4 (b)), we find that several ROIs (e.g., hippocampus, supplementary motor area and thalamus) are highly associated with MDD identification, which coincides with previous studies [26–28]. For ANI identification (see Fig. 4 (c)), the detected ROIs such as amygdala, right temporal pole: superior temporal gyrus and parahippocampal gyrus, are also reported in previous research [29–31]. This further demonstrates the effectiveness of the MDRL in disease-associated biomarker detection.
Fig. 5. Results of our MDRL with different hyperparameters (i.e., λ1 and λ2 ).
Fig. 6. Results of the proposed MDRL with different modularity ratios.
4
Discussion
Parameter Analysis. To investigate the influence of hyperparameters, we vary the values of two parameters (i.e., λ1 and λ2 ) in Eq. (4) and report the results of MDRL in Fig. 5. It can be seen from Fig. 5 that, with λ1 fixed, the performance of MDRL exhibits small fluctuations with the increase of parameter values of λ2 , implying that MDRL is not very sensitive to λ2 in three tasks. With λ2 fixed, the MDRL with a large λ1 (e.g., λ1 = 10) achieves worse performance. The possible reason could be that using a strong graph reconstruct constraint will make the model difficult to converge, thus degrading its learning performance.
54
Q. Wang et al.
Influence of Modularity Ratio. In the main experiments, we randomly select m = 50% of all Nk (N2k −1) paired ROIs in the k-th module (with Nk ROIs) to constrain the MDRL. We now vary the modularity ratio m within [0%, 25%, · · · , 100%] and record the results of MDRL in three tasks in Fig. 6. It can be seen from Fig. 6 that, when m < 75%, the ACC and AUC results generally increase as m increases. But when using a large modularity ratio (e.g., m = 100%), the MDRL cannot achieve satisfactory results. This may be due to the oversmoothing problem caused by using a too-strong modularity constraint. Influence of Network Construction. We use PC to construct the original FC networks in MDRL. We also use sparse representation (SR) and low-rank representation (LR) for network construction in MDRL and report results in Table 3. It can be seen from Table 3 that the MDRL with PC outperforms its two variants. The underlying reason could be that PC can model dependencies among regional BOLD signals without discarding any connection information. Table 3. Results of the MDRL with different FC network construction strategies. Method
ASD vs. HC on NYU of ABIDE AUC (%) ACC (%) BAC (%)
MDD vs. HC on Site 20 of MDD AUC (%) ACC (%) BAC (%)
ANI vs. HC on HAND AUC (%) ACC (%) BAC (%)
MDRL LR
62.2(3.5)
59.2(4.6)
59.4(4.6)
54.4(5.0)
53.5(4.9)
53.2(4.8)
58.2(3.3)
60.6(2.0)
60.7(1.8)
MDRL SR
60.9(1.9)
62.7(1.9)
60.8(2.7)
55.5(7.4)
52.7(3.8)
53.2(4.6)
64.6(4.0)
61.0(2.2)
61.2(2.0)
MDRL
72.6(1.7)
65.6(2.1)
65.6(2.0)
60.9(2.6)
57.4(1.9)
56.9(1.4)
66.2(3.4)
63.1(1.3)
63.2(1.3)
5
Conclusion and Future Work
In this work, we propose a modularity-constrained dynamic graph representation (MDRL) framework for fMRI-based brain disorder analysis. We first construct dynamic graphs for each subject and then design a modularity-constrained GNN to learn spatiotemporal representation, followed by prediction and biomarker detection. Experimental results on three rs-fMRI datasets validate the superiority of the MDRL in brain disease detection. Currently, we only characterize pairwise relationships of ROIs within 3 prominent neurocognitive modules (i.e., CEN, SN, and DMN) as prior knowledge to design the modularity constraint in MDRL. Fine-grained modular structure and disease-specific modularity constraint will be considered. Besides, we will employ advanced harmonization methods [32] to reduce inter-site variance, fully utilizing multi-site fMRI data for model training. Acknowledgment. M. Wu, Y. Fang, and M. Liu were supported by an NIH grant RF1AG073297.
References 1. Pagani, M., et al.: mTOR-related synaptic pathology causes autism spectrum disorder-associated functional hyperconnectivity. Nat. Commun. 12(1), 6084 (2021)
Modularity-Constrained Dynamic Representation Learning
55
2. Sezer, I., Pizzagalli, D.A., Sacchet, M.D.: Resting-state fMRI functional connectivity and mindfulness in clinical and non-clinical contexts: a review and synthesis. Neurosci. Biobehav. Rev. (2022) 104583 3. Liu, J., et al.: Astrocyte dysfunction drives abnormal resting-state functional connectivity in depression. Sci. Adv. 8(46), eabo2098 (2022) 4. Sahoo, D., Satterthwaite, T.D., Davatzikos, C.: Hierarchical extraction of functional connectivity components in human brain using resting-state fMRI. IEEE Trans. Med. Imaging 40(3), 940–950 (2020) 5. Traut, N., et al.: Insights from an autism imaging biomarker challenge: promises and threats to biomarker discovery. Neuroimage 255, 119171 (2022) 6. Azevedo, T., et al.: A deep graph neural network architecture for modelling spatiotemporal dynamics in resting-state functional MRI data. Med. Image Anal. 79, 102471 (2022) 7. Bessadok, A., Mahjoub, M.A., Rekik, I.: Graph neural networks in network neuroscience. IEEE Trans. Pattern Anal. Mach. Intell. (2022) 8. Zhang, Z., Xie, Y., Xing, F., McGough, M., Yang, L.: MDNet: a semantically and visually interpretable medical image diagnosis network. In: CVPR, pp. 6428–6436 (2017) 9. Li, X., et al.: BrainGNN: interpretable brain graph neural network for fMRI analysis. Med. Image Anal. 74, 102233 (2021) 10. Sporns, O., Betzel, R.F.: Modular brain networks. Annu. Rev. Psychol. 67, 613–640 (2016) 11. Bertolero, M.A., Yeo, B.T., D’Esposito, M.: The modular and integrative functional architecture of the human brain. Proc. Natl. Acad. Sci. 112(49), E6798–E6807 (2015) 12. Goulden, N., et al.: The salience network is responsible for switching between the default mode network and the central executive network: replication from DCM. Neuroimage 99, 180–190 (2014) 13. Geirhos, R., et al.: Shortcut learning in deep neural networks. Nat. Mach. Intell. 2(11), 665–673 (2020) 14. Knyazev, B., Taylor, G.W., Amer, M.: Understanding attention and generalization in graph neural networks. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 15. Di Martino, A., et al.: The autism brain imaging data exchange: Towards a largescale evaluation of the intrinsic brain architecture in autism. Mol. Psychiatry 19(6), 659–667 (2014) 16. Yan, C.G., et al.: Reduced default mode network functional connectivity in patients with recurrent major depressive disorder. Proc. Natl. Acad. Sci. 116(18), 9078– 9083 (2019) 17. Yan, C., Zang, Y.: DPARSF: a MATLAB toolbox for “pipeline” data analysis of resting-state fMRI. Front. Syst. Neurosci. 4, 13 (2010) 18. Gadgil, S., Zhao, Q., Pfefferbaum, A., Sullivan, E.V., Adeli, E., Pohl, K.M.: Spatiotemporal graph convolution for resting-state fMRI analysis. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12267, pp. 528–538. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59728-3 52 19. Kim, B.H., Ye, J.C.: Understanding graph isomorphism network for rs-fMRI functional connectivity analysis. Front. Neurosci. 630 (2020) 20. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: CVPR, pp. 7132– 7141 (2018) 21. Tibshirani, R.: Regression shrinkage and selection via the Lasso. J. Roy. Stat. Soc.: Ser. B (Methodol.) 58(1), 267–288 (1996)
56
Q. Wang et al.
22. Kipf, T.N., Welling, M.: Semi-supervised classification with graph convolutional networks. arXiv preprint arXiv:1609.02907 (2016) 23. Veliˇckovi´c, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., Bengio, Y.: Graph attention networks. arXiv preprint arXiv:1710.10903 (2017) 24. Ayub, R., et al.: Thalamocortical connectivity is associated with autism symptoms in high-functioning adults with autism and typically developing adults. Transl. Psychiatry 11(1), 93 (2021) 25. Xu, J., et al.: Specific functional connectivity patterns of middle temporal gyrus subregions in children and adults with autism spectrum disorder. Autism Res. 13(3), 410–422 (2020) 26. MacQueen, G., Frodl, T.: The hippocampus in major depression: evidence for the convergence of the bench and bedside in psychiatric research? Mol. Psychiatry 16(3), 252–264 (2011) 27. Sarkheil, P., Odysseos, P., Bee, I., Zvyagintsev, M., Neuner, I., Mathiak, K.: Functional connectivity of supplementary motor area during finger-tapping in major depression. Compr. Psychiatry 99, 152166 (2020) 28. Batail, J.M., Coloigner, J., Soulas, M., Robert, G., Barillot, C., Drapier, D.: Structural abnormalities associated with poor outcome of a major depressive episode: the role of thalamus. Psychiatry Res. Neuroimaging 305, 111158 (2020) 29. Clark, U.S., et al.: Effects of HIV and early life stress on amygdala morphometry and neurocognitive function. J. Int. Neuropsychol. Soc. 18(4), 657–668 (2012) 30. Zhan, Y., et al.: The resting state central auditory network: a potential marker of HIV-related central nervous system alterations. Ear Hear. 43(4), 1222 (2022) 31. Sarma, M.K., et al.: Regional brain gray and white matter changes in perinatally HIV-infected adolescents. NeuroImage Clin. 4, 29–34 (2014) 32. Guan, H., Liu, M.: Domain adaptation for medical image analysis: a survey. IEEE Trans. Biomed. Eng. 69(3), 1173–1185 (2021)
Anatomy-Driven Pathology Detection on Chest X-rays Philip Müller1(B) , Felix Meissen1 , Johannes Brandt1 , Georgios Kaissis1,2 , and Daniel Rueckert1,3 1
Institute for AI in Medicine, Technical University of Munich, Munich, Germany [email protected] 2 Helmholtz Zentrum Munich, Munich, Germany 3 Department of Computing, Imperial College London, London, UK Abstract. Pathology detection and delineation enables the automatic interpretation of medical scans such as chest X-rays while providing a high level of explainability to support radiologists in making informed decisions. However, annotating pathology bounding boxes is a timeconsuming task such that large public datasets for this purpose are scarce. Current approaches thus use weakly supervised object detection to learn the (rough) localization of pathologies from image-level annotations, which is however limited in performance due to the lack of bounding box supervision. We therefore propose anatomy-driven pathology detection (ADPD), which uses easy-to-annotate bounding boxes of anatomical regions as proxies for pathologies. We study two training approaches: supervised training using anatomy-level pathology labels and multiple instance learning (MIL) with image-level pathology labels. Our results show that our anatomy-level training approach outperforms weakly supervised methods and fully supervised detection with limited training samples, and our MIL approach is competitive with both baseline approaches, therefore demonstrating the potential of our approach. Keywords: Pathology detection
1
· Anatomical regions · Chest X-rays
Introduction
Chest radiographs (chest X-rays) represent the most widely utilized type of medical imaging examination globally and hold immense significance in the detection of prevalent thoracic diseases, including pneumonia and lung cancer, making them a crucial tool in clinical care [10,15]. Pathology detection and localization – for brevity we will use the term pathology detection throughout this work – enables the automatic interpretation of medical scans such as chest X-rays by predicting bounding boxes for detected pathologies. Unlike classification, which only predicts the presence of pathologies, it provides a high level of explainability supporting radiologists in making informed decisions. Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_6. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 57–66, 2023. https://doi.org/10.1007/978-3-031-43907-0_6
58
P. Müller et al.
However, while image classification labels can be automatically extracted from electronic health records or radiology reports [7,20], this is typically not possible for bounding boxes, thus limiting the availability of large datasets for pathology detection. Additionally, manually annotating pathology bounding boxes is a time-consuming task, further exacerbating the issue. The resulting scarcity of large, publicly available datasets with pathology bounding boxes limits the use of supervised methods for pathology detection, such that current approaches typically follow weakly supervised object detection approaches, where only classification labels are required for training. However, as these methods are not guided by any form of bounding boxes, their performance is limited. We, therefore, propose a novel approach towards pathology detection that uses anatomical region bounding boxes, solely defined on anatomical structures, as proxies for pathology bounding boxes. These region boxes are easier to annotate – the physiological shape of a healthy subject’s thorax can be learned relatively easily by medical students – and generalize better than those of pathologies, such that huge labeled datasets are available [21]. In summary: – We propose anatomy-driven pathology detection (ADPD), a pathology detection approach for chest X-rays, trained with pathology classification labels together with anatomical region bounding boxes as proxies for pathologies. – We study two training approaches: using localized (anatomy-level) pathology labels for our model Loc-ADPD and using image-level labels with multiple instance learning (MIL) for our model MIL-ADPD. – We train our models on the Chest ImaGenome [21] dataset and evaluate on NIH ChestX-ray 8 [20], where we found that our Loc-ADPD model outperforms both, weakly supervised methods and fully supervised detection with a small training set, while our MIL-ADPD model is competitive with supervised detection and slightly outperforms weakly supervised approaches.
2
Related Work
Weakly Supervised Pathology Detection. Due to the scarcity of bounding box annotations, pathology detection on chest X-rays is often tackled using weakly supervised object detection with Class Activation Mapping (CAM) [25], which only requires image-level classification labels. After training a classification model with global average pooling (GAP), an activation heatmap is computed by classifying each individual patch (extracted before pooling) with the trained classifier, before thresholding this heatmap for predicting bounding boxes. Inspired by this approach, several methods have been developed for chest X-rays [6,14,20,23]. While CheXNet [14] follows the original approach, the method provided with the NIH ChestX-ray 8 dataset [20] and the STL method [6] use Logsumexp (LSE) pooling [13], while the MultiMap model [23] uses max-min pooling as first proposed for the WELDON [3] method. Unlike our method, none of these methods utilize anatomical regions as proxies for predicting pathology bounding boxes, therefore leading to inferior performance.
Anatomy-Driven Pathology Detection on Chest X-rays
59
region tokens
Training + Inference
pneumonia infiltration cardiomegaly
Backbone
Region Detector
(DenseNet121)
(DETR decoder)
Pathology Classifier (on regions)
pneumonia infiltration cardiomegaly pneumonia infiltration cardiomegaly
Inference only
pneumonia
pneumonia: 0.72
Box Prediction (with region fusion)
target
predicted
Fig. 1. Overview of our method. Anatomical regions are first detected using a CNN backbone and a shallow detector. For each region, observed pathologies are predicted using a shared classifier. Bounding boxes for each pathology are then predicted by considering regions with positive predictions and fusing overlapping boxes.
Localized Pathology Classification. Anatomy-level pathology labels have been utilized before to train localized pathology classifiers [1,21] or to improve weakly supervised pathology detection [24]. Along with the Chest ImaGenome dataset [21] several localized pathology classification models have been proposed which use a Faster R-CNN [16] to extract anatomical region features before predicting observed pathologies for each region using either a linear model or a GCN model based on pathology co-occurrences. This approach has been further extended to use GCNs on anatomical region relationships [1]. While utilizing the same form of supervision as our method, these methods do not tackle pathology detection. In AGXNet [24], anatomy-level pathology classification labels are used to train a weakly-supervised pathology detection model. Unlike our and the other described methods, it does however not use anatomical region bounding boxes.
3 3.1
Method Model
Figure 1 provides an overview of our method. Given a chest X-ray, we apply a DenseNet121 [5] backbone and extract patch-wise features by using the feature map after the last convolutional layer (before GAP). We then apply a lightweight object detection model consisting of a single DETR [2] decoder layer to detect anatomical regions. Following [2], we use learned query tokens attending to patch features in the decoder layer, where each token corresponds to one predicted bounding box. As no anatomical region can occur more than once in each chest X-ray, each query token is assigned to exactly one pre-defined anatomical region, such that the number of tokens equals the number of anatomical regions. This one-to-one assignment of tokens and regions allows us to remove the Hungarian
60
P. Müller et al.
pneumonia infiltration cardiomegaly
0.42
pneumonia infiltration cardiomegaly
0.74
pneumonia infiltration cardiomegaly
0.71
0.27 0.15
0.61 0.15
Box Prediction
pneumonia: 0.74 pneumonia: 0.71
pneumonia: 0.72
Weighted Box Fusion
0.54 0.16
Fig. 2. Inference. For each pathology, the regions with pathology probability above a threshold are predicted as bounding boxes, which are then fused if overlapping.
matching used in [2]. As described next, the resulting per-region features from the output of the decoder layer will be used for predictions on each region. For predicting whether the associated region is present, we use a binary classifier with a single linear layer, for bounding box prediction we use a three-layer MLP followed by sigmoid. We consider the prediction of observed pathologies as a multi-label binary classification task and use a single linear layer (followed by sigmoid) to predict the probabilities of all pathologies. Each of these predictors is applied independently to each region with their weights shared across regions. We experimented with more complex pathology predictors like an MLP or a transformer layer but did not observe any benefits. We also did not observe improvements when using several decoder layers and observed degrading performance when using ROI pooling to compute region features. 3.2
Inference
During inference, the trained model predicts anatomical region bounding boxes and per-region pathology probabilities, which are then used to predict pathology bounding boxes in two steps, as shown in Fig. 2. In step (i), pathology probabilities are first thresholded and for each positive pathology (with probability larger than the threshold) the bounding box of the corresponding anatomical region is predicted as its pathology box, using the pathology probability as box score. This means, if a region contains several predicted pathologies, then all of its predicted pathologies share the same bounding box during step (i). In step (ii), weighted box fusion (WBF) [19] merges bounding boxes of the same pathology with IoU-overlaps above 0.03 and computes weighted averages (using box scores as weights) of their box coordinates. As many anatomical regions are at least partially overlapping, and we use a small IoU-overlap threshold, this allows the model to either pull the predicted boxes to relevant subparts of an anatomical region or to predict that pathologies stretch over several regions. 3.3
Training
The anatomical region detector is trained using the DETR loss [2] with fixed oneto-one matching (i.e. without Hungarian matching). For training the pathology classifier, we experiment with two different levels of supervision (Fig. 3).
Anatomy-Driven Pathology Detection on Chest X-rays Loc-ADPD predicted
61
MIL-ADPD target
pneumonia infiltration cardiomegaly
Loss
pneumonia infiltration cardiomegaly
Loss
pneumonia infiltration cardiomegaly
Loss
Loss Loss
Loss Loss
Loss Loss
pneumonia infiltration cardiomegaly
pneumonia infiltration cardiomegaly
pneumonia infiltration cardiomegaly
pneumonia infiltration cardiomegaly
pneumonia infiltration cardiomegaly
pneumonia infiltration cardiomegaly
predicted LSE
pneumonia infiltration cardiomegaly
target Loss Loss Loss
pneumonia infiltration cardiomegaly
Fig. 3. Training. Loc-ADPD: Pathology predictions of regions are directly trained using anatomy-level supervision. MIL-ADPD: Region predictions are first aggregated using LSE pooling and then trained using image-level supervision.
For our Loc-ADPD model, we utilize anatomy-level pathology classification labels. Here, the target set of observed pathologies is available for each anatomical region individually such that the pathology observation prediction can directly be trained for each anatomical region. We apply the ASL [17] loss function independently on each region-pathology pair and average the results over all regions and pathologies. The decoder feature dimension is set to 512. For our MIL-ADPD model, we experiment with a weaker form of supervision, where pathology classification labels are only available on the per-image level. We utilize multiple instance learning (MIL), where an image is considered a bag of individual instances (i.e. the anatomical regions), and only a single label (per pathology) is provided for the whole bag, which is positive if any of its instances is positive. To train using MIL, we first aggregate the predicted pathology probabilities of each region over all detected regions in the image using LSE pooling [13], acting as a smooth approximation of max pooling. The resulting per-image probability for each pathology is then trained using the ASL [17] loss. In this model, the decoder feature dimension is set to 256. In both models, the ASL loss is weighted by a factor of 0.01 before adding it to the DETR loss. We train using AdamW [12] with a learning rate of 3e−5 (Loc-ADPD) or 1e−4 (MIL-ADPD) and weight decay 1e−5 (Loc-ADPD) or 1e−4 (MIL-ADPD) in batches of 128 samples with early stopping (with 20 000 steps patience) for roughly 7 h on a single Nvidia RTX A6000. 3.4
Dataset
Training Dataset. We train on the Chest ImaGenome dataset [4,21,22]1 , consisting of roughly 240 000 frontal chest X-ray images with corresponding scene graphs automatically constructed from free-text radiology reports. It is derived from the MIMIC-CXR dataset [9,10], which is based on imaging studies from 65 079 patients performed at Beth Israel Deaconess Medical Center in Boston, US. Amongst other information, each scene graph contains bounding boxes for 29 1
https://physionet.org/content/chest-imagenome/1.0.0 Health Data License 1.5.0).
(PhysioNet
Credentialed
62
P. Müller et al.
unique anatomical regions with annotated attributes, where we consider positive anatomical finding and disease attributes as positive labels for pathologies, leading to binary anatomy-level annotations for 55 unique pathologies. We consider the image-level label for a pathology to be positive if any region is positively labeled with that pathology. We use the provided jpg-images [11]2 and follow the official MIMIC-CXR training split but only keep samples containing a scene graph with at least five valid region bounding boxes, resulting in a total of 234 307 training samples. During training, we use random resized cropping with size 224 × 224, apply contrast and brightness jittering, random affine augmentations, and Gaussian blurring. Evaluation Dataset and Class Mapping. We evaluate our method on the subset of 882 chest X-ray images with pathology bounding boxes, annotated by radiologists, from the NIH ChestXray-8 (CXR8) dataset [20]3 from the National Institutes of Health Clinical Center in the US. We use 50% for validation and keep the other 50% as a held-out test set. Note that for evaluation only pathology bounding boxes are required (to compute the metrics), while during training only anatomical region bounding boxes (without considering pathologies) are required. All images are center-cropped and resized to 224 × 224. The dataset contains bounding boxes for 8 unique pathologies. While partly overlapping with the training classes, a one-to-one correspondence is not possible for all classes. For some evaluation classes, we therefore use a many-to-one mapping where the class probability is computed as the mean over several training classes. We refer to the supp. material for a detailed study on class mappings.
4
Experiments and Results
4.1
Experimental Setup and Baselines
We compare our method against several weakly supervised object detection methods (CheXNet [14], STL [6], GradCAM [18], CXR [20], WELDON [3], MultiMap Model [23], LSE Model [13]), trained on the CXR8 training set using only image-level pathology labels. Note that some of these methods focus on (imagelevel) classification and do not report quantitative localization results. Nevertheless, we compare their localization approaches quantitatively with our method. We also use AGXNet [24] for comparison, a weakly supervised method trained using anatomy-level pathology labels but without any bounding box supervision. It was trained on MIMIC-CXR (sharing the images with our method) with labels from RadGraph [8] and finetuned on the CXR8 training set with imagelevel labels. Additionally, we also compare with a Faster-RCNN [16] trained on a small subset of roughly 500 samples from the CXR8 training set that have been 2 3
https://physionet.org/content/mimic-cxr-jpg/2.0.0/ (PhysioNet Credentialed Health Data License 1.5.0). https://www.kaggle.com/datasets/nih-chest-xrays/data (CC0: Public Domain).
Anatomy-Driven Pathology Detection on Chest X-rays
63
annotated with pathology bounding boxes by two medical experts, including one board-certified radiologist. Table 1. Results on the NIH ChestX-ray 8 dataset [20]. Our models Loc-ADPD and MIL-ADPD, trained using anatomy (An) bounding boxes, both outperform all weakly supervised methods trained with image-level pathology (Pa) and anatomy-level pathology (An-Pa) labels by a large margin. MIL-ADPD is competitive with the supervised baseline trained with pathology (Pa) bounding boxes, while Loc-ADPD outperforms it by a large margin. Method
Supervision IoU@10-70 IoU@10 IoU@30 IoU@50 Box Class mAP AP loc-acc AP loc-acc AP loc-acc
MIL-ADPD (ours) An w/o WBF Loc-ADPD (ours) An w/o WBF w/ MIL
Pa
7.84 5.42 An-Pa 10.89 8.88 10.29
14.01 11.05 19.99 17.02 19.16
0.68 0.67 0.85 0.84 0.84
8.85 7.97 12.43 9.65 10.95
0.65 0.65 0.84 0.83 0.83
7.03 3.44 8.72 7.36 8.00
0.65 0.64 0.83 0.83 0.82
CheXNet [14] STL [6] GradCAM [18] CXR [20] WELDON [3] MultiMap [23] LSE Model [13]
Pa Pa Pa Pa Pa Pa Pa
5.80 5.61 4.43 5.61 4.76 4.91 3.77
12.87 12.76 12.53 13.91 14.57 12.36 14.49
0.58 0.57 0.58 0.59 0.61 0.61 0.61
8.23 7.94 6.67 8.01 6.18 7.13 2.62
0.55 0.54 0.54 0.55 0.56 0.57 0.56
3.12 2.45 0.13 1.24 0.34 1.35 0.42
0.52 0.50 0.51 0.51 0.51 0.53 0.54
– – – – – – –
AGXNet [24]
–
An-Pa
5.30
11.39
0.59
6.58
0.56
4.14 0.54
Faster R-CNN
Pa
–
7.36
9.11
0.79
7.62
0.79
7.26 0.78
For all models, we only consider the predicted boxes with the highest box score per pathology, as the CXR8 dataset never contains more than one box per pathology. We report the standard object detection metrics average precision (AP) at different IoU-thresholds and the mean AP (mAP) over thresholds (0.1, 0.2, . . . , 0.7), commonly used thresholds on this dataset [20]. Additionally, we report the localization accuracy (loc-acc) [20], a common localization metric on this dataset, where we use a box score threshold of 0.7 for our method. 4.2
Pathology Detection Results
Comparison with Baselines. Table 1 shows the results of our MIL-ADPD and Loc-ADPD models and all baselines on the CXR8 test set. Compared to the best weakly supervised method with image-level supervision (CheXNet) our methods improve by large margins (MIL-ADPD by Δ+35.2%, Loc-ADPD by Δ+87.8% in mAP). Improvements are especially high when considering larger IoU-thresholds and huge improvements are also achieved in loc-acc at all thresholds. Both models also outperform AGXNet (which uses anatomy-level supervision) by large
64
P. Müller et al.
Fig. 4. Qualitative results of Loc-ADPD, with predicted (solid) and target (dashed) boxes. Cardiomegaly (red) is detected almost perfectly, as it is always exactly localized at one anatomical region. Other pathologies like atelectasis (blue), effusion (green), or pneumonia (cyan) are detected but often with non-perfect overlapping boxes. Detection also works well for predicting several overlapping pathologies (second from left). (Color figure online)
margins (MIL-ADPD by Δ + 47.9% and Loc-ADPD by Δ + 105.5% mAP), while improvements on larger thresholds are smaller here. Even when compared to Faster R-CNN trained on a small set of fully supervised samples, MIL-ADPD is competitive (Δ + 6.5%), while Loc-ADPD improves by Δ + 48.0%. However, on larger thresholds (IoU@50) the supervised baseline slightly outperforms MIL-ADPD, while Loc-ADPD is still superior. This shows that using anatomical regions as proxies is an effective approach to tackle pathology detection. While using image-level annotations (MIL-ADPD) already gives promising results, the full potential is only achieved using anatomy-level supervision (Loc-ADPD). Unlike Loc-ADPD and MIL-ADPD, all baselines were either trained or finetuned on the CXR8 dataset, showing that our method generalizes well to unseen datasets and that our class mapping is effective. For detailed results per pathology we refer to the supp. material. We found that the improvements of MIL-ADPD are mainly due to improved performance on Cardiomegaly and Mass detection, while Loc-ADPD consistently outperforms all baselines on all classes except Nodule, often by a large margin. Ablation Study. In Table 1 we also show the results of different ablation studies. Without WBF, results degrade for both of our models, highlighting the importance of merging region boxes. Combining the training strategies of Loc-ADPD and MIL-ADPD does not lead to an improved performance. Different class mappings between training and evaluation set are studied in the supp. material. Qualitative Results. As shown in Fig. 4 Loc-ADPD detects cardiomegaly almost perfectly, as it is always exactly localized at one anatomical region. Other pathologies are detected but often with too large or too small boxes as they only cover parts of anatomical regions or stretch over several of them, which cannot be completely corrected using WBF. Detection also works well for predicting several overlapping pathologies. For qualitative comparisons between Loc-ADPD and MIL-ADPD, we refer to the supp. material.
Anatomy-Driven Pathology Detection on Chest X-rays
5
65
Discussion and Conclusion
Limitations. While our proposed ADPD method outperforms all competing models, it is still subject to limitations. First, due to the dependence on region proxies, for pathologies covering only a small part of a region, our models predict the whole region, as highlighted by their incapability to detect nodules. We however note that in clinical practice, chest X-rays are not used for the final diagnosis of such pathologies and even rough localization can be beneficial. Additionally, while not requiring pathology bounding boxes, our models still require supervision in the form of anatomical region bounding boxes, and Loc-ADPD requires anatomy-level labels. However, anatomical bounding boxes are easier to annotate and predict than pathology bounding boxes, and the used anatomylevel labels were extracted automatically from radiology reports [21]. While our work is currently limited to chest X-rays, we see huge potential for modalities where abnormalities can be assigned to meaningful regions. Conclusion. We proposed a novel approach tackling pathology detection on chest X-rays using anatomical region bounding boxes. We studied two training approaches, using anatomy-level pathology labels and using image-level labels with MIL. Our experiments demonstrate that using anatomical regions as proxies improves results compared weakly supervised methods and supervised training on little data, thus providing a promising direction for future research.
References 1. Agu, N.N., et al.: AnaXNet: anatomy aware multi-label finding classification in chest X-ray. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 804–813. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87240-3_77 2. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., Zagoruyko, S.: Endto-end object detection with transformers. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12346, pp. 213–229. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58452-8_13 3. Durand, T., Thome, N., Cord, M.: Weldon: weakly supervised learning of deep convolutional neural networks. In: CVPR, pp. 4743–4752 (2016). https://doi.org/ 10.1109/CVPR.2016.513 4. Goldberger, A.L., et al.: PhysioBank, PhysioToolkit, and PhysioNet: components of a new research resource for complex physiologic signals. Circulation 101(23), e215–e220 (2000) 5. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: CVPR, pp. 2261–2269 (2017). https://doi.org/10.1109/ CVPR.2017.243 6. Hwang, S., Kim, H.-E.: Self-transfer learning for weakly supervised lesion localization. In: Ourselin, S., Joskowicz, L., Sabuncu, M.R., Unal, G., Wells, W. (eds.) MICCAI 2016. LNCS, vol. 9901, pp. 239–246. Springer, Cham (2016). https://doi. org/10.1007/978-3-319-46723-8_28 7. Irvin, J., et al.: CheXpert: a large chest radiograph dataset with uncertainty labels and expert comparison. In: AAAI, pp. 590–597 (2019). https://doi.org/10.1609/ aaai.v33i01.3301590
66
P. Müller et al.
8. Jain, S., et al.: Radgraph: extracting clinical entities and relations from radiology reports. In: NeurIPS (2021) 9. Johnson, A.E.W., et al.: Mimic-cxr, a de-identified publicly available database of chest radiographs with free-text reports. Sci. Data 6(1), 1–8 (2019) 10. Johnson, A.E.W., et al.: Mimic-cxr database (version 2.0.0). PhysioNet (2019) 11. Johnson, A.E.W., et al.: Mimic-cxr-jpg, a large publicly available database of labeled chest radiographs. arXiv preprint arXiv:1901.07042 (2019) 12. Loshchilov, I., Hutter, F.: Decoupled weight decay regularization. In: ICLR (2019) 13. Pinheiro, P.O., Collobert, R.: From image-level to pixel-level labeling with convolutional networks. In: CVPR, pp. 1713–1721 (2015). https://doi.org/10.1109/ CVPR.2015.7298780 14. Rajpurkar, P., et al.: CheXnet: radiologist-level pneumonia detection on chest xrays with deep learning. arXiv preprint arXiv:1711.05225 (2017). https://doi.org/ 10.48550/arXiv.1711.05225 15. Raoof, S., Feigin, D., Sung, A., Raoof, S., Irugulpati, L., Rosenow, E.C., III.: Interpretation of plain chest roentgenogram. Chest 141(2), 545–558 (2012) 16. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. In: NIPS, vol. 28 (2015) 17. Ridnik, T., et al.: Asymmetric loss for multi-label classification. In: ICCV, pp. 82–91 (2021). https://doi.org/10.1109/ICCV48922.2021.00015 18. Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: GradCAM: visual explanations from deep networks via gradient-based localization. In: CVPR, pp. 618–626 (2017). https://doi.org/10.1109/ICCV.2017.74 19. Solovyev, R., Wang, W., Gabruseva, T.: Weighted boxes fusion: ensembling boxes from different object detection models. Image Vis. Comput. 107, 104117 (2021). https://doi.org/10.1016/j.imavis.2021.104117 20. Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., Summers, R.M.: Chestx-ray8: hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. In: CVPR, pp. 2097–2106 (2017). https://doi.org/10.1109/CVPR.2017.369 21. Wu, J., et al.: Chest imagenome dataset for clinical reasoning. In: NIPS (2021) 22. Wu, J.T., et al.: Chest imagenome dataset (version 1.0.0). PhysioNet (2021) 23. Yan, C., Yao, J., Li, R., Xu, Z., Huang, J.: Weakly supervised deep learning for thoracic disease classification and localization on chest x-rays. In: ACM BCB, pp. 103–110 (2018) 24. Yu, K., Ghosh, S., Liu, Z., Deible, C., Batmanghelich, K.: Anatomy-guided weaklysupervised abnormality localization in chest x-rays. In: Wang, L., et al. (eds.) MICCAI 2022. LNCS, vol. 13435, pp. 658–668. Springer, Cham (2022). https:// doi.org/10.1007/978-3-031-16443-9_63 25. Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: CVPR, pp. 2921–2929 (2016). https://doi.org/ 10.1109/CVPR.2016.319
VesselVAE: Recursive Variational Autoencoders for 3D Blood Vessel Synthesis Paula Feldman1,3(B) , Miguel Fainstein3 , Viviana Siless3 , Claudio Delrieux1,2 , and Emmanuel Iarussi1,3 1
Consejo Nacional de Investigaciones Científicas y Técnicas, Buenos Aires, Argentina [email protected] 2 Universidad Nacional del Sur, Bahía Blanca, Argentina 3 Universidad Torcuato Di Tella, Buenos Aires, Argentina Abstract. We present a data-driven generative framework for synthesizing blood vessel 3D geometry. This is a challenging task due to the complexity of vascular systems, which are highly variating in shape, size, and structure. Existing model-based methods provide some degree of control and variation in the structures produced, but fail to capture the diversity of actual anatomical data. We developed VesselVAE, a recursive variational Neural Network that fully exploits the hierarchical organization of the vessel and learns a low-dimensional manifold encoding branch connectivity along with geometry features describing the target surface. After training, the VesselVAE latent space can be sampled to generate new vessel geometries. To the best of our knowledge, this work is the first to utilize this technique for synthesizing blood vessels. We achieve similarities of synthetic and real data for radius (.97), length (.95), and tortuosity (.96). By leveraging the power of deep neural networks, we generate 3D models of blood vessels that are both accurate and diverse, which is crucial for medical and surgical training, hemodynamic simulations, and many other purposes. Keywords: Vascular 3D model Networks
1
· Generative modeling · Neural
Introduction
Accurate 3D models of blood vessels are increasingly required for several purposes in Medicine and Science [25]. These meshes are typically generated using either image segmentation or synthetic methods. Despite significant advances in vessel segmentation [26], reconstructing thin features accurately from medical images remains challenging [2]. Manual editing of vessel geometry is a tedious and error prone task that requires expert medical knowledge, which explains the Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_7. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 67–76, 2023. https://doi.org/10.1007/978-3-031-43907-0_7
68
P. Feldman et al.
scarcity of curated datasets. As a result, several methods have been developed to adequately synthesize blood vessel geometry [29]. Within the existing literature on generating vascular 3D models, we identified two primary types of algorithms: fractal-based, and space-filling algorithms. Fractal-based algorithms use a set of fixed rules that include different branching parameters, such as the ratio of asymmetry in arterial bifurcations and the relationship between the diameter of the vessel and the flow [7,33]. On the other hand, space-filling algorithms allow the blood vessels to grow into a specific perfusion volume while aligning with hemodynamic laws and constraints on the formation of blood vessels [9,17,21,22,25]. Although these model-based methods provide some degree of control and variation in the structures produced, they often fail to capture the diversity of real anatomical data. In recent years, deep neural networks led to the development of powerful generative models [30], such as Generative Adversarial Networks [8,12] and Diffusion Models [11], which produced groundbreaking performance in many applications, ranging from image and video synthesis to molecular design. These advances have inspired the creation of novel network architectures to model 3D shapes using voxel representations [28], point clouds [31], signed distance functions [19], and polygonal meshes [18]. In particular, and close to our aim, Wolterink et al. [27] propose a GAN model capable of generating coronary artery anatomies. However, this model is limited to generating single-channel blood vessels and thus does not support the generation of more complex, tree-like vessel topologies. In this work we propose a novel data-driven framework named VesselVAE for synthesizing blood vessel geometry. Our generative framework is based on a Recursive variational Neural Network (RvNN), that has been applied in various contexts, including natural language [23,24], shape semantics modeling [14,15], and document layout generation [20]. In contrast to previous data-driven methods, our recursive network fully exploits the hierarchical organization of the vessel and learns a low-dimensional manifold encoding branch connectivity along with geometry features describing the target surface. Once trained, the VesselVAE latent space is sampled to generate new vessel geometries. To the best of our knowledge, this work is the first to synthesize multi-branch blood vessel trees by learning from real data. Experiments show that synth and real blood vessel geometries are highly similar measured with the cosine similarity: radius (.97), length (.95), and tortuosity (.96).
2
Methods
Input. The network input is a binary tree representation of the blood vessel 3D geometry. Formally, each tree is defined as a tuple (T, E), where T is the set of nodes, and E is the set of directed edges connecting a pair of nodes (n, m), with n, m ∈ T . In order to encode a 3D model into this representation, vessel segments V are parameterized by a central axis consisting of ordered points in Euclidean space: V = v1 , v2 , . . . , vN and a radius r, assuming a piece-wise tubular vessel for simplicity. We then construct the binary tree as a set of nodes T = n1 , n2 , . . . , nN ,
Recursive Variational Autoencoders for 3D Blood Vessel Synthesis
69
Fig. 1. Top: Overview of the Recursive variational Neural Network for synthesizing blood vessel structures. The architecture follows an Encoder-Decoder framework which can handle the hierarchical tree representation of the vessels. VesselVAE learns to generate the topology and attributes for each node in the tree, which is then used to synthesize 3D meshes. Bottom: Layers of the Encoder and Decoder networks comprising branches of fully-connected layers followed by leaky ReLU activations. Note that the right/left Enc-MLPs within the Encoder are triggered respectively when the incoming node in the tree is identified as a right or left child. Similarly, the Decoder only uses right/left Dec-MLPs when the Node Classifier predicts bifurcations.
where each node ni represents a vessel segment v and contains an attribute vector xi = [xi , yi , zi , ri ] ∈ R4 with the coordinates of the corresponding point and its radius ri . See Sect. 3 for details. Network Architecture. The proposed generative model is a Recursive variational Neural Network (RvNN) consisting of two main components: the Encoder (Enc) and the Decoder (Dec) networks. The Encoder transforms a tree structure into a hierarchical encoding on the learned manifold. The Decoder network is capable of sampling from this encoded space to decode tree structures, as depicted in Fig. 1. The encoding and decoding processes are achieved through a depth-first traversal of the tree, where each node is combined with its parent node recursively. The model outputs a hierarchy of vessel branches, where each internal node in the hierarchy is represented by a vector that encodes its own attributes and the information of all subsequent nodes in the tree. Within the RvNN Decoder network there are two essential components: the Node Classifier (Cls) and the Features Decoder Multi-Layer Perceptron (Features Dec-MLP). The Node Classifier discerns the type of node to be decoded, whether it is a leaf node or an internal node with one or two bifurcations. This
70
P. Feldman et al.
is implemented as a multi-layer perceptron trained to predict a three-category bifurcation probability based on the encoded vector as input. Complementing the Node Classifier, the Features Dec-MLP is responsible for reconstructing the attributes of each node, specifically its coordinates and radius. Furthermore, two additional components, the Right and Left Dec-MLP, are in charge of recursively decoding the next encoded node in the tree hierarchy. These decoder’s branches execute based on the classifier prediction for that encoded node. If the Node Classifier predicts a single child for a node, a right child is assumed by default. In addition to the core architecture, our model is further augmented with three auxiliary, shallow, fully-connected neural networks: fμ , fσ , and gz . Positioned before the RvNN bottleneck, the fμ and fσ networks shape the distribution of the latent space where encoded tree structures lie. Conversely, the gz network, situated after the bottleneck, facilitates the decoding of latent variables, aiding the Decoder network in the reconstruction of tree structures. Collectively, these supplementary networks streamline the data transformation process through the model. All activation functions used in our networks are leaky ReLUs. See the Appendix for implementation details. Objective. Our generative model is trained to learn a probability distribution over the latent space that can be used to generate new blood vessel segments. After encoding, the decoder takes samples from a multivariate Gaussian distribution: zs (x) ∼ N (μ, σ) with μ = fμ (Enc(x)) and σ = fσ (Enc(x)), where Enc is the recursive encoder and fμ , fσ are two fully-connected neural networks. In order to recover the feature vectors x for each node along with the tree topology, we simultaneously train the regression network (Features Dec-MLP in Fig. 1) on a reconstruction objective Lrecon , and the Node Classifier using Ltopo . Additionally, in line with the general framework proposed by β-VAE [10], we incorporated a Kullback-Leibler (KL) divergence term encouraging the distribution p(zs (x)) over all training samples x to move closer to the prior of the standard normal distribution p(z). We therefore minimize the following equation: L = Lrecon + αLtopo + γLKL ,
(1)
where the reconstruction loss is defined as Lrecon = Dec (zs (x)) − x2 , the Kullback-Leibler divergence loss is LKL = DKL (p (zs (x)) p(z)), and the topology objective is a three-class cross entropy loss Ltopo = Σ3c=1 xc log(Cls(Dec(x))c ). Notice that xc is a binary indicator (0 or 1) for the true class of the sample x. Specifically, xc = 1 if the sample belongs to class c and 0 otherwise. Cls(Dec(x))c is the predicted probability of the sample x belonging to class c (zero, one, or two bifurcations), as output by the classifier. Here, Dec(x) denotes the encoded-decoded node representation of the input sample x. 3D Mesh Synthesis. Several algorithms have been proposed in the literature to generate a surface 3D mesh from a tree-structured centerline [29]. For simplicity and efficiency, we chose the approach described in [6], which produces good quality meshes from centerlines with a low sample rate. The implemented method iterates through the points in the curve generating a coarse quadrilateral
Recursive Variational Autoencoders for 3D Blood Vessel Synthesis
71
Fig. 2. Dataset and pre-processing overview: The raw meshes from the IntraA 3D collection undergo pre-processing using the VMTK toolkit. This step is crucial for extracting centerlines and cross-sections from the meshes, which are then used to construct their binary tree representations.
mesh along the segments and joints. The centerline sampling step is crucial for a successful reconstruction outcome. Thus, our re-sampling is not equispaced but rather changes with curvature and radius along the centerline, increasing the frequency of sampling near high-curvature regions. This results in a better quality and more accurate mesh. Finally, Catmull-Clark subdivision algorithm [5] is used to increase mesh resolution and smooth out the surface.
3
Experimental Setup
Materials. We trained our networks using a subset of the open-access IntrA dataset1 published by Yang et al. in 2020 [32]. This subset consisted of 1694 healthy vessel segments reconstructed from 2D MRA images of patients. We converted 3D meshes into a binary tree representation and used the network extraction script from the VMTK toolkit2 to extract the centerline coordinates of each vessel model. The centerline points were determined based on the ratio between the sphere step and the local maximum radius, which was computed using the advancement ratio specified by the user. The radius of the blood vessel conduit at each centerline sample was determined using the computed crosssections assuming a maximal circular shape (See Fig. 2). To improve computational efficiency during recursive tree traversal, we implemented an algorithm that balances each tree by identifying a new root. We additionally trimmed trees to a depth of ten in our experiments. This decision reflects a balance between the computational demands of depth-first tree traversal in each training step and the complexity of the training meshes. We excluded from our study trees 1 2
https://github.com/intra3d2019/IntrA. http://www.vmtk.org/vmtkscripts/vmtknetworkextraction.
72
P. Feldman et al.
that exhibited greater depth, nodes with more than two children, or with loops. However, non-binary trees can be converted into binary trees and it is possible to train with deeper trees at the expense of higher computational costs. Ultimately, we were able to obtain 700 binary trees from the original meshes using this approach. Implementation Details. For the centerline extraction, we set the advancement ratio in the VMTK script to 1.05. The script can sometimes produce multiple cross-sections at centerline bifurcations. In those cases, we selected the sample with the lowest radius, which ensures proper alignment with the centerline principal direction. All attributes were normalized to a range of [0, 1]. For the mesh reconstruction we used 4 iterations of Catmull-Clark subdivision algorithm. The data pre-processing pipeline and network code were implemented in Python and PyTorch Framework. Training. In all stages, we set the batch size to 10 and used the ADAM optimizer with β1 = 0.9, β2 = 0.999, and a learning rate of 1 × 10−4 . We set α = .3 and γ = .001 for Eq. 1 in our experiments. To enhance computation speed, we implemented dynamic batching [16], which groups together operations involving input trees of dissimilar shapes and different nodes within a single input graph. It takes approximately 12 h to train our models on a workstation equipped with an NVIDIA A100 GPU, 80 GB VRAM, and 256 GB RAM. However, the memory footprint during training is very small (≤1 GB) due to the use of a lightweight tree representation. This means that the amount of memory required to store and manipulate our training data structures is minimal. During training, we ensure that the reconstructed tree aligns with the original structure, rather than relying solely on the classifier’s predictions. We train the classifier using a crossentropy loss that compares its predictions to the actual values from the original tree. Since the number of nodes in each class is unbalanced, we scale the weight given to each class in the cross-entropy loss using the inverse of each class count. During preliminary experiments, we observed that accurately classifying nodes closer to the tree root is critical. This is because a miss-classification of top nodes has a cascading effect on all subsequent nodes in the tree (i.e. skip reconstructing a branch). To account for this, we introduce a weighting scheme that for each node, assigns a weight to the cross-entropy loss based on the number of total child nodes. The weight is normalized by the total number of nodes in the tree. Metrics. We defined a set of metrics to evaluate our trained network’s performance. By using these metrics, we can determine how well the generated 3D models of blood vessels match the original dataset distribution, as well as the diversity of the generated output. The chosen metrics have been widely used in the field of blood vessel 3D modeling, and have shown to provide reliable and accurate quantification of blood vessels main characteristics [3,13]. We analyzed tortuosity per branch, the vessel centerline total length, and the average radius of the tree. Tortuosity distance metric [4] is a widely used metric in the field of blood vessel analysis, mainly because of its clinical importance. It measures the amount of twistiness in each branch of the vessel. Vessel’s total length and
Recursive Variational Autoencoders for 3D Blood Vessel Synthesis
73
Fig. 3. (a) shows the histograms of total length, average radius and tortuosity per branch for both, real and synthetic samples. (b) shows a visual comparison among our method and two baselines [9, 27].
average radius were used in previous work to distinguish healthy vasculature from cancerous malformations. Finally, in order to measure the distance across distributions for each metric, we compute the cosine similarity.
4
Results
We conducted both quantitative and qualitative analyses to evaluate the model’s performance. For the quantitative analyses, we implemented a set of metrics commonly used for characterizing blood vessels. We computed histograms of the radius, total length, and tortuosity for the real blood vessel set and the generated set (700 samples) in Fig. 3 (a). The distributions are aligned and consistent. We measured the closeness of histograms with the cosine similarity by projecting the distribution into a vector of n-dimensional space (n is the number of bins in the histogram). Since our points are positive, the results range from 0 to 1. We obtain a radius cosine similarity of .97, a total length cosine similarity of .95,
74
P. Feldman et al.
and a tortuosity cosine similarity of .96. Results show high similarities between histograms demonstrating that generated blood vessels are realistic. Given the differences with the baselines generated topologies, for a fair comparison, we limited our evaluation to a visual inspection of the meshes. The qualitative analyses consisted of a visual evaluation of the reconstructed outputs provided by the decoder network. We visually compared them to stateof-the-art methods in Fig. 3 (b). The method described by Wolterink and colleagues [27] is able to generate realistic blood vessels but without branches, and the method described by Hamarneh et al. [9] is capable of generating branches with straight shapes, missing on realistic modeling. In contrast, our method is capable of generating realistic blood vessels containing branches, with smooth varying radius, lengths, and tortuosity.
5
Conclusions
We have presented a novel approach for synthesizing blood vessel models using a variational recursive autoencoder. Our method enables efficient encoding and decoding of binary tree structures, and produces high-quality synthesized models. In the future, we aim to explore combinations of our approach with representing surfaces by the zero level set in a differentiable implicit neural representation (INR) [1]. This could lead to more accurate and efficient modeling of blood vessels and potentially other non-tree-like structures such as capillary networks. Since the presented framework would require significant adaptations to accommodate such complex topologies, exploring this problem would certainly be an interesting direction for future research. Additionally, the generated geometries might show self-intersections. In the future, we would like to incorporate restrictions into the generative model to avoid such artifacts. Overall, we believe that our proposed approach holds great promise for advancing 3D blood vessel geometry synthesis and contributing to the development of new clinical tools for healthcare professionals. Acknowledgements. This project was supported by grants from Salesforce, USA (Einstein AI 2020), National Scientific and Technical Research Council (CONICET), Argentina (PIP 2021-2023 GI - 11220200102981CO), and Universidad Torcuato Di Tella, Argentina.
References 1. Alblas, D., Brune, C., Yeung, K.K., Wolterink, J.M.: Going off-grid: continuous implicit neural representations for 3D vascular modeling. In: Camara, O., et al. (eds.) STACOM 2022. LNCS, vol. 13593, pp. 79–90. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-23443-9_8 2. Alblas, D., Brune, C., Wolterink, J.M.: Deep learning-based carotid artery vessel wall segmentation in black-blood MRI using anatomical priors. arXiv preprint arXiv:2112.01137 (2021)
Recursive Variational Autoencoders for 3D Blood Vessel Synthesis
75
3. Bullitt, E., et al.: Vascular attributes and malignant brain tumors. In: Ellis, R.E., Peters, T.M. (eds.) MICCAI 2003. LNCS, vol. 2878, pp. 671–679. Springer, Heidelberg (2003). https://doi.org/10.1007/978-3-540-39899-8_82 4. Bullitt, E., Gerig, G., Pizer, S.M., Lin, W., Aylward, S.R.: Measuring tortuosity of the intracerebral vasculature from MRA images. IEEE Trans. Med. Imaging 22(9), 1163–1171 (2003) 5. Catmull, E., Clark, J.: Recursively generated b-spline surfaces on arbitrary topological meshes. Comput. Aided Des. 10(6), 350–355 (1978) 6. Felkel, P., Wegenkittl, R., Buhler, K.: Surface models of tube trees. In: Proceedings Computer Graphics International, pp. 70–77. IEEE (2004) 7. Galarreta-Valverde, M.A., Macedo, M.M., Mekkaoui, C., Jackowski, M.P.: Threedimensional synthetic blood vessel generation using stochastic l-systems. In: Medical Imaging 2013: Image Processing, vol. 8669, pp. 414–419. SPIE (2013) 8. Goodfellow, I., et al.: Generative adversarial networks. Commun. ACM 63(11), 139–144 (2020) 9. Hamarneh, G., Jassi, P.: Vascusynth: simulating vascular trees for generating volumetric image data with ground-truth segmentation and tree analysis. Comput. Med. Imaging Graph. 34(8), 605–616 (2010) 10. Higgins, I., et al.: beta-VAE: learning basic visual concepts with a constrained variational framework. In: International Conference on Learning Representations (2017) 11. Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Adv. Neural. Inf. Process. Syst. 33, 6840–6851 (2020) 12. Kazeminia, S., et al.: GANs for medical image analysis. Artif. Intell. Med. 109, 101938 (2020) 13. Lang, S., et al.: Three-dimensional quantification of capillary networks in healthy and cancerous tissues of two mice. Microvasc. Res. 84(3), 314–322 (2012) 14. Li, J., Xu, K., Chaudhuri, S., Yumer, E., Zhang, H., Guibas, L.: Grass: generative recursive autoencoders for shape structures. ACM Trans. Graph. (TOG) 36(4), 1–14 (2017) 15. Li, M., et al.: Grains: generative recursive autoencoders for indoor scenes. ACM Trans. Graph. (TOG) 38(2), 1–16 (2019) 16. Looks, M., Herreshoff, M., Hutchins, D., Norvig, P.: Deep learning with dynamic computation graphs. arXiv preprint arXiv:1702.02181 (2017) 17. Merrem, A., Bartzsch, S., Laissue, J., Oelfke, U.: Computational modelling of the cerebral cortical microvasculature: effect of x-ray microbeams versus broad beam irradiation. Phys. Med. Biol. 62(10), 3902 (2017) 18. Nash, C., Ganin, Y., Eslami, S.A., Battaglia, P.: Polygen: an autoregressive generative model of 3D meshes. In: International Conference on Machine Learning, pp. 7220–7229. PMLR (2020) 19. Park, J.J., Florence, P., Straub, J., Newcombe, R., Lovegrove, S.: DeepSDF: learning continuous signed distance functions for shape representation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 165–174 (2019) 20. Patil, A.G., Ben-Eliezer, O., Perel, O., Averbuch-Elor, H.: Read: recursive autoencoders for document layout generation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, pp. 544–545 (2020) 21. Rauch, N., Harders, M.: Interactive synthesis of 3D geometries of blood vessels. In: Theisel, H., Wimmer, M. (eds.) Eurographics 2021 - Short Papers. The Eurographics Association (2021)
76
P. Feldman et al.
22. Schneider, M., Reichold, J., Weber, B., Székely, G., Hirsch, S.: Tissue metabolism driven arterial tree generation. Med. Image Anal. 16(7), 1397–1414 (2012) 23. Socher, R.: Recursive deep learning for natural language processing and computer vision. Stanford University (2014) 24. Socher, R., Lin, C.C., Manning, C., Ng, A.Y.: Parsing natural scenes and natural language with recursive neural networks. In: Proceedings of the 28th International Conference on Machine Learning (ICML-11), pp. 129–136 (2011) 25. Talou, G.D.M., Safaei, S., Hunter, P.J., Blanco, P.J.: Adaptive constrained constructive optimisation for complex vascularisation processes. Sci. Rep. 11(1), 1–22 (2021) 26. Tetteh, G., et al.: Deepvesselnet: vessel segmentation, centerline prediction, and bifurcation detection in 3-D angiographic volumes. Front. Neurosci. 1285 (2020) 27. Wolterink, J.M., Leiner, T., Isgum, I.: Blood vessel geometry synthesis using generative adversarial networks. arXiv preprint arXiv:1804.04381 (2018) 28. Wu, J., Zhang, C., Xue, T., Freeman, B., Tenenbaum, J.: Learning a probabilistic latent space of object shapes via 3D generative-adversarial modeling. In: Advances in Neural Information Processing Systems, vol. 29 (2016) 29. Wu, J., Hu, Q., Ma, X.: Comparative study of surface modeling methods for vascular structures. Comput. Med. Imaging Graph. 37(1), 4–14 (2013) 30. Xu, M., et al.: Generative AI-empowered simulation for autonomous driving in vehicular mixed reality metaverses. arXiv preprint arXiv:2302.08418 (2023) 31. Yang, G., Huang, X., Hao, Z., Liu, M.Y., Belongie, S., Hariharan, B.: Pointflow: 3D point cloud generation with continuous normalizing flows. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4541–4550 (2019) 32. Yang, X., Xia, D., Kin, T., Igarashi, T.: Intra: 3D intracranial aneurysm dataset for deep learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2656–2666 (2020) 33. Zamir, M.: Arterial branching within the confines of fractal l-system formalism. J. Gen. Physiol. 118(3), 267–276 (2001)
Dense Transformer based Enhanced Coding Network for Unsupervised Metal Artifact Reduction Wangduo Xie(B)
and Matthew B. Blaschko
Center for Processing Speech and Images, Department of ESAT, KU Leuven, Leuven, Belgium {wangduo.xie,matthew.blaschko}@esat.kuleuven.be Abstract. CT images corrupted by metal artifacts have serious negative effects on clinical diagnosis. Considering the difficulty of collecting paired data with ground truth in clinical settings, unsupervised methods for metal artifact reduction are of high interest. However, it is difficult for previous unsupervised methods to retain structural information from CT images while handling the non-local characteristics of metal artifacts. To address these challenges, we proposed a novel Dense Transformer based Enhanced Coding Network (DTEC-Net) for unsupervised metal artifact reduction. Specifically, we introduce a Hierarchical Disentangling Encoder, supported by the high-order dense process, and transformer to obtain densely encoded sequences with long-range correspondence. Then, we present a second-order disentanglement method to improve the dense sequence’s decoding process. Extensive experiments and model discussions illustrate DTEC-Net’s effectiveness, which outperforms the previous state-of-the-art methods on a benchmark dataset, and greatly reduces metal artifacts while restoring richer texture details. Keywords: Metal artifact reduction · CT image restoration Unsupervised learning · Enhanced coding
1
·
Introduction
CT technology can recover the internal details of the human body in a noninvasive way and has been widely used in clinical practice. However, if there is metal in the tissue, metal artifacts (MA) will appear in the reconstructed CT image, which will corrupt the image and affect the medical diagnosis [1,6]. In light of the clinical need for MA reduction, various traditional methods [5,6,12,19] have been proposed to solve the problem by using interpolation and iterative optimization. As machine learning research increasingly impacts medical imaging, deep learning based methods have been proposed for MA reduction. Specifically, these methods can be roughly divided into supervised and unsupervised categories according to the degree of supervision. In the supervised category, the methods [9,14,16,18] based on the dual domain (sinogram and image Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_8. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 77–86, 2023. https://doi.org/10.1007/978-3-031-43907-0_8
78
W. Xie and M. B. Blaschko
domains) can achieve good performance for MA reduction. However, supervised learning methods are hindered by the lack of large-scale real-world data pairs consisting of “images with MA” and “images without MA” representing the same region. The lack of such data can lead algorithms trained on synthetic data to over-fit simulated data pairs, resulting in difficulties in generalizing to clinical settings [11]. Furthermore, although sinogram data can bring additional information, it is difficult to collect in realistic settings [8,15]. Therefore, unsupervised methods based only on the image domain are strongly needed in practice. For unsupervised methods in the image domain, Liao et al. [8] used Generative Adversarial Networks (GANs) [2] to disentangle the MA from the underlying clean structure of the artifact-affected image in latent space by using unpaired data with and without MA. Although the method can separate the artifact component in the latent space, the features from the latent space can’t represent rich low-level information of the original input. Further, it’s also hard for the encoder to represent long-range correspondence across different regions. Accordingly, the restored image loses texture details and can’t retain structure from the CT image. In the same unsupervised setting, Lyu et al. [11] directly separate the MA component and clean structure in image space using a CycleGAN-based method [25]. Although implementation in the image space makes it possible to construct dual constraints, directly operating in the image space affects the algorithm’s performance upper limit, because it is difficult to encode in the image space as much low-level information as the feature space. Considering the importance of low-level features in the latent space for generating the artifact-free component, we propose a novel Dense Transformer based Enhanced Coding Network (DTEC-Net) for unsupervised metal artifact reduction, which can obtain low-level features with hierarchical information and map them to a clean image space through adversarial training. DTEC-Net contains our developed Hierarchical Disentangling Encoder (HDE), which utilizes longrange correspondences obtained by a lightweight transformer and a high-order dense process to produce the enhanced coded sequence. To ease the burden of decoding the sequence, we also propose a second-order disentanglement method to finish the sequence decomposition. Extensive empirical results show that our method can not only reduce the MA greatly and generate high-quality images, but also surpasses the competing unsupervised approaches.
2
Methodology
We design a Hierarchical Disentangling Encoder(HDE) that can capture lowlevel sequences and enable high-performance restoration. Moreover, to reduce the burden of the decoder group brought by the complicated sequences, we propose a second-order disentanglement mechanism. The intuition is shown in Fig. 1. 2.1
Hierarchical Disentangling Encoder (HDE)
As shown in Fig. 2(a), the generator of DTEC-Net consists of three encoders and four decoders. We design the HDE to play the role of Encoder1 for enhanced
DTEC-Net for Unsupervised Metal Artifact Reduction
79
Fig. 1. (a) CT image with metal artifacts. (b) Blue/Orange arrows: Reuse of low-level features/Long-range correspondence. (c) Output of our DTEC-Net. (d) Ground truth. (Color figure online)
Fig. 2. (a) Generator of DTEC-Net. xa : input with artifacts. yc : unpaired input without artifacts. Xl and xh are defined in the Sect. 2.1. xm : the MA parts in latent space. xs : the overall “structural part” in latent space. xc (or ya ): the output after removing (or adding) the artifacts. x a (or yc ): the output of the identity map. (b) Generator of ADN [8]. The data relationship is shown in [8]. In addition to the difference in disentanglement, DTEC-Net and ADN also have different inner structures.
coding. Specifically, for the HDE’s input image xa ∈ R1×H×W with MA, HDE first uses a convolution for the preliminary feature extraction and produces a high-dimensional tensor xl0 with c channels. Then, xl0 will be encoded by three Dense Transformers for Disentanglement (DTDs) in a first-order reuse manner [4,23]. Specifically, the output xli of the ith DTD can be characterized as: fDTDi (fs-hde (cat(xli−1 , xli−2 ..., , xl0 ))), i = 2, ..., N. xli = (1) fDTDi (xli−1 ), i = 1. In Eq. (1), fs-hde represents the channel compression of the concatenation of multiple DTDs’ outputs, and N represents the total number of DTDs in the HDE. As shown in Fig. 3, HDE can obtain the hierarchical information sequence Xl {xl0 , ...xlN } and high-level semantic features xh xlN . As shown in Fig. 2(b), Encoder1 of ADN cannot characterize upstream low-level information, and results in limited performance. By using HDE, the
80
W. Xie and M. B. Blaschko
Fig. 3. The architecture of Hierarchical Disentangling Encoder (HDE).
upstream of the DTEC-Net’s Encoder1 can represent rich low-level information, and be encoded in the efficient way described in Eq. (1). After generating the enhanced coding sequences Xl with long-range correspondence and densely reused information, DTEC-Net can decode it back to the clean image domain by using the proposed second-order disentanglement for MA reduction, which reduces the decoder group’s burden to a large extent. 2.2
Dense Transformer for Disentanglement (DTD)
In addition to the first-order feature multiplexing given in Eq. (1), HDE also uses the DTD to enable second-order feature reuse. The relationship between HDE and DTD is shown in Fig. 3. Inspired by [7,20], DTD first uses a lightweight transformer based on the Swin transformer [7,10] to represent content-based information with long-range correspondence inside of every partition window. It then performs in-depth extraction and second-order reuse. Specifically, the input x1 ∈ RC×H×W of the DTD will be processed sequentially by the lightweight transformer and groups of convolutions in the form of second-order dense connections. The output xj+1 of the jth convolution with ReLU, which is connected in a second-order dense pattern, can be expressed as: xj+1 =
fcj (cat(x1 , x2 , ..., xj )), j = 2, 3, ..., J. fcj (ftransformer-light (xj )), j = 1.
(2)
In Eq. (2), fcj indicates the jth convolution with ReLU after the lightweight transformer, and the J indicates the total number of convolutions after the lightweight transformer and is empirically set to six. The dense connection method can effectively reuse low-level features [22,23] so that the latent space
DTEC-Net for Unsupervised Metal Artifact Reduction
81
including these type of features will help the decoder to restore clean images without metal artifacts. Because the low-level information on different channels has different importance to the final restoration task, we use the channel attention mechanism [3] to filter the output of the final convolution layer: xout = xJ+1 fMLP (fpooling (x1 )) + x1 ,
(3)
where represents the Hadamard product, fMLP indicates a multi-layer perceptron with only one hidden layer, and fpooling represents global pooling. Because the transformer usually requires a large amount of data for training and CT image datasets are usually smaller than those for natural images, we do lightweight processing for the Swin transformer. Specifically, for an input tensor x ∈ RC×H×W of the lightweight transformer, the number of channels will be reduced from C to Cin to lighten the burden of the attention matrix. Then, a residual block is employed to extract information with low redundancy. After completing lightweight handling, the tensor will first be partitioned 2 HW into multiple local windows and flattened to xin ∈ R( P 2 )×P ×Cin according the pre-operation [7,10] of the Swin transformer. P × P represents the window size for partitioning as shown in Fig. 1(b). Then, the attention matrix belonging to the ith window can be calculated by pairwise multiplication between converted vectors in Si {xin (i, j, :)|j = 0, 1, ..., P 2 − 1}. Specifically, by using a linear map from RCin to RCa for every vector in Si , the query key and value: Q, K, 2 V ∈ RP ×Ca can be derived. Afterwards, the attention matrix for each window can be obtained by the following formula [10]: √ Attention(Q, K, V ) = SoftMax(QK T / Ca )V. (4) In actual operation, we use window-based multi-head attention (MSA) [7,10,13] to replace the single-head attention because of the performance improvement [7]. The output of the Swin transformer layer will be unflattened and operated by post processing (POP) which consists of a classic convolution and layer norm (LN) with flatten and unflatten operations. After POP, the lightweight tensor with fewer channels will be unsqueezed to expand the channels, and finally added to the original input x in the form of residuals. 2.3
Second-Order Disentanglement for MA Reduction (SOD-MAR)
As mentioned in Sect. 2.1, Xl represents the hierarchical sequence and facilitates the generator’s representation. However, Xl needs to be decoded by a highcapacity decoder to match the encoder. Considering that Decoder2 does not directly participate in the restoration branch and already loaded up the complicated artifact part xm in traditional first-order disentanglement learning [8], to reduce the burden of the decoder group, we propose and analyze SOD-MAR. Specifically, Decoder2 of DTEC-Net doesn’t decode sequence Xl , it only decodes the combination of second-order disentangled information xh ∈ Xl and the latent feature xm representing the artifact parts shown in Fig. 2(a). In order
82
W. Xie and M. B. Blaschko
Fig. 4. The architecture of Decoder1 (Decoder2). (xm∗ ) represents the Decoder2 case.
to complete the process, Decoder2 uses the structure shown in Fig. 4 to finish the decoding step, which is also used by Decoder1 to decode the sequence Xl . Moreover, we don’t only map the xh into Decoder1 and Decoder2 while dropping the Xl \{xh } to implement the burden reduction, because the low-level information in Xl \{xh } is vital for restoring artifact-free images. Furthermore, xh will be disturbed by noise from the approaching target xa of Decoder2 while information Xl \ {xh } upstream from the HDE can counteract the noise disturbance to a certain extent. The reason behind the counteraction is that the update to upstream parameters is not as large as that of the downstream parameters. 2.4
Loss Function
Following [8], we use discriminators D0 , D1 to constrain the output xc and ya : Ladv = E[log(1 − D0 (xc )) + log(D0 (yc ))] + E[log(D1 (xa )) + log(1 − D1 (ya ))]. (5) The above xa , yc represent the input as shown in Fig. 2(a). Following [8], we use the reconstruction loss Lrec to constrain the identity map, and also use the artifact consistency loss Lart and self-reduction loss Lself to control the optimization process. The coefficients for each of these losses are set as in [8].
3
Empirical Results
Synthesized DeepLesion Dataset. Following [8], we randomly select 4186 images from DeepLesion [17] and 100 metal templates [21] to build a dataset. The simulation is consistent with [8,21]. For training, we randomly select 3986 images from DeepLesion combined with 90 metal templates for simulation. The 3986 images will be divided to two disjoint image sets with and without MA after simulation. Then a random combination can form the physically unpaired data with and without MA in the training process. Besides, another 200 images combined with the remaining 10 metal templates are used for the testing process. Real Clinic Dataset. We randomly combine 6165 artifacts-affected images and 20729 artifacts-free images from SpineWeb1 [8] for training, and 105 artifactsaffected images from SpineWeb for testing. 1
spineweb.digitalimaginggroup.ca.
DTEC-Net for Unsupervised Metal Artifact Reduction
83
Implementation Details. We use peak signal-to-noise ratio (PSNR) and structural similarity index (SSIM) to measure performance. We use mean squared error (MSE) only for measuring ablation experiments. For Synthesized DeepLesion dataset (and Real Clinic dataset), we set the batch size to 2 (and 2) and trained the network for 77 (and 60) epochs using the Adam optimizer. Our DTEC-Net was implemented in Pytorch using an Nvidia Tesla P100. Table 1. Ablation study on Synthesized DeepLesion under different settings. ↑: Higher value is better; ↓: Lower value is better. The best values are in bold. Model
PSNR↑ SSIM↑ MSE↓
HDE with one DTD
34.46
27.12
HDE with two DTD
34.71
0.938
24.96
HDE with three DTD (only Transformer)
34.31
0.936
27.40
HDE with three DTD (without SOD-MAR)
34.91
0.940
24.36
HDE with three DTD (with SOD-MAR, Ours) 35.11
3.1
0.937
0.941 22.89
Ablation Study
To verify the effectiveness of the proposed methods, ablation experiments were carried out on Synthesized DeepLesion. The results are shown in Table 1. The Impact of DTD in HDE. In this experiment, we change the encoding ability of HDE by changing the number of DTDs. We first use only one DTD to build the HDE, then the PSNR is 0.65 dB lower than our DTEC-Net using three DTDs. Additionally, the average MSE in this case is much higher than DTECNet. When the number of DTDs increases to two, the performance improves by 0.25 dB and is already better than the SOTA method [11]. As we further increase the number of DTDs to three, the PSNR and SSIM increase 0.4 dB and 0.003, respectively. The number of DTDs is finally set to three in a trade-off between computation and performance. To match different encoders and decoders and facilitate training, we also adjust the accept headers of Decoder1 to adapt to the sequence length determined by the different numbers of DTDs. Only Transformer in DTD. Although the transformer can obtain better longrange correspondence than convolutions, it lacks the multiplexing of low-level information. For every DTD in DTEC-Net, we delete the second-order feature reuse pattern and only keep the lightweight transformer, the degraded version’s results are 0.8 dB lower than our DTEC-Net. At the same time, great instability appears in generative adversarial training. So, only using the transformer cannot achieve good results in reducing metal artifacts. Removing SOD-MAR. Although SOD-MAR mainly helps by easing the burden of decoding as discussed in Sect. 2.3, it also has a performance gain compared to first-order disentanglement. We delete the SOD-MAR in DTEC-Net and let xh be the unique feature decoded by Decoder1. The Performance is 0.2 dB lower than our DTEC-Net, while MSE increases by 1.47.
84
W. Xie and M. B. Blaschko
Fig. 5. Visual comparison(Metal implants are colored in red. The bottom values represent PSNR/SSIM). Our method has sharper edges and richer textures than ADN. (Color figure online) Table 2. Quantitative comparison of different methods on Synthesized DeepLesion. The best results are in bold.
3.2
Method Classification Method
PSNR↑
SSIM↑
Conventional
LI [5]
32.00 [8]
0.910 [8]
Supervised
CNNMAR [21]
32.50 [8]
0.914 [8]
Unsupervised
CycleGAN [25]
30.80 [8]
0.729 [8]
Unsupervised
RCN [24]
32.98 [11] 0.918 [11]
Unsupervised
ADN [8]
33.60 [8]
Unsupervised
U-DuDoNet [11]
34.54 [11] 0.934 [11]
Unsupervised
DTEC-Net(Ours) 35.11
0.924 [8] 0.941
Comparison to State-of-the-Art (SOTA)
For a fair comparison, we mainly compare with SOTA methods under unsupervised settings: ADN [8], U-DuDoNet [11], RCN [24], and CycleGAN [25]. We also compare with the traditional method LI [5] and classical supervised method CNNMAR [21]. The quantitative results of ADN, CycleGAN, CNNMAR and LI are taken from [8], the results of U-DuDoNet and RCN are taken from [11]. Because ADN has open-source code, we run their code for qualitative results. Quantitative Results. As shown in Table 2. For the Synthesized DeepLesion Dataset, our method has the highest PSNR and SSIM value and outperforms the baseline ADN by 1.51 dB in PSNR and 0.017 in SSIM. At the same time, it also exceeds the SOTA method U-DuDoNet by 0.57 dB. For the Real Clinic Dataset, the numerical results can’t be calculated because the ground truth does not exist. We will present the qualitative results in the appendix. Furthermore, as our work is single-domain based, it has the potential to be easily applied in clinical practice.
DTEC-Net for Unsupervised Metal Artifact Reduction
85
Qualitative Results. A visual comparison is shown in Fig. 5. Our method not only reduces artifacts to a large extent, but also has sharper edges and richer textures than the compared method. More results are shown in the appendix.
4
Conclusion
In this paper, we proposed a Dense Transformer based Enhanced Coding Network (DTEC-Net) for unsupervised metal-artifact reduction. In DTEC-Net, we developed a Hierarchical Disentangling Encoder (HDE) to represent longrange correspondence and produce an enhanced coding sequence. By using this sequence, the DTEC-Net can better recover low-level characteristics. In addition, to decrease the burden of decoding, we specifically design a Second-order Disentanglement for MA Reduction (SOD-MAR) to finish the sequence decomposition. The extensive quantitative and qualitative experiments demonstrate our DTEC-Net’s effectiveness and show it outperforms other SOTA methods. Acknowledgements. This research work was undertaken in the context of Horizon 2020 MSCA ETN project “xCTing” (Project ID: 956172).
References 1. Barrett, J.F., Keat, N.: Artifacts in CT: recognition and avoidance. Radiographics 24(6), 1679–1691 (2004) 2. Goodfellow, I., et al.: Generative adversarial networks. Commun. ACM 63(11), 139–144 (2020) 3. Hu, J., Shen, L., Sun, G.: Squeeze-and-excitation networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7132–7141 (2018) 4. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4700–4708 (2017) 5. Kalender, W.A., Hebel, R., Ebersberger, J.: Reduction of CT artifacts caused by metallic implants. Radiology 164(2), 576–577 (1987) 6. Lemmens, C., Faul, D., Nuyts, J.: Suppression of metal artifacts in CT using a reconstruction procedure that combines map and projection completion. IEEE Trans. Med. Imaging 28(2), 250–260 (2008) 7. Liang, J., Cao, J., Sun, G., Zhang, K., Van Gool, L., Timofte, R.: SwinIR: image restoration using swin transformer. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1833–1844 (2021) 8. Liao, H., Lin, W.A., Zhou, S.K., Luo, J.: ADN: artifact disentanglement network for unsupervised metal artifact reduction. IEEE Trans. Med. Imaging 39(3), 634– 643 (2019) 9. Lin, W.A., et al.: Dudonet: dual domain network for CT metal artifact reduction. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10512–10521 (2019) 10. Liu, Z., et al.: Swin transformer: hierarchical vision transformer using shifted windows. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10012–10022 (2021)
86
W. Xie and M. B. Blaschko
11. Lyu, Y., Fu, J., Peng, C., Zhou, S.K.: U-DuDoNet: unpaired dual-domain network for CT metal artifact reduction. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12906, pp. 296–306. Springer, Cham (2021). https://doi.org/10.1007/ 978-3-030-87231-1_29 12. Meyer, E., Raupach, R., Lell, M., Schmidt, B., Kachelrieß, M.: Normalized metal artifact reduction (NMAR) in computed tomography. Med. Phys. 37(10), 5482– 5493 (2010) 13. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems, vol. 30 (2017) 14. Wang, H., Li, Y., Zhang, H., Meng, D., Zheng, Y.: Indudonet+: a deep unfolding dual domain network for metal artifact reduction in CT images. Med. Image Anal. 85, 102729 (2023) 15. Wang, H., Xie, Q., Li, Y., Huang, Y., Meng, D., Zheng, Y.: Orientation-shared convolution representation for CT metal artifact learning. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13436, pp. 665– 675. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16446-0_63 16. Wang, T., et al.: Dual-domain adaptive-scaling non-local network for CT metal artifact reduction. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12906, pp. 243–253. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087231-1_24 17. Yan, K., Wang, X., Lu, L., Summers, R.M.: Deeplesion: automated mining of largescale lesion annotations and universal lesion detection with deep learning. J. Med. Imaging 5(3), 036501–036501 (2018) 18. Yu, L., Zhang, Z., Li, X., Xing, L.: Deep sinogram completion with image prior for metal artifact reduction in CT images. IEEE Trans. Med. Imaging 40(1), 228–238 (2020) 19. Zhang, H., Wang, L., Li, L., Cai, A., Hu, G., Yan, B.: Iterative metal artifact reduction for x-ray computed tomography using unmatched projector/backprojector pairs. Med. Phys. 43(6Part1), 3019–3033 (2016) 20. Zhang, J., Zhang, Y., Gu, J., Zhang, Y., Kong, L., Yuan, X.: Accurate image restoration with attention retractable transformer. arXiv preprint arXiv:2210.01427 (2022) 21. Zhang, Y., Yu, H.: Convolutional neural network based metal artifact reduction in x-ray computed tomography. IEEE Trans. Med. Imaging 37(6), 1370–1381 (2018) 22. Zhang, Y., Li, K., Li, K., Wang, L., Zhong, B., Fu, Y.: Image super-resolution using very deep residual channel attention networks. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11211, pp. 294–310. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01234-2_18 23. Zhang, Y., Tian, Y., Kong, Y., Zhong, B., Fu, Y.: Residual dense network for image super-resolution. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2472–2481 (2018) 24. Zhao, B., Li, J., Ren, Q., Zhong, Y.: Unsupervised reused convolutional network for metal artifact reduction. In: Yang, H., Pasupa, K., Leung, A.C.-S., Kwok, J.T., Chan, J.H., King, I. (eds.) ICONIP 2020. CCIS, vol. 1332, pp. 589–596. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-63820-7_67 25. Zhu, J.Y., Park, T., Isola, P., Efros, A.A.: Unpaired image-to-image translation using cycle-consistent adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2223–2232 (2017)
Multi-scale Cross-restoration Framework for Electrocardiogram Anomaly Detection Aofan Jiang1,2 , Chaoqin Huang1,2,4 , Qing Cao3 , Shuang Wu3 , Zi Zeng3 , Kang Chen3 , Ya Zhang1,2 , and Yanfeng Wang1,2(B) 1 Shanghai Jiao Tong University, Shanghai, China {stillunnamed,huangchaoqin,ya_zhang,wangyanfeng}@sjtu.edu.cn 2 Shanghai AI Laboratory, Shanghai, China 3 Ruijin Hospital, Shanghai Jiao Tong University School of Medicine, Shanghai, China {cq30553,ck11208}@rjh.com.cn, {shuang-renata,zengzidoct}@sjtu.edu.cn 4 National University of Singapore, Singapore, Singapore
Abstract. Electrocardiogram (ECG) is a widely used diagnostic tool for detecting heart conditions. Rare cardiac diseases may be underdiagnosed using traditional ECG analysis, considering that no training dataset can exhaust all possible cardiac disorders. This paper proposes using anomaly detection to identify any unhealthy status, with normal ECGs solely for training. However, detecting anomalies in ECG can be challenging due to significant inter-individual differences and anomalies present in both global rhythm and local morphology. To address this challenge, this paper introduces a novel multi-scale cross-restoration framework for ECG anomaly detection and localization that considers both local and global ECG characteristics. The proposed framework employs a two-branch autoencoder to facilitate multi-scale feature learning through a masking and restoration process, with one branch focusing on global features from the entire ECG and the other on local features from heartbeat-level details, mimicking the diagnostic process of cardiologists. Anomalies are identified by their high restoration errors. To evaluate the performance on a large number of individuals, this paper introduces a new challenging benchmark with signal point-level ground truths annotated by experienced cardiologists. The proposed method demonstrates state-of-the-art performance on this benchmark and two other well-known ECG datasets. The benchmark dataset and source code are available at: https://github. com/MediaBrain-SJTU/ECGAD
Keywords: Anomaly Detection
· Electrocardiogram
A. Jiang and C. Huang—Equal contribution.
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_9. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 87–97, 2023. https://doi.org/10.1007/978-3-031-43907-0_9
88
1
A. Jiang et al.
Introduction
The electrocardiogram (ECG) is a monitoring tool widely used to evaluate the heart status of patients and provide information on cardiac electrophysiology. Developing automated analysis systems capable of detecting and identifying abnormal signals is crucial in light of the importance of ECGs in medical diagnosis and the need to ease the workload of clinicians. However, training a classifier on labeled ECGs that focus on specific diseases may not recognize new abnormal statuses that were not encountered during training, given the diversity and rarity of cardiac diseases [8,16,23]. On the other hand, anomaly detection, which is trained only on normal healthy data, can identify any potential abnormal status and avoid the failure to detect rare cardiac diseases [10,17,21]. The current anomaly detection techniques, including one-class discriminative approaches [2,14], reconstruction-based approaches [15,30], and self-supervised learning-based approaches [3,26], all operate under the assumption that models trained solely on normal data will struggle to process anomalous data and thus the substantial drop in performance presents an indication of anomalies. While anomaly detection has been widely used in the medical field to analyze medical images [12,24] and time-series data [18,29], detecting anomalies in ECG data is particularly challenging due to the substantial inter-individual differences and the presence of anomalies in both global rhythm and local morphology. So far, few studies have investigated anomaly detection in ECG [11,29]. TSL [29] uses expert knowledge-guided amplitude- and frequency-based data transformations to simulate anomalies for different individuals. BeatGAN [11] employs a generative adversarial network to separately reconstruct normalized heartbeats instead of the entire raw ECG signal. While BeatGAN alleviates individual differences, it neglects the important global rhythm information of the ECG. This paper proposes a novel multi-scale cross-restoration framework for ECG anomaly detection and localization. To our best knowledge, this is the first work to integrate both local and global characteristics for ECG anomaly detection. To take into account multi-scale data, the framework adopts a two-branch autoencoder architecture, with one branch focusing on global features from the entire ECG and the other on local features from heartbeat-level details. A multiscale cross-attention module is introduced, which learns to combine the two feature types for making the final prediction. This module imitates the diagnostic process followed by experienced cardiologists who carefully examine both the entire ECG and individual heartbeats to detect abnormalities in both the overall rhythm and the specific local morphology of the signal [7]. Each of the branches employs a masking and restoration strategy, i.e., the model learns how to perform temporal-dependent signal inpainting from the adjacent unmasked regions within a specific individual. Such context-aware restoration has the advantage of making the restoration less susceptible to individual differences. During testing, anomalies are identified as samples or regions with high restoration errors. To comprehensively evaluate the performance of the proposed method on a large number of individuals, we adopt the public PTB-XL database [22] with only patient-level diagnoses, and ask experienced cardiologists to provide signal
Multi-scale Cross-restoration Framework for ECG Anomaly Detection
89
Fig. 1. The multi-scale cross-restoration framework for ECG anomaly detection.
point-level localization annotations. The resulting dataset is then introduced as a large-scale challenging benchmark for ECG anomaly detection and localization. The proposed method is evaluated on this challenging benchmark as well as on two traditional ECG anomaly detection benchmarks [6,13]. The experimental results have shown that the proposed method outperforms several stateof-the-art methods for both anomaly detection and localization, highlighting its potential for real-world clinical diagnosis.
2
Method
In this paper, we focus on unsupervised anomaly detection and localization on ECGs, training based on only normal ECG data. Formally, given a set of N normal ECGs denoted as {xi , i = 1, ..., N }, where xi ∈ RD represents the vectorized representation of the i-th ECG consisting of D signal points, the objective is to train a computational model capable of identifying whether a new ECG is normal or anomalous, and localize the regions of anomalies in abnormal ECGs. 2.1
Multi-scale Cross-restoration
In Fig. 1, we present an overview of our two-branch framework for ECG anomaly detection. One branch is responsible for learning global ECG features, while the other focuses on local heartbeat details. Our framework comprises four main components: (i) masking and encoding, (ii) multi-scale cross-attention module, (iii) uncertainty-aware restoration, and (iv) trend generation module. We provide detailed explanations of each of these components in the following sections. Masking and Encoding. Given a pair consisting of a global ECG signal xg ∈ RD and a randomly selected local heartbeat xl ∈ Rd segmented from xg for
90
A. Jiang et al.
training, as shown in Fig. 1, we apply two random masks, Mg and Ml , to mask xg and xl , respectively. To enable multi-scale feature learning, Ml is applied to a consecutive small region to facilitate detail restoration, while Mg is applied to several distinct regions distributed throughout the whole sequence to facilitate global rhythm restoration. The masked signals are processed separately by global and local encoders, Eg and El , resulting in global feature fgin = Eg (xg Mg ) and local feature flin = El (xl Ml ), where denotes the element-wise product. Multi-scale Cross-attention. To capture the relationship between global and local features, we use the self-attention mechanism [20] on the concatenated feature of fgin and flin . Specifically, the attention mechanism is expressed as T
√ )V , where Q, K, V are identical input terms, Attention(Q, K, V ) = softmax( QK dk √ while dk is the square root of the feature dimension used as a scaling factor. Self-attention is achieved by setting Q = K = V = concat(fgin , flin ). The cross-attention feature, fca , is obtained from the self-attention mechanism, which dynamically weighs the importance of each element in the combined feature. To obtain the final outputs of the global and local features, fgout and flout , containing cross-scale information, we consider residual connections: fgout = fgin + φg (fca ), flout = flin + φl (fca ), where φg (·) and φl (·) are MLP architectures with two fully connected layers.
Uncertainty-Aware Restoration. Targeting signal restorations, features of fgout and flout are decoded by two decoders, Dg and Dl , to obtain restored ˆl , respectively, along with corresponding restoration uncertainty signals x ˆg and x maps σg and σl measuring the difficulty of restoration for various signal points, ˆl , σl = Dl (flout ). An uncertainty-aware restoration where x ˆg , σg = Dg (fgout ), x loss is used to incorporate restoration uncertainty into the loss functions, Lglobal =
D d (xkg − x ˆkg )2 ˆkl )2 (xkl − x k { + log σ }, L = { + log σlk }, (1) local g σgk σlk
k=1
k=1
where for each function, the first term is normalized by the corresponding uncertainty, and the second term prevents predicting a large uncertainty for all restoration pixels following [12]. The superscript k represents the position of the k-th element of the signal. It is worth noting that, unlike [12], the uncertainty-aware loss is used for restoration, but not for reconstruction. Trend Generation Module. The trend generation module (TGM) illustrated in Fig. 1 generates a smooth time-series trend xt ∈ RD by removing signal details, which is represented as the smooth difference between adjacent time-series signal points. An autoencoder (Et and Dt ) encodes the trend information into Et (xt ), which are concatenated with the global feature fgout to restore the global ECG x ˆt = Dt (concat(Et (xt ), fgout )). The restoration loss is defined as the Euclidean D k ˆt , Ltrend = ˆkt )2 . This process guides distance between xg and x k=1 (xg − x global feature learning using time-series trend information, emphasizing rhythm characteristics while de-emphasizing morphological details.
Multi-scale Cross-restoration Framework for ECG Anomaly Detection
91
Loss Function. The final loss function for optimizing our model during the training process can be written as L = Lglobal + αLlocal + βLtrend ,
(2)
where α and β are trade-off parameters weighting the loss function. For simplicity, we adopt α = β = 1.0 as the default. 2.2
Anomaly Score Measurement
For each test sample x, local ECGs from the segmented heartbeat set {xl,m , m = 1, ..., M } are paired with the global ECG xg one at a time as inputs. The anomaly score A(x) is calculated to estimate the abnormality, A(x) =
D (xkg − x ˆkg )2 k=1
σgk
+
d M (xkl,m − x ˆkl,m )2 m=1 k=1
k σl,m
+
D
(xkg − x ˆkt )2 ,
(3)
k=1
where the three terms correspond to global restoration, local restoration, and trend restoration, respectively. For localization, an anomaly score map is generated in the same way as Eq. (3), but without summing over the signal points. The anomalies are indicated by relatively large anomaly scores, and vice versa.
3
Experiments
Datasets. Three publicly available ECG datasets are used to evaluate the proposed method, including PTB-XL [22], MIT-BIH [13], and Keogh ECG [6]. – PTB-XL database includes clinical 12-lead ECGs that are 10 s in length for each patient, with only patient-level annotations. To build a new challenging anomaly detection and localization benchmark, 8167 normal ECGs are used for training, while 912 normal and 1248 abnormal ECGs are used for testing. We provide signal point-level annotations of 400 ECGs, including 22 different abnormal types, that were annotated by two experienced cardiologists. To our best knowledge, we are the first to explore ECG anomaly detection and localization across various patients on such a complex and large-scale database. – MIT-BIH arrhythmia dataset divides the ECGs from 44 patients into independent heartbeats based on the annotated R-peak position, following [11]. 62436 normal heartbeats are used for training, while 17343 normal and 9764 abnormal beats are used for testing, with heartbeat-level annotations. – Keogh ECG dataset includes 7 ECGs from independent patients, evaluating anomaly localization with signal point-level annotations. For each ECG, there is an anomaly subsequence that corresponds to a pre-ventricular contraction, while the remaining sequence is used as normal data to train the model. The ECGs are partitioned into fixed-length sequences of 320 by a sliding window with a stride of 40 during training and 160 during testing.
92
A. Jiang et al.
Table 1. Anomaly detection and anomaly localization results on PTB-XL database. Results are shown in the patient-level AUC for anomaly detection and the signal pointlevel AUC for anomaly localization, respectively. The best-performing method is in bold, and the second-best is underlined.
Table 2. Anomaly detection results on MIT-BIH dataset, comparing with stateof-the-arts. Results are shown in terms of the AUC and F1 score for heartbeatlevel classification. The best-performing method is in bold, and the second-best is underlined.
Method
Year detection localization
Method
Year F1
AUC
DAGMM [30]
2018 0.782
0.688
DAGMM [30]
2018 0.677
0.700
MADGAN [9] 2019 0.775
0.708
MSCRED [27] 2019 0.778
0.627
USAD [1]
2020 0.785
0.683
USAD [1]
2020 0.384
0.352
TranAD [18]
2022 0.788
0.685
TranAD [18]
2022 0.621
0.742
AnoTran [25]
2022 0.762
0.641
AnoTran [25]
2022 0.650
0.770
TSL [29]
2022 0.757
0.509
TSL [29]
2022 0.750
0.894
BeatGAN [11] 2022 0.799
0.715
BeatGAN [11] 2022 0.816
0.945
0.747
Ours
Ours
2023 0.860
2023 0.883 0.969
Table 3. Anomaly localization results on Keogh ECG [6] dataset, comparing with several state-of-the-arts. Results are shown in the signal point-level AUC. The bestperforming method is in bold, and the second-best is underlined. Methods
Year A
B
C
D
E
F
G
Avg
DAGMM [30] MSCRED [27] MADGAN [9] USAD [1] GDN [4] CAE-M [28] TranAD [18] AnoTran [25] BeatGAN [11]
2018 2019 2019 2020 2021 2021 2022 2022 2022
0.612 0.633 0.702 0.616 0.611 0.618 0.623 0.502 0.623
0.805 0.798 0.833 0.795 0.790 0.802 0.820 0.792 0.783
0.713 0.714 0.664 0.715 0.674 0.715 0.720 0.799 0.747
0.457 0.461 0.463 0.462 0.458 0.457 0.446 0.498 0.506
0.662 0.746 0.692 0.649 0.648 0.708 0.780 0.748 0.757
0.676 0.659 0.678 0.680 0.671 0.671 0.680 0.711 0.852
0.657 0.668 0.674 0.655 0.650 0.661 0.674 0.684 0.724
Ours
2023 0.832 0.641
0.819
0.815 0.543 0.760
0.833
0.749
0.672 0.667 0.688 0.667 0.695 0.657 0.647 0.739 0.803
Evaluation Protocols. The performance of anomaly detection and localization is quantified using the area under the Receiver Operating Characteristic curve (AUC), with a higher AUC value indicating a better method. To ensure comparability across different annotation levels, we used patient-level, heartbeat-level, and signal point-level AUC for each respective setting. For heartbeat-level classification, the F1 score is also reported following [11]. Implementation Details. The ECG is pre-processed by a Butterworth filter and Notch filter [19] to remove high-frequency noise and eliminate ECG baseline wander. The R-peaks are detected with an adaptive threshold following [5], which
Multi-scale Cross-restoration Framework for ECG Anomaly Detection
93
Fig. 2. Anomaly localization visualization on PTB-XL with different abnormal types. Ground truths are highlighted in red boxes on the ECG data, and anomaly localization results for each case, compared with the state-of-the-art method, are attached below. (Color figure online)
does not require any learnable parameters. The positions of the detected R-peaks are then used to segment the ECG sequence into a set of heartbeats. We use a convolutional-based autoencoder, following the architecture proposed in [11]. The model is trained using the AdamW optimizer with an initial learning rate of 1e-4 and a weight decay coefficient of 1e-5 for 50 epochs on a single NVIDIA GTX 3090 GPU, with a single cycle of cosine learning rate used for decay scheduling. The batch size is set to 32. During testing, the model requires 2365M GPU memory and achieves an inference speed of 4.2 fps. 3.1
Comparisons with State-of-the-Arts
We compare our method with several time-series anomaly detection methods, including heartbeat-level detection method BeatGAN [11], patient-level detection method TSL [29], and several signal point-level anomaly localization methods [1,4,9,18,25,27,28,30]. For a fair comparison, we re-trained all the methods under the same experimental setup. For those methods originally designed for signal point-level tasks only [1,9,18,25,30], we use the mean value of anomaly localization results as their heartbeat-level or patient-level anomaly scores. Anomaly Detection. The anomaly detection performance on PTB-XL is summarized in Table 1. The proposed method achieves 86.0% AUC in patient-level anomaly detection and outperforms all baselines by a large margin (10.3%). Table 2 displays the comparison results on MIT-BIH, where the proposed method achieves a heartbeat-level AUC of 96.9%, showing an improvement of 2.4% over the state-of-the-art BeatGAN (94.5%). Furthermore, the F1-score of the proposed method is 88.3%, which is 6.7% higher than BeatGAN (81.6%).
94
A. Jiang et al.
Table 4. Ablation studies on PTB-XL dataset. Factors under analysis are: the masking and restoring (MR), the multiscale cross-attention (MC), the uncertainty loss function (UL), and the trend generation module (TGM). Results are shown in the patient-level AUC in % of five runs. The best-performing method is in bold. MR MC UL TGM AUC 70.4±0.3
80.4±0.7
80.3±0.3
72.8±2.0
Table 5. Sensitivity analysis w.r.t. mask ratio on PTB-XL dataset. Results are shown in the patient-level AUC of five runs. The best-performing method is in bold, and the second-best is underlined. Mask Ratio AUC 0%
80.2±0.2
10%
85.2±0.2
20%
85.5±0.3
30%
86.0±0.1
40%
84.9±0.3
50%
83.8±0.1
60%
82.9±0.1
70%
75.8±1.0
71.2±0.5 84.8±0.8 85.2±0.4
86.0±0.1
Anomaly Localization. Table 1 presents the results of anomaly localization on our proposed benchmark for multiple individuals. The proposed method achieves a signal point-level AUC of 74.7%, outperforming all baselines (3.2% higher than BeatGAN). It is worth noting that TSL, which is not designed for localization, shows poor performance in this task. Table 3 shows the signal point-level anomaly localization results for each independent individual on Keogh ECG. Overall, the proposed method achieves the best or second-best performance compared to other methods on six subsets and the highest mean AUC among all subsets (74.9%, 2.5% higher than BeatGAN), indicating its effectiveness. The proposed method shows a lower standard deviation (±10.5) across the seven subsets compared to TranAD (±11.3) and BeatGAN (±11.0), which indicates good generalizability of the proposed method across different subsets. Anomaly Localization Visualization. We present visualization results of anomaly localization on several samples from our proposed benchmark in Fig. 2, with ground truths annotated by experienced cardiologists. Regions with higher anomaly scores are indicated by darker colors. Our proposed method outperforms BeatGAN in accurately localizing various types of ECG anomalies, including both periodic and episodic anomalies, such as incomplete right bundle branch block and premature beats. Our method though provides narrower localization results than ground truths, as it is highly sensitive to abrupt unusual changes in signal values, but still represents the important areas for anomaly identification, a fact confirmed by experienced cardiologists. 3.2
Ablation Study and Sensitivity Analysis
Ablation studies were conducted on PTB-XL to confirm the effectiveness of individual components of the proposed method. Table 4 shows that each module
Multi-scale Cross-restoration Framework for ECG Anomaly Detection
95
contributes positively to the overall performance of the framework. When none of the modules were employed, the method becomes a ECG reconstruction approach with a naive L2 loss and lacks cross-attention in multi-scale data. When individually adding the MR, MC, UL, and TGM modules to the baseline model without any of them, the AUC values improve from 70.4% to 80.4%, 80.3%, 72.8%, and 71.2%, respectively, demonstrating the effectiveness of each module. Moreover, as the modules are added in sequence, the performance improves step by step from 70.4% to 86.0% in AUC, highlighting the combined impact of all modules on the proposed framework. We conduct a sensitivity analysis on the mask ratio, as shown in Table 5. Restoration with a 0% masking ratio can be regarded as reconstruction, which takes an entire sample as input and its target is to output the input sample. Results indicate that the model’s performance first improves and then declines as the mask ratio increases from 0% to 70%. This trend is due to the fact that a low mask ratio can limit the model’s feature learning ability during restoration, while a high ratio can make it increasingly difficult to restore the masked regions. Therefore, there is a trade-off between maximizing the model’s potential and ensuring a reasonable restoration difficulty. The optimal mask ratio is 30%, which achieves the highest anomaly detection result (86.0% in AUC).
4
Conclusion
This paper proposes a novel framework for ECG anomaly detection, where features of the entire ECG and local heartbeats are combined with a maskingrestoration process to detect anomalies, simulating the diagnostic process of cardiologists. A challenging benchmark, with signal point-level annotations provided by experienced cardiologists, is proposed, facilitating future research in ECG anomaly localization. The proposed method outperforms state-of-the-art methods, highlighting its potential in real-world clinical diagnosis. Acknowledgement. This work is supported by the National Key R&D Program of China (No. 2022ZD0160702), STCSM (No. 22511106101, No. 18DZ2270700, No. 21DZ1100100), 111 plan (No. BP0719010), the Youth Science Fund of National Natural Science Foundation of China (No. 7210040772) and National Facility for Translational Medicine (Shanghai) (No. TMSK-2021-501), and State Key Laboratory of UHD Video and Audio Production and Presentation.
References 1. Audibert, J., Michiardi, P., Guyard, F., Marti, S., Zuluaga, M.A.: Usad: unsupervised anomaly detection on multivariate time series. In: Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining, pp. 3395–3404 (2020) 2. Chalapathy, R., Menon, A.K., Chawla, S.: Robust, deep and inductive anomaly detection. In: Ceci, M., Hollmén, J., Todorovski, L., Vens, C., Džeroski, S. (eds.) ECML PKDD 2017. LNCS (LNAI), vol. 10534, pp. 36–51. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-71249-9_3
96
A. Jiang et al.
3. Chen, L., Bentley, P., Mori, K., Misawa, K., Fujiwara, M., Rueckert, D.: Selfsupervised learning for medical image analysis using image context restoration. Med. Image Anal. 58, 101539 (2019) 4. Deng, A., Hooi, B.: Graph neural network-based anomaly detection in multivariate time series. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 4027–4035 (2021) 5. van Gent, P., Farah, H., van Nes, N., van Arem, B.: Analysing noisy driver physiology real-time using off-the-shelf sensors: Heart rate analysis software from the taking the fast lane project. J. Open Res. Softw. 7(1) (2019) 6. Keogh, E., Lin, J., Fu, A.: Hot sax: finding the most unusual time series subsequence: Algorithms and applications. In: Proceedings of the IEEE International Conference on Data Mining, pp. 440–449. Citeseer (2004) 7. Khan, M.G.: Step-by-step method for accurate electrocardiogram interpretation. In: Rapid ECG Interpretation, pp. 25–80. Humana Press, Totowa, NJ (2008) 8. Kiranyaz, S., Ince, T., Gabbouj, M.: Real-time patient-specific ECG classification by 1-d convolutional neural networks. IEEE Trans. Biomed. Eng. 63(3), 664–675 (2015) 9. Li, D., Chen, D., Jin, B., Shi, L., Goh, J., Ng, S.-K.: MAD-GAN: multivariate anomaly detection for time series data with generative adversarial networks. In: Tetko, I.V., Kůrková, V., Karpov, P., Theis, F. (eds.) ICANN 2019. LNCS, vol. 11730, pp. 703–716. Springer, Cham (2019). https://doi.org/10.1007/978-3-03030490-4_56 10. Li, H., Boulanger, P.: A survey of heart anomaly detection using ambulatory electrocardiogram (ECG). Sensors 20(5), 1461 (2020) 11. Liu, S., et al.: Time series anomaly detection with adversarial reconstruction networks. IEEE Trans. Knowl. Data Eng. (2022) 12. Mao, Y., Xue, F.-F., Wang, R., Zhang, J., Zheng, W.-S., Liu, H.: Abnormality detection in chest X-Ray images using uncertainty prediction autoencoders. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12266, pp. 529–538. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59725-2_51 13. Moody, G.B., Mark, R.G.: The impact of the mit-bih arrhythmia database. IEEE Eng. Med. Biol. Mag. 20(3), 45–50 (2001) 14. Ruff, L., et al.: Deep one-class classification. In: International Conference on Machine Learning, pp. 4393–4402. PMLR (2018) 15. Schlegl, T., Seeböck, P., Waldstein, S.M., Langs, G., Schmidt-Erfurth, U.: f-anogan: fast unsupervised anomaly detection with generative adversarial networks. Med. Image Anal. 54, 30–44 (2019) 16. Shaker, A.M., Tantawi, M., Shedeed, H.A., Tolba, M.F.: Generalization of convolutional neural networks for ECG classification using generative adversarial networks. IEEE Access 8, 35592–35605 (2020) 17. Shen, L., Yu, Z., Ma, Q., Kwok, J.T.: Time series anomaly detection with multiresolution ensemble decoding. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 9567–9575 (2021) 18. Tuli, S., Casale, G., Jennings, N.R.: Tranad: deep transformer networks for anomaly detection in multivariate time series data. In: International Conference on Very Large Databases 15(6), pp. 1201–1214 (2022) 19. Van Gent, P., Farah, H., Van Nes, N., Van Arem, B.: Heartpy: a novel heart rate algorithm for the analysis of noisy signals. Transport. Res. F: Traffic Psychol. Behav. 66, 368–378 (2019) 20. Vaswani, A., et al.: Attention is all you need. In: Advances in Neural Information Processing Systems 30 (2017)
Multi-scale Cross-restoration Framework for ECG Anomaly Detection
97
21. Venkatesan, C., Karthigaikumar, P., Paul, A., Satheeskumaran, S., Kumar, R.: ECG signal preprocessing and SVM classifier-based abnormality detection in remote healthcare applications. IEEE Access 6, 9767–9773 (2018) 22. Wagner, P., et al.: Ptb-xl, a large publicly available electrocardiography dataset. Sci. Data 7(1), 1–15 (2020) 23. Wang, J., et al.: Automated ECG classification using a non-local convolutional block attention module. Comput. Methods Programs Biomed. 203, 106006 (2021) 24. Wolleb, J., Bieder, F., Sandkühler, R., Cattin, P.C.: Diffusion models for medical anomaly detection. In: Medical Image Computing and Computer Assisted Intervention-MICCAI 2022, pp. 35–45. Springer, Cham (2022). https://doi.org/ 10.1007/978-3-031-16452-1_4 25. Xu, J., Wu, H., Wang, J., Long, M.: Anomaly transformer: time series anomaly detection with association discrepancy. In: International Conference on Learning Representations (2022) 26. Ye, F., Huang, C., Cao, J., Li, M., Zhang, Y., Lu, C.: Attribute restoration framework for anomaly detection. IEEE Trans. Multimedia 24, 116–127 (2022) 27. Zhang, C., et al.: A deep neural network for unsupervised anomaly detection and diagnosis in multivariate time series data. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 1409–1416 (2019) 28. Zhang, Y., Chen, Y., Wang, J., Pan, Z.: Unsupervised deep anomaly detection for multi-sensor time-series signals. IEEE Trans. Knowl. Data Eng. 35(2), 2118–2132 (2021) 29. Zheng, Y., Liu, Z., Mo, R., Chen, Z., Zheng, W.s., Wang, R.: Task-oriented selfsupervised learning for anomaly detection in electroencephalography. In: International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 193–203. Springer, Cham (2022). https://doi.org/10.1007/978-3-03116452-1_19 30. Zong, B., et al.: Deep autoencoding gaussian mixture model for unsupervised anomaly detection. In: International Conference on Learning Representations (2018)
Correlation-Aware Mutual Learning for Semi-supervised Medical Image Segmentation Shengbo Gao1 , Ziji Zhang2 , Jiechao Ma1 , Zihao Li1 , and Shu Zhang1(B) 1
2
Deepwise AI Lab, Beijing, China [email protected] School of Artificial Intelligence, Beijing University of Posts and Telecommunications, Beijing, China
Abstract. Semi-supervised learning has become increasingly popular in medical image segmentation due to its ability to leverage large amounts of unlabeled data to extract additional information. However, most existing semi-supervised segmentation methods only focus on extracting information from unlabeled data, disregarding the potential of labeled data to further improve the performance of the model. In this paper, we propose a novel Correlation Aware Mutual Learning (CAML) framework that leverages labeled data to guide the extraction of information from unlabeled data. Our approach is based on a mutual learning strategy that incorporates two modules: the Cross-sample Mutual Attention Module (CMA) and the Omni-Correlation Consistency Module (OCC). The CMA module establishes dense cross-sample correlations among a group of samples, enabling the transfer of label prior knowledge to unlabeled data. The OCC module constructs omni-correlations between the unlabeled and labeled datasets and regularizes dual models by constraining the omni-correlation matrix of each sub-model to be consistent. Experiments on the Atrial Segmentation Challenge dataset demonstrate that our proposed approach outperforms state-of-the-art methods, highlighting the effectiveness of our framework in medical image segmentation tasks. The codes, pre-trained weights, and data are publicly available. Keywords: Semi-supervised learning · Medical Image Segmentation. Mutual learning · Cross-sample correlation
1
·
Introduction
Despite the remarkable advancements achieved through the use of deep learning for automatic medical image segmentation, the scarcity of precisely annotated training data remains a significant obstacle to the widespread adoption of such
https://github.com/Herschel555/CAML S. Gao and Z. Zhang—Both authors contributed equally to this work. Z. Zhang—Work done as an intern in Deepwise AI Lab. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 98–108, 2023. https://doi.org/10.1007/978-3-031-43907-0_10
Correlation-Aware Mutual Learning
99
techniques in clinical settings. As a solution, the concept of semi-supervised segmentation has been proposed to enable models to be trained using less annotated but abundant unlabeled data. Recently, methods that adopt the co-teaching [3,11,19] or mutual learning [25] paradigm have emerged as a promising approach for semi-supervised learning. Those methods adopt two simultaneously updated models, each trained to predict the prediction results of its counterpart, which can be seen as a combination of the notions of consistency regularization [1,4,9,14,15] and entropy minimization [2,7,20,22,24]. In the domain of semi-supervised medical image segmentation, MC-Net [19] has shown significant improvements in segmentation performance. With the rapid advancement of semi-supervised learning, the importance of unlabeled data has garnered increased attention across various disciplines in recent years. However, the role of labeled data has been largely overlooked, with the majority of semi-supervised learning techniques treating labeled data supervision as merely an initial step of the training pipeline or as a means to ensure training convergence [6,15,26]. Recently, methods that can leverage labeled data to directly guide information extraction from unlabeled data have attracted the attention of the community [16]. In the domain of semi-supervised medical image segmentation, there exist shared characteristics between labeled and unlabeled data that possess greater intuitiveness and instructiveness for the algorithm. Typically, partially labeled clinical datasets exhibit similar foreground features, including comparable texture, shape, and appearance among different samples. As such, it can be hypothesized that constructing a bridge across the entire training dataset to connect labeled and unlabeled data can effectively transfer prior knowledge from labeled data to unlabeled data and facilitate the extraction of information from unlabeled data, ultimately overcoming the performance bottleneck of semi-supervised learning methods. Based on the aforementioned conception, we propose a novel Correlation Aware Mutual Learning (CAML) framework to explicitly model the relationship between labeled and unlabeled data to effectively utilize the labeled data. Our proposed method incorporates two essential components, namely the Crosssample Mutual Attention module (CMA) and the Omni-Correlation Consistency module (OCC), to enable the effective transfer of labeled data information to unlabeled data. The CMA module establishes mutual attention among a group of samples, leading to a mutually reinforced representation of co-salient features between labeled and unlabeled data. Unlike conventional methods, where supervised signals from labeled and unlabeled samples are separately back-propagated, the proposed CMA module creates a new information propagation path among each pixel in a group of samples, which synchronously enhances the feature representation ability of each intra-group sample. In addition to the CMA module, we introduce the OCC module to regularize the segmentation model by explicitly modeling the omni-correlation between unlabeled features and a group of labeled features. This is achieved by constructing a memory bank to store the labeled features as a reference set of features or basis vectors. In each iteration, a portion of features from the memory bank is utilized to calculate the omni-correlation with unlabeled features, reflecting
100
S. Gao et al.
Fig. 1. Overview of our proposed CAML. CAML adopts a co-teaching scheme with cross-pseudo supervision. The CMA module incorporated into the auxiliary network and the OCC module are introduced for advanced cross-sample relationship modeling.
the similarity relationship of an unlabeled pixel with respect to a set of basis vectors of the labeled data. Finally, we constrain the omni-correlation matrix of each sub-model to be consistent to regularize the entire framework. With the proposed omni-correlation consistency, the labeled data features serve as anchor groups to guide the representation learning of the unlabeled data feature and explicitly encourage the model to learn a more unified feature distribution among unlabeled data. In summary, our contributions are threefold: (1) We propose a novel Correlation Aware Mutual Learning (CAML) framework that focuses on the efficient utilization of labeled data to address the challenge of semi-supervised medical image segmentation. (2) We introduce the Cross-sample Mutual Attention module (CMA) and the Omni-Correlation Consistency module (OCC) to establish cross-sample relationships directly. (3) Experimental results on a benchmark dataset demonstrate significant improvements over previous SOTAs, especially when only a small number of labeled images are available.
2 2.1
Method Overview
Figure 1 gives an overview of CAML. We adopt a co-teaching paradigm like MCNet [19] to enforce two parallel networks to predict the prediction results of its counterpart. To achieve efficient cross-sample relationship modeling and enable
Correlation-Aware Mutual Learning
101
information propagation among labeled and unlabeled data in a mini-batch, we incorporate a Cross-sample Mutual Attention module to the auxiliary segmentation network fa , whereas the vanilla segmentation network fv remains the original V-Net structure. In addition, we employ an Omni-Correlation Consistency regularization to further regularize the representation learning of the unlabeled data. Details about those two modules will be elaborated on in the following sections. The total loss of CAML can be formulated as: L = Ls + λc lc + λo lo
(1)
where lo represents the proposed omni-correlation consistency loss, while Ls and lc are the supervised loss and the cross-supervised loss implemented in the Cross Pseudo Supervision(CPS) module. λc and λo are the weights to control lc and lo separately. During the training procedure, a batch of mixed labeled and unlabeled samples are fed into the network. The supervised loss is only applied to labeled data, while all samples are utilized to construct cross-supervised learning. Please refer to [3] for a detailed description of the CPS module and loss design of Ls and lc . 2.2
Cross-Sample Mutual Attention Module
To enable information propagation through any positions of any samples in a mini-batch, one can simply treat each pixel’s feature vector as a token and perform self-attentions for all tokens in a mini-batch. However, this will make the computation cost prohibitively large as the computation complexity of selfattention is O(n2 ) with respect to the number of tokens. We on the other hand adopt two sequentially mounted self-attention modules along different dimensions to enable computation efficient mutual attention among all pixels. As illustrated in Fig. 1, the proposed CMA module consists of two sequential transformer encoder layers, termed as E1 and E2 , each including a multi-head attention and a M LP block with a layer normalization after each block. For an input feature map ain ∈ Rb×c×k , where k = h × w × d , b represents batch size and c is the dimension of ain , E1 performs intra-sample self-attention on the spatial dimension of each sample. This is used to model the information propagation paths between every pixel position within each sample. Then, to further enable information propagation among different samples, we perform an inter-sample self-attention along the batch dimension. In other words, along the b dimension, the pixels located in the same spatial position from samples are fed into a self-attention module to construct cross-sample relationships. In CAML, we employ the proposed CMA module in the auxiliary segmentation network fa , whereas the vanilla segmentation network fv remains the original V-Net structure. The reasons can be summarized into two folds. From deployment perspective, the insertion of the CMA module requires a batch size of large than 1 to model the attention among samples within a mini-batch, which is not applicable for model inference(batchsize=1). From the perspective of model design, we model the vanilla and the auxiliary branch with different
102
S. Gao et al.
architectures to increase the architecture heterogeneous for better performance in a mutual learning framework. 2.3
Omni-Correlation Consistency Regularization
In this chapter, we introduce Omni-Correlation Consistency (OCC) to formulate additional model regularization. The core of the OCC module is omnicorrelation, which is a kind of similarity matrix that is calculated between the feature of an unlabeled pixel and a group of prototype features sampled from labeled instances features. It reflects the similar relationship of an unlabeled pixel with respect to a set of labeled reference pixels. During the training procedure, we explicitly constrain the omni-correlation calculated using heterogeneous unlabeled features from those two separate branches to remain the same. In practice, we use an Omni-correlation matrix to formulate the similarity distribution between unlabeled features and the prototype features. Let gv and ga denote two projection heads attached to the backbones of fv and fa separately, and zv ∈ Rm×c and za ∈ Rm×c represent two sets of embeddings sampled from their projected features extracted from unlabeled samples, where m is the number of sampled features and c is the dimension of the projected features. It should be noted that zv and za are sampled from the embeddings corresponding to the same set of positions on unlabeled sam ples. Suppose zp ∈ Rn×c represents a set of prototype embeddings sampled from labeled instances, where n represents the number of sampled prototype features, the omni-correlation matrix calculation between zv and zp can be formulated as: exp(cos(zv , zpi ) ∗ t) , i ∈ {1, ..., n} simvpi = n j=1 exp(cos(zv , zpj ) ∗ t)
(2)
where cos means the cosine similarity and t is the temperature hyperparameter. simvp ∈ Rm×n is the calculated omni-correlation matrix. Similarly, the similarity distribution simap between za and zp can be calculated by replacing zv with za . To constrain the consistency of omni-correlation between dual branches, the omni-correlation consistency regularization can be conducted with the crossentropy loss lce as follows: 1 (3) lce (simvp , simap ) m Memory Bank Construction. We utilize a memory bank T to iteratively update prototype embeddings for OCC computation. Specifically, T initializes N slots for each labeled training sample and updates prototype embeddings with filtered labeled features projected by gv and ga . To ensure the reliability of the features stored in T , we select embeddings on the positions where both fv and fa have the correct predictions and update T with the mean fusion of the projected features projected by gv and ga . For each training sample, following [5], T updates slots corresponding to the labeled samples in the current mini-batch in a query-like manner. lo =
Correlation-Aware Mutual Learning
103
Embeddings Sampling. For computation efficiency, omni-correlation is not calculated on all labeled and unlabeled pixels. Specifically, we have developed a confidence-based mechanism to sample the pixel features from the unlabeled data. Practically, to sample zv and za from unlabeled features, we first select the pixels where fv and fa have the same prediction. For each class, we sort the confidence scores of these pixels, and then select features of the top i pixels as the sampled unlabeled features. Thus, m = i × C, where C represents the number of classes. With regards to the prototype embeddings, we randomly sample j embeddings from each class among all the embeddings contained in T and n = j × C to increase its diversity.
3
Experiments and Results
Dataset. Our method is evaluated on the Left Atrium (LA) dataset [21] from the 2018 Atrial Segmentation Challenge. The dataset comprises 100 gadoliniumenhanced MR imaging scans (GE-MRIs) and their ground truth masks, with an isotropic resolution of 0.6253 mm3 . Following [23], we use 80 scans for training and 20 scans for testing. All scans are centered at the heart region and cropped accordingly, and then normalized to zero mean and unit variance. Implementation Details. We implement our CAML using PyTorch 1.8.1 and CUDA 10.2 on an NVIDIA TITAN RTX GPU. For training data augmentation, we randomly crop sub-volumes of size 112 × 112 × 80 following [23]. To ensure a fair comparison with existing methods, we use the V-Net [13] as the backbone for all our models. During training, we use a batch size of 4, with half of the images annotated and the other half unannotated. We train the entire framework using the SGD optimizer, with a learning rate of 0.01, momentum of 0.9, and weight decay of 1e−4 for 15000 iterations. To balance the loss terms in the training process, we use a time-dependent Gaussian warming up function for λU and λC , 2 where λ(t) = β∗e−5(1−t/tmax ) , and set β to 1 and 0.1 for λU and λC , respectively. For the OCC module, we set c to 64, j to 256, and i to 12800. During inference, prediction results from the vanilla V-Net are used with a general sliding window strategy without any post-processing. Quantitative Evaluation and Comparison. Our CAML is evaluated on four metrics: Dice, Jaccard, 95% Hausdorff Distance (95HD), and Average Surface Distance (ASD). It is worth noting that the previous researchers reported results (Reported Metrics in Table 1) on LA can be confusing, with some studies reporting results from the final training iteration, while others report the best performance obtained during training. However, the latter approach can lead to overfitting of the test dataset and unreliable model selection. To ensure a fair comparison, we perform all experiments three times with a fixed set of randomly selected seeds on the same machine, and report the mean and standard deviation of the results from the final iteration. The results on LA are presented in Table 1. The results of the full-supervised V-Net model trained on different ratios serve as the lower and upper bounds
104
S. Gao et al.
Table 1. Comparison with state-of-the-art methods on the LA database. Metrics reported the mean± standard results with three random seeds, Reported Metrics are the results reported in the original paper. Method
Scans used Metrics Labeled Unlabeled Dice(%)
Reported Metrics Jaccard(%) 95HD(voxel) ASD(voxel) Dice(%) Jaccard(%) 95HD(voxel) ASD(voxel)
V-Net V-Net V-Net V-Net
4 8 16 80
0 0 0 0
43.32±8.62 79.87±1.23 85.94±0.48 90.98±0.67
31.43±6.90 67.60±1.88 75.99±0.57 83.61±1.06
40.19±1.11 26.65±6.36 16.70±1.82 8.58±2.34
12.13±0.57 7.94±2.22 4.80±0.62 2.10±0.59
52.55 78.57 86.96 91.62
39.60 66.96 77.31 84.60
47.05 21.20 11.85 5.40
9.87 6.07 3.22 1.64
UA-MT [23] (MICCAI’19) 4(5%) SASSNet [8] (MICCAI’20) DTC [10] (AAAI’21) MC-Net [19] (MedIA’21) URPC [12] (MedIA’22) SS-Net [18] (MICCAI’22) MC-Net+ [17] (MedIA’22) ours
76(95%)
78.07±0.90 79.61±0.54 80.14±1.22 80.92±3.88 80.75±0.21 83.33±1.66 83.23±1.41 87.34±0.05
65.03±0.96 67.00±0.59 67.88±1.82 68.90±5.09 68.54±0.34 71.79±2.36 71.70±1.99 77.65±0.08
29.17±3.82 25.54±4.60 24.08±2.63 17.25±6.08 19.81±0.67 15.70±0.80 14.92±2.56 9.76±0.92
8.63±0.98 7.20±1.21 7.18±0.62 2.76±0.49 4.98±0.25 4.33±0.36 3.43±0.64 2.49±0.22
86.33 -
76.15 -
9.97 -
2.31 -
UA-MT [23] (MICCAI’19) 8(10%) SASSNet [8] (MICCAI’20) DTC [10] (AAAI’21) MC-Net [19] (MedIA’21) URPC [12] (MedIA’22) SS-Net [18] (MICCAI’22) MC-Net+ [17] (MedIA’22) ours
72(90%)
85.81±0.17 85.71±0.87 84.55±1.72 86.87±1.74 83.37±0.21 86.56±0.69 87.68±0.56 89.62±0.20
75.41±0.22 75.35±1.28 73.91±2.36 78.49±1.06 71.99±0.31 76.61±1.03 78.27±0.83 81.28±0.32
18.25±1.04 14.74±3.14 13.80±0.16 11.17±1.40 17.91±0.73 12.76±0.58 10.35±0.77 8.76±1.39
5.04±0.24 4.00±0.86 3.69±0.25 2.18±0.14 4.41±0.17 3.02±0.19 1.85±0.01 2.02±0.17
84.25 86.81 87.71 88.55 88.96 -
73.48 76.92 78.31 79.63 80.25 -
3.36 3.94 9.36 7.49 7.93 -
13.84 12.54 2.18 1.90 1.86 -
UA-MT [23] (MICCAI’19) 16(20%) 64(80%) SASSNet [8] (MICCAI’20) DTC [10] (AAAI’21) MC-Net [19] (MedIA’21) URPC [12] (MedIA’22) SS-Net [18] (MICCAI’22) MC-Net+ [17] (MedIA’22) ours
88.18±0.69 88.11±0.34 87.79±0.50 90.43±0.52 87.68±0.36 88.19±0.42 90.60±0.39 90.78±0.11
79.09±1.05 79.08±0.48 78.52±0.73 82.69±0.75 78.36±0.53 79.21±0.63 82.93±0.64 83.19±0.18
9.66±2.99 12.31±4.14 10.29±1.52 6.52±0.66 14.39±0.54 8.12±0.34 6.27±0.25 6.11±0.39
2.62±0.59 3.27±0.96 2.50±0.65 1.66±0.14 3.52±0.17 2.20±0.12 1.58±0.07 1.68±0.15
88.88 89.27 89.42 90.34 91.07 -
80.21 80.82 80.98 82.48 83.67 -
2.26 3.13 2.10 6.00 5.84 -
7.32 8.83 7.32 1.77 1.67 -
Fig. 2. Visualization of the segmentations results from different methods.
of each ratio setting. We report the reproduced results of state-of-the-art semisupervised methods and corresponding reported results if available. By comparing the reproduced and reported results, we observe that although the performance of current methods generally shows an increasing trend with the development of algorithms, the performance of individual experiments can be unstable. and the reported results may not fully reflect the true performance.
Correlation-Aware Mutual Learning
105
It is evident from Table 1 that CAML outperforms other methods by a significant margin across all settings without incurring any additional inference or post-processing costs. With only 5% labeled data, CAML achieves 87.34% Dice score with an absolute improvement of 4.01% over the state-of-the-art. CAML also achieves 89.62% Dice score with only 10% labeled data. When the amount of labeled data is increased to 20%, the model obtains comparable results with the results of V-Net trained in 100% labeled data), achieving a Dice score of 90.78% compared to the upper-bound model’s score of 90.98%. As presented in Table 1, through the effective transfer of knowledge between labeled and unlabeled data, CAML achieves impressive improvements. Table 2. Ablation study of our proposed CAML on the LA database. Scans used
Components
Metrics
Labeled Unlabeled Baseline OCC CMA √ 4(5%) 76(95%) √ √ √ √ √ √ √ 8(10%)
72(90%)
16(20%) 64(80%)
√ √ √ √ √ √ √ √
√ √
√ √
√ √
√ √
Dice(%)
Jaccard(%) 95HD(voxel) ASD(voxel)
80.92±3.88 83.12±2.12 86.35±0.26 87.34±0.05
68.90±5.09 71.73±3.04 76.16±0.40 77.65±0.08
17.25±6.08 16.94±7.25 12.36±0.20 9.76±0.92
2.76±0.49 4.51±2.16 2.94±0.21 2.49±0.22
86.87±1.74 88.50±3.25 88.84±0.55 89.62±0.20
78.49±1.06 79.53±0.51 80.05±0.85 81.28±0.32
11.17±1.40 9.89±0.83 8.50±0.66 8.76±1.39
2.18±0.14 2.35±0.21 1.97±0.02 2.02±0.17
90.43±0.52 90.27±0.22 90.25±0.28 90.78±0.11
82.69±0.75 82.42±0.39 82.34±0.43 83.19±0.18
6.52±0.66 6.96±1.03 6.95±0.09 6.11±0.39
1.66±0.14 1.91±0.24 1.79±0.18 1.68±0.15
Table 1 also demonstrated that as the labeled data ratio declines, the model maintains a low standard deviation of results, which is significantly lower than other state-of-the-art methods. This finding suggests that CAML is highly stable and robust. Furthermore, the margin between our method and the state-of-theart semi-supervised methods increases with the decline of the labeled data ratio, indicating that our method rather effectively transfers knowledge from labeled data to unlabeled data, thus enabling the model to extract more universal features from unlabeled data. Figure 2 shows the qualitative comparison results. The figure presents 2D and 3D visualizations of all the compared methods and the corresponding ground truth. As respectively indicated by the orange rectangle and circle in the 2D and 3D visualizations Our CAML achieves the best segmentation results compared to all other methods. Ablation Study. In this section, we analyze the effectiveness of the proposed CMA module and OCC module. We implement the MC-Net as our baseline,
106
S. Gao et al.
which uses different up-sampling operations to introduce architecture heterogeneity. Table 2 presents the results of our ablation study. The results demonstrate that under 5% ratio, both CMA and OCC significantly improve the performance of the baseline. By combining these two modules, CAML achieves an absolute improvement of 6.42% in the Dice coefficient. Similar improvements can be observed for a data ratio of 10%. Under a labeled data ratio of 20%, the baseline performance is improved to 90.43% in the Dice coefficient, which is approximately comparable to the upper bound of a fully-supervised model. In this setting, adding the CMA and OCC separately may not achieve a significant improvement. Nonetheless, by combining these two modules in our proposed CAML framework, we still achieve the best performance in this setting, which further approaches the performance of a fully-supervised model.
4
Conclusion
In this paper, we proposed a novel framework named CAML for semi-supervised medical image segmentation. Our key idea is that cross-sample correlation should be taken into consideration for semi-supervised learning. To this end, two novel modules: Cross-sample Mutual Attention(CMA) and Omni-Correlation Consistency(OCC) are proposed to encourage efficient and direct transfer of the prior knowledge from labeled data to unlabeled data. Extensive experimental results on the LA dataset demonstrate that we outperform previous state-of-the-art results by a large margin without extra computational consumption in inference. Acknowledgements. This work is funded by the Scientific and Technological Innovation 2030 New Generation Artificial Intelligence Project of the National Key Research and Development Program of China (No. 2021ZD0113302), Beijing Municipal Science and Technology Planning Project (No. Z201100005620008, Z211100003521009).
References 1. Berthelot, D., Carlini, N., Goodfellow, I., Papernot, N., Oliver, A., Raffel, C.A.: Mixmatch: a holistic approach to semi-supervised learning. In: Advances in Neural Information Processing Systems 32 (2019) 2. Cascante-Bonilla, P., Tan, F., Qi, Y., Ordonez, V.: Curriculum labeling: revisiting pseudo-labeling for semi-supervised learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 6912–6920 (2021) 3. Chen, X., Yuan, Y., Zeng, G., Wang, J.: Semi-supervised semantic segmentation with cross pseudo supervision. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2613–2622 (2021) 4. French, G., Laine, S., Aila, T., Mackiewicz, M., Finlayson, G.: Semi-supervised semantic segmentation needs strong, varied perturbations. arXiv preprint arXiv:1906.01916 (2019) 5. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9729–9738 (2020)
Correlation-Aware Mutual Learning
107
6. Kwon, D., Kwak, S.: Semi-supervised semantic segmentation with error localization network. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9957–9967 (2022) 7. Lee, D.H., et al.: Pseudo-label: the simple and efficient semi-supervised learning method for deep neural networks. In: Workshop on Challenges in Representation Learning, ICML, vol. 3, p. 896 (2013) 8. Li, S., Zhang, C., He, X.: Shape-aware semi-supervised 3D semantic segmentation for medical images. In: Martel, A.L., Abolmaesumi, P., Stoyanov, D., Mateus, D., Zuluaga, M.A., Zhou, S.K., Racoceanu, D., Joskowicz, L. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 552–561. Springer, Cham (2020). https://doi.org/10.1007/ 978-3-030-59710-8 54 9. Liu, Y., Tian, Y., Chen, Y., Liu, F., Belagiannis, V., Carneiro, G.: Perturbed and strict mean teachers for semi-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4258–4267 (2022) 10. Luo, X., Chen, J., Song, T., Wang, G.: Semi-supervised medical image segmentation through dual-task consistency. In: AAAI Conference on Artificial Intelligence (2021) 11. Luo, X., Hu, M., Song, T., Wang, G., Zhang, S.: Semi-supervised medical image segmentation via cross teaching between CNN and transformer. In: International Conference on Medical Imaging with Deep Learning, pp. 820–833. PMLR (2022) 12. Luo, X., et al.: Semi-supervised medical image segmentation via uncertainty rectified pyramid consistency. Med. Image Anal. 80, 102517 (2022) 13. Milletari, F., Navab, N., Ahmadi, S.A.: V-net: fully convolutional neural networks for volumetric medical image segmentation. In: 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. IEEE (2016) 14. Mittal, S., Tatarchenko, M., Brox, T.: Semi-supervised semantic segmentation with high-and low-level consistency. IEEE Trans. Pattern Anal. Mach. Intell. 43(4), 1369–1379 (2019) 15. Ouali, Y., Hudelot, C., Tami, M.: Semi-supervised semantic segmentation with cross-consistency training. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12674–12684 (2020) 16. Wu, L., Fang, L., He, X., He, M., Ma, J., Zhong, Z.: Querying labeled for unlabeled: Cross-image semantic consistency guided semi-supervised semantic segmentation. IEEE Trans. Pattern Anal. Mach. Intell. (2023) 17. Wu, Y., et al.: Mutual consistency learning for semi-supervised medical image segmentation. Med. Image Anal. 81, 102530 (2022) 18. Wu, Y., Wu, Z., Wu, Q., Ge, Z., Cai, J.: Exploring smoothness and class-separation for semi-supervised medical image segmentation. In: MICCAI 2022, Part V, pp. 34–43. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16443-9 4 19. Wu, Y., Xu, M., Ge, Z., Cai, J., Zhang, L.: Semi-supervised left atrium segmentation with mutual consistency training. In: de Bruijne, M., Cattin, P.C., Cotin, S., Padoy, N., Speidel, S., Zheng, Y., Essert, C. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 297–306. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087196-3 28 20. Xie, Q., Luong, M.T., Hovy, E., Le, Q.V.: Self-training with noisy student improves imagenet classification. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10687–10698 (2020) 21. Xiong, Z., et al.: A global benchmark of algorithms for segmenting late gadoliniumenhanced cardiac magnetic resonance imaging. Medical Image Analysis (2020)
108
S. Gao et al.
22. Yang, L., Zhuo, W., Qi, L., Shi, Y., Gao, Y.: St++: make self-training work better for semi-supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4268–4277 (2022) 23. Yu, L., Wang, S., Li, X., Fu, C.W., Heng, P.A.: Uncertainty-aware self-ensembling model for semi-supervised 3d left atrium segmentation. In: MICCAI (2019) 24. Yuan, J., Liu, Y., Shen, C., Wang, Z., Li, H.: A simple baseline for semi-supervised semantic segmentation with strong data augmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 8229–8238 (2021) 25. Zhang, P., Zhang, B., Zhang, T., Chen, D., Wen, F.: Robust mutual learning for semi-supervised semantic segmentation. arXiv preprint arXiv:2106.00609 (2021) 26. Zou, Y., et al.: Pseudoseg: designing pseudo labels for semantic segmentation. arXiv preprint arXiv:2010.09713 (2020)
TPRO: Text-Prompting-Based Weakly Supervised Histopathology Tissue Segmentation Shaoteng Zhang1,2,4 , Jianpeng Zhang2 , Yutong Xie3(B) , and Yong Xia1,2,4(B) 1
Ningbo Institute of Northwestern Polytechnical University, Ningbo 315048, China [email protected] 2 National Engineering Laboratory for Integrated Aero-Space-Ground-Ocean Big Data Application Technology, School of Computer Science and Engineering, Northwestern Polytechnical University, Xi’an 710072, China 3 Australian Institute for Machine Learning, The University of Adelaide, Adelaide, SA, Australia [email protected] 4 Research and Development Institute of Northwestern Polytechnical University in Shenzhen, Shenzhen 518057, China
Abstract. Most existing weakly-supervised segmentation methods rely on class activation maps (CAM) to generate pseudo-labels for training segmentation models. However, CAM has been criticized for highlighting only the most discriminative parts of the object, leading to poor quality of pseudo-labels. Although some recent methods have attempted to extend CAM to cover more areas, the fundamental problem still needs to be solved. We believe this problem is due to the huge gap between image-level labels and pixel-level predictions and that additional information must be introduced to address this issue. Thus, we propose a text-prompting-based weakly supervised segmentation method (TPRO), which uses text to introduce additional information. TPRO employs a vision and label encoder to generate a similarity map for each image, which serves as our localization map. Pathological knowledge is gathered from the internet and embedded as knowledge features, which are used to guide the image features through a knowledge attention module. Additionally, we employ a deep supervision strategy to utilize the network’s shallow information fully. Our approach outperforms other weakly supervised segmentation methods on benchmark datasets LUAD-HistoSeg and BCSS-WSSS datasets, setting a new state of the art. Code is available at: https://github.com/zhangst431/TPRO.
Keywords: Histopathology Tissue Segmentation Semantic Segmentation · Vision-Language
· Weakly-Supervised
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_11. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 109–118, 2023. https://doi.org/10.1007/978-3-031-43907-0_11
110
1
S. Zhang et al.
Introduction
Automated segmentation of histopathological images is crucial, as it can quantify the tumor micro-environment, provide a basis for cancer grading and prognosis, and improve the diagnostic efficiency of clinical doctors [6,13,19]. However, pixellevel annotation of images is time-consuming and labor-intensive, especially for histopathology images that require specialized knowledge. Therefore, there is an urgent need to pursue weakly supervised solutions for pixel-wise segmentation. Nonetheless, weakly supervised histopathological image segmentation presents a challenge due to the low contrast between different tissues, intra-class variations, and inter-class similarities [4,11]. Additionally, the tissue structures in histopathology images can be randomly arranged and dispersed, which makes it difficult to identify complete tissues or regions of interest [7].
CAM
Ours
Tumor epithelial ssue
Necrosis ssue
Tumor-associated stroma
Under the microscope, tumor epithelial ssue may appear as solid nests, acinar structures, or papillary formaons. The cells may have enlarged and irregular nuclei.
Necrosis may appear as areas of pink, amorphous material under the microscope, and may be surrounded by viable tumor cells and stroma.
Tumor-associated stroma ssue is the connecve ssue that surrounds and supports the tumor epithelial ssue.
Fig. 1. Comparison of activation maps extracted from CAM and our method, from left to right: origin image, ground truth, three activation maps of tumor epithelial (red), necrosis (green), and tumor-associated stroma (orange) respectively. On the right side, there are some examples of the related language knowledge descriptions used in our method. It shows that CAM only highlights a small portion of the target, while our method, which incorporates external language knowledge, can encompass a wider and more precise target tissue. (Color figure online)
Recent studies on weakly supervised segmentation primarily follow class activation mapping (CAM) [20], which localizes the attention regions and then generates the pseudo labels to train the segmentation network. However, the CAM generated based on the image-level labels can only highlight the most discriminative region, but fail to locate the complete object, leading to defective pseudo labels, as shown in Fig. 1. Accordingly, many attempts have been made to enhance the quality of CAM and thus boost the performance of weakly supervised segmentation. Han et al. [7] proposed an erasure-based method that continuously expands the scope of attention areas to obtain rich content of pseudo labels. Li et al. [11] utilized the confidence method to remove any noise that may exist in the pseudo labels and only included the confident pixel labels for the segmentation training. Zhang et al. [18] leveraged the Transformer to model the long-distance dependencies on the whole histopathological images to improve the CAM’s ability to find more complete regions. Lee et al. [10] utilized the ability of an advanced saliency detection model to assist CAM in locating more precise targets. However, these improved variants still face difficulties in capturing the
TPRO for Weakly Supervised Histopathology Tissue Segmentation
111
complete tissues. The primary limitation is that the symptoms and manifestations of histopathological subtypes cannot be comprehensively described by an abstract semantic category. As a result, the image-level label supervision may not be sufficient to pinpoint the complete target area. To remedy the limitations of image-level supervision, we advocate for the integration of language knowledge into weakly supervised learning to provide reliable guidance for the accurate localization of target structures. To this end, we propose a text-prompting-based weakly supervised segmentation method (TPRO) for accurate histopathology tissue segmentation. The text information originates from the task’s semantic labels and external descriptions of subtype manifestations. For each semantic label, a pre-trained medical language model is utilized to extract the corresponding text features that are matched to each feature point in the image spatial space. A higher similarity represents a higher possibility of this location belonging to the corresponding semantic category. Additionally, the text representations of subtype manifestations, including tissue morphology, color, and relationships to other tissues, are extracted by the language model as external knowledge. The discriminative information can be explored from the text knowledge to help identify and locate complete tissues accurately by jointly modeling long-range dependencies between image and text. We conduct experiments on two weakly supervised histological segmentation benchmarks, LUAD-HistoSeg and BCSS-WSSS, and demonstrate the superior quality of pseudo labels produced by our TPRO model compared to other CAM-based methods. Our contributions are summarized as follows: (1) To the best of our knowledge, this is the first work that leverages language knowledge to improve the quality of pseudo labels for weakly-supervised histopathology image segmentation. (2) The proposed text prompting models the correlation between image representations and text knowledge, effectively improving the quality of pseudo labels. (3) The effectiveness of our approach has been effectively validated by two benchmarks, setting a new state of the art. Knowledge Attention Tumor epithelial tissue is .... Necrosis tissue is .... Lymphocyte tissue is .... Tumor-associated stroma tissue is.....
Knowledge BERT Input
Knowledge Features reshape
Image Features
Knowledge Input
SIM
GAP
stage 4
1 1 0 1
Search from Internet reshape tumor epithelial tissue necrosis tissue lymphocyte tissue tumor-associated stroma tissue
SIM
stage 3 reshape
stage 2
Label Input
SIM stage 1
GAP
Label Input
CLIP
CLIP: MedCLIP
GAP
FC: FC+ReLU+FC
BERT: CLinicalBert Input Image
Input Image
SIM: Pixel-label Correlation
Fig. 2. The framework of the proposed TPRO.
1 1 0 1 1 1 0 1
112
2
S. Zhang et al.
Method
Figure 2 displays the proposed TPRO framework, a classification network designed to train a suitable model and extract segmentation pseudo-labels. The framework comprises a knowledge attention module and three encoders: one vision encoder and two text encoders (label encoder and knowledge encoder). 2.1
Classification with Deep Text Guidance
Vision Encoder. The vision encoder is composed of four stages that encode the input image into image features. The image features are denoted as Ts ∈ RMs ×Cs , where 2 ≤ s ≤ 4 indicates the stage number. Label Encoder. The label encoder encodes the text labels in the dataset into N label features, denoted as L ∈ RN ×Cl , where N represents the number of classes in the dataset and Cl represents the dimension of label features. Since the label features will be used to calculate the similarity with image features, it is important to choose a language model that has been pre-trained on image-text pairs. Here we use MedCLIP1 as our label encoder, which is a model fine-tuned on the ROCO dataset [12] based on CLIP [14]. Knowledge Encoder. The knowledge encoder is responsible for embedding the descriptions of subtype manifestations into knowledge features, denoted as K ∈ RN ×Ck . The knowledge features guide the image features to focus on regions relevant to the target tissue. To encode the subtype manifestations description into more general semantic features, we employ ClinicalBert [2] as our knowledge encoder. ClinicalBert is a language model that has been fine-tuned on the MIMIC-III [8] dataset based on BioBert [9]. Adaptive Layer. We freeze the label and knowledge encoders for training efficiency but add an adaptive layer after the text encoders to better tailor the text features to our dataset. The adaptive layer is a simple FC-ReLU-FC block that allows for fine-tuning of the features extracted from the text encoders. Label-Pixel Correlation. After the input image and text labels are embedded. We employ the inner product to compute the similarity between image features and label features, denoted as Fs . Specially, we first reshape the image features from a token format into feature maps. We denote the feature map as Is ∈ RHs ×Ws ×Cs , where Hs and Ws mean the height and width of the feature map. Fs is computed with the below formula Fs [i, j, k] = Is [i, j] · L[k] ∈ RHs ×Ws ×N .
(1)
Then, we perform a global average-pooling operation on the produced similarity map to obtain the class prediction, denoted as Ps ∈ R1×N . We then calculate the binary cross-entropy loss between the class label Y ∈ R1×N and the class prediction Ps to supervise the model training, which is formulated as: 1
https://github.com/Kaushalya/medclip.
TPRO for Weakly Supervised Histopathology Tissue Segmentation
Ls = −
N 1 Y [n]log σ(Ps [n]) + (1 − Y [n])log[1 − σ(Ps [n])] N n=1
113
(2)
Deep Supervision. To leverage the shallow features in the network, we employ a deep supervision strategy by calculating the similarity between the image features from different stages and the label features from different adaptive layers. Class predictions are derived from these similarity maps. The loss of the entire network is computed as: L = λ2 L2 + λ3 L3 + λ4 L4 .
2.2
(3)
Knowledge Attention Module
To enhance the model’s understanding of the color, morphology, and relationships between different tissues, we gather text representations of different subtype manifestations from the Internet and encode them into external knowledge via the knowledge encoder. The knowledge attention module uses this external knowledge to guide the image features toward relevant regions of the target tissues. The knowledge attention module, shown in Fig. 2, consists of two multi-head self-attention modules. The image features T4 ∈ RM4 ×C4 and knowledge features after adaptive layer K ∈ RN ×C4 are concatenated in the token dimension to obtain Tf use ∈ R(M4 +N )×C4 . This concatenated feature is then fed into the knowledge attention module for self-attention calculation. The output tokens are split, and the part corresponding to the image features is taken out. Noted that the knowledge attention module is added only after the last stage of the vision encoder to save computational resources. 2.3
Pseudo Label Generation
In the classification process, we calculate the similarity between image features and label features to obtain a similarity map F , and then directly use the result of global average pooling on the similarity map as a class prediction. That is, the value at position (i, j, k) of F represents the probability that pixel (i, j) is classified into the kth class. Therefore we directly use F as our localization map. We first perform min-max normalization on it, the formula is as follows Ffcg =
F c − min(F c ) , max(F c ) − min(F c )
(4)
where 1 ≤ c ≤ N means cth class in the dataset. Then we calculate the background localization map by the following formula: Fbg (i, j) = {1 − max Ffcg (i, j)}α , c∈[0,C)
(5)
where α ≥ 1 denotes a hyper-parameter that adjusts the background confidence scores. Referring to [1] and combined with our own experiments, we set α to
114
S. Zhang et al.
10. Then we stitch together the localization map of foreground and background, denoted as Fˆ . In order to make full use of the shallow information of the network, we perform weighted fusion on the localization maps from different stages by the following formula: (6) Fall = γ2 Fˆ2 + γ3 Fˆ4 + γ4 Fˆ4 . Finally, we perform argmax operation on Fall to obtain the final pseudo-label.
3
Experiments
3.1
Dataset
LUAD-HistoSeg2 [7] is a weakly-supervised histological semantic segmentation dataset for lung adenocarcinoma. There are four tissue classes in this dataset: tumor epithelial (TE), tumor-associated stroma (TAS), necrosis (NEC), and lymphocyte (LYM). The dataset comprises 17,258 patches of size 224×224. According to the official split, the dataset is divided into a training set (16,678 patch-level annotations), a validation set (300 pixel-level annotations), and a test set (307 pixel-level annotations). BCSS-WSSS3 is a weakly supervised tissue semantic segmentation dataset extracted from the fully supervised segmentation dataset BCSS [3], which contains 151 representative H&E-stained breast cancer pathology slides. The dataset was randomly cut into 31826 patches of size 224 × 224 and divided into a training set (23422 patch-level annotations), a validation set (3418 pixel-level annotations), and a test set (4986 pixel-level annotations) according to the official split. There are four foreground classes in this dataset, including Tumor (TUM), Stroma (STR), Lymphocytic infiltrate (LYM), and Necrosis (NEC). 3.2
Implementation Details
For the classification part, we adopt MixTransformer [17] pretrained on ImageNet, MedCLIP, and ClinicalBert [2] as our vision encoder, label encoder, and Table 1. Comparison of the pseudo labels generated by our proposed method and those generated by previous methods. Dataset
LUAD-HistoSeg
Method
TE
CAM [20] 69.66 70.07 Grad-CAM [15] TransWS (CAM) [18] 65.92 71.72 MLPS [7] 74.82 TPRO (Ours) 2 3
BCSS-WSSS
NEC
LYM
TAS
mIoU TUM STR
LYM
NEC
mIoU
72.62 66.01 60.16 76.27 77.55
72.58 70.18 73.34 73.53 76.40
66.88 64.76 69.11 67.67 70.98
70.44 67.76 67.13 72.30 74.94
49.41 43.36 44.96 50.87 54.95
51.12 30.04 50.60 52.94 61.43
56.52 49.02 54.64 58.91 64.33
66.83 65.96 64.85 70.76 77.18
58.71 56.71 58.17 61.07 63.77
https://drive.google.com/drive/folders/1E3Yei3Or3xJXukHIybZAgochxfn6FJpr. https://drive.google.com/drive/folders/1iS2Z0DsbACqGp7m6VDJbAcgzeXNEFr77.
TPRO for Weakly Supervised Histopathology Tissue Segmentation
115
knowledge encoder, respectively. The hyperparameters during training and evaluation can be found in the supplementary materials. We conduct all of our experiments on 2 NVIDIA GeForce RTX 2080 Ti GPUs. 3.3
Compare with State-of-the-Arts
Comparison on Pseudo-Labels. Table 1 compares the quality of our pseudolabels with those generated by previous methods. CAM [20] and Grad-CAM [15] were evaluated using the same ResNet38 [16] classifier, and the results showed that CAM [20] outperformed Grad-CAM [15], with mIoU values of 70.44% and 56.52% on the LUAD-HistoSeg and BCSS-WSSS datasets, respectively. TransWS [18] consists of a classification and a segmentation branch, and Table 1 displays the pseudo-label scores generated by the classification branch. Despite using CAM [20] for pseudo-label extraction, TransWS [18] yielded inferior results compared to CAM [20]. This could be due to the design of TransWS [18] for single-label image segmentation, with the segmentation branch simplified to binary segmentation to reduce the difficulty, while our dataset consists of multilabel images. Among the compared methods, MLPS [7] was the only one to surpass CAM [20] in terms of the quality of the generated pseudo-labels, with its proposed progressive dropout attention effectively expanding the coverage of target regions beyond what CAM [20] can achieve. Our proposed method outperformed all previous methods on both LUAD-HistoSeg and BCSS-WSSS datasets, with improvements of 2.64% and 5.42% over the second-best method, respectively (Table 2). Table 2. Comparison of the final segmentation results between our method and the methods in previous years. Dataset
LUAD-HistoSeg
Method
TE
NEC
LYM
TAS
mIoU TUM STR
BCSS-WSSS LYM
NEC
mIoU
HistoSegNet [5] TransWS (seg) [18] OEEM [11] MLPS [7] TPRO (Ours)
45.59 57.04 73.81 73.90 75.80
36.30 49.98 70.49 77.48 80.56
58.28 59.46 71.89 73.61 78.14
50.82 58.59 69.48 69.53 72.69
47.75 56.27 71.42 73.63 76.80
29.05 41.72 48.91 52.54 54.55
1.91 38.08 61.03 58.67 64.96
27.64 40.25 62.37 62.55 65.64
33.14 44.71 74.86 74.54 77.95
46.46 36.49 64.68 64.45 65.10
Comparison on Segmentation Results. To further evaluate our proposed method, we trained a segmentation model using the extracted pseudo-labels and compared its performance with previous methods. Due to its heavy reliance on dataset-specific post-processing steps, HistoSegNet [5] failed to produce the desired results on our datasets. As we have previously analyzed since the datasets we used are all multi-label images, it was challenging for the segmentation branch of TransWS [18] to perform well, and it failed to provide an overall benefit to the model. Experimental results also indicate that the IoU scores of its segmentation
116
S. Zhang et al.
Table 3. Comparison the effectiveness of Table 4. Comparison of pseudo labels label text(LT), knowledge text(KT), and extracted from the single stage and our deep supervision(DS). fused version. NEC
LYM
TAS
mIoU
NEC
LYM
TAS
mIoU
68.11
75.24
64.95
66.57
68.72
stage2 67.16
65.28
67.38
55.09
63.73
72.39
72.44
71.37
68.67
71.22
stage3 72.13
70.83
73.47
69.46
71.47
72.41
72.11
74.21
70.07
72.20
stage4 72.69
77.57
76.06
69.81
74.03
LT DS KT TE
74.82 77.55 76.40 70.98 74.94
TE
fusion 74.82 77.55 76.40 70.98 74.94
branch were even lower than the pseudo-labels of the classification branch. By training the segmentation model of OEEM [11] using the pseudo-labels extracted by CAM [20] in Table 1, we can observe a significant improvement in the final segmentation results. The final segmentation results of MLPS [7] showed some improvement compared to its pseudo-labels, indicating the effectiveness of the Multi-layer Pseudo Supervision and Classification Gate Mechanism strategy proposed by MLPS [7]. Our segmentation performance surpassed all previous methods. Specifically, our mIoU scores exceeded the second-best method by 3.17% and 3.09% on LUAD-HistoSeg and BCSS-WSSS datasets, respectively. Additionally, it is worth noting that we did not use any strategies specifically designed for the segmentation stage. 3.4
Ablation Study
The results of our ablation experiments are presented in Table 3. We set the baseline as the framework shown in Fig. 2 with all text information and deep supervision strategy removed. It is evident that the addition of textual information increases our pseudo-label mIoU by 2.50%. Furthermore, including the deep supervision strategy and knowledge attention module improves our pseudo-label by 0.98% and 2.74%, respectively. These findings demonstrate the significant contribution of each proposed module to the overall improvement of the results. In order to demonstrate the effectiveness of fusing pseudo-labels from the last three stages, we have presented in Table 4 the IoU scores for each stage’s pseudolabels as well as the fused pseudo-labels. It can be observed that after fusing the pseudo-labels, not only have the IoU scores for each class substantially increased, but the mIoU score has also increased by 0.91% compared to the fourth stage.
4
Conclusion
In this paper, we propose the TPRO to address the limitation of weakly supervised semantic segmentation on histopathology images by incorporating text supervision and external knowledge. We argue that image-level labels alone cannot provide sufficient information and that text supervision and knowledge attention can provide additional guidance to the model. The proposed method
TPRO for Weakly Supervised Histopathology Tissue Segmentation
117
achieves the best results on two public datasets, LUAD-HistoSeg and BCSSWSSS, demonstrating the superiority of our method. Acknowledgment. This work was supported in part by the Natural Science Foundation of Ningbo City, China, under Grant 2021J052, in part by the Ningbo Clinical Research Center for Medical Imaging under Grant 2021L003 (Open Project 2022LYKFZD06), in part by the National Natural Science Foundation of China under Grant 62171377, in part by the Key Technologies Research and Development Program under Grant 2022YFC2009903/2022YFC2009900, in part by the Key Research and Development Program of Shaanxi Province, China, under Grant 2022GY-084, and in part by the Science and Technology Innovation Committee of Shenzhen Municipality, China, under Grants JCYJ20220530161616036.
References 1. Ahn, J., Kwak, S.: Learning pixel-level semantic affinity with image-level supervision for weakly supervised semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4981–4990 (2018) 2. Alsentzer, E., et al.: Publicly available clinical bert embeddings. arXiv preprint arXiv:1904.03323 (2019) 3. Amgad, M., et al.: Structured crowdsourcing enables convolutional segmentation of histology images. Bioinformatics 35(18), 3461–3467 (2019) 4. Chan, L., Hosseini, M.S., Plataniotis, K.N.: A comprehensive analysis of weaklysupervised semantic segmentation in different image domains. Int. J. Comput. Vision 129, 361–384 (2021) 5. Chan, L., Hosseini, M.S., Rowsell, C., Plataniotis, K.N., Damaskinos, S.: Histosegnet: semantic segmentation of histological tissue type in whole slide images. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10662–10671 (2019) 6. Chen, H., Qi, X., Yu, L., Heng, P.A.: Dcan: deep contour-aware networks for accurate gland segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2487–2496 (2016) 7. Han, C., et al.: Multi-layer pseudo-supervision for histopathology tissue semantic segmentation using patch-level classification labels. Med. Image Anal. 80, 102487 (2022) 8. Johnson, A.E., et al.: Mimic-iii, a freely accessible critical care database. Sci. Data 3(1), 1–9 (2016) 9. Lee, J., et al.: Biobert: a pre-trained biomedical language representation model for biomedical text mining. Bioinformatics 36(4), 1234–1240 (2020) 10. Lee, S., Lee, M., Lee, J., Shim, H.: Railroad is not a train: saliency as pseudopixel supervision for weakly supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 5495–5505 (2021) 11. Li, Y., Yu, Y., Zou, Y., Xiang, T., Li, X.: Online easy example mining for weaklysupervised gland segmentation from histology images. In: Medical Image Computing and Computer Assisted Intervention-MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part IV. pp. 578–587. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16440-8_55
118
S. Zhang et al.
12. Pelka, O., Koitka, S., Rückert, J., Nensa, F., Friedrich, C.M.: Radiology Objects in COntext (ROCO): a multimodal image dataset. In: Stoyanov, D., et al. (eds.) LABELS/CVII/STENT -2018. LNCS, vol. 11043, pp. 180–189. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-01364-6_20 13. Qaiser, T., et al.: Fast and accurate tumor segmentation of histology images using persistent homology and deep convolutional features. Med. Image Anal. 55, 1–14 (2019) 14. Radford, A., et al.: Learning transferable visual models from natural language supervision. In: International Conference on Machine Learning, pp. 8748–8763. PMLR (2021) 15. Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Gradcam: visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 618– 626 (2017) 16. Wu, Z., Shen, C., Van Den Hengel, A.: Wider or deeper: revisiting the resnet model for visual recognition. Pattern Recogn. 90, 119–133 (2019) 17. Xie, E., Wang, W., Yu, Z., Anandkumar, A., Alvarez, J.M., Luo, P.: Segformer: simple and efficient design for semantic segmentation with transformers. Adv. Neural. Inf. Process. Syst. 34, 12077–12090 (2021) 18. Zhang, S., Zhang, J., Xia, Y.: Transws: Transformer-based weakly supervised histology image segmentation. In: Machine Learning in Medical Imaging: 13th International Workshop, MLMI 2022, Held in Conjunction with MICCAI 2022, Singapore, September 18, 2022, Proceedings. pp. 367–376. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-21014-3_38 19. Zhao, B., et al.: Triple u-net: hematoxylin-aware nuclei segmentation with progressive dense feature aggregation. Med. Image Anal. 65, 101786 (2020) 20. Zhou, B., Khosla, A., Lapedriza, A., Oliva, A., Torralba, A.: Learning deep features for discriminative localization. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2921–2929 (2016)
Additional Positive Enables Better Representation Learning for Medical Images Dewen Zeng1(B) , Yawen Wu2 , Xinrong Hu1 , Xiaowei Xu3 , Jingtong Hu2 , and Yiyu Shi1(B) 1
3
University of Notre Dame, Notre Dame, IN, USA {dzeng2,yshi4}@nd.edu 2 University of Pittsburgh, Pittsburgh, PA, USA Guangdong Provincial People’s Hospital, Guangzhou, China
Abstract. This paper presents a new way to identify additional positive pairs for BYOL, a state-of-the-art (SOTA) self-supervised learning framework, to improve its representation learning ability. Unlike conventional BYOL which relies on only one positive pair generated by two augmented views of the same image, we argue that information from different images with the same label can bring more diversity and variations to the target features, thus benefiting representation learning. To identify such pairs without any label, we investigate TracIn, an instance-based and computationally efficient influence function, for BYOL training. Specifically, TracIn is a gradient-based method that reveals the impact of a training sample on a test sample in supervised learning. We extend it to the self-supervised learning setting and propose an efficient batchwise per-sample gradient computation method to estimate the pairwise TracIn for representing the similarity of samples in the mini-batch during training. For each image, we select the most similar sample from other images as the additional positive and pull their features together with BYOL loss. Experimental results on two public medical datasets (i.e., ISIC 2019 and ChestX-ray) demonstrate that the proposed method can improve the classification performance compared to other competitive baselines in both semi-supervised and transfer learning settings. Keywords: self-supervised learning image classification
1
· representation learning · medical
Introduction
Self-supervised learning (SSL) has been extremely successful in learning good image representations without human annotations for medical image applications like classification [1,23,29] and segmentation [2,4,16]. Usually, an encoder is pre-trained on a large-scale unlabeled dataset. Then, the pre-trained encoder is Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_12. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 119–129, 2023. https://doi.org/10.1007/978-3-031-43907-0_12
120
D. Zeng et al.
used for efficient training on downstream tasks with limited annotation [19,24]. Recently, contrastive learning has become the state-of-the-art (SOTA) SSL method due to its powerful learning ability. A recent contrastive learning method learns by pulling the representations of different augmented views of the same image (a.k.a positive pair) together and pushing the representation of different images (a.k.a negative pair) apart [6]. The main disadvantage of this method is its heavy reliance on negative pairs, making it necessary to use a large batch size [6] or memory banks [15] to ensure effective training. To overcome this challenge, BYOL [12] proposes two siamese neural networks - the online and target networks. The online network is trained to predict the target network representation of the same image under a different augmented view, requiring only one positive pair per sample. This approach makes BYOL more resilient to batch size and the choice of data augmentations. As the positive pair in BYOL is generated from the same image, the diversity of features within the positive pair could be quite limited. For example, one skin disease may manifest differently in different patients or locations, but such information is often overlooked in the current BYOL framework. In this paper, we argue that such feature diversity can be increased by adding additional positive pairs from other samples with the same label (a.k.a. True Positives). Identifying such pairs without human annotation is challenging because of the unrelated information in medical images, such as the background normal skin areas in dermoscopic images. One straightforward way to detect positive pairs is using feature similarity: two images are considered positive if their representations are close to each other in the feature space. However, samples with different labels might also be close in the feature space because the learned encoder is not perfect. Considering them as positive might further pull them together after learning, leading to degraded performance. To solve this problem, we propose BYOL-TracIn, which improves vanilla BYOL using the TracIn influence function. Instead of quantifying the similarity of two samples based on feature similarity, we propose using TracIn to estimate their similarity by calculating the impact of training one sample on the other. TracIn [22] is a gradient-based influence function that measures the loss reduction of one sample by the training process of another sample. Directly applying TracIn in BYOL is non-trivial as it requires the gradient of each sample and careful selection of model checkpoints and data augmentations to accurately estimate sample impacts without labels. To avoid per-sample gradient computation, we introduce an efficient method that computes the pairwise TracIn in a mini-batch with only one forward pass. For each image in the mini-batch, the sample from other images with the highest TracIn values is selected as the additional positive pair. Their representation distance is then minimized using BYOL loss. To enhance positive selection accuracy, we propose to use a pre-trained model for pairwise TracIn computation as it focuses more on task-related features compared to an on-the-fly model. Light augmentations are used on the samples for TracIn computation to ensure stable positive identification. To the best of our knowledge, we are the first to incorporate additional positive pairs from different images in BYOL. Our extensive empirical results show that our proposed
Additional Positive Enables Better Representation Learning
121
method outperforms other competing approaches in both semi-supervised and transfer learning settings for medical image classification tasks.
2
Related Work
Self-supervised Learning. Most SSL methods can be categorized as either generative [10,28] or discriminative [11,21], in which pseudo labels are automatically generated from the inputs. Recently, contrastive learning [6,15,27] as a new discriminative SSL method has dominated this field because of its excellent performance. SimCLR [6] and MoCo [15] are two typical contrastive learning methods that try to attract positive pairs and repulse negative pairs. However, these methods rely on a large number of negative samples to work well. BYOL [12] improves contrastive learning by directly predicting the representation output from another view and achieves SOTA performance. As such, only positive pairs are needed for training. SimSiam [7] further proves that stop-gradient plays an essential role in the learning stability of siamese neural networks. Since the positive pairs in BYOL come from the same image, the feature diversity from different images of the same label is ignored. Our method introduces a novel way to accurately identify such positive pairs and attract them in the feature space. Influence Function. The influence function (IF) was first introduced to machine learning models in [20] to study the following question: which training points are most responsible for a given prediction? Intuitively, if we remove an important sample from the training set, we will get a large increase in the test loss. IF can be considered as an interpretability score that measures the importance of all training samples on the test sample. Aside from IF, other types of scores and variants have also been proposed in this field [3,5,14]. Since IF is extremely computationally expensive, TracIn [22] was proposed as an efficient alternative to estimate training sample influence using first-order approximation. Our method extends the normal TracIn to the SSL setting (i.e., BYOL) with a sophisticated positive pair selection schema and an efficient batch-wise per-sample gradient computation method, demonstrating that aside from model interpretation, TracIn can also be used to guide SSL pre-training.
3 3.1
Method Framework Overview
Our BYOL-TracIn framework is built upon classical BYOL method [12]. Figure 1 shows the overview of our framework. Here, we use x1 as the anchor sample for an explanation, and the same logic can be applied to all samples in the mini-batch. Unlike classical BYOL where only one positive pair (x1 and x1 ) generated from the same image is utilized, we use the influence function, TracIn, to find another sample (x3 ) from the batch that has the largest impact on the anchor sample. During training, the representations distance of x1 and x3 will also be minimized. We think this additional positive pair can increase the variance and diversity of
122
D. Zeng et al.
Fig. 1. Overview of the proposed BYOL-TracIn framework. X and X represent two augmentations of the mini-batch inputs. BYOL-TracIn minimizes the similarity loss of two views of the same image (e.g., q1 and z1 ) as well as the similarity loss of the additional positive (e.g., z3 ) identified by our TracIn algorithm. sg means stop-gradient.
the features of the same label, leading to better clustering in the feature space and improved learning performance. The pairwise TracIn matrix is computed using first-order gradient approximation which will be discussed in the next section. For simplicity, this paper only selects the top-1 additional sample, but our method can be easily extended to include top-k (k > 1) additional samples. 3.2
Additional Positive Selection Using TracIn
Idealized TracIn and Its First-order Approximation. Suppose we have a training dataset D = {x1 , x2 , ..., xn } with n samples. fw (·) is a model with parameter w ∈ R, and (w, xi ) is the loss function when model parameter is w and training example is xi . The training process in iteration t can be viewed as minimizing the training loss (wt , xt ) and updating parameter wt to wt+1 using gradient descent (suppose only xt ∈ D is used for training in each iteration). Then the idealized TracIn of one sample xi on another sample xk can be defined as the total loss reduction by training xi in the whole training process. TracInIdeal(xi , xk ) =
T
((wt , xk ) − (wt+1 , xk )).
(1)
t:xt =xi
where T is the total number of iterations. If stochastic gradient descent is utilized as the optimization method, we can approximately express the loss reduction after iteration t as (wt+1 , xk )−(wt , xk ) = (wt , xk )·(wt+1 −wt )+O(||Δwt ||2 ). The parameter change in iteration t is Δwt = wt+1 − wt = −ηt (wt , xt ), in which ηt is the learning rate in iteration t, and xt is the training example. Since ηt is usually small during training, we can ignore the high order term O(||Δwt ||2 ), and the first-order TracIn can be formulated as: TracIn(xi , xk ) =
T t:xt =xi
ηt (wt , xk ) · (wt , xi ).
(2)
Additional Positive Enables Better Representation Learning
123
The above equation reveals that we can estimate the influence of xi on xk by summing up their gradient dot products across all training iterations. In practical BYOL training, the optimization is usually done on mini-batches, and it is impossible to save the gradients of a sample for all iterations. However, we can use the TracIn of the current iteration to represent the similarity of two samples in the mini-batch because we care about the pairwise relative influences instead of the exact total values across training. Intuitively, if the TracIn of two samples is large in the current iteration, this means that the training of one sample can benefit the other sample a lot because they share some common features. Therefore, they are similar to each other. Efficient Batch-wise TracIn Computation. Equation 2 requires the gradient of each sample in the mini-batch for pairwise TracIn computation. However, it is prohibitively expensive to compute the gradient of samples one by one. Moreover, calculating the dot product of gradients on the entire model is computationally and memory-intensive, especially for large deep-learning models where there could be millions or trillions of parameters. Therefore, we work with the gradients of the last linear layer in the online predictor. As current deep learning frameworks (e.g., Pytorch and TensorFlow) do not support per-sample gradient when the batch size is larger than 1, we use the following method to efficiently compute the per-sample gradient of the last layer. Suppose the weight matrix of the last linear layer is W ∈ Rm×n , where m and n are the numbers of input and output units. f (q) = 2 − 2 · q, z/(q2 · z2 ) is the standard BYOL loss function, where q is the online predictor output (a.k.a., logits) and z is the target encoder output that can be viewed as a constant during training. We have q = W a, where a is the input to the last linear layer. According to the chain rule, the gradient of the last linear layer can be computed as W f (q) = q f (q)aT , in which the gradient of the logits can be computed by: q f (q) = 2 · (
q, z · q z − ). 3 q2 · z2 q2 · z2
(3)
Therefore, the TracIn of sample xi and xk at iteration t can be computed as: TracIn(xi , xk ) ≈ ηt W f (qi ) · W f (qk ) = ηt (q f (qi ) · q f (qk ))(ai · ak ).
(4)
Equation 3 and 4 tell us that the per-sample gradient of the last linear layer can be computed by using the inputs of this layer and the gradient of the output logits for each sample, which can be achieved with only one forward pass on the mini-batch. This technique greatly speeds up the TracIn computation and makes it possible to be used in BYOL. Using Pre-trained Model to Increase True Positives. During the pretraining stage of BYOL, especially in the early stages, the model can be unstable and may focus on unrelated features in the background instead of the target features. This can result in the selection of wrong positive pairs while using TracIn. For example, the model may identify all images with skin diseases on
124
D. Zeng et al.
the face as positive pairs, even if they are from different diagnostics, as it focuses on the face feature instead of the diseases. To address this issue, we suggest using a pre-trained model to select additional positives with TracIn to guide BYOL training. This is because a pre-trained model is more stable and well-trained to focus on the target features, thus increasing the selected true positive ratio.
4 4.1
Experiments and Results Experimental Setups
Datasets. We evaluate the performance of the proposed BYOL-TracIn on four publicly available medical image datasets. (1) ISIC 2019 dataset is a dermatology dataset that contains 25,331 dermoscopic images among nine different diagnostic categories [8,9,25]. (2) ISIC 2016 dataset was hosted in ISBI 2016 [13]. It contains 900 dermoscopic lesion images with two classes benign and malignant. (3) ChestX-ray dataset is a chest X-ray database that comprises 108,948 frontal view X-ray images of 32,717 unique patients with 14 disease labels [26]. Each image may have multiple labels. (4) Shenzhen dataset is a small chest X-ray dataset with 662 frontal chest X-rays, of which 326 are normal cases and 336 are cases with manifestations of Tuberculosis [18]. Training Details. We use Resnet18 as the backbone. The online projector and predictor follow the classical BYOL [12], and the embedding dimension is set to 256. On both ISIC 2019 and ChestX-ray datasets, we resize all the images to 140×140 and then crop them to 128×128. Data augmentation used in pretraining includes horizontal flipping, vertical flipping, rotation, color jitter, and cropping. For TracIn computation, we use one view with no augmentation and the other view with horizontal flipping and center cropping because this setting has the best empirical results in our experiments. We pre-train the model for 300 epochs using SGD optimizer with momentum 0.9 and weight decay 1 × e−5 . The learning rate is set to 0.1 for the first 10 epochs and then decays following a concise learning rate schedule. The batch size is set to 256. The moving average decay of the momentum encoder is set to 0.99 at the beginning and then gradually updates to 1 following a concise schedule. All experiments are performed on one NVIDIA GeForce GTX 1080 GPU. Baselines. We compare the performance of our method with a random initialization approach without pre-training and the following SOTA baselines that involve pre-training. (1) BYOL [12]: the vanilla BYOL with one positive pair from the same image. (2) FNC [17]: a false negative identification method designed to improve contrastive-based SSL framework. We adapt it to BYOL to select additional positives because false negatives are also equal to true positives for a particular anchor sample. (3) FT [30]: a feature transformation method used in contrastive learning that creates harder positives and negatives to improve the learning ability. We apply it in BYOL to create harder virtual positives. (4) FS: using feature similarity from the current mini-batch to select the top-1 additional positive. (5) FS-pretrained: different from the FS that uses the current
Additional Positive Enables Better Representation Learning
125
Table 1. Comparison of all methods on ISIC 2019 and ChestX-ray datasets in the semisupervised setting. We also report the fine-tuning results on 100% datasets. BYOL-Sup is the upper bound of our method. BMA represents the balanced multiclass accuracy. Method
Random BYOL [12] FNC [17] FT [30] FS FS-pretrained
10%
ISIC 2019 50% BMA ↑
100%
10%
ChestX-ray 50% AUC ↑
100%
0.327(.004) 0.399(.001) 0.401(.004) 0.405(.005) 0.403(.006) 0.406(.002)
0.558(.005) 0.580(.006) 0.584(.004) 0.588(.008) 0.591(.003) 0.596(.004)
0.650(.004) 0.692(.005) 0.694(.005) 0.695(.005) 0.694(.004) 0.697(.005)
0.694(.005) 0.699(.004) 0.706(.001) 0.708(.001) 0.705(.003) 0.709(.001)
0.736(.001) 0.738(.003) 0.739(.001) 0.743(.001) 0.738(.001) 0.744(.002)
0.749(.001) 0.750(.001) 0.752(.002) 0.751(.002) 0.752(.002) 0.752(.002)
BYOL-TracIn 0.403(.003) 0.594(.004) 0.694(.004) 0.705(.001) 0.742(.003) 0.753(.002) BYOL-TracIn-pretrained 0.408(.007) 0.602(.003) 0.700(.006) 0.712(.001) 0.746(.002) 0.754(.002) BYOL-Sup
0.438(.006)
0.608(.007)
0.705(.005)
0.714(.001)
0.748(.001)
0.756(.003)
model to compute the feature similarity on the fly, we use a pre-trained model to test whether a well-trained encoder is more helpful in identifying the additional positives. (6) BYOL-Sup: the supervised BYOL in which we randomly select one additional positive from the mini-batch using the label information. This baseline is induced as the upper bound of our method because the additional positive is already correct. We evaluate two variants of our method, BYOL-TracIn and BYOL-TracIn-pretrained. The former uses the current training model to compute the TracIn for each iteration while the latter uses a pre-trained model. For a fair comparison, all methods use the same pre-training and finetuning setting unless otherwise specified. For FS-pretrained and BYOL-TracIn-pretrained, the pre-trained model uses the same setting as BYOL. Note that this pre-trained model is only used for positive selection and not involves in training. 4.2
Semi-supervised Learning
In this section, we evaluate the performance of our method by finetuning with the pre-trained encoder on the same dataset as pre-training with limited annotations. We sample 10% or 50% of the labeled data from ISIC 2019 and ChestXray training sets and finetune the model for 100 epochs on the sampled datasets. Data augmentation is the same as pre-training. Table 1 shows the comparisons of all methods. For ISIC 2019, we report the balanced multiclass accuracy (BMA, suggested by the ISIC challenge). For ChestX-ray, we report the average AUC across all diagnoses. We conduct each finetuning experiment 5 times with different random seeds and report the mean and std. From Table 1, we have the following observations: (1) Compared to Random, all the other methods have better accuracy, which means that pre-training can indeed help downstream tasks. (2) Compared to vanilla BYOL, other pretraining methods show performance improvement on both datasets. This shows that additional positives can increase feature diversity and benefit BYOL learning. (3) Our BYOL-TracIn-pretrained consistently outperforms all other unsu-
126
D. Zeng et al.
Anchor image Top-3 most similar images in a mini-batch
Label:NV
Label:NV Label:NV Label:NV TracIn:0.023 TracIn:0.018 TracIn:0.016
Label:MEL Label:MEL Label:NV FS:0.907 FS:0.894 FS:0.892
Fig. 2. Comparison of TracIn and Feature Similarity (FS) in selecting the additional positive during training on ISIC 2019.
Table 2. Transfer learning comparison of the proposed method with the baselines on ISIC 2016 and Shenzhen datasets. Method
ISIC 2016
Shenzhen
Precision ↑
AUC ↑
Random
0.400(.005)
0.835(.010)
BYOL [12]
0.541(.008)
0.858(.003)
FNC [17]
0.542(.007)
0.862(.006)
FT [30]
0.559(.011)
0.876(.005)
FS
0.551(.003)
0.877(.004)
FS-pretrained
0.556(.004)
0.877(.006)
BYOL-TracIn
0.555(.012)
0.880(.007)
BYOL-TracInpretrained
0.565(.010) 0.883(.001)
BYOL-Sup
0.592(.008)
0.893(.006)
pervised baselines. Although BYOL-TracIn can improve BYOL, it could be worse than other baselines like FT and FS-pretrained (e.g., 10% on ISIC 2019). This is because some additional positives identified by the on-the-fly model may be false positives, and attracting representations of such samples will degrade the learned features. However, with a pre-trained model in BYOL-TracIn-pretrained, the identification accuracy can be increased, leading to more true positives and better representations. (4) TracIn-pretrained performs better than FS-pretrained in all settings, and the improvement in BMA could be up to 0.006. This suggests that TracIn can be a more reliable metric for assessing the similarity between images when there is no human label information available. (5) Supervised BYOL can greatly increase the BYOL performance on both datasets. Yet our BYOLTracIn-pretrained only has a marginal accuracy drop from supervised BYOL with a sufficient number of training samples (e.g., 100% on ISIC 2019). To further demonstrate the superiority of TracIn over Feature Similarity (FS) in selecting additional positive pairs for BYOL, we use an image from ISIC 2019 as an example and visualize the top-3 most similar images selected by both metrics using a BYOL pre-trained model in Fig. 2. We can observe that TracIn accurately identifies the most similar images with the same label as the anchor image, whereas two of the images selected by FS have different labels. This discrepancy may be attributed to the fact that the FS of these two images is dominated by unrelated features (e.g., background tissue), which makes it unreliable. More visualization examples can be found in the supplementary. 4.3
Transfer Learning
To evaluate the transfer learning performance of the learned features, we use the encoder learned from the pre-training to initialize the model on the downstream datasets (ISIC 2019 transfers to ISIC 2016, and ChestX-ray transfers to Shenzhen). We finetune the model for 50 epochs and report the precision and AUC on ISIC 2016 and Shenzhen datasets, respectively. Table 2 shows the comparison
Additional Positive Enables Better Representation Learning
127
results of all methods. We can see that BYOL-TracIn-pretrained always outperforms other unsupervised pre-training baselines, indicating that the additional positives can help BYOL learn better transferrable features.
5
Conclusion
In this paper, we propose a simple yet effective method, named BYOL-TracIn, to boost the representation learning performance of the vanilla BYOL framework. BYOL-TracIn can effectively identify additional positives from different samples in the mini-batch without using label information, thus introducing more variances to learned features. Experimental results on multiple public medical image datasets show that our method can significantly improve classification performance in both semi-supervised and transfer learning settings. Although this paper only discusses the situation of one additional pair for each image, our method can be easily extended to multiple additional pairs. However, more pairs will introduce more computation costs and increase the false positive rate which may degrade the performance. Another limitation of this paper is that BYOL-TracIn requires a pre-trained model to start with, which means more computation resources are needed to demonstrate its effectiveness.
References 1. Azizi, S., et al.: Big self-supervised models advance medical image classification. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3478–3488 (2021) 2. Bai, W., et al.: Self-supervised learning for cardiac MR image segmentation by anatomical position prediction. In: Shen, D., Liu, T., Peters, T.M., Staib, L.H., Essert, C., Zhou, S., Yap, P.-T., Khan, A. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 541–549. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-322458_60 3. Barshan, E., Brunet, M.E., Dziugaite, G.K.: Relatif: identifying explanatory training samples via relative influence. In: International Conference on Artificial Intelligence and Statistics, pp. 1899–1909. PMLR (2020) 4. Chaitanya, K., Erdil, E., Karani, N., Konukoglu, E.: Contrastive learning of global and local features for medical image segmentation with limited annotations. Adv. Neural. Inf. Process. Syst. 33, 12546–12558 (2020) 5. Chen, H., et al.: Multi-stage influence function. Adv. Neural. Inf. Process. Syst. 33, 12732–12742 (2020) 6. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607. PMLR (2020) 7. Chen, X., He, K.: Exploring simple siamese representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 15750–15758 (2021) 8. Codella, N., et al.: Skin lesion analysis toward melanoma detection 2018: A challenge hosted by the international skin imaging collaboration (isic). arXiv preprint arXiv:1902.03368 (2019)
128
D. Zeng et al.
9. Combalia, M., et al.: Bcn20000: Dermoscopic lesions in the wild. arXiv preprint arXiv:1908.02288 (2019) 10. Doersch, C., Gupta, A., Efros, A.A.: Unsupervised visual representation learning by context prediction. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1422–1430 (2015) 11. Gidaris, S., Singh, P., Komodakis, N.: Unsupervised representation learning by predicting image rotations. arXiv preprint arXiv:1803.07728 (2018) 12. Grill, J.B., et al.: Bootstrap your own latent-a new approach to self-supervised learning. Adv. Neural. Inf. Process. Syst. 33, 21271–21284 (2020) 13. Gutman, D., et al.: Skin lesion analysis toward melanoma detection: a challenge at the international symposium on biomedical imaging (isbi) 2016, hosted by the international skin imaging collaboration (isic). arXiv preprint arXiv:1605.01397 (2016) 14. Hara, S., Nitanda, A., Maehara, T.: Data cleansing for models trained with sgd. In: Advances in Neural Information Processing Systems 32 (2019) 15. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 9729–9738 (2020) 16. Hu, X., Zeng, D., Xu, X., Shi, Y.: Semi-supervised contrastive learning for labelefficient medical image segmentation. In: de Bruijne, M., Cattin, P.C., Cotin, S., Padoy, N., Speidel, S., Zheng, Y., Essert, C. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 481–490. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087196-3_45 17. Huynh, T., Kornblith, S., Walter, M.R., Maire, M., Khademi, M.: Boosting contrastive self-supervised learning with false negative cancellation. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 2785–2795 (2022) 18. Jaeger, S., Candemir, S., Antani, S., Wáng, Y.X.J., Lu, P.X., Thoma, G.: Two public chest x-ray datasets for computer-aided screening of pulmonary diseases. Quant. Imaging Med. Surg. 4(6), 475 (2014) 19. Jaiswal, A., Babu, A.R., Zadeh, M.Z., Banerjee, D., Makedon, F.: A survey on contrastive self-supervised learning. Technologies 9(1), 2 (2020) 20. Koh, P.W., Liang, P.: Understanding black-box predictions via influence functions. In: International Conference on Machine Learning, pp. 1885–1894. PMLR (2017) 21. Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9910, pp. 69–84. Springer, Cham (2016). https://doi.org/10.1007/9783-319-46466-4_5 22. Pruthi, G., Liu, F., Kale, S., Sundararajan, M.: Estimating training data influence by tracing gradient descent. Adv. Neural. Inf. Process. Syst. 33, 19920–19930 (2020) 23. Sowrirajan, H., Yang, J., Ng, A.Y., Rajpurkar, P.: Moco pretraining improves representation and transferability of chest x-ray models. In: Medical Imaging with Deep Learning, pp. 728–744. PMLR (2021) 24. Tajbakhsh, N., Jeyaseelan, L., Li, Q., Chiang, J.N., Wu, Z., Ding, X.: Embracing imperfect datasets: a review of deep learning solutions for medical image segmentation. Med. Image Anal. 63, 101693 (2020) 25. Tschandl, P., Rosendahl, C., Kittler, H.: The ham10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Sci. Data 5(1), 1–9 (2018)
Additional Positive Enables Better Representation Learning
129
26. Wang, X., Peng, Y., Lu, L., Lu, Z., Bagheri, M., Summers, R.M.: Chestx-ray8: hospital-scale chest x-ray database and benchmarks on weakly-supervised classification and localization of common thorax diseases. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 2097–2106 (2017) 27. Zbontar, J., Jing, L., Misra, I., LeCun, Y., Deny, S.: Barlow twins: self-supervised learning via redundancy reduction. In: International Conference on Machine Learning, pp. 12310–12320. PMLR (2021) 28. Zhang, R., Isola, P., Efros, A.A.: Colorful image colorization. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9907, pp. 649–666. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46487-9_40 29. Zhang, Y., Jiang, H., Miura, Y., Manning, C.D., Langlotz, C.P.: Contrastive learning of medical visual representations from paired images and text. In: Machine Learning for Healthcare Conference, pp. 2–25. PMLR (2022) 30. Zhu, R., Zhao, B., Liu, J., Sun, Z., Chen, C.W.: Improving contrastive learning by visualizing feature transformation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10306–10315 (2021)
Multi-modal Semi-supervised Evidential Recycle Framework for Alzheimer’s Disease Classification Yingjie Feng1 , Wei Chen2 , Xianfeng Gu3 , Xiaoyin Xu4 , and Min Zhang1(B) 1
3
Collaborative Innovation Center of Artificial Intelligence, College of Computer Science and Technology, Zhejiang University, Hangzhou, China min [email protected] 2 Zhejiang University Affiliated Sir Run Run Shaw Hospital, Hangzhou, China Department of Computer Science, Stony Brook University, Stony Brook, NY, USA 4 Department of Radiology, Brigham and Women’s Hospital, Harvard Medical School, Boston, MA, USA
Abstract. Alzheimer’s disease (AD) is an irreversible neurodegenerative disease, so early identification of Alzheimer’s disease and its early stage disorder, mild cognitive impairment (MCI), is of great significance. However, currently available labeled datasets are still small, so the development of semi-supervised classification algorithms will be beneficial for clinical applications. We propose a novel uncertainty-aware semisupervised learning framework based on the improved evidential regression. Our framework uses the aleatoric uncertainty (AU) from the data itself and the epistemic uncertainty (EU) from the model to optimize the evidential classifier and feature extractor step by step to achieve the best performance close to supervised learning with small labeled data counts. We conducted various experiments on the ADNI-2 dataset, demonstrating the effectiveness and advancement of our method. Keywords: Semi-supervised learning · Deep evidential regression EfficientNet-V2 · Alzheimer’s disease · Multi-modality
1
·
Introduction
Alzheimer’s disease (AD) is an irreversible neurodegenerative disease that leaves patients with impairments in memory, language and cognition [7]. Previous work of [6,22] show that the combination of image data and other related data is beneficial to the improvement of model performance, but how to efficiently combine statistical non-imaging data and medical image data is still an open question. Second, although it is not too difficult to obtain and collect patient data, subjective bias in the AD diagnosis process and the time-consuming and complicated M. Zhang was partially supported by NSFC62202426. X. Gu was partially supported by NIH 3R01LM012434-05S1, 1R21EB029733-01A1, NSF FAIN-2115095, NSF CMMI1762287. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 130–140, 2023. https://doi.org/10.1007/978-3-031-43907-0_13
Multi-modal Semi-supervised Evidential Recycle Framework
131
process of labeling the diagnostic results lead to the scarcity of labeled data [11]. Therefore, research and development of models that require only a small amount of labeled data to achieve higher accuracy has attracted great attention [14]. Semi-supervised learning (SSL) methods are commonly used in medical image analysis to address the lack of manually annotated data [24]. Hang et al. [10] proposed a contrastive self-ensembling framework by introducing the weight formula and reliability-awareness for semi-supervised medical image classification. In [3], Aviles et al., based on the diffusion model and hypergraph learning, proposed a multi-modal hypergraph diffusion network to implement semi-supervised learning for AD classification. In [4], researchers introduced CSEAL, a semi-supervised learning framework that combines consistency-based SSL with uncertainty-based active learning, for multi-label chest X-ray classification tasks. Other studies using evidential learning [5,17] have demonstrated the great potential of this theory in fitting low-dimensional manifolds in high-dimensional spaces for classification with uncertainty estimates. This feature makes the model based on evidential learning promising in the SSL field of medical images. The proposal of evidential deep learning (EDL) [21] allows the model to better estimate the uncertainty in multi-classification tasks. On the binary classification task, controlling evidential regression to obtain a continuous probability value before 0 and 1 can often achieve more accurate results than using the Dirichlet distribution to obtain a discrete distribution of EDL [19].
Fig. 1. An iteration of optimization of our model. The AU optimization process does not rely on any label and obtains the best classification of the current model by reducing AU. Then, on the basis of this classification, EU optimization relies on ground truth and pseudo labels to optimize the model to get better prediction results.
The residual between the prediction results of the imperfect model and the true distribution of the data can be decomposed into aleatoric uncertainty (AU) and epistemic uncertainty (EU). Theoretically, the former comes from the noise of the data, which usually does not depend on the sample size. Therefore, by
132
Y. Feng et al.
iteratively reducing this part of uncertainty, the best classification results can be obtained under a given amount of data. The latter is proportional to the sample size. As sample size increases, the reduction of this part of uncertainty can make the model closer to the observed distribution or fit with more complex conditions, thereby improving the performance of the model itself. Based on this understanding, we exploit the ability of evidential regression of handling uncertainty to decompose the two parts of uncertainty, AU and EU, and proposed a method by adjusting the two parts of uncertainty to achieve semi-supervised classification which shows in Fig. 1. Our main contributions include: 1) Adjusting the loss function of evidential regression so it can obtain more accurate results and better separate AU and EU; 2) Building a multi-layer and multi-step network to implement evidential regression and a semi-supervised learning method of step-by-step training is proposed; 3) A new SOTA of semi-supervised learning is achieved on the ADNI dataset, and performance close to supervised learning can be achieved with only a small amount of labeled data.
2 2.1
Methods Original Deep Evidential Regression (DER)
DER [2] adopts the simplest setting: yi ∼ N (0, σi2 ). In a Bayesian framework, this corresponds to taking the normal inverse Gamma distribution NIG(μ, σ 2 |m), m = (γ, ν, α, β), as a conjugate prior of a normal contribution with unknown mean μ and variance σ 2 . Combining the disturbance parameter with Bayesian inference, the likelihood of an observation y for a given m follows a t-distribution βi (1+γi ) IG y . For known m, = St |γ , with 2αi degrees of freedom: LN 2αi i i i γi αi Animi et al. [2] defined the prediction of yi as E[μi ] = γi , and defined AU and EU as ua and ue : u2a = E[σi2 ] = βi /(αi − 1),
u2e = var[μi ] = E[σi2 ]/νi .
And it follows that:
IG Li (w) = − log LN (w) + λLR i i (w),
LR i (w) = |yi − γi | · Φ
where m = N N (w) is specified by a neural network (NN), λ is a hyperparameter, and Φ = 2γi + αi represents the total evidence gained from training. 2.2
Evidential Regression Beyond DER
Although DER has achieved some success in both theoretical and practical applications [2,15], as pointed out by Meinert et al. [16], this theory has some major flaws. First, although the regularization part of loss function LR is added, the constrain on parameter βi is not enough. Second, although the two parts of AU and EU are defined separately, the correlation between them is too high.
Multi-modal Semi-supervised Evidential Recycle Framework
133
Fig. 2. Architecture of our network. The recycle classifier part judges the decrease of the uncertainty of the predicted value, and controls the framework to train the other two parts in a loop.
In practice, disentangling and effectively using uncertainty information for training remains challenging. After practice and theoretical proof, Meinert et al. [16] states that the width of the t-distribution projected by the NIG distribution, that is, wSt , can better reflect the noise in data. √ And, correspondingly, we use the residual 1/ νi part of ua and ue in the original definition to represent EU: βi (1 + νi ) ue 1 uA = wSt = , uE = =√ αi νi ua νi where νi , αi , and βi are part of the parameters of the evidence distribution m = (γ, ν, α, β), and we verify the performance of this new uncertainty estimation method through experiments. 2.3
Model and Workflow
With efficient estimation of AU and EU, our model has the basis for implementation. As shown in Fig. 2, our model is divided into three parts: multimodal feature extractor, evidential predictor, and recycle classifier. The multimodal feature extractor form a high-dimensional feature space and the evidential predictor generates the evidential distribution in the feature space. After calculating the classification result and the uncertainty (AU and EU) based on the evidential distribution, the recycle classifier controls the training process and reaches the best performances through a step-by-step recurrent training workflow. AU for Training Classifier. Based on the manifold assumption, the real data is gathered on the low-dimensional manifold of the high-dimensional space, and the noise of the data is located on the edge of the manifold for the corresponding
134
Y. Feng et al.
category. When using ER to fit the manifold, these noise data will make marginal data with high AU. By optimizing the classifier to iteratively reduce the AU, optimal classification result under the current conditions can be obtained. We use La to optimize AU: γi − γi2 IG Lai (w) = − log LN (w) + λ Φ . a i 2 wSt In the above formula λa = [0.005, 0.01] is a parameter that controls the degree of deviations of the regularization part and wSt uses the previous definition, and Φ is the total amount of evidence learned by the model. In order to better motivate the learning of the model, we adopted the work of Liu et al. [15] and used the form of Φ = γi + 2αi . We used the expanded form of LN IG with minor N IG = adjustments according to the optimization objective, specifically, − log Li 1 2
log( πν ) − α log(2β + 2βν) + (α + 0.5) log(ν(γi − γi2 ) + 2β + 2βν) + log
Γ (α) Γ (α+0.5)
.
EU for Training Extractor. If only the AU part is optimized, there will always be this gap between the model prediction and the real data. EU is mainly used to optimize the feature extractor since EU mainly reflects the bias of the model in the prediction. For data Dl , given groundtruth labels, we 2 yi −γi IG (w) + λ Φ . In order to enhance the certainty use Lli (w) = − log LN l i wSt
of the data with ground truth, we set a smaller λl = [0.005, 0.015]. For dataset Du without real labels, we use the prediction results y obtained in the last IG (w) + iterative training to replace the real labels to get Lui (w) = − log LN i 2 yi −γi λu wSt Φ . In order for the model to utilize the results of the previous round
of learning, we set a larger λu = [0.015, 0.025], which can make the model more conservative about making predictions in the next iteration. This reduces our models being affected by misleading evidence and obtains better performance by retaining higher uncertainty to allow the model to have more room to optimize. In order to effectively combine labeled and unlabeled data we adjust the weights of different data: IG Lei (w) = μl Lli + μu Lui = − log LN (w) + μu λu i
yi − γi wSt
2
Φ + μl λl
yi − γi wSt
2
Φ
where μl + μu = 1, μl , μu ∈ [0, 1], are two weight factors. Model. In terms of the feature extractor, we use the latest EfficientNetV2, which, in Feng et al. [8], has achieved good results in combination with EDL. In order to avoid overfitting, we used the minimum model in this network and added Dropout to the output end. At the same time, in order to fill the differences between multi-modality data and model input, we have added the fully connected (FC) layer and convolutional layer (Conv) to adaptive adjust input channels. We employed three evidential FC layers proposed by Amini et al. [2] to form our
Multi-modal Semi-supervised Evidential Recycle Framework
135
evidential predictor. At the same time, in order to achieve the optimization of AU and EU separately, we froze some parameters in the first two layers and limited the range of the last layer of parameter adjustment. Workflow of Recycle Training. First, we do not fix any weight and we use a small part of the data for warm-up training. Second, we freeze the weight update of the extractor and P 2 and P 3 in the evidential predictor and use La to optimize the classifier. We calculate and record the AU score after each update. When the difference between the update |ΔAU | is smaller than the threshold value Ta = [0.0005, 0.001] we set, the cycle of AU optimization is over. Then, we fix the weight of P 1 and P 3 in the evidential predictor and use Le to optimize the extractor. Similarly, when the change |ΔEU | brought by the update is less than the threshold Te = [0.0025, 0.005], end the cycle and output y . Finally, we fix all network parameters except P 3 to fine-tune until |ΔU | = |ΔAU |+|ΔEU | brought by the update loss function L = Le +La is less than threshold Tu = [0.002, 0.005]. All thresholds are adjusted according to the proportion of labels and unlabeled data during training. Table 1. Classification results of all comparison methods on ADNI-2 dataset (%). The top section of the table shows the results of supervised learning (SL), while the bottom section shows the performance of the current SSL SOTA methods. Method
AD vs. NC ACC SPE
Baseline [12] SL SOTA [18, 20, 23] Upper bound (UB)
80.53 96.84 94.45
80.10 98.23 93.80 85.82 89.03 89.50 92.80 92.95
Π model [13] 90.45 90.86 DS 3 L [9] 92.11 RFS-LDA [1] Hypergraph diffusion [3] 92.11 93.90 Ours
3
SEN
EMCI vs. LMCI ACC SPE SEN
LMCI vs. NC ACC SPE SEN
80.32 95.76 94.07
74.10 92.40 89.95
73.18 93.70 90.29
75.85 89.50 90.81
72.05 92.49 88.56
70.80 91.08 88.81
71.26 93.48 86.32
90.05 89.72 88.40 91.33 93.01
81.58 81.07 80.90 85.22 89.45
80.15 83.25 81.05 86.40 88.50
84.50 82.81 83.63 84.02 89.47
80.65 80.79 81.90 82.01 87.27
83.48 81.55 84.72 84.01 86.94
78.75 81.16 80.05 81.80 85.83
Experiments and Results
Data Description. In this paper, we assess the effectiveness of our multimodal semi-supervised evidential recycle framework on the ADNI-2 dataset1 , which comprises multi-center data consisting of various modalities, including imaging and multiple phenotype data. Specifically, the dataset consists of four 1
*Data used in preparation of this article were obtained from the Alzheimer’s Disease Neuroimaging Initiative (ADNI) database (adni.loni.usc.edu).
136
Y. Feng et al.
categories: normal control (NC), early mild cognitive impairment (EMCI), late mild cognitive impairment (LMCI), and Alzheimer’s disease (AD). To ensure the effectiveness of our training and balance the number of categories, we used a sample of 515 patients, utilizing their MRI, PET, demographics, and APOE as inputs. On MRI images, we used 3T T1-weighted and FLAIR MR images, and the preprocessing process used CAT12 and SPM tools. All MRI data were processed using standard pipeline, including anterior commissure (AC)-posterior commissure (PC) correction, intensity correction, and skull stripping. Affine registration is performed to linearly align each MRI to the Colin27 template and resample to 224 × 224 × 91 for subsequent processing. For PET images, we used the official pre-processed AV-45 PET image and resampled them in the same way as the MRIs. We chose to include APOE in our analysis, as it is a well-established genetic risk factor for developing AD. Evaluation. We evaluated our model from three aspects. First, for the sake of comparison, we followed the technical conventions of most similar studies and selected three comparison tasks: AD vs NC, LMCI vs NC, and EMCI vs LMCI. Second, we compared and demonstrated the results of our model under different numbers of ground truth labels to verify that its performance improves as the label data volume increases. Third, we conducted different ablation experiments, which shows in Fig. 3, to prove the validity and rationality of each part of the proposed model framework. Among them, CNNs represents the performance when using only the EfficientNetV2-S model and its original supervised learning classifier without using unlabeled data for training, which is the baseline model. AU and EU represent the training process using only the corresponding parts. DER uses our proposed complete training process but does not use our improved ua and ue estimations, instead continuing to use the estimation method uA and uE which proposed in the original DER paper [2]. To compare performance fairly, we ran all techniques under the same conditions. The results were evaluated on accuracy (ACC), specificity (SPE), and sensitivity (SEN). Implementation Details. The upper bound in performance is the result obtained when the model is trained with all the input data are labeled. In the current supervised learning algorithms, the performance of each algorithm on
Fig. 3. Ablation experiment results.
Multi-modal Semi-supervised Evidential Recycle Framework
137
Table 2. Classification results(%) with different percentages of labeled and unlabelled data in the training process. Method
Labeled AD vs. NC ACC SPE
SEN
EMCI vs. LMCI ACC SPE SEN
LMCI vs. NC ACC SPE SEN
Upper Bound 100%
94.45
93.80
94.07
89.95
90.29
90.81
88.56
88.81
86.32
Baseline DS 3 L Ours
5%
75.11 78.92 82.45
73.92 76.75 81.63
75.79 78.14 84.97
73.17 74.68 73.05
69.47 71.65 70.84
75.14 72.39 74.22
68.95 74.27 73.86
67.81 70.58 74.15
68.53 73.71 72.28
Baseline DS 3 L Ours
10%
78.50 84.73 90.18
78.03 79.67 89.73
77.48 82.69 87.42
73.67 80.93 81.67
70.29 80.13 83.45
74.11 79.54 82.29
69.13 74.57 80.01
68.57 75.39 78.25
68.38 72.88 80.89
Baseline DS 3 L Ours
20%
80.53 80.10 80.32 74.10 73.18 75.85 72.05 70.80 71.26 90.86 89.03 89.72 81.07 83.25 82.81 80.79 81.55 81.16 93.90 92.95 93.01 89.45 88.50 89.47 87.27 86.94 85.83
each task is not consistent, so we selected three papers in supervised learning, each representing the SOTA performance of the three tasks [18,20,23] for comparison. Our implementation employs PyTorch v1.4.0 and utilizes the Adam optimizer with a learning rate of 1 × 10−4 and a weight decay of 1 × 10−5 . We utilize a linear decay scheduler of 0.1 based on the loss functions above. The optimizer is set with β values of [0.9, 0.999] and value of 1 × 10−8 . In terms of data, since the SSL method needs to learn from unlabeled data, 100% of the data is put into training, and some of the data have ground truth labels. In the test, only the result index of the unlabeled data is calculated, so the training set and the test set are not divided. But in order to determine the threshold of each uncertainty, we randomly selected 10% of the data as the validation set, and calculated the uncertainty independently outside the training process. Results. We compared our model with the semi-supervised learning methods currently achieving the best performance on the ADNI-2 dataset, as well as other top models in the semi-supervised learning field. As shown in Table 1, our model achieved SOTA performance in all three tasks of the semi-supervised learning category. At the same time, compared with other semi-supervised learning algorithms, our results are unprecedentedly close to the best supervised learning methods, indicating the performance of our model under less labeled data and the feasibility of applying this algorithm in clinical settings. Our ablation experiment results are shown in Fig. 3. Firstly, compared with the baseline, our semi-supervised learning algorithm effectively learns classification information from unlabeled data. Secondly, compared with DER, our uncertainty estimation surpasses the original DER method. The AU and EU items demonstrate the importance of optimizing both the AU and EU components in our framework.
138
Y. Feng et al.
Fig. 4. Error rate of different percentages of label counts, the well-known conduction effect in the field of semi-supervised learning can be observed.
From Table 2, we can observe that we have outperformed the currently representative advanced semi-supervised learning algorithm DS 3 L [9] in each labeled data count. At the same time, the superiority of our model compared to the baseline method also proves the learning efficiency of our framework. The performance of our model at 20% labeled data count is already very close to the upper bound, which is the result obtained using 100% labeled data. This indicates the strong learning ability of our model in the case of a small labeled data amount. In addition, we have plotted the error rate of our framework under different labeled data counts in Fig. 4. It is apparent that the performance of our model improves as the labeled data amount increases from 5% to 10%, 15%, and 20%. Combining with Table 2, we can observe the well-known transductive effect in the field of semi-supervised learning, which means that beyond a certain data amount, increasing the size of the dataset can only bring marginal performance improvement. This is evident when comparing the model performance under 20%, 40%, 80%, and 100% labeled data counts.
4
Conclusions
We proposed an evidential regression-based semi-supervised learning framework, using the characteristics of AU and EU to train classifiers and extractors, respectively. Our model achieves SOTA performance on the ADNI-2 dataset. And due to the characteristics of semi-supervised learning, our model has unique advantages in adding private data, fine-tuning downstream tasks, and avoiding overfitting, which makes our model have great potential in clinical applications.
Multi-modal Semi-supervised Evidential Recycle Framework
139
References 1. Adeli, E., et al.: Semi-supervised discriminative classification robust to sampleoutliers and feature-noises. IEEE Trans. Pattern Anal. Mach. Intell. 41(2), 515–522 (2018) 2. Amini, A., Schwarting, W., Soleimany, A., Rus, D.: Deep evidential regression. Adv. Neural. Inf. Process. Syst. 33, 14927–14937 (2020) 3. Aviles-Rivero, A.I., Runkel, C., Papadakis, N., Kourtzi, Z., Sch¨ onlieb, C.B.: Multimodal hypergraph diffusion network with dual prior for Alzheimer classification. In: MICCAI 2022, Part III. LNCS, pp. 717–727. Springer, Cham (2022). https:// doi.org/10.1007/978-3-031-16437-8 69 4. Balaram, S., Nguyen, C.M., Kassim, A., Krishnaswamy, P.: Consistency-based semi-supervised evidential active learning for diagnostic radiograph classification. In: MICCAI 2022, Part I. LNCS, vol. 13431, pp. 675–685. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16431-6 64 5. Bengs, V., H¨ ullermeier, E., Waegeman, W.: Pitfalls of epistemic uncertainty quantification through loss minimisation. In: Advances in Neural Information Processing Systems (2022) 6. Cobbinah, B.M., et al.: Reducing variations in multi-center Alzheimer’s disease classification with convolutional adversarial autoencoder. Med. Image Anal. 82, 102585 (2022) 7. De Strooper, B., Karran, E.: The cellular phase of Alzheimer’s disease. Cell 164(4), 603–615 (2016) 8. Feng, Y., Wang, J., An, D., Gu, X., Xu, X., Zhang, M.: End-to-end evidentialefficient net for radiomics analysis of brain MRI to predict oncogene expression and overall survival. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022, Part III, vol. 13433, pp. 282–291. Springer, Cham (2022). https:// doi.org/10.1007/978-3-031-16437-8 27 9. Guo, L.Z., Zhang, Z.Y., Jiang, Y., Li, Y.F., Zhou, Z.H.: Safe deep semi-supervised learning for unseen-class unlabeled data. In: International Conference on Machine Learning, pp. 3897–3906. PMLR (2020) 10. Hang, W., Huang, Y., Liang, S., Lei, B., Choi, K.S., Qin, J.: Reliability-aware contrastive self-ensembling for semi-supervised medical image classification. In: Medical Image Computing and Computer Assisted Intervention-MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part I. pp. 754–763. Springer, Cham (2022). https://doi.org/10.1007/978-3-03116431-6 71 11. Hett, K., Ta, V.T., Oguz, I., Manj´ on, J.V., Coup´e, P., Initiative, A.D.N., et al.: Multi-scale graph-based grading for Alzheimer’s disease prediction. Med. Image Anal. 67, 101850 (2021) 12. Huang, G., Liu, Z., Van Der Maaten, L., Weinberger, K.Q.: Densely connected convolutional networks. In: Proceedings of the IEEE Conference on Computer Cision and Pattern Recognition, pp. 4700–4708 (2017) 13. Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. In: International Conference on Learning Representations (2016) 14. Li, Z., Togo, R., Ogawa, T., Haseyama, M.: Chronic gastritis classification using gastric x-ray images with a semi-supervised learning method based on tri-training. Med. Biol. Eng. Comput. 58, 1239–1250 (2020) 15. Liu, Z., Amini, A., Zhu, S., Karaman, S., Han, S., Rus, D.L.: Efficient and robust lidar-based end-to-end navigation. In: 2021 IEEE International Conference on Robotics and Automation (ICRA), pp. 13247–13254. IEEE (2021)
140
Y. Feng et al.
16. Meinert, N., Gawlikowski, J., Lavin, A.: The unreasonable effectiveness of deep evidential regression. arXiv e-prints pp. arXiv-2205 (2022) 17. Neupane, K.P., Zheng, E., Yu, Q.: MetaEDL: Meta evidential learning for uncertainty-aware cold-start recommendations. In: 2021 IEEE International Conference on Data Mining (ICDM), pp. 1258–1263. IEEE (2021) 18. Ning, Z., Xiao, Q., Feng, Q., Chen, W., Zhang, Y.: Relation-induced multi-modal shared representation learning for Alzheimer’s disease diagnosis. IEEE Trans. Med. Imaging 40(6), 1632–1645 (2021) 19. Oh, D., Shin, B.: Improving evidential deep learning via multi-task learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, pp. 7895– 7903 (2022) 20. Pei, Z., Wan, Z., Zhang, Y., Wang, M., Leng, C., Yang, Y.H.: Multi-scale attentionbased pseudo-3D convolution neural network for Alzheimer’s disease diagnosis using structural MRI. Pattern Recogn. 131, 108825 (2022) 21. Sensoy, M., Kaplan, L., Kandemir, M.: Evidential deep learning to quantify classification uncertainty. In: Advances in Neural Information Processing Systems 31 (2018) 22. Song, X., et al.: Graph convolution network with similarity awareness and adaptive calibration for disease-induced deterioration prediction. Med. Image Anal. 69, 101947 (2021) 23. Song, X., et al.: Multi-center and multi-channel pooling GCN for early AD diagnosis based on dual-modality fused brain network. IEEE Trans. Med. Imaging (2022) 24. Yang, X., Song, Z., King, I., Xu, Z.: A survey on deep semi-supervised learning. IEEE Trans. Knowl. Data Eng. (2022)
3D Arterial Segmentation via Single 2D Projections and Depth Supervision in Contrast-Enhanced CT Images Alina F. Dima1,2(B) , Veronika A. Zimmer1,2 , Martin J. Menten1,4 , Hongwei Bran Li1,3 , Markus Graf2 , Tristan Lemke2 , Philipp Raffler2 , Robert Graf1,2 , Jan S. Kirschke2 , Rickmer Braren2 , and Daniel Rueckert1,2,4 1
3
School of Computation, Information and Technology, Technical University of Munich, Munich, Germany [email protected] 2 School of Medicine, Klinikum Rechts der Isar, Technical University of Munich, Munich, Germany Department of Quantitative Biomedicine, University of Zurich, Zurich, Switzerland 4 Department of Computing, Imperial College London, London, UK
Abstract. Automated segmentation of the blood vessels in 3D volumes is an essential step for the quantitative diagnosis and treatment of many vascular diseases. 3D vessel segmentation is being actively investigated in existing works, mostly in deep learning approaches. However, training 3D deep networks requires large amounts of manual 3D annotations from experts, which are laborious to obtain. This is especially the case for 3D vessel segmentation, as vessels are sparse yet spread out over many slices and disconnected when visualized in 2D slices. In this work, we propose a novel method to segment the 3D peripancreatic arteries solely from one annotated 2D projection per training image with depth supervision. We perform extensive experiments on the segmentation of peripancreatic arteries on 3D contrast-enhanced CT images and demonstrate how well we capture the rich depth information from 2D projections. We demonstrate that by annotating a single, randomly chosen projection for each training sample, we obtain comparable performance to annotating multiple 2D projections, thereby reducing the annotation effort. Furthermore, by mapping the 2D labels to the 3D space using depth information and incorporating this into training, we almost close the performance gap between 3D supervision and 2D supervision. Our code is available at: https://github.com/alinafdima/3Dseg-mip-depth. Keywords: vessel segmentation · 3D segmentation · weakly supervised segmentation · curvilinear structures · 2D projections
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_14. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 141–151, 2023. https://doi.org/10.1007/978-3-031-43907-0_14
142
1
A. F. Dima et al.
Introduction
Automated segmentation of blood vessels in 3D medical images is a crucial step for the diagnosis and treatment of many diseases, where the segmentation can aid in visualization, help with surgery planning, be used to compute biomarkers, and further downstream tasks. Automatic vessel segmentation has been extensively studied, both using classical computer vision algorithms [16] such as vesselness filters [8], or more recently with deep learning [3,5,6,11,19,21], where state-ofthe-art performance has been achieved for various vessel structures. Supervised deep learning typically requires large, well-curated training sets, which are often laborious to obtain. This is especially the case for 3D vessel segmentation. Manually delineating 3D vessels typically involves visualizing and annotating a 3D volume through a sequence of 2D cross-sectional slices, which is not a good medium for visualizing 3D vessels. This is because often only the cross-section of a vessel is visible in a 2D slice. In order to segment a vessel, the annotator has to track the cross-section of that vessel through several adjacent slices, which is especially tedious for curved or branching vessel trees. Projecting 3D vessels to a 2D plane allows for the entire vessel tree to be visible within a single 2D image, providing a more robust representation and potentially alleviating the burden of manual annotation. Kozinski et al. [13] propose to annotate up to three maximum intensity projections (MIP) for the task of centerline segmentation [13], obtaining results comparable to full 3D supervision. Compared to centerline segmentation, where the vessel diameter is disregarded, training a 3D vessel segmentation model from 2D annotations poses additional segmentationspecific challenges, as 2D projections only capture the outline of the vessels, providing no information about their interior. Furthermore, the axes of projection are crucial for the model’s success, given the sparsity of information in 2D annotations. To achieve 3D vessel segmentation with only 2D supervision from projections, we first investigate which viewpoints to annotate in order to maximize segmentation performance. We show that it is feasible to segment the full extent of vessels in 3D images with high accuracy by annotating only a single randomlyselected 2D projection per training image. This approach substantially reduces the annotation effort, even compared to works training only on 2D projections. Secondly, by mapping the 2D annotations to the 3D space using the depth of the MIPs, we obtain a partially segmented 3D volume that can be used as an additional supervision signal. We demonstrate the utility of our method on the challenging task of peripancreatic arterial segmentation on contrast-enhanced arterial-phase computed tomography (CT) images, which feature large variance in vessel diameter. Our contribution to 3D vessel segmentation is three-fold: ◦ Our work shows that highly accurate automatic segmentation of 3D vessels can be learned by annotating single MIPs. ◦ Based on extensive experimental results, we determine that the best annotation strategy is to label randomly selected viewpoints, while also substantially reducing the annotation cost.
3D Arterial Segmentation via Single 2D Projections and Depth Supervision
143
◦ By incorporating additional depth information obtained from 2D annotations at no extra cost to the annotator, we almost close the gap between 3D supervision and 2D supervision.
2
Related Work
Learning from Weak Annotations. Weak annotations have been used in deep learning segmentation to reduce the annotation effort through cheaper, less accurate, or sparser labeling [20]. Bai et al. [1] learn to perform aortic image segmentation by sparsely annotating only a subset of the input slices. Multiple instance learning approaches bin pixels together by only providing labels at the bin level. Jia et al. [12] use this approach to segment cancer on histopathology images successfully. Annotating 2D projections for 3D data is another approach to using weak segmentation labels, which has garnered popularity recently in the medical domain. Bayat et al. [2] propose to learn the spine posture from 2D radiographs, while Zhou et al. [22] use multi-planar MIPs for multi-organ segmentation of the abdomen. Kozinski et al.[13] propose to segment vessel centerlines using as few as 2-3 annotated MIPs. Chen et al. [4] train a vessel segmentation model from unsupervised 2D labels transferred from a publicly available dataset, however, there is still a gap to be closed between unsupervised and supervised model performance. Our work uses weak annotations in the form of annotations of 2D MIPs for the task of peripancreatic vessel segmentation, where we attempt to reduce the annotation cost to a minimum by only annotating a single projection per training input without sacrificing performance. Incorporating Depth Information. Depth is one of the properties of the 3D world. Loss of depth information occurs whenever 3D data is projected onto a lower dimensional space. In natural images, depth loss is inherent through image acquisition, therefore attempts to recover or model depth have been employed for 3D natural data. For instance, Fu et al. [9] use neural implicit fields to semantically segment images by transferring labels from 3D primitives to 2D images. Lawin et al. [14] propose to segment 3D point clouds by projecting them onto 2D and training a 2D segmentation network. At inference time, the predicted 2D segmentation labels are remapped back to the original 3D space using the depth information. In the medical domain, depth information has been used in volume rendering techniques [7] to aid with visualization, but it has so far not been employed when working with 2D projections of 3D volumes to recover information loss. We propose to do the conceptually opposite approach from Lawin et al. [14], by projecting 3D volumes onto 2D to facilitate and reduce annotation. We use depth information to map the 2D annotations to the original 3D space at annotation time and generate partial 3D segmentation volumes, which we incorporate in training as an additional loss term.
144
3
A. F. Dima et al.
Methodology
Overview. The maximum intensity projection (MIP) of a 3D volume I ∈ RNx ×Ny ×Nz is defined as the highest intensity along a given axis: mip(x, y) = maxz I(x, y, z) ∈ RNx ×Ny .
(1)
For simplicity, we only describe MIPs along the z-axis, but they can be performed on any image axis.
Fig. 1. Method overview. We train a 3D network to segment vessels from 2D annotations. Given an input image I, depth-encoded MIPs pf w , pbw are generated by projecting the input image to 2D. 2D binary labels A are generated by annotating one 2D projection per image. The 2D annotation is mapped to the 3D space using the depth information, resulting in a partially labeled 3D volume D. During training, both 2D annotations and 3D depth maps are used as supervision signals in a combined loss, which uses both predicted 3D segmentation Y and its 2D projection mip(Y ).
Exploiting the fact that arteries are hyperintense in arterial phase CTs, we propose to annotate MIPs of the input volume for binary segmentation. The hyperintensities of the arteries ensures their visibility in the MIP, while additional processing removes most occluding nearby tissue (Sect. 4). Given a binary 2D annotation of a MIP A ∈ {0, 1}Nx ×Ny , we map the foreground pixels in A to the original 3D image space. This is achieved by using the first and last z coordinates where the maximum intensity is observed along any projection ray. Owing to the fact that the vessels in the abdominal cavity are relatively sparse in 2D projections and most of the occluding tissue is removed in postprocessing, this step results in a fairly complete surface of the vessel tree. Furthermore, we can partially fill this surface volume, resulting in a 3D depth map D, which is a partial segmentation of the vessel tree. We use the 2D annotations as well as the depth map to train a 3D segmentation network in a weakly supervised manner.
3D Arterial Segmentation via Single 2D Projections and Depth Supervision
145
An overview of our method is presented in Fig. 1. In the following, we describe these components and how they are combined to train a 3D segmentation network in more detail. Depth Information. We can view MIP as capturing the intensity of the brightest pixel along each ray rxy ∈ RNz , where rxy (z) = I(x, y, z). Along each projection ray, we denote the first and last z coordinates which have the same intensity as the MIP to be the forward depth zf w = arg maxz I(x, y, z) and backward depth zbw = arg minz I(x, y, z). This information can be utilized for the following: (1) enhancing the MIP visualization, or (2) providing a way to map pixels from the 2D MIP back to the 3D space (depth map). The reason why the maximum intensity is achieved multiple times along a ray is because our images are clipped, which removes a lot of the intensity fluctuations.
Fig. 2. Example depth-enhanced MIP using (a) forward depth zf w and (b) backward depth zbw visualized in color; (c) binary 2D annotation; a slice view from a 3D volume illustrating: (e) the forward – in green – and backward depth – in blue – , (f) the depth map, (g) 3D ground truth; volume rendering of (h) the depth map and (d) the depth map with only forward and backward depth pixels. The input images are contrast-enhanced.(Color figure online)
Depth-Enhanced MIP. We encode depth information into the MIPs by combining the MIP with the forward and backward depth respectively, in order to √ fw fw = defines mip · z achieve better depth perception during annotation: p √ bw bw the forward projection, while p = mip · z defines the backward projection. Figure 2 showcases (a) forward and (b) backward depth encoded MIPs. Depth Map Generation. Foreground pixels from the 2D annotations are mapped to the 3D space by combining a 2D annotation with the forward and backward depth, resulting in a 3D partial vessel segmentation:
146
A. F. Dima et al.
1. Create an empty 3D volume D ∈ RNx ×Ny ×Nz . 2. For each foreground pixel in the annotation A at location (x, y), we label (x, y, zf w ) and (x, y, zbw ) as foreground pixels in D. 3. If the fluctuation in intensity between zf w and zbw along the ray rxy is below a certain threshold in the source image I, the intermediate pixels are also labeled as foreground in D. Training Loss. We train a 3D segmentation network to predict 3D binary vessel segmentation given a 3D input volume using 2D annotations. Our training set Dtr (I, A, D) consists of 3D volumes I paired with 2D annotations A and their corresponding 3D depth maps D. Given the 3D network output Y = θ(I), we minimize the following loss during training: L(Y) = α · CE(A, mip(Y)) + (1 − α) · CE(D, Y) · D,
(2)
where α ∈ [0, 1]. Our final loss is a convex combination between: (a) the crossentropy(CE) of the network output projected to 2D and the 2D annotation, as well as (b) the cross-entropy between the network output and the depth map, but only applied to positive pixels in the depth map. Notably, the 2D loss constrains the shape of the vessels, while the depth loss promotes the segmentation of the vessel interior.
4
Experimental Design
Dataset. We use an in-house dataset of contrast-enhanced abdominal computed tomography images (CTs) in the arterial phase to segment the peripancreatic arteries [6]. The cohort consists of 141 patients with pancreatic ductal adenocarcinoma, of an equal ratio of male to female patients. Given a 3D arterial CT of the abdominal area, we automatically extract the vertebrae [15,18] and semi-automatically extract the ribs, which have similar intensities as arteries in arterial CTs and would otherwise occlude the vessels. In order to remove as much of the cluttering surrounding tissue and increase the visibility of the vessels in the projections, the input is windowed so that the vessels appear hyperintense. Details of the exact preprocessing steps can be found in Table 2 of the supplementary material. The dataset contains binary 3D annotations of the peripancreatic arteries carried out by two radiologists, each having annotated half of the dataset. The 2D annotations we use in our experiments are projections of these 3D annotations. For more information about the dataset, see [6]. Image Augmentation and Transformation. As the annotations lie on a 2D plane, 3D spatial augmentation cannot be used due to the information sparsity in the ground truth. Instead, we apply an invertible transformation T to the input volume and apply the inverse transformation T −1 to the network output before applying the loss, such that the ground truth need not be altered. A detailed description of the augmentations and transformations used can be found in Table 1 in the supplementary material.
3D Arterial Segmentation via Single 2D Projections and Depth Supervision
147
Training and Evaluation. We use a 3D U-Net [17] with four layers as our backbone, together with Xavier initialization [10]. A diagram of the network architecture can be found in Fig. 2 in the supplementary material. The loss weight α is tuned at 0.5, as this empirically yields the best performance. Our experiments are averaged over 5-fold cross-validation with 80 train samples, 20 validation samples, and a fixed test set of 41 samples. The network initialization is different for each fold but kept consistent across different experiments run on the same fold. This way, both data variance and initialization variance are accounted for through cross-validation. To measure the performance of our models, we use the Dice score, precision, recall, and mean surface distance (MSD). We also compute the skeleton recall as the percentage of the ground truth skeleton pixels which are present in the prediction. Table 1. Viewpoint ablation. We compare models trained on single random viewpoints (VPs) with (+D) or without (−D) depth against fixed viewpoint baselines without depth and full 3D supervision. We distinguish between model selection based on 2D annotations vs. 3D annotations on the validation set. The best-performing models for each model selection (2D vs. 3D) are highlighted in bold. Experiment
Model Selection Dice ↑
3D
3D
92.18 ± 0.35 93.86 ± 0.81 90.64 ± 0.64
fixed 3VP fixed 2VP
3D 3D
92.02 ± 0.52 91.29 ± 0.78
fixed 3VP fixed 2VP fixed 1VP
2D 2D 2D
90.78 ± 1.30 90.66 ± 1.30 91.18 ± 3.08 90.22 ± 1.19 88.16 ± 2.86 92.74 ± 1.63 60.76 ± 24.14 50.47 ± 23.21 92.52 ± 3.09
random 1VP−D 2D random 1VP+D 2D
5
Precision ↑ 93.05 ± 0.61 91.46 ± 2.13
Recall ↑
Skeleton Recall ↑ MSD ↓ 76.04 ± 4.51
91.13 ± 0.79 78.61 ± 1.52 91.37 ± 1.45 78.51 ± 2.78 81.77 ± 2.13 82.18 ± 2.47 81.19 ± 2.39
91.29 ± 0.81 91.42 ± 0.92 91.45 ± 1.00 80.16 ± 2.35 91.69 ± 0.48 90.77 ± 1.76 92.79 ± 0.95 81.27 ± 2.02
1.15 ± 0.11 1.13 ± 0.11 1.13 ± 0.09 1.16 ± 0.13 1.14 ± 0.09 2.96 ± 3.15 1.13 ± 0.04 1.15 ± 0.11
Results
The Effectiveness of 2D Projections and Depth Supervision. We compare training using single random viewpoints with and without depth information against baselines that use more supervision. Models trained on full 3D ground truth represent the upper bound baseline, which is very expensive to annotate. We implement [13] as a baseline on our dataset, training on up to 3 fixed orthogonal projections. We distinguish between models selected according to the 2D performance on the validation set (2D) which is a fair baseline, and models selected according to the 3D performance on the validation set (3D), which is an unfair baseline as it requires 3D annotations on the validation set. With the exception of the single fixed viewpoint baselines where the models have the tendency to diverge towards over- or segmentation, we perform binary holefilling on the output of all of our other models, as producing hollow objects is a common under-segmentation issue. In Table 1 we compare our method against the 3D baseline, as well as baselines trained on multiple viewpoints. We see that by using depth information
148
A. F. Dima et al.
paired with training using a single random viewpoint per sample performs almost at the level of models trained on 3D labels, at a very small fraction of the annotation cost. The depth information also reduces model variance compared to the same setup without depth information. Even without depth information, training the model on single randomly chosen viewpoints offers a robust training signal that the Dice score is on par with training on 2 fixed viewpoints under ideal model selection at only half the annotation cost. Randomly selecting viewpoints for training acts as powerful data augmentation, which is why we are able to obtain performance comparable to using more fixed viewpoints. Under ideal 3D-based model selection, three views would come even closer to full 3D performance; however, with realistic 2D-based model selection, fixed viewpoints are more prone to diverge. This occurs because sometimes 2D-based model selection favors divergent models which only segment hollow objects, which cannot be fixed in postprocessing. Single fixed viewpoints contain so little information on their own that models trained on such input fail to learn how to segment the vessels and generally converge to over-segmenting in the blind spots in the projections. We conclude that using random viewpoints is not only helpful in reducing annotation cost but also decreases model variance. In terms of other metrics, randomly chosen projection viewpoints with and without depth improve both recall and skeleton recall even compared to fully 3D annotations, while generally reducing precision. We theorize that this is because the dataset itself contains noisy annotations and fully supervised models better overfit to the type of data annotation, whereas our models converge to following the contrast and segmenting more vessels, which are sometimes wrongfully labeled as background in the ground truth. MSD are not very telling in our dataset due to the noisy annotations and the nature of vessels, as an under- or over-segmented vessel branch can quickly translate into a large surface distance. The Effect of Dataset Size. We vary the size of the training set from |Dtr | = 80 to as little as |Dtr | = 10 samples, while keeping the size of the validation and test sets constant, and train models on single random viewpoints. In Table 2, we compare single random projections trained with and without depth information at varying dataset sizes to ilustrate the usefulness of the depth information with different amounts of training data. Our depth loss offers consistent improvement across multiple dataset sizes and reduces the overall performance variance. The performance boost is noticeable across the board, the only exception being precision. The smaller the dataset size is, the greater the performance boost from the depth. We perform a Wilcoxon rank-sum statistical test comparing the individual sample predictions of the models trained at various dataset sizes with single random orthogonal viewpoints with or without depth information, obtaining a statistically significant (p-value of < 0.0001). We conclude that the depth information complements the segmentation effectively.
3D Arterial Segmentation via Single 2D Projections and Depth Supervision
149
Table 2. Dataset size ablation. We vary the training dataset size |Dtr | and compare models trained on single random viewpoints, with or without depth. Best performing models in each setting are highlighted. |Dtr | Depth Dice ↑
Precision ↑
Recall ↑
Skeleton Recall ↑ MSD ↓
10 10
−D +D
86.03 ± 2.94 88.23 ± 2.58 84.81 ± 6.42 78.25 ± 2.20 89.06 ± 1.20 88.55 ± 1.73 89.91 ± 1.29 78.95 ± 3.62
1.92 ± 0.55 1.80 ± 0.28
20 20
−D +D
88.22 ± 3.89 90.26 ± 1.64 86.74 ± 6.56 80.78 ± 1.66 90.51 ± 0.38 89.84 ± 0.90 91.50 ± 1.23 80.00 ± 1.95
1.44 ± 0.20 1.33 ± 0.16
40 40
−D +D
88.07 ± 2.34 89.09 ± 2.01 87.62 ± 4.43 78.38 ± 2.39 90.21 ± 0.89 89.08 ± 2.89 91.82 ± 2.11 79.16 ± 2.36
1.38 ± 0.10 1.24 ± 0.14
80 80
−D +D
91.29 ± 0.81 91.42 ± 0.92 91.45 ± 1.00 80.16 ± 2.35 91.69 ± 0.48 90.77 ± 1.76 92.79 ± 0.95 81.27 ± 2.02
1.13 ± 0.04 1.15 ± 0.11
6
Conclusion
In this work, we present an approach for 3D segmentation of peripancreatic arteries using very sparse 2D annotations. Using a labeled dataset consisting of single, randomly selected, orthogonal 2D annotations for each training sample and additional depth information obtained at no extra cost, we obtain accuracy almost on par with fully supervised models trained on 3D data at a mere fraction of the annotation cost. Limitations of our work are that the depth information relies on the assumption that the vessels exhibit minimal intensity fluctuations within local neighborhoods, which might not hold on other datasets, where more sophisticated ray-tracing methods would be more effective in locating the front and back of projected objects. Furthermore, careful preprocessing is performed to eliminate occluders, which would limit its transferability to datasets with many occluding objects of similar intensities. Further investigation is needed to quantify how manual 2D annotations compare to our 3D-derived annotations, where we expect occluders to affect the annotation process.
References 1. Bai, W., et al.: Recurrent neural networks for aortic image sequence segmentation with sparse annotations. In: Frangi, A.F., Schnabel, J.A., Davatzikos, C., AlberolaLópez, C., Fichtinger, G. (eds.) MICCAI 2018. LNCS, vol. 11073, pp. 586–594. Springer, Cham (2018). https://doi.org/10.1007/978-3-030-00937-3_67 2. Bayat, A.: Inferring the 3D standing spine posture from 2D radiographs. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12266, pp. 775–784. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59725-2_75 3. Chen, C., Chuah, J.H., Ali, R., Wang, Y.: Retinal vessel segmentation using deep learning: a review. IEEE Access 9, 111985–112004 (2021)
150
A. F. Dima et al.
4. Chen, H., Wang, X., Wang, L.: 3D vessel segmentation with limited guidance of 2D structure-agnostic vessel annotations. arXiv preprint arXiv:2302.03299 (2023) 5. Ciecholewski, M., Kassjański, M.: Computational methods for liver vessel segmentation in medical imaging: A review. Sensors 21(6), 2027 (2021) 6. Dima, A., et al.: Segmentation of peripancreatic arteries in multispectral computed tomography imaging. In: Lian, C., Cao, X., Rekik, I., Xu, X., Yan, P. (eds.) MLMI 2021. LNCS, vol. 12966, pp. 596–605. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87589-3_61 7. Drebin, R.A., Carpenter, L., Hanrahan, P.: Volume rendering. ACM Siggraph Comput. Graphics 22(4), 65–74 (1988) 8. Frangi, A.F., Niessen, W.J., Vincken, K.L., Viergever, M.A.: Multiscale vessel enhancement filtering. In: Wells, W.M., Colchester, A., Delp, S. (eds.) MICCAI 1998. LNCS, vol. 1496, pp. 130–137. Springer, Heidelberg (1998). https://doi.org/ 10.1007/BFb0056195 9. Fu, X., et al.: Panoptic NeRF: 3D-to-2D label transfer for panoptic urban scene segmentation. In: International Conference on 3D Vision, 3DV 2022, Prague, Czech Republic, 12–16 September 2022, pp. 1–11. IEEE (2022) 10. He, K., Zhang, X., Ren, S., Sun, J.: Delving deep into rectifiers: surpassing humanlevel performance on imagenet classification. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1026–1034 (2015) 11. Isensee, F., Jaeger, P.F., Kohl, S.A., Petersen, J., Maier-Hein, K.H.: nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nat. Methods 18(2), 203–211 (2021) 12. Jia, Z., Huang, X., Eric, I., Chang, C., Xu, Y.: Constrained deep weak supervision for histopathology image segmentation. IEEE Trans. Med. Imaging 36(11), 2376– 2388 (2017) 13. Koziński, M., Mosinska, A., Salzmann, M., Fua, P.: Tracing in 2D to reduce the annotation effort for 3D deep delineation of linear structures. Med. Image Anal. 60, 101590 (2020) 14. Lawin, F.J., Danelljan, M., Tosteberg, P., Bhat, G., Khan, F.S., Felsberg, M.: Deep projective 3D semantic segmentation. In: Felsberg, M., Heyden, A., Krüger, N. (eds.) CAIP 2017. LNCS, vol. 10424, pp. 95–107. Springer, Cham (2017). https:// doi.org/10.1007/978-3-319-64689-3_8 15. Löffler, M.T., et al.: A vertebral segmentation dataset with fracture grading. Radiol. Artifi. Intell. 2(4), e190138 (2020) 16. Luboz, V., et al.: A segmentation and reconstruction technique for 3D vascular structures. In: Duncan, J.S., Gerig, G. (eds.) MICCAI 2005. LNCS, vol. 3749, pp. 43–50. Springer, Heidelberg (2005). https://doi.org/10.1007/11566465_6 17. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28 18. Sekuboyina, A., et al.: VerSe: a vertebrae labelling and segmentation benchmark for multi-detector CT images. Med. Image Anal. 73, 102166 (2021) 19. Shi, F., et al.: Intracranial vessel wall segmentation using convolutional neural networks. IEEE Trans. Biomed. Eng. 66(10), 2840–2847 (2019) 20. Tajbakhsh, N., Jeyaseelan, L., Li, Q., Chiang, J.N., Wu, Z., Ding, X.: Embracing imperfect datasets: A review of deep learning solutions for medical image segmentation. Med. Image Anal. 63, 101693 (2020)
3D Arterial Segmentation via Single 2D Projections and Depth Supervision
151
21. Tetteh, G., et al.: Deepvesselnet: vessel segmentation, centerline prediction, and bifurcation detection in 3-d angiographic volumes. Front. Neurosci., 1285 (2020) 22. Zhou, Y., et al.: Semi-supervised 3D abdominal multi-organ segmentation via deep multi-planar co-training. In: 2019 IEEE Winter Conference on Applications of Computer Vision (WACV), pp. 121–140. IEEE (2019)
Automatic Retrieval of Corresponding US Views in Longitudinal Examinations Hamideh Kerdegari1(B) , Nhat Phung Tran Huy1,3 , Van Hao Nguyen2 , Thi Phuong Thao Truong2 , Ngoc Minh Thu Le2 , Thanh Phuong Le2 , Thi Mai Thao Le2 , Luigi Pisani4 , Linda Denehy5 , Reza Razavi1 , Louise Thwaites3 , Sophie Yacoub3 , Andrew P. King1 , and Alberto Gomez1 1
3 4
School of Biomedical Engineering and Imaging Sciences, King’s College London, London, UK [email protected] 2 Hospital for Tropical Diseases, Ho Chi Minh City, Vietnam Oxford University Clinical Research Unit, Ho Chi Minh City, Vietnam Mahidol Oxford Tropical Medicine Research Unit, Bangkok, Thailand 5 Melbourne School of Health Sciences, The University of Melbourne, Melbourne, Australia
Abstract. Skeletal muscle atrophy is a common occurrence in critically ill patients in the intensive care unit (ICU) who spend long periods in bed. Muscle mass must be recovered through physiotherapy before patient discharge and ultrasound imaging is frequently used to assess the recovery process by measuring the muscle size over time. However, these manual measurements are subject to large variability, particularly since the scans are typically acquired on different days and potentially by different operators. In this paper, we propose a self-supervised contrastive learning approach to automatically retrieve similar ultrasound muscle views at different scan times. Three different models were compared using data from 67 patients acquired in the ICU. Results indicate that our contrastive model outperformed a supervised baseline model in the task of view retrieval with an AUC of 73.52% and when combined with an automatic segmentation model achieved 5.7% ± 0.24% error in cross-sectional area. Furthermore, a user study survey confirmed the efficacy of our model for muscle view retrieval. Keywords: Muscle atrophy · Ultrasound view retrieval Self-supervised contrastive learning · Classification
·
H. Kerdegari—This work was supported by the Wellcome Trust UK (110179/Z/15/Z, 203905/Z/16/Z, WT203148/Z/16/Z). H. Kerdegari, N. Phung, R. Razavi, A. P King and A. Gomez acknowledge financial support from the Department of Health via the National Institute for Health Research (NIHR) comprehensive Biomedical Research Centre award to Guy’s and St Thomas’ NHS Foundation Trust in partnership with King’s College London and King’s College Hospital NHS Foundation Trust. Vital Consortium: Membership of the VITAL Consortium is provided in the Acknowledgments. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 152–161, 2023. https://doi.org/10.1007/978-3-031-43907-0_15
Automatic Retrieval of Corresponding US Views
1
153
Introduction
Muscle wasting, also known as muscle atrophy (see Fig. 1), is a common complication in critically ill patients, especially in those who have been hospitalized in the intensive care unit (ICU) for a long period [17]. Factors contributing to muscle wasting in ICU patients include immobilization, malnutrition, inflammation, and the use of certain medications [13]. Muscle wasting can result in weakness, impaired mobility, and increased morbidity and mortality. Assessing the degree of muscle wasting in ICU patients is essential for monitoring their progress and tailoring their rehabilitation program to recover muscular mass through physiotherapy before patient discharge. Traditional methods of assessing muscle wasting, such as physical examination, bioelectrical impedance analysis, and dual-energy X-ray absorptiometry, may be limited in ICUs due to the critical illness of patients [15]. Instead, ultrasound (US) imaging has emerged as a reliable, non-invasive, portable tool for assessing muscle wasting in the ICU [11].
Fig. 1. Example of the cross-section of the rectus femoris (RF) on one ICU patient showing muscle mass reduction from admission (9cm2 , left) to discharge (6cm2 , right).
The accuracy and reliability of US imaging in assessing muscle wasting in ICU patients have been demonstrated by Parry et al. [12]. US imaging can provide accurate measurements of muscle size, thickness, and architecture, allowing clinicians to track changes over time. However, these measurements are typically performed manually, which is time-consuming, subject to large variability and depends on the expertise of the operator. Furthermore, operators might be different from day to day and/or start scanning from different positions in each scan which will cause further variability. In recent years, self-supervised learning (SSL) has gained popularity for automated diagnosis in the field of medical imaging due to its ability to learn from unlabeled data [1,6,8,16]. Previous studies on SSL for medical imaging have focused on designing pretext tasks [2,9,10,18]. A class of SSL, contrastive learning (CL), aims to learn feature representations via a contrastive loss function to distinguish between negative and positive image samples. A relatively small number of works have applied CL to US imaging, for example to synchronize different cross-sectional views [7] and to perform view classification [4] in echocardiography (cardiac US).
154
H. Kerdegari et al.
In this paper, we focus on the underinvestigated application of view matching for longitudinal RF muscle US examinations to assess muscle wasting. Our method uses a CL approach (see Fig. 2) to learn a discriminative representation from muscle US data which facilitates the retrieval of similar muscle views from different scans.
Fig. 2. Proposed architecture for US view retrieval. (a): Overview, a shared encoder and a projection head (two dense layers, each 512 nodes). (b): Encoder subnetwork. (c): The classification subnetwork has four dense layers of 2024, 1024, 512 and 2 features.
The novel contributions of this paper are: 1) the first investigation of the problem of muscle US view matching for longitudinal image analysis, and 2) our approach is able to automatically retrieve similar muscle views between different scans, as shown by quantitative validation and qualitatively through a clinical survey.
2 2.1
Method Problem Formulation
Muscle wasting assessment requires matching of corresponding cross-sectional US views of the RF over subsequent (days to weeks apart) examinations. The first acquisition is carried out following a protocol to place the transducer half way through the thigh and perpendicular to the skin, but small variations in translation and angulation away from this standard view are common. This scan produces the reference view at time T1 (RT1 ). The problem is as follows: given RT1 , the task is to retrieve the corresponding view (V T2 ) at a later time (T2 ) from a sequence of US images captured by the operator using the transducer at approximately the same location and angle as for T1 . The main challenges
Automatic Retrieval of Corresponding US Views
155
of this problem include: (1) the transducer pose and angle might be different, (2) machine settings might be slightly different, and (3) parts of the anatomy (specifically the RF) might change in shape and size over time. As a result, our aim is to develop a model that can select the most similar view acquired during T2 to the reference view RT1 acquired at T1 . 2.2
Contrastive Learning Framework for Muscle View Matching
Inspired by the SimCLR algorithm [5], our model learns representations by maximizing the similarity between two different augmented views of the same muscle US image via a contrastive loss in the latent space. We randomly sample a minibatch of N images from the video sequences over three times T1 , T2 and T3 , and define the contrastive learning on positive pairs (Xi, Xj) of augmented images derived from the minibatch, resulting in 2N samples. Rather than explicitly sampling negative examples, given a positive pair, we consider the other 2(N − 1) augmented image pairs within a minibatch as negative. The contrastive loss function for a positive pair (Xi, Xj) is defined as: LiC = −log 2n
k=1
exp(sim(zi , zj )/τ ) 1[k=i] exp(sim(zi , zk )/τ )
,
(1)
where 1 ∈ (0, 1), τ is a temperature parameter and sim(·) denotes the pairwise cosine similarity. z is a representation vector, calculated by z = g(f (X)), where f(·) indicates a shared encoder and g(·) is a projection head. LiC is computed across all positive pairs in a mini-batch. Then f (·) and g(·) are trained to maximize similarity using this contrastive loss. 2.3
The Model Architecture
The model architecture is shown in Fig. 2a. First, we train the contrastive model to identify the similarity between two images, which are a pair of image augmentations created by horizontal flipping and random cropping (size 10×10) applied on a US image (i.e., they represent different versions of the same image). Each image of this pair (Xi, Xj) is fed into an encoder to extract representation vectors (hi, hj) from them. The encoder architecture (Fig. 2b) has four conv layers (kernel 3 × 3) with ReLU and two max-poolings. A projection head (a multilayer perceptron with two dense layers of 512 nodes) follows mapping these representations to the space where the contrastive loss is applied. Second, we use the trained encoder f (·) for the training of our main task (i.e. the downstream task), which is the classification of positive and negative matches (corresponding and non-corresponding views) of our test set. For that, we feed a reference image Xref , and a candidate frame Xj to the encoder to obtain the representations hi, hj and feed these in turn to a classification network (shown in Fig. 2c) that contains four dense layers with ReLU activation and a softmax layer.
156
3
H. Kerdegari et al.
Materials
The muscle US exams were performed using GE Venue Go and GE Vivid IQ machines, both with linear probes (4.2-13.0 MHz), by five different doctors. During examination, patients were in supine position with the legs in a neutral rotation with relaxed muscle and passive extension. Measurements were taken at the point three fifths of the way between the anterior superior iliac spine and the patella upper pole. The transducer was placed perpendicular to the skin and to the longitudinal axis of the thigh to get the cross-sectional area of the RF. An excess of US gel was used and pressure on the skin was kept minimal to maximise image quality. US measurements were taken at ICU admission (T1 ), 2-7 d after admission (T2 ) and at ICU discharge (T3 ). For this study, 67 Central Nervous System (CNS) and Tetanus patients were recruited and their data were acquired between June 2020 and Feb 2022. Each patient had an average of six muscle ultrasound examinations, three scans for each leg, totalling 402 examinations. The video resolution was 1080 × 1920 with a frame rate of 30fps. This study was performed in line with the principles of the Declaration of Helsinki. Approval was granted by the Ethics Committee of the Hospital for Tropical Diseases, Ho Chi Minh City and Oxford Tropical Research Ethics Committee.
Fig. 3. An example of positive and negative pair labeling for US videos acquired at T1 and T2 . Positive pairs are either the three views acquired consecutively at the Ti , or a view labeled at T1 and the corresponding view on the same leg at T2 or T3 .
The contrastive learning network was trained without any annotations. However, for the view matching classification task, our test data were annotated automatically as positive and negative pairs based upon manual frame selection by a team of five doctors comprising three radiologists and two ultrasound specialists with expertise in muscle ultrasound. Specifically, each frame in an examination was manually labelled as containing a similar view to the reference RT1 or not. Based upon these labelings, as shown in Fig. 3, the positive pairs are combinations of similar views within each examination (T1 /T2 /T3 ) and between examinations. The rest are considered negative pairs.
Automatic Retrieval of Corresponding US Views
4
157
Experiments and Results
4.1
Implementation Details
Our model was implemented using Tensorflow 2.7. During training, input videos underwent experimentation with clip sizes of 256 × 256, 128 × 128, and 64 × 64. Eventually, they were resized to 64 × 64 clips, which yielded the best performance. All the hyperparameters were chosen using the validation set. For the CL training, the standard Adam optimizer was used with learning rate =0.00001, kernel size = 3 × 3, batch size = 128, batch normalization, dropout with p = 0.2 and L2 regularization of the model parameters with a weight = 0.00001. The CL model was trained on 80% of the muscle US data for 500 epochs. For the view retrieval model, the standard Adam optimizer with learning rate = 0.0001, batch size = 42 and dropout of p = 0.2 was used. The classifier was trained on the remaining 20% of the data (of which 80% were used for training, 10% for validation and 10% for testing) and the network converged after 60 epochs. For the supervised baseline model, the standard Adam optimizer was used with learning rate =0.00001, kernel size = 3 × 3, batch size = 40, and batch normalization. Here, we used the same data splitting as our view retrieval classifier. The code we used to train and evaluate our models is available at https://github.com/ hamidehkerdegari/Muscle-view-retrieval. 4.2
Results
Quantitative Results. We carried out two quantitative experiments. First, we evaluated the performance of the view classifier. Second, we evaluated the quality of the resulting cross-sectional areas segmented using a U-Net [14]. The classifier performance was carried out by measuring, for the view retrieval task, the following metrics: Area Under the Curve (AUC), precision, recall, and F1-score. Because there is no existing state of the art for this task, we created two baseline models to compare our proposed model to: first, a naive image-space comparison using normalized cross-correlation (NCC) [3], and second, a supervised classifier. The supervised classifier has the same architecture as our CL model, but with the outputs of the two networks being concatenated after the representation h followed by a dense layer with two nodes and a softmax activation function to produce the probabilities of being a positive or negative pair. Table 1 shows the classification results on our dataset. Table 1. AUC, precision, recall and F1 score results on the muscle video dataset. Model
Precision
Recall
F1
Normalized cross-correlation 68.35 %
AUC
58.65 %
63.12 %
60.8 %
Supervised baseline model
69.87 %
65.81 %
60.57 %
63.08 %
Proposed model
73.52 % 67.2 %
68.31 % 67.74 %
158
H. Kerdegari et al.
As shown in Table 1, our proposed method achieved superior performance in terms of AUC, precision, recall, and F1-score compared to all other models. The NCC method demonstrated the lowest performance, as it lacked the capability to accurately capture dynamic changes and deformations in US images which can result in significant structural differences. A representative example of a modelretrieved view for one case is presented in Fig. 4. It shows positive, negative, and middle (i.e., images with a probability value between the highest and lowest values predicted by our model) pairs of images generated by our model from a patient’s left leg. As reference, on the left we show the user pick (RT2 ).
Fig. 4. Results showing three sample positive, medium and negative predicted pairs by our model when ground truth (GT) from T1 is compared with the T2 video.
To assess the quality of the resulting cross-sections, we calculated the mean relative absolute area difference (d) between the ground truth (aGT ) frame and that of the model predicted frame (apred ) for each examination as follows: d=
|aGT − apred | aGT
(2)
We applied a trained U-Net model (already trained with 1000 different US muscle images and manual segmentations). Results showed an overall cross-sectional mean relative absolute area error of 5.7% ± 0.24% on the test set (Full details provided in Fig. 5, right). To put this number into context, Fig. 5, left visualizes two cases where the relative error is 2.1% and 5.2%. Qualitative Results. We conducted a user study survey to qualitatively assess our model’s performance. The survey was conducted blindly and independently by four clinicians and consisted of thirty questions. In each, clinicians were shown two different series of three views of the RF: (1) RT1 , GT match from T2 and model prediction from T2 , and (2) RT1 , a random frame from T2 and model
Automatic Retrieval of Corresponding US Views
159
prediction from T2 . They were asked to indicate which (second or third) was the best match with the first image. The first question aimed to determine if the model’s performance was on par with clinicians, while the second aimed to determine if the model’s selection of images was superior to a randomly picked frame. As shown in Fig. 6, left, clinicians chose the model prediction more often than the GT; however, this difference was not significant (paired Student’s t-test, p = 0.44, significance= 0.05). Therefore, our model can retrieve the view as well as clinicians, and significantly better (Fig. 6, right) than randomly chosen frames (paired Student’s t-test, p = 0.02, significance= 0.05).
Fig. 5. Left: cross-sectional area error for T1 and T2 examinations (acquisition times). Right: mean relative absolute area difference (d) for T1 T1 , T1 T2 , T1 T3 (reference frame from T1 and corresponding predicted frames from T1, T2 and T3 respectively) and overall acquisition time.
5
Discussion and Conclusion
This paper has presented a self-supervised CL approach for automatic muscle US view retrieval in ICU patients. We trained a classifier to find positive and
Fig. 6. User study survey results. Left: when T1 GT, T2 GT and model prediction (from T2 ) are shown to the users. Right: when T1 GT, T2 -random frame and model prediction (from T2 ) are shown to the users.
160
H. Kerdegari et al.
negative matches. We also computed the cross-sectional area error between the ground truth frame and the model prediction in each acquisition time to evaluate model performance. The performance of our model was evaluated on our muscle US video dataset and showed AUC of 73.52% and 5.7% ± 0.24% error in cross-sectional view. Results showed that our model outperformed the supervised baseline approach. This is the first work proposed to identify corresponding ultrasound views over time, addressing an unmet clinical need. Acknowledgments. The VITAL Consortium: OUCRU: Dang Phuong Thao, Dang Trung Kien, Doan Bui Xuan Thy, Dong Huu Khanh Trinh, Du Hong Duc, Ronald Geskus, Ho Bich Hai, Ho Quang Chanh, Ho Van Hien, Huynh Trung Trieu, Evelyne Kestelyn, Lam Minh Yen, Le Dinh Van Khoa, Le Thanh Phuong, Le Thuy Thuy Khanh, Luu Hoai Bao Tran, Luu Phuoc An, Nguyen Lam Vuong, Ngan Nguyen Lyle, Nguyen Quang Huy, Nguyen Than Ha Quyen, Nguyen Thanh Ngoc, Nguyen Thi Giang, Nguyen Thi Diem Trinh, Nguyen Thi Kim Anh, Nguyen Thi Le Thanh, Nguyen Thi Phuong Dung, Nguyen Thi Phuong Thao, Ninh Thi Thanh Van, Pham Tieu Kieu, Phan Nguyen Quoc Khanh, Phung Khanh Lam, Phung Tran Huy Nhat, Guy Thwaites, Louise Thwaites, Tran Minh Duc, Trinh Manh Hung, Hugo Turner, Jennifer Ilo Van Nuil, Vo Tan Hoang, Vu Ngo Thanh Huyen, Sophie Yacoub. Hospital for Tropical Diseases, Ho Chi Minh City: Cao Thi Tam, Ha Thi Hai Duong, Ho Dang Trung Nghia, Le Buu Chau, Le Mau Toan, Nguyen Hoan Phu, Nguyen Quoc Viet, Nguyen Thanh Dung, Nguyen Thanh Nguyen, Nguyen Thanh Phong, Nguyen Thi Cam Huong, Nguyen Van Hao, Nguyen Van Thanh Duoc, Pham Kieu Nguyet Oanh, Phan Thi Hong Van, Phan Vinh Tho, Truong Thi Phuong Thao. University of Oxford: Natasha Ali, James Anibal, David Clifton, Mike English, Ping Lu, Jacob McKnight, Chris Paton, Tingting Zhu Imperial College London: Pantelis Georgiou, Bernard Hernandez Perez, Kerri Hill-Cawthorne, Alison Holmes, Stefan Karolcik, Damien Ming, Nicolas Moser, Jesus Rodriguez Manzano. King’s College London: Liane Canas, Alberto Gomez, Hamideh Kerdegari, Andrew King, Marc Modat, Reza Razavi. University of Ulm: Walter Karlen. Melbourne University: Linda Denehy, Thomas Rollinson. Mahidol Oxford Tropical Medicine Research Unit (MORU): Luigi Pisani, Marcus Schultz
References 1. Azizi, S., et al.: Big self-supervised models advance medical image classification. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 3478–3488 (2021) 2. Bai, W.: Self-supervised learning for cardiac mr image segmentation by anatomical position prediction. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 541–549. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8 60 3. Bourke, P.: Cross correlation. Cross Correlation”, Auto Correlation-2D Pattern Identification (1996) 4. Chartsias, A., et al.: Contrastive learning for view classification of echocardiograms. In: Simplifying Medical Ultrasound: Second International Workshop, ASMUS 2021, Held in Conjunction with MICCAI 2021, Strasbourg, France, 27 September 2021, Proceedings 2, pp. 149–158. Springer (2021). https://doi.org/10.1007/978-3-03116440-8 33
Automatic Retrieval of Corresponding US Views
161
5. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607. PMLR (2020) 6. Chen, Y., et al.: USCL: pretraining deep ultrasound image diagnosis model through video contrastive representation learning. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12908, pp. 627–637. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87237-3 60 7. Dezaki, F.T.: Echo-syncnet: self-supervised cardiac view synchronization in echocardiography. IEEE Trans. Med. Imaging 40(8), 2092–2104 (2021) 8. Hosseinzadeh Taher, M.R., Haghighi, F., Feng, R., Gotway, M.B., Liang, J.: A systematic benchmarking analysis of transfer learning for medical image analysis. In: Albarqouni, S., et al. (eds.) DART/FAIR -2021. LNCS, vol. 12968, pp. 3–13. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87722-4 1 9. Hu, S.Y., et al.: Self-supervised pretraining with dicom metadata in ultrasound imaging. In: Machine Learning for Healthcare Conference, pp. 732–749 (2020) 10. Jiao, J., Droste, R., Drukker, L., Papageorghiou, A.T., Noble, J.A.: Self-supervised representation learning for ultrasound video. In: 2020 IEEE 17th International Symposium on Biomedical Imaging (ISBI), pp. 1847–1850. IEEE (2020) 11. Mourtzakis, M., Wischmeyer, P.: Bedside ultrasound measurement of skeletal muscle. Current Opinion Clinical Nutrition Metabolic Care 17(5), 389–395 (2014) 12. Parry, S.M., et al.: Ultrasonography in the intensive care setting can be used to detect changes in the quality and quantity of muscle and is related to muscle strength and function. J. Crit. Care 30(5), 1151-e9 (2015) 13. Puthucheary, Z.A., et al.: Acute skeletal muscle wasting in critical illness. JAMA 310(15), 1591–1600 (2013) 14. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4 28 15. Schefold, J.C., Wollersheim, T., Grunow, J.J., Luedi, M.M., Z’Graggen, W.J., Weber-Carstens, S.: Muscular weakness and muscle wasting in the critically ill. J. Cachexia. Sarcopenia Muscle 11(6), 1399–1412 (2020) 16. Sowrirajan, H., Yang, J., Ng, A.Y., Rajpurkar, P.: Moco pretraining improves representation and transferability of chest x-ray models. In: Medical Imaging with Deep Learning, pp. 728–744. PMLR (2021) 17. Trung, T.N., et al.: Functional outcome and muscle wasting in adults with tetanus. Trans. R. Soc. Trop. Med. Hyg. 113(11), 706–713 (2019) 18. Zhuang, X., Li, Y., Hu, Y., Ma, K., Yang, Y., Zheng, Y.: Self-supervised feature learning for 3d medical images by playing a rubik’s cube. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11767, pp. 420–428. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32251-9 46
Many Tasks Make Light Work: Learning to Localise Medical Anomalies from Multiple Synthetic Tasks Matthew Baugh1(B) , Jeremy Tan2 , Johanna P. M¨ uller3 , Mischa Dombrowski3 , James Batten1 , and Bernhard Kainz1,3 1
3
Imperial College London, London, United Kingdom [email protected] 2 ETH Zurich, Zurich, Switzerland Friedrich–Alexander University Erlangen–N¨ urnberg, Erlangen, Germany Abstract. There is a growing interest in single-class modelling and outof-distribution detection as fully supervised machine learning models cannot reliably identify classes not included in their training. The long tail of infinitely many out-of-distribution classes in real-world scenarios, e.g., for screening, triage, and quality control, means that it is often necessary to train single-class models that represent an expected feature distribution, e.g., from only strictly healthy volunteer data. Conventional supervised machine learning would require the collection of datasets that contain enough samples of all possible diseases in every imaging modality, which is not realistic. Self-supervised learning methods with synthetic anomalies are currently amongst the most promising approaches, alongside generative auto-encoders that analyse the residual reconstruction error. However, all methods suffer from a lack of structured validation, which makes calibration for deployment difficult and datasetdependant. Our method alleviates this by making use of multiple visuallydistinct synthetic anomaly learning tasks for both training and validation. This enables more robust training and generalisation. With our approach we can readily outperform state-of-the-art methods, which we demonstrate on exemplars in brain MRI and chest X-rays. Code is available at https://github.com/matt-baugh/many-tasks-make-light-work.
1
Introduction
In recent years, the workload of radiologists has grown drastically, quadrupling from 2006 to 2020 in Western Europe [4]. This huge increase in pressure has led to long patient-waiting times and fatigued radiologists who make more mistakes [3]. The most common of these errors is underreading and missing anomalies (42%); followed by missing additional anomalies when concluding their search after an initial finding (22%) [10]. Interestingly, despite the challenging work environment, only 9% of errors reviewed in [10] were due to mistakes in the clinicians’ Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 16. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 162–172, 2023. https://doi.org/10.1007/978-3-031-43907-0_16
Many Tasks Make Light Work
163
reasoning. Therefore, there is a need for automated second-reader capabilities, which brings any kind of anomalies to the attention of radiologists. For such a tool to be useful, its ability to detect rare or unusual cases is particularly important. Traditional supervised models would not be appropriate, as acquiring sufficient training data to identify such a broad range of pathologies is not feasible. Unsupervised or self-supervised methods to model an expected feature distribution, e.g., of healthy tissue, is therefore a more natural path, as they are geared towards identifying any deviation from the normal distribution of samples, rather than a particular type of pathology. There has been rising interest in using end-to-end self-supervised methods for anomaly detection. Their success is most evident at the MICCAI Medical Outof-Distribution Analysis (MOOD) Challenge [31], where all winning methods have followed this paradigm so far (2020-2022). These methods use the variation within normal samples to generate diverse anomalies through sample mixing [7, 23–25]. However all these methods lack a key component: structured validation. This creates uncertainty around the choice of hyperparameters for training. For example, selecting the right training duration is crucial to avoid overfitting to proxy tasks. Yet, in practice, training time is often chosen arbitrarily, reducing reproducibility and potentially sacrificing generalisation to real anomalies. Contribution: We propose a cross-validation framework, using separate selfsupervision tasks to minimise overfitting on the synthetic anomalies that are used for training. To make this work effectively we introduce a number of non-trivial and seamlessly-integrated synthetic tasks, each with a distinct feature set so that during validation they can be used to approximate generalisation to unseen, real-world anomalies. To the best of our knowledge, this is the first work to train models to directly identify anomalies on tasks that are deformation-based, tasks that use Poisson blending with patches extracted from external datasets, and tasks that perform efficient Poisson image blending in 3D volumes, which is in itself a new contribution of our work. We also introduce a synthetic anomaly labelling function which takes into account the natural noise and variation in medical images. Together our method achieves an average precision score of 76.2 for localising glioma and 78.4 for identifying pathological chest X-rays, thus setting the state-of-the-art in self-supervised anomaly detection. Related Work: The most prevalent methods for self-supervised anomaly detection are based on generative auto-encoders that analyse the residual error from reconstructing a test sample. This is built on the assumption that a reconstruction model will only be able to correctly reproduce data that is similar to the instances it has been trained on, e.g. only healthy samples. Theoretically, at test time, the residual reconstruction error should be low for healthy tissues but high for anomalous features. This is an active area of research with several recent improvements upon the initial idea [22], e.g., [21] applied a diffusion model to a VQ-VAE [27] to resample the unlikely latent codes and [30] gradually transition from a U-Net architecture to an autoencoder over the training process in order to improve the reconstruction of finer details. Several other methods aim to ensure that the model will not reproduce anomalous regions by training it to
164
M. Baugh et al.
restore samples altered by augmentations such as masking out regions [32], interpolating heavily augmented textures [29] or adding coarse noise [9]. [5] sought to identify more meaningful errors in image reconstructions by comparing the reconstructions of models trained on only healthy data against those trained on all available data. However, the general assumption that reconstruction error is a good basis for an anomaly scoring function has recently been challenged. Auto-encoders are unable to identify anomalies with extreme textures [16], are reliant on empirical post-processing to reduce false-positives in healthy regions [2] and can be outperformed by trivial approaches like thresholding of FLAIR MRI [15]. Self-supervised methods take a more direct approach, training a model to directly predict an anomaly score using synthetic anomalies. Foreign patch interpolation (FPI) [24] was the first to do this at a pixel-level, by linearly interpolating patches extracted from other samples and predicting the interpolation factor as the anomaly score. Similar to CutPaste [11], [7] fully replaces 3D patches with data extracted from elsewhere in the same sample, but then trains the model to segment the patches. Poisson image interpolation (PII) [25] seamlessly integrates sample patches into training images, preventing the models from learning to identify the anomalies by their discontinuous boundaries. Natural synthetic anomalies (NSA) [23] relaxes patch extraction to random locations in other samples and introduces an anomaly labelling function based on the changes introduced by the anomaly. Some approaches combine self-supervised and reconstruction-based methods by training a discriminator to compute more exact segmentations from reconstruction model errors [6,29]. Other approaches have also explored contrasting self-supervised learning for anomaly detection [12,26].
2
Method
The core idea of our method is to use synthetic tasks for both training and validation. This allows us to monitor performance and prevent overfitting, all without the need for real anomalous data. Each self-supervised task involves introducing a synthetic anomaly into otherwise normal data whilst also producing the corresponding label. Since the relevant pathologies are unknown a priori, we avoid simulating any specific pathological features. Instead, we use a wide range of subtle and well-integrated anomalies to help the model detect many different kinds of deviations, ideally including real unforeseen anomalies. In our experiments, we use five tasks, but more could be used as long as each one is sufficiently unique. Distinct tasks are vital because we want to use these validation tasks to estimate the model’s generalisation to unseen classes of anomalies. If the training and validation tasks are too similar, the performance on the validation set may be an overly optimistic estimate of how the model would perform on unseen real-world anomalies. When performing cross-validation over all synthetic tasks and data partitions independently, the number of possible train/validation splits increases significantly, requiring us to train F · (TN CT ) independent models, where TN is the
Many Tasks Make Light Work
165
total number of tasks, T is the number of tasks used to train each model and F is the number of data folds, which is computationally expensive. Instead, as in our case TN = F = 5, we opt to associate each task with a single fold of the training data (Fig. 1). We then apply 5CT -fold cross-validation over each combination. In each iteration, the corresponding data folds are collected and used for training or validation, depending on which partition forms the majority. Training Dataset
Synthetic tasks
Data Fold 0
Task 0
Cross-validation training Task subset a)
Training
Data subset a)
Data Fold 1
Task 1
Data Fold 2
Task 2
Data Fold 3
Task 3
Larger Compare sizes
Smaller Data subset b)
Data Fold 4
Task 4
Task subset b)
Validation
Fig. 1. Our pipeline performs cross-validation over the synthetic task and data fold pairs.
Fig. 2. Examples of changes introduced by synthetic anomalies, showing the before (grey) and after (green) of a 1D slice across the affected area. - deformation centre for sink/source. (Color figure online)
Synthetic Tasks: Figure 2 shows examples of our self-supervised tasks viewed in both one and two dimensions. Although each task produces visually distinct anomalies, they fall into three overall categories, based on blending, deformation, or intensity variation. Also, all tasks share a common recipe: the target anomaly mask Mh is always a randomly sized and rotated ellipse or rectangle (ellipsoids/cuboids in 3D); all anomalies are positioned such that at least 50% of the mask intersects with the foreground of the image; and after one augmentation is applied, the process is randomly repeated (based on a fair coin toss, p = 0.5), for up to a maximum of 4 anomalies per image. The Intra-dataset Blending Task. Poisson image blending is the current state-of-the-art for synthetic anomaly tasks [23,25], but it does not scale naturally to more than two dimensions or non-convex interpolation regions [17]. Therefore, we extend [20] and propose a D-dimensional variant of Poisson image editing following earlier ideas by [17]. Poisson image editing [20] uses the image gradient to seamlessly blend a patch into an image. It does this by combining the target gradientwith Dirichlet 2 boundary conditions to define a minimisation problem minfin h |∇fin − v| with fin |∂h = fout |∂h , and fin representing the intensity values within the patch h. The goal is to find intensity values of fin that will match the surrounding values, fout , of the destination image xi , along the border of the patch and ∂· ∂· follow the image gradient, v = ∇· = ∂x , ∂y , of the source image xj . Its solution is the Poisson equation Δfin = divv over h with fin |∂h = fout |∂h . Note that the divergence of v is equal to the Laplacian of the source image Δxj . Also, by defining h as the axis-aligned bounding box of Mh , we can ensure the
166
M. Baugh et al.
boundaries coincide with coordinate lines. This enables us to use the Fourier transform method to solve this partial differential equation [17], which yields a direct relationship between Fourier coefficients of Δfin and v after padding to a symmetric image. To simplify for our use case, an image with shape N0 × transformation with a discrete sine transform · · · × ND−1 , we replace the Fourier D−1 Nd −1 π(n+1)(ud +1) ˆ (DST) fu = d=0 n=0 n sin . This follows as a DST is equivalent Nd +1 to a discrete Fourier transform of a real sequence that is odd around the zeroth and middle points, scaled by 0.5, which can be established for our images. With this, the Poisson equation becomes congruent to a relationship of the coefficients,
D−1 π(ud +1) 2 d=0
Nd +1
D−1 fˆu ∼ = d=0
π(ud +1) Nd +1
ˆ d where v=(v0 , ..., vD−1 ) v
ˆ is the DST of each component. The solution for fˆu can then be computed and v in DST space by dividing the right side through the terms on the left side and the destination image can be obtained through xi = DST −1 (fˆu ). Because this approach uses a frequency transform-based solution, it may slightly alter areas outside of Mh (where image gradients are explicitly edited) in order to ensure the changes are seamlessly integrated. We refer to this blending process as x ˜ = P oissonBlend(xi , xj , Mh ) in the following. The intra-dataset blending task therefore results from x ˜intra = P oissonBlend(x, x , Mh ) with x, x ∈ D with samples from a common dataset D and is therefore similar to the self-supervision task used in [23] for 2D images. The inter-dataset blending task follows the same process as intra-dataset blending but uses patches extracted from an external dataset D , allowing for a greater variety of structures. Therefore, samples from this task can be defined as x ˜inter = P oissonBlend(x, x , Mh ) with x ∈ D, x ∈ D . The sink/source tasks shift all points in relation to a randomly selected deformation centre c. For a given point p, we resample intensities from a new location p˜. To create a smooth displacement centred on c, we consider the distance p−c2 in relation to the radius of the mask (along this direction), d. The extent of this displacement is controlled by the exponential factor f > 1. For example, the sink task (Eqn. 1) with a factor of f = 2 would take the intensity at 0.75d and place it at 0.5d, effectively pulling these intensities closer to the centre. Note that unlike the sink equation in [24] this formulation cannot sample outside of the boundaries of PMh meaning it seamlessly blends into the surrounding area. The source task (Eqn. 2) performs the reverse, appearing to push the pixels away from the centre by sampling intensities towards it. f p−c p − c2 x ˜p = xp˜, p˜ = c+d 1− 1− , c ∈ PMh , ∀ p ∈ PMh (1) p − c2 d p−c x ˜p = xp˜, p˜ = c + d p − c2
p − c2 d
f , c ∈ P Mh , ∀ p ∈ P Mh
(2)
The smooth intensity change task aims to either add or subtract an intensity over the entire anomaly mask. To avoid sharp discontinuities at the boundaries,
Many Tasks Make Light Work
167
this intensity change is gradually dampened for pixels within a certain margin of the boundary. This smoothing starts at a random distance from the boundary, ds , and the change is modulated by dp /ds . Anomaly Labelling: In order to train and validate with multiple tasks simultaneously we use the same anomaly labelling function across all of our tasks. The scaled logistic function, used in NSA [23], helps to translate raw intensity changes into more semantic labels. But, it also rounds imperceptible differences up to a minimum score of about 0.1. This sudden and arbitrary jump creates noisy labels and can lead to unstable training. We correct this semantic dis(x) with X ∼ N (0, σ 2 ), instead of continuity by computing labels as y = 1 − ppX X (0) 1 [23]. This flipped Gaussian shape is C1 continuous and smoothly y = 1+e−k(x−x 0) approaches zero, providing consistent labels even for smaller changes.
3
Experiments and Results
Data: We evaluate our method on T2-weighted brain MR and chest X-ray datasets to provide direct comparisons to state-of-the-art methods over a wide range of real anomalies. For brain MRI we train on the Human Connectome Project (HCP) dataset [28] which consists of 1113 MRI scans of healthy, young adults acquired as part of a scientific study. To evaluate, we use the Brain Tumor Segmentation Challenge 2017 (BraTS) dataset [1], containing 285 cases with either high or low grade glioma, and the ischemic stroke lesion segmentation challenge 2015 (ISLES) dataset [13], containing 28 cases with ischemic stroke lesions. The data from both test sets was acquired as part of clinical routine. The HCP dataset was resampled to have 1mm isotropic spacing to match the test datasets. We apply z-score normalisation to each sample and then align the bounding box of each brain before padding it to a size of 160 × 224 × 160. Lastly, samples are downsampled by a factor of two. For chest X-rays we use the VinDr-CXR dataset [18] including 22 different local labels. To be able to compare with the benchmarks reported in [6] we use the same healthy subset of 4000 images for training along with their test set (DDADts ) of 1000 healthy and 1000 unhealthy samples, with some minor changes outlined as follows. First note that [6] derives VinDr-CXR labels using the majority vote of the 3 annotators. Unfortunately, this means there are 52 training samples, where 1/3 of radiologists identified an anomaly, but the majority label is counted as healthy. The same applies to 10 samples within the healthy testing subset. To avoid this ambiguity, we replace these samples with leftover training data that all radiologists have labelled as healthy. We also evaluate using the true test set (VinDrts ), where two senior radiologists have reviewed and consolidated all labels. For preprocessing, we clip pixel intensities according to the window centre and width attributes in each DICOM file, and apply histogram equalisation, before scaling intensities to the range [−1, 1]. Finally, images are resized to 256 × 256.
168
M. Baugh et al.
Table 1. Upper left part: Metrics on Brain MRI, evaluated on BraTS and ISLES, presented as AP/AUROC. · indicates that the metrics are evaluated over the same region and at the same resolution as CRADL [12]. Upper right part: Metrics on VinDr-CXR, presented as AP/AUROC on the VinDr and DDAD test splits. Random is the baseline performance of a random classifier. Lower part: a sensitivity analysis of the average AP of each individual fold (mean±s.d.) alongside that of the model ensemble, varying how many tasks we use for training versus validation. Best results are highlighted in bold. CRADL setup
Brain MRI
Methods MRI
1/4
Chest X-Ray (CXR)
Pixel-wise BraTS ISLES
51.9/71.7 54.1/72.7 54.9/69.3
29.8/92.5 48.3/94.8 38.0/94.2
7.7/87.5 14.5/87.9 18.6/89.8
61.3/80.2
76.2/98.7
46.5/97.1
Sample-wise VinDrts DDADts 59.8/55.8 74.8/76.3 72.8/73.8 49.9/48.2
Pixel-wise DDADts VinDrts MemAE [8] f-AnoGAN [22] AE-U [14] FPI [24]
CRADL setup
Ours
87.6/89.4
65.8/65.9
PII [25]
Methods CXR
VAE 80.7/83.3 ceVAE[32] 85.6/86.5 CRADL[12] 81.9/82.6
Train/val. split abl.
Slice-wise BraTS17 ISLES
CRADL setup
Random Random Ours
2/3 3/2 4/1
all ens. all ens. all ens. all ens.
49.0/50.0 36.6/50.0 40.3/50.0 29.4/50.0 87.6/92.2 62.0/84.6 83.4±4.4 87.6 82.5±3.3 85.7 81.1±4.3 84.0 81.5±2.7 83.1
59.3±2.2 62.0 55.9±8.5 58.4 52.5±4.7 55.0 53.1±2.3 54.7
2.4/50.0 1.1/50.0 1.7/50.0 0.8/50.0 76.2/99.1 45.9/97.9 46.9±14.9 76.2 42.8±12.8 72.2 37.9±11.1 63.7 36.1±9.0 52.5
23.7±7.7 45.9 21.2±9.3 41.0 15.4±3.3 26.6 16.5±5.0 23.7
65.8/64.4 50.0/50.0 31.6/50.0 78.4/76.6 71.2/81.1 74.7±4.9 78.4 78.6±1.4 80.7 78.7±1.8 80.4 79.2±1.3 80.5
66.3±4.4 71.2 71.0±1.4 73.8 71.1±1.4 73.3 71.8±1.3 73.6
NSA [23] 4.5/50.0 2.7/50.0 21.1/75.6 21.4/81.2 15.3±3.7 21.1 19.2±1.7 24.0 20.3±1.4 24.3 20.4±0.9 23.5
15.2±4.5 21.4 19.5±1.8 24.7 20.4±1.7 24.7 21.1±0.9 24.5
Comparison to State-of-the-Art Methods: Validating on synthetic tasks is one of our main motivations; as such, we use a 1/4 (train/val.) task split to compare with benchmark methods. For brain MRI, we evaluate results at the slice and voxel level, computing average precision (AP) and area under the receiver operating characteristic curve (AUROC), as implemented in scikit learn [19]. Note that the distribution shift between training and test data (research vs. clinical scans) adds further difficulty to this task. In spite of this, we substantially improve upon the current state-of-the-art (Table 1 upper left). In particular, we achieve a pixel-wise AP of 76.2 and 45.9 for BraTS and ISLES datasets respectively. To make our comparison as faithful as possible, we also re-evaluate after post-processing our predictions to match the region and resolution used by CRADL, where we see similar improvement. Qualitative examples are shown in Fig. 3. Note that all baseline methods use a validation set consisting of real anomalous samples from BraTS and ISLES to select which anomaly scoring function to use. We, however, only use synthetic validation data. This further verifies that our method of using synthetic data to estimate generalisation works well. For both VinDr-CXR test sets we evaluate at a sample and pixel level, although previous publications have only reported their results at a sample level.
Many Tasks Make Light Work
169
Fig. 3. Examples of predictions on randomly selected BraTS and ISLES samples after training on HCP. The red contour outlines the ground truth segmentation. (Color figure online)
We again show performance above the current state-of-the-art (Table 1 upper right). Our results are also substantially higher than previously proposed selfsupervised methods, improving on the current state-of-the-art NSA [23] by 12.6 to achieve 78.4 image-level AP. This shows that our use of synthetic validation data succeeds where their fixed training schedule fails. Ablation and Sensitivity Analysis on Cross-Validation Structure: We also investigate how performance changes as we vary the number of tasks used for training and validation (Table 1 lower). For VinDr-CXR, in an individual fold, the average performance increases as training becomes more diverse (i.e. more tasks); however, the performance of the ensemble plateaus. Having more training tasks can help the model to be sensitive to a wider range of anomalous features. But as the number of training tasks increases, so does the overlap between different models in the ensemble, diminishing the benefit of pooling predictions. This could also explain why the standard deviation (across folds) decreases as the number of training tasks increases, since the models are becoming more similar. Our best configuration is close to being competitive with the state-of-the-art semi -supervised method DDAD-ASR [6]. Even though their method uses twice as much training data, as well as some real anomalous data, our purely synthetic method begins to close the gap (AP of [6] 84.3 vs. ours 80.7 on DDADts ). For the brain datasets, all metrics generally decrease as the number of training tasks increases. This could be due to the distribution shift between training and test data. Although more training tasks may increase sensitivity to diverse irregularities, this can actually become a liability if there are differences between (healthy) training and test data (e.g. acquisition parameters). More sensitive models may then lead to more “false” positives. Discussion: We demonstrate the effectiveness of our method in multiple settings and across different modalities. A unique aspect of the brain data is the domain shift. The HCP training data was acquired at a much higher isotropic resolution than the BraTS and ISLES test data, which are both anisotropic. Here we achieve the best performance using more tasks for validation, which successfully reduces overfitting and hypersensitivity. Incorporating greater data augmentations, such as simulating anisotropic spacing, could further improve results by training the model to ignore these transformations. We also achieve strong results for the X-ray data, although precise localisation remains a challenging task. The gap between current performance and clinicially useful localisation should therefore be high priority for future research.
170
4
M. Baugh et al.
Conclusion
In this work we use multiple synthetic tasks to both train and validate selfsupervised anomaly detection models. This enables more robust training without the need for real anomalous training or validation data. To achieve this we propose multiple diverse tasks, exposing models to a wide range of anomalous features. These include patch blending, image deformations and intensity modulations. As part of this, we extend Poisson image editing to images of arbitrary dimensions, enabling the current state-of-the-art tasks to be applied beyond just 2D images. In order to use all of these tasks in a common framework we also design a unified labelling function, with improved continuity for small intensity changes. We evaluate our method on both brain MRI and chest X-rays and achieve state-of-the-art performance and above. We also report pixel-wise results, even for the challenging case of chest X-rays. We hope this encourages others to do the same, as accurate localisation is essential for anomaly detection to have a future in clinical workflows. Acknowledgements. We thank EPSRC for DTP funding and HPC resources provided by the Erlangen National High Performance Computing Center (NHR @ FAU) of the Friedrich-Alexander-Universit¨ at Erlangen-N¨ urnberg (FAU) under the NHR project b143dc. NHR funding is provided by federal and Bavarian state authorities. NHR@FAU hardware is partially funded by the German Research Foundation (DFG) - 440719683. Support was also received by the ERC - project MIA-NORMAL 101083647 and DFG KA 5801/2-1, INST 90/1351-1.
References 1. Bakas, S., et al.: Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the brats challenge. arXiv:1811.02629 (2018) 2. Baur, C., Denner, S., Wiestler, B., Navab, N., Albarqouni, S.: Autoencoders for unsupervised anomaly segmentation in brain mr images: a comparative study. Med. Image Anal. 69, 101952 (2021) 3. Brady, A.P.: Error and discrepancy in radiology: inevitable or avoidable? Insights Imaging 8(1), 171–182 (2017) 4. Bruls, R., Kwee, R.: Workload for radiologists during on-call hours: dramatic increase in the past 15 years. Insights Imaging 11, 1–7 (2020) 5. Cai, Y., Chen, H., Yang, X., Zhou, Y., Cheng, K.T.: Dual-distribution discrepancy for anomaly detection in chest x-rays. In: MICCAI 2022, Part III, pp. 584–593. Springer (2022). https://doi.org/10.1007/978-3-031-16437-8 56 6. Cai, Y., Chen, H., Yang, X., Zhou, Y., Cheng, K.T.: Dual-distribution discrepancy with self-supervised refinement for anomaly detection in medical images. arXiv:2210.04227 (2022) 7. Cho, J., Kang, I., Park, J.: Self-supervised 3d out-of-distribution detection via pseudoanomaly generation. In: Biomedical Image Registration, Domain Generalisation and Out-of-Distribution Analysis, pp. 95–103 (2022)
Many Tasks Make Light Work
171
8. Gong, D., et al.: Memorizing normality to detect anomaly: memory-augmented deep autoencoder for unsupervised anomaly detection. In: CVPR 2019, pp. 1705– 1714 (2019) 9. Kascenas, A., et al.: The role of noise in denoising models for anomaly detection in medical images. arXiv:2301.08330 (2023) 10. Kim, Y.W., Mansfield, L.T.: Fool me twice: delayed diagnoses in radiology with emphasis on perpetuated errors. Am. J. Roentgenol. 202(3), 465–470 (2014) 11. Li, C.L., Sohn, K., Yoon, J., Pfister, T.: Cutpaste: self-supervised learning for anomaly detection and localization. In: CVPR 2021, pp. 9664–9674 (2021) 12. L¨ uth, C.T., et al.: Cradl: contrastive representations for unsupervised anomaly detection and localization. arXiv:2301.02126 (2023) 13. Maier, O., Menze, B.H., von der Gablentz, J., H¨ ani, L., Heinrich, M.P., et al.: ISLES 2015 - a public evaluation benchmark for ischemic stroke lesion segmentation from multispectral MRI. Med. Image Anal. 35, 250–269 (2017). https://doi.org/ 10.1016/j.media.2016.07.009 14. Mao, Y., Xue, F.F., Wang, R., Zhang, J., Zheng, W.S., Liu, H.: Abnormality detection in chest x-ray images using uncertainty prediction autoencoders. In: MICCAI 2020, pp. 529–538 (2020) 15. Meissen, F., Kaissis, G., Rueckert, D.: Challenging current semi-supervised anomaly segmentation methods for brain mri. In: BrainLes 2021 at MICCAI 2021, 27 Sept 2021, Part I, pp. 63–74. Springer (2022). https://doi.org/10.1007/978-3031-08999-2 5 16. Meissen, F., Wiestler, B., Kaissis, G., Rueckert, D.: On the pitfalls of using the residual error as anomaly score. In: Proceedings of The 5th International Conference on Medical Imaging with Deep Learning. Proceedings of Machine Learning Research, vol. 172, pp. 914–928. PMLR (06–08 Jul 2022) 17. Morel, J.M., Petro, A.B., Sbert, C.: Fourier implementation of poisson image editing. Pattern Recogn. Lett. 33(3), 342–348 (2012) 18. Nguyen, H.Q., et al.: Vindr-cxr: an open dataset of chest x-rays with radiologist’s annotations. Scientific Data 9(1), 429 (2022) 19. Pedregosa, F., et al.: Scikit-learn: machine learning in python. J. Mach. Learn. Res. 12, 2825–2830 (2011) 20. P´erez, P., Gangnet, M., Blake, A.: Poisson image editing. In: ACM SIGGRAPH 2003 Papers, pp. 313–318 (2003) 21. Pinaya, W.H., et al.: Fast unsupervised brain anomaly detection and segmentation with diffusion models. In: MICCAI 2022, Part VIII, pp. 705–714. Springer (2022). https://doi.org/10.1007/978-3-031-16452-1 67 22. Schlegl, T., Seeb¨ ock, P., Waldstein, S.M., Langs, G., Schmidt-Erfurth, U.: fAnoGAN: fast unsupervised anomaly detection with generative adversarial networks. Med. Image Anal. 54, 30–44 (2019). https://doi.org/10.1016/j.media.2019. 01.010 23. Schl¨ uter, H.M., Tan, J., Hou, B., Kainz, B.: Natural synthetic anomalies for selfsupervised anomaly detection and localization. In: Computer Vision - ECCV 2022, pp. 474–489. Springer Nature Switzerland, Cham (2022). https://doi.org/10.1007/ 978-3-031-19821-2 27 24. Tan, J., Hou, B., Batten, J., Qiu, H., Kainz, B.: Detecting outliers with foreign patch interpolation. Mach. Learn. Biomed. Imaging 1, 1–27 (2022) 25. Tan, J., Hou, B., Day, T., Simpson, J., Rueckert, D., Kainz, B.: Detecting outliers with poisson image interpolation. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 581–591. Springer, Cham (2021). https://doi.org/10.1007/ 978-3-030-87240-3 56
172
M. Baugh et al.
26. Tian, Y., et al.: Constrained contrastive distribution learning for unsupervised anomaly detection and localisation in medical images. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 128–140. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87240-3 13 27. Van Den Oord, A., Vinyals, O., et al.: Neural discrete representation learning. In: Advances in Neural Information Processing Systems 30 (2017) 28. Van Essen, D., Ugurbil, K., Auerbach, E., et al.: The human connectome project: a data acquisition perspective. NeuroImage 62(4), 2222–2231 (2012). https://doi. org/10.1016/j.neuroimage.2012.02.018 29. Zavrtanik, V., Kristan, M., Skoˇcaj, D.: Draem-a discriminatively trained reconstruction embedding for surface anomaly detection. In: CVPR 2021, pp. 8330–8339 (2021) 30. Zhang, W., et al.: A multi-task network with weight decay skip connection training for anomaly detection in retinal fundus images. In: MICCAI 2022, Part II, pp. 656– 666. Springer (2022) 31. Zimmerer, D., et al.: Mood 2020: a public benchmark for out-of-distribution detection and localization on medical images. IEEE Trans. Med. Imaging 41(10), 2728– 2738 (2022) 32. Zimmerer, D., Kohl, S.A., Petersen, J., Isensee, F., Maier-Hein, K.H.: Context-encoding variational autoencoder for unsupervised anomaly detection. arXiv:1812.05941 (2018)
AME-CAM: Attentive Multiple-Exit CAM for Weakly Supervised Segmentation on MRI Brain Tumor Yu-Jen Chen1(B) , Xinrong Hu2 , Yiyu Shi2 , and Tsung-Yi Ho3 1
3
National Tsing Hua University, Hsinchu, Taiwan [email protected] 2 University of Notre Dame, Notre Dame, IN, USA {xhu7,yshi4}@nd.edu The Chinese University of Hong Kong, Hong Kong, China [email protected]
Abstract. Magnetic resonance imaging (MRI) is commonly used for brain tumor segmentation, which is critical for patient evaluation and treatment planning. To reduce the labor and expertise required for labeling, weakly-supervised semantic segmentation (WSSS) methods with class activation mapping (CAM) have been proposed. However, existing CAM methods suffer from low resolution due to strided convolution and pooling layers, resulting in inaccurate predictions. In this study, we propose a novel CAM method, Attentive Multiple-Exit CAM (AME-CAM), that extracts activation maps from multiple resolutions to hierarchically aggregate and improve prediction accuracy. We evaluate our method on the BraTS 2021 dataset and show that it outperforms state-of-the-art methods. Keywords: Tumor segmentation segmentation
1
· Weakly-supervised semantic
Introduction
Deep learning techniques have greatly improved medical image segmentation by automatically extracting specific tissue or substance location information, which facilitates accurate disease diagnosis and assessment. However, most deep learning approaches for segmentation require fully or partially labeled training datasets, which can be time-consuming and expensive to annotate. To address this issue, recent research has focused on developing segmentation frameworks that require little or no segmentation labels. To meet this need, many researchers have devoted their efforts to WeaklySupervised Semantic Segmentation (WSSS) [21], which utilizes weak supervision, such as image-level classification labels. Recent WSSS methods can be broadly categorized into two types [4]: Class-Activation-Mapping-based (CAM-based) [9,13, 16,19,20,22], and Multiple-Instance-Learning-based (MIL-based) [15] methods. Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 17. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 173–182, 2023. https://doi.org/10.1007/978-3-031-43907-0_17
174
Y.-J. Chen et al.
The literature has not adequately addressed the issue of low-resolution ClassActivation Maps (CAMs), especially for medical images. Some existing methods, such as dilated residual networks [24] and U-Net segmentation architecture [3,7,17], have attempted to tackle this issue, but still require many upsampling operations, which the results become blurry. Meanwhile, LayerCAM [9] has proposed a hierarchical solution that extracts activation maps from multiple convolution layers using Grad-CAM [16] and aggregates them with equal weights. Although this approach successfully enhances the resolution of the segmentation mask, it lacks flexibility and may not be optimal. In this paper, we propose an Attentive Multiple-Exit CAM (AME-CAM) for brain tumor segmentation in magnetic resonance imaging (MRI). Different from recent CAM methods, AME-CAM uses a classification model with multipleexit training strategy applied to optimize the internal outputs. Activation maps from the outputs of internal classifiers, which have different resolutions, are then aggregated using an attention model. The model learns the pixel-wise weighted sum of the activation maps by a novel contrastive learning method. Our proposed method has the following contributions: – To tackle the issues in existing CAMs, we propose to use multiple-exit classification networks to accurately capture all the internal activation maps of different resolutions. – We propose an attentive feature aggregation to learn the pixel-wise weighted sum of the internal activation maps. – We demonstrate the superiority of AME-CAM over state-of-the-art CAM methods in extracting segmentation results from classification networks on the 2021 Brain Tumor Segmentation Challenge (BraTS 2021) [1,2,14]. – For reproducibility, we have released our code at https://github.com/windstormer/AME-CAM Overall, our proposed method can help overcome the challenges of expensive and time-consuming segmentation labeling in medical imaging, and has the potential to improve the accuracy of disease diagnosis and assessment.
2
Attentive Multiple-Exit CAM (AME-CAM)
The proposed AME-CAM method consists of two training phases: activation extraction and activation aggregation, as shown in Fig. 1. In the activation extraction phase, we use a binary classification network, e.g., ResNet-18, to obtain the class probability y = f (I) of the input image I. To enable multipleexit training, we add one internal classifier after each residual block, which generates the activation map Mi of different resolutions. We use a cross-entropy loss to train the multiple-exit classifier, which is defined as loss =
4 i=1
CE(GAP (Mi ), L)
(1)
Conditional Diffusion Models for WSSS
175
Fig. 1. An overview of the proposed AME-CAM method, which contains multiple-exit network based activation extraction phase and attention based activation aggregation phase. The operator and ⊗ denote the pixel-wise weighted sum and the pixel-wise multiplication, respectively.
where GAP (·) is the global-average-pooling operation, CE(·) is the cross-entropy loss, and L is the image-wise ground-truth label. In the activation aggregation phase, we create an efficient hierarchical aggregation method to generate the aggregated activation map Mf by calculating the pixel-wise weighted sum of the activation maps Mi . We use an attention network A(·) to estimate the importance of each pixel from each activation map. The attention network takes in the input image I masked by the activation map and outputs the pixel-wised importance score Sxyi of each activation map. We formulate the operation as follows: Sxyi = A([I ⊗ n(Mi )]4i=1 )
(2)
where [·] is the concatenate operation, n(·) is the min-max normalization to map the range to [0,1], and ⊗ is the pixel-wise multiplication, which is known as image masking. The aggregated activation map Mf is then obtained by the pixel-wise 4 weighted sum of Mi , which is Mf = i=1 (Sxyi ⊗ Mi ). We train the attention network with unsupervised contrastive learning, which forces the network to disentangle the foreground and the background of the aggregated activation map Mf . We mask the input image by the aggregated activation map Mf and its opposite (1 − Mf ) to obtain the foreground feature
176
Y.-J. Chen et al.
and the background feature, respectively. The loss function is defined as follows: loss = SimM in(vif , vjb ) + SimM ax(vif , vjf ) + SimM ax(vib , vjb )
(3)
where vif and vib denote the foreground and the background feature of the i-th sample, respectively. SimM in and SimM ax are the losses that minimize and maximize the similarity between two features (see C2 AM [22] for details). Finally, we average the activation maps M1 to M4 and the aggregated map Mf to obtain the final CAM results for each image. We apply the Dense Conditional Random Field (DenseCRF) [12] algorithm to generate the final segmentation mask. It is worth noting that the proposed method is flexible and can be applied to any classification network architecture.
3 3.1
Experiments Dataset
We evaluate our method on the Brain Tumor Segmentation challenge (BraTS) dataset [1,2,14], which contains 2,000 cases, each of which includes four 3D volumes from four different MRI modalities: T1, post-contrast enhanced T1 (T1CE), T2, and T2 Fluid Attenuated Inversion Recovery (T2-FLAIR), as well as a corresponding segmentation ground-truth mask. The official data split divides these cases by the ratio of 8:1:1 for training, validation, and testing (5,802 positive and 1,073 negative images). In order to evaluate the performance, we use the validation set as our test set and report statistics on it. We preprocess the data by slicing each volume along the z-axis to form a total of 193,905 2D images, following the approach of Kang et al. [10] and Dey and Hong [6]. We use the ground-truth segmentation masks only in the final evaluation, not in the training process. 3.2
Implementation Details and Evaluation Protocol
We implement our method in PyTorch using ResNet-18 as the backbone classifier. We pretrain the classifier using SupCon [11] and then fine-tune it in our experiments. We use the entire training set for both pretraining and fine-tuning. We set the initial learning rate to 1e-4 for both phases, and use the cosine annealing scheduler to decrease it until the minimum learning rate is 5e-6. We set the weight decay in both phases to 1e-5 for model regularization. We use Adam optimizer in the multiple-exit phase and SGD optimizer in the aggregation phase. We train all classifiers until they converge with a test accuracy of over 0.9 for all image modalities. Note that only class labels are available in the training set. We use the Dice score and Intersection over Union (IoU) to evaluate the quality of the semantic segmentation, following the approach of Xu et al. [23], Tang et al. [18], and Qian et al. [15]. In addition, we report the 95% Hausdorff Distance (HD95) to evaluate the boundary of the prediction mask. Interested readers can refer to the supplementary material for results on other network architectures.
Conditional Diffusion Models for WSSS
177
Table 1. Comparison with weakly supervised methods (WSSS), unsupervised method (UL), and fully supervised methods (FSL) on BraTS dataset with T1, T1-CE, T2, and T2-FLAIR MRI images. Results are reported in the form of mean ± std. We mark the highest score among WSSS methods with bold text. BraTS T1 Type Method
IoU ↑
HD95 ↓
WSSS Grad-CAM (2016) 0.107 ± 0.090
Dice ↑
0.059 ± 0.055
121.816 ± 22.963
ScoreCAM (2020) 0.296 ± 0.128
0.181 ± 0.089
60.302 ± 14.110
LFI-CAM (2021)
0.568 ± 0.167
0.414 ± 0.152
23.939 ± 25.609
LayerCAM (2021) 0.571 ± 0.170
0.419 ± 0.161
23.335 ± 27.369
Swin-MIL (2022)
0.330 ± 0.147
46.468 ± 30.408
0.477 ± 0.170
AME-CAM (ours) 0.631 ± 0.119 0.471 ± 0.119 21.813 ± 18.219 UL
C&F (2020)
0.200 ± 0.082
0.113 ± 0.051
79.187 ± 14.304
FSL
C&F (2020)
0.572 ± 0.196
0.426 ± 0.187
29.027 ± 20.881
Opt. U-net (2021) 0.836 ± 0.062
0.723 ± 0.090
11.730 ± 10.345
BraTS T1-CE Type Method
IoU ↑
HD95 ↓
WSSS Grad-CAM (2016) 0.127 ± 0.088
Dice ↑
0.071 ± 0.054
129.890 ± 27.854
ScoreCAM (2020) 0.397 ± 0.189
0.267 ± 0.163
46.834 ± 22.093
LFI-CAM (2021)
0.121 ± 0.120
0.069 ± 0.076
136.246 ± 38.619
LayerCAM (2021) 0.510 ± 0.209
0.367 ± 0.180
29.850 ± 45.877
Swin-MIL (2022)
0.314 ± 0.140
46.996 ± 22.821
0.460 ± 0.169
AME-CAM (ours) 0.695 ± 0.095 0.540 ± 0.108 18.129 ± 12.335 UL
C&F (2020)
0.179 ± 0.080
0.101 ± 0.050
77.982 ± 14.042
FSL
C&F (2020)
0.246 ± 0.104
0.144 ± 0.070
130.616 ± 9.879
Opt. U-net (2021) 0.845 ± 0.058
0.736 ± 0.085
11.593 ± 11.120
BraTS T2 Type Method
IoU ↑
HD95 ↓
WSSS Grad-CAM (2016) 0.049 ± 0.058
Dice ↑
0.026 ± 0.034
141.025 ± 23.107
ScoreCAM (2020) 0.530 ± 0.184
0.382 ± 0.174
28.611 ± 11.596
LFI-CAM (2021)
0.673 ± 0.173
0.531 ± 0.186
18.165 ± 10.475
LayerCAM (2021) 0.624 ± 0.178
0.476 ± 0.173
23.978 ± 44.323
Swin-MIL (2022)
0.290 ± 0.117
38.006 ± 30.000
0.437 ± 0.149
AME-CAM (ours) 0.721 ± 0.086 0.571 ± 0.101 14.940 ± 8.736 UL
C&F (2020)
0.230 ± 0.089
0.133 ± 0.058
76.256 ± 13.192
FSL
C&F (2020)
0.611 ± 0.221
0.474 ± 0.217
109.817 ± 27.735
Opt. U-net (2021) 0.884 ± 0.064
0.798 ± 0.098
8.349 ± 9.125
BraTS T2-FLAIR Type Method
IoU ↑
HD95 ↓
WSSS Grad-CAM (2016) 0.150 ± 0.077
Dice ↑
0.083 ± 0.050
110.031 ± 23.307
ScoreCAM (2020) 0.432 ± 0.209
0.299 ± 0.178
39.385 ± 17.182
LFI-CAM (2021)
0.161 ± 0.192
0.102 ± 0.140
125.749 ± 45.582
LayerCAM (2021) 0.652 ± 0.206
0.515 ± 0.210
22.055 ± 33.959
Swin-MIL (2022)
0.163 ± 0.079
41.870 ± 19.231
0.272 ± 0.115
AME-CAM (ours) 0.862 ± 0.088 0.767 ± 0.122 8.664 ± 6.440 UL
C&F (2020)
0.306 ± 0.190
0.199 ± 0.167
75.651 ± 14.214
FSL
C&F (2020)
0.578 ± 0.137
0.419 ± 0.130
138.138 ± 14.283
Opt. U-net (2021) 0.914 ± 0.058
0.847 ± 0.093
8.093 ± 11.879
178
4
Y.-J. Chen et al.
Results
Fig. 2. Qualitative results of all methods. (a) Input Image. (b) Ground Truth. (c) Grad-CAM [16] (d) ScoreCAM [19]. (e) LFI-CAM [13]. (f) LayerCAM [9]. (g) SwinMIL [15]. (h) AME-CAM (ours). The image modalities of rows 1-4 are T1, T1-CE, T2, T2-FLAIR, respectively from the BraTS dataset.
4.1
Quantitative and Qualitative Comparison with State-of-the-Art
In this section, we compare the segmentation performance of the proposed AME-CAM with five state-of-the-art weakly-supervised segmentation methods, namely Grad-CAM [16], ScoreCAM [19], LFI-CAM [13], LayerCAM [9], and Swin-MIL [15]. We also compare with an unsupervised approach C&F [5], the supervised version of C&F, and the supervised Optimized U-net [8] to show the comparison with non-CAM-based methods. We acknowledge that the results from fully supervised and unsupervised methods are not directly comparable to the weakly supervised CAM methods. Nonetheless, these methods serve as interesting references for the potential performance ceiling and floor of all the CAM methods. Quantitatively, Grad-CAM and ScoreCAM result in low dice scores, demonstrating that they have difficulty extracting the activation of medical images. LFI-CAM and LayerCAM improve the dice score in all modalities, except LFICAM in T1-CE and T2-FLAIR. Finally, the proposed AME-CAM achieves optimal performance in all modalities of the BraTS dataset. Compared to the unsupervised baseline (UL), C&F is unable to separate the tumor and the surrounding tissue due to low contrast, resulting in low dice scores in all experiments. With pixel-wise labels, the dice of supervised C&F
Conditional Diffusion Models for WSSS
179
improves significantly. Without any pixel-wise label, the proposed AME-CAM outperforms supervised C&F in all modalities. The fully supervised (FSL) Optimized U-net achieves the highest dice score and IoU score in all experiments. However, even under different levels of supervision, there is still a performance gap between the weakly supervised CAM methods and the fully supervised state-of-the-art. This indicates that there is still potential room for WSSS methods to improve in the future. Qualitatively, Fig. 2 shows the visualization of the CAM and segmentation results from all six CAM-based approaches under four different modalities from the BraTS dataset. Grad-CAM (Fig. 2(c)) results in large false activation region, where the segmentation mask is totally meaningless. ScoreCAM eliminates false activation corresponding to air. LFI-CAM focus on the exact tumor area only in the T1 and T2 MRI (row 1 and 3). Swin-MIL can hardly capture the tumor region of the MRI image, where the activation is noisy. Among all, only LayerCAM and the proposed AME-CAM successfully focus on the exact tumor area, but AMECAM reduces the under-estimation of the tumor area. This is attributed to the benefit provided by aggregating activation maps from different resolutions. 4.2
Ablation Study
Table 2. Ablation study for aggregation phase using T1 MRI images from the BraTS dataset. Avg. ME denotes that we directly average four activation maps generated by the multiple-exit phase. The dice score, IoU, and the HD95 are reported in the form of mean ± std. Method
Dice ↑
IoU ↑
HD95 ↓
Avg. ME
0.617 ± 0.121
0.457 ± 0.121
23.603 ± 20.572
Avg. ME+C2 AM [22] 0.484 ± 0.256
0.354 ± 0.207
69.242 ± 121.163
AME-CAM (ours)
0.631 ± 0.119 0.471 ± 0.119 21.813 ± 18.219
Effect of Different Aggregation Approaches: In Table 2, we conducted an ablation study to investigate the impact of using different aggregation approaches after extracting activations from the multiple-exit network. We aim to demonstrate the superiority of the proposed attention-based aggregation approach for segmenting tumor regions in T1 MRI of the BraTS dataset. Note that we only report the results for T1 MRI in the BraTS dataset. Please refer to the supplementary material for the full set of experiments. As a baseline, we first conducted the average of four activation maps generated by the multiple-level activation extraction (Avg. ME). We then applied C2 AM [22], a state-of-the-art CAM-based refinement approach, to refine the result of the baseline, which we call “Avg. ME+C2 AM”. However, we observed that C2 AM tended to segment the brain region instead of the tumor region due
180
Y.-J. Chen et al.
to the larger contrast between the brain tissue and the air than that between the tumor region and its surrounding tissue. Any incorrect activation of C2 AM also led to inferior results, resulting in a degradation of the average dice score from 0.617 to 0.484. In contrast, the proposed attention-based approach provided a significant weighting solution that led to optimal performance in all cases. Table 3. Ablation study for using single-exit from M1 , M2 , M3 or M4 of Fig. 1 and the multiple-exit using results from M2 and M3 and using all exits (AME-CAM). The experiments are done on the T1-CE MRI of BraTS dataset. The dice score, IoU, and the HD95 are reported in the form of mean ± std.
Selected Exit Single-exit
M1 M2 M3 M4
Dice ↑
IoU ↑
HD95 ↓
0.144 ± 0.184 0.500 ± 0.231 0.520 ± 0.163 0.154 ± 0.101
0.090 ± 0.130 0.363 ± 0.196 0.367 ± 0.141 0.087 ± 0.065
74.249 ± 62.669 43.762 ± 85.703 43.749 ± 54.907 120.779 ± 44.548
Multiple-exit M2 + M3 0.566 ± 0.207 0.421 ± 0.186 27.972 ± 56.591 AME-CAM (ours) 0.695 ± 0.095 0.540 ± 0.108 18.129 ± 12.335
Effect of Single-Exit and Multiple-Exit: Table 3 summarizes the performance of using single-exit from M1 , M2 , M3 , or M4 of Fig. 1 and the multipleexit using results from M2 and M3 , and using all exits (AME-CAM) on T1-CE MRI in the BraTS dataset. The comparisons show that the activation map obtained from the shallow layer M1 and the deepest layer M4 result in low dice scores, around 0.15. This is because the network is not deep enough to learn the tumor region in the shallow layer, and the resolution of the activation map obtained from the deepest layer is too low to contain sufficient information to make a clear boundary for the tumor. Results of the internal classifiers from the middle of the network (M2 and M3 ) achieve the highest dice score and IoU, both of which are around 0.5. To evaluate whether using results from all internal classifiers leads to the highest performance, we further apply the proposed method to the two internal classifiers with the highest dice scores, i.e., M2 and M3 , called M2 + M3 . Compared with using all internal classifiers (M1 to M4 ), M2 + M3 results in 18.6% and 22.1% lower dice and IoU, respectively. In conclusion, our AME-CAM still achieves the optimal performance among all the experiments of single-exit and multiple-exit. Other ablation studies are presented in the supplementary material due to space limitations.
5
Conclusion
In this work, we propose a brain tumor segmentation method for MRI images using only class labels, based on an Attentive Multiple-Exit Class Activation
Conditional Diffusion Models for WSSS
181
Mapping (AME-CAM). Our approach extracts activation maps from different exits of the network to capture information from multiple resolutions. We then use an attention model to hierarchically aggregate these activation maps, learning pixel-wise weighted sums. Experimental results on the four modalities of the 2021 BraTS dataset demonstrate the superiority of our approach compared with other CAM-based weakly-supervised segmentation methods. Specifically, AME-CAM achieves the highest dice score for all patients in all datasets and modalities. These results indicate the effectiveness of our proposed approach in accurately segmenting brain tumors from MRI images using only class labels.
References 1. Bakas, S., et al.: Advancing the cancer genome atlas glioma mri collections with expert segmentation labels and radiomic features. Scientific Data 4(1), 1–13 (2017) 2. Bakas, S., et al.: Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the brats challenge. arXiv preprint arXiv:1811.02629 (2018) 3. Belharbi, S., Sarraf, A., Pedersoli, M., Ben Ayed, I., McCaffrey, L., Granger, E.: F-cam: Ffull resolution class activation maps via guided parametric upscaling. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3490–3499 (2022) 4. Chan, L., Hosseini, M.S., Plataniotis, K.N.: A comprehensive analysis of weaklysupervised semantic segmentation in different image domains. Int. J. Comput. Vision 129, 361–384 (2021) 5. Chen, J., Frey, E.C.: Medical image segmentation via unsupervised convolutional neural network. arXiv preprint arXiv:2001.10155 (2020) 6. Dey, R., Hong, Y.: ASC-Net: adversarial-based selective network for unsupervised anomaly segmentation. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 236–247. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087240-3 23 7. Englebert, A., Cornu, O., De Vleeschouwer, C.: Poly-cam: high resolution class activation map for convolutional neural networks. arXiv preprint arXiv:2204.13359 (2022) 8. Futrega, M., Milesi, A., Marcinkiewicz, M., Ribalta, P.: Optimized u-net for brain tumor segmentation. arXiv preprint arXiv:2110.03352 (2021) 9. Jiang, P.T., Zhang, C.B., Hou, Q., Cheng, M.M., Wei, Y.: Layercam: exploring hierarchical class activation maps for localization. IEEE Trans. Image Process. 30, 5875–5888 (2021) 10. Kang, H., Park, H.m., Ahn, Y., Van Messem, A., De Neve, W.: Towards a quantitative analysis of class activation mapping for deep learning-based computer-aided diagnosis. In: Medical Imaging 2021: Image Perception, Observer Performance, and Technology Assessment, vol. 11599, p. 115990M. International Society for Optics and Photonics (2021) 11. Khosla, P., et al.: Supervised contrastive learning. arXiv preprint arXiv:2004.11362 (2020) 12. Kr¨ ahenb¨ uhl, P., Koltun, V.: Efficient inference in fully connected crfs with gaussian edge potentials. In: Advances in Neural Information Processing Systems 24 (2011)
182
Y.-J. Chen et al.
13. Lee, K.H., Park, C., Oh, J., Kwak, N.: Lfi-cam: learning feature importance for better visual explanation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 1355–1363 (2021) 14. Menze, B.H., et al.: The multimodal brain tumor image segmentation benchmark (brats). IEEE Trans. Med. Imaging 34(10), 1993–2024 (2014) 15. Qian, Z., et al.: Transformer based multiple instance learning for weakly supervised histopathology image segmentation. In: Medical Image Computing and Computer Assisted Intervention-MICCAI 2022: 25th International Conference, Singapore, September 18–22, 2022, Proceedings, Part II, pp. 160–170. Springer (2022). https://doi.org/10.1007/978-3-031-16434-7 16 16. Selvaraju, R.R., Cogswell, M., Das, A., Vedantam, R., Parikh, D., Batra, D.: Gradcam: Visual explanations from deep networks via gradient-based localization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 618– 626 (2017) 17. Tagaris, T., Sdraka, M., Stafylopatis, A.: High-resolution class activation mapping. In: 2019 IEEE International Conference On Image Processing (ICIP), pp. 4514– 4518. IEEE (2019) 18. Tang, W., et al.: M-SEAM-NAM: multi-instance self-supervised equivalent attention mechanism with neighborhood affinity module for double weakly supervised segmentation of COVID-19. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12907, pp. 262–272. Springer, Cham (2021). https://doi.org/10.1007/978-3030-87234-2 25 19. Wang, H., et al.: Score-cam: score-weighted visual explanations for convolutional neural networks. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops, pp. 24–25 (2020) 20. Wang, Y., Zhang, J., Kan, M., Shan, S., Chen, X.: Self-supervised equivariant attention mechanism for weakly supervised semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12275–12284 (2020) 21. Wolleb, J., Bieder, F., Sandk¨ uhler, R., Cattin, P.C.: Diffusion models for medical anomaly detection. In: Medical Image Computing and Computer Assisted Intervention-MICCAI 2022: 25th International Conference, Singapore, 18–22 September 2022, Proceedings, Part VIII, pp. 35–45. Springer (2022). https://doi. org/10.1007/978-3-031-16452-1 4 22. Xie, J., Xiang, J., Chen, J., Hou, X., Zhao, X., Shen, L.: C2am: contrastive learning of class-agnostic activation map for weakly supervised object localization and semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 989–998 (2022) 23. Xu, X., et al.: Whole heart and great vessel segmentation in congenital heart disease using deep neural networks and graph matching. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 477–485. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32245-8 53 24. Yu, F., Koltun, V., Funkhouser, T.: Dilated residual networks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 472–480 (2017)
Cross-Adversarial Local Distribution Regularization for Semi-supervised Medical Image Segmentation Thanh Nguyen-Duc1(B) , Trung Le1 , Roland Bammer1 , He Zhao2 , Jianfei Cai1 , and Dinh Phung1 1 Monash University, Melbourne, Australia {thanh.nguyen4,trunglm,roland.bammer,jianfei.cai,dinh.phung}@monash.edu 2 CSIRO’s Data61, Melbourne, Australia [email protected]
Abstract. Medical semi-supervised segmentation is a technique where a model is trained to segment objects of interest in medical images with limited annotated data. Existing semi-supervised segmentation methods are usually based on the smoothness assumption. This assumption implies that the model output distributions of two similar data samples are encouraged to be invariant. In other words, the smoothness assumption states that similar samples (e.g., adding small perturbations to an image) should have similar outputs. In this paper, we introduce a novel cross-adversarial local distribution (Cross-ALD) regularization to further enhance the smoothness assumption for semi-supervised medical image segmentation task. We conducted comprehensive experiments that the Cross-ALD archives state-of-the-art performance against many recent methods on the public LA and ACDC datasets. Keywords: Semi-supervised segmentation · Adversarial local distribution · Adversarial examples · Cross-adversarial local distribution
1
Introduction
Medical image segmentation is a critical task in computer-aided diagnosis and treatment planning. It involves the delineation of anatomical structures or pathological regions in medical images, such as magnetic resonance imaging (MRI) or computed tomography (CT) scans. Accurate and efficient segmentation is essential for various medical applications, including tumor detection, surgical planning, and monitoring disease progression. However, manual medical imaging annotation is time-consuming and expensive because it requires the domain Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 18. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 183–194, 2023. https://doi.org/10.1007/978-3-031-43907-0_18
184
T. Nguyen-Duc et al.
knowledge from medical experts. Therefore, there is a growing interest in developing semi-supervised learning that leverages both labeled and unlabeled data to improve the performance of image segmentation models [16,27]. Existing semi-supervised segmentation methods exploit smoothness assumption, e.g., the data samples that are closer to each other are more likely to to have the same label. In other words, the smoothness assumption encourages the model to generate invariant outputs under small perturbations. We have seen such perturbations being be added to natural input images at data-level [4,9,14,19,21], feature-level [6,17,23,25], and model-level [8,11,12,24,28]. Among them, virtual adversarial training (VAT) [14] is a well-known one which promotes the smoothness of the local output distribution using adversarial examples. The adversarial examples are near decision boundaries generated by adding adversarial perturbations to natural inputs. However, VAT can only create one adversarial sample in a run, which is often insufficient to completely explore the space of possible perturbations (see Sect. 2.1). In addition, the adversarial examples of VAT can also lie together and lose diversity that significantly reduces the quality of adversarial examples [15,20]. Mixup regularization [29] is a data augmentation method used in deep learning to improve model generalization. The idea behind mixup is to create new training examples by linearly interpolating between pairs of existing examples and their corresponding labels, which has been adopted in [2,3,19] to semi-supervised learning. The work [5] suggests that Mixup improves the smoothness of the neural function by bounding the Lipschitz constant of the gradient function of the neural networks. However, we show that mixing between more informative samples (e.g., adversarial examples near decision boundaries) can lead to a better performance enhancement compared to mixing natural samples (see Sect. 3.3). In this paper, we propose a novel cross-adversarial local distribution regularization for semi-supervised medical image segmentation for smoothness assumption enhancement1 . Our contributions are summarized as follows: 1) To overcome the VAT’s drawback, we formulate an adversarial local distribution (ALD) with Dice loss function that covers all possible adversarial examples within a ball constraint. 2) To enhance smoothness assumption, we propose a novel cross-adversarial local distribution regularization (Cross-ALD) to encourage the smoothness assumption, which is a random mixing between two ALDs. 3) We also propose a sufficiently approximation for the Cross-ALD by a multiple particle-based search using semantic feature Stein Variational Gradient Decent (SVGDF), an enhancement of the vanilla SVGD [10]. 4) We conduct comprehensive experiments on ADCD [1] and LA [26] datasets, showing that our Cross-ALD regularization achieves state-of-the-art performance against existing solutions [8,11,12,14,21,22,28].
1
The Cross-ALD implementation in https://github.com/PotatoThanh/Crossadversarial-local-distribution-regularization.
Cross-Adversarial Local Distribution Regularization
2
185
Method
In this section, we begin by reviewing the minimax optimization problem of virtual adversarial training (VAT) [14]. Given an input, we then formulate a novel adversarial local distribution (ALD) with Dice loss, which benefits the medical semi-supervised image segmentation problem specifically. Next, a crossadversarial local distribution (Cross-ALD) is constructed by randomly combining two ALDs. We approximate the ALD by a particle-based method named semantic feature Stein Variational Gradient Descent (SVGDF). Considering the resolution of medical images are usually high, we enhance the vanilla SVGD [10] from data-level to feature-level, which is named SVGDF. We finally provide our regularization loss for semi-supervised medical image segmentation. 2.1
The Minimax Optimization of VAT
Let Dl and Dul be the labeled and unlabeled dataset, respectively, with PDl and PDul being the corresponding data distribution. Denote x ∈ Rd as our ddimensional input in a space X. The labeled image xl and segmentation groundtruth y are sampled from the labeled dataset Dl (xl , y ∼ PDl ), and the unlabeled image sampled from Dul is x ∼ PDul . Given an input x ∼ PDul (i.e., the unlabeled data distribution), let us denote the ball constraint around the image x as C (x) = {x ∈ X : ||x − x||p ≤ }, where is a ball constraint radius with respect to a norm || · ||p , and x is an adversarial example2 . Given that fθ is our model parameterized by θ, VAT [14] trains the model with the loss of vat that a minimax optimization problem: (1) vat := min Ex∼PDul max DKL (fθ (x ), fθ (x)) , θ
x ∈C (x)
where DKL is the Kullback-Leibler divergence. The inner maximization problem is to find an adversarial example near decision boundaries, while the minimization problem enforces the local smoothness of the model. However, VAT is insufficient to explore the set of of all adversarial examples within the constraint C because it only find one adversarial example x given a natural input x. Moreover, the works [15,20] show that even solving the maximization problem with random initialization, its solutions can also lie together and lose diversity, which significantly reduces the quality of adversarial examples. 2.2
Adversarial Local Distribution
In order to overcome the drawback of VAT, we introduce our proposed adversarial local distribution (ALD) with Dice loss function instead of DKL in [14,15]. ALD forms a set of all adversarial examples x within the ball constraint given
2
A sample generated by adding perturbations toward the adversarial direction.
186
T. Nguyen-Duc et al.
an input x. Therefore, the distribution can helps to sufficiently explore all possible adversarial examples. The adversarial local distribution Pθ (x |x) is defined with a ball constraint C as follow:
Pθ (x |x) :=
eDice (x ,x;θ) eDice (x ,x;θ) , = Z(x; θ) eDice (x ,x;θ) dx C (x)
(2)
where Pθ (·|x) is the conditional local distribution, and Z(x; θ) is a normalization function. The Dice is the Dice loss function as shown in Eq. 3 Dice (x , x; θ) =
C 1 2||pθ (yˆc |x) ∩ pθ (y˜c |x )|| ], [1 − C c=1 ||pθ (yˆc |x) + pθ (y˜c |x )||
(3)
where C is the number of classes. pθ (yˆc |x) and pθ (y˜c |x ) are the predictions of input image x and adversarial image x , respectively. 2.3
Cross-Adversarial Distribution Regularization
Given two random samples xi , xj ∼ PD (i = j), we define the cross-adversarial distribution (Cross-ALD) denoted P˜θ as shown in Eq. 4 P˜θ (·|xi , xj ) = γPθ (·|xi ) + (1 − γ)Pθ (·|xj )
(4)
where γ ∼ Beta(α, α) for α ∈ (0, ∞), inspired by [29]. The P˜θ is the Cross-ALD distribution, a mixture between the two adversarial local distributions. Given Eq. 4, we propose the Cross-ALD regularization at two random input images xi , xj ∼ PD (i = j) as ˜ |xi , xj )] = −H(P˜θ (·|xi , xj )), R(θ, xi , xj ) := Ex˜ ∼P˜θ (·|x i ,x j ) [log P˜θ (x
(5)
where H indicates the entropy of a given distribution. When minimizing R(θ, xi , xj ) or equivalently −H(Pθ (·|xi , xj )) w.r.t. θ, we encourage Pθ (·|xi , xj ) to be closer to a uniform distribution. This implies that ˜ ) = a constant c, where x ˜, x ˜ ∼ P˜θ (·|xi , xj ). In ˜ ) = f (x the outputs of f (x other words, we encourages the invariant model outputs under small perturbations. Therefore, minimizing the Cross-ALD regularization loss leads to an enhancement in the model smoothness. While VAT only enforces local smoothness using one adversarial example, Cross-ALD further encourages smoothness of both local and mixed adversarial distributions to improve the model generalization. 2.4
Multiple Particle-Based Search to Approximate the Cross-ALD Regularization
In Eq. 2, the normalization Z(x; θ) in denominator term is intractable to find. Therefore, we propose a multiple particle-based search method named SVGDF
Cross-Adversarial Local Distribution Regularization
187
to sample x(1) , x(2) , . . . , x(N ) ∼ Pθ (·|x)). N is the number of samples (or adversarial particles). SVGDF is used to solve the optimization problem of finding a target distribution Pθ (·|x)). SVGDF is a particle-based Bayesian inference algorithm that seeks a set of points (or particles) to approximate the target distribution without explicit parametric assumptions using iterative gradient-based updates. Specifically, a set of adversarial particles (x(n) ) is initialized by adding uniform noises, then projected onto the ball C . These adversarial particles are then iteratively updated using a closed-form solution (Eq. 6) until reaching termination conditions (, number of iterations).
x(n),(l+1) = x(n),(l) + τ ∗ φ(x(n),(l) ) C
s.t. φ(x ) =
N 1 [k(Φ(x(j),(l) ), Φ(x ))∇x (j),(l) log P (x(j),(l) |x) N j=1
(6)
+ ∇x (j),(l) k(Φ(x(j),(l) ), Φ(x ))], where x(n),(l) is a nth adversarial particle at lth iteration (n ∈{1, 2, ..., N }, and l ∈ {1, 2, ..., L} with the maximum number of iteration L). C is projection operator to the C constraint. τ is the step size2 updating. k is the radial basis . Φ is a fixed feature extractor function (RBF) kernel k(x , x) = exp −||x2σ−x|| 2 (e.g., encoder of U-Net/V-Net). While vanilla SVGD [10] is difficult to capture semantic meaning of high-resolution data because of calculating RBF kernel (k) directly on the data-level, we use the feature extractor Φ as a semantic transformation to further enhance the SVGD algorithm performance for medical imaging. Moreover, the two terms of φ in Eq. 6 have different roles: (i) the first one encourages the adversarial particles to move towards the high density areas of Pθ (·|x) and (ii) the second one prevents all the particles from collapsing into the local modes of Pθ (·|x) to enhance diversity (e.g.,pushing the particles away from each other). Please refer to the Cross-ALD Github repository for more details. SVGDF approximates Pθ (·|xi ) and Pθ (·|xj ) in Eq. 4, where xi , xj ∼ PDul (1) (2) (N ) (i = j). We form sets of adversarial particles as Dadv |xi = { xi , xi , . . . , xi } (1) (2) (N ) and Dadv |xj = {xj , xj , . . . , xj }. The problem (5) can then be relaxed to R(θ, xi , xj ) := Ex (n) ∼P (˜ x , x ˜ ; θ) (m) Dice ∼PDadv |x j Dadv |x i ,x j i (7) (n) (m) ˜ = γxi + (1 − γ)xj , s.t. : x ˜ = γxi + (1 − γ)xj ; x where γ ∼ Beta(α, α) for α ∈ (0, ∞). 2.5
Cross-ALD Regularization Loss in Medical Semi-supervised Image Segmentation
In this paper, the overall loss function total consists of three loss terms. The first term is the dice loss, where labeled image xl and segmentation ground-truth y
188
T. Nguyen-Duc et al.
are sampled from labeled dataset Dl . The second term is a contrastive learning loss for inter-class separation cs proposed by [21]. The third term is our CrossALD regularization, which is an enhancement of vat to significantly improve the model performance. total := min E(x l ,y )∼PDl lDice (xl , y; θ) + λcs Ex l ∼PDl ,x ∼PDul cs (xl , x) θ (8) + λCross−ALD E(x i ,x j )∼PDul R(θ, xi , xj ) , where λcs and λCross−ALD are the corresponding weights to balance the losses. Note that our implementation is replacing vat loss with the proposed CrossAD regularization in SS-Net code repository3 [21] to reach the state-of-the-art performance.
3
Experiments
In this section, we conduct several comprehensive experiments using the ACDC4 dataset [1] and the LA5 dataset [26] for 2D and 3D image segmentation tasks, respectively. For fair comparisons, all experiments are conducted using the identical setting, following [21]. We evaluate our model in challenging semi-supervised scenarios, where only 5% and 10% of the data are labeled and the remaining data in the training set is treated as unlabeled. The Cross-ALD uses the U-Net [18] and V-Net [13] architectures for the ACDC and LA dataset, respectively. We compare the diversity between the adversarial particles generated by our method against vanilla SVGD and VAT with random initialization in Sect. 3.1 . We then illustrate the Cross-AD outperforms other recent methods on ACDC and LA datasets in Sect. 3.2. We show ablation studies in Sect. 3.3. The effect of the number particles to the model performance is studied in the Cross-ALD Github repository. 3.1
Diversity of Adversarial Particle Comparison
Settings. We fixed all the decoder models (U-Net for ACDC and V-Net for LA). We run VAT with random initialization and SVGD multiple times to produce adversarial examples, which we compared to the adversarial particles generated using SVGDF. SVGDF is the proposed algorithm, which leverages feature transformation to capture the semantic meaning of inputs. Φ is the decoder of U-Net in ACDC dataset, while Φ is the decoder of V-Net in LA dataset. We set the same radius ball constraint, updating step, and etc. We randomly pick three images from the datasets to generate adversarial particles. To evaluate their diversity, we report the sum squared error (SSE) between these particles. Higher SSE indicates more diversity, and for each number of particles, we calculate the average of the mean of SSEs. 3 4 5
https://github.com/ycwu1997/SS-Net. https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html. http://atriaseg2018.cardiacatlas.org.
Cross-Adversarial Local Distribution Regularization
189
Fig. 1. Diversity comparison of our SVGDF, SVGD and VAT with random initialization using sum of square error (SSE) of ACDC and LA datasets. Table 1. Performance comparisons with six recent methods on ACDC dataset. All results of existing methods are used from [21] for fair comparisons. Method
# Scans used Metrics Complexity Labeled Unlabeled Dice(%)↑ Jaccard(%)↑ 95HD(voxel)↓ ASD(voxel)↓ Para.(M) MACs(G)
U-Net U-Net U-Net
3(5%) 7(10%) 70(All)
UA-MT [28] 3 (5%) SASSNet [8] DTC [11] URPC [12] MC-Net [22] SS-Net [21] Cross-ALD (Ours)
0 0 0
47.83 79.41 91.44
37.01 68.11 84.59
31.16 9.35 4.3
12.62 2.7 0.99
1.81 1.81 1.81
2.99 2.99 2.99
67(95%)
46.04 57.77 56.9 55.87 62.85 65.82 80.6
35.97 46.14 45.67 44.64 52.29 55.38 69.08
20.08 20.05 23.36 13.6 7.62 6.67 5.96
7.75 6.06 7.39 3.74 2.33 2.28 1.9
1.81 1.81 1.81 1.83 2.58 1.83 1.83
2.99 3.02 3.02 3.02 5.39 2.99 2.99
81.65 84.5 84.29 83.1 86.44 86.78 87.52
70.64 74.34 73.92 72.41 77.04 77.67 78.62
6.88 5.42 12.81 4.84 5.5 6.07 4.81
2.02 1.86 4.01 1.53 1.84 1.4 1.6
1.81 1.81 1.81 1.83 2.58 1.83 1.83
2.99 3.02 3.02 3.02 5.39 2.99 2.99
UA-MT [28] 7 (10%) 63(90%) SASSNet [8] DTC [11] URPC [12] MC-Net [22] SS-Net [21] Cross-ALD (Ours)
Results. Note that the advantage of SVGD over VAT is that the former generates diversified adversarial examples because of the second term in Eq. 6 while VAT only creates one example. Moreover, vanilla SVGD is difficult to capture semantic meaning of high-resolution medical imaging because it calculates kernel k on image-level. In Fig. 1, our SVGDF produces the most diverse particles compared to SVGD and VAT with random initialization. 3.2
Performance Evaluation on the ACDC and la Datasets
Settings. We use the metrics of Dice, Jaccard, 95% Hausdorff Distance (95HD), and Average Surface Distance (ASD) to evaluate the results. We compare our
190
T. Nguyen-Duc et al.
Table 2. Performance comparisons with six recent methods on LA dataset. All results of existing methods are used from [21] for fair comparisons. Method
# Scans used Metrics Complexity Labeled Unlabeled Dice(%)↑ Jaccard(%)↑ 95HD(voxel)↓ ASD(voxel)↓ Para.(M) MACs(G)
V-Net V-Net V-Net
4(5%) 8(10%) 80(All)
0 0 0
52.55 82.74 91.47
39.6 71.72 84.36
47.05 13.35 5.48
9.87 3.26 1.51
9.44 9.44 9.44
47.02 47.02 47.02
76(95%)
82.26 81.6 81.25 82.48 83.59 86.33 88.62
70.98 69.63 69.33 71.35 72.36 76.15 79.62
13.71 16.16 14.9 14.65 14.07 9.97 7.098
3.82 3.58 3.99 3.65 2.7 2.31 1.83
9.44 9.44 9.44 5.88 12.35 9.46 9.46
47.02 47.05 47.05 69.43 95.15 47.17 47.17
UA-MT [28] 8 (10%) 72(90%) SASSNet [8] DTC [11] URPC [12] MC-Net [22] SS-Net [21] Cross-ALD (Ours)
87.79 87.54 87.51 86.92 87.62 88.55 89.92
78.39 78.05 78.17 77.03 78.25 79.62 81.78
8.68 9.84 8.23 11.13 10.03 7.49 7.65
2.12 2.59 2.36 2.28 1.82 1.9 1.546
9.44 9.44 9.44 5.88 12.35 9.46 9.46
47.02 47.05 47.05 69.43 95.15 47.17 47.17
UA-MT [28] 4 (5%) SASSNet [8] DTC [11] URPC [12] MC-Net [22] SS-Net [21] Cross-ALD (Ours)
Cross-ALD to six recent methods including UA-MT [28] (MICCAI’19), SASSNet [8] (MICCAI’20), DTC [11] (AAAI’21) , URPC [12] (MICCAI’21) , MC-Net [22] (MICCAI’21), and SS-Net [21] (MICCAI’22). The loss weights λCross−ALD and λcs are set as an iteration dependent warming-up function [7], and number of particles N = 2. All experiments are conducted using the identical settings in the Github repository6 [21] for fair comparisons. Results. Recall that our Cross-ALD generates diversified adversarial particles using SVGDF compared to vanilla SVGD and VAT, and further enhances smoothness of cross-adversarial local distributions. In Table 1 and 2, the CrossALD can significantly outperform other recent methods with only 5%/10% labeled data training based on the four metrics. Especially, our method impressively gains 14.7% and 2.3% Dice score higher than state-of-the-art SS-Net using 5% labeled data of ACDC and LA, respectively. Moreover, the visualized results of Fig. 2 shows Cross-ALD can segment the most organ details compared to other methods. 3.3
Ablation Study
Settings. We use the same network architectures and parameter settings in Sect. 3.2, and train the models with 5% labeled training data of ACDC and LA. We illustrate that crossing adversarial particles is more beneficial than random 6
https://github.com/ycwu1997/SS-Net.
Cross-Adversarial Local Distribution Regularization
191
Table 3. Ablation study on ACDC and LA datasets. Dataset Method
# Scans used Metrics Labeled Unlabeled Dice(%)↑ Jaccard(%)↑ 95HD(voxel)↓ ASD(voxel)↓
ACDC
U-Net 4(5%) RanMixup 4 (5%) VAT VAT + Mixup SVGD SVGDF SVGDF + cs Cross-ALD (Ours)
0 76(95%)
47.83 61.78 63.87 66.23 66.53 73.15 74.89 80.6
37.01 51.69 53.18 56.37 58.09 61.71 62.61 69.08
31.16 8.16 7.61 7.18 6.41 6.32 6.52 5.96
12.62 3.44 3.38 2.53 2.4 2.12 2.01 1.9
LA
V-Net 3(5%) RanMixup 3 (5%) VAT VAT + Mixup SVGD SVGDF SVGDF + cs Cross-ALD (Ours)
0 67(95%)
52.55 79.82 82.27 83.28 84.62 86.3 86.55 87.52
39.6 67.44 70.46 71.77 73.6 76.17 76.51 78.62
47.05 16.52 13.82 12.8 11.68 10.01 9.41 4.81
9.87 5.19 3.48 2.63 2.94 2.11 2.24 1.6
Fig. 2. Visualization results of several semi-supervised segmentation methods with 5% labeled training data and its corresponding ground-truth on ACDC and LA datasets.
mixup between natural inputs (RanMixup [29]) because these particles are near decision boundaries. Recall that our SVGDF is better than VAT and SVGD by producing more diversified adversarial particles. Applying SVGDF’s particles and cs (SVGDF + cs ) to gain the model performance in the semi-supervised segmentation task, while Cross-ALD efficiently enhances smoothness to significantly improve the generalization. Result. Table 3 shows that mixing adversarial examples from VAT outperform those from RanMixup. While SVGDF + cs is better than SVGD and VAT, the proposed Cross-ALD achieves the most outstanding performance among comparisons methods. In addition, our method produces more accurate segmentation masks compared to the ground-truth, as shown in Fig. 2.
192
4
T. Nguyen-Duc et al.
Conclusion
In this paper, we have introduced a novel cross-adversarial local distribution (Cross-ALD) regularization that extends and overcomes drawbacks of VAT and Mixup techniques. In our method, SVGDF is proposed to approximate CrossALD, which produces more diverse adversarial particles than vanilla SVGD and VAT with random initialization. We adapt Cross-ALD to semi-supervised medical image segmentation to achieve start-of-the-art performance on the ACDC and LA datasets compared to many recent methods such as VAT [14], UA-MT [28], SASSNet [8], DTC [11], URPC [12] , MC-Net [22], and SS-Net [21]. Acknowledgements. This work was partially supported by the Australian Defence Science and Technology (DST) Group under the Next Generation Technology Fund (NGTF) scheme. Dinh Phung further gratefully acknowledges the partial support from the Australian Research Council, project ARC DP230101176.
References 1. Bernard, O., et al.: Deep learning techniques for automatic MRI cardiac multistructures segmentation and diagnosis: is the problem solved? IEEE Trans. Med. Imaging 37(11), 2514–2525 (2018) 2. Berthelot, D., et al.: Remixmatch: Semi-supervised learning with distribution alignment and augmentation anchoring. arXiv preprint arXiv:1911.09785 (2019) 3. Berthelot, D., Carlini, N., Goodfellow, I., Papernot, N., Oliver, A., Raffel, C.A.: Mixmatch: A holistic approach to semi-supervised learning. Adv. Neural Inform. Process. Syst. 32 (2019) 4. French, G., Laine, S., Aila, T., Mackiewicz, M., Finlayson, G.: Semi-supervised semantic segmentation needs strong, varied perturbations. arXiv preprint arXiv:1906.01916 (2019) 5. Gyawali, P., Ghimire, S., Wang, L.: Enhancing mixup-based semi-supervised learningwith explicit lipschitz regularization. In: 2020 IEEE International Conference on Data Mining (ICDM), pp. 1046–1051. IEEE (2020) 6. Lai, X., et al.: Semi-supervised semantic segmentation with directional contextaware consistency. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1205–1214 (2021) 7. Laine, S., Aila, T.: Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242 (2016) 8. Li, S., Zhang, C., He, X.: Shape-aware semi-supervised 3D semantic segmentation for medical images. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 552–561. Springer, Cham (2020). https://doi.org/10.1007/978-3-03059710-8 54 9. Li, X., Yu, L., Chen, H., Fu, C.W., Xing, L., Heng, P.A.: Transformationconsistent self-ensembling model for semisupervised medical image segmentation. IEEE Trans. Neural Netw. Learn. Syst. 32(2), 523–534 (2020) 10. Liu, Q., Wang, D.: Stein variational gradient descent: A general purpose bayesian inference algorithm. In: Lee, D., Sugiyama, M., Luxburg, U., Guyon, I., Garnett, R. (eds.) Proceedings of NeurIPS. vol. 29 (2016)
Cross-Adversarial Local Distribution Regularization
193
11. Luo, X., Chen, J., Song, T., Wang, G.: Semi-supervised medical image segmentation through dual-task consistency. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 35, pp. 8801–8809 (2021) 12. Luo, X., et al.: Efficient semi-supervised gross target volume of nasopharyngeal carcinoma segmentation via uncertainty rectified pyramid consistency. In: Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 318–329. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87196-3 30 13. Milletari, F., Navab, N., Ahmadi, S.A.: V-net: Fully convolutional neural networks for volumetric medical image segmentation. In: 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. IEEE (2016) 14. Miyato, T., Maeda, S.i., Koyama, M., Ishii, S.: Virtual adversarial training: a regularization method for supervised and semi-supervised learning. IEEE TPAMI 41(8), 1979–1993 (2018) 15. Nguyen-Duc, T., Le, T., Zhao, H., Cai, J., Phung, D.Q.: Particle-based adversarial local distribution regularization. In: AISTATS, pp. 5212–5224 (2022) 16. Ouali, Y., Hudelot, C., Tami, M.: An overview of deep semi-supervised learning. arXiv preprint arXiv:2006.05278 (2020) 17. Ouali, Y., Hudelot, C., Tami, M.: Semi-supervised semantic segmentation with cross-consistency training. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12674–12684 (2020) 18. Ronneberger, O., Fischer, P., Brox, T.: U-net: Convolutional networks for biomedical image segmentation. In: Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18, pp. 234–241. Springer (2015) 19. Sohn, K., et al.: Fixmatch: simplifying semi-supervised learning with consistency and confidence. Adv. Neural. Inf. Process. Syst. 33, 596–608 (2020) 20. Tashiro, Y., Song, Y., Ermon, S.: Diversity can be transferred: Output diversification for white-and black-box attacks. Proc. NeurIPS 33, 4536–4548 (2020) 21. Wu, Y., Wu, Z., Wu, Q., Ge, Z., Cai, J.: Exploring smoothness and class-separation for semi-supervised medical image segmentation. In: International Conference on Medical Image Computing and Computer-Assisted Intervention, vol. 13435, pp. 34–43. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16443-9 4 22. Wu, Y., Xu, M., Ge, Z., Cai, J., Zhang, L.: Semi-supervised left atrium segmentation with mutual consistency training. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 297–306. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87196-3 28 23. Wu, Z., Shi, X., Lin, G., Cai, J.: Learning meta-class memory for few-shot semantic segmentation. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 517–526 (2021) 24. Xia, Y., et al.: 3D semi-supervised learning with uncertainty-aware multi-view cotraining. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 3646–3655 (2020) 25. Xie, Y., Zhang, J., Liao, Z., Verjans, J., Shen, C., Xia, Y.: Intra-and inter-pair consistency for semi-supervised gland segmentation. IEEE Trans. Image Process. 31, 894–905 (2021) 26. Xiong, Z., et al.: A global benchmark of algorithms for segmenting the left atrium from late gadolinium-enhanced cardiac magnetic resonance imaging. Med. Image Anal. 67, 101832 (2021) 27. Yang, X., Song, Z., King, I., Xu, Z.: A survey on deep semi-supervised learning. IEEE Transactions on Knowledge and Data Engineering (2022)
194
T. Nguyen-Duc et al.
28. Yu, L., Wang, S., Li, X., Fu, C.-W., Heng, P.-A.: Uncertainty-aware self-ensembling model for semi-supervised 3D left atrium segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 605–613. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32245-8 67 29. Zhang, H., Cisse, M., Dauphin, Y.N., Lopez-Paz, D.: mixup: Beyond empirical risk minimization. arXiv preprint arXiv:1710.09412 (2017)
AMAE: Adaptation of Pre-trained Masked Autoencoder for Dual-Distribution Anomaly Detection in Chest X-Rays Behzad Bozorgtabar1,3(B) , Dwarikanath Mahapatra2 , and Jean-Philippe Thiran1,3 1
École Polytechnique Fédérale de Lausanne (EPFL), Lausanne, Switzerland {behzad.bozorgtabar,jean-philippe.thiran}@epfl.ch 2 Inception Institute of AI (IIAI), Abu Dhabi, United Arab Emirates [email protected] 3 Lausanne University Hospital (CHUV), Lausanne, Switzerland
Abstract. Unsupervised anomaly detection in medical images such as chest radiographs is stepping into the spotlight as it mitigates the scarcity of the labor-intensive and costly expert annotation of anomaly data. However, nearly all existing methods are formulated as a one-class classification trained only on representations from the normal class and discard a potentially significant portion of the unlabeled data. This paper focuses on a more practical setting, dual distribution anomaly detection for chest X-rays, using the entire training data, including both normal and unlabeled images. Inspired by a modern self-supervised vision transformer model trained using partial image inputs to reconstruct missing image regions- we propose AMAE, a two-stage algorithm for adaptation of the pre-trained masked autoencoder (MAE). Starting from MAE initialization, AMAE first creates synthetic anomalies from only normal training images and trains a lightweight classifier on frozen transformer features. Subsequently, we propose an adaptation strategy to leverage unlabeled images containing anomalies. The adaptation scheme is accomplished by assigning pseudo-labels to unlabeled images and using two separate MAE based modules to model the normative and anomalous distributions of pseudo-labeled images. The effectiveness of the proposed adaptation strategy is evaluated with different anomaly ratios in an unlabeled training set. AMAE leads to consistent performance gains over competing self-supervised and dual distribution anomaly detection methods, setting the new state-of-the-art on three public chest X-ray benchmarks - RSNA, NIH-CXR, and VinDr-CXR.
Keywords: Anomaly detection
· Chest X-ray · Masked autoencoder
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_19. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 195–205, 2023. https://doi.org/10.1007/978-3-031-43907-0_19
196
1
B. Bozorgtabar et al.
Introduction
To reduce radiologists’ reading burden and make the diagnostic process more manageable, especially when the number of experts is scanty, computer-aided diagnosis (CAD) systems, particularly deep learning-based anomaly detection [1,2,22], have witnessed the flourish due to their capability to detect rare anomalies for different imaging modalities including chest X-ray (CXR). Nonetheless, unsupervised anomaly detection methods [20,26] are strongly preferred due to the difficulties of highly class-imbalanced learning and the tedious annotation of anomaly data for developing such systems. Most current anomaly detection methods are formulated as a one-class classification (OCC) problem [18], where the goal is to model the distribution of normal images used for training and thus detect abnormal cases that deviate from normal class at test time. On this basis, image reconstruction based, e.g., autoencoder [9] or generative models [20], self-supervised learning (SSL) based, e.g., contrastive learning [26], and embedding-similarity-based methods [7] have been proposed for anomaly detection. Some recent self-supervised methods proposed synthetic anomalies using cut-and-paste data augmentation [12,19] to approximate real sub-image anomalies. Nonetheless, their performances lag due to the lack of real anomaly data. More importantly, these methods have often ignored readily available unlabeled images. More recently, similar to our method, DDAD [3] leverages readily available unlabeled images for anomaly detection, but it requires training an ensemble of several reconstruction-based networks. Self-supervised model adaptation on unlabeled data has been widely investigated using convolutional neural networks (CNNs) in many vision tasks via self-training [17], contrastive learning [22,26], and anatomical visual words [10]. Nonetheless, the adaptation of vision transformer (ViT) [8] architectures largely remains unexplored, particularly for anomaly detection. Recently, masked autoencoder (MAE) [11] based models demonstrated great scalability and substantially improved several selfsupervised learning benchmarks [27]. In this paper, inspired by the success of the MAE approach, we propose a two-stage algorithm for “Adaptation of pre-trained Masked AutoEncoder” (AMAE) to leverage simultaneously normal and unlabeled images for anomaly detection in chest X-rays. As for Stage 1 of our method, (i) AMAE creates synthetic anomalies from only normal training images, and the usefulness of pretrained MAE [11] is evaluated by training a lightweight classifier using a proxy task to detect synthetic anomalies. (ii) For the Stage 2, AMAE customizes the recipe of MAE adaptation based on an unlabeled training set. In particular, we propose an adaptation strategy based on reconstructing the masked-out input images. The rationale behind the proposed adaptation strategy is to assign pseudo-labels to unlabeled images and train two separate modules to measure the distribution discrepancy between normal and pseudo-labeled abnormal images. (iii) We conduct extensive experiments across three chest X-ray datasets and verify the effectiveness of our adaptation strategy in apprehending anomalous features from unlabeled images. In addition, we evaluate the model with different
AMAE
197
Fig. 1. Schematic overview of AMAE training (Stage 1). Top. Illustration of Anatpaste augmentation [19] generated from normal training images. Bottom. Starting from MAE initialization, only the MLP-based projection head (Proj.) is trained to classify synthetic anomalies.
anomaly ratios (ARs) in an unlabeled training set and show consistent performance improvement with increasing AR.
2
Method
Notation. We first formally define the problem setting for the proposed dualdistribution anomaly detection. Contrary to previous unsupervised anomaly detection methods, AMAE fully uses unlabeled images, yielding a training data Ttrain = Tn ∪ Tu consisting of both normal Tn and unlabeled Tu training sets. N We denote the normal training set as Tn = {xni }i=1 , with N normal images, M and the unlabeled training set as Tu = {xui }i=1 , with M unlabeled images to be composed of both normal and abnormal images. At test time, given a test S set Ttest = {(xti , yi )}i=1 with S normal or abnormal images, where yi ∈ {0, 1} is the corresponding label to xti (0 for normal (negative) and 1 for abnormal (positive) image), the trained anomaly detection model should identify whether the test image is abnormal or not. Architecture. Our architecture is τ ypred = (5) 1 (OOD) otherwise
4
Experiments and Results
Dataset and Implementation. For our experiments, we utilized a fetal ultrasound dataset of 359 subject videos that were collected as part of the PULSE project [7]. The in-distribution dataset consisted of 5 standard heart views (3VT, 3VV, LVOT, RVOT, and 4CH), while the out-of-distribution dataset comprised of three non-heart anatomies - fetal head, abdomen, and femur. The original images were of size 1008 × 784 pixels and were resized to 224 × 224 pixels. To train the models, we randomly sampled 5000 fetal heart images and used 500 images for evaluating image generation performance. To test the performance of our final model and compare it with other methods, we used an held-out dataset of 7471 images, comprising 4309 images of different heart views and 3162 images (about 1000 for each anatomy) of out-of-distribution classes. Further details about the dataset are given in Supp. Fig. 2 and 3. All models were trained using PyTorch version 1.12 with a Tesla V100 32 GB GPU. During training, we used T=1000 for noising the input image and a linearly increasing noise schedule that varied from 0.0015 to 0.0195. To generate samples from our trained model, we used DDIM [26] sampling with T=100. All baseline models were trained and evaluated using the original implementation. 4.1
Results
We evaluated the performance of the dual-conditioned diffusion models (DCDMs) for OOD detection by comparing them with two current state-ofthe-art unsupervised reconstruction-based approaches and one likelihood-based approach. The first baseline is Deep-MCDD [15], a likelihood-based OOD detection method that proposes a Gaussian discriminant-based objective to learn class conditional distributions. The second baseline is ALOCC [23] a GAN-based model that uses the confidence of the discriminator on reconstructed samples to detect OOD samples. The third baseline is the method of Graham et al. [10], where they use DDPM [14] to generate multiple images at varying noise levels
222
D. Mishra et al.
Table 1. Quantitative comparison of our model (DCDM) with reference methods Method
AUC(%) F1-Score(%) Accuracy(%) Precision(%)
Deep-MCDD [15]
64.58
66.23
60.41
51.82
ALOCC [23]
57.22
59.34
52.28
45.63
Graham et al. [10] 63.86
63.55
60.15
50.89
DCDM(Ours)
74.29
77.95
73.34
77.60
for each input. They then compute the MSE and LPIPS metrics for each image compared to the input, convert them to Z-scores, and finally average them to obtain the OOD score. Quantitative Results. The performance of the DCDM, along with comparisons with the other approaches, are shown in Table 1. The GAN-based method ALOCC [23] has the lowest AUC of 57.22%, which is improved to 63.86% by the method of Graham et al. and further improved to 64.58% by likelihoodbased Deep-MCDD. DCDM outperforms all the reference methods by 20%, 14% and 13%, respectively and has an AUC of 77.60%. High precision is essential for OOD detection as this can reduce false positives and increase trust in the model. DCDM exhibits a precision that is 22% higher than the reference methods while still having an 8% improvement in F1-Score. Qualitative Results. Qualitative results are shown in Fig. 2. Visual comparisons show ALOCC generates images structurally similar to input images for indistribution and OOD samples. This makes it harder for the ALOCC model to detect OOD samples. The model of Graham et al. generates any random heart view for a given image as a DDPM is unconditional, and our in-distribution data contains multiple heart views. For example, given a 4CH view as input, the model generates an entirely different heart view. However, unlike ALOCC, the Graham et al. model generates heart views for OOD samples, improving OOD detection performance. DCDM generates images with high spatial similarity to the input image and belonging to the same heart view for ID samples while structurally diverse heart views for OOD samples. In Fig. 2 (c) for OOD sample, even-though the confidence is high (0.68), the gap between ID and OOD classes is wide enough to separate the two.Additional qualitative results can be observed in Supp. Fig. 4. 4.2
Ablation Study
Ablation experiments were performed to study the impact of various conditioning mechanisms on the model performance both qualitatively and quantitatively. When analyzed quantitatively, as shown in Table 2, the unconditional model has the lowest AUC of 69.61%. Incorporating the IDCC guidance or LIFC
Dual Conditioned Diffusion Models for OOD Detection
223
Fig. 2. Qualitative comparison of our method with (a) ALOCC generates similar images to the input for ID and OOD samples (b) Graham et al. generates any random heart view for a given input image (c) Our model generates images that are similar to the input image for ID and dissimilar for OOD samples. Classes predicted by CFR and the OOD score (τ = 0.73) are mentioned in brackets. Table 2. Ablation study of different conditioning mechanisms of DCDM.
Method
Accuracy (%) Precision (%) AUC (%)
Unconditional
68.16
58.44
69.61
In-Distribution Class Conditioning 74.39
66.12
75.27
Latent Image Feature Conditioning 77.02
70.02
77.40
Dual Conditioning
73.34
77.60
77.95
Fig. 3. Qualitative ablation study showing the effect of (a) IDCC, (b) LIFC and, (c) DC on generative results of DM. Brackets in IDCC, DC show labels predicted by CFR.
224
D. Mishra et al.
separately, improves performance with an AUC of 75.27% and 77.40%, respectively. The best results are achieved when both mechanisms are used (DCDM), resulting in an 11% improvement in the AUC score relative to the unconditional model. Although there is a small margin of performance improvement between the combined model (DCDM) and the LIFC model in terms of AUC, the precision improves by 3%, demonstrating the combined model is more precise and hence the best model for OOD detection. As shown in Fig. 3, the unconditional diffusion model generates a random heart view for a given input for both in-distribution and OOD samples. The IDCC guides the model to generate a heart view according to the in-distribution classifier (CFR) prediction which leads to the generation of similar samples for in-distribution input while dissimilar samples for OOD input. On the other hand, LIFC generates an image with similar spatial information. However, heart views are still generated for OOD samples as the model was only trained on them. When dual-conditioning (DC) is used, the model generates images that are closer aligned to the input image for in-distribution input and high-fidelity heart views for OOD than those generated by a model conditioned on either IDCC or LIFC alone. Supp. Fig. 1 presents further qualitative ablations.
5
Conclusion
We introduce novel dual-conditioned diffusion model for OOD detection in fetal ultrasound videos and demonstrate how the proposed dual-conditioning mechanisms can manipulate the generative space of a diffusion model. Specifically, we show how our dual-conditioning mechanism can tackle scenarios where the in-distribution data has high inter- (using IDCC) and intra- (using LIFC) class variations and guide a diffusion model to generate similar images to the input for in-distribution input and dissimilar images for OOD input images. Our approach does not require labelled data for OOD classes and is especially applicable to challenging scenarios where the in-distribution data comprises more than one class and there is high similarity between the in-distribution and OOD classes. Acknowledgement. This work was supported in part by the InnoHK-funded Hong Kong Centre for Cerebro-cardiovascular Health Engineering (COCHE) Project 2.1 (Cardiovascular risks in early life and fetal echocardiography), the UK EPSRC (Engineering and Physical Research Council) Programme Grant EP/T028572/1 (VisualAI), and a UK EPSRC Doctoral Training Partnership award.
References 1. Arjovsky, M., Bottou, L.: Towards principled methods for training generative adversarial networks. arXiv preprint arXiv:1701.04862 (2017) 2. Bau, D., et al.: Seeing what a GAN cannot generate. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 4502–4511 (2019) 3. Chen, X., Konukoglu, E.: Unsupervised detection of lesions in brain MRI using constrained adversarial auto-encoders. arXiv preprint arXiv:1806.04972 (2018)
Dual Conditioned Diffusion Models for OOD Detection
225
4. Choi, H., Jang, E., Alemi, A.A.: Waic, but why? Generative ensembles for robust anomaly detection. arXiv preprint arXiv:1810.01392 (2018) 5. DeVries, T., Taylor, G.W.: Learning confidence for out-of-distribution detection in neural networks. arXiv preprint arXiv:1802.04865 (2018) 6. Dhariwal, P., Nichol, A.: Diffusion models beat GANs on image synthesis. Adv. Neural. Inf. Process. Syst. 34, 8780–8794 (2021) 7. Drukker, L., et al.: Transforming obstetric ultrasound into data science using eye tracking, voice recording, transducer motion and ultrasound video. Sci. Rep. 11(1), 14109 (2021) 8. Fort, S.: Adversarial vulnerability of powerful near out-of-distribution detection. arXiv preprint arXiv:2201.07012 (2022) 9. Fort, S., Ren, J., Lakshminarayanan, B.: Exploring the limits of out-of-distribution detection. Adv. Neural. Inf. Process. Syst. 34, 7068–7081 (2021) 10. Graham, M.S., Pinaya, W.H., Tudosiu, P.D., Nachev, P., Ourselin, S., Cardoso, M.J.: Denoising diffusion models for out-of-distribution detection. arXiv preprint arXiv:2211.07740 (2022) 11. Gu´enais, T., Vamvourellis, D., Yacoby, Y., Doshi-Velez, F., Pan, W.: Bacoun: Bayesian classifers with out-of-distribution uncertainty. arXiv preprint arXiv:2007.06096 (2020) 12. Hendrycks, D., Gimpel, K.: A baseline for detecting misclassified and out-ofdistribution examples in neural networks. arXiv preprint arXiv:1610.02136 (2016) 13. Hertz, A., Mokady, R., Tenenbaum, J., Aberman, K., Pritch, Y., Cohen-Or, D.: Prompt-to-prompt image editing with cross attention control. arXiv preprint arXiv:2208.01626 (2022) 14. Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. Adv. Neural. Inf. Process. Syst. 33, 6840–6851 (2020) 15. Lee, D., Yu, S., Yu, H.: Multi-class data description for out-of-distribution detection. In: Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 1362–1370 (2020) 16. Liu, L., Ren, Y., Lin, Z., Zhao, Z.: Pseudo numerical methods for diffusion models on manifolds. arXiv preprint arXiv:2202.09778 (2022) 17. Margatina, K., Baziotis, C., Potamianos, A.: Attention-based conditioning methods for external knowledge integration. arXiv preprint arXiv:1906.03674 (2019) 18. Meng, C., et al.: Sdedit: guided image synthesis and editing with stochastic differential equations. In: International Conference on Learning Representations (2021) 19. Nalisnick, E., Matsukawa, A., Teh, Y.W., Gorur, D., Lakshminarayanan, B.: Do deep generative models know what they don’t know? arXiv preprint arXiv:1810.09136 (2018) 20. Rebain, D., Matthews, M.J., Yi, K.M., Sharma, G., Lagun, D., Tagliasacchi, A.: Attention beats concatenation for conditioning neural fields. arXiv preprint arXiv:2209.10684 (2022) 21. Ren, J., et al.: Likelihood ratios for out-of-distribution detection. Adv. Neural Inform. Process. Syst. 32 (2019) 22. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., Ommer, B.: High-resolution image synthesis with latent diffusion models. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10684–10695 (2022) 23. Sabokrou, M., Khalooei, M., Fathy, M., Adeli, E.: Adversarially learned one-class classifier for novelty detection. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3379–3388 (2018) 24. Saharia, C., et al.: Palette: Image-to-image diffusion models. In: ACM SIGGRAPH 2022 Conference Proceedings, pp. 1–10 (2022)
226
D. Mishra et al.
25. Schlegl, T., Seeb¨ ock, P., Waldstein, S.M., Schmidt-Erfurth, U., Langs, G.: Unsupervised anomaly detection with generative adversarial networks to guide marker discovery. In: Neithammer, M., et al. (eds.) IPMI 2017. LNCS, vol. 10265, pp. 146–157. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-59050-9 12 26. Song, J., Meng, C., Ermon, S.: Denoising diffusion implicit models. arXiv preprint arXiv:2010.02502 (2020) 27. Vaswani, A., et al.: Attention is all you need. Adv. Neural Inform. Process. Syst. 30 (2017) 28. Wald, Y., Feder, A., Greenfeld, D., Shalit, U.: On calibration and out-of-domain generalization. Adv. Neural. Inf. Process. Syst. 34, 2215–2227 (2021) 29. Wyatt, J., Leach, A., Schmon, S.M., Willcocks, C.G.: Anoddpm: anomaly detection with denoising diffusion probabilistic models using simplex noise. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 650–656 (2022) 30. Xu, H., et al.: Unsupervised anomaly detection via variational auto-encoder for seasonal KPIs in web applications. In: Proceedings of the 2018 World Wide Web Conference, pp. 187–196 (2018) 31. Yang, J., Zhou, K., Li, Y., Liu, Z.: Generalized out-of-distribution detection: a survey. arXiv preprint arXiv:2110.11334 (2021) 32. Yang, L., et al.: Diffusion models: a comprehensive survey of methods and applications. arXiv preprint arXiv:2209.00796 (2022) 33. Zhou, Y.: Rethinking reconstruction autoencoder-based out-of-distribution detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7379–7387 (2022) 34. Zhou, Z., Guo, L.Z., Cheng, Z., Li, Y.F., Pu, S.: Step: out-of-distribution detection in the presence of limited in-distribution labeled data. Adv. Neural. Inf. Process. Syst. 34, 29168–29180 (2021)
Weakly-Supervised Positional Contrastive Learning: Application to Cirrhosis Classification Emma Sarfati1,2(B) , Alexandre Bˆ one1 , Marc-Michel Roh´e1 , Pietro Gori2 , and Isabelle Bloch2,3 1
2
Guerbet Research, Villepinte, France LTCI, T´el´ecom Paris, Institut Polytechnique de Paris, Paris, France [email protected] 3 Sorbonne Universit´e, CNRS, LIP6, Paris, France
Abstract. Large medical imaging datasets can be cheaply and quickly annotated with low-confidence, weak labels (e.g., radiological scores). Access to high-confidence labels, such as histology-based diagnoses, is rare and costly. Pretraining strategies, like contrastive learning (CL) methods, can leverage unlabeled or weakly-annotated datasets. These methods typically require large batch sizes, which poses a difficulty in the case of large 3D images at full resolution, due to limited GPU memory. Nevertheless, volumetric positional information about the spatial context of each 2D slice can be very important for some medical applications. In this work, we propose an efficient weakly-supervised positional (WSP) contrastive learning strategy where we integrate both the spatial context of each 2D slice and a weak label via a generic kernel-based loss function. We illustrate our method on cirrhosis prediction using a large volume of weakly-labeled images, namely radiological low-confidence annotations, and small strongly-labeled (i.e., high-confidence) datasets. The proposed model improves the classification AUC by 5% with respect to a baseline model on our internal dataset, and by 26% on the public LIHC dataset from the Cancer Genome Atlas. The code is available at: https://github. com/Guerbet-AI/wsp-contrastive. Keywords: Weakly-supervised learning Cirrhosis prediction · Liver
1
· Contrastive learning · CT ·
Introduction
In the medical domain, obtaining a large amount of high-confidence labels, such as histopathological diagnoses, is arduous due to the cost and required technicality. It is however possible to obtain lower confidence assessments for a large amount of images, either by a clinical questioning, or directly by a radiological diagnosis. To take advantage of large volumes of unlabeled or weakly-labeled Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 22. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 227–237, 2023. https://doi.org/10.1007/978-3-031-43907-0_22
228
E. Sarfati et al.
images, pre-training encoders with self-supervised methods showed promising results in deep learning for medical imaging [1,4,21,27–29]. In particular, contrastive learning (CL) is a self-supervised method that learns a mapping of the input images to a representation space where similar (positive) samples are moved closer and different (negative) samples are pushed far apart. Weak discrete labels can be integrated into contrastive learning by, for instance, considering as positives only the samples having the same label, as in [13], or by directly weighting unsupervised contrastive and supervised cross entropy loss functions, as in [19]. In this work, we focus on the scenario where radiological meta-data (thus, low-confidence labels) are available for a large amount of images, whereas high-confidence labels, obtained by histological analysis, are scarce. Naive extensions of contrastive learning methods, such as [5,10,11], from 2D to 3D images may be difficult due to limited GPU memory and therefore small batch size. A usual solution consists in using patch-based methods [8,23]. However, these methods pose two difficulties: they reduce the spatial context (limited by the size of the patch), and they require similar spatial resolution across images. This is rarely the case for abdominal CT/MRI acquisitions, which are typically strongly anisotropic and with variable resolutions. Alternatively, depth position of each 2D slice, within its corresponding volume, can be integrated in the analysis. For instance, in [4], the authors proposed to integrate depth in the sampling strategy for the batch creation. Likewise, in [26], the authors proposed to define as similar only 2D slices that have a small depth difference, using a normalized depth coordinate d ∈ [0, 1]. These works implicitly assume a certain threshold on depth to define positive and negative samples, which may be difficult to define and may be different among applications and datasets. Differently, inspired by [2,8], here we propose to use a degree of “positiveness” between samples by defining a kernel function w on depth positions. This allows us to consider volumetric depth information during pre-training and to use large batch sizes. Furthermore, we also propose to simultaneously leverage weak discrete attributes during pre-training by using a novel and efficient contrastive learning composite kernel loss function, denoting our global method Weakly-Supervised Positional (WSP). We apply our method to the classification of histology-proven liver cirrhosis, with a large volume of (weakly) radiologically-annotated CT-scans and a small amount of histopathologically-confirmed cirrhosis diagnosis. We compare the proposed approach to existing self-supervised methods.
2
Method
Let xt be an input 2D image, usually called anchor, extracted from a 3D volume, yt a corresponding discrete weak variable and dt a related continuous variable. In this paper, yt refers to a weak radiological annotation and dt corresponds to the normalized depth position of the 2D image within its corresponding 3D volume: if Vmax corresponds to the maximal depth-coordinate of a volume V , pt with pt ∈ [0, Vmax ] being the original depth coordinate. we compute dt = Vmax
Weakly-Supervised Positional Contrastive Learning
229
+ Let x− j and xi be two semantically different (negative) and similar (positive) images with respect to xt , respectively. The definition of similarity is crucial in CL and is the main difference between existing methods. For instance, in unsupervised CL, methods such as SimCLR [5, 6] choose as positive samples random augmentations of the anchor x+ i = t(xt ), where t ∼ T is a random transformation chosen among a user-selected family T . Negative images x− j are all other (transformed) images present in the batch. − + Once xj and xi are defined, the goal of CL is to compute a mapping function fθ : X → Sd , where X is the set of images and Sd the representation space, so that similar samples are mapped closer in the representation space than dissimilar samples. Mathematically, this can be defined as looking for a fθ that satisfies the condition: + (1) s− tj − sti ≤ 0 ∀t, j, i − + + where s− tj = sim(fθ (xt ), fθ (xj )) and sti = sim(fθ (xt ), fθ (xi )), with sim a T
similarity function defined here as sim(a, b) = aτ b with τ > 0. In the presence of discrete labels y, the definition of negative (x− j ) and posi+ tive (xi ) samples may change. For instance, in SupCon [13], the authors define as positives all images with the same discrete label y. However, when working with continuous labels d, one cannot use the same strategy since all images are somehow positive and negative at the same time. A possible solution [26] would be to define a threshold γ on the distance between labels (e.g., da , db ) so that, if the distance is smaller than γ (i.e., ||da − db ||2 < γ), the samples (e.g., xa and xb ) are considered as positives. However, this requires a user-defined hyperparameter γ, which could be hard to find in practice. A more efficient solution, as proposed in [8], is to define a degree of “positiveness” between samples using a normalized kernel function wσ (d, di ) = Kσ (d − di ), where Kσ is, for instance, a Gaussian kernel, with user defined hyper-parameter σ and 0 ≤ wσ ≤ 1. It is interesting to notice that, for discrete labels, one could also define a kernel as: wδ (y, yi ) = δ(y − yi ), δ being the Dirac function, retrieving exactly SupCon [13]. In this work, we propose to leverage both continuous d and discrete y labels, by combining (here by multiplying) the previously defined kernels, wσ and wδ , into a composite kernel loss function. In this way, samples will be considered as similar (positive) only if they have a composite degree of “positiveness” greater than zero, namely both kernels have a value greater (or different) than 0 (wσ > 0 and wδ = 0). An example of resulting representation space is shown in Fig. 1. This constraint can be defined by slightly modifying the condition introduced in Eq. 1, as: (2) wδ (yt , yi ) · wσ (dt , di )(stj − sti ) ≤ 0 ∀t, i, j = i composite kernel wti
where the indices t, i, j traverse all N images in the batch since there are no “hard” positive or negative samples, as in SimCLR or SupCon, but all images are considered as positive and negative at the same time. As commonly done in CL [3], this condition can be transformed into an optimization problem using
230
E. Sarfati et al.
Fig. 1. Example of representation space constructed by our loss function, leveraging both continuous depth coordinate d and discrete label y (i.e., radiological diagnosis yradio ). Samples from different radiological classes are well separated and, at the same time, samples are ordered within each class based on their depth coordinate d.
the max operator and its smooth approximation LogSumExp: max(0, wti {stj − sti }N wti max(0, {stj − sti }N arg min j=1 ) = arg min j=1 ) fθ
t,i
⎛ ≈ arg min ⎝− fθ
j=i
t,i
fθ
wti log
t,i
exp(sti )
N j=i exp(stj )
⎞ ⎠
j=i
(3) By defining P (t) = {i : yi = yt } as the set of indices of images xi in the batch with the same discrete label yi as the anchor xt , we can rewrite our final loss function as: N exp(sti ) LW SP = − wσ (dt , di ) log N (4) t=1 i∈P (t) j=i exp(stj ) where wσ (dt , di ) is normalized over i ∈ P (t). In practice, it is rather easy to find a good value of σ, as the proposed kernel method is quite robust to its variation. A robustness study is available in the supplementary material. For the experiments, we fix σ = 0.1.
3
Experiments
We compare the proposed method with different contrastive and non-contrastive methods, that either use no meta-data (SimCLR [5], BYOL [10]), or leverage
Weakly-Supervised Positional Contrastive Learning
231
only discrete labels (SupCon [13]), or continuous labels (depth-Aware [8]). The proposed method is the only one that takes simultaneously into account both discrete and continuous labels. In all experiments, we work with 2D slices rather than 3D volumes due to the anisotropy of abdominal CT-scans in the depth direction and the limited spatial context or resolution obtained with 3D patchbased or downsampling methods, respectively, which strongly impacts the cirrhosis diagnosis that is notably based on the contours irregularity. Moreover, the large batch sizes necessary in contrastive learning can not be handled in 3D due to a limited GPU memory. 3.1
Datasets
Three datasets of abdominal CT images are used in this study. One dataset is used for contrastive pretraining, and the other two for evaluation. All images have a 512 × 512 size, and we clip the intensity values between -100 and 400. D r adio . First, Dradio contains 2,799 CT-scans of patients in portal venous phase with a radiological (weak) annotation, i.e. realized by a radiologist, indicating four different stages of cirrhosis: no cirrhosis, mild cirrhosis, moderate cirrhosis and severe cirrhosis (yradio ). The respective numbers are 1880, 385, 415 and 119. yradio is used as the discrete label y during pre-training. D 1histo . It contains 106 CT-scans from different patients in portal venous phase, with an identified histopathological status (METAVIR score) obtained 1 . It corresponds to absent fibrosis by a histological analysis, designated as yhisto (F0), mild fibrosis (F1), significant fibrosis (F2), severe fibrosis (F3) and cirrhosis (F4). This score is then binarized to indicate the absence or presence of advanced fibrosis [14]: F0/F1/F2 (N = 28) vs. F3/F4 (N = 78). D 2histo . This is the public LIHC dataset from the Cancer Genome Atlas [9], 2 , that which presents a histological score, the Ishak score, designated as yhisto 1 differs from the METAVIR score present in Dhisto . This score is also distributed through five labels: No Fibrosis, Portal Fibrosis, Fibrous Speta, Nodular Formation and Incomplete Cirrhosis and Established Cirrhosis. Similarly 1 , we also binarize the Ishak score, as proposed to the METAVIR score in Dhisto in [16,20], which results in two cohorts of 34 healthy and 15 pathological patients. In all datasets, we select the slices based on the liver segmentation of the patients. To gain in precision, we keep the top 70% most central slices with respect to liver segmentation maps obtained manually in Dradio , and automati1 2 and Dhisto using a U-Net architecture pretrained on Dradio [18]. cally for Dhisto For the latter pretraining dataset, it presents an average slice spacing of 3.23 mm with a standard deviation of 1.29 mm. For the x and y axis, the dimension is 0.79 mm per voxel on average, with a standard deviation of 0.10 mm. 3.2
Architecture and Optimization
Backbones. We propose to work with two different backbones in this paper: TinyNet and ResNet-18 [12]. TinyNet is a small encoder with 1.1M parameters,
232
E. Sarfati et al.
inspired by [24], with five convolutional layers, a representation space (for downstream tasks) of size 256 and a latent space (after a projection head of two dense layers) of size 64. In comparison, ResNet-18 has 11.2M parameters, a representation space of dimension 512 and a latent space of dimension 128. More details and an illustration of TinyNet are available in the supplementary material, as well as a full illustration of the algorithm flow. Data Augmentation, Sampling and Optimization. CL methods [5,10,11] require strong data augmentations on input images, in order to strengthen the association between positive samples [22]. In our work, we leverage three types of augmentations: rotations, crops and flips. Data augmentations are computed on the GPU, using the Kornia library [17]. During inference, we remove the augmentation module to only keep the original input images. For sampling, inspired by [4], we propose a strategy well-adapted for contrastive learning in 2D medical imaging. We first sample N patients, where N is the batch size, in a balanced way with respect to the radiological/histological classes; namely, we roughly have the same number of subjects per class. Then, we randomly select only one slice per subject. In this way, we maximize the slice heterogeneity within each batch. We use the same sampling strategy also for clas2 , which has fewer patients than the batch size, we sification baselines. For Dhisto use a balanced sampling strategy with respect to the radiological/histological classes with no obligation of one slice per patient in the batch. As we work with 2D slices rather than 3D volumes, we compute the average probability per patient of having the pathology. The evaluation results presented later are based on the patient-level aggregated prediction. Finally, we run our experiments on a Tesla V100 with 16GB of RAM and a 6 CPU cores, and we used the PyTorch-Lightning library to implement our models. All models share the same data augmentation module, with a batch size of B = 64 and a fixed number of epochs nepochs = 200. For all experiments, we fix a learning rate (LR) of α = 10−4 and a weight decay of λ = 10−4 . We add a cosine decay learning rate scheduler [15] to prevent over-fitting. For BYOL, we initialize the moving average decay at 0.996. Evaluation Protocol. We first pretrain the backbone networks on Dradio using all previously listed contrastive and non-contrastive methods. Then, we train a regularized logistic regression on the frozen representations of the datasets 1 2 and Dhisto . We use a stratified 5-fold cross-validation. As a baseline, Dhisto we train a classification algorithm from scratch (supervised) for each dataset, 1 2 and Dhisto , using both backbone encoders and the same 5-fold crossDhisto validation strategy. We also train a regularized logistic regression on representations obtained with a random initialization as a second baseline (random). Finally, we report the cross-validated results for each model on the aggregated 1+2 1 2 = Dhisto + Dhisto . dataset Dhisto
Weakly-Supervised Positional Contrastive Learning
233
Table 1. Resulting 5-fold cross-validation AUCs. For each encoder, best results are in bold, second top results are underlined. * = We use the pretrained weights from ImageNet with ResNet-18 and run a logistic regression on the frozen representations. Backbone Pretraining method Weak labels Depth pos. D 1h i s t o (N=106) D 2h i s t o (N=49) D 1+2 h i s t o (N=155) TinyNet
ResNet-18
Supervised
✗
✗
None (random) SimCLR BYOL SupCon depth-Aware Ours
✗ ✗ ✗ ✓ ✗ ✓
✗ ✗ ✗ ✗ ✓ ✓
Supervised
✗
✗
None (random) ImageNet* SimCLR BYOL SupCon depth-Aware Ours
✗ ✗ ✗ ✗ ✓ ✗ ✓
✗ ✗ ✗ ✗ ✗ ✓ ✓
0.79 (±0.05) 0.64 0.75 0.75 0.76 0.80 0.84
(±0.10) (±0.08) (±0.09) (±0.09) (±0.13) (±0.12)
0.77 (±0.10) 0.69 0.72 0.79 0.78 0.69 0.83 0.84
(±0.19) (±0.17) (±0.09) (±0.09) (±0.07) (±0.07) (±0.07)
0.65 (±0.25) 0.75 0.88 0.95 0.93 0.81 0.91
(±0.13) (±0.16) (±0.07) (±0.07) (±0.08) (±0.11)
0.56 (±0.29) 0.73 0.76 0.82 0.77 0.69 0.82 0.85
(±0.12) (±0.04) (±0.14) (±0.11) (±0.13) (±0.11) (±0.10)
0.71 (±0.04) 0.73 0.76 0.77 0.72 0.77 0.79
(±0.06) (±0.11) (±0.08) (±0.06) (±0.08) (±0.11)
0.72 (±0.08) 0.68 0.66 0.79 0.78 0.76 0.80 0.84
(±0.09) (±0.10) (±0.08) (±0.08) (±0.12) (±0.07) (±0.07)
Fig. 2. Projections of the ResNet-18 representation vectors of 10 randomly selected 1 onto the first two modes of a PCA. Each dot represents a 2D slice. subjects of Dhisto Color gradient refers to different depth positions. Red = cirrhotic cases. Blue = healthy subjects.
4
Results and Discussion
We present in Table 1 the results of all our experiments. For each of them, we report whether the pretraining method integrates the weak label meta-data, the depth spatial encoding, or both, which is the core of our method. First, we 1 can notice that our method outperforms all other pretraining methods in Dhisto 1+2 and Dhisto , which are the two datasets with more patients. For the latter, the proposed method surpasses the second best pretraining method, depth-Aware,
234
E. Sarfati et al.
1 by 4%. For Dhisto , it can be noticed that WSP (ours) provides the best AUC 2 , our method score whatever the backbone used. For the second dataset Dhisto is on par with BYOL and SupCon when using a small encoder and outperforms the other methods when using a larger backbone. To illustrate the impact of the proposed method, we report in Fig. 2 the projections of the ResNet-18 representation vectors of 10 randomly selected subjects 1 onto the first two modes of a PCA. It can be noticed that the repreof Dhisto sentation space of our method is the only one where the diagnostic label (not available during pretraining) and the depth position are correctly integrated. Indeed, there is a clear separation between slices of different classes (healthy at the bottom and cirrhotic cases at the top) and at the same time it seems that the depth position has been encoded in the x-axis, from left to right. SupCon performs well on the training set of Dradio (figure available in the supplementary 2 1 with TinyNet, but it poorly generalizes to Dhisto material), as well as Dhisto 1+2 and Dhisto . The method depth-Aware manages to correctly encode the depth position but not the diagnostic class label. To assess the clinical performance of the pretraining methods, we also compute the balanced accuracy scores (bACC) of the trained classifiers, which is compared in Table 2 to the bACC achieved by radiologists who were asked to 1 . visually assess the presence or absence of cirrhosis for the N=106 cases of Dhisto
The reported bACC values correspond to the best scores among those obtained with Tiny and ResNet encoders. Radiologists achieved a bACC Pretraining method bACC models bACC radiologists of 82% with respect to the histoSupervised 0.78 (±0.04) None (random) 0.71 (±0.13) logical reference. The two bestImageNet 0.74 (±0.13) performing methods surpassed SimCLR 0.78 (±0.08) BYOL 0.77 (±0.04) 0.82 this score: depth-Aware and 0.77 (±0.10) SupCon the proposed WSP approach, depth-Aware 0.84 (±0.04) Ours 0.85 (±0.09) improving respectively the radiologists score by 2% and 3%, suggesting that including 3D information (depth) at the pretraining phase was beneficial.
Table 2. Comparison of the pretraining methods with a binary radiological annotation for cirrhosis 1 . Best results are in bold, second top on Dhisto results are underlined.
5
Conclusion
In this work, we proposed a novel kernel-based contrastive learning method that leverages both continuous and discrete meta-data for pretraining. We tested it on a challenging clinical application, cirrhosis prediction, using three different datasets, including the LIHC public dataset. To the best of our knowledge, this is the first time that a pretraining strategy combining different kinds of meta-data has been proposed for such application. Our results were compared to other stateof-the-art CL methods well-adapted for cirrhosis prediction. The pretraining
Weakly-Supervised Positional Contrastive Learning
235
methods were also compared visually, using a 2D projection of the representation vectors onto the first two PCA modes. Results showed that our method has an organization in the representation space that is in line with the proposed theory, which may explain its higher performances in the experiments. As future work, it would be interesting to adapt our kernel method to non-contrastive methods, such as SimSIAM [7], BYOL [10] or Barlow Twins [25], that need smaller batch sizes and have shown greater performances in computer vision tasks. In terms of application, our method could be easily translated to other medical problems, such as pancreas cancer prediction using the presence of intrapancreatic fat, diabetes mellitus or obesity as discrete meta-labels. Acknowledgments. This work was supported by R´egion Ile-de-France (ChoTherIA project) and ANRT (CIFRE #2021/1735). Compliance with Ethical Standards. This research study was conducted retrospectively using human data collected from various medical centers, whose Ethics Committees granted their approval. Data was de-identified and processed according to all applicable privacy laws and the Declaration of Helsinki.
References 1. Azizi, S., et al.: Big self-supervised models advance medical image classification. In: 2021 IEEE/CVF International Conference on Computer Vision (ICCV), pp. 3458–3468 (2021) 2. Barbano, C.A., Dufumier, B., Duchesnay, E., Grangetto, M., Gori, P.: Contrastive learning for regression in multi-site brain age prediction. In: IEEE ISBI (2022) 3. Barbano, C.A., Dufumier, B., Tartaglione, E., Grangetto, M., Gori, P.: Unbiased Supervised Contrastive Learning. In: ICLR (2023) 4. Chaitanya, K., Erdil, E., Karani, N., Konukoglu, E.: Contrastive learning of global and local features for medical image segmentation with limited annotations. In: Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., Lin, H. (eds.) Advances in Neural Information Processing Systems. vol. 33, pp. 12546–12558. Curran Associates, Inc. (2020) 5. Chen, T., Kornblith, S., Norouzi, M., et al.: A simple framework for contrastive learning of visual representations. In: 37th International Conference on Machine Learning (ICML) (2020) 6. Chen, T., Kornblith, S., Swersky, K., et al.: Big self-supervised models are strong semi-supervised learners. In: NeurIPS (2020) 7. Chen, X., He, K.: Exploring simple Siamese representation learning. In: 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 15745–15753 (2020) 8. Dufumier, B., et al.: Contrastive learning with continuous proxy meta-data for 3D MRI classification. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 58–68. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87196-3 6 9. Erickson, B.J., Kirk, S., Lee, et al.: Radiology data from the cancer genome atlas colon adenocarcinoma [TCGA-COAD] collection. (2016)
236
E. Sarfati et al.
10. Grill, J.B., Strub, F., Altch´e, F., et al.: Bootstrap your own latent - a new approach to self-supervised learning. In: Larochelle, H., Ranzato, M., Hadsell, R., Balcan, M., Lin, H. (eds.) Advances in Neural Information Processing Systems. vol. 33, pp. 21271–21284. Curran Associates, Inc. (2020) 11. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 9726–9735 (2020) 12. He, K., Zhang, X., Ren, S., et al.: Deep residual learning for image recognition. IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 770– 778 (2016) 13. Khosla, P., Teterwak, P., Wang, C., et al.: Supervised contrastive learning. Adv. Neural. Inf. Process. Syst. 33, 18661–18673 (2020) 14. Li, Q., Yu, B., Tian, X., Cui, X., Zhang, R., Guo, Q.: Deep residual nets model for staging liver fibrosis on plain CT images. Int. J. Comput. Assist. Radiol. Surg. 15(8), 1399–1406 (2020). https://doi.org/10.1007/s11548-020-02206-y 15. Loshchilov, I., Hutter, F.: SGDR: Stochastic gradient descent with warm restarts. In: International Conference on Learning Representations (2017) 16. Mohamadnejad, M., et al.: Histopathological study of chronic hepatitis B: a comparative study of Ishak and METAVIR scoring systems. Int. J. Organ Transp. Med. 1 (2010) 17. Riba, E., Mishkin, D., Ponsa, D., Rublee, E., Bradski, G.: Kornia: an open source differentiable computer vision library for PyTorch. In: Winter Conference on Applications of Computer Vision (2020) 18. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4 28 19. Sarfati, E., Bone, A., Rohe, M.M., Gori, P., Bloch, I.: Learning to diagnose cirrhosis from radiological and histological labels with joint self and weakly-supervised pretraining strategies. In: IEEE ISBI. Cartagena de Indias, Colombia (Apr 2023) 20. Shiha, G., Zalata, K.: Ishak versus METAVIR: Terminology, convertibility and correlation with laboratory changes in chronic hepatitis C. In: Takahashi, H. (ed.) Liver Biopsy, chap. 10. IntechOpen, Rijeka (2011) 21. Taleb, A., Kirchler, M., Monti, R., Lippert, C.: Contig: Self-supervised multimodal contrastive learning for medical imaging with genetics. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 20908–20921 (June 2022) 22. Wang, X., Qi, G.J.: Contrastive learning with stronger augmentations. CoRR abs/2104.07713 (2021) 23. Wen, J., et al.: Convolutional neural networks for classification of Alzheimer’s disease: overview and reproducible evaluation. Med. Image Anal. 63, 101694 (2020) 24. Yin, Y., Yakar, D., Dierckx, R.A.J.O., Mouridsen, K.B., Kwee, T.C., de Haas, R.J.: Liver fibrosis staging by deep learning: a visual-based explanation of diagnostic decisions of the model. Eur. Radiol. 31(12), 9620–9627 (2021). https://doi.org/10. 1007/s00330-021-08046-x 25. Zbontar, J., Jing, L., Misra, I., LeCun, Y., Deny, S.: Barlow twins: Self-supervised learning via redundancy reduction. In: International Conference on Machine Learning (2021) 26. Zeng, D., et al.: Positional contrastive learning for volumetric medical image segmentation. In: MICCAI, pp. 221–230. Springer-Verlag, Berlin, Heidelberg (2021)
Weakly-Supervised Positional Contrastive Learning
237
27. Zhang, P., Wang, F., Zheng, Y.: Self supervised deep representation learning for fine-grained body part recognition. In: 2017 IEEE 14th International Symposium on Biomedical Imaging (ISBI 2017), pp. 578–582 (2017) 28. Zhou, Z., et al.: Models genesis: generic autodidactic models for 3D medical image analysis. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11767, pp. 384–393. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32251-9 42 29. Zhuang, X., Li, Y., Hu, Y., Ma, K., Yang, Y., Zheng, Y.: Self-supervised feature learning for 3D medical images by playing a Rubik’s cube. In: MICCAI (2019)
Inter-slice Consistency for Unpaired Low-Dose CT Denoising Using Boosted Contrastive Learning Jie Jing1 , Tao Wang1 , Hui Yu1 , Zexin Lu1 , and Yi Zhang2(B) 2
1 College of Computer Science, Sichuan University, Chengdu 610065, China School of Cyber Science and Engineering, Sichuan University, Chengdu 610065, China [email protected]
Abstract. The research field of low-dose computed tomography (LDCT) denoising is primarily dominated by supervised learning-based approaches, which necessitate the accurate registration of LDCT images and the corresponding NDCT images. However, since obtaining wellpaired data is not always feasible in real clinical practice, unsupervised methods have become increasingly popular for LDCT denoising. One commonly used method is CycleGAN, but the training processing of CycleGAN is memory-intensive and mode collapse may occur. To address these limitations, we propose a novel unsupervised method based on boosted contrastive learning (BCL), which requires only a single generator. Furthermore, the constraints of computational power and memory capacity often force most existing approaches to focus solely on individual slices, leading to inconsistency in the results between consecutive slices. Our proposed BCL-based model integrates inter-slice features while maintaining the computational cost at an acceptable level comparable to most slice-based methods. Two modifications are introduced to the original contrastive learning method, including weight optimization for positive-negative pairs and imposing constraints on difference invariants. Experiments demonstrate that our method outperforms existing several state-of-the-art supervised and unsupervised methods in both qualitative and quantitative metrics. Keywords: Low-dose computed tomography image denoising · machine learning
1
· unsupervised learning ·
Introduction
Computed tomography (CT) is a common tool for medical diagnosis but increased usage has led to concerns about the possible risks caused by excessive radiation exposure. The well-known ALARA (as low as reasonably achievable) This work was supported in part by the National Natural Science Foundation of China under Grant 62271335; in part by the Sichuan Science and Technology Program under Grant 2021JDJQ0024; and in part by the Sichuan University “From 0 to 1” Innovative Research Program under Grant 2022SCUH0016. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 238–247, 2023. https://doi.org/10.1007/978-3-031-43907-0_23
Unpaired Low-Dose CT Denoising
239
[20] principle is widely adopted to reduce exposure based on the strategies such as sparse sampling and tube flux reduction. However, reducing radiation dose will degrade the imaging quality and then inevitably jeopardize the subsequent diagnoses. Various algorithms have been developed to address this issue, which can be roughly categorized into sinogram domain filtration [16], iterative reconstruction [2,9], and image post-processing [1,8,14]. Recently, deep learning (DL) has been introduced for low-dose computed tomography (LDCT) image restoration. The utilization of convolutional neural networks (CNNs) for image super-resolution, as described in [7], outperformed most conventional techniques. As a result, it was subsequently employed for LDCT in [6]. The DIP [21] method is an unsupervised image restoration technique that leverages the inherent ability of untrained networks to capture image statistics. Other methods that do not require clean images are also used in this field [11,23]. Various network architectures have been proposed, such as RED [5] and MAP-NN [18]. The choice of loss function also significantly affects model performance. Perceptual loss [12] based on the pretrained VGG [19] was proposed to mitigate over-smoothing caused by MSE. Most DL techniques for LDCT denoising are supervised models, but unsupervised learning frameworks which do not need paired data for training like GANs [3,10,26], Invertible Network [4] and CUT [25] have also been applied for LDCT [13,15,24]. This study presents a novel unsupervised framework for denoising low-dose CT (LDCT) images, which utilizes contrastive learning (CL) and doesn’t require paired data. Our approach possesses three major contributions as follows: Firstly, We discard the use of CycleGAN that most unpaired frameworks employ, instead adopting contrastive learning to design the training framework. As a result, the training process becomes more stable and imposes a lesser computational burden. Secondly, our approach can adapt to almost all end-to-end image translation neural networks, demonstrating excellent flexibility. Lastly, the proposed interslice consistency loss makes our model generates stable output quality across slices, in contrast to most slice based methods that exhibit inter-slice instability. Our model outperforms almost all other models in this regard, making it the superior option for LDCT denoising. Further experimental data about this point will be presented in this paper.
2
Method
LDCT image denoising can be expressed as a noise reduction problem in the image domain as x ˆ = f (x), where x ˆ and x denote the denoised output and corresponding LDCT image. f represents the denoising function. Rather than directly denoising LDCT images, an encoder-decoder model is used to extract important features from the LDCT images and predict corresponding NDCT images. Most CNN-based LDCT denoising models are based on supervised learning and require both the LDCT and its perfectly paired NDCT images to learn f . However, it is infeasible in real clinical practice. Currently, some unsupervised models, including CUT and CycleGAN, relax the constraint on requiring paired data for training. Instead, these models can be trained with unpaired data.
240
2.1
J. Jing et al.
Contrastive Learning for Unpaired Data
The task of LDCT image denoising can be viewed as an image translation process from LDCT to NDCT. CUT provides a powerful framework for training a model to complete image-to-image translation tasks. The main concept behind CUT is to use contrastive learning for enhanced feature extraction aided by an adversarial loss. The key principle of contrastive learning is to create positive and negative pairs of samples, in order to help the model gain strong feature representation ability. The loss of contrastive learning can be formulated as: l(v, v + , v − ) = −log[
exp(v · v + /τ ) ], N exp(v · v + /τ ) + n=1 exp(v · vn− /τ )
(1)
where v, v + , v − denote the anchors, positive and negative pairs, respectively. N is the number of negative pairs. τ is the temperature factor which is set to 0.07 in this paper. The generator G we used contains two parts, an encoder E and a decoder. A simple MLP H is used to module the features extracted from the encoder. The total loss of CUT for image translation is defined as: L = LGAN (G, D, X, Y )+λ1 LP atchN CE (G, H, X)+λ2 LP atchN CE (G, H, Y ), (2) where D denotes the discriminator. X represents the input images, for which LP atchN CE (G, H, X) utilizes contrastive learning in the source domain (represented by noisy images). Y indicates the images in the target domain, which means NDCT images in this paper. LP atchN CE (G, H, Y ) employs contrastive learning in this target domain. As noted in a previous study [25], this component plays a similar role as the identity loss in CycleGAN. In this work, λ1 and λ2 are both set to 1. Since CT images are three-dimensional data, we can identify more negative pairs between different slices. The strategy about how we design positive and negative pairs for our proposed model is illustrated in Fig. 1. As shown in Fig. 1, we select two negative patches from the same slice as the anchor, as well as one from the previous slice and the other from the next slice. It is important to note that these patches are not adjacent, since neighbored slices are nearly identical. Similar to most contrastive learning methods, we use cosine similarity to compute the feature similarity. 2.2
Contrastive Learning for Inter-slice Consistency
Due to various constraints, most denoising methods for LDCT can only perform on the slice plane, resulting in detail loss among different slices. While 3D models can mitigate this issue to a certain degree, they require significant computational costs and are prone to model collapse during training, leading to a long training time. Additionally, most methods are unable to maintain structural consistency between slices with certain structures (e.g., bronchi and vessels) appearing continuously across several adjacent slices.
Unpaired Low-Dose CT Denoising
241
Fig. 1. The construction of training sample pairs for contrastive learning. The generator G is composed of the encoder E and the decoder. The anchor is represented by the red box, while the negative patches are indicated by blue boxes. Two of these negative patches come from the same slice but have different locations, while the other two come from different slices but have the same locations. The green box represents the positive patch, which is located in the same position as the anchor shown in the generated image. (Color figure online)
To address this issue, we design an inter-slice consistency loss based on contrastive learning. This approach helps to maintain structural consistency between slices, and then improve the overall denoising performance. As illustrated in Fig. 2, we begin by randomly selecting the same patch from both the input (LDCT) and the generated denoised result. These patches are passed through the encoder E, allowing us to obtain the feature representation for each patch. Next, we perform a feature subtraction of each inter-slice pair. The output can be interpreted as the feature difference between slices. We assume that the feature difference between the same pair of slices should be similar, which is formulated as follows: H(E(P (Xt )))−H(E(P (Xt+1 ))) = H(E(P (G(Xt ))))−H(E(P (G(Xt+1 )))), (3) where P denotes the patch selection function. A good denoising generator can minimize the feature difference between similar slices while maximizing the feature difference between different slices. By utilizing contrastive learning, we can treat the former condition as a positive pair and the latter as a negative pair. After computing the cosine similarity of the pairs, a softmax operation is applied to assign 1 to the positive pairs and 0 to the negative pairs. Compared to the original contrastive learning, which focuses on patch pairs, we apply this technique to measure feature differences, which stabilizes the features and improves the consistency between slices.
242
J. Jing et al.
Fig. 2. Contrastive learning ultilized for stablizing inter-slice features. Patches are extracted from the same location in three consecutive slices.
2.3
Boosted Contrastive Learning
Original contrastive Learning approaches treat every positive and negative pair equally. However, in CT images, some patches may be very similar to others (e.g., patches from the same organ), while others may be completely different. Therefore, assigning the same weight to different pairs may not be appropriate. [25] demonstrated that fine-tuning the weights between pairs can significantly improve the performance of contrastive learning. For our inter-slice consistency loss, only one positive and negative pair can be generated at a time, making it unnecessary to apply reweighting. However, we include additional negative pairs in the patchNCE loss for unpaired translation, making reweighting between pairs more critical than in the original CUT model. As a result, Eq. 1 is updated as follows: l(v, v + , v − ) = −log[
exp(v · v + /τ ) ], N exp(v · v + /τ ) + n=1 wn exp(v · v − /τ )
(4)
where w stands for a weight factor for each negative patch. According to [25], using “easy weighting” is more effective for unpaired tasks, which involves assigning higher weights to easy negative samples (i.e., samples that are easy to distinguish from the anchor). This finding contradicts most people’s intuition. Nonetheless, we have demonstrated that their discovery is accurate in our specific scenario. The reweighting approach we have employed is defined as follows: exp(1 − v · vn− )/τ ) . (5) wn = N − j=1 exp((1 − v · vj )/τ ) j=n
Unpaired Low-Dose CT Denoising
243
In summary, the less similar two patches are, the easier they can be distinguished, the more weight the pair is given for learning purposes.
3 3.1
Experiments Dataset and Training Details
While our method only requires unpaired data for training, many of the compared methods rely on paired NDCT. We utilized the dataset provided by the Mayo Clinic called “NIH-AAPM-Mayo Clinic Low Dose CT Grand Challenge” [17], which offers paired LD-NDCT images. The model parameters were initialized using a random Gaussian distribution with zero-mean and standard deviation of 10−2 . The learning rate for the optimizer was set to 10−4 and halved every 5 epochs for 20 epochs total. The experiments were conducted in Python on a server with an RTX 3090 GPU. Two metrics, peak signal-to-noise ratio (PSNR) and structural similarity index measure (SSIM) [22], were employed to quantitatively evaluate the image quality. The image data from five individuals were used as the training set and the data from other two individuals were used for the test set. 3.2
Comparison of Different Methods
To demonstrate the denoising performance of our model, we conducted experiments to compare our method with various types of denoising methods including unsupervised denoising methods that only use LDCT data, fully supervised methods that use perfectly registered LDCT and NDCT pairs, and semisupervised methods, including CycleGAN and CUT, which utilize unpaired data. A representative slice processed by different methods is shown in Fig. 3. The window center is set to 40 and the window width is set to 400. Our framework is flexible and can work with different autoencoder frameworks. In our experiments, the well-known residual encoder-decoder network (RED) was adopted as our network backbone. The quantitative results and computational costs of unsupervised methods are presented in Table 1. It can be seen that our method produces promising denoising results, with obvious numerical improvements compared to other unsupervised and semi-supervised methods. As shown in Table 2, our score is very close to our backbone model when trained fully supervised. Our model even got higher PSNR value. Moreover, our framework is lightweight, which has a similar model scale to RED. It’s worth noting that adding perceptual loss to our model will decrease the PSNR result, and it is consistent with the previous studies that perceptual loss may maintain more details but decrease the MSE-based metric, such as PSNR. Furthermore, the reweighting mechanism demonstrates its effectiveness in improving our model’s results. The improvement by introducing the reweighting mechanism can be easily noticed.
244
J. Jing et al.
Fig. 3. Methods comparison. DIP and BM3D are fully unsupervised, RED and WGAN models will require paired dataset, CycleGAN(“Cycle” in figure), CUT and ours will use unpaired dataset. “(P)” means perceptual loss is added. “(W)” means proposed re-weight mechanism is applied. (Color figure online) Table 1. Metrics comparison for unsupervised methods. “(P)” means perceptual loss is added. “(W)” means proposed re-weight mechanism is applied.
Metrics
DIP
BM3D CycleGAN CUT
Ours
Ours(W) Ours(P) LDCT
PSNR
26.79
26.64
28.82
28.15
28.88
29.09
28.81
22.33
SSIM
0.86
0.82
0.91
0.86
0.90
0.91
0.91
0.63
MACs(G)
75.64 NaN
1576.09
496.87 521.36 521.36
521.36
NaN
7.58
3.82
1.92
NaN
Params(M) 2.18
3.3
NaN
1.92
1.92
Line Plot over Slices
Although our method may only be competitive with supervised methods, we are able to demonstrate the effectiveness of our proposed inter-slice consistency loss. The line plot in Fig. 4 shows the pixel values at point (200, 300) across different slices. In Fig. 4, it can be observed that our method effectively preserves the interslice consistency of features, which is clinically important for maintaining the structural consistency of the entire volume. Although the supervised model achieves a similar overall score to our model, the results across slices of our model are closer to the ground truth (GT), especially when pixel value changes dramatically.
Unpaired Low-Dose CT Denoising
245
Table 2. Metrics comparison for supervised methods. “(P)” means perceptual loss is added. “(W)” means proposed re-weight mechanism is applied. Metrics
RED(MSE) RED(P) WGAN Ours(W) LDCT
PSNR
29.06
28.74
27.75
29.09
22.33
SSIM
0.92
0.92
0.89
0.91
0.63
MACs(G)
462.53
462.53
626.89
521.36
NaN
1.85
2.52
1.92
NaN
Params(M) 1.85
Fig. 4. Inter slice HU value line plot.
3.4
Discussion
Our method achieves competitive results and obtains the highest PSNR value in all the methods with unpaired samples. Although we cannot surpass supervised methods in terms of some metrics, our method produces promising results across consecutive slices that are more consistent and closer to the GT.
4
Conclusion
In this paper, we introduce a novel low-dose CT denoising model. The primary motivation for this work is based on the fact that most CNN-based denoising models require paired LD-NDCT images, while we usually can access unpaired CT data in clinical practice. Furthermore, many existing methods using unpaired samples require extensive computational costs, which can be prohibitive for clinical use. In addition, most existing methods focus on a single slice, which results in inconsistent results across consecutive slices. To overcome these limitations, we propose a novel unsupervised method based on contrastive learning that only requires a single generator. We also apply modifications to the original contrastive learning method to achieve SOTA denoising results using relatively a low computational cost. Our experiments demonstrate that our method outperforms existing SOTA supervised, semi-supervised, and unsupervised methods in both qualitative and quantitative measures. Importantly, our framework does not require paired training data and is more adaptable for clinical use.
246
J. Jing et al.
References 1. Aharon, M., Elad, M., Bruckstein, A.: K-SVD: an algorithm for designing overcomplete dictionaries for sparse representation. IEEE Trans. Signal Process. 54(11), 4311–4322 (2006) 2. Beister, M., Kolditz, D., Kalender, W.A.: Iterative reconstruction methods in X-ray CT. Phys. Med. 28(2), 94–108 (2012) 3. Bera, S., Biswas, P.K.: Axial consistent memory GAN with interslice consistency loss for low dose computed tomography image denoising. IEEE Trans. Radiation Plasma Med. Sci. (2023) 4. Bera, S., Biswas, P.K.: Self supervised low dose computed tomography image denoising using invertible network exploiting inter slice congruence. In: Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision, pp. 5614–5623 (2023) 5. Chen, H., et al.: Low-dose CT with a residual encoder-decoder convolutional neural network. IEEE Trans. Med. Imaging 36(12), 2524–2535 (2017) 6. Chen, H., Zet al.: Low-dose CT denoising with convolutional neural network. In: 2017 IEEE 14th International Symposium on Biomedical Imaging (ISBI 2017), pp. 143–146. IEEE (2017) 7. Dong, C., Loy, C.C., He, K., Tang, X.: Image super-resolution using deep convolutional networks. IEEE Trans. Pattern Anal. Mach. Intell. 38(2), 295–307 (2015) 8. Feruglio, P.F., Vinegoni, C., Gros, J., Sbarbati, A., Weissleder, R.: Block matching 3D random noise filtering for absorption optical projection tomography. Phys. Med. Biol. 55(18), 5401 (2010) 9. Geyer, L.L., et al.: State of the art: iterative CT reconstruction techniques. Radiology 276(2), 339–357 (2015) 10. Goodfellow, I.J., et al.: Generative adversarial networks. arXiv preprint arXiv:1406.2661 (2014) 11. Jing, J., et al.: Training low dose CT denoising network without high quality reference data. Phys. Med. Biol. 67(8), 084002 (2022) 12. Johnson, J., Alahi, A., Fei-Fei, L.: Perceptual losses for real-time style transfer and super-resolution. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9906, pp. 694–711. Springer, Cham (2016). https://doi.org/10. 1007/978-3-319-46475-6 43 13. Jung, C., Lee, J., You, S., Ye, J.C.: Patch-wise deep metric learning for unsupervised low-dose ct denoising. In: International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 634–643. Springer (2022). https:// doi.org/10.1007/978-3-031-16446-0 60 14. Kang, D., et al.: Image denoising of low-radiation dose coronary CT angiography by an adaptive block-matching 3D algorithm. In: Medical Imaging 2013: Image Processing. vol. 8669, pp. 86692G. International Society for Optics and Photonics (2013) 15. Li, Z., Huang, J., Yu, L., Chi, Y., Jin, M.: Low-dose CT image denoising using cycle-consistent adversarial networks. In: 2019 IEEE Nuclear Science Symposium and Medical Imaging Conference (NSS/MIC), pp. 1–3 (2019). https://doi.org/10. 1109/NSS/MIC42101.2019.9059965 16. Manduca, A., et al.: Projection space denoising with bilateral filtering and CT noise modeling for dose reduction in CT. Med. Phys. 36(11), 4911–4919 (2009) 17. Moen, T.R., et al.: Low-dose CT image and projection dataset: . Med. Phys. 48, 902–911 (2021)
Unpaired Low-Dose CT Denoising
247
18. Shan, H., et al.: Competitive performance of a modularized deep neural network compared to commercial algorithms for low-dose CT image reconstruction. Nature Mach. Intell. 1(6), 269–276 (2019) 19. Simonyan, K., Zisserman, A.: Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556 (2014) 20. Smith-Bindman, R., et al.: Radiation dose associated with common computed tomography examinations and the associated lifetime attributable risk of cancer. Arch. Intern. Med. 169(22), 2078–2086 (2009) 21. Ulyanov, D., Vedaldi, A., Lempitsky, V.S.: Deep image prior. CoRR abs/1711.10925 (2017). https://arxiv.org/abs/1711.10925 22. Wang, Z., Bovik, A., Sheikh, H., Simoncelli, E.: Image quality assessment: from error visibility to structural similarity. IEEE Trans. Image Process. 13(4), 600–612 (2004). https://doi.org/10.1109/TIP.2003.819861 23. Wu, D., Gong, K., Kim, K., Li, X., Li, Q.: Consensus neural network for medical imaging denoising with only noisy training samples. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11767, pp. 741–749. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32251-9 81 24. Yang, Q., et al.: Low-dose CT image denoising using a generative adversarial network with Wasserstein distance and perceptual loss. IEEE Trans. Med. Imaging 37(6), 1348–1357 (2018). https://doi.org/10.1109/TMI.2018.2827462 25. Zhan, F., Zhang, J., Yu, Y., Wu, R., Lu, S.: Modulated contrast for versatile image synthesis. In: 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR). pp. 18259–18269 (2022).https://doi.org/10.1109/ CVPR52688.2022.01774 26. Zhu, J.Y., Park, T., Isola, P., Efros, A.A.: Unpaired image-to-image translation using cycle-consistent adversarial networks. In: Computer Vision (ICCV), 2017 IEEE International Conference on (2017)
DAS-MIL: Distilling Across Scales for MIL Classification of Histological WSIs Gianpaolo Bontempo1,2 , Angelo Porrello1 , Federico Bolelli1(B) , Simone Calderara1 , and Elisa Ficarra1 1
University of Modena and Reggio Emilia, Modena, Italy {gianpaolo.bontempo,angelo.porrello,federico.bolelli, simone.calderara,elisa.ficarra}@unimore.it, [email protected] 2 University of Pisa, Pisa, Italy
Abstract. The adoption of Multi-Instance Learning (MIL) for classifying Whole-Slide Images (WSIs) has increased in recent years. Indeed, pixel-level annotation of gigapixel WSI is mostly unfeasible and timeconsuming in practice. For this reason, MIL approaches have been profitably integrated with the most recent deep-learning solutions for WSI classification to support clinical practice and diagnosis. Nevertheless, the majority of such approaches overlook the multi-scale nature of the WSIs; the few existing hierarchical MIL proposals simply flatten the multiscale representations by concatenation or summation of features vectors, neglecting the spatial structure of the WSI. Our work aims to unleash the full potential of pyramidal structured WSI; to do so, we propose a graphbased multi-scale MIL approach, termed DAS-MIL, that exploits message passing to let information flows across multiple scales. By means of a knowledge distillation schema, the alignment between the latent space representation at different resolutions is encouraged while preserving the diversity in the informative content. The effectiveness of the proposed framework is demonstrated on two well-known datasets, where we outperform SOTA on WSI classification, gaining a +1.9% AUC and +3.3% accuracy on the popular Camelyon16 benchmark. The source code is available at https://github.com/aimagelab/mil4wsi.
Keywords: Whole-slide Images Knowledge Distillation
1
· Multi-instance Learning ·
Introduction
Modern microscopes allow the digitalization of conventional glass slides into gigapixel Whole-Slide Images (WSIs) [18], facilitating their preservation and Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 24. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 248–258, 2023. https://doi.org/10.1007/978-3-031-43907-0_24
DAS-MIL: Distilling Across Scales for MIL Classification of WSIs
249
Fig. 1. Overview of our proposed framework, DAS-MIL. The features extracted at different scales are connected (8-connectivity) by means of different graphs. The nodes of both graphs are later fused into a third one, respecting the rule “part of”. The contextualized features are then passed to distinct attention-based MIL modules that extract bag labels. Furthermore, a knowledge distillation mechanism encourages the agreement between the predictions delivered by different scales.
retrieval, but also introducing multiple challenges. On the one hand, annotating WSIs requires strong medical expertise, is expensive, time-consuming, and labels are usually provided at the slide or patient level. On the other hand, feeding modern neural networks with the entire gigapixel image is not a feasible approach, forcing to crop data into small patches and use them for training. This process is usually performed considering a single resolution/scale among those provided by the WSI image. Recently, Multi-Instance Learning (MIL) emerged to cope with these limitations. MIL approaches consider the image slide as a bag composed of many patches, called instances; afterwards, to provide a classification score for the entire bag, they weigh the instances through attention mechanisms and aggregate them into a single representation. It is noted that these approaches are intrinsically flat and disregard the pyramidal information provided by the WSI [15], which have been proven to be more effective than single-resolution [4,13,15,19]. However, to the best of our knowledge, none of the existing proposals leverage the full potential of the WSI pyramidal structure. Indeed, the flat concatenation of features [19] extracted at different resolutions does not consider the substantial difference in the informative content they provide. A proficient learning approach should instead consider the heterogeneity between global structures and local cellular regions, thus allowing the information to flow effectively across the image scales. To profit from the multi-resolution structure of WSI, we propose a pyramidal Graph Neural Network (GNN) framework combined with (self) Knowledge Distillation (KD), called DAS-MIL (Distilling Across Scales). A visual representation of the proposed approach is depicted in Fig. 1. Distinct GNNs provide contextualized features, which are fed to distinct attention-based MIL modules that compute bag-level predictions. Through knowledge distillation, we encour-
250
G. Bontempo et al.
age agreement across the predictions delivered at different resolutions, while individual scale features are learned in isolation to preserve the diversity in terms of information content. By transferring knowledge across scales, we observe that the classifier self-improves as information flows during training. Our proposal has proven its effectiveness on two well-known histological datasets, Camelyon16 and TCGA lung cancer, obtaining state-of-the-art results on WSI classification.
2
Related Work
MIL Approaches for WSI Classification. We herein summarize the most recent approaches; we refer the reader to [11,26] for a comprehensive overview. Single-Scale. A classical approach is represented by AB-MIL [16], which employs a side-branch network to calculate the attention scores. In [28], a similar attention mechanism is employed to support a double-tier feature distillation approach, which distills features from pseudo-bags to the original slide. Differently, DS-MIL [19] applies non-local attention aggregation by considering the distance with the most relevant patch. The authors of [20] and [25] propose variations of AB-MIL, which introduce clustering losses and transformers, respectively. In addition, SETMIL [31] makes use of spatial-encoding transformer layers to update the representation. The authors of [7] leverage DINO [5] as feature extractor, highlighting its effectiveness for medical image analysis. Beyond classical attention mechanisms, there are also algorithms based on Recurrent Neural Networks (RNN) [4], and Graphs Neural Networks (GNN) [32]. Multi-Scale. Recently, different authors focused on multi-resolution approaches. DSMIL-LC [19] merges representations from different resolutions, i.e., low instance representations are concatenated with the ones obtained at a higher resolution. MS-RNNMIL [4], instead, fed an RNN with instances extracted at different scales. In [6], a self-supervised hierarchical transformer is applied at each scale. In MS-DA-MIL [13], multi-scale features are included in the same attention algorithm. [10] and [15] exploit multi-resolution through GNN architectures. Knowledge Distillation. Distilling knowledge from a more extensive network (teacher ) to a smaller one (student) has been widely investigated in recent years [21,24] and applied to different fields, ranging from model compression [3] to WSI analysis [17]. Typically, a tailored learning objective encourages the student to mimic the behaviour of its teacher. Recently, self-supervised representation learning approaches have also employed such a schema: as an example, [5,9] exploit KD to obtain an agreement between networks fed with different views of the same image. In [28], KD is used to transfer the knowledge between MIL tiers applied on different subsamples bags. Taking inspiration from [23] and [30], our model applies (self) knowledge distillation between WSI scale resolutions.
DAS-MIL: Distilling Across Scales for MIL Classification of WSIs
3
251
Method
Our approach aims to promote the information flow through the different employed resolutions. While existing works [19,20,25] take into account interscales interactions by mostly leveraging trivial operations (such as concatenation of related feature representations), we instead provide a novel technique that builds upon: i) a GNN module based on message passing, which propagates patches’ representation according to the natural structure of multi-resolutions WSI; ii) a regulation term based on (self) knowledge distillation, which pins the most effective resolution to further guide the training of the other one(s). In the following, we are delving into the details of our architecture. Feature Extraction. Our work exploits DINO, the self-supervised learning approach proposed in [5], to provide a relevant representation of each patch. Differently from other proposals [19,20,28], it focuses solely on aligning positive pairs during optimization (and hence avoids negative pairs), which has shown to require a lower memory footprint during training. We hence devise an initial stage with multiple self-supervised feature extractors f (·; θ1 ), . . . , fM (·; θM ), one dedicated to each resolution: this way, we expect to promote feature diversity across scales. After training, we freeze the weights of these networks and use them as patch-level feature extractors. Although we focus only on two resolutions at time (i.e., M = 2) the approach can be extended to more scales. Architecture. The representations yield by DINO provide a detailed description of the local patterns in each patch; however, they retain poor knowledge of the surrounding context. To grasp a global guess about the entire slide, we allow patches to exchange local information. We achieve it through a Pyramidal Graph Neural Network (PGNN) in which each node represents an individual WSI patch seen at different scales. Each node is connected to its neighbors (8-connectivity) in the euclidean space and between scales following the relation “part of”1 . To perform message passing, we adopt Graph ATtention layers (GAT) [27]. In general terms, such a module takes as input multi-scale patch-level representations X = [X1 X2 ], where X1 ∈ RN1 ×F and X2 ∈ RN2 ×F are respectively the representations of the lower and higher scale. The input undergoes two graph layers: while the former treats the two scales as independent subgraphs A1 ∈ RN1 ×N1 and A2 ∈ RN2 ×N2 , the latter process them jointly by considering the entire graph A (see Fig. 1, left). In formal terms: H = PGNN(X ; A, A1 , A2 , θPGNN ) = GAT([GAT(X1 ; A1 , θ1 )GAT(X2 ; A2 , θ2 )]; A, θ3 ), where H ≡ [H1 H2 ] stands for the output of the PGNN obtained by concatenating the two scales. These new contextualized patch representations are then fed to the attention-based MIL module proposed in [19], which produces bag-level 1
The relation “part of” connects a parent WSI patch (lying in the lower resolution) with its children, i.e., the higher-scale patches it contains.
252
G. Bontempo et al.
scores y1BAG , y2BAG ∈ R1×C where C equals the number of classes. Notably, such a module provides additional importance scores z1 ∈ RN1 and z2 ∈ RN2 , which quantifies the importance of each original patch to the overall prediction. Aligning Scales with (Self ) Knowledge Distillation. We have hence obtained two distinct sets of predictions for the two resolutions: namely, a bag-level score (e.g., a tumor is either present or not) and a patch-level one (e.g., which instances contribute the most to the target class). However, as these learned metrics are inferred from different WSI zooms, a disagreement may emerge: indeed, we have observed (see Table 4) that the higher resolutions generally yield better classification performance. In this work, we exploit such a disparity to introduce two additional optimization objectives, which pin the predictions out of the higher scale as teaching signal for the lower one. Further than improving the results of the lowest scale only, we expect its benefits to propagate also to the shared message-passing module, and so to the higher resolution. Formally, the first term seeks to align bag predictions from the two scales through (self) knowledge distillation [14,29]: LKD = τ 2 KL(softmax(
y BAG y1BAG ) softmax( 2 )), τ τ
(1)
where KL stands for the Kullback-Leibler divergence and τ is a temperature that lets secondary information emerge from the teaching signal. The second aligning term regards the instance scores. It encourages the two resolutions to assign criticality scores in a consistent manner: intuitively, if a lowresolution patch has been considered critical, then the average score attributed to its children patches should be likewise high. We encourage such a constraint by minimizing the Euclidean distance between the low-resolution criticality grid map z1 and its subsampled counterpart computed by the high-resolution branch: LCRIT = z1 − GraphPooling(z2 )22 .
(2)
In the equation above, GraphPooling identifies a pooling layer applied over the higher scale: to do so, it considers the relation “part of” between scales and then averages the child nodes, hence allowing the comparison at the instance level. Overall Objective. To sum up, the overall optimization problem is formulated as a mixture of two objectives: the one requiring higher conditional likelihood w.r.t. ground truth labels y and carried out through the Cross-Entropy loss LCE (·; y); the other one based on knowledge distillation: min (1 − λ)LCE (y2BAG ) + LCE (y1BAG ) + λLKD + βLCRIT , θ
(3)
where λ is a hyperparameter weighting the tradeoff between the teaching signals provided by labels and the higher resolution, while β balances the contributions of the consistency regularization introduced in Eq. (2).
DAS-MIL: Distilling Across Scales for MIL Classification of WSIs
4
253
Experiments
WSIs Pre-processing. We remove background patches through an approach similar to the one presented in the CLAM framework [20]: after an initial segmentation process based on Otsu [22] and Connected Component Analysis [2], non-overlapped patches within the foreground regions are considered. Optimization. We use Adam as optimizer, with a learning rate of 2 × 10−4 and a cosine annealing scheduler (10−5 decay w/o warm restart). We set τ = 1.5, β = 1, and λ = 1. The DINO feature extractor has been trained with two RTX5000 GPUs: differently, all subsequent experiments have been performed with a single RTX2080 GPU using Pytorch-Geometric [12]. To asses the performance of our approach, we adhere to the protocol of [19,28] and use the accuracy and AUC metrics. Moreover, the classifier on the higher scale has been used to make the final overall prediction. Regarding the KD loss, we apply the temperature term to both student and teacher outputs for numerical stability. Table 1. Comparison with state-of-the-art solutions. Results marked with “†” have been calculated on our premises as the original papers lack the specific settings; all the other numbers are taken from [19, 28]. Method
Camelyon16 Accuracy AUC
TCGA Lung Accuracy AUC
Single Scale Mean-pooling † 0.723 Max-pooling † 0.893 MILRNN [4] 0.806 ABMIL [16] 0.845 CLAM-SB [20] 0.865 CLAM-MB [20] 0.850 Trans-MIL † [25] 0.883 DTFD (AFS) [28] 0.908 DTFD (MaxMinS) [28] 0.899 DSMIL † [19] 0.915
0.672 0.899 0.806 0.865 0.885 0.894 0.942 0.946 0.941 0.952
0.823 0.851 0.862 0.900 0.875 0.878 0.881 0.891 0.894 0.888
0.905 0.909 0.911 0.949 0.944 0.949 0.948 0.951 0.961 0.951
Multi Scale
0.887 0.837 0.951 0.955 0.912 0.973
0.900 0.891 0.890 0.913 0.823 0.925
0.955 0.921 0.950 0.964 0.917 0.965
MS-DA-MIL [13] MS-MILRNN [4] HIPT † [6] DSMIL-LC † [19] H2 -MIL † [15] DAS-MIL (ours)
0.876 0.814 0.890 0.909 0.859 0.945
Camelyon16. [1] We adhere to the official training/test sets. To produce the fairest comparison with the single-scale state-of-the-art solution, the 270 remaining WSIs are split into training and validation in the proportion 9:1.
254
G. Bontempo et al.
TCGA Lung Dataset. It is available on the GDC Data Transfer Portal and comprises two subsets of cancer: Lung Adenocarcinoma (LUAD) and Lung Squamous Cell Carcinoma (LUSC), counting 541 and 513 WSIs, respectively. The aim is to classify LUAD vs LUSC; we follow the split proposed by DSMIL [19]. 4.1
Comparison with the State-of-the-art
Table 1 compares our DAS-MIL approach with the state-of-the-art, including both single- and multi-scale architectures. As can be observed: i) the joint exploitation of multiple resolutions is generally more efficient; ii) our DAS-MIL yields robust and compelling results, especially on Camelyon16, where it provides 0.945 of accuracy and 0.973 AUC (i.e., an improvement of +3.3% accuracy and +1.9% AUC with respect to the SOTA). Finally, we remark that most of the methods in the literature resort to different feature extractors; however, the next subsections prove the consistency of DAS-MIL benefits across various backbones.
Table 2. Impact (AUC, Camelyon16) of Eq. 3 hyperparameters. λ
× 20×
× 10×
β
× 20×
× 10×
1.0 0.973 0.974 1.5 0.964
0.968
0.8 0.967
0.966
1.2 0.970
0.964
0.5 0.968
0.932
1.0 0.973 0.974
0.3 0.962
0.965
0.8 0.962
0.965
0.0 0.955
0.903
0.6 0.951
0.953
4.2
Table 3. Impact (Camelyon16) of KD temperature (Eq. 1), α = β = 1.0. τ
Accuracy × 10× × 20×
τ =1
0.883
0.962
0.906
0.957
τ = 1.3 0.898
0.958
0.891
0.959
AUC × 10× × 20×
τ = 1.5 0.945 0.945 0.973 0.974 τ =2
0.906
0.914
0.962
0.963
τ = 2.5 0.922
0.914
0.951
0.952
Model Analysis
On the Impact of Knowledge Distillation. To assess its merits, we conducted several experiments varying the values of the corresponding balancing coefficients (see Table 2). As can be observed, lowering their values (even reaching λ = 0, i.e., no distillation is performed) negatively affects the performance. Such a statement holds not only for the lower resolution (as one could expect), but also for the higher one, thus corroborating the claims we made in Sect. 3 on the bidirectional benefits of knowledge distillation in our multi-scale architecture.
DAS-MIL: Distilling Across Scales for MIL Classification of WSIs
We have also performed an assessment on the temperature τ , which controls the smoothing factor applied to teacher’s predictions (Table 3). We found that the lowest the temperature, the better the results, suggesting that the teacher scale is naturally not overconfident about its predictions, but rather well-calibrated.
255
Table 4. Comparison between scales. The target column indicates the features passed to the two MIL layers: the “” symbol indicates that they have been previously concatenated. Input Scale 10× 20× 5×, 20× 5×, 20× 10×, 20× 10×, 20×
MIL Target(s) 10× 20× 5×, 20× 5×, [5× 20×] 10×, 20× 10×, [10× 20×]
Accuracy 0.818 0.891 0.891 0.898 0.945 0.922
AUC 0.816 0.931 0.938 0.941 0.973 0.953
Single-Scale vs Multi-Scale. Table 4 demonstrates the contribution of hierarchical representations. For singlescale experiments, the model is fed only with patches extracted at a single reference scale. For what concerns multi-scale results, representations can be combined in different ways. Overall, the best results are obtained with 10× and 20× input resolutions; the table also highlights that 5× magnitude is less effective and presents a worst discriminative capability. We ascribe it to the specimenlevel pixel size relevant for cancer diagnosis task; different datasets/tasks may benefit from different scale combinations. Table 5. Comparison between DAS-MIL with and w/o (✗) the graph contextualization mechanism, and the most recent graph-based multi-scale approach H2 -MIL, when using different resolutions as input (5× and 20×). Feature Extractor Graph Mechanism Camelyon16 Acc. AUC
TCGA Lung Acc. AUC
SimCLR
✗
0.859
SimCLR
DAS-MIL
0.906 0.928 0.883 0.9489
SimCLR
H2 -MIL
0.836
0.857
0.826
0.916
DINO
✗
0.852
0.905
0.906
0.956
DINO
DAS-MIL
0.891 0.938 0.925 0.965
DINO
H2 -MIL
0.859
0.869
0.912
0.864
0.823
0.932
0.917
The Impact of the Feature Extractors and GNNs. Table 5 proposes an investigation of these aspects, which considers both SimCRL [8] and DINO, as well as the recently proposed graph mechanism H2 -MIL [15]. In doing so, we fix the input resolutions to 5× and 20×. We draw the following conclusions: i) when our DAS-MIL feature propagation layer is used, the selection of the optimal feature extractor (i.e., SimCLR vs Dino) has less impact on performance, as the message-passing can compensate for possible lacks in the initial representation; ii) DAS-MIL appears a better features propagator w.r.t. H2 -MIL.
256
G. Bontempo et al.
H2 -MIL exploits a global pooling layer (IHPool) that fulfils only the spatial structure of patches: as a consequence, if non-tumor patches surround a tumor patch, its contribution to the final prediction is likely to be outweighed by the IHPool module of H2 -MIL. Differently, our approach is not restricted in such a way, as it can dynamically route the information across the hierarchical structure (also based on the connections with the critical instance).
5
Conclusion
We proposed a novel way to exploit multiple resolutions in the domain of histological WSI. We conceived a novel graph-based architecture that learns spatial correlation at different WSI resolutions. Specifically, a GNN cascade architecture is used to extract context-aware and instance-level features considering the spatial relationship between scales. During the training process, this connection is further amplified by a distillation loss, asking for an agreement between the lower and higher scales. Extensive experiments show the effectiveness of the proposed distillation approach. Acknowledgement. This project has received funding from DECIDER, the European Union’s Horizon 2020 research and innovation programme under GA No. 965193, and from the Department of Engineering “Enzo Ferrari” of the University of Modena through the FARD-2022 (Fondo di Ateneo per la Ricerca 2022). We also acknowledge the CINECA award under the ISCRA initiative, for the availability of high performance computing resources and support.
References 1. Bejnordi, B.E., et al.: Diagnostic assessment of deep learning algorithms for detection of lymph node metastases in women with breast cancer. JAMA 318(22), 2199–2210 (2017) 2. Bolelli, F., Allegretti, S., Grana, C.: One DAG to rule them all. IEEE Trans. Pattern Anal. Mach. Intell. 44(7), 3647–3658 (2021) 3. Buciluˇ a, C., Caruana, R., Niculescu-Mizil, A.: Model compression. In: Proceedings of the Twelfth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, pp. 535–541 (2006) 4. Campanella, G., et al.: Clinical-grade computational pathology using weakly supervised deep learning on whole slide images. Nat. Med. 25(8), 1301–1309 (2019) 5. Caron, M., et al.: Emerging properties in self-supervised vision transformers. In: IEEE/CVF International Conference on Computer Vision (ICCV), pp. 9650–9660 (2021) 6. Chen, R.J., et al.: Scaling vision transformers to gigapixel images via hierarchical self-supervised learning. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 16144–16155 (2022) 7. Chen, R.J., Krishnan, R.G.: Self-Supervised Vision Transformers Learn Visual Concepts in Histopathology. Learning Meaningful Representations of Life, NeurIPS (2022)
DAS-MIL: Distilling Across Scales for MIL Classification of WSIs
257
8. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607. PMLR (2020) 9. Chen, X., He, K.: Exploring simple Siamese representation learning. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 15750– 15758 (2021) 10. Chen, Z., Zhang, J., Che, S., Huang, J., Han, X., Yuan, Y.: Diagnose like a pathologist: weakly-supervised pathologist-tree network for slide-level immunohistochemical scoring. In: Proceedings of the AAAI Conference on Artificial Intelligence (2021) 11. Dimitriou, N., Arandjelovi´c, O., Caie, P.D.: Deep learning for whole slide image analysis: an overview. Front. Med. 6, 264 (2019) 12. Fey, M., Lenssen, J.E.: Fast graph representation learning with pytorch geometric. In: ICLR Workshop on Representation Learning on Graphs and Manifolds (2019) 13. Hashimoto, N., et al.: Multi-scale domain-adversarial multiple-instance CNN for cancer subtype classification with unannotated histopathological images. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 3852–3861 (2020) 14. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. In: NIPS Deep Learning and Representation Learning Workshop (2015) 15. Hou, W., et al.: H2-MIL: exploring hierarchical representation with heterogeneous multiple instance learning for whole slide image analysis. In: Proceedings of the AAAI Conference on Artificial Intelligence (2022) 16. Ilse, M., Tomczak, J., Welling, M.: Attention-based deep multiple instance learning. In: International Conference on Machine Learning, vol. 80, pp. 2127–2136. PMLR (2018) 17. Ilyas, T., Mannan, Z.I., Khan, A., Azam, S., Kim, H., De Boer, F.: TSFD-Net: tissue specific feature distillation network for nuclei segmentation and classification. Neural Netw. 151, 1–15 (2022) 18. Kumar, N., Gupta, R., Gupta, S.: Whole slide imaging (WSI) in pathology: current perspectives and future directions. J. Digit. Imaging 33(4), 1034–1040 (2020) 19. Li, B., Li, Y., Eliceiri, K.W.: Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 14318–14328 (2021) 20. Lu, M.Y., Williamson, D.F., Chen, T.Y., Chen, R.J., Barbieri, M., Mahmood, F.: Data-efficient and weakly supervised computational pathology on whole-slide images. Nat. Biomed. Eng. 5(6), 555–570 (2021) 21. Monti, A., Porrello, A., Calderara, S., Coscia, P., Ballan, L., Cucchiara, R.: How many observations are enough? Knowledge distillation for trajectory forecasting. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 6543–6552 (2022) 22. Otsu, N.: A threshold selection method from gray-level histograms. IEEE Trans. Syst. Man Cybern. 9(1), 62–66 (1979) 23. Porrello, A., Bergamini, L., Calderara, S.: Robust Re-Identification by Multiple Views Knowledge Distillation. In: Computer Vision - ECCV 2020. pp. 93–110. Springer (2020) 24. Porrello, A., Bergamini, L., Calderara, S.: Robust re-identification by multiple views knowledge distillation. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12355, pp. 93–110. Springer, Cham (2020). https:// doi.org/10.1007/978-3-030-58607-2 6
258
G. Bontempo et al.
25. Shao, Z., Bian, H., Chen, Y., Wang, Y., Zhang, J., Ji, X., et al.: Transmil: transformer based correlated multiple instance learning for whole slide image classification. Adv. Neural Inf. Process. Syst. (NeurIPS) 34, 2136–2147 (2021) 26. Srinidhi, C.L., Ciga, O., Martel, A.L.: Deep neural network models for computational histopathology: a survey. Med. Image Anal. 67, 101813 (2021) 27. Veliˇckovi´c, P., Cucurull, G., Casanova, A., Romero, A., Li` o, P., Bengio, Y.: Graph attention networks. In: International Conference on Learning Representations (2018). accepted as poster 28. Zhang, H., et al.: DTFD-MIL: double-tier feature distillation multiple instance learning for histopathology whole slide image classification. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 18802–18812 (2022) 29. Zhang, L., Bao, C., Ma, K.: Self-distillation: towards efficient and compact neural networks. IEEE Trans. Pattern Anal. Mach. Intell. 44(8), 4388–4403 (2021) 30. Zhang, L., Song, J., Gao, A., Chen, J., Bao, C., Ma, K.: Be your own teacher: improve the performance of convolutional neural networks via self distillation. In: IEEE/CVF International Conference on Computer Vision (ICCV), pp. 3713–3722 (2019) 31. Zhao, Y. et al.: SETMIL: spatial encoding transformer-based multiple instance learning for pathological image analysis. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) Medical Image Computing and Computer Assisted Intervention – MICCAI 2022. MICCAI 2022. LNCS, vol. 13432, pp. 66–76. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16434-7 7 32. Zhao, Y., et al.: Predicting lymph node metastasis using histopathological images based on multiple instance learning with deep graph convolution. In: IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4837–4846 (2020)
SLPD: Slide-Level Prototypical Distillation for WSIs Zhimiao Yu, Tiancheng Lin, and Yi Xu(B) MoE Key Lab of Artificial Intelligence, AI Institute, Shanghai Jiao Tong University, Shanghai, China [email protected]
Abstract. Improving the feature representation ability is the foundation of many whole slide pathological image (WSIs) tasks. Recent works have achieved great success in pathological-specific self-supervised learning (SSL). However, most of them only focus on learning patch-level representations, thus there is still a gap between pretext and slide-level downstream tasks, e.g., subtyping, grading and staging. Aiming towards slide-level representations, we propose Slide-Level Prototypical Distillation (SLPD) to explore intra- and inter-slide semantic structures for context modeling on WSIs. Specifically, we iteratively perform intra-slide clustering for the regions (4096 × 4096 patches) within each WSI to yield the prototypes and encourage the region representations to be closer to the assigned prototypes. By representing each slide with its prototypes, we further select similar slides by the set distance of prototypes and assign the regions by cross-slide prototypes for distillation. SLPD achieves state-of-the-art results on multiple slide-level benchmarks and demonstrates that representation learning of semantic structures of slides can make a suitable proxy task for WSI analysis. Code will be available at https://github.com/Carboxy/SLPD.
Keywords: Computational pathology Self-supervised learning
1
· Whole slide images(WSIs) ·
Introduction
In computational histopathology, visual representation extraction is a fundamental problem [14], serving as a cornerstone of the (downstream) task-specific learning on whole slide pathological images (WSIs). Our community has witnessed the progress of the de facto representation learning paradigm from the supervised ImageNet pre-training to self-supervised learning (SSL) [15,36]. Numerous Z. Yu and T. Lin—Equal contribution.
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_25. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 259–269, 2023. https://doi.org/10.1007/978-3-031-43907-0_25
260
Z. Yu et al.
pathological applications benefit from SSL, including classification of glioma [7], breast carcinoma [1], and non-small-cell lung carcinoma [25], mutation prediction [32], microsatellite instability prediction [31], and survival prediction from WSIs [2,16]. Among them, pioneering works [12,22,27] directly apply the SSL algorithms developed for natural images (e.g., SimCLR [10], CPC [30] and MoCo [11]) to WSI analysis tasks, and the improved performance proves the effectiveness of SSL. However, WSI is quite different from natural images in that it exhibits a hierarchical structure with giga-pixel resolution. Following works turn to designing pathological-specific tasks to explore the inherent characteristics of WSIs for representation learning, e.g., resolution-aware tasks [18,34,37] and color-aware tasks [2,38]. Since the pretext tasks encourage to mine the pathologically relevant patterns, the learned representations are expected to be more suitable for WSI analysis. Nevertheless, these works only consider learning the representations at the patch level, i.e, the cellular organization, but neglecting macro-scale morphological features, e.g., tissue phenotypes and intra-tumoral heterogeneity. As a result, there is still a gap between the pre-trained representations and downstream tasks, as the latter is mainly at the slide level, e.g., subtyping, grading and staging. More recently, some works propose to close the gap via directly learning slidelevel representations in pre-training. For instance, HIPT [8], a milestone work, introduces hierarchical pre-training (DINO [6]) for the patch-level (256 × 256) and region-level (4096 × 4096) in a two-stage manner, achieving superior performance on slide-level tasks. SS-CAMIL [13] uses EfficientNet-B0 for image compression in the first stage and then derives multi-task learning on the compressed WSIs, which assumes the primary site information, e.g., the organ type, is always available and can be used as pseudo labels. SS-MIL [35] also proposes a two-stage pre-training framework for WSIs using contrastive learning (SimCLR [10]), where the differently subsampled bags1 from the same WSI are positive pairs in the second stage. A similar idea can be found in Giga-SSL [20] with delicate patch- and WSI-level augmentations. The aforementioned methods share the same two-stage pre-training paradigm, i.e., patch-to-region/slide. Thus broader context information is preserved to close the gap between pretext and downstream tasks. However, they are essentially instance discrimination where only the self-invariance of region/slide is considered, leaving the intraand inter-slide semantic structures unexplored. In this paper, we propose to encode the intra- and inter-slide semantic structures by modeling the mutual-region/slide relations, which is called SLPD: SlideLevel Prototypical Distillation for WSIs. Specifically, we perform the slide-level clustering for the 4096 × 4096 regions within each WSI to yield the prototypes, which characterize the medically representative patterns of the tumor (e.g., morphological phenotypes). In order to learn this intra-slide semantic structure, we encourage the region representations to be closer to the assigned prototypes. By representing each slide with its prototypes, we further select semantically simi1
By formulating WSI tasks as a multi-instance learning problem, the WSI is treated as a bag with corresponding patches as instances.
SLPD (a) Hierarchical Structure of WSI
261
(c) Slide-Level Prototypical Disllaon
(b) Two-Stage Pretraining
Image Patch Region Probability distribuon WSI
(d) Global (Le) vs. Slide-Level Clustering (Right)
(e) Intra-Slide Disllaon
Prototype within slide
(f) Inter-Slide Disllaon
Prototypes across slides
Fig. 1. (a) A WSI possesses the hierarchical structure of WSI-region-patch-image, from coarse to fine. (b) Two-stage pre-training paradigm successively performs the image-topatch and patch-to-region aggregations. (c-e) The proposed SLPD. SLPD explores the semantic structure by slide-level clustering. Besides self-distillation, region representations are associated with the prototypes within and across slides to comprehensively understand WSIs.
lar slides by the set-to-set distance of prototypes. Then, we learn the inter-slide semantic structure by building correspondences between region representations and cross-slide prototypes. We conduct experiments on two benchmarks, NSCLC subtyping and BRCA subtyping. SLPD achieves state-of-the-art results on multiple slide-level tasks, demonstrating that representation learning of semantic structures of slides can make a suitable proxy task for WSI analysis. We also perform extensive ablation studies to verify the effectiveness of crucial model components.
2 2.1
Method Overview
As shown in Fig. 1(a), a WSI exhibits hierarchical structure at varying resolutions under 20× magnification: 1) the 4096×4096 regions describing macro-scale organizations of cells, 2) the 256 × 256 patches capturing local clusters of cells, 3) and the 16 × 16 images characterizing the fine-grained features at the celllevel. Given N unlabeled WSIs {w1 , w2 , · · · , wN }, consisting of numerous regions N n {{xln }L l=1 }n=1 , where Ln denotes the number of regions of WSI wn , we aim to learn a powerful encoder that maps each xln to an embedding znl ∈ RD . SLPD
262
Z. Yu et al.
is built upon the two-stage pre-training paradigm proposed by HIPT, which will be described in Sect. 2.2. Fig 1(c-d) illustrates the pipeline of SLPD. We characterize the semantic structure of slides in Sect. 2.2, which is leveraged to establish the relationship within and across slides, leading to the proposed intraand inter-slide distillation in Sect. 2.4 and Sect. 2.5. 2.2
Preliminaries
We revisit Hierarchical Image Pyramid Transformer (HIPT) [8], a cutting-edge method for learning representations of WSIs via self-supervised vision transformers. As shown in Fig. 1(b), HIPT proposes a two-stage pre-training paradigm considering the hierarchical structure of WSIs. In stage one, a patch-level vision transformer, denoted as ViT256 -16, aggregates non-overlapping 16 × 16 images within 256 × 256 patches to form patch-level representations. In stage two, the pre-trained ViT256 -16 is freezed and leveraged to tokenize the patches within 4096 × 4096 regions. Then a region-level vision transformer ViT4096 -256 aggregates these tokens to obtain region-level representations. With this hierarchical aggregation strategy, a WSI can be represented as a bag of region-level representations, which are then aggregated with another vision transformer, ViTWSI 4096, to perform slide-level prediction tasks. HIPT leverages DINO [6] to pre-train ViT256 -16 and ViT4096 -256, respectively. The learning objective of DINO is self-distillation. Taking stage two as an example, DINO distills the knowledge from teacher to student by minimizing the cross-entropy between the probability distributions of two views at region-level: z ), gs (z)), Lself = Ex∼pd H(gt (ˆ
(1)
where H(a, b) = −a log b, and pd is the data distribution that all regions are drawn from. The teacher and the student share the same architecture consisting of an encoder (e.g., ViT) and a projection head gt /gs . zˆ and z are the embeddings of two views at region-level yielded by the encoder. The parameters of the student are exponentially moving averaged to the parameters of the teacher. 2.3
Slide-Level Clustering
Many histopathologic features have been established based on the morphologic phenotypes of the tumor, such as tumor invasion, anaplasia, necrosis and mitoses, which are then used for cancer diagnosis, prognosis and the estimation of response-to-treatment in patients [3,9]. To obtain meaningful representations of slides, we aim to explore and maintain such histopathologic features in the latent space. Clustering can reveal the representative patterns in the data and has achieved success in the area of unsupervised representation learning [4,5,24,26]. To characterize the histopathologic features underlying the slides, a straightforward practice is the global clustering, i.e., clustering the region embeddings from all the WSIs, as shown in the left of Fig. 1(d). However, the obtained clustering centers, i.e., the prototypes, are inclined to represent the visual bias
SLPD
263
related to staining or scanning procedure rather than medically relevant features [33]. Meanwhile, this clustering strategy ignores the hierarchical structure “region→WSI→whole dataset” underlying the data, where the ID of the WSI can be served as an extra learning signal. Therefore, we first consider the slidelevel clustering that clusters the embeddings within each WSI, which is shown in the right of Fig. 1(d). Specifically, we conduct k-means algorithm before the n start of each epoch over Ln region embeddings {znl }L l=1 of wn to obtain M prom D M totypes {cn ∈ R }m=1 . Similar operations are applied across other slides, and M N then we acquire N groups of prototypes {{cm n }m=1 }n=1 . Each group of prototypes is expected to encode the semantic structure (e.g., the combination of histopathologic features) of the WSI. 2.4
Intra-Slide Distillation
The self-distillation utilized by HIPT in stage two encourages the correspondence between two views of a region at the macro-scale because the organizations of cells share mutual information spatially. However, the self-distillation, which solely mines the spatial correspondences inside the 4096 × 4096 region, cannot comprehensively understand the histopathologic consistency at the slide-level. In order to achieve better representations, the histopathologic connections between the WSI and its regions should be modeled and learned, which is called intraslide correspondences. With the proposed slide-level clustering in Sect. 2.3, a slide can be abstracted by a group of prototypes, which capture the semantic structure of the WSI. As shown in Fig. 1(e), we assume that the representation z and its assigned prototype c also share mutual information and encourage z to be closer to c with the intra-slide distillation: Lintra = Ex∼pd H (gt (c), gs (z)) ,
(2)
We omit super-/sub-scripts of z for brevity. Through Eq. 2, we can leverage more intra-slide correspondences to guide the learning process. For further understanding, a prototype can be viewed as an augmented representation aggregating the slide-level information. Thus this distillation objective is encoding such information into the corresponding region embedding, which makes the learning process semantic structure-aware at the slide-level. 2.5
Inter-Slide Distillation
Tumors of different patients can exhibit morphological similarities in some respects [17,21], so the correspondences across slides should be characterized during learning. Previous self-supervised learning methods applied to histopathologic images only capture such correspondences with positive pairs at the patchlevel [22,23], which overlooks the semantic structure of the WSI. We rethink this problem from the perspective how to measure the similarity between two slides accurately. Due to the heterogeneity of the slides, comparing them with the local crops or the averaged global features are both susceptible to being one-sided. To
264
Z. Yu et al.
address this, we bridge the slides with their semantic structures and define the semantic similarity between two slides wi and wj through an optimal bipartite matching between two sets of prototypes: D(wi , wj ) = max{
M 1 σ(m) cos(cm ) | σ ∈ SM }, D(wi , wj ) ∈ [−1, 1], (3) i , cj M m=1
where cos(·, ·) measures the cosine similarity between two vectors, and SM enumerates the permutations of M elements. The optimal permutation σ ∗ can be computed efficiently with the Hungarian algorithm [19]. With the proposed setto-set distance, we can model the inter-slide correspondences conveniently and accurately. Specifically, for a region embedding z belonging to the slide w and assigned to the prototype c, we first search the top-K nearest neighbors of w in the dataset based on the semantic similarity, denoted as {w ˆk }K k=1 . Second, we K also obtain the matched prototype pairs {(c, cˆk )}k=1 determined by the optimal ˆk . Finally, we encourage z to be permutation, where cˆk is the prototype of w closer to cˆk with the inter-slide distillation: Linter = Ex∼pd [
K 1 H (gt (ˆ ck ), gs (z))]. K
(4)
k=1
The inter-slide distillation can encode the sldie-level information complementary to that of intra-slide distillation into the region embeddings. The overall learning objective of the proposed SLPD is defined as: Ltotal = Lself + α1 Lintra + α2 Linter ,
(5)
where the loss scale is simply set to α1 = α2 = 1. We believe the performance can be further improved by tuning this.
3
Experimental Results
Datasets. We conduct experiments on two public WSI datasets2 . TCGANSCLC dataset includes two subtypes in lung cancer, Lung Squamous Cell Carcinoma and Lung Adenocarcinoma, with a total of 1,054 WSIs. TCGA-BRCA dataset includes two subtypes in breast cancer, Invasive Ductal and Invasive Lobular Carcinoma, with a total of 1,134 WSIs. Pre-training. We extract 62,852 and 60,153 regions at 20× magnification from TCGA-NSCLC and TCGA-BRCA for pre-training ViT4096 -256 in stage two. We leverage the pre-trained ViT256 -16 in stage one provided by HIPT to tokenize the patches within each region. Following the official code of HIPT, ViT4096 -256 is optimized for 100 epochs with optimizer of AdamW, base learning rate of 5e-4 and batch size of 256 on 4 GTX3090 GPUs. 2
The data is released under a CC-BY-NC 4.0 international license.
SLPD
265
Fine-tuning. We use the pre-trained ViT256 -16 and ViT4096 -256 to extract embeddings at the patch-level (256 × 256) and the region-level (4096 × 4096) for downstream tasks. With the pre-extracted embeddings, we fine-tune three aggregators (i.e., MIL [28], DS-MIL [22] and ViTWSI -4096 [8]) for 20 epochs and follow other settings in the official code of HIPT. Evaluation Metrics. We adopt the 10-fold cross validated Accuracy (Acc.) and area under the curve (AUC) to evaluate the weakly-supervised classification performance. The data splitting scheme is kept consistent with HIPT. Table 1. Slide-level classification. “Mean” leverages the averaged pre-extracted embeddings to evaluate KNN performance. Bold and underlined numbers highlight the best and second best performance NSCLC Feature Feature Pretrain Aggragtor Extraction Method Acc. AUC Weakly supervised classification 1 patch-level DINO 0.780±0.126 0.864±0.089 MIL [28] 2 region-level SLPD 0.856±0.025 0.926±0.017 3 patch-level DINO 0.825±0.054 0.905±0.059 DS-MIL [22] 4 region-level DINO 0.841±0.036 0.917±0.035 5 region-level SLPD 0.858±0.040 0.938±0.026 6 region-level DINO 0.843±0.044 0.926±0.032 7 region-level DINO+Lintra 0.850±0.042 0.931±0.041 ViTWSI -4096 [8] region-level DINO+Linter 0.850±0.043 0.938±0.028 8 9 region-level SLPD 0.864±0.042 0.939±0.022 K-nearest neighbors (KNN) evaluation 10 region-level DINO 0.770±0.031 0.840±0.038 Mean 11 region-level DINO+Lintra 0.776±0.039 0.850±0.023 12 region-level DINO+Linter 0.782±0.027 0.854±0.025 13 region-level SLPD 0.792±0.035 0.863±0.024 #
3.1
BRCA Acc.
AUC
0.822±0.047 0.879±0.035 0.847±0.032 0.854±0.032 0.854±0.039 0.849±0.037 0.866±0.030 0.860±0.030 0.869±0.039
0.783±0.056 0.863±0.076 0.848±0.075 0.848±0.075 0.876±0.050 0.854±0.069 0.881±0.069 0.874±0.059 0.886±0.057
0.837±0.014 0.841±0.012 0.845±0.014 0.849±0.014
0.724±0.055 0.731±0.064 0.738±0.080 0.751±0.079
Weakly-Supervised Classification
We conduct experiments on two slide-level classification tasks, NSCLC subtyping and BRCA subtyping, and report the results in Table 1. The region-level embeddings generated by SLPD outperform the patch-level embeddings across two aggregators3 and two tasks (#1∼ 5). This illustrates that learning representations with broader image contexts is more suitable for WSI analysis. Compared with the strong baseline, i.e., the two-stage pre-training method proposed by HIPT (#6), SLPD achieves performance increases of 1.3% and 3.2% AUC on NSCLC and BRCA (#9). Nontrivial performance improvements are also observed under KNN evaluation (#10 vs.#13): +2.3% and +3.1% AUC on NSCLC and BRCA. The superior performance of SLPD demonstrates that learning representations with slide-level semantic structure appropriately can significantly narrow the gap between pre-training and downstream slide-level 3
The feature extraction of the patch-level is impracticable for the ViT-based model due to its quadratic complexity in memory usage.
266
Z. Yu et al.
tasks. Moreover, intra-slide and inter-slide distillation show consistent performance over the baseline, corroborating the effectiveness of these critical components of SLPD. 3.2
Ablation Study
Different Clustering Methods. As discussed in Sect. 2.3, we can alternatively use the global clustering to obtain prototypes and then optimize the network with a similar distillation objective as Eq. 2. For a fair comparison, the total number of prototypes of the two clustering methods is approximately the same. Table 2(#1,2) reports the comparative results, where the slide-level clustering surpasses the global clustering by 0.6% and 1.8% of AUC on NSCLC and BRCA, which verifies the effectiveness of the former. The inferior performance of the global clustering is due to the visual bias underlying the whole dataset. Table 2. Ablation studies of SLPD. ViTWSI -4096 is the aggregator with region-level embeddings. #
Ablation
1 2 3 4 5 6 7 8 9 10
Different clustering methods Different interslide distillations Number of prototypes Number of slide neighbors
Method global slide-level region prototype M =2 M =3 M =4 K=1 K=2 K=3
NSCLC Acc. AUC 0.848±0.045 0.925±0.033 0.850±0.042 0.931±0.041 0.828±0.040 0.915±0.025 0.850±0.043 0.938±0.028 0.859±0.036 0.936±0.021 0.864±0.035 0.938±0.022 0.864±0.042 0.939±0.022 0.864±0.042 0.939±0.022 0.862±0.039 0.938±0.029 0.869±0.034 0.936±0.024
BRCA Acc. AUC 0.842±0.048 0.863±0.060 0.866±0.030 0.881±0.069 0.843±0.024 0.849±0.067 0.860±0.030 0.874±0.059 0.869±0.039 0.886±0.057 0.861±0.056 0.878±0.069 0.860±0.031 0.872±0.060 0.869±0.039 0.886±0.057 0.875±0.038 0.889±0.057 0.873±0.051 0.880±0.058
Different Inter-slide Distillations. The proposed inter-slide distillation is semantic structure-aware at the slide-level, since we build the correspondence between the region embedding and the matched prototype (#4 in Table 2). To verify the necessity of this distillation method, we turn to another design where the inter-slide correspondence is explored through two nearest region embeddings across slides (#3 in Table 2). As can be seen, the region-level correspondences lead to inferior performances, even worse than the baseline (#5 in Table 1), because the learning process is not guided by the slide-level information. Number of Prototypes. As shown in Table 2(#5∼7), the performance of SLPD is relatively robust to the number of prototypes on NSCLC, but is somewhat affected by it on BRCA. One possible reason is that the heterogeneity of invasive breast carcinoma is low [29], and thus the excessive number of prototypes cannot obtain medically meaningful clustering results. Empirically, we set M = 4 on NSCLC and M = 2 on BRCA as the default configuration. We
SLPD
267
suggest the optimal number of prototypes should refer to clinical practice, by considering tissue types, cell morphology, gene expression and other factors. Number of Slide Neighbors. As demonstrated in Table 2(#5∼7), the performance of SLPD is robust to the number of slide neighbors. Considering that more slide neighbors require more computation resources, we set K = 1 as the default configuration. For more results, please refer to the Supplementary.
4
Conclusion
This paper reflects on slide-level representation learning from a novel perspective by considering the intra- and inter-slide semantic structures. This leads to the proposed Slide-Level Prototypical Distillation (SLPD), a new self-supervised learning approach achieving the more comprehensive understanding of WSIs. SLPD leverages the slide-level clustering to characterize semantic structures of slides. By representing slides as prototypes, the mutual-region/slide relations are further established and learned with the proposed intra- and inter-slide distillation. Extensive experiments have been conducted on multiple WSI benchmarks and SLPD achieves state-of-the-art results. Though SLPD is distillation-based, we plan to apply our idea to other pre-training methods in the future, e.g., contrastive learning [10,11]. Acknowledgement. This work was supported in part by NSFC 62171282, Shanghai Municipal Science and Technology Major Project (2021SHZDZX0102), 111 project BP0719010, STCSM 22DZ2229005, and SJTU Science and Technology Innovation Special Fund YG2022QN037.
References 1. Abbasi-Sureshjani, S., Yüce, A., Schönenberger, et al.: Molecular subtype prediction for breast cancer using H&E specialized backbone. In: MICCAI, pp. 1–9. PMLR (2021) 2. Abbet, C., Zlobec, I., Bozorgtabar, B., Thiran, J.-P.: Divide-and-rule: selfsupervised learning for survival analysis in colorectal cancer. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12265, pp. 480–489. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59722-1_46 3. Amin, M.B., Greene, F.L., Edge, S.B., et al.: The eighth edition AJCC cancer staging manual: continuing to build a bridge from a population-based to a more personalized approach to cancer staging. CA Cancer J. Clin. 67(2), 93–99 (2017) 4. Caron, M., Bojanowski, P., Joulin, A., Douze, M.: Deep clustering for unsupervised learning of visual features. In: ECCV, pp. 132–149 (2018) 5. Caron, M., Misra, I., Mairal, J., Goyal, P., Bojanowski, P., Joulin, A.: Unsupervised learning of visual features by contrasting cluster assignments. Adv. Neural Inf. Process. Syst. 33, 9912–9924 (2020) 6. Caron, M., Touvron, H., Misra, I., et al.: Emerging properties in self-supervised vision transformers. In: ICCV, pp. 9650–9660 (2021)
268
Z. Yu et al.
7. Chen, L., Bentley, P., et al.: Self-supervised learning for media using image context restoration. MedIA 58, 101539 (2019) 8. Chen, R.J., et al.: Scaling vision transformers to gigapixel images via hierarchical self-supervised learning. In: CVPR, pp. 16144–16155 (2022) 9. Chen, R.J., Lu, M.Y., Williamson, D.F.K., et al.: Pan-cancer integrative histologygenomic analysis via multimodal deep learning. Cancer Cell 40(8), 865–878 (2022) 10. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: ICML, pp. 1597–1607. PMLR (2020) 11. Chen, X., Fan, H., Girshick, R., He, K.: Improved baselines with momentum contrastive learning. arXiv preprint arXiv:2003.04297 (2020) 12. Dehaene, O., Camara, A., Moindrot, O., et al.: Self-supervision closes the gap between weak and strong supervision in histology. arXiv preprint arXiv:2012.03583 (2020) 13. Fashi, P.A., Hemati, S., Babaie, M., Gonzalez, R., Tizhoosh, H.: A self-supervised contrastive learning approach for whole slide image representation in digital pathology. J. Pathol. Inform. 13, 100133 (2022) 14. Gurcan, M.N., Boucheron, L.E., Can, A., et al.: Histopathological image analysis: a review. IEEE Rev. Biomed. Eng. 2, 147–171 (2009) 15. He, K., Fan, H., Wu, Y., Xie, S., Girshick, R.: Momentum contrast for unsupervised visual representation learning. In: CVPR, pp. 9729–9738 (2020) 16. Huang, Z., Chai, H., Wang, R., Wang, H., Yang, Y., Wu, H.: Integration of patch features through self-supervised learning and transformer for survival analysis on whole slide images. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12908, pp. 561–570. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087237-3_54 17. Jass, J.R.: HNPCC and sporadic MSI-H colorectal cancer: a review of the morphological similarities and differences. Fam. Cancer 3, 93–100 (2004) 18. Koohbanani, N.A., Unnikrishnan, B., Khurram, S.A., Krishnaswamy, P., Rajpoot, N.: Self-path: self-supervision for classification of pathology images with limited annotations. IEEE Trans. Med. Imaging 40(10), 2845–2856 (2021) 19. Kuhn, H.W.: The Hungarian method for the assignment problem. Nav. Res. Logist. Q. 2(1–2), 83–97 (1955) 20. Lazard, T., Lerousseau, M., Decencière, E., Walter, T.: Self-supervised extreme compression of gigapixel images 21. Levy-Jurgenson, A., et al.: Spatial transcriptomics inferred from pathology wholeslide images links tumor heterogeneity to survival in breast and lung cancer. Sci. Rep. 10(1), 1–11 (2020) 22. Li, B., Li, Y., Eliceiri, K.W.: Dual-stream multiple instance learning network for whole slide image classification with self-supervised contrastive learning. In: CVPR, pp. 14318–14328 (2021) 23. Li, J., Lin, T., Xu, Y.: SSLP: spatial guided self-supervised learning on pathological images. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 3–12. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87196-3_1 24. Li, J., Zhou, P., Xiong, C., Hoi, S.C.: Prototypical contrastive learning of unsupervised representations. arXiv preprint arXiv:2005.04966 (2020) 25. Li, L., Liang, Y., Shao, M., et al.: Self-supervised learning-based multi-scale feature fusion network for survival analysis from whole slide images. Comput. Biol. Med. 153, 106482 (2023) 26. Li, Y., Hu, P., Liu, Z., Peng, D., Zhou, J.T., Peng, X.: Contrastive clustering. In: AAAI, vol. 35, pp. 8547–8555 (2021)
SLPD
269
27. Lu, M.Y., Chen, R.J., Mahmood, F.: Semi-supervised breast cancer histology classification using deep multiple instance learning and contrast predictive coding (conference presentation). In: Medical Imaging 2020: Digital Pathology, vol. 11320, p. 113200J. International Society for Optics and Photonics (2020) 28. Lu, M.Y., Williamson, D.F., Chen, T.Y., et al.: Data-efficient and weakly supervised computational pathology on whole-slide images. Nat. Biomed. Eng. 5(6), 555–570 (2021) 29. Öhlschlegel, C., Zahel, K., Kradolfer, D., Hell, M., Jochum, W.: Her2 genetic heterogeneity in breast carcinoma. J. Clin. Pathol. 64(12), 1112–1116 (2011) 30. Oord, A.V.D., Li, Y., Vinyals, O.: Representation learning with contrastive predictive coding. arXiv preprint arXiv:1807.03748 (2018) 31. Saillard, C., Dehaene, et al.: Self supervised learning improves dMMR/MSI detection from histology slides across multiple cancers. arXiv preprint arXiv:2109.05819 (2021) 32. Saldanha, O.L., Loeffler, et al.: Self-supervised deep learning for pan-cancer mutation prediction from histopathology. bioRxiv, pp. 2022–09 (2022) 33. Sharma, Y., Shrivastava, A., Ehsan, L., et al.: Cluster-to-conquer: a framework for end-to-end multi-instance learning for whole slide image classification. In: Medical Imaging with Deep Learning, pp. 682–698. PMLR (2021) 34. Srinidhi, C.L., Kim, S.W., Chen, F.D., Martel, A.L.: Self-supervised driven consistency training for annotation efficient histopathology image analysis. Media 75, 102256 (2022) 35. Tavolara, T.E., Gurcan, M.N., Niazi, M.K.K.: Contrastive multiple instance learning: an unsupervised framework for learning slide-level representations of whole slide histopathology images without labels. Cancers 14(23), 5778 (2022) 36. Wu, Z., Xiong, Y., Yu, S.X., Lin, D.: Unsupervised feature learning via nonparametric instance discrimination. In: CVPR, pp. 3733–3742 (2018) 37. Xie, X., Chen, J., Li, Y., Shen, L., Ma, K., Zheng, Y.: Instance-aware self-supervised learning for nuclei segmentation. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12265, pp. 341–350. Springer, Cham (2020). https://doi.org/10.1007/ 978-3-030-59722-1_33 38. Yang, P., et al.: CS-CO: a hybrid self-supervised visual representation learning method for H&E-stained histopathological images. Media 81, 102539 (2022)
PET Image Denoising with Score-Based Diffusion Probabilistic Models Chenyu Shen1,2 , Ziyuan Yang2 , and Yi Zhang1(B) 1
School of Cyber Science and Engineering, Sichuan University, Chengdu, China [email protected] 2 College of Computer Science, Sichuan University, Chengdu, China Abstract. Low-count positron emission tomography (PET) imaging is an effective way to reduce the radiation risk of PET at the cost of a low signal-to-noise ratio. Our study aims to denoise low-count PET images in an unsupervised mode since the mainstream methods usually rely on paired data, which is not always feasible in clinical practice. We adopt the diffusion probabilistic model in consideration of its strong generation ability. Our model consists of two stages. In the training stage, we learn a score function network via evidence lower bound (ELBO) optimization. In the sampling stage, the trained score function and low-count image are employed to generate the corresponding high-count image under two handcrafted conditions. One is based on restoration in latent space, and the other is based on noise insertion in latent space. Thus, our model is named the bidirectional condition diffusion probabilistic model (BCDPM). Real patient whole-body data are utilized to evaluate our model. The experiments show that our model achieves better performance in both qualitative and quantitative aspects compared to several traditional and recently proposed learning-based methods. Keywords: PET denoising space conditions
1
· diffusion probabilistic model · latent
Introduction
Positron emission tomography (PET) is an imaging modality in nuclear medicine that has been successfully applied in oncology, neurology, and cardiology. By injecting a radioactive tracer into the human body, the molecular-level activity in tissues can be observed. To mitigate the radiation risk to the human body, it is essential to reduce the dose or shorten the scan time, leading to a low signalto-noise ratio and further negatively influencing the accuracy of diagnosis. Recently, the denoising diffusion probabilistic model (DDPM) [6,9,11] has become a hot topic in the generative model community. The original DDPM was designed for generation tasks, and many recent works have proposed extending it for image restoration or image-to-image translation. In supervised mode, Saharia Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 26. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 270–278, 2023. https://doi.org/10.1007/978-3-031-43907-0_26
PET Image Denoising with Score-Based Diffusion Probabilistic Models
271
Fig. 1. Overview of our proposed BC-DPM model.
et al. [8] proposed a conditional DDPM to perform single-image super-resolution, which integrates a low-resolution image into each reverse step. In unsupervised mode, to handle the stochasticity of the generative process, Choi proposed iterative latent variable refinement (ILVR) [1] to guarantee the given condition in each transition, thus generating images with the desired semantics. DDPM has also been applied in medical imaging. To explore its generalization ability, Song et al. [12] proposed a fully unsupervised model for medical inverse problems, providing the measuring process and the prior distribution learned with a scorebased generative model. For PET image denoising, Gong et al. [4] proposed two paradigms. One is directly feeding noisy PET images and anatomical priors (if available) into the score function network, which relies on paired high-quality and low-quality PET images. The other is feeding only MR images into the score function network while using noisy PET images in the inference stage under the assumption that PET image noise obeys a Gaussian distribution. In this paper, we propose a conditional diffusion probabilistic model for lowcount PET image denoising in an unsupervised manner without the Gaussian noise assumption or paired datasets. Our model is divided into two stages. In the training stage, we leverage the standard DDPM to train the score function network to learn a prior distribution of PET images. Once the network is trained, we transplant it into the sampling stage, in which we design two conditions to control the generation of high-count PET images given corresponding lowcount PET images. One condition is that the denoised versions of low-count PET images are similar to high-count PET images. The other condition is that
272
C. Shen et al.
Fig. 2. The proposed conditional block in the sampling stage.
when we add noise to high-count PET images, they degrade to low-count PET images. As a result, our model is named the bidirectional condition diffusion probabilistic model (BC-DPM). In particular, to simulate the formation of PET noise, we add noise in the sinogram domain. Additionally, the two proposed conditions are implemented in latent space. Notably, Our model is ‘one for all’, that is, once we have trained the score network, we can utilize this model for PET images with different count levels.
2
Method
Letting X ⊂ X be a high-count PET image dataset and Y ⊂ Y be a low-count PET image dataset, x0 and y0 denote instances in X and Y , respectively. Our goal is to estimate a mapping F(Y) = X , and the proposed BC-DPM provides an unsupervised technique to solve this problem. BC-DPM includes two stages. In the training stage, it requires only X without paired (X, Y ), and in the sampling stage, it produces the denoised x0 for a given y0 . 2.1
Training Stage
BC-DPM acts the same as the original DDPM in the training stage, it consists of a forward process and a reverse process. In the forward process, x0 is gradually contaminated by fixed Gaussian noise, producing a sequence of latent space data {x1 , x2 , ..., xT }, where xT ∼ N (0, I). The forward process can be described formally by a joint distribution q(x1:T |x0 ) given x0 . Under the Markov property, it can be defined as: q(x1:T |x0 ) :=
T
q(xt |xt−1 ),
q(xt |xt−1 ) := N (xt ;
1 − βt xt−1 , βt I),
(1)
t=1
where {β1 , β2 , ..., βT } is a fixed variance schedule with small positive constants and I represents the identity matrix. Notably, the forward process allows xt to
PET Image Denoising with Score-Based Diffusion Probabilistic Models
273
be sampled directly from x0 : xt =
√
α ¯ t x0 +
√ 1−α ¯ t ,
(2)
t where α¯t := s=1 αs , αt := 1 − βt and ∼ N (0, I). The reverse process is defined by a Markov chain starting with p(xT ) = N (xT ; 0, I): pθ (x0:T ) := p(xT )
T
pθ (xt−1 |xt ),
pθ (xt−1 |xt ) := N (xt−1 ; μθ (xt , t), σθ (xt , t)I).
t=1
(3) Given the reverse process, pθ (x0 )can be expressed by setting up an integral over the x1:T variables pθ (x0 ) := pθ (x0:T )dx1:T , and the parameter θ can be updated by optimizing the following simple loss function: 2 √ √ ¯ t , t) ]. ¯ t x0 + 1 − α Lsimple (θ) = Et,x0 , [ − θ ( α
(4)
The θ (xt , t) used in this paper heavily relies on that proposed by Dhariwal et al. [3]. The pseudocode for the training stage is given in Algorithm 1. Algorithm 1: Training stage. repeat x0 ∼ q(x0 ) t ∼ U nif orm(1, 2, ..., T ) ∼ N (0, I) Update θ by optimizing 2 √ √ Et,x0 , [ − θ ( α ¯ t x0 + 1 − α ¯ t , t) ] until convergence
2.2
Sampling Stage
The main difference between BC-DPM and the original DDPM lies in the sampling stage. Due to the stochasticity of the reverse process pθ (x0:T ), it is difficult for the original DDPM to generate images according to our expectation. To overcome this obstacle, the proposed BC-DPM models pθ (x0 |c) given condition c instead of modeling pθ (x0 ) as pθ (x0 |c) =
pθ (x0:T |c)dx1:T ,
pθ (x0:T |c) = p(xT )
T
pθ (xt−1 |xt , c).
(5)
t=1
Condition c derives from specific prior knowledge from the high-count PET image x0 and the low-count PET image y0 . With c, BC-DPM can control the generation of x0 given y0 . Then, the core problem is to design a proper condition c. A natural choice is D(y0 ) ≈ x0 , that is, the restoration task itself. We must clarify that it will not
274
C. Shen et al.
cause a ‘deadlock’ for the following two reasons. One is that the final form of the condition D(y0 ) ≈ x0 does not involve x0 , and the other is that we choose a relatively simple denoiser in the condition, which can be viewed as a ’coarse to fine’ operation. In practice, we utilize a Gaussian filter GF(·) as the denoiser in this condition. However, the Gaussian filter usually leads to smoothed images. Based on this property, we observe that the PSNR value between GF(y0 ) and x0 is usually inferior to that between GF(y0 ) and GF(x0 ), which means that the condition GF(y0 ) ≈ GF(x0 ) is more accurate than GF(y0 ) ≈ x0 . Thus, we choose GF(y0 ) ≈ GF(x0 ) in our experiments. However, if we only utilize the above condition, the training is unstable, and distortion may be observed. To address this problem, another condition needs to be introduced. The above condition refers to denoising, so conversely, we can consider adding noise to x0 ; that is, y0 ≈ A(x0 ). According to the characteristics of PET noise, Poisson noise is used in the sinogram domain instead of the image domain. We define this condition as P † (P o(P(x0 ) + r + s)) ≈ y0 , where P, P o, P † , r and s represent the Radon transform, Poisson noise insertion, inverse Radon transform, random coincidence and scatter coincidence, respectively. Now, we have two conditions GF(y0 ) ≈ GF(x0 ) and P † (P o(P(x0 ) + r + s)) ≈ y0 from the perspectives of denoising and noise insertion, respectively. Since the conditions involve x0 , we have to convert the conditions from the original data space into latent space under certain circumstances to avoid estimating x0 . Let us denote each transition in the reverse process under global conditions as: pθ (xt−1 |xt , c1 , c2 ) = pθ (xt−1 |xt , GF(x0 ) = GF(y0 ), P † (P o(P(x0 ) + r + s)) = y0 ).
(6)
In Eq. (2), xt can be represented by a linear combination of x0 and . Then, we can express x0 with xt and : √ √ x0 ≈ fθ (xt , t) = (xt − 1 − α ¯ t θ (xt , t))/ α ¯t. (7) Similarly, applying the same diffusion process to y0 , we have {y1 , y2 , ..., yT }, and y0 can be expressed with yt and : √ √ y0 ≈ fθ (yt , t) = (yt − 1 − α ¯ t θ (yt , t))/ α ¯t. (8) Replacing x0 and y0 with fθ (xt , t) and fθ (yt , t) in Eq. (6), respectively, we have: pθ (xt−1 |xt , c1 , c2 ) ≈ Eq(yt−1 |y0 ) [pθ (xt−1 |xt , GF(xt−1 ) = GF(yt−1 ), P † (P o(P(xt−1 ) + r + s)) = yt−1 )].
(9)
Assume that xt−1 = (1 − λ)(GF(yt−1 ) + (I − GF)(xt−1 )) + λ(yt−1 + xt−1 − P † (P o(P(xt−1 ) + r + s)),
(10)
PET Image Denoising with Score-Based Diffusion Probabilistic Models
275
where xt−1 is sampled from pθ (xt−1 |xt ), and λ ∈ [0, 1] is a balancing factor between the two conditions. Thus, we have Eq(yt−1 |y0 ) [pθ (xt−1 |xt , GF (xt−1 ) = GF (yt−1 ), P † (P o(P(xt−1 ) + r + s)) = yt−1 )] ≈ pθ (xt−1 |xt , GF (xt−1 ) = GF (yt−1 ), P † (P o(P(xt−1 ) + r + s)) = yt−1 ).
(11)
Finally, we have pθ (xt−1 |xt , GF(x0 ) = GF(y0 ), P † (P o(P(x0 ) + r + s)) = y0 ) = pθ (xt−1 |xt , GF(xt−1 ) = GF(yt−1 ), P † (P o(P(xt−1 ) + r + s)) = yt−1 ),
(12)
which indicates that under the assumption of Eq. (10), the global conditions on (x0 , y0 ) can be converted to local conditions on (xt−1 , yt−1 ) in each transition from xt to xt−1 . Now, given a low-count PET image y0 , to estimate x0 , we can sample from white noise xT using the following two steps iteratively. The first step is to generate an immediate xt−1 from pθ (xt−1 |xt ). The second step is to generate xt−1 from xt−1 using Eq. (10). In practice, we note that there is no need to operate the two local conditions in each transition; instead, we only need the last l transitions. Generally speaking, The larger l is, the more blurred the image will be. As l decreases, the image gets more noisy. We provide the sampling procedure of BC-DPM in Algorithm 2.
Algorithm 2: Sampling stage. Input: low-count PET image y0 , parameter θ from the training stage, hyper-parameters λ and l Output: high-count PET image x0 xT ∼ N (0, I) for t = T, ..., 1 do if t 0 is a scaling hyperparameter. Now, we design the following loss Lsom so that SOM representations close to × on the grid are also close to z × in the latent space (measured by the Euclidean distance z × − gi,j 2 ): ⎞ ⎛ u v wi,j · z u − gi,j 22 + wi,j · z v − gi,j 22 ⎠ . (3) Lsom := E(xu ,xv )∼S ⎝ gi,j ∼G
LSOR
283
To improve robustness, we make two more changes to Eq. 3. First, we account for SOM representations transitioning from random initialization to becoming meaningful cluster centers that preserve the high-dimensional relationships within the 2D SOM grid. We do so by decreasing τ in Eq. 2 with each iteration so that the weights gradually concentrate on SOM representations closer to g× t/T min as training proceeds: τ (t) := Nr · Nc · τmax ττmax with τmin being the minimum and τmax the maximum standard deviation in the Gaussian kernel, and t represents the current and T the maximum iteration. The second change to Eq. 3 is to apply the stop-gradient operator sg[·] [16] to z × , which sets the gradients of z × to 0 during the backward pass. The stopgradient operator prevents the undesirable scenario where z × is pulled towards a naive solution, i.e., different MRI samples are mapped to the same weighted average of all SOM representations. This risk of deriving the naive solution is especially high in the early stages of the training when the SOM representations are randomly initialized and may not accurately represent the clusters. Longitudinal Consistency Regularization. We derive a SOM grid related to brain aging by generating an age-stratified latent space. Specifically, the latent space is defined by a smooth trajectory field (Fig. 1, blue box) characterizing the morphological changes associated with brain aging. The smoothness is based on the assumption that MRIs with similar appearances (close latent representations on the latent space) should have similar trajectories. It is enforced by modeling the similarity between each subject-specific trajectory Δz with a reference trajectory that represents the average trajectory of the cluster. Specifically, Δgi,j is the reference trajectory (Fig. 1, green arrow) associated with gi,j then the r ,Nc reference trajectories of all clusters GΔ = {Δgi,j }N i=1,j=1 represent the average aging of SOM clusters with respect to the training set. As all subject-specific trajectories are iteratively updated during the training, it is computationally infeasible to keep track of GΔ on the whole training set. We instead propose to compute the exponential moving average (EMA) (Fig. 1, black box), which iteratively aggregates the average trajectory with respect to a training batch to GΔ : ⎧ ⎪ t=0 ⎨Δhi,j Δgi,j ← Δgi,j t > 0 and |Ωi,j | = 0 ⎪ ⎩ α · Δgi,j + (1 − α) · Δhi,j t > 0 and |Ωi,j | > 0 with Δhi,j
bs bs 1 := 1[uk = (i, j)] · Δzk and |Ωi,j | := 1[uk = (i, j)]. |Ωi,j |
N
N
k=1
k=1
α is the EMA keep rate, k denotes the index of the sample pair, Nbs symbolizes the batch size, 1[·] is the indicator function, and |Ωi,j | denotes the number of sample pairs with u = (i, j) within a batch. Then in each iteration, Δhi,j (Fig. 1, purple arrow) represents the batch-wise average of subject-specific trajectories for sample pairs with u = (i, j). By iteratively updating GΔ , GΔ then approximate the average trajectories derived from the entire training set. Lastly,
284
J. Ouyang et al.
inspired by [11,12], the longitudinal consistency regularization is formulated as Ldir := E(xu ,xv )∼S (1 − cos(θ[Δz, sg[Δgu ]])) , where θ[·, ·] denotes the angle between two vectors. Since Δg is optimized by EMA, the stop-gradient operator is again incorporated to only compute the gradient with respect to Δz in Ldir . Objective Function. The complete objective function is the weighted combination of the prior losses with weighing parameters λcommit , λsom , and λdir : L := Lrecon + λcommit · Lcommit + λsom · Lsom + λdir · Ldir The objective function encourages a smooth trajectory field of aging on the latent space while maintaining interpretable SOM representations for analyzing brain age in a pure self-supervised fashion. 2.2
SOM Similarity Grid
During inference, a (2D) similarity grid ρ is computed by the closeness between the latent representation z of an MRI sample and the SOM representations: ρ := sof tmax(− z − G 22 /γ) with γ := std( z − G 22 ) std denotes the standard deviation of the distance between z to all SOM representations. As the SOM grid is learned to be associated with brain age (e.g., represents aging from left to right), the similarity grid essentially encodes a “likelihood function” of the brain age in z. Given all MRIs of a longitudinal scan, the change across the corresponding similarity grids over time represents the brain aging process of that individual. Furthermore, brain aging on the group-level is captured by first computing the average similarity grid for an age group and then visualizing the difference of those average similarity grids across age groups.
3 3.1
Experiments Experimental Setting
Dataset. We evaluated the proposed method on all 632 longitudinal T1weighted MRIs (at least two visits per subject, 2389 MRIs in total) from ADNI-1 [13]. The data set consists of 185 NC (age: 75.57 ± 5.06 years), 193 subjects diagnosed with sMCI (age: 75.63 ± 6.62 years), 135 subjects diagnosed with pMCI (age: 75.91 ± 5.35 years), and 119 subjects with AD (age: 75.17 ± 7.57 years). There was no significant age difference between the NC and AD cohorts (p = 0.55, two-sample t-test) as well as the sMCI and pMCI cohorts (p = 0.75). All MRI images were preprocessed by a pipeline including denoising, bias field correction, skull stripping, affine registration to a template, re-scaling to 64 × 64 × 64 volume, and transforming image intensities to z-scores.
LSOR
285
Fig. 2. The color at each SOM representation encodes the average value of (a) chronological age, (b) % of AD and pMCI, and (c) ADAS-Cog score across the training samples of that cluster; (d) Confined to the last row of the grid, the average MRI of 20 latent representations closest to the corresponding SOM representation. (Color figure online)
Implementation Details. Let Ck denote a Convolution(kernel size of 3 × 3 × 3, Convk )-BatchNorm-LeakyReLU(slope of 0.2)-MaxPool(kernel size of 2) block with k filters, and CDk an Convolution-BatchNorm-LeakyReLU-Upsample block. The architecture was designed as C16 -C32 -C64 -C16 -Conv16 -CD64 -CD32 CD16 -CD16 -Conv1 , which results in a latent space of 1024 dimensions. The training of SOM is difficult in this high-dimensional space with random initialization in practice, thus we first pre-trained the model with only Lrecon for 10 epochs and initialized the SOM representations by doing k-means of all training samples using this pre-trained model. Then, the network was further trained for 40 epochs with regularization weights set to λrecon = 1.0, λcommit = 0.5, λsom = 1.0, λdir = 0.2. Adam optimizer with learning rate of 5 × 10−4 and weight decay of 10−5 were used. τmin and τmax in Lsom were set as 0.1 and 1.0 respectively. An EMA keep rate of α = 0.99 was used to update reference trajectories. A batch size Nbs = 64 and the SOM grid size Nr = 4, Nc = 8 were applied. Evaluation. We performed five-fold cross-validation (folds split based on subjects) using 10% of the training subjects for validation. The training data was augmented by flipping brain hemispheres and random rotation and translation. To quantify the interpretability of the SOM grid, we correlated the coordinates of the SOM grid with quantitative measures related to brain age, e.g., chronological age, the percentage of subjects with severe cognitive decline, and Alzheimer’s Disease Assessment Scale-Cognitive Subscale (ADAS-Cog). We illustrated the interpretability with respect to brain aging by visualizing the changes in the SOM similarity maps over time. We further visualized the trajectory vector field along with SOM representations by projecting the 1024-dimensional representations to the first two principal components of SOM representations. Lastly, we quantitatively evaluated the quality of the representations by applying them to the downstream tasks of classifying sMCI vs. pMCI and ADAS-Cog prediction. We measured the classification accuracy via Balanced accuracy (BACC) and Area Under Curve (AUC) and the prediction accuracy via R2 and rootmean-square error (RMSE). The classifier and predictor were multi-layer per-
286
J. Ouyang et al.
Fig. 3. The average similarity grid ρ over subjects of a specific age and diagnosis (NC vs AD). Each grid encodes the likelihood of the average brain age of the corresponding sub-cohort. Cog denotes the average ADAS-Cog score.
ceptrons containing two fully connected layers of dimensions 1024 and 64 with a LeakyReLU activation. We compared the accuracy metrics to models using the same architecture with encoders pre-trained by other representation learning methods, including unsupervised methods (AE, VAE [4]), self-supervised method (SimCLR [1]), longitudinal self-supervised method (LSSL [17]), and longitudinal neighborhood embedding (LNE [12]). All comparing methods used the same experimental setup (e.g., encoder-decoder, learning rate, batch size, epochs, etc.), and the method-specific hyperparameters followed [12]. 3.2
Results
Interpretability of SOM Embeddings. Fig. 2 shows the stratification of brain age over the SOM grid G. For each grid entry, we show the average value of chronological age (Fig. 2(a)), % of AD & pMCI (Fig. 2(b)), and ADAS-Cog score (Fig. 2(c)) over samples of that cluster. We observed a trend of older brain age (yellow) from the upper left towards the lower right, corresponding to older chronological age and worse cognitive status. The SOM grid index strongly correlated with these three factors (distance correlation of 0.92, 0.94, and 0.91 respectively). Figure 2(d) shows the average brain over 20 input images with representations that are closest to each SOM representation of the last row of the grid (see Supplement Fig. S1 for all rows). From left to right the ventricles are enlarging and the brain is atrophying, which is a hallmark for brain aging effects. Interpretability of Similarity Grid. Visualizing the average similarity grid ρ of the NC and AD at each age range in Fig. 3, we observed that higher similarity (yellow) gradually shifts towards the right with age in both NC and AD (see Supplemental Fig. S2 for sMCI and pMCI cohorts). However, the shift is faster for AD, which aligns with AD literature reporting that AD is linked to accelerated brain aging [15]. Furthermore, the subject-level aging effects shown in Supplemental Fig. S3 reveal that the proposed visualization could capture subtle morphological changes caused by brain aging. Interpretability of Trajectory Vector Field. Fig. 4 plots the PCA projections of the latent space in 2D, which shows a smooth trajectory field (gray arrows) and reference trajectories GΔ (blue arrows) representing brain aging.
LSOR
287
Table 1. Supervised downstream tasks using the learned representations z (without finetuning the encoder). LSOR achieved comparable or higher accuracy scores than other stateof-the-art self- and un-supervised methods.
Fig. 4. 2D PCA of the LSOR’s latent space. Light gray arrows represent Δz. The orange grid represents the relationships between SOM representations and associated reference trajectory ΔG (blue arrow). (Color figure online)
Methods
sMCI/pMCI ADAS-Cog BACC AUC R2 RMSE
AE
62.6
65.4
0.26
6.98
VAE [4]
61.3
64.8
0.23
7.17
SimCLR [1] 63.3
66.3
0.26
6.79
LSSL [17]
71.8
0.29
6.49 6.46
69.4
LNE [12]
70.6
72.1
0.30
LSOR
69.8
72.4
0.32 6.31
This projection also preserved the 2D grid structure (orange) of the SOM representations suggesting that aging was the most important variation in the latent space. Downstream Tasks. To evaluate the quality of the learned representations, we froze encoders trained by each method without fine-tuning and utilized their representations for the downstream tasks (Table 1). On the task of sMCI vs. pMCI classification (Table 1 (left)), the proposed method achieved a BACC of 69.8 and an AUC of 72.4, a comparable accuracy (p > 0.05, DeLong’s test) with LSSL [17] and LNE [12], two state-of-the-art self-supervised methods on this task. On the ADAS-Cog score regression task, the proposed method obtained the best accuracy with an R2 of 0.32 and an RMSE of 6.31. It is worth mentioning that an accurate prediction of the ADAS-Cog score is very challenging due to its large range (between 0 and 70) and its subjectiveness resulting in large variability across exams [2] so that even larger RMSEs have been reported for this task [7]. Furthermore, our representations were learned in an unsupervised manner so that further fine-tuning of the encoder would improve the prediction accuracy.
4
Conclusion
In this work, we proposed LSOR, the first SOM-based learning framework for longitudinal MRIs that is self-supervised and interpretable. By incorporating a soft SOM regularization, the training of the SOM was stable in the highdimensional latent space of MRIs. By regularizing the latent space based on longitudinal consistency as defined by longitudinal MRIs, the latent space formed a smooth trajectory field capturing brain aging as shown by the resulting SOM grid. The interpretability of the representations was confirmed by the correlation between the SOM grid and cognitive measures, and the SOM similarity map.
288
J. Ouyang et al.
When evaluated on downstream tasks sMCI vs. pMCI classification and ADASCog prediction, LSOR was comparable to or better than representations learned from other state-of-the-art self- and un-supervised methods. In conclusion, LSOR is able to generate a latent space with high interpretability regarding brain age purely based on MRIs, and valuable representations for downstream tasks. Acknowledgement. This work was partly supported by funding from the National Institute of Health (MH113406, DA057567, AA017347, AA010723, AA005965, and AA028840), the DGIST R&D program of the Ministry of Science and ICT of KOREA (22-KUJoint-02), Stanford’s Department of Psychiatry & Behavioral Sciences Faculty Development & Leadership Award, and by Stanford HAI Google Cloud Credit.
References 1. Chen, T., Kornblith, S., Norouzi, M., Hinton, G.: A simple framework for contrastive learning of visual representations. In: International Conference on Machine Learning, pp. 1597–1607. PMLR (2020) 2. Connor, D.J., Sabbagh, M.N.: Administration and scoring variance on the ADASCog. J. Alzheimers Dis. 15(3), 461–464 (2008) 3. Fortuin, V., H¨ user, M., Locatello, F., Strathmann, H., R¨ atsch, G.: SOM-VAE: interpretable discrete representation learning on time series. In: International Conference on Learning Representations (2019) 4. Kingma, D.P., Welling, M.: Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114 (2013) 5. Kohonen, T.: The self-organizing map. Proc. IEEE 78(9), 1464–1480 (1990) 6. Li, O., Liu, H., Chen, C., Rudin, C.: Deep learning for case-based reasoning through prototypes: a neural network that explains its predictions. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 32 (2018) 7. Ma, D., Pabalan, C., Interian, Y., Raj, A.: Multi-task learning and ensemble approach to predict cognitive scores for patients with Alzheimer’s disease. bioRxiv, pp. 2021–12 (2021) 8. Manduchi, L., H¨ user, M., Vogt, J., R¨ atsch, G., Fortuin, V.: DPSOM: deep probabilistic clustering with self-organizing maps. In: Conference on Neural Information Processing Systems Workshop on Machine Learning for Health (2019) 9. Molnar, C.: Interpretable machine learning (2020) 10. Mulyadi, A.W., Jung, W., Oh, K., Yoon, J.S., Lee, K.H., Suk, H.I.: Estimating explainable Alzheimer’s disease likelihood map via clinically-guided prototype learning. Neuroimage 273, 120073 (2023) 11. Ouyang, J., Zhao, Q., Adeli, E., Zaharchuk, G., Pohl, K.M.: Self-supervised learning of neighborhood embedding for longitudinal MRI. Med. Image Anal. 82, 102571 (2022) 12. Ouyang, J., et al.: Self-supervised longitudinal neighbourhood embedding. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 80–89. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87196-3 8 13. Petersen, R.C., et al.: Alzheimer’s disease neuroimaging initiative (ADNI): clinical characterization. Neurology 74(3), 201–209 (2010) 14. Rudin, C.: Stop explaining black box machine learning models for high stakes decisions and use interpretable models instead. Nat. Mach. Intell. 1(5), 206–215 (2019)
LSOR
289
15. Toepper, M.: Dissociating normal aging from Alzheimer’s disease: a view from cognitive neuroscience. J. Alzheimers Dis. 57(2), 331–352 (2017) 16. Van Den Oord, A., Vinyals, O., et al.: Neural discrete representation learning. Adv. Neural Inf. Process. Syst. 30 (2017) 17. Zhao, Q., Liu, Z., Adeli, E., Pohl, K.M.: Longitudinal self-supervised learning. Med. Image Anal. 71, 102051 (2021)
Self-supervised Learning for Physiologically-Based Pharmacokinetic Modeling in Dynamic PET Francesca De Benetti1(B) , Walter Simson2 , Magdalini Paschali3 , Hasan Sari4 , Axel Rominger5 , Kuangyu Shi1,5 , Nassir Navab1 , and Thomas Wendler1 1
2
Chair for Computer-Aided Medical Procedures and Augmented Reality, Technische Universität München, Garching, Germany [email protected] Department of Radiology, Stanford University School of Medicine, Stanford, USA 3 Department of Psychiatry and Behavioral Sciences, Stanford University School of Medicine, Stanford, USA 4 Advanced Clinical Imaging Technology, Siemens Healthcare AG, Lausanne, Switzerland 5 Department of Nuclear Medicine, Bern University Hospital, Bern, Switzerland
Abstract. Dynamic Positron Emission Tomography imaging (dPET) provides temporally resolved images of a tracer. Voxel-wise physiologically-based pharmacokinetic modeling of the Time Activity Curves (TAC) extracted from dPET can provide relevant diagnostic information for clinical workflow. Conventional fitting strategies for TACs are slow and ignore the spatial relation between neighboring voxels. We train a spatio-temporal UNet to estimate the kinetic parameters given TAC from dPET. This work introduces a self-supervised loss formulation to enforce the similarity between the measured TAC and those generated with the learned kinetic parameters. Our method provides quantitatively comparable results at organ level to the significantly slower conventional approaches while generating pixel-wise kinetic parametric images which are consistent with expected physiology. To the best of our knowledge, this is the first self-supervised network that allows voxel-wise computation of kinetic parameters consistent with a non-linear kinetic model. Keywords: Kinetic modelling
1
· PBPK models · Dynamic PET
Introduction
Positron Emission Tomography (PET) is a 3D imaging modality using radiopharmaceuticals, such as F-18-fluorodeoxyglucose (FDG), as tracers. Newly introduced long axial field-of-view PET scanners have enabled dynamic PET (dPET) Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_28. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 290–299, 2023. https://doi.org/10.1007/978-3-031-43907-0_28
Spatio-Temporal Self-supervised Network for PBPK Modeling
291
with frame duration < 1 min [18], allowing the observation of dynamic metabolic processes throughout the body. For a given voxel in space, the radioactivity concentration over time can be described by a characteristic curve known as Time Activity Curve (TAC), measured in [Bq/ml]. TACs can be described by mathematical functions, called physiologically-based pharmacokinetic (PBPK) models or kinetic models (KM) [14]. The parameters of the KM represent physiologically relevant quantities and are often called micro-parameters, whereas their combinations are called macro-parameters [14,20]. While the former can be retrieved only by methods that directly use the KM function, the latter can be computed by simplified linearized methods (such as the Logan and the PatlakGjedde plots). The approaches to extract KM parameters can be split in two categories: Volume of Interest (VoI) methods, in which the average TAC in a VoI is used, or voxel-based methods. Despite the former displaying less noise and, therefore, lower variance in the kinetic parameters (KPs), VoI-based methods only provide organ-wise information. On the other hand, voxel-based methods allow the generation of parametric images (KPIs), in which the KPs are visualized at a voxel level, but suffer from motion and breathing artifacts, and require more computational power or simplified linearized methods. Parametric images are reported to be superior in lesion detection and delineation when compared to standard-of-care activity- and weight-normalized static PET volumes, known as Standard Uptake Value (SUV) volumes [6,7]. Changes in the KPs during oncological therapy are associated with pathological response to treatment, whereas this is not true for changes in SUV [15]. Despite the advantages of KPIs in diagnosis, the generation of accurate micro-parametric images is not yet possible in clinical practice. To address the problem of the generation of micro-parametric images, we propose a custom 3D UNet [3] to estimate kinetic micro-parameters in an unsupervised setting drawing inspiration from physics-informed neural networks (PINN). The main contributions of this work are: – A self-supervised formulation of the problem of kinetic micro-parameters estimation – A spatio-temporal deep neural network for parametric images estimation – A quantitative and qualitative comparison with conventional methods for PBPK modeling The code is available at: https://github.com/FrancescaDB/self_supervised_PBPK_modelling. 1.1
Related Work
Finding the parameters of a KM is a classical optimization problem [2,19,21] solved by fitting the KM equation to a measured TAC in a least squares sense [1, 14,17]. The non-linearity of the KM functions makes this approach prone to overfitting and local minima, and sensitive to noise [14]. Therefore, non-linear parametric imaging is still too noisy for clinical application [20].
292
F. De Benetti et al.
To limit the drawbacks of the non-linear parameter fitting, the identification of KPs is commonly performed using simplified linearized versions of the KM [6,20], such as the Logan and the Patlak-Gjedde plots [5,16], which are often included in clinical software for KM such as PMOD1 . Preliminary works towards KM parameter estimation in dPET imaging have recently begun to be explored. Moradi et al. used an auto-encoder along with a Gaussian process regression block to select the best KM to describe simulated kinetic data [13]. A similar approach was presented for the quantification of myocardial blood flow from simulated PET sinograms [11]. Huang et al. used a supervised 3D U-Net to predict a macro-parametric image using an SUV image as input, and a Patlak image derived from dPET acquisition as ground truth [9]. Cui et al. proposed a conditional deep image prior framework to predict a macro-parametric image using a DNN in an unsupervised setting [4]. Finally, a supervised DNN was used to predict Patlak KPIs from Dynamic PET sinograms [12]. Until now, methods used simulated data [11,13] or static PET [9], were supervised [9,11–13] or predicted macro-parameters [4,9,12].
2
Methodology and Materials
We propose to compute the kinetic micro-parameters in a self-supervised setting by directly including the KM function in the loss function and comparing the predicted TAC to the measured TAC. For this reason, an understanding of the KM is fundamental to describing our pipeline. 2.1
Kinetic Modelling
˜ The concentration of the tracer C(t) [Bq/ml] in each tissue can be described as a set of ordinary differential equations [20]. It represents the interaction of two compartments, F (t) (free) and B(t) (bound), and takes as input the radioactivity concentration in blood or plasma A(t) [14]: dF (t) = K1 A(t) − (k2 + k3 )F (t) dt dB(t) = k3 F (t) − k4 B(t) dt ˜ = F (t) + B(t) C(t)
(1)
where K1 [ml/cm3 /min], k2 [1/min], k3 [1/min] and k4 [1/min] are the microparameters [6,20]. Equation 1 describes a general two-tissue compartment (2TC) kinetic model. However, an FDG-TAC is conventionally described by an irreversible 2TC, in which k4 is set to 0 [20]. Therefore, in the following, we will use k4 = 0. Moreover, including the blood fraction volume VB [·] allows to correctly model the contribution to the radioactivity in a voxel coming from vessels that are too small to be resolved by the PET scanner [14]. 1
https://www.pmod.com.
Spatio-Temporal Self-supervised Network for PBPK Modeling
293
Fig. 1. Proposed Pipeline: a sequence of dPET slices (a) is processed by the proposed DNN (b) to estimate the 4 KPIs (c). During training, the predicted TACs are computed by the KM (d) and compared to the measured TACs using the mean square error (e).
Together, the TAC of each voxel in an FDG dPET acquisition can be modeled ˜ + VB A(t), and solved using the Laplace transform [20]: as C(t) = (1 − VB )C(t) K1 k3 + k2 e−(k2 +k3 )t ∗ A(t) + VB A(t). C(t) = (1 − VB ) (2) k2 + k3
2.2
Proposed Pipeline
Our network takes as input a sequence of 2D axial slices and returns a 4-channel output representing the spatial distribution of the KM parameters of a 2TC for FDG metabolisation [14]. The network has a depth of four, with long [3] and short skip connections [10]. The kernel size of the max-pooling is [2, 2, 2]. After the last decoder block, two 3D convolutional layers (with kernel size [3, 3, 3] and [64, 1, 1]) estimate the KPs per voxel given the output feature of the network. Inside the network the activation function is ELU and critically batch normalization is omitted. The network was trained with an initial learning rate of 10−4 , which was divided by half every 25 epochs, for a maximum of 500 epochs. Following the approach taken by Küstner et al. for motion correction of 4D spatio-temporal CINE MRI [10], we replaced a conventional 3D convolutional layer with (2+1)D spatial and temporal convolutional layers. The spatial convolutional layer is a 3D convolutional layer with kernel size [1, 3, 3] in [t, x, y]. Similarly, the temporal convolutional layer has a kernel size of [3, 1, 1]. We imposed that the KPs predicted by the network satisfy Eq. 2 by including it in the computation of the loss. At a pixel level, we computed the mean squared error between the TAC estimated using the corresponding predicted parameters ˜ i ) and the measured one (TACi ), as seen in Fig. 1. (TAC
294
F. De Benetti et al.
We introduced a final activation function to limit the output of the network to the valid parameter domain of the KM function. Using the multi-clamp function, each channel of the logits is restricted to the following parameter spaces: K1 ∈ [0.01, 2], k2 ∈ [0.01, 3], k3 ∈ [0.01, 1], and VB ∈ [0, 1]. The limits of the ranges were defined based on the meaning of the parameter (as in VB ), mathematical requirements (as in the minimum values of k2 and k3 , whose sum can not be zero) [6] or previous knowledge on the dataset derived by the work of Sari et al. [16] (as in the maximum values of K1 , k2 and k3 ). We evaluated the performance of the network using the Mean Absolute Error ˜ i. (MAE) and the Cosine Similarity (CS) between TACi and TAC 2.3
Curve Fit
For comparison, parameter optimization via non-linear fitting was implemented in Python using the scipy.optimize.curve_fit function (version 1.10), with step equal to 0.001. The bounds were the same as in the DNN. 2.4
Dataset
The dataset is composed of 23 oncological patients with different tumor types. dPET data was acquired on a Biograph Vision Quadra for 65 min, over 62 frames. The exposure duration of the frames were 2 × 10 s, 30 × 2 s, 4 × 10 s, 8 × 30 s, 4 × 60 s, 5 × 120 s and 9 × 300 s. The PET volumes were reconstructed with an isotropic voxel size of 1.6 mm. The dataset included the label maps of 7 organs (bones, lungs, heart, liver, kidneys, spleen, aorta) and one image-derived input function A(t) [Bq/ml] from the descending aorta per patient. Further details on the dataset are presented elsewhere [16]. The PET frames and the label map were resampled to an isotropic voxel size of 2.5 mm. Then, the dataset was split patient-wise into training, validation, and test set, with 10, 4, and 9 patients respectively. Details on the dataset split are available in the Supplementary Material (Table 1). The training set consisted of 750 slices and the validation consisted of 300. In both cases, 75 axial slices per patient were extracted in a pre-defined patient-specific range from the lungs to the bladder (included) and were cropped to size 112 × 112 pixels.
3
Results
Table 1 shows the results of the 8 ablation studies we performed to find the best model. We evaluated the impact of the design of the convolutional and maxpooling kernels, as well as the choice of the final activation function. The design of the max pooling kernel (i.e., kernel size equal to [2, 2, 2] or [1, 2, 2]) had no measurable effects in terms of CS in most of the experiments, with the exception of Exp. 3.2, where max-pooling only in space resulted in a drop of 0.06. When evaluating the MAE, the use of 3D max-pooling was generally better.
Spatio-Temporal Self-supervised Network for PBPK Modeling
295
Table 1. Configurations and metrics of ablation studies for architecture optimization.
Exp Convolution
Pooling Final activation CS ↑
MAE ↓
1.1
3D
3D
Absolute
0.74 ± 0.05
3.55 ± 2.12
1.2
3D
space
Absolute
0.74 ± 0.05
3.64 ± 2.21
2.1
space + time 3D
Absolute
0.75 ± 0.05
3.59 ± 2.33
2.2
space + time space
Absolute
0.75 ± 0.05
3.67 ± 2.20
3.1
space + time 3D
Clamp
0.75 ± 0.05
3.48 ± 2.04
3.2
space + time space
Clamp
0.69 ± 0.05
3.55 ± 2.07
4.1
space + time 3D
Multi-clamp
0.78 ± 0.05 3.28 ± 2.03
4.2
space + time space
Multi-clamp
0.77 ± 0.05
3.27 ± 2.01
Fig. 2. Comparison between the kinetic parameters obtained with different methods: KPDNN in blue, KPCF in orange and, as plausibility check, KPref CF [16] in green. The exact values are reported in the Supplementary Material (Table 3 and 4) and in [16]. (Color figure online)
The most important design choice is the selection of the final activation function. Indeed, the multi-clamp final activation function was proven to be the best both in terms of CS (Exp 4.1: CS = 0.78 ± 0.05) and MAE (Exp 4.2: MAE = 3.27 ± 2.01). Compared to the other final activation functions, when the multi-clamp is used the impact of the max-pooling design is negligible also in terms of MAE. For the rest of the experiments, the selected configuration is the one from Exp. 4.1 (see Table 1). Figure 2 shows the KPs for four selected organs as computed with the proposed DNN (KPDNN ), as computed with curve fit using only the 9 patients of the test set (KPCF ) and using all 23 patients (KPref CF ) [16]. The voxel-wise KPs predicted by the DNN were averaged over the available organ masks. In terms of run-time, the DNN needed ≈ 1 min to predict the KPs of the a whole-body scan (≈ 400 slices), whereas curve fit took 8.7 min for a single slice: the time reduction of the DNN is expected to be ≈ 3.500 times.
4
Discussion
Even though the choice of the final activation function has a greater impact, the selection of the kernel design is important. Using spatial and temporal convo-
296
F. De Benetti et al.
Fig. 3. (a) Cosine similarity (CS) per slice in patient 23 (blue: lungs; red: lungs and heart; green: liver). (b–e) Parametric images of a coronal slice for the same patient. (Color figure online)
lution results in an increase in the performance (+0.01 in CS) and reduces the number of trainable parameters (from 2.1 M to 8.6 K), as pointed out by [10]. Therefore, the convergence is reached faster. Moreover, the use of two separate kernels in time and space is especially meaningful. Pixel counts for a given exposure are affected by the neighboring count measurements due to the limited resolution of the PET scanner [20]. The temporally previous or following counts are independent. In general, there is good agreement between KPDNN , KPCF and KPref CF . The DNN prediction of K1 and k2 in the spleen and k3 in the lungs is outside the confidence interval of the results published by Sari et al. [16]. An analysis per slice of the metrics shows that the CS between TACi and ˜ i changes substantially depending on the region: CSmax = 0.87 within the TAC liver boundaries and CSmin = 0.71 in the region corresponding to the heart and lungs (see Fig. 3a). This can be explained by the fact that VB is underestimated for the heart and aorta. The proposed network predicts VBheart = 0.376 ± 0.133 and VBaorta = 0.622 ± 0.238 while values of nearly 1 are to be expected. This is likely due to breathing and heartbeat motion artifacts, which cannot be modeled properly with a 2TC KM that assumes no motion between frames. Figure 3b–e shows the central coronal slice of the four KPIs in an exemplary patient. As expected, K1 is high in the heart, liver, and kidney. Similarly, the blood fraction volume VB is higher in the heart, blood vessels, and lungs. The KPDNN are more homogeneous than KPCF , as can be seen in the exemplary K1 axial slice shown in Fig. 4. A quantitative evaluation of the smoothness of the images is reported in the Supplementary Material (Fig. 1). Moreover, the distribution in the liver is more realistic in KPDNN , where the gallbladder can be seen as an ellipsoid between the right and left liver lobes. High K1 regions are mainly within the liver, spleen, and kidney for KPDNN , while they also appear in unexpected areas in the KPCF (e.g., next to the spine or in the region of the stomach). The major limitation of this work is the lack of ground truth and a canonical method to evaluate quantitatively its performance. This limitation is inherent
Spatio-Temporal Self-supervised Network for PBPK Modeling
297
Fig. 4. Comparison of K1 parametric images for an axial slice of patient 2, with contours of the liver (left), the spleen (center) and the left kidney (right).
to PBPK modeling and results in the need for qualitative analyses based on expected physiological processes. A possible way to leverage this would be to work on simulated data, yet the validity of such evaluations strongly depends on how realistic the underlying simulation models are. As seen in Fig. 3a, motion (gross, respiratory, or cardiac) has a major impact on the estimation quality. Registering different dPET frames has been shown to improve conventional PBPK models [8] and would possibly have a positive impact on our approach.
5
Conclusion
In this work, inspired by PINNs, we combine a self-supervised spatio-temporal DNN with a new loss formulation considering physiology to perform kinetic modeling of FDG dPET. We compare the best DNN model with the most commonly used conventional PBPK method, curve fit. While no ground truth is available, the proposed method provides similar results to curve fit but qualitatively more plausible images in physiology and with a radically shorter run-time. Further, our approach can be applied to other KMs without significantly increasing the complexity and the need for computational power. In general, Eq. 2 should be modified to represent the desired KM [20], and the number of channels of the output of the network should be the same as the KP to be predicted. Overall, this work offers scalability and a new research direction for analysing pharmacokinetics. Acknowledgements. This work was partially funded by the German Research Foundation (DFG, grant NA 620/51-1).
298
F. De Benetti et al.
References 1. Avula, X.J.: Mathematical modeling. In: Meyers, R.A. (ed.) Encyclopedia of Physical Science and Technology, 3rd edn., pp. 219–230. Academic Press, New York (2003) 2. Besson, F.L., et al.: 18F-FDG PET and DCE kinetic modeling and their correlations in primary NSCLC: first voxel-wise correlative analysis of human simultaneous [18F] FDG PET-MRI data. EJNMMI Res. 10(1), 1–13 (2020) 3. Çiçek, Ö., Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-Net: learning dense volumetric segmentation from sparse annotation. In: Ourselin, S., Joskowicz, L., Sabuncu, M.R., Unal, G., Wells, W. (eds.) MICCAI 2016. LNCS, vol. 9901, pp. 424–432. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946723-8_49 4. Cui, J., Gong, K., Guo, N., Kim, K., Liu, H., Li, Q.: Unsupervised PET logan parametric image estimation using conditional deep image prior. Med. Image Anal. 80, 102519 (2022) 5. Dias, A.H., Hansen, A.K., Munk, O.L., Gormsen, L.C.: Normal values for 18F-FDG uptake in organs and tissues measured by dynamic whole body multiparametric FDG PET in 126 patients. EJNMMI Res. 12(1), 1–14 (2022) 6. Dimitrakopoulou-Strauss, A., Pan, L., Sachpekidis, C.: Kinetic modeling and parametric imaging with dynamic PET for oncological applications: general considerations, current clinical applications, and future perspectives. Eur. J. Nucl. Med. Mol. Imaging 48, 21–39 (2021). https://doi.org/10.1007/s00259-020-04843-6 7. Fahrni, G., Karakatsanis, N.A., Di Domenicantonio, G., Garibotto, V., Zaidi, H.: Does whole-body Patlak 18 F-FDG PET imaging improve lesion detectability in clinical oncology? Eur. Radiol. 29, 4812–4821 (2019). https://doi.org/10.1007/ s00330-018-5966-1 8. Guo, X., Zhou, B., Chen, X., Liu, C., Dvornek, N.C.: MCP-Net: inter-frame motion correction with Patlak regularization for whole-body dynamic PET. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022, Part IV. LNCS, vol. 13434, pp. 163–172. Springer, Cham (2022). https://doi.org/10.1007/978-3031-16440-8_16 9. Huang, Z., et al.: Parametric image generation with the uEXPLORER total-body PET/CT system through deep learning. Eur. J. Nucl. Med. Mol. Imaging 49(8), 2482–2492 (2022). https://doi.org/10.1007/s00259-022-05731-x 10. Küstner, T., et al.: CINENet: deep learning-based 3D cardiac CINE MRI reconstruction with multi-coil complex-valued 4D spatio-temporal convolutions. Sci. Rep. 10(1), 13710 (2020) 11. Li, A., Tang, J.: Direct parametric image reconstruction for dynamic myocardial perfusion PET using artificial neural network representation (2022) 12. Li, Y., et al.: A deep neural network for parametric image reconstruction on a large axial field-of-view PET. Eur. J. Nucl. Med. Mol. Imaging 50(3), 701–714 (2023). https://doi.org/10.1007/s00259-022-06003-4 13. Moradi, H., Vegh, V., O’Brien, K., Hammond, A., Reutens, D.: FDG-PET kinetic model identifiability and selection using machine learning (2022) 14. Pantel, A.R., Viswanath, V., Muzi, M., Doot, R.K., Mankoff, D.A.: Principles of tracer kinetic analysis in oncology, part I: principles and overview of methodology. J. Nucl. Med. 63(3), 342–352 (2022) 15. Pantel, A.R., Viswanath, V., Muzi, M., Doot, R.K., Mankoff, D.A.: Principles of tracer kinetic analysis in oncology, part II: examples and future directions. J. Nucl. Med. 63(4), 514–521 (2022)
Spatio-Temporal Self-supervised Network for PBPK Modeling
299
16. Sari, H., et al.: First results on kinetic modelling and parametric imaging of dynamic 18 F-FDG datasets from a long axial FOV PET scanner in oncological patients. Eur. J. Nucl. Med. Mol. Imaging 49, 1997–2009 (2022). https://doi.org/ 10.1007/s00259-021-05623-6 17. Snyman, J.A., Wilke, D.N., et al.: Practical Mathematical Optimization. Springer, New York (2005). https://doi.org/10.1007/b105200 18. Surti, S., Pantel, A.R., Karp, J.S.: Total body PET: why, how, what for? IEEE Trans. Radiat. Plasma Med. Sci. 4(3), 283–292 (2020) 19. Wang, G., et al.: Total-body PET multiparametric imaging of cancer using a voxelwise strategy of compartmental modeling. J. Nucl. Med. 63(8), 1274–1281 (2022) 20. Watabe, H.: Compartmental modeling in PET kinetics. In: Khalil, M.M. (ed.) Basic Science of PET Imaging, pp. 323–352. Springer, Cham (2017). https://doi. org/10.1007/978-3-319-40070-9_14 21. Zuo, Y., Sarkar, S., Corwin, M.T., Olson, K., Badawi, R.D., Wang, G.: Structural and practical identifiability of dual-input kinetic modeling in dynamic PET of liver inflammation. Phys. Med. Biol. 64(17), 175023 (2019)
Geometry-Invariant Abnormality Detection Ashay Patel(B) , Petru-Daniel Tudosiu , Walter Hugo Lopez Pinaya , Olusola Adeleke , Gary Cook , Vicky Goh , Sebastien Ourselin , and M. Jorge Cardoso King’s College London, London WC2R 2LS, UK [email protected] Abstract. Cancer is a highly heterogeneous condition best visualised in positron emission tomography. Due to this heterogeneity, a generalpurpose cancer detection model can be built using unsupervised learning anomaly detection models. While prior work in this field has showcased the efficacy of abnormality detection methods (e.g. Transformer-based), these have shown significant vulnerabilities to differences in data geometry. Changes in image resolution or observed field of view can result in inaccurate predictions, even with significant data pre-processing and augmentation. We propose a new spatial conditioning mechanism that enables models to adapt and learn from varying data geometries, and apply it to a state-of-the-art Vector-Quantized Variational Autoencoder + Transformer abnormality detection model. We showcase that this spatial conditioning mechanism statistically-significantly improves model performance on whole-body data compared to the same model without conditioning, while allowing the model to perform inference at varying data geometries.
1
Introduction
The use of machine learning for anomaly detection in medical imaging analysis has gained a great deal of traction over previous years. Most recent approaches have focused on improvements in performance rather than flexibility, thus limiting approaches to specific input types – little research has been carried out to generate models unhindered by variations in data geometries. Often, research assumes certain similarities in data acquisition parameters, from image dimensions to voxel dimensions and fields-of-view (FOV). These restrictions are then carried forward during inference [5,25]. This strong assumption can often be complex to maintain in the real-world and although image pre-processing steps can mitigate some of this complexity, test error often largely increases as new data variations arise. This can include variances in scanner quality and resolution, in addition to the FOV selected during patient scans. Usually training data, especially when acquired from differing sources, undergoes significant preprocessing such that data showcases the same FOV and has the same input Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_29. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 300–309, 2023. https://doi.org/10.1007/978-3-031-43907-0_29
Geometry-Invariant Abnormality Detection
301
dimensions, e.g. by registering data to a population atlas. Whilst making the model design simpler, these pre-processing approaches can result in poor generalisation in addition to adding significant pre-processing times [11,13,26]. Given this, the task of generating an anomaly detection model that works on inputs with a varying resolution, dimension and FOV is a topic of importance and the main focus of this research. Unsupervised methods have become an increasingly prominent field for automatic anomaly detection by eliminating the necessity of acquiring accurately labelled data [4,7] therefore relaxing the stringent data requirements of medical imaging. This approach consists of training generative models on healthy data, and defining anomalies as deviations from the defined model of normality during inference. Until recently, the variational autoencoder (VAE) and its variants held the state-of-the-art for the unsupervised approach. However, novel unsupervised anomaly detectors based on autoregressive Transformers coupled with VectorQuantized Variational Autoencoders (VQ-VAE) have overcome issues associated with autoencoder-only methods [21,22]. In [22], the authors explore the advantage of tractably maximizing the likelihood of the normal data to model the long-range dependencies of the training data. The work in [21] takes this method a step further through multiple samplings from the Transformer to generate a non-parametric Kernel Density Estimation (KDE) anomaly map. Even though these methods are state-of-the-art, they have stringent data requirements, such as having a consistent geometry of the input data, e.g., in a whole-body imaging scenario, it is not possible to crop a region of interest and feed it to the algorithm, as this cropped region will be wrongly detected as an anomaly. This would happen even in the case that a scan’s original FOV was restricted [17]. As such, we propose a geometric-invariant approach to anomaly detection, and apply it to cancer detection in whole-body PET via an unsupervised anomaly detection method with minimal spatial labelling. Through adapting the VQVAE Transformer approach in [21], we showcase that we can train our model on data with varying fields of view, orientations and resolutions by adding spatial conditioning in both the VQ-VAE and Transformer. Furthermore, we show that the performance of our model with spatial conditioning is at least equivalent to, and sometimes better, than a model trained on whole-body data in all testing scenarios, with the added flexibility of a “one model fits all data” approach. We greatly reduce the pre-processing requirements for generating a model (as visualised in Fig. 1), demonstrating the potential use cases of our model in more flexible environments with no compromises on performance.
2
Background
The main building blocks behind the proposed method are introduced below. Specifically, a VQ-VAE plus a Transformer are jointly used to learn the probability density function of 3D PET images as explored in prior research [21,22,24].
302
A. Patel et al.
Fig. 1. Flowchart showcasing traditional data pipelines for developing machine learning models in medical imaging (top) vs. the reduced pipeline for our approach (bottom)
2.1
Vector-Quantized Variational Autoencoder
The VQ-VAE model provides a data-efficient encoding mechanism—enabling 3D inputs at their original resolution—while generating a discrete latent representation that can trivially be learned by a Transformer network [20]. The VQ-VAE is composed of an encoder that maps an image X ∈ RH×W ×D onto a compressed latent representation Z ∈ Rh×w×d×nz where nz is the latent embedding vector dimension. Z is then passed through a quantization block where each feature column vector is mapped to its nearest codebook vector. Each spatial code Zijl ∈ Rnz is then replaced by its nearest codebook element ek ∈ Rnz , k ∈ 1, ..., K where K denotes the codebook vocabulary size, thus obtaining Zq . Given Zq , ˆ ∈ RH×W ×D . The the VQ-VAE decoder then reconstructs the observations X architecture used for the VQ-VAE model used an encoder consisting of three downsampling layers that contain a convolution with stride 2 and kernel size 4 followed by a ReLU activation and 3 residual blocks. Each residual block consists of a kernel of size 3, followed by a ReLU activation, a convolution of kernel size 1 and another ReLU activation. Similar to the encoder, the decoder has 3 layers of 3 residual blocks, each followed by a transposed convolutional layer with stride 2 and kernel size 4. Finally, before the last transposed convolutional layer, a Dropout layer with a probability of 0.05 is added. The VQ-VAE codebook used had 256 atomic elements (vocabulary size), each of length 128. The CT VQ-VAE was identical in hyperparameters except each codebook vector has length 64. See Appendix A for implementation details. 2.2
Transformer
After training a VQ-VAE model, the next stage is to learn the probability density function of the discrete latent representations. Using the VQ-VAE, we can obtain a discrete representation of the latent space by replacing the codebook elements in Zq with their respective indices in the codebook yielding Ziq . To model the imaging data, we require the discretized latent space Ziq to take the form of a 1D sequence s, which we achieve via a raster scan of the latent. The Transformer is then trained to maximize the log-likelihoods of the latent tokens sequence in an autoregressive manner. By doing this, the Transformer can learn the codebook distribution for position i within s with respect to previous codes
Geometry-Invariant Abnormality Detection
303
p(si ) = p(si |s 0 ij v i − (3) Es (v) = es v i − v j = v i − v j −1 , s = 0 ij log i=1 j=1, ji
Here v is sampled feature of each image with point-wise embedding vi and L is the length of the feature, which is also the number of sampled voxels. We randomly sample from the whole case so that the features can better express the overall representational power of the model. The feature vectors will be more widely dispersed in the unit sphere if the hyperspherical energy (HSE) is lower [3]. For the dataset with N cases, we choose s = 1 and the feature variety Fv is formulated as N 1 −1 E (v) (4) Fv = N i=1 s Overall Estimation. As for semantic segmentation problems, the feature pyramid structure is critical for segmentation results [14,29]. Hence in our framework, different decoders’ outputs are upsampled to the size of the output and can be used in the sliding window sampling process. Besides, we decrease the sampling ratio in the decoder layer close to the bottleneck to avoid feature redundancy. The final transferability of pre-trained model m to dataset t Tm→t is Tm→t =
D Fi 1 log i v D i=1 Ccons
where D is the number of decoder layers used in the estimation.
(5)
Pick the Best Pre-trained Model
3 3.1
679
Experiment Experiment on MSD Dataset
The Medical Segmentation Decathlon (MSD) [2] dataset is composed of ten different datasets with various challenging characteristics, which are widely used in the medical image analysis field. To evaluate the effectiveness of CC-FV, we conduct extensive experiments on 5 of the MSD dataset, including Task03 Liver(liver and tumor segmentation), Task06 Lung(lung nodule segmentation), Task07 Pancreas(pancreas and pancreas tumor segmentation), Task09 Spleen(spleen segmentation), and Task10 Colon(colon cancer segmentation). All of the datasets are 3D CT images. The public part of the MSD dataset is chosen for our experiments, and each dataset is divided into a training set and a test set at a scale of 80% and 20%. For each dataset, we use the other four datasets to pre-train the model and fine-tune the model on this dataset to evaluate the performance as well as the transferability using the correlation between two ranking sequences of upstream pre-trained models. We load all the pre-trained models’ parameters except for the last convolutional layer and no parameters are frozen during the fine-tuning process. On top of that, we follow the nnUNet [11] with the selfconfiguring method to choose the pre-processing, training, and post-processing strategy. For fair comparisons, the baseline methods including TransRate [9], LogME [27], GBC [17] and LEEP [15] are also implemented. For these currently available methods, we employ the output of the layer before the final convolution as the feature map and sample it through the same sliding window as CC-FV to obtain different classes of features, which can be used for the calculation. Figure 2 visualizes the average Dice score and the estimation value on Task 03 Liver. The TE results are obtained from the training set only. U-Net [20] and UNETR [8] are applied in the experiment and each model is pre-trained for 250k iterations and fine-tuned for 100k iterations with batch size 2 on a single NVIDIA A100 GPU. Besides, we use the model at the end of training for inference and calculate the final DSC performance on the test set. And we use weighted Kendall’s τ [27] and Pearson correlation coefficient for the correlation between the TE results and fine-tuning performance. The Kendall’s τ ranges from [-1, 1], and τ=1 means the rank of TE results and performance j j i i > Ts→t if and only if Ps→t > Ps→t ). Since model are perfectly correlated(Ts→t selection generally picks the top models and ignores the poor performers, we assign a higher weight to the good models in the calculation, known as weighted Kendall’s τ. The Pearson coefficient also ranges from [-1, 1], and measures how well the data can be described by a linear equation. The higher the Pearson coefficient, the higher the correlation between the variables. It is clear that the TE results of our method have a more positive correlation with respect to DSC performance. Table 1 demonstrates that our method surpasses all the other methods. Most of the existing methods are inferior to ours because they are not designed for segmentation tasks with a serious class imbalance problem. Besides, these methods rely only on single-layer features and do not make good use of the hierarchical structure of the model.
680
Y. Yang et al.
Fig. 2. Correlation between the fine-tuning performance and transferability metrics using Task03 as an example. The vertical axis represents the average Dice of the model, while the horizontal axis represents the transferability metric results. We have standardized the various metrics uniformly, aiming to observe a positive relationship between higher performance and higher transferability estimations. Table 1. Pearson coefficient and weighted Kendall’s τ for transferability estimation Data/Method Metrics Task03 Task06 Task07 Task09 Task10 Avg LogME
τ –0.1628 –0.0988 0.3280 pearson 0.0412 0.5713 0.3236
0.2778 0.2725
–0.2348 0.0218 –0.1674 0.2082
TransRate
τ –0.1843 –0.1028 0.5923 pearson –0.5178 –0.2804 0.7170
0.4322 0.5573
0.6069 0.7629
0.2688 0.2478
LEEP
τ 0.6008 pearson 0.6765
0.1658 0.2691 –0.0073 0.7146
0.3516 0.1633
0.5841 0.4979
0.3943 0.4090
GBC
τ 0.1233 –0.1569 0.6637 pearson –0.2634 –0.3733 0.7948
0.7611 0.7604
0.6643 0.7404
0.4111 0.3317
Ours CC-FV
τ 0.6374 pearson 0.8608
0.5700 0.7491
0.5550 0.8406
0.4986 0.7003
0.0735 0.0903
0.6569 0.9609
Pick the Best Pre-trained Model
681
Fig. 3. Visualization of features with same labels using t-SNE. Points with different colors are from different samples. Pre-trained models tend to have a more consistent distribution within a class than the randomly initialized model and after fine-tuning they often have a better Dice performance than the randomly initialized models. Table 2. Ablation on the effectiveness of different parts in our methods
3.2
Data/Method
Task03 Task06 Task07 Task09 Task10 Avg
Ours CC-FV
0.6374
0.0735
0.6569
0.5700
0.5550
0.4986
Ours w/o Ccons 0.1871
–0.2210 –0.2810 –0.0289 –0.2710 –0.1230
Ours w/o Fv
0.6165
0.3235
0.6054
Single-scale
0.4394
0.0252
0.5336
KL-divergence
–0.5658 –0.0564 0.2319
Bha-distance
0.1808
0.7866
0.4650
0.0723
0.2295
0.2761
0.5269
0.4697
0.5759
0.6007
0.4341
0.4628
–0.0323 0.0080 0.3468
Ablation Study
In Table 2 we analyze the different parts of our method and compare some other methods. First, we analyze the impact of class consistency Ccons and feature variety Fv . Though Fv can not contribute to the final Kendall’s τ directly, Ccons with the constraint of Fv promotes the total estimation result. Then we compare the performance of our method at single and multiple scales to prove the effectiveness of our multi-scale strategy. Finally, we change the distance metrics in class consistency estimation. KL-divergence and Bha-distance are unstable in high dimension matrics calculation and the performance is also inferior to the Wasserstein distance. Figure 3 visualize the distribution of different classes using t-SNE methods. We can easily find that with models with a pre-training process have a more compact intra-class distance and a higher fine-tuning performance.
4
Conclusion
In our work, we raise the problem of model selection for upstream and downstream transfer processes in the medical image segmentation task and analyze
682
Y. Yang et al.
the practical implications of this problem. In addition, due to the ethical and privacy issues inherent in medical care and the computational load of 3D image segmentation tasks, we design a generic framework for the task and propose a transferability estimation method based on class consistency with feature variety constraint, which outperforms existing model transferability estimation methods as demonstrated by extensive experiments. Acknowledgement. This work was supported by the Open Funding of Zhejiang Laboratory under Grant 2021KH0AB03, NSFC China (No. 62003208); Committee of Science and Technology, Shanghai, China (No.19510711200); Shanghai Sailing Program (20YF1420800), and Shanghai Municipal of Science and Technology Project (Grant No.20JC1419500, No. 20DZ2220400).
References 1. Agostinelli, A., Uijlings, J., Mensink, T., Ferrari, V.: Transferability metrics for selecting source model ensembles. In: CVPR, pp. 7936–7946 (2022) 2. Antonelli, M., et al.: The medical segmentation decathlon. Nat. Commun. 13(1), 4128 (2022) 3. Chen, W., et al.: Contrastive syn-to-real generalization. arXiv preprint arXiv:2104.02290 (2021) 4. Chen, X., Wang, S., Fu, B., Long, M., Wang, J.: Catastrophic forgetting meets negative transfer: batch spectral shrinkage for safe transfer learning. In: NIPS 32 (2019) 5. Cui, Q., et al.: Discriminability-transferability trade-off: an information-theoretic perspective. In: ECCV, pp. 20–37. Springer (2022). https://doi.org/10.1007/9783-031-19809-0_2 6. Dwivedi, K., Huang, J., Cichy, R.M., Roig, G.: Duality diagram similarity: a generic framework for initialization selection in task transfer learning. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12371, pp. 497–513. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58574-7_30 7. Dwivedi, K., Roig, G.: Representation similarity analysis for efficient task taxonomy & transfer learning. In: CVPR, pp. 12387–12396 (2019) 8. Hatamizadeh, A., et al.: Unetr: transformers for 3d medical image segmentation. In: CVPR, pp. 574–584 (2022) 9. Huang, L.K., Huang, J., Rong, Y., Yang, Q., Wei, Y.: Frustratingly easy transferability estimation. In: ICML, pp. 9201–9225. PMLR (2022) 10. Irvin, J., et al.: Chexpert: a large chest radiograph dataset with uncertainty labels and expert comparison. In: AAAI, vol. 33, pp. 590–597 (2019) 11. Isensee, F., Jaeger, P.F., Kohl, S.A., Petersen, J., Maier-Hein, K.H.: nnu-net: a self-configuring method for deep learning-based biomedical image segmentation. Nat. Methods 18(2), 203–211 (2021) 12. Li, X., et al.: Delta: deep learning transfer using feature map with attention for convolutional networks. arXiv preprint arXiv:1901.09229 (2019) 13. Li, Y., et al.: Ranking neural checkpoints. In: CVPR, pp. 2663–2673 (2021) 14. Lin, T.Y., Dollár, P., Girshick, R., He, K., Hariharan, B., Belongie, S.: Feature pyramid networks for object detection. In: CVPR, pp. 2117–2125 (2017)
Pick the Best Pre-trained Model
683
15. Nguyen, C., Hassner, T., Seeger, M., Archambeau, C.: Leep: a new measure to evaluate transferability of learned representations. In: ICML, pp. 7294–7305. PMLR (2020) 16. Panaretos, V.M., Zemel, Y.: Statistical aspects of wasserstein distances. Annual Rev. Stat. Appli. 6, 405–431 (2019) 17. Pándy, M., Agostinelli, A., Uijlings, J., Ferrari, V., Mensink, T.: Transferability estimation using bhattacharyya class separability. In: CVPR, pp. 9172–9182 (2022) 18. Paszke, A., et al.: Pytorch: an imperative style, high-performance deep learning library. In: Advances in Neural Information Processing Systems 32 (2019) 19. Reiss, T., Cohen, N., Bergman, L., Hoshen, Y.: Panda: adapting pretrained features for anomaly detection and segmentation. In: CVPR, pp. 2806–2814 (2021) 20. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28 21. Tajbakhsh, N., et al.: Convolutional neural networks for medical image analysis: full training or fine tuning? IEEE Trans. Med. Imaging 35(5), 1299–1312 (2016) 22. Tong, X., Xu, X., Huang, S.L., Zheng, L.: A mathematical framework for quantifying transferability in multi-source transfer learning. NIPS 34, 26103–26116 (2021) 23. Wang, T., Isola, P.: Understanding contrastive representation learning through alignment and uniformity on the hypersphere. In: ICML, pp. 9929–9939. PMLR (2020) 24. Wang, Z., Dai, Z., Póczos, B., Carbonell, J.: Characterizing and avoiding negative transfer. In: CVPR, pp. 11293–11302 (2019) 25. Wolf, T., et al.: Transformers: State-of-the-art natural language processing. In: EMNLP, pp. 38–45 (2020) 26. Xuhong, L., Grandvalet, Y., Davoine, F.: Explicit inductive bias for transfer learning with convolutional networks. In: ICML, pp. 2825–2834. PMLR (2018) 27. You, K., Liu, Y., Wang, J., Long, M.: Logme: practical assessment of pre-trained models for transfer learning. In: ICM, pp. 12133–12143. PMLR (2021) 28. Zamir, A.R., Sax, A., Shen, W., Guibas, L.J., Malik, J., Savarese, S.: Taskonomy: disentangling task transfer learning. In: CVPR, pp. 3712–3722 (2018) 29. Zhao, H., Shi, J., Qi, X., Wang, X., Jia, J.: Pyramid scene parsing network. In: CVPR, pp. 2881–2890 (2017) 30. Zhou, Z., Shin, J.Y., Gurudu, S.R., Gotway, M.B., Liang, J.: Active, continual fine tuning of convolutional neural networks for reducing annotation efforts. Med. Image Anal. 71, 101997 (2021)
Source-Free Domain Adaptive Fundus Image Segmentation with Class-Balanced Mean Teacher Longxiang Tang1 , Kai Li2(B) , Chunming He1 , Yulun Zhang3 , and Xiu Li1(B) 1
Tsinghua Shenzhen International Graduate School, Tsinghua University, Shenzhen, China {lloong.x,chunminghe19990224}@gmail.com, [email protected] 2 NEC Laboratories America, Princeton, USA [email protected] 3 ETH Zurich, Zürich, Switzerland [email protected]
Abstract. This paper studies source-free domain adaptive fundus image segmentation which aims to adapt a pretrained fundus segmentation model to a target domain using unlabeled images. This is a challenging task because it is highly risky to adapt a model only using unlabeled data. Most existing methods tackle this task mainly by designing techniques to carefully generate pseudo labels from the model’s predictions and use the pseudo labels to train the model. While often obtaining positive adaption effects, these methods suffer from two major issues. First, they tend to be fairly unstable - incorrect pseudo labels abruptly emerged may cause a catastrophic impact on the model. Second, they fail to consider the severe class imbalance of fundus images where the foreground (e.g., cup) region is usually very small. This paper aims to address these two issues by proposing the Class-Balanced Mean Teacher (CBMT) model. CBMT addresses the unstable issue by proposing a weak-strong augmented mean teacher learning scheme where only the teacher model generates pseudo labels from weakly augmented images to train a student model that takes strongly augmented images as input. The teacher is updated as the moving average of the instantly trained student, which could be noisy. This prevents the teacher model from being abruptly impacted by incorrect pseudo-labels. For the class imbalance issue, CBMT proposes a novel loss calibration approach to highlight foreground classes according to global statistics. Experiments show that CBMT well addresses these two issues and outperforms existing methods on multiple benchmarks.
Keywords: Source-free domain adaptation teacher
· Fundus image · Mean
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_65. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 684–694, 2023. https://doi.org/10.1007/978-3-031-43907-0_65
Class-Balanced Mean Teacher
1
685
Introduction
Medical image segmentation plays an essential role in computer-aided diagnosis systems in different applications and has been tremendously advanced in the past few years [6,12,19,22]. While the segmentation model [3,10,11,21] always requires sufficient labeled data, unsupervised domain adaptation (UDA) approaches have been proposed, learning an adaptive model jointly with unlabeled target domain images and labeled source domain images [9], for example, the adversarial training paradigm [2,8,13,14,16]. Although impressive performance has been achieved, these UDA methods may be limited for some real-world medical image segmentation tasks where labeled source images are not available for adaptation. This is not a rare scenario because medical images are usually highly sensitive in privacy and copyright protection such that labeled source images may not be allowed to be distributed. This motivates the investigation of source-free domain adaptation (SFDA) where adapts a source segmentation model trained on labeled source data (in a privateprotected way) to the target domain only using unlabeled data. A few recent SFDA works have been proposed. OSUDA [17] utilizes the domain-specific low-order batch statistics and domain-shareable high-order batch statistics, trying to adapt the former and keep the consistency of the latter. SRDA [1] minimizes a label-free entropy loss guided with a domaininvariant class-ratio prior. DPL [4] introduces pixel-level and class-level pseudolabel denoising schemes to reduce noisy pseudo-labels and select reliable ones. U-D4R [27] applies an adaptive class-dependent threshold with the uncertaintyrectified correction to realize better denoising. Although these methods have achieved some success in model adaptation, they still suffer from two major issues. First, they tend to be fairly unstable. Without any supervision signal from labeled data, the model heavily relies on the predictions generated by itself, which are always noisy and could easily make the training process unstable, causing catastrophic error accumulation after several training epochs as shown in Fig. 1(a). Some works avoid this problem by only training the model for very limited iterations (only 2 epochs in [4,27]) and selecting the best-performing model during the whole training process for testing. However, this does not fully utilize the data and it is non-trivial to select the best-performing model for this unsupervised learning task. Second, they failed to consider the severe foreground and background imbalance of fundus images where the foreground (e.g., cup) region is usually very small (as shown in Fig. 1(b)). This oversight could also lead to a model degradation due to the dominate background learning signal. In this paper, we propose the Class-Balanced Mean Teacher (CBMT) method to address the limitations of existing methods. To mitigate the negative impacts of incorrect pseudo labels, we propose a weak-strong augmented mean teacher learning scheme which involves a teacher model and a student model that are both initialized from the source model. We use the teacher to generate pseudo label from a weakly augmented image, and train the student that takes strongly augmented version of the same image as input. We do not train the teacher
686
L. Tang et al.
Fig. 1. (a) Training curve of vanilla pseudo-labeling, DPL [4] and our approach. (b) Fundus image and its label with class proportion from the RIM-ONE-r3 dataset. (c) Illustrated framework of our proposed CBMT method.
model directly by back-propagation but update its weights as the moving average of the student model. This prevents the teacher model from being abruptly impacted by incorrect pseudo labels and meanwhile accumulates new knowledge learned by the student model. To address the imbalance between foreground and background, we propose to calibrate the segmentation loss and highlight the foreground class, based on the prediction statistics derived from the global information. We maintain a prediction bank to capture global information, which is considered more reliable than that inside one image. Our contributions can be summarized as follows: (1) We propose the weakstrong augmented mean teacher learning scheme to address the stable issue of existing methods. (2) We propose the novel global knowledge-guided loss calibration technique to address the foreground and background imbalance problem. (3) Our proposed CBMT reaches state-of-the-art performance on two popular benchmarks for adaptive fundus image segmentation.
2
Method
Source-Free Domain Adaptive (SFDA) fundus image segmentation aims to adapt S a source model h, trained with NS labeled source images S = {(Xi , Yi )}N i=1 , NT to the target domain using only NT unlabeled target images T = {Xi }i=1 . Yi ∈ {0, 1}H×W ×C is the ground truth, and H, W , and C denote the image height, width, and class number, respectively. A vanilla pseudo-labeling-based method generates pseudo labels yˆ ∈ RC from the sigmoided model prediction p = h(x) for each pixel x ∈ Xi with source model h: yˆk = 1 [pk > γ] ,
(1)
where 1 is the indicator function and γ ∈ [0, 1] is the probability threshold for transferring soft probability to hard label. pk and yk is the k-th dimension of p
Class-Balanced Mean Teacher
687
and y, respectively, denoting the prediction and pseudo label for class k. Then (x, yˆ) is utilized to train the source model h with binary cross entropy loss: Lbce = Ex∼Xi [ˆ y log(p) + (1 − yˆ) log(1 − p)]
(2)
Most existing SFDA works refine this vanilla method by proposing techniques to calibrate p and get better pseudo label yˆ, or measure the uncertainty of p and apply a weight when using yˆ for computing the loss [4,27]. While achieving improved performance, these methods still suffer from the unstable issue because noisy yˆ will directly impact h, and the error will accumulate since then the predictions of h will be used for pseudo labeling. Another problem with this method is that they neglect the imbalance of the foreground and background pixels in fungus images, where the foreground region is small. Consequently, the second term in Eq. (2) will dominate the loss, which is undesirable. Our proposed CBMT model addresses the two problems by proposing the weak-strong augmented mean teacher learning scheme and the global knowledgeguided loss calibration technique. Figure 1(c) shows the framework of CBMT. 2.1
Weak-Strong Augmented Mean Teacher
To avoid error accumulation and achieve a robust training process, we introduce the weak-strong augmented mean teacher learning scheme where there is a teacher model ht and a student model hs both initialized from the source model h. We generate pseudo labels with ht and use the pseudo labels to train hs . To enhance generalization performance, we further introduce a weak-strong augmentation mechanism that feeds weakly and strongly augmented images to the teacher model and the student model, respectively. Concretely, for each image Xi , we generate a weakly-augmented version Xiw by using image flipping and resizing. Meanwhile, we generate a stronglyaugmented version Xis . The strong augmentations we used include a random eraser, contrast adjustment, and impulse noises. For each pixel xw ∈ Xiw , we generate pseudo label yˆw = ht (x) by the teacher model ht with Eq. (1). Then, we train the student model hs with L = Exs ∼Xis ,ˆyw [L˜bce ],
(3)
where L˜bce is the refined binary cross entropy loss which we will introduce later. It is based on Eq. (2) but addresses the fore- and back-ground imbalance problem. The weakly-strong augmentation mechanism has two main benefits. First, since fundus image datasets are always on a small scale, the model could easily get overfitted due to the insufficient training data issue. To alleviate it, we enhance the diversity of the training set by introducing image augmentation techniques. Second, learning with different random augmentations performs as a consistency regularizer constraining images with similar semantics to the same class, which forms a more distinguishable feature representation.
688
L. Tang et al.
We update the student model by back-propagating the loss defined in Eq. (3). But for the teacher model, we update it as the exponential moving average (EMA) of the student model as, θ˜ ← λθ˜ + (1 − λ)θ,
(4)
˜ θ are the teacher and student model weights separately. Instead of where θ, updating the model with gradient directly, we define the teacher model as the exponential moving average of students, which makes the teacher model more consistent along the adaptation process. With this, we could train a model for a relatively long process and safely choose the final model without accuracy validation. From another perspective, the teacher model can be interpreted as a temporal ensemble of students in different time steps [18], which enhances the robustness of the teacher model. 2.2
Global Knowledge Guided Loss Calibration
For a fundas image, the foreground object (e.g., cup) is usually quite small and most pixel will the background. If we update the student model with Eq. (2), the background class will dominate the loss, which dilutes the supervision signals for the foreground class. The proposed global knowledge guided loss calibration technique aims to address this problem. A naive way to address the foreground and background imbalance is to calculate the numbers of pixels falling into the two categories, respectively, within each individual image and devise a loss weighting function based on the numbers. This strategy may work well for the standard supervised learning tasks, where the labels are reliable. But with pseudo labels, it is too risky to conduct the statistical analysis based on a single image. To remedy this, we analyze the class imbalance across the whole dataset, and use this global knowledge to calibrate our loss for each individual image. Specifically, we store the predictions of pixels from all images and maintain the mean loss for foreground and background as, Li,k · 1[ˆ yi,k = 1] Li,k · 1[ˆ yi,k = 0] fg bg i ; ηk = i (5) ηk = yi,k = 1] yi,k = 0] i 1[ˆ i 1[ˆ where L is the segmentation loss mentioned above, and “fg” and “bg” represent foreground/background. The reason we use the mean of the loss, rather than the number of pixels, is that the loss of each pixel indicates the “hardness“ of each pixel according to the pseudo ground truth. This gives more weight to those more informative pixels, thus more global knowledge is considered. With each average loss, the corresponding learning scheme could be further calibrated. We utilize the ratio of ηkfg to ηkbg to weight background loss Lbg k : L˜bce = Ex∼Xi ,k∼C [ˆ yk log(pk ) + ηkfg /ηkbg (1 − yˆk ) log(1 − pk )]
(6)
The calibrated loss ensures fair learning among different classes, therefore alleviating model degradation issues caused by class imbalance.
Class-Balanced Mean Teacher
689
Since most predictions are usually highly confident (very close to 0 or 1), they are thus less informative. We need to only include pixels with relatively large loss scales to compute mean loss. We realize this by adopting constraint i )−γ| > α, where α is set by default to 0.2. α threshold α to select pixels: |f|(x yˆi −γ| represents the lower bound threshold of normalized prediction, which can filter well-segmented uninformative pixels out.
3
Experiments
Implementation Details1 . We apply the Deeplabv3+ [5] with MobileNetV2 [23] backbone as our segmentation model, following the previous works [4,26,27] for a fair comparison. For model optimization, we use Adam optimizer with 0.9 and 0.99 momentum coefficients. During the source model training stage, the initial learning rate is set to 1e-3 and decayed by 0.98 every epoch, and the training lasts 200 epochs. At the source-free domain adaptation stage, the teacher and student model are first initialized by the source model, and the EMA update scheme is applied between them for a total of 20 epochs with a learning rate of 5e-4. Loss calibration parameter η is computed every epoch and implemented on the class cup. The output probability threshold γ is set as 0.75 according to previous study [26] and model EMA update rate λ is 0.98 by default. We implement our method with PyTorch on one NVIDIA 3090 GPU and set batch size as 8 when adaptation. Datasets and Metrics. We evaluate our method on widely-used fundus optic disc and cup segmentation datasets from different clinical centers. Following previous works, We choose the REFUGE challenge training set [20] as the source domain and adapt the model to two target domains: RIM-ONE-r3 [7] and Drishti-GS [24] datasets for evaluation. Quantitatively, the source domain consists of 320/80 fundus images for training/testing with pixel-wise optic disc and cup segmentation annotation, while the target domains have 99/60 and 50/51 images. Same as [26], the fundus images are cropped to 512 × 512 as ROI regions. We compare our CBMT model with several state-of-the-art domain adaptation methods, including UDA methods BEAL [26] and AdvEnt [25] and SFDA methods: SRDA [1], DAE [15] and DPL [4]. More comparisons with U-D4R [27] under other adaptation settings could be found in supplementary materials. General metrics for segmentation tasks are used for model performance evaluation, including the Dice coefficient and Average Symmetric Surface Distance (ASSD). The dice coefficient (the higher the better) gives pixel-level overlap results, and ASSD (the lower the better) indicates prediction boundary accuracy. 3.1
Experimental Results
The quantitative evaluation results are shown in Table 1. We include the without adaptation results from [4] as a lower bound, and the supervised learning results 1
The code can be found in https://github.com/lloongx/SFDA-CBMT.
690
L. Tang et al.
Table 1. Quantitative results of comparison with different methods on two datasets, and the best score for each column is highlighted. - means the results are not reported by that method, ± refers to the standard deviation across samples in the dataset. S-F means source-free. Methods
S-F Optic Disc Segmentation Optic Cup Segmentation Dice[%] ↑ ASSD[pixel] ↓ Dice[%] ↑ ASSD[pixel] ↓
RIM-ONE-r3 W/o DA [4] Oracle [26] × ×
BEAL [26] AdvEnt [25] SRDA [1] DAE [15] DPL [4]
CBMT(Ours)
83.18±6.46 96.80
24.15±15.58 –
74.51±16.40 85.60
14.44±11.27 –
89.80 89.73±3.66 89.37±2.70 89.08±3.32 90.13±3.06
– 9.84±3.86 9.91±2.45 11.63±6.84 9.43±3.46
81.00 77.99±21.08 77.61±13.58 79.01±12.82 79.78±11.05
– 7.57±4.24 10.15±5.75 10.31±8.45 9.01±5.59
93.36±4.07 6.20±4.79
81.16±14.71 8.37±6.99
93.84±2.91 97.40
9.05±7.50 –
83.36±11.95 90.10
11.39±6.30 –
96.10 96.16±1.65 96.22±1.30 94.04±2.85 96.39±1.33
– 4.36±1.83 4.88±3.47 8.79±7.45 4.08±1.49
86.20 82.75±11.08 80.67±11.78 83.11±11.89 83.53±17.80
– 11.36±7.22 13.12±6.48 11.56±6.32 11.39±10.18
96.61±1.45 3.85±1.63
84.33±11.70
10.30±5.88
Drishti-GS W/o DA [4] Oracle [26] × ×
BEAL [26] AdvEnt [25] SRDA [1] DAE [15] DPL [4]
CBMT(Ours)
Table 2. Ablation study results of our proposed modules on the RIM-ONE-r3 dataset. P-L means vanilla pseudo-labeling method. * represents the accuracy is manually selected from the best epoch. The best results are highlighted. P-L
EMA
Aug.
Calib.
Avg. Dice ↑ Avg. ASSD ↓ 64.19 (84.68*)
15.11 (9.67*)
83.63 84.36 86.04 87.26
8.51 8.48 8.26 7.29
from [26] as an upper bound, same as [4]. As shown in the table, both two quantitative metric results perform better than previous state-of-the-art SFDA methods and even show an improvement against traditional UDA methods on
Class-Balanced Mean Teacher
691
Table 3. Loss calibration weight with different thresholds α on RIM-ONE-r3 dataset. Our method is robust to the hyper-parameter setting. α
0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
ηkfg /ηkbg 2.99 0.24 0.24 0.24 0.24 0.24 0.23 0.23 0.22 0.22
some metrics. Especially in the RIM-ONE-r3 dataset, our CBMT gains a great performance increase than previous works (dice gains by 3.23 on disc), because the domain shift issue is severer here and has big potential for improvement. Moreover, CBMT alleviates the need for precise tuning of hyper-parameters. Here we could set a relatively long training procedure (our epoch number is 10 times that of [4,27]), and safely select the last checkpoint as our final result without concerning about model degradation issue, which is crucial in real-world clinical source-free domain adaptation application. 3.2
Further Analyses
Ablation Study. In order to assess the contribution of each component to the final performance, we conducted an ablation study on the main modules of CBMT, as summarized in Table 2. Note that we reduced the learning rates by a factor of 20 for the experiments of the vanilla pseudo-labeling method to get comparable performance because models are prone to degradation without EMA updating. As observed in quantitative results, the EMA update strategy avoids the model from degradation, which the vanilla pseudo-labeling paradigm suffers from. Image augmentation and loss calibration also boost the model accuracy, and the highest performance is achieved with both. The loss calibration module achieves more improvement in its solution to class imbalance, while image augmentation is easy to implement and plug-and-play under various circumstances. Hyper-parameter Sensitivity Analysis. We further investigate the impact of different hyper-parameter. Figure 2(a) presents the accuracy with different EMA update rate parameters λ. It demonstrates that both too low and too high update rates would cause a drop in performance, which is quite intuitive: a higher λ leads to inconsistency between the teacher and student, and thus teacher can hardly learn knowledge from the student; On the other hand, a lower λ will always keep teacher and student close, making it degenerated to vanilla pseudolabeling. But within a reasonable range, the model is not sensitive to update rate λ. To evaluate the variation of the loss calibration weight ηkfg /ηkbg with different constraint thresholds α, we present the results in Table 3. As we discussed in Sect. 2.2, most pixels in an image are well-classified, and if we simply calculate with all pixels (i.e. α = 0), as shown in the first column, the mean loss of background will be severely underestimated due to the large quantity of zeroloss pixel. Besides, as α changes, the calibration weight varies little, indicating the robustness of our calibration technique to threshold α.
692
L. Tang et al.
Fig. 2. (a) Model performance with different EMA update rate λ setting. (b) Training curves with and without our proposed loss calibration scheme.
The Effectiveness of Loss Calibration to Balance Class. The class imbalance problem can cause misalignment in the learning processes of different classes, leading to a gradual decrease of predicted foreground area. This can ultimately result in model degradation. As shown in Fig. 2(b), neglecting the issue of class imbalance can cause a significant drop in the predicted pixel quantity of the class “cup” during training, and finally leads to a performance drop. Loss calibration performs a theoretical investigation and proposes an effective technique to alleviate this issue by balancing loss with global context.
4
Conclusion
In this work, we propose a class-balanced mean teacher framework to realize robust SFDA learning for more realistic clinical application. Based on the observation that model suffers from degradation issues during adaptation training, we introduce a mean teacher strategy to update the model via an exponential moving average way, which alleviates error accumulation. Meanwhile, by investigating the foreground and background imbalance problem, we present a global knowledge guided loss calibration module. Experiments on two fundus image segmentation datasets show that CBMT outperforms previous SFDA methods. Acknowledgement. This work was partly supported by Shenzhen Key Laboratory of next generation interactive media innovative technology (No: ZDSYS202 10623092001004).
References 1. Bateson, M., Kervadec, H., Dolz, J., Lombaert, H., Ben Ayed, I.: Source-Relaxed domain adaptation for image segmentation. In: MICCAI 2020. LNCS, vol. 12261, pp. 490–499. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-597108_48
Class-Balanced Mean Teacher
693
2. Cai, J., Zhang, Z., Cui, L., Zheng, Y., Yang, L.: Towards cross-modal organ translation and segmentation: a cycle-and shape-consistent generative adversarial network. Med. Image Anal. 52, 174–184 (2019) 3. Carion, N., Massa, F., Synnaeve, G., Usunier, N., Kirillov, A., Zagoruyko, S.: Endto-end object detection with transformers. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12346, pp. 213–229. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-58452-8_13 4. Chen, C., Liu, Q., Jin, Y., Dou, Q., Heng, P.-A.: Source-free domain adaptive fundus image segmentation with denoised pseudo-labeling. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 225–235. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87240-3_22 5. Chen, L.C., Zhu, Y., Papandreou, G., Schroff, F., Adam, H.: Encoder-decoder with atrous separable convolution for semantic image segmentation. In: Proceedings of the European Conference on Computer Vision (ECCV), pp. 801–818 (2018) 6. Drozdzal, M., Vorontsov, E., Chartrand, G., Kadoury, S., Pal, C.: The importance of skip connections in biomedical image segmentation. In: Carneiro, G., et al. (eds.) LABELS/DLMIA -2016. LNCS, vol. 10008, pp. 179–187. Springer, Cham (2016). https://doi.org/10.1007/978-3-319-46976-8_19 7. Fumero, F., Alayón, S., Sanchez, J.L., Sigut, J., Gonzalez-Hernandez, M.: Rimone: an open retinal image database for optic nerve evaluation. In: 2011 24th International Symposium on Computer-based Medical Systems (CBMS), pp. 1–6. IEEE (2011) 8. Gadermayr, M., Gupta, L., Appel, V., Boor, P., Klinkhammer, B.M., Merhof, D.: Generative adversarial networks for facilitating stain-independent supervised and unsupervised segmentation: a study on kidney histology. IEEE Trans. Med. Imaging 38(10), 2293–2302 (2019) 9. Ganin, Y., et al.: Domain-adversarial training of neural networks. J. Mach. Learn. Res. 17(1), 2096–2030 (2016) 10. He, C., et al.: Camouflaged object detection with feature decomposition and edge reconstruction. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 22046–22055 (2023) 11. He, C., et al.: Weakly-supervised concealed object segmentation with SAM-based pseudo labeling and multi-scale feature grouping. arXiv preprint: arXiv:2305.11003 (2023) 12. Ibtehaz, N., Rahman, M.S.: MultiResUNet: rethinking the u-net architecture for multimodal biomedical image segmentation. Neural Netw. 121, 74–87 (2020) 13. Javanmardi, M., Tasdizen, T.: Domain adaptation for biomedical image segmentation using adversarial training. In: 2018 IEEE 15th International Symposium on Biomedical Imaging (ISBI 2018), pp. 554–558. IEEE (2018) 14. Kamnitsas, K., et al.: Unsupervised domain adaptation in brain lesion segmentation with adversarial networks. In: Niethammer, M., et al. (eds.) IPMI 2017. LNCS, vol. 10265, pp. 597–609. Springer, Cham (2017). https://doi.org/10.1007/ 978-3-319-59050-9_47 15. Karani, N., Erdil, E., Chaitanya, K., Konukoglu, E.: Test-time adaptable neural networks for robust medical image segmentation. Med. Image Anal. 68, 101907 (2021) 16. Li, K., Zhang, Y., Li, K., Fu, Y.: Adversarial feature hallucination networks for fewshot learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13470–13479 (2020)
694
L. Tang et al.
17. Liu, X., Xing, F., Yang, C., El Fakhri, G., Woo, J.: Adapting off-the-shelf source Segmenter for target medical image segmentation. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12902, pp. 549–559. Springer, Cham (2021). https:// doi.org/10.1007/978-3-030-87196-3_51 18. Liu, Y.C., et al.: Unbiased teacher for semi-supervised object detection. arXiv preprint: arXiv:2102.09480 (2021) 19. Milletari, F., Navab, N., Ahmadi, S.A.: V-Net: fully convolutional neural networks for volumetric medical image segmentation. In: 2016 Fourth International Conference on 3D Vision (3DV), pp. 565–571. IEEE (2016) 20. Orlando, J.I., Fu, H., Breda, J.B., van Keer, K., Bathula, D.R., Diaz-Pinto, A., Fang, R., Heng, P.A., Kim, J., Lee, J., et al.: Refuge challenge: a unified framework for evaluating automated methods for glaucoma assessment from fundus photographs. Med. Image Anal. 59, 101570 (2020) 21. Ren, S., He, K., Girshick, R., Sun, J.: Faster R-CNN: towards real-time object detection with region proposal networks. In: Advances in Neural Information Processing Systems, vol. 28 (2015) 22. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4_28 23. Sandler, M., Howard, A., Zhu, M., Zhmoginov, A., Chen, L.C.: Mobilenetv 2: inverted residuals and linear bottlenecks. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4510–4520 (2018) 24. Sivaswamy, J., Krishnadas, S., Chakravarty, A., Joshi, G., Tabish, A.S., et al.: A comprehensive retinal image dataset for the assessment of glaucoma from the optic nerve head analysis. JSM Biomed. Imaging Data Pap. 2(1), 1004 (2015) 25. Vu, T.H., Jain, H., Bucher, M., Cord, M., Pérez, P.: ADVENT: adversarial entropy minimization for domain adaptation in semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2517–2526 (2019) 26. Wang, S., Yu, L., Li, K., Yang, X., Fu, C.-W., Heng, P.-A.: Boundary and entropydriven adversarial learning for fundus image segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11764, pp. 102–110. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32239-7_12 27. Xu, Z., et al.: Denoising for relaxing: unsupervised domain adaptive fundus image segmentation without source data. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. Lecture Notes in Computer Science, vol. 13435, pp. 214–224. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-16443-9_21
Unsupervised Domain Adaptation for Anatomical Landmark Detection Haibo Jin1 , Haoxuan Che1 , and Hao Chen1,2(B) 1 Department of Computer Science and Engineering, The Hong Kong University of Science and Technology, Kowloon, Hong Kong {hjinag,hche,jhc}@cse.ust.hk 2 Department of Chemical and Biological Engineering, The Hong Kong University of Science and Technology, Kowloon, Hong Kong
Abstract. Recently, anatomical landmark detection has achieved great progresses on single-domain data, which usually assumes training and test sets are from the same domain. However, such an assumption is not always true in practice, which can cause significant performance drop due to domain shift. To tackle this problem, we propose a novel framework for anatomical landmark detection under the setting of unsupervised domain adaptation (UDA), which aims to transfer the knowledge from labeled source domain to unlabeled target domain. The framework leverages self-training and domain adversarial learning to address the domain gap during adaptation. Specifically, a self-training strategy is proposed to select reliable landmark-level pseudo-labels of target domain data with dynamic thresholds, which makes the adaptation more effective. Furthermore, a domain adversarial learning module is designed to handle the unaligned data distributions of two domains by learning domaininvariant features via adversarial training. Our experiments on cephalometric and lung landmark detection show the effectiveness of the method, which reduces the domain gap by a large margin and outperforms other UDA methods consistently.
1
Introduction
Anatomical landmark detection is a fundamental step in many clinical applications such as orthodontic diagnosis [11] and orthognathic treatment planning [6]. However, manually locating the landmarks can be tedious and time-consuming. And the results from manual labeling can cause errors due to the inconsistency in landmark identification [5]. Therefore, it is of great need to automate the task of landmark detection for efficiency and consistency. In recent years, deep learning based methods have achieved great progresses in anatomical landmark detection. For supervised learning, earlier works [6,20,27] adopted heatmap regression with extra shape constraints. Later, Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0_66. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 695–705, 2023. https://doi.org/10.1007/978-3-031-43907-0_66
696
H. Jin et al.
Fig. 1. Domain A vs. Domain B. (a) Image histogram. (b)–(c) Visual samples.
graph network [16] and self-attention [11] were introduced to model landmark dependencies in an end-to-end manner for better performance. Despite the success of recent methods, they mostly focus on single-domain data, which assume the training and test sets follow the same distribution. However, such an assumption is not always true in practice, due to the differences in patient populations and imaging devices. Figure 1 shows that cephalogram images from two domains can be very different in both histogram and visual appearance. Therefore, a well trained model may encounter severe performance degradation in practice due to the domain shift of test data. A straightforward solution to this issue is to largely increase the size and diversity of training set, but the labeling is prohibitively expensive, especially for medical images. On the other hand, unsupervised domain adaptation (UDA) [10] aims to transfer the knowledge learned from the labeled source domain to the unlabeled target domain, which is a potential solution to the domain shift problem as unlabeled data is much easier to collect. The effectiveness of UDA has been proven in many vision tasks, such as image classification [10], object detection [7], and pose estimation [3,19,24]. However, its feasibility in anatomical landmark detection still remains unknown. In this paper, we aim to investigate anatomical landmark detection under the setting of UDA. Our preliminary experiments show that a well-performed model will yield significant performance drop on cross-domain data, where the mean radial error (MRE) increases from 1.22 mm to 3.32 mm and the success detection rate (SDR) within 2 mm drops from 83.76% to 50.05%. To address the domain gap, we propose a unified framework, which contains a base landmark detection model, a self-training strategy, and a domain adversarial learning module. Specifically, self-training is adopted to effectively leverage the unlabeled data from the target domain via pseudo-labels. To handle confirmation bias [2], we propose landmark-aware self-training (LAST) to select pseudo-labels at the landmark-level with dynamic thresholds. Furthermore, to address the covariate shift [26] issue (i.e., unaligned data distribution) that may degrade the performance of self-training, a domain adversarial learning (DAL) module is designed to learn domain-invariant features via adversarial training. Our experiments on two anatomical datasets show the effectiveness of the proposed framework. For example, on cephalometric landmark detection, it reduces the domain gap in MRE by 47% (3.32 mm → 1.75 mm) and improves the SDR (2 mm) from 50.05% to 69.15%. We summarize our contributions as follows.
Unsupervised Domain Adaptation for Anatomical Landmark Detection
697
Fig. 2. The overall framework. Based on 1) the landmark detection model, it 2) utilizes LAST to leverage the unlabeled target domain data via pseudo-labels, and 3) simultaneously conducts DAL for domain-invariant features.
1. We investigated anatomical landmark detection under the UDA setting for the first time, and showed that domain shift indeed causes severe performance drop of a well-performed landmark detection model. 2. We proposed a novel framework for the UDA of anatomical landmark detection, which significantly improves the cross-domain performance and consistently outperforms other state-of-the-art UDA methods.
2
Method
Figure 2 shows the overall framework, which aims to yield satisfactory performance in target domain under the UDA setting. During training, it leverages both labeled source domain data S = {xSi , yiS }N i=1 and unlabeled target domain . For evaluation, it will be tested on a hold-out test set from data T = {xTj }M j=1 target domain. The landmark detection model is able to predict landmarks with confidence, which is detailed in Sect. 2.1. To reduce domain gap, we further propose LAST and DAL, which are introduced in Sects. 2.2 and 2.3, respectively. 2.1
Landmark Detection Model
Recently, coordinate regression [11,16] has obtained better performance than heatmap regression [6,27]. However, coordinate based methods do not output confidence scores, which are necessary for pseudo-label selection in self-training [4,14]. To address this issue, we designed a model that is able to predict accurate landmarks while providing confidence scores. As shown in Fig. 3 (a), the model utilizes both coordinate and heatmap regression, where the former provides coarse but robust predictions via global localization, then projected to the local maps of the latter for prediction refinement and confidence measurement. Global Localization. We adopt Transformer decoder [12,15] for coarse localization due to its superiority in global attentions. A convolutional neural network
698
H. Jin et al.
Fig. 3. (a) Our landmark detection model. (b) Confidence scores of different landmarks for a random target domain image. (c) Statistics of reliable landmark-level pseudolabels with a fixed threshold τ = 0.4 over 500 images.
(CNN) is used to extract feature f ∈ RC×H×W , where C, H, and W represents number of channels, map height and width, respectively. By using f as memory, the decoder takes landmark queries q ∈ RL×C as input, then iteratively updates them through multiple decoder layers, where L is the number of landmarks. Finally, a feed-forward network (FFN) converts the updated landmark queries to coordinates yˆc ∈ RL×2 . The loss function Lcoord is defined to be the L1 loss between the predicted coordinate yˆc and the label yc . Local Refinement. This module outputs a score map yˆs ∈ RL×H×W and an offset map yˆo ∈ R2L×H×W via 1 × 1 convolutional layers by taking f as input. The score map indicates the likelihood of each grid to be the target landmark, while the offset map represents the relative offsets of the neighbouring grids to the target. The ground-truth (GT) landmark of the score map is smoothed by a Gaussian kernel [23], and L2 loss is used for loss function Lscore . Since the offset is a regression problem, L1 is used for loss Loffset , and only applied to the area where its GT score is larger than zero. During inference, different from [6,18,23], the optimal local grid is not selected by the maximum score of yˆs , but instead the projection of the coordinates from global localization. Then the corresponding offset value is added to the optimal grid for refinement. Also, the confidence of each prediction can be easily obtained from the score map via projection. The loss function of the landmark detection model can be summarized as λs Lscore + λo Loffset + Lcoord , (1) Lbase = S
where S is source domain data, λs and λo are balancing coefficients. Empirically, we set λs = 100 and λo = 0.02 in this paper.
Unsupervised Domain Adaptation for Anatomical Landmark Detection
2.2
699
Landmark-Aware Self-training
Self-training [14] is an effective semi-supervised learning (SSL) method, which iteratively estimates and selects reliable pseudo-labeled samples to expand the labeled set. Its effectiveness has also been verified on several vision tasks under the UDA setting, such as object detection [7]. However, very few works explored self-training for the UDA of landmark detection, but mostly restricted to the paradigm of SSL [9,21]. Since UDA is more challenging than SSL due to domain shift, reliable pseudolabels should be carefully selected to avoid confirmation bias [2]. Existing works [9,19,21] follow the pipeline of image classification by evaluating reliability at the image-level, which we believe is not representative because the landmarks within an image may have different reliabilities (see Fig. 3 (b)). To avoid potential noisy labels caused by the image-level selection, we propose LAST, which selects reliable pseudo-labels at the landmark-level. To achieve this, we use a binary mask m ∈ {0, 1}L to indicate the reliability of each landmark for a given image, where value 1 indicates the label is reliable and 0 the opposite. To decide the reliability of each landmark, a common practice is to use a threshold τ , where the l-th landmark is reliable if its confidence score sl > τ . During loss calculation, each loss term is multiplied by m to mask out the unreliable landmark-level labels. Thus, the loss for LAST is M (Lbase ), (2) LLAST = S∪T T
T
where M represents the mask operation, T = {xTj , y j }M is the j=1 , and y estimated pseudo-labels from the last self-training round. Note that the masks of the source domain data S always equal to one as they are ground truths. However, the landmark-level selection leads to unbalanced pseudo-labels between landmarks, as shown in Fig. 3 (c). This is caused by the fixed threshold τ in self-training, which cannot handle different landmarks adaptively. To address this issue, we introduce percentile scores [4] to yield dynamic thresholds (DT) for different landmarks. Specifically, for the l-th landmark, when the pseudolabels are sorted based on confidence (high to low), τrl is used as the threshold, which is the confidence score of r-th percentile. In this way, the selection ratio of pseudo-labels can be controlled by r, and the unbalanced issue can be addressed by using the same r for all the landmarks. We set the curriculum to be r = Δ · t, where t is the t-th self-training round and Δ is a hyperparameter that controls the pace. We use Δ = 20%, which yields five training rounds in total.
2.3
Domain Adversarial Learning
Although self-training has been shown effective, it inevitably contains bias towards source domain because its initial model is trained with source domain data only. In other words, the data distribution of target domain is different from the source domain, which is known as covariate shift [26]. To mitigate it,
700
H. Jin et al.
we introduce DAL to align the distribution of the two by conducting an adversarial training between a domain classifier and the feature extractor. Specifically, the feature f further goes through a global average pooling (GAP) and a fully connected (FC) layer, then connects to a domain classifier D to discriminate the source of input x. The classifier can be trained with binary cross-entropy loss: LD = −d log D(F (x)) − (1 − d) log(1 − D(F (x))),
(3)
where d is domain label, with d = 0 and d = 1 indicating the images are from source and target domain, respectively. The domain classifier is trained to minimize LD , while the feature extractor F is encouraged to maximize it such that the learned feature is indistinguishable to the domain classifier. Thus, the adversarial objective function can be written as LDAL = max min LD . To simplify the optiF
D
mization, we adopt gradient reversal layer (GRL) [10] to mimic the adversarial training, which is placed right after the feature extractor. During backpropagation, GRL negates the gradients that pass back to the feature extractor F so that F is actually maximized. In this way, the adversarial training can be done via the minimization of LD , i.e., LDAL = LD . Finally, we have the overall loss function as follows: LLAST + λD LDAL , (4) L= S∪T
where λD is a balancing coefficient.
3 3.1
Experiments Experimental Settings
In this section, we present experiments on cephalometric landmark detection. See lung landmark detection in Appendix A. Source Domain. The ISBI 2015 Challenge provides a public dataset [22], which is widely used as a benchmark of cephalometric landmark detection. It contains 400 images in total, where 150 images are for training, 150 images are Test 1 data, and the remaining are Test 2. Each image is annotated with 19 landmarks by two experienced doctors, and the mean values of the two are used as GT. In this paper, we only use the training set as the labeled source domain data. Target Domain. The ISBI 2023 Challenge provides a new dataset [13], which was collected from 7 different imaging devices. By now, only the training set is released, which contains 700 images. For UDA setting, we randomly selected 500 images as unlabeled target domain data, and the remaining 200 images are for evaluation. The dataset provides 29 landmarks, but we only use 19 of them, i.e., the same landmarks as the source domain [22]. Following previous works [11,16], all the images are resized to 640 × 800. For evaluation metric, we adopt MRE and SDR within four radius (2 mm, 2.5 mm, 3 mm, and 4 mm).
Unsupervised Domain Adaptation for Anatomical Landmark Detection
701
Table 1. Results on the target domain test set, under UDA setting. Method
MRE↓ 2 mm
2.5 mm 3 mm
4 mm
Base, Labeled Source 3.32
50.05
56.87
62.63
70.87
FDA [25] UMT [8] SAC [1] AT [17] Ours
61.28 63.94 65.68 66.82 69.15
69.73 72.52 73.76 74.81 76.94
76.34 78.89 79.63 80.73 82.92
84.57 87.05 87.81 88.47 90.05
83.76
89.71
92.79
96.08
2.16 1.98 1.94 1.87 1.75
Base, Labeled Target 1.22
Implementation Details. We use ImageNet pretrained ResNet-50 as the backbone, followed by three deconvolutional layers for upsampling to stride 4 [23]. For Transformer decoder, three decoder layers are used, and the embedding length C = 256. Our model has 41M parameters and 139 GFLOPs when input size is 640 × 800. The source domain images are oversampled to the same number of target domain so that the domain classifier is unbiased. Adam is used as the optimizer, and the model is trained for 720 epochs in each self-training round. The initial learning rate is 2 × 10−4 , and decayed by 10 at the 480th and 640th epoch. The batch size is set to 10 and λD is set to 0.01. For data augmentation, we use random scaling, translation, rotation, occlusion, and blurring. The code was implemented with PyTorch 1.13 and trained with one RTX 3090 GPU. The training took about 54 h. 3.2
Results
For the comparison under UDA setting, several state-of-the-art UDA methods were implemented, including FDA [25], UMT [8], SAC [1], and AT [17]. Additionally, the base model trained with source domain data only is included as the lower bound, and the model trained with equal amount of labeled target domain data is used as the upper bound. Table 1 shows the results. Firstly, we can see that the model trained on the target domain obtains much better performance than the one on source domain in both MRE (1.22 mm vs. 3.32 mm) and SDR (83.76% vs. 50.05%, within 2 mm), which indicates that the domain shift can cause severe performance degradation. By leveraging both labeled source domain and unlabeled target domain data, our model achieves 1.75 mm in MRE and 69.15% in SDR within 2 mm. It not only reduces the domain gap by a large margin (3.32 mm → 1.75 mm in MRE and 50.05% → 69.15% in 2 mm SDR), but also outperforms the other UDA methods consistently. However, there is still a gap between the UDA methods and the supervised model in target domain. 3.3
Model Analysis
We first do ablation study to show the effectiveness of each module, which can be seen in Table 2. The baseline model simply uses vanilla self-training [14] for
702
H. Jin et al.
Fig. 4. Qualitative results of three models on target domain test data. Green dots are GTs, and red dots are predictions. Yellow rectangles indicate that our model performs better than the other two, while cyan rectangles indicate that all the three fail. (Color figure online) Table 2. Ablation study of different modules. Method
MRE↓ 2 mm
2.5 mm 3 mm
4 mm
Self-training [14]
2.18
62.18
69.44
75.47
84.36
LAST w/o DT
1.98
65.34
72.53
78.03
86.11
LAST
1.91
66.21
74.39
80.23
88.42
74.18
79.73
87.60
DAL
1.96
65.92
LAST+DAL
1.75
69.15 76.94
82.92 90.05
66.45
81.82
LAST+DAL w/ HM 1.84
75.09
89.55
domain adaptation, which achieves 2.18 mm in MRE. By adding LAST but without dynamic thresholds (DT), the MRE improves to 1.98 mm. When the proposed LAST and DAL are applied separately, the MREs are 1.91 mm and 1.96 mm, respectively, which verifies the effectiveness of the two modules. By combining the two, the model obtains the best results in both MRE and SDR. To show the superiority of our base model, we replace it by standard heatmap regression [23] (HM), which obtains degraded results in both MRE and SDR. Furthermore, we conduct analysis on subdomain discrepancy, which shows the effectiveness of our method on each subdomain (see Appendix B). 3.4
Qualitative Results
Figure 4 shows the qualitative results of the source-only base model, AT [17], and our method on target domain test data. The green dots are GTs, and red dots are predictions. It can be seen from the figure that our model makes better predictions than the other two (see yellow rectangles). We also notice that
Unsupervised Domain Adaptation for Anatomical Landmark Detection
703
some landmarks are quite challenging, where all the three fail to give accurate predictions (see cyan rectangles).
4
Conclusion
In this paper, we investigated anatomical landmark detection under the UDA setting. To mitigate the performance drop caused by domain shift, we proposed a unified UDA framework, which consists of a landmark detection model, a self-training strategy, and a DAL module. Based on the predictions and confidence scores from the landmark model, a self-training strategy is proposed for domain adaptation via landmark-level pseudo-labels with dynamic thresholds. Meanwhile, the model is encouraged to learn domain-invariant features via adversarial training so that the unaligned data distribution can be addressed. We constructed a UDA setting based on two anatomical datasets, where the experiments showed that our method not only reduces the domain gap by a large margin, but also outperforms other UDA methods consistently. However, a performance gap still exists between the current UDA methods and the supervised model in target domain, indicating more effective UDA methods are needed to close the gap. Acknowledgments. This work was supported by the Shenzhen Science and Technology Innovation Committee Fund (Project No. SGDX20210823103201011) and Hong Kong Innovation and Technology Fund (Project No. ITS/028/21FP).
References 1. Araslanov, N., Roth, S.: Self-supervised augmentation consistency for adapting semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 15384–15394 (2021) 2. Arazo, E., Ortego, D., Albert, P., O’Connor, N.E., McGuinness, K.: Pseudolabeling and confirmation bias in deep semi-supervised learning. In: 2020 International Joint Conference on Neural Networks (IJCNN), pp. 1–8. IEEE (2020) 3. Bigalke, A., Hansen, L., Diesel, J., Heinrich, M.P.: Domain adaptation through anatomical constraints for 3D human pose estimation under the cover. In: International Conference on Medical Imaging with Deep Learning, pp. 173–187. PMLR (2022) 4. Cascante-Bonilla, P., Tan, F., Qi, Y., Ordonez, V.: Curriculum labeling: revisiting pseudo-labeling for semi-supervised learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 35, pp. 6912–6920 (2021) 5. Chen, M.H., et al.: Intraobserver reliability of landmark identification in cone-beam computed tomography-synthesized two-dimensional cephalograms versus conventional cephalometric radiography: a preliminary study. J. Dental Sci. 9(1), 56–62 (2014) 6. Chen, R., Ma, Y., Chen, N., Lee, D., Wang, W.: Cephalometric landmark detection by attentive feature pyramid fusion and regression-voting. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11766, pp. 873–881. Springer, Cham (2019). https:// doi.org/10.1007/978-3-030-32248-9_97
704
H. Jin et al.
7. Chen, Y., Li, W., Sakaridis, C., Dai, D., Van Gool, L.: Domain adaptive faster R-CNN for object detection in the wild. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3339–3348 (2018) 8. Deng, J., Li, W., Chen, Y., Duan, L.: Unbiased mean teacher for cross-domain object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 4091–4101 (2021) 9. Dong, X., Yang, Y.: Teacher supervises students how to learn from partially labeled images for facial landmark detection. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 783–792 (2019) 10. Ganin, Y., Lempitsky, V.: Unsupervised domain adaptation by backpropagation. In: International Conference on Machine Learning, pp. 1180–1189. PMLR (2015) 11. Jiang, Y., Li, Y., Wang, X., Tao, Y., Lin, J., Lin, H.: CephalFormer: incorporating global structure constraint into visual features for general cephalometric landmark detection. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) MICCAI 2022. LNCS, vol. 13433, pp. 227–237. Springer, Cham (2022). https://doi.org/10. 1007/978-3-031-16437-8_22 12. Jin, H., Li, J., Liao, S., Shao, L.: When liebig’s barrel meets facial landmark detection: a practical model. arXiv preprint arXiv:2105.13150 (2021) 13. Khalid, M.A., et al.: Aariz: a benchmark dataset for automatic cephalometric landmark detection and CVM stage classification. arXiv:2302.07797 (2023) 14. Lee, D.H.: Pseudo-label: the simple and efficient semi-supervised learning method for deep neural networks. In: ICML Workshop on Challenges in Representation Learning (2013) 15. Li, J., Jin, H., Liao, S., Shao, L., Heng, P.A.: RepFormer: refinement pyramid transformer for robust facial landmark detection. In: IJCAI (2022) 16. Li, W., et al.: Structured landmark detection via topology-adapting deep graph learning. In: Vedaldi, A., Bischof, H., Brox, T., Frahm, J.-M. (eds.) ECCV 2020. LNCS, vol. 12354, pp. 266–283. Springer, Cham (2020). https://doi.org/10.1007/ 978-3-030-58545-7_16 17. Li, Y.J., et al.: Cross-domain adaptive teacher for object detection. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7581–7590 (2022) 18. Liu, W., Wang, Yu., Jiang, T., Chi, Y., Zhang, L., Hua, X.-S.: Landmarks detection with anatomical constraints for total hip arthroplasty preoperative measurements. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12264, pp. 670–679. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59719-1_65 19. Mu, J., Qiu, W., Hager, G.D., Yuille, A.L.: Learning from synthetic animals. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 12386–12395 (2020) 20. Payer, C., Štern, D., Bischof, H., Urschler, M.: Integrating spatial configuration into heatmap regression based CNNs for landmark localization. Med. Image Anal. (2019) 21. Wang, C., et al.: Pseudo-labeled auto-curriculum learning for semi-supervised keypoint localization. In: ICLR (2022) 22. Wang, C.W., et al.: A benchmark for comparison of dental radiography analysis algorithms. Med. Image Anal. (2016) 23. Xiao, B., Wu, H., Wei, Y.: Simple baselines for human pose estimation and tracking. In: Ferrari, V., Hebert, M., Sminchisescu, C., Weiss, Y. (eds.) ECCV 2018. LNCS, vol. 11210, pp. 472–487. Springer, Cham (2018). https://doi.org/10.1007/978-3030-01231-1_29
Unsupervised Domain Adaptation for Anatomical Landmark Detection
705
24. Yang, W., Ouyang, W., Wang, X., Ren, J., Li, H., Wang, X.: 3D human pose estimation in the wild by adversarial learning. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 5255–5264 (2018) 25. Yang, Y., Soatto, S.: FDA: Fourier domain adaptation for semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4085–4095 (2020) 26. Zhao, S., et al.: A review of single-source deep unsupervised visual domain adaptation. IEEE Tran. Neural Netw. Learn. Syst. 33(2), 473–493 (2020) 27. Zhong, Z., Li, J., Zhang, Z., Jiao, Z., Gao, X.: An attention-guided deep regression model for landmark detection in cephalograms. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11769, pp. 540–548. Springer, Cham (2019). https://doi.org/10. 1007/978-3-030-32226-7_60
MetaLR: Meta-tuning of Learning Rates for Transfer Learning in Medical Imaging Yixiong Chen1,2 , Li Liu3(B) , Jingxian Li4 , Hua Jiang1,2 , Chris Ding1 , and Zongwei Zhou5 1
The Chinese University of Hong Kong, Shenzhen, China Shenzhen Research Institute of Big Data, Shenzhen, China The Hong Kong University of Science and Technology (Guangzhou), Guangzhou, China [email protected] 4 Software School, Fudan University, Shanghai, China 5 Johns Hopkins University, Baltimore, USA 2
3
Abstract. In medical image analysis, transfer learning is a powerful method for deep neural networks (DNNs) to generalize on limited medical data. Prior efforts have focused on developing pre-training algorithms on domains such as lung ultrasound, chest X-ray, and liver CT to bridge domain gaps. However, we find that model fine-tuning also plays a crucial role in adapting medical knowledge to target tasks. The common finetuning method is manually picking transferable layers (e.g., the last few layers) to update, which is labor-expensive. In this work, we propose a meta-learning-based learning rate (LR) tuner, named MetaLR, to make different layers automatically co-adapt to downstream tasks based on their transferabilities across domains. MetaLR learns LRs for different layers in an online manner, preventing highly transferable layers from forgetting their medical representation abilities and driving less transferable layers to adapt actively to new domains. Extensive experiments on various medical applications show that MetaLR outperforms previous state-of-the-art (SOTA) fine-tuning strategies. Codes are released. Keywords: Medical image analysis
1
· Meta-learning · Transfer learning
Introduction
Transfer learning has become a standard practice in medical image analysis as collecting and annotating data in clinical scenarios can be costly. The pre-trained parameters endow better generalization to DNNs than the models trained from scratch [8,23]. A popular approach to enhancing model transferability is by pretraining on domains similar to the targets [9,21,27–29]. However, utilizing specialized pre-training for all medical applications becomes impractical due to the Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 67. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 706–716, 2023. https://doi.org/10.1007/978-3-031-43907-0_67
707
FC
…
CONV
CONV
…
CONV
CONV
…
CONV
CONV
…
CONV
CONV
CONV CONV
MetaLR
Source Domain
CONV
CONV
FC
…
CONV
FC
…
…
CONV
CONV
…
CONV
…
CONV
CONV
CONV
CONV
…
CONV
Pre-trained parameters
Previous: fix ( ) transferable layers CONV CONV
Target Domain
CONV
…
CONV
CONV
…
CONV
CONV
CONV
MetaLR: adaptive layer-wise LR
Fig. 1. The motivation of MetaLR. Previous works fix transferable layers in pre-trained models to prevent them from catastrophic forgetting. It is inflexible and labor-expensive for this method to find the optimal scheme. MetaLR uses meta-learning to automatically optimize layer-wise LR for fine-tuning.
diversity between domains and tasks and privacy concerns related to pre-training data. Consequently, recent work [2,6,14,22] has focused on improving the generalization capabilities of existing pre-trained DNN backbones through fine-tuning techniques. Previous studies have shown that the transferability of lower layers is often higher than higher layers that are near the model output [26]. Layer-wise finetuning [23], was thus introduced to preserve the transferable low-level knowledge by fixing lower layers. But recent studies [7] revealed that the lower layers may also be sensitive to small domains like medical images. Given the two issues, transferability for medical tasks becomes more complicated [24,25]. It can even be irregular among layers for medical domains far from pre-training data [7]. Given the diverse medical domains and model architectures, there is currently no universal guideline to follow to determine whether a particular layer should be retrained for a given target domain. To search for optimal layer combinations for fine-tuning, manually selecting transferable layers [2,23] can be a solution, but it requires a significant amount of human labor and computational cost. In order to address this issue and improve the flexibility of fine-tuning strategies, we propose controlling the fine-tuning process with layer-wise learning rates (LRs), rather than simply manually fixing or updating the layers (see Fig. 1). Our proposed algorithm, Meta Learning Rate (MetaLR), is based on meta-learning [13] and adaptively adjusts LRs for each layer according to transfer feedback. It treats the layer-wise LRs as metaknowledge and optimizes them to improve the model generalization. Larger LRs indicate less transferability of corresponding layers and require more updating,
708
Y. Chen et al.
while smaller LRs preserve transferable knowledge in the layers. Inspired by [20], we use an online adaptation strategy of LRs with a time complexity of O(n), instead of the computationally-expensive bi-level O(n2 ) meta-learning. We also enhance the algorithm’s performance and stability with a proportional hyper-LR (LR for LR) and a validation scheme on training data batches. In summary, this work makes the following three contributions. 1) We introduce MetaLR, a meta-learning-based LR tuner that can adaptively adjust layerwise LRs based on transfer learning feedback from various medical domains. 2) We enhance MetaLR with a proportional hyper-LR and a validation scheme using batched training data to improve the algorithm’s stability and efficacy. 3) Extensive experiments on both lesion detection and tumor segmentation tasks were conducted to demonstrate the superior efficiency and performance of MetaLR compared to current SOTA medical fine-tuning techniques.
2
Method
This section provides a detailed description of the proposed MetaLR. It is a meta-learning-based [13,18] approach that determines the appropriate LR for each layer based on its transfer feedback. It is important to note that fixing transferable layers is a special case of this method, where fixed layers always have zero LRs. First, we present the theoretical formulation of MetaLR. Next, we discuss online adaptation for efficiently determining optimal LRs. Finally, we demonstrate the use of a proportional hyper-LR and a validation scheme with batched training data to enhance performance. 2.1
Formulation of Meta Learning Rate
Let (x, y) denote a sample-label pair, and {(xi , yi ) | i = 1, ..., N } be the training data. The validation dataset {(xvi , yiv ) | i = 1, ..., M } is assumed to be independent and identically distributed as the training dataset. Let yˆ = Φ(x, θ) be the prediction for sample x from deep model Φ with parameters θ. In standard training N of DNNs, the aim is to minimize the expected risk for the training set: 1 yi , yi ) with fixed training hyper-parameters, where L(ˆ y , y) is the loss i=1 L(ˆ N function for the current task. The model generalization can be evaluated by the M 1 yiv , yiv ). Based on the generalization, one can tune the validation loss M i=1 L(ˆ hyper-parameters of the training process to improve the model. The key idea of MetaLR is considering the layer-wise LRs as self-adaptive hyper-parameters during the training and automatically adjusting them to achieve better model generalization. We denote the LR and model parameters for the layer j at the iteration t as αjt and θjt . The LR scheduling scheme α = {αjt | j = 1, ..., d; t = 1, ..., T } is what MetaLR wants to learn, affecting which local optimal θ∗ (α) the model parameters θt = {θjt | j = 1, ..., d} will converge to. The optimal parameters θ∗ (α) are given by optimization on the training data. At the same time, the best LR tuning scheme α∗ can be optimized based on the feedback for θ∗ (α) from
MetaLR
709
Algorithm 1. Online Meta Learning Rate Algorithm Input: Training data D, validation data Dv , initial model parameter {θ10 , ..., θd0 }, LRs {α10 , ..., αd0 }, batch size n, max iteration T; Output: Final model parameter θT = {θ1T , ..., θdT }; 1: for t = 0 : T − 1 do 2: {(xi , yi ) | i = 1, ..., n} ← TrainDataLoader(D, n) ; 3: {(xvi , yiv ) | i = 1, ..., n} ← ValidDataLoader(Dv , n) ; 4: Step forward for one step to get {θˆ1t (α1t ), ..., θˆdt (αdt )} with Eq. (2); 5: Update {α1t , ..., αdt } to become {α1t+1 , ..., αdt+1 } with Eq. (3); 6: Update {θ1t , ..., θdt } to become {θ1t+1 , ..., θdt+1 } with Eq. (4); 7: end for
the validation loss. This problem can be formulated as the following bi-level optimization problem: min α
M 1 L(Φ(xvi , θ∗ (α)), yiv ), M i=1
N 1 s.t. θ∗ (α) = arg min L(Φ(xi , θ), yi ). N i=1 θ
(1)
MetaLR aims to use the validation set to optimize α through an automatic process rather than a manual one. The optimal scheme α∗ can be found by a nested optimization [13], but it is too computationally expensive in practice. A faster and more lightweight method is needed to make it practical. 2.2
Online Learning Rate Adaptation
Inspired by the online approximation [20], we propose efficiently adapting the LRs and model parameters online. The motivation of the online LR adaptation is updating the model parameters θt and LRs {αjt | j = 1, 2, ..., d} within the same loop. We first inspect the descent direction of parameters θjt on the training loss landscape and adjust the αjt based on the transfer feedback. Positive feedback (lower validation loss) means the LRs are encouraged to increase. We adopt Stochastic Gradient Descent (SGD) as the optimizer to conduct the meta-learning. The whole training process is summarized in Algorithm 1. At the iteration t of training, a training data batch {(xi , yi ) | i = 1, ..., n} and a validation data batch {(xvi , yiv ) | i = 1, ..., n} are sampled, where n is the size of the batches. First, the parameters of each layer are updated once with the current LR according to the descent direction on training batch. 1 θˆjt (αjt ) = θjt − αjt ∇θj ( n
n i=1
L(Φ(xi , θjt ), yi )), j = 1, ..., d.
(2)
710
Y. Chen et al.
This step of updating aims to get feedback for LR of each layer. After taking derivative of the validation loss w.r.t. αjt , we can utilize the gradient to know how the LR for each layer should be adjusted. So the second step of MetaLR is to move the LRs along the meta-objective gradient on the validation data: αjt+1 = αjt − η∇αj (
n
1 L(Φ(xvi , θˆjt (αjt )), yiv )), n i=1
(3)
where η is the hyper-LR. Finally, the updated LRs can be employed to optimize the model parameters through gradient descent truly. θjt+1 = θjt − αjt+1 ∇θj (
n
1 L(Φ(xi , θjt ), yi )). n i=1
(4)
For practical use, we constrain the LR for each layer to be αjt ∈ [10−6 , 10−2 ]. Online MetaLR optimizes the layer-wise LRs as well as the training objective on a single task, which differentiates it from traditional meta-learning algorithms [12, 19] that train models on multiple small tasks. 2.3
Proportional Hyper Learning Rate
In practice, LRs are often tuned in an exponential style (e.g., 1e−3, 3e−3, 1e−2) and are always positive values. However, if a constant hyper-LR is used, it will linearly update its corresponding LR regardless of numerical constraints. This can lead to fluctuations in the LR or even the risk of the LR becoming smaller than 0 and being truncated. To address this issue, we propose using a proportional hyper-LR η = β × αjt , where β is a pre-defined hyper-parameter. This allows us to rewrite Eq. (3) as: αjt+1 = αjt (1 − β∇αj (
n
1 L(Φ(xvi , θˆjt (αjt )), yiv ))). n i=1
(5)
The exponential update of αjt guarantees its numerical stability. 2.4
Generalizability Validation on Training Data Batch
One limitation of MetaLR is that the LRs are updated using separate validation data, which reduces the amount of data available for the training process. This can be particularly problematic for medical transfer learning, where the amount of downstream data has already been limited. In Eq. 2 and Eq. 3, the update of model parameter θjt and LR αjt is performed using different datasets to ensure that the updated θjt can be evaluated for generalization without being influenced by the seen data. As an alternative, but weaker, approach, we explore using another batch of training data for Eq. 3 to evaluate generalization. Since this batch was not used in the update of Eq. 2, it may still perform well for validation in meta-learning. The effect of this approach is verified in Sect. 3.2, and the differences between the two methods are analyzed in Sect. 3.4.
MetaLR
3
711
Experiments and Analysis
3.1
Experimental Settings
We extensively evaluate MetaLR on four transfer learning tasks (as shown in Table 1). To ensure the reproducibility of the results, all pre-trained models (USCL [9], ImageNet [11], C2L [28], Models Genesis [29]) and target datasets (POCUS [5], BUSI [1], Chest X-ray [17], LiTS [4]) are publicly available. In our work, we consider models pre-trained on both natural and medical image datasets, with three target modalities and three target organs, which makes our experimental results more credible. For the lesion detection tasks, we used ResNet-18 [15] with the Adam optimizer. The initial learning rate (LR) and hyper-LR coefficient β are set to 10−3 and 0.1, respectively. In addition, we use 25% of the training set as the validation set for meta-learning. For the segmentation task, we use 3D U-Net [10] with the SGD optimizer. The initial LR and hyper-LR coefficient β are set to 10−2 and 3 × 10−3 , respectively. The validation set for the LiTS segmentation dataset comprises 23 samples from the training set of size 111. All experiments are implemented using PyTorch 1.10 on an Nvidia RTX A6000 GPU. We report the mean values and standard deviations for each experiment with five different random seeds. For more detailed information on the models and hyper-parameters, please refer to our supplementary material. Table 1. Pre-training data, algorithms, and target tasks. Source
Pre-train Method
Target
Organ Modality Task
US-4 [9]
USCL [9]
POCUS [5]
Lung
ImageNet [11]
supervised
BUSI [1]
Breast US
MIMIC-CXR [16] C2L [28] LIDC-IDRI [3]
3.2
Chest X-ray [17] Lung
Models Genesis [29] LiTS [4]
Liver
US
Size
COVID-19 detection 2116 images Tumor detection
780 images
X-ray
Pneumonia detection 5856 images
CT
Liver segmentation
131 volumes
Ablation Study
In order to evaluate the effectiveness of our proposed method, we conduct an ablation study w.r.t. the basic MetaLR algorithm, the proportional hyper-LR, and batched-training-data validation (as shown in Table 2). When applying only the basic MetaLR, we observe only marginal performance improvements for the four downstream tasks. We conjecture that this is due to two reasons: Firstly, the constant hyper-LR makes the training procedures less stable than direct training, which is evident from the larger standard deviation of performance. Secondly, part of the training data are split for validation, which can be detrimental to the performance. After applying the proportional hyper-LR, significant improvements are in both the performance and its stability. Moreover, although the generalization validation on the training data batch may introduce bias, providing sufficient training data ultimately benefits the performance.
712
Y. Chen et al.
Table 2. Ablation study for MetaLR, hyper-LR, and validation data. The baseline is the direct tuning of all layers with constant LRs. The default setting for MetaLR is a constant hyper-LR of 10−3 and a separate validation set. MetaLR Prop. hyper-LR Val. on trainset POCUS
93.1 ± 0.4
91.9 ± 0.6 84.9 ± 1.3 95.0 ± 0.4
93.2 ± 0.8
93.6 ± 0.4 85.2 ± 0.8 95.3 ± 0.2
93.3 ± 0.6
93.0 ± 0.3 86.3 ± 0.7 95.5 ± 0.2
93.9 ± 0.5
93.9 ± 0.4 86.7 ± 0.7 95.8 ± 0.3
94.2 ± 0.5
Chest X-ray LiTS
91.6 ± 0.8 84.4 ± 0.7 94.8 ± 0.3
BUSI
• Final MetaLR outperforms baseline with p-values of 0.0014, 0.0016, 0.0013, 0.0054.
Table 3. Comparative experiments on lesion detection. We report sensitivities (%) of the abnormalities, overall accuracy (%), and training time on each task. Method COVID
POCUS Pneu. Acc
Last Layer All Layers Layer-wise Bi-direc AutoLR
77.9 ± 2.1 85.8 ± 1.7 87.5 ± 1.0 90.1 ± 1.2 89.8 ± 1.6
84.0 ± 1.3 90.0 ± 1.9 92.3 ± 1.3 92.5 ± 1.5 89.7 ± 1.5
MetaLR
94.8 ± 1.2 93.1 ± 1.5 93.9 ± 0.4 24.8 m 92.2 ± 0.7 75.6 ± 3.6
3.3
84.1 ± 0.2 91.6 ± 0.8 92.1 ± 0.3 93.6 ± 0.2 90.4 ± 0.8
Time 15.8 m 16.0 m 2.4 h 12.0 h 17.5 m
Benign 83.5 ± 0.4 90.4 ± 1.5 90.8 ± 1.2 92.2 ± 1.0 90.4 ± 1.8
BUSI Malignant Acc 47.6 ± 4.4 77.8 ± 3.5 75.7 ± 2.6 77.1 ± 3.5 76.2 ± 3.2
66.8 ± 0.5 84.4 ± 0.7 85.6 ± 0.4 86.5 ± 0.5 84.9 ± 0.8
Time 4.4 m 4.3 m 39.0 m 3.2 h 4.9 m
86.7 ± 0.7 6.0 m
Pneu.
Chest X-ray Acc
99.7 ± 1.3 98.8 ± 0.2 97.9 ± 0.3 98.4 ± 0.3 95.4 ± 0.5
87.8 ± 0.6 94.8 ± 0.3 95.2 ± 0.2 95.4 ± 0.1 93.0 ± 0.8
Time 12.7 m 12.9 m 1.9 h 9.7 h 13.3 m
97.4 ± 0.4 95.8 ± 0.3 26.3 m
Comparative Experiments
In our study, we compare MetaLR with several other fine-tuning schemes, including tuning only the last layer / all layers with constant LRs, layer-wise finetuning [23], bi-directional fine-tuning [7], and AutoLR [22]. The U-Net finetuning scheme proposed by Amiri et al. [2] was also evaluated. Results on Lesion Detection Tasks. MetaLR consistently shows the best performance on all downstream tasks (Table 3). It shows 1%–2.3% accuracy improvements compared to direct training (i.e., tuning all layers) because it takes into account the different transferabilities of different layers. While manual picking methods, such as layer-wise and bi-directional fine-tuning, also achieve higher performance, they require much more training time (5×–50×) for searching the best tuning scheme. On the other hand, AutoLR is efficient, but its strong hypothesis harms its performance sometimes. In contrast, MetaLR makes no hypothesis about transferability and learns appropriate layer-wise LRs on different domains. Moreover, its performance improvements are gained with only 1.5×–2.5× training time compared with direct training. Results on Segmentation Task. MetaLR achieves the best Dice performance on the LiTS segmentation task (Table 4). Unlike ResNet for lesion detection, the U-Net family has a more complex network topology. With skip connections, there are two interpretations [2] of depths for layers: 1) the left-most layers are the shallowest, and 2) the top layers of the “U” are the shallowest. This makes the handpicking methods even more computationally expensive. However, MetaLR
MetaLR
713
Table 4. Comparative experiments on LiTS liver segmentation task. Method
PPV
Sensitivity
Last Layer All Layers Layer-wise Bi-direc Mina et al.
26.1 ± 5.5 94.0 ± 0.6 92.1 ± 1.3 92.4 ± 1.1 92.7 ± 1.2
MetaLR
94.4 ± 0.9 93.6 ± 0.4
71.5 ± 4.2 93.1 ± 0.7 96.4 ± 0.4 96.1 ± 0.2 93.2 ± 0.5
Dice 33.5 ± 3.4 93.1 ± 0.4 93.7 ± 0.3 93.8 ± 0.1 92.4 ± 0.5
Time 2.5 h 2.6 h 41.6 h 171.2 h 2.6 h
94.2 ± 0.5 5.8 h
Fig. 2. The LR curves for MetaLR on POCUS detection (a), on LiTS segmentation (b), with constant hyper-LR (c), and with a separate validation set (d).
updates the LR for each layer according to their validation gradients, and its training efficiency is not affected by the complex model architecture. 3.4
Discussion and Findings
The LRs Learned with MetaLR. For ResNet-18 (Fig. 2 (a)), the layer-wise LRs fluctuate drastically during the first 100 iterations. However, after iteration 100, all layers except the first layer “Conv1” become stable at different levels. The first layer has a decreasing LR (from 2.8 × 10−3 to 3 × 10−4 ) throughout the process, reflecting its higher transferability. For 3D U-Net (Fig. 2 (b)), the middle layers of the encoder “Down-128” and “Down-256” are the most transferable and have the lowest LRs, which is difficult for previous fine-tuning schemes to
714
Y. Chen et al.
discover. As expected, the randomly initialized “FC” and “Out” layers have the largest LRs since they are not transferable. The Effectiveness of Proportional Hyper-LR and Training Batches Validation. We illustrate the LR curves with a constant hyper-LR instead of a proportional one. The LR curves of “Block 3-1” and “Block 4-2” become much more fluctuated (Fig. 2 (c)). This instability may be the key reason for the instability of performance when using a constant hyper-LR. Furthermore, we surprisingly find that the learned LRs are similar to the curves learned when validated on the training set when using a separate validation set Fig. 2 (d)). With similar learned LR curves and more training data, it is reasonable that batched training set validation can be an effective alternative to the basic MetaLR. Limitations of MetaLR. Although MetaLR improves fine-tuning for medical image analysis, it has several limitations. First, the gradient descent of Eq. (3) takes more memory than the usual fine-tuning strategy, it may restrict the batch size available during training. Second, MetaLR sometimes does not get converged LRs after the parameters converge, which may harm its performance in some cases. Third, MetaLR is designed for medical fine-tuning instead of general cases, what problem it may encounter in other scenarios remains unknown.
4
Conclusion
In this work, we proposed a new fine-tuning scheme, MetaLR, for medical transfer learning. It achieves significantly superior performance to the previous SOTA fine-tuning algorithms. MetaLR alternatively optimizes model parameters and layer-wise LRs in an online meta-learning fashion with a proportional hyper-LR. It learns to assign lower LRs for the layers with higher transferability and higher LRs for the less transferable layers. The proposed algorithm is easy to implement and shows the potential to replace manual layer-wise fine-tuning schemes. Future works include adapting MetaLR to a wider variety of clinical tasks. Acknowledgement. This work was supported by the National Natural Science Foundation of China (No. 62101351) and the GuangDong Basic and Applied Basic Research Foundation (No.2020A1515110376).
References 1. Al-Dhabyani, W., Gomaa, M., Khaled, H., Fahmy, A.: Dataset of breast ultrasound images. Data Brief 28, 104863 (2020) 2. Amiri, M., Brooks, R., Rivaz, H.: Fine-tuning u-net for ultrasound image segmentation: different layers, different outcomes. IEEE Trans. Ultrason. Ferroelectr. Freq. Control 67(12), 2510–2518 (2020) 3. Armato, S.G., III., McLennan, G., Bidaut, L., et al.: The lung image database consortium (LIDC) and image database resource initiative (IDRI): a completed reference database of lung nodules on CT scans. Med. Phys. 38(2), 915–931 (2011)
MetaLR
715
4. Bilic, P., et al.: The liver tumor segmentation benchmark (LITS). arXiv preprint arXiv:1901.04056 (2019) 5. Born, J., Wiedemann, N., Cossio, M., et al.: Accelerating detection of lung pathologies with explainable ultrasound image analysis. Appl. Sci. 11(2), 672 (2021) 6. Chambon, P., Cook, T.S., Langlotz, C.P.: Improved fine-tuning of in-domain transformer model for inferring COVID-19 presence in multi-institutional radiology reports. J. Digit. Imaging, 1–14 (2022) 7. Chen, Y., Li, J., Ding, C., Liu, L.: Rethinking two consensuses of the transferability in deep learning. arXiv preprint arXiv:2212.00399 (2022) 8. Chen, Y., Zhang, C., Ding, C.H., Liu, L.: Generating and weighting semantically consistent sample pairs for ultrasound contrastive learning. IEEE TMI (2022) 9. Chen, Y., et al.: USCL: pretraining deep ultrasound image diagnosis model through video contrastive representation learning. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12908, pp. 627–637. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87237-3 60 ¨ Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-net: 10. C ¸ i¸cek, O., learning dense volumetric segmentation from sparse annotation. In: Ourselin, S., Joskowicz, L., Sabuncu, M.R., Unal, G., Wells, W. (eds.) MICCAI 2016. LNCS, vol. 9901, pp. 424–432. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946723-8 49 11. Deng, J., Dong, W., Socher, R., Li, L.J., Li, K., Fei-Fei, L.: Imagenet: a large-scale hierarchical image database. In: CVPR, pp. 248–255 (2009) 12. Finn, C., Abbeel, P., Levine, S.: Model-agnostic meta-learning for fast adaptation of deep networks. In: ICML, pp. 1126–1135. PMLR (2017) 13. Franceschi, L., Frasconi, P., Salzo, S., Grazzi, R., Pontil, M.: Bilevel programming for hyperparameter optimization and meta-learning. In: ICML, pp. 1568–1577. PMLR (2018) 14. Guo, Y., Shi, H., Kumar, A., Grauman, K., Rosing, T., Feris, R.: Spottune: transfer learning through adaptive fine-tuning. In: CVPR, pp. 4805–4814 (2019) 15. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: CVPR, pp. 770–778 (2016) 16. Johnson, A.E., et al.: Mimic-cxr-jpg, a large publicly available database of labeled chest radiographs. arXiv preprint arXiv:1901.07042 (2019) 17. Kermany, D., Zhang, K., Goldbaum, M.: Large dataset of labeled optical coherence tomography (OCT) and chest x-ray images. Mendeley Data 3, 10–17632 (2018) 18. Li, Z., Zhou, F., Chen, F., Li, H.: Meta-SGD: learning to learn quickly for few-shot learning. arXiv preprint arXiv:1707.09835 (2017) 19. Nichol, A., Achiam, J., Schulman, J.: On first-order meta-learning algorithms. arXiv preprint arXiv:1803.02999 (2018) 20. Ren, M., Zeng, W., Yang, B., Urtasun, R.: Learning to reweight examples for robust deep learning. In: ICML, pp. 4334–4343. PMLR (2018) 21. Riasatian, A., et al.: Fine-tuning and training of densenet for histopathology image representation using TCGA diagnostic slides. Med. Image Anal. 70, 102032 (2021) 22. Ro, Y., Choi, J.Y.: Autolr: layer-wise pruning and auto-tuning of learning rates in fine-tuning of deep networks. In: AAAI, vol. 35, pp. 2486–2494 (2021) 23. Tajbakhsh, N., et al.: Convolutional neural networks for medical image analysis: full training or fine tuning? IEEE TMI 35(5), 1299–1312 (2016) 24. Vrbanˇciˇc, G., Podgorelec, V.: Transfer learning with adaptive fine-tuning. IEEE Access 8, 196197–196211 (2020) 25. Wang, G., et al.: Interactive medical image segmentation using deep learning with image-specific fine tuning. IEEE TMI 37(7), 1562–1573 (2018)
716
Y. Chen et al.
26. Yosinski, J., Clune, J., Bengio, Y., Lipson, H.: How transferable are features in deep neural networks? In: NeurIPS, vol. 27 (2014) 27. Zhang, C., Chen, Y., Liu, L., Liu, Q., Zhou, X.: HiCo: hierarchical contrastive learning for ultrasound video model pretraining. In: Wang, L., Gall, J., Chin, T.J., Sato, I., Chellappa, R. (eds.) ACCV 2022. LNCS, vol. 13846, pp. 229–246. Springer, Cham (2022). https://doi.org/10.1007/978-3-031-26351-4 1 28. Zhou, H.-Y., Yu, S., Bian, C., Hu, Y., Ma, K., Zheng, Y.: Comparing to learn: surpassing imagenet pretraining on radiographs by comparing image representations. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 398–407. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59710-8 39 29. Zhou, Z., et al.: Models genesis: generic autodidactic models for 3D medical image analysis. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11767, pp. 384–393. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32251-9 42
Multi-Target Domain Adaptation with Prompt Learning for Medical Image Segmentation Yili Lin1 , Dong Nie3 , Yuting Liu2 , Ming Yang2 , Daoqiang Zhang1 , and Xuyun Wen1(B) 1
Department of Computer Science and Technology, Nanjing University of Aeronautics and Astronautics, Nanjing, China [email protected] 2 Department of Radiology, Children’s Hospital of Nanjing Medical University, Nanjing, China 3 Alibaba Inc, El Monte, USA
Abstract. Domain shift is a big challenge when deploying deep learning models in real-world applications due to various data distributions. The recent advances of domain adaptation mainly come from explicitly learning domain invariant features (e.g., by adversarial learning, metric learning and self-training). While they cannot be easily extended to multi-domains due to the diverse domain knowledge. In this paper, we present a novel multi-target domain adaptation (MTDA) algorithm, i.e., prompt-DA, through implicit feature adaptation for medical image segmentation. In particular, we build a feature transfer module by simply obtaining the domain-specific prompts and utilizing them to generate the domain-aware image features via a specially designed simple feature fusion module. Moreover, the proposed prompt-DA is compatible with the previous DA methods (e.g., adversarial learning based) and the performance can be continuously improved. The proposed method is evaluated on two challenging domain-shift datasets, i.e., the Iseg2019 (domain shift in infant MRI of different ages), and the BraTS2018 dataset (domain shift between high-grade and low-grade gliomas). Experimental results indicate our proposed method achieves state-of-the-art performance in both cases, and also demonstrates the effectiveness of the proposed prompt-DA. The experiments with adversarial learning DA show our proposed prompt-DA can go well with other DA methods. Our code is available at https://github.com/MurasakiLin/prompt-DA.
Keywords: Domain Adaptation
· Prompt Learning · Segmentation
Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 68. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 717–727, 2023. https://doi.org/10.1007/978-3-031-43907-0_68
718
1
Y. Lin et al.
Introduction
Deep learning has brought medical image segmentation into the era of datadriven approaches, and has made significant progress in this field [1,2], i.e., the segmentation accuracy has improved considerably. In spite of the huge success, the deployment of trained segmentation models is often severely impacted by a distribution shift between the training (or labeled) and test (or unlabeled) data since the segmentation performance will deteriorate greatly in such situations. Domain shift is typically caused by various factors, including differences in acquisition protocols (e.g., parameters, imaging methods, modalities) and characteristics of data (e.g., age, gender, the severity of the disease and so on). Domain adaptation (DA) has been proposed and investigated to combat distribution shift in medical image segmentation. Many researchers proposed using adversarial learning to tackle distribution shift problems [3–7]. These methods mainly use the game between the domain classifier and the feature extractor to learn domain-invariant features. However, they easily suffer from the balance between feature alignment and discrimination ability of the model. Some recent researchers begin to explore self-training based DA algorithms, which generate pseudo labels for the ‘other’ domain samples to fulfill self-training [8–11]. While it is very difficult to ensure the quality of pseudo labels in the ‘other’ domain and is also hard to build capable models with noise labels. However, most of these methods cannot well handle the situation when the domains are very diverse, since it is very challenging to learn domain-invariant features when each domain contains domain-specific knowledge. Also, the domain information itself is well utilized in the DA algorithms. To tackle the aforementioned issues, we propose utilizing prompt learning to take full advantage of domain information. Prompt learning [12,13] is a recently emergent strategy to extend the same natural language processing (NLP) model to different tasks without re-training. Prompt learning models can autonomously tune themselves for different tasks by transferring domain knowledge introduced through prompts, and they can usually demonstrate better generalization ability across many downstream tasks. very few works have attempted to apply prompt learning to the computer vision field, and have achieved promising results. [14] introduced prompt tuning as an efficient and effective alternative to full finetuning for large-scale Transformer models. [15] exploited prompt learning to fulfill domain generalization in image classification tasks. The prompts in these models are generated and used in the very early stage of the models, which prevents the smooth combination with other domain adaptation methods. In this paper, we introduce a domain prompt learning method (prompt-DA) to tackle distribution shift in multi-target domains. Different from the recent prompt learning methods, we generate domain-specific prompts in the encoding feature space instead of the image space. As a consequence, it can improve the quality of the domain prompts, more importantly, we can easily consolidate the prompt learning with the other DA methods, for instance, adversarial learning based DA. In addition, we propose a specially designed fusion module to reinforce the respective characteristics of the encoder features and domain-
MTDA with Prompt Learning for Medical Image Segmentation
719
specific prompts, and thus generate domain-aware features. As a way to prove the prompt-DA is compatible with other DAs, a very simple adversarial learning module is jointly adopted in our method to further enhance the model’s generalization ability (we denote this model as comb-DA). We evaluate our proposed method on two multi-domain datasets: 1). the infant brain MRI dataset for cross-age segmentation; 2). the BraTS2018 dataset for cross-grade tumor segmentation. Experiments show our proposed method outperforms state-of-the-art methods. Moreover, ablation study demonstrates the effectiveness of the proposed domain prompt learning and the feature fusion module. Our claim about the successful combination of prompt learning with adversarial learning is also well-supported by experiments.
2
Methodology
Our proposed prompt-DA network consists of three main components as depicted in Fig. 1(a): a typical encoder-decoder network (e.g., UNet) serving the segmentation baseline, a prompt learning network to learn domain-specific prompts, and a fusion module aggregating the image features and domain-specific prompts to build domain-aware feature representation, where the fused features are fed into the decoder. It is worth noting that our proposed method is compatible with other DA algorithms, and thus we can add an optional extra DA module to further optimize the domain generalization ability, in this paper, we choose an adversarial learning based DA as an example since it is the mostly used DA methods in medical image segmentation (as introduced in Sect. 1). There are various encoder-decoder segmentation networks, many of which are well known. As a result, we donot introduce the details of the encoder-decoder and just choose two typical networks to work as the segmentation backbone, that is, 3D-UNet [16] (convolution based and 3D) and TransUNet [17] (transformer based and 2D). In the following subsections, our focus will be on domain-specific prompt generation and domain-aware feature learning with feature fusion. 2.1
Learning Domain-Specific Prompts
In our designed prompt learning based DA method, it is essential to learn domain-specific prompts. Moreover, the quality of generated prompts directly determines the domain-aware features. Therefore, we specially designed a prompt generation module to learn domain-specific prompts which mainly consists of two components, i.e., a classifier and a prompt generator. Our approach incorporates domain-specific information into the prompts to guide the model in adapting to the target domain. To achieve this, we introduce ˆ of the input image, a classifier h(x) that distinguishes the domain (denoted as d) shown in Eq. 1. dˆ = h(x) (1) where x is the image or abstracted features from the encoder.
720
Y. Lin et al. Domain-specific prompts
seg outputs
training parameters
Optional Prompt generator
discrimination outputs
category feature
Decoder
cls head
D encoder mid layer feature
prompted feature
classification outputs
classifier
Domain-specific prompts AFusion
encoder feature
(b) prompt generation (prompt-DA module) Prompt generator
prompted feature
+
category feature
CNN Encoder
classifier
classification outputs
Sigmoid expand L1
source domain
ws
+
wc expand 1C
FC+SiLU+FC
target domain FC reshape
Pooling
encoder feature
(a) overview
reshape
LC CHW
LC
prompts
(c) feature fusion (AFusion module)
Fig. 1. a). Overview of the proposed prompt-DA network; b). The prompt generation module to learn domain-specific prompt; c). The feature fusion to learn domain-aware features.
To optimize the parameters, we adopt cross-entropy loss to train the classifier, as shown in Eq. 2. ˆ d) = Lce (d, ˆ d) = − Lcls (d,
C
d(i) log dˆ(i)
(2)
i=1
where dˆ is the predicted domain information, and d is the ground truth domain information. Prompt Generation: Instead of directly using dˆ as the category information, we fed the second-to-last layer’s features (i.e., z) of the classifier to a prompt generation, namely, g(z). In particular, the g(z) is a multi-layer-perception, as defined in Eq. 3. (3) g(z) = φ3 (φ2 (φ1 (z))) where φ can be a Conv+BN+ReLU sequence. Note this module does not change the size of the feature map, instead, it transforms the extracted category features into domain-specific prompts. 2.2
Learning Domain-Aware Representation by Fusion
The learned prompt captures clearly about a certain domain and the features from the encoder describe the semantics as well as spatial information for the images. We can combine them to adapt the image features to domain-aware representations.
MTDA with Prompt Learning for Medical Image Segmentation
721
Basically, suppose we have an image denoted as I, and the prompt encodings for the domain knowledge is g(e(I)) (where e(I) is the features from a shallow layer), E(I) is the encoder features for this image. Then the domain-aware features (i.e., F ) are extracted by a fusion module as Eq. 4. F = ψ(g(e(I)), E(I))
(4)
As the learned prompt and encoder feature capture quite different aspects of the input data, we cannot achieve good effect by simply using addition, multiplication or concatenation to serve as the fusion function ψ. Specifically, while the encoder feature emphasizes spatial information for image segmentation, the prompt feature highlights inter-channel information for domain-related characteristics. To account for these differences, we propose a simple attention-based fusion (denoted as AFusion) module to smoothly aggregate the information. This module computes channel-wise and spatial-wise weights separately to enhance both the channel and spatial characteristics of the input. Figure 1(c) illustrates the structure of our proposed module. Our module utilizes both channel and spatial branches to obtain weights for two input sources. The spatial branch compresses the encoder feature in the channel dimension using an FC layer to obtain spatial weights. Meanwhile, the channel branch uses global average pooling and two FC layers to compress the prompt and obtain channel weights. We utilize FC layers for compression and rescaling, denoted as fcp and fre respectively. The spatial and channel weights are computed according to Eq. 5. Ws = fcp (E(I)), Wc = fre (fcp (avgpool(g(e(I)))))
(5)
Afterward, we combine the weights from the spatial and channel dimensions to obtain a token that can learn both high-level and low-level features from both the encoder feature and the prompt, which guides the fusion of the two features. The process is illustrated as follows: W = sigmoid(Wc + Ws ), Fout = g(e(I)) ∗ W + E(I) ∗ (1 − W)
(6)
This module introduces only a few parameters, yet it can effectively improve the quality of the prompted domain-aware features after feature fusion. In the experimental section (i.e., Sect. 3.3), we conducted relevant experiments to verify that this module can indeed improve the performance of our prompt-DA method. 2.3
Adversarial Learning to Enhance the Generalization Ability
As aforementioned, our proposed prompt-DA is fully compatible with other DA algorithms. We thus use adversarial learning, which is widely adopted in medical image DA, to work as an optional component in our network to continuously enhance the domain adaptation ability.
722
Y. Lin et al.
Specially, inspired by the adversarial DA in [18], we adopt the classic GAN loss to train the discriminator and prompt generator (Note the adversarial loss, Ladv , for the generator will only be propagated to the prompt generator). 2.4
Total Loss
To optimize the segmentation backbone network, we use a combined loss function, Lseg , that incorporates both dice loss [19] and cross-entropy loss with a balance factor. By summing the above-introduced losses, the total loss to train the segmentation network can be defined by Eq. 7. Ltotal = Lseg + λcls Lcls + λadv Ladv ,
(7)
where λ is the scaling factor to balance the losses. 2.5
Implementation Details
We use basic 3D-UNet [16] or TransUnet [17] as our segmentation network. We use a fully convolutional neural network consisting of four convolutional layers with 3 × 3 kernels and stride of 1 as the domain classifier, with each convolution layer followed by a ReLU parameterized by 0.2. We used three convolutional layers with ReLU activation function as the Prompt Generator and constructed a Discriminator with a similar structure to the Classifier. We adopt Adam as the optimizer and set the learning rate to 0.0002 and 0.002 for the segmentation and domain classifier, respectively. The learning rate will be decayed by 0.1 every quarter of the training process.
3 3.1
Experiments and Results Datasets
Our proposed method was evaluated using two medical image segmentation DA datasets. The first dataset, i.e., cross-age infant segmentation [20], was used for cross-age infant brain image segmentation, while the second dataset, i.e., Brats2018 [21], was used for HGG to LGG domain adaptation. The first dataset is for infant brain segmentation (white matter, gray matter and cerebrospinal fluid). To build the cross-age dataset, we take advantage 10 brain MRIs of 6-month-old from iSeg2019 [20], and also build 3-month-old and 12-month-old in-house datasets. In this dataset, we collect 11 brain MRI for both the 3-month-old and 12-month-old infants. We take the 6-month-old data as the source domain, the 3-month-old and 12-month-old as the target domains. The 2nd dataset is for brain tumor segmentation (enhancing tumor, peritumoral edema and necrotic and non-enhancing tumor core), which has 285 MRI samples (210 HGG and 75 LGG). We take HGG as the source domain and LGG as the target domain.
MTDA with Prompt Learning for Medical Image Segmentation
723
Table 1. Comparison with SOTA DA methods on the infant brain segmentation task. The evaluation metric shown is DICE. Method
6 month WM GM
CSF
avg.
3 month WM GM
CSF
avg.
82.47 88.57 93.84 88.29 68.58 83.44 88.97 94.76 89.06 65.66
72.37 73.23
76.45 66.74
72.47 68.54
69.29 55.54
61.92 63.67
62.84 67.19
64.68 62.13
ADDA CyCADA SIFA ADR ours
80.88 81.12 81.71 81.69 81.77
75.12 75.24 76.98 77.02 80.03
75.78 77.13 77.02 76.65 78.74
73.62 74.16 74.79 75.10 77.93
70.02 70.12 69.89 70.16 70.59
68.13 70.54 71.12 72.04 74.51
62.94 62.91 63.01 62.98 63.18
67.03 67.86 68.01 68.39 69.43
3.2
92.96 93.06 92.98 93.01 93.04
avg.
12 month WM GM
no-DA nn-UNet
87.36 87.89 87.87 87.94 88.01
CSF
87.07 87.36 87.52 87.55 87.61
69.98 70.12 70.37 71.81 75.03
Comparison with State-of-the-Art (SOTA) Method
We compared our method with four SOTA methods: ADDA [18], CyCADA [22], SIFA [23] and ADR [24]. We directly use the code from the corresponding papers. For fair comparison, we have replaced the backbone of these models with the same we used in our approach. The quantitative comparison results of cross-age infant brain segmentation is presented in Table 1, and due to space limitations, we put the experimental results of the brain tumor segmentation task in Table 1 of Supplementary Material, Sec.3. As observed, our method demonstrates very good DA ability on the crossage infant segmentation task, which improves about 5.46 DICE and 4.75 DICE on 12-month-old and 3-month-old datasets, respectively. When compared to the four selected SOTA DA methods, we also show superior transfer performance in all the target domains. Specially, we outperform other SOTA methods by at least 2.83 DICE and 1.04 DICE on the 12-month-old and 3-month-old tasks. When transferring to a single target domain in the brain tumor segmentation task, our proposed DA solution improves about 3.09 DICE in the target LGG domain. Also, the proposed method shows considerable improvements over ADDA and CyCADA, but very subtle improvements to the SIFA and ADR methods (although ADR shows a small advantage on the Whole category). We also visualize the segmentation results on a typical test sample of the infant brain dataset in Fig. 2, which once again demonstrates the advantage of our method in some detailed regions.
T1w(12m/3m)
No adaption
ADDA
CyCADA
SIFA
ADR
ours
groundtruth
Fig. 2. Visualization of segmentation maps (details) for all the comparison methods.
724
Y. Lin et al. Table 2. Ablation study about prompt-DA, adv-DA and comb-DA.
model
experiment 6 month WM GM
CSF
avg.
12 month WM GM
CSF
avg.
3 month WM GM
CSF
avg.
no-DA adv-DA prompt-DA comb-DA
82.47 80.88 81.57 81.77
88.57 87.36 87.90 88.01
93.84 92.96 93.06 93.04
88.29 87.07 87.51 87.61
68.58 69.98 71.3 75.03
72.37 75.12 77.82 80.03
76.45 75.78 77.16 78.74
72.47 73.62 75.06 77.93
69.29 70.02 70.69 70.59
61.92 68.13 69.51 74.51
62.84 62.94 62.83 63.18
64.68 67.03 67.68 69.43
TransUnet no-DA adv-DA prompt-DA comb-DA
73.24 72.76 73.01 72.98
81.12 80.72 80.31 80.59
84.19 82.98 83.21 83.61
79.52 78.82 78.84 79.06
66.04 66.72 67.41 70.25
70.12 70.39 71.01 72.57
54.94 55.21 55.41 57.04
63.72 64.11 64.61 66.62
39.70 39.89 40.17 42.61
59.49 60.02 60.22 61.03
59.25 59.89 60.09 61.57
52.81 53.27 53.49 55.07
3D-Unet
3.3
Ablation Study
Prompt-DA vs. adv-DA: Since the performance reported in Table 1 is achieved with the method combining prompt-DA and adv-DA, we carry out more studies to investigate: 1). Does prompt-DA itself shows the transfer ability? 2). Is prompt-DA compatible with adv-DA? The corresponding experiments are conducted on the infant brain dataset and experimental results are shown in Table 2. To make the table more readable, we denote: no-DA means only training the segmentation network without any DA strategies; adv-DA presents only using adversarial learning based DA; promptDA is the proposed prompt learning based DA and comb-DA is our final DA algorithm which combines both adv-DA and prompt-DA. As observed in Table 2, both adv-DA and prompt-DA can improve the transfer performance on all the target domains. When looking into details, the proposed prompt-DA can improve more (1.44 DICE and 0.65 DICE respectively) compared to the adv-DA on both 12-month-old and 3-month-old with 3D-UNet segmentation backbone. When combined together (i.e., comb-DA), the performance can be further improved by a considerable margin, 2.87 DICE and 1.75 DICE on 12-month-old and 3-month-old respectively, compared to prompt-DA. With TransUNet segmentation backbone, we can find the similar phenomenon. To this end, we can draw conclusions that 1). Prompt-DA itself is beneficial to improve the transfer ability; 2). prompt-DA is quite compatible with adv-DA. Fusion Strategy for Learning Domain-Aware Features: One of the key components of the prompt-DA is to learn domain-aware features through fusion. We have evaluated the effectiveness of our proposed feature fusion strategy in both 3D and 2D models. For comparison, we considered several other fusion strategies, including ‘add/mul’, which adds or multiplies the encoder feature and prompt directly, ‘conv’, which employs a single convolutional layer to process the concatenated features, and ‘rAFusion’, which utilizes a reverse version of the AFusion module, sending the prompt to the spatial branch and the encoder feature to the channel branch. The results of these experiments are presented in Table 3.
MTDA with Prompt Learning for Medical Image Segmentation
725
Table 3. Ablation study about fusion strategies to learn domain-aware features. model
experiment 6 month WM GM
CSF
avg.
12 month WM GM
CSF
avg.
3 month WM GM
CSF
avg.
no-DA add/mul conv rAFusion AFusion
82.47 81.31 81.4 81.72 81.77
88.57 87.62 87.61 88.17 88.01
93.84 92.67 93.16 93.33 93.04
88.29 87.2 87.39 87.74 87.61
68.58 73.88 73.06 74.82 75.03
72.37 78.59 77.73 79.75 80.03
76.45 76.52 78.73 77.84 78.74
72.47 76.33 76.51 77.47 77.93
69.29 68.27 69.91 69.42 70.59
61.92 74.17 71.93 74.54 74.51
62.84 62.90 63.02 62.98 63.18
64.68 68.45 68.29 68.98 69.43
TransUnet no-DA add/mul conv rAFusion AFusion
73.24 72.66 72.72 72.96 72.98
81.12 80.69 80.56 80.72 80.59
84.19 83.77 83.71 83.74 83.61
79.52 79.04 79.00 79.14 79.06
66.04 68.66 69.75 69.80 70.25
70.12 70.92 71.01 72.44 72.57
54.94 55.52 55.56 56.18 57.04
63.72 65.03 65.44 66.14 66.62
39.70 40.41 41.32 42.12 42.61
59.49 60.87 60.97 61.01 61.03
59.25 60.32 60.24 60.50 61.57
52.81 53.87 54.18 54.54 55.07
3D-Unet
Our experimental results demonstrate that the proposed AFusion module improves the model’s performance significantly, and it is effective for both 3D and 2D models.
4
Conclusion
In this paper, we propose a new DA paradigm, namely, prompt learning based DA. The proposed prompt-DA uses a classifier and a prompt generator to produce domain-specific information and then employs a fusion module (for encoder features and prompts) to learn domain-aware representation. We show the effectiveness of our proposed prompt-DA in transfer ability, and also we prove that the prompt-DA is smoothly compatible with the other DA algorithms. Experiments on two DA datasets with two different segmentation backbones demonstrate that our proposed method works well on DA problems. Acknowledgements. This work was supported by the National Natural Science Foundation of China (No. 62001222), the China Postdoctoral Science Foundation funded project (No. 2021TQ0150 and No. 2021M701699).
References 1. Ronneberger, O., Fischer, P., Brox, T.: U-net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4 28 2. Zhou, S., Nie, D., Adeli, E., Yin, J., Lian, J., Shen, D.: High-resolution encoderdecoder networks for low-contrast medical image segmentation. IEEE Trans. Image Process. 29, 461–475 (2019) 3. Ouyang, C., Kamnitsas, K., Biffi, C., Duan, J., Rueckert, D.: Data efficient unsupervised domain adaptation for cross-modality image segmentation. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11765, pp. 669–677. Springer, Cham (2019). https://doi.org/10.1007/978-3-030-32245-8 74
726
Y. Lin et al.
4. Xie, X., Chen, J., Li, Y., Shen, L., Ma, K., Zheng, Y.: MI2 GAN: generative adversarial network for medical image domain adaptation using mutual information constraint. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12262, pp. 516–525. Springer, Cham (2020). https://doi.org/10.1007/978-3-030-59713-9 50 5. Dou, Q., et al.: Pnp-adanet: plug-and-play adversarial domain adaptation network at unpaired cross-modality cardiac segmentation. IEEE Access 7, 99 065–99 076 (2019) 6. Chen, C., Dou, Q., Chen, H., Qin, J., Heng, P.-A.: Synergistic image and feature adaptation: Towards cross-modality domain adaptation for medical image segmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, no. 01, pp. 865–872 (2019) 7. Cui, H., Yuwen, C., Jiang, L., Xia, Y., Zhang, Y.: Bidirectional cross-modality unsupervised domain adaptation using generative adversarial networks for cardiac image segmentation. Comput. Biol. Med. 136, 104726 (2021) 8. Kumar, A., Ma, T., Liang, P.: Understanding self-training for gradual domain adaptation. In: International Conference on Machine Learning, pp. 5468–5479. PMLR (2020) 9. Sheikh, R., Schultz, T.: Unsupervised domain adaptation for medical image segmentation via self-training of early features. In: International Conference on Medical Imaging with Deep Learning, pp. 1096–1107. PMLR (2022) 10. Xie, Q., et al.: Unsupervised domain adaptation for medical image segmentation by disentanglement learning and self-training. IEEE Trans. Med. Imaging (2022) 11. Yang, C., Guo, X., Chen, Z., Yuan, Y.: Source free domain adaptation for medical image segmentation with Fourier style mining. Med. Image Anal. 79, 102457 (2022) 12. Liu, X., Ji, K., Fu, Y., Du, Z., Yang, Z., Tang, J.: P-tuning v2: prompt tuning can be comparable to fine-tuning universally across scales and tasks. arXiv preprint arXiv:2110.07602 (2021) 13. Zhou, K., Yang, J., Loy, C.C., Liu, Z.: Learning to prompt for vision-language models. Int. J. Comput. Vision 130(9), 2337–2348 (2022) 14. Jia, M., et al.: Visual prompt tuning. In: Avidan, S., Brostow, G., Ciss´e, M., Farinella, G.M., Hassner, T. (eds.) Computer Vision - ECCV 2022. Lecture Notes in Computer Science, vol. 13693, pp. 709–727. Springer, Cham (2022) 15. Zheng, Z., Yue, X., Wang, K., You, Y.: Prompt vision transformer for domain generalization. arXiv preprint arXiv:2208.08914 (2022) ¨ Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-net: 16. C ¸ i¸cek, O., learning dense volumetric segmentation from sparse annotation. In: Ourselin, S., Joskowicz, L., Sabuncu, M.R., Unal, G., Wells, W. (eds.) MICCAI 2016. LNCS, vol. 9901, pp. 424–432. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946723-8 49 17. Chen, J., et al.: Transunet: transformers make strong encoders for medical image segmentation. arXiv preprint arXiv:2102.04306 (2021) 18. Tzeng, E., Hoffman, J., Saenko, K., Darrell, T.: Adversarial discriminative domain adaptation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7167–7176 (2017) 19. Sudre, C.H., Li, W., Vercauteren, T., Ourselin, S., Jorge Cardoso, M.: Generalised dice overlap as a deep learning loss function for highly unbalanced segmentations. In: Cardoso, M.J., et al. (eds.) DLMIA/ML-CDS -2017. LNCS, vol. 10553, pp. 240–248. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-67558-9 28 20. Sun, Y., et al.: Multi-site infant brain segmentation algorithms: the ISEG-2019 challenge. IEEE Trans. Med. Imaging 40(5), 1363–1376 (2021)
MTDA with Prompt Learning for Medical Image Segmentation
727
21. Bakas, S., et al.: Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the brats challenge. arXiv preprint arXiv:1811.02629 (2018) 22. Hoffman, J., et al.: Cycada: cycle-consistent adversarial domain adaptation. In: International Conference on Machine Learning, pp. 1989–1998. PMLR (2018) 23. Chen, C., Dou, Q., Chen, H., Qin, J., Heng, P.A.: Unsupervised bidirectional crossmodality adaptation via deeply synergistic image and feature alignment for medical image segmentation. IEEE Trans. Med. Imaging 39(7), 2494–2505 (2020) 24. Zeng, G., et al.: Semantic consistent unsupervised domain adaptation for crossmodality medical image segmentation. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12903, pp. 201–210. Springer, Cham (2021). https://doi.org/10. 1007/978-3-030-87199-4 19
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation Jiajin Zhang1 , Hanqing Chao1 , Amit Dhurandhar2 , Pin-Yu Chen2 , Ali Tajer3 , Yangyang Xu4 , and Pingkun Yan1(B) 1 Department of Biomedical Engineering and Center for Biotechnology and Interdisciplinary Studies, Rensselaer Polytechnic Institute, Troy, NY, USA [email protected] 2 IBM Thomas J. Watson Research Center, Yorktown Heights, NY, USA 3 Department of Electrical, Computer, and Systems Engineering, Rensselaer Polytechnic Institute, Troy, NY, USA 4 Department of Mathematical Sciences, Rensselaer Polytechnic Institute, Troy, NY, USA
Abstract. Domain shift is a common problem in clinical applications, where the training images (source domain) and the test images (target domain) are under different distributions. Unsupervised Domain Adaptation (UDA) techniques have been proposed to adapt models trained in the source domain to the target domain. However, those methods require a large number of images from the target domain for model training. In this paper, we propose a novel method for Few-Shot Unsupervised Domain Adaptation (FSUDA), where only a limited number of unlabeled target domain samples are available for training. To accomplish this challenging task, first, a spectral sensitivity map is introduced to characterize the generalization weaknesses of models in the frequency domain. We then developed a Sensitivity-guided Spectral Adversarial MixUp (SAMix) method to generate target-style images to effectively suppresses the model sensitivity, which leads to improved model generalizability in the target domain. We demonstrated the proposed method and rigorously evaluated its performance on multiple tasks using several public datasets. The source code is available at https://github.com/ RPIDIAL/SAMix. Keywords: Few-shot UDA
1
· Data Augmentation · Spectral Sensitivity
Introduction
A common challenge for deploying deep learning to clinical problems is the discrepancy between data distributions across different clinical sites [6,15,20,28,29]. This discrepancy, which results from vendor or protocol differences, can cause a significant performance drop when models are deployed to a new site [2, 21,23]. To solve this problem, many Unsupervised Domain Adaptation (UDA) methods [6] have been developed for adapting a model to a new site with only c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 728–738, 2023. https://doi.org/10.1007/978-3-031-43907-0_69
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation
729
unlabeled data (target domain) by transferring the knowledge learned from the original dataset (source domain). However, most UDA methods require sufficient target samples, which are scarce in medical imaging due to the limited accessibility to patient data. This motivates a new problem of Few-Shot Unsupervised Domain Adaptation (FSUDA), where only a few unlabeled target samples are available for training. Few approaches [11,22] have been proposed to tackle the problem of FSUDA. Luo et. al [11] introduced Adversarial Style Mining (ASM), which uses a pretrained style-transfer module to generate augmented images via an adversarial process. However, this module requires extra style images [9] for pre-training. Such images are scarce in clinical settings, and style differences across sites are subtle. This hampers the applicability of ASM to medical image analysis. SMPPM [22] trains a style-mixing model for semantic segmentation by augmenting source domain features to a fictitious domain through random interpolation with target domain features. However, SM-PPM is specifically designed for segmentation tasks and cannot be easily adapted to other tasks. Also, with limited target domain samples in FSUDA, the random feature interpolation is ineffective in improving the model’s generalizability. In a different direction, numerous UDA methods have shown high performance in various tasks [4,16–18]. However, their direct application to FSUDA can result in severe overfitting due to the limited target domain samples [22]. Previous studies [7,10,24,25] have demonstrated that transferring the amplitude spectrum of target domain images to a source domain can effectively convey image style information and diversify training dataset. To tackle the overfitting issue of existing UDA methods, we propose a novel approach called Sensitivityguided Spectral Adversarial MixUp (SAMix) to augment training samples. This approach uses an adversarial mixing scheme and a spectral sensitivity map that reveals model generalizability weaknesses to generate hard-to-learn images with limited target samples efficiently. SAMix focuses on two key aspects. 1) Model generalizability weaknesses: Spectral sensitivity analysis methods have been applied in different works [26] to quantify the model’s spectral weaknesses to image amplitude corruptions. Zhang et al. [27] demonstrated that using a spectral sensitivity map to weigh the amplitude perturbation is an effective data augmentation. However, existing sensitivity maps only use single-domain labeled data and cannot leverage target domain information. To this end, we introduce a Domain-Distance-modulated Spectral Sensitivity (DoDiSS) map to analyze the model’s weaknesses in the target domain and guide our spectral augmentation. 2) Sample hardness: Existing studies [11,19] have shown that mining hard-to-learn samples in model training can enhance the efficiency of data augmentation and improve model generalization performances. Therefore, to maximize the use of the limited target domain data, we incorporate an adversarial approach into the spectral mixing process to generate the most challenging data augmentations. This paper has three major contributions. 1) We propose SAMix, a novel approach for augmenting target-style samples by using an adversarial spectral mixing scheme. SAMix enables high-performance UDA methods to adapt easily
730
J. Zhang et al.
to FSUDA problems. 2) We introduce DoDiSS to characterize a model’s generalizability weaknesses in the target domain. 3) We conduct thorough empirical analyses to demonstrate the effectiveness and efficiency of SAMix as a plug-in module for various UDA methods across different tasks.
2
Methods
We denote the labeled source domain as XS = {(xsn , yns )}N n=1 and the unlabeled s t h×w , x , x ∈ R . Figure 1 depicts the K-shot target domain as XT = {xtk }K n k=1 k framework of our method as a plug-in module for boosting a UDA method in the FSUDA scenario. It contains two components. First, a Domain-Distancemodulated Spectral Sensitivity (DoDiSS) map is calculated to characterize a source model’s weaknesses in generalizing to the target domain. Then, this sensitivity map is used for Sensitivity-guided Spectral Adversarial MixUp (SAMix) to generate target-style images for UDA models. The details of the components are presented in the following sections.
Fig. 1. Illustration of the proposed framework. (a) DoDiSS map characterizes a model’s generalizability weaknesses. (b) SAMix enables UDA methods to solve FSUDA.
2.1
Domain-Distance-Modulated Spectral Sensitivity (DoDiSS)
The prior research [27] found that a spectral sensitivity map obtained using Fourier-based measurement of model sensitivity can effectively portray the generalizability of that model. However, the spectral sensitivity map is limited to single-domain scenarios and cannot integrate target domain information to assess model weaknesses under specific domain shifts. Thus, we introduce DoDiSS, extending the previous method by incorporating domain distance to tackle
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation
731
domain adaptation problems. Fig. 1 (a) depicts the DoDiSS pipeline. It begins by computing a domain distance map for identifying the amplitude distribution difference between the source and target domains in each frequency. Subsequently, this difference map is used for weighting amplitude perturbations when calculating the DoDiSS map. Domain Distance Measurement. To overcome the limitations of lacking target domain images, we first augment the few-shot images from the target domain with random combinations of various geometric transformations, including random cropping, rotation, flipping, and JigSaw [13]. These transformations keep the image intensities unchanged, preserving the target domain style information. The Fast Fourier Transform (FFT) is then applied to all the source images and the augmented target domain images to obtain their amplitude specˆT , respectively. We calculate the probabilistic distritrum, denoted as AS and A S T ˆT at the (i, j)th frequency entry, respectively. butions pi,j and pi,j of AS and A The domain distance map at (i, j) is defined as DW (i, j) = W1 (pSi,j , pTi,j ), where W1 is the 1-Wasserstein distance. DoDiSS Computation. With the measured domain difference, we can now compute the DoDiSS map of a model. As shown in Fig. 1 (a), a Fourier basis is defined as a Hermitian matrix Hi,j ∈ Rh×w with only two non-zero elements at (i, j) and (−i, −j). A Fourier basis image Ui,j can be obtained by 2 -normalized IF F T (A i,j ) Inverse Fast Fourier Transform (IFFT) of Ai,j , i.e., Ui,j = ||IF F T (A i,j )||2 . To analyze the model’s generalization weakness with respect to the frequency (i, j), we generate perturbed source domain images by adding the Fourier basis noise Ni,j = r · DW (i, j) · Ui,j to the original source domain image xs as xs + Ni,j . DW (i, j) controls the 2 -norm of Ni,j and r is randomly sampled to be either -1 or 1. The Ni,j only introduces perturbations at the frequency components (i, j) to the original images. The DW (i, j) guarantees that images are perturbed across all frequency components following the real domain shift. For RGB images, we add Ni,j to each channel independently following [27]. The sensitivity at frequency (i, j) of a model F trained on the source domain is defined as the prediction error rate over the whole dataset XS as in (1), where Acc denotes the prediction accuracy MS (i, j) = 1 −
2.2
Acc
(x s ,y s )∈X S
(F (xs + r · DW (i, j) · Ui,j ), y s ).
(1)
Sensitivity-Guided Spectral Adversarial Mixup (SAMix)
Using the DoDiSS map MS and an adversarially learned parameter λ∗ as a weighting factor, SAMix mixes the amplitude spectrum of each source image with the spectrum of a target image. DoDiSS indicates the spectral regions where the model is sensitive to the domain difference. The parameter λ∗ mines the heard-tolearn samples to efficiently enrich the target domain samples by maximizing the task loss. Further, by retaining the phase of the source image, SAMix preserves the semantic meaning of the original source image in the generated target-style
732
J. Zhang et al.
sample. Specifically, as shown in Fig. 1 (b), given a source image xs and a target image xt , we compute their amplitude and phase spectrum, denoted as (As , Φs ) and (At , Φt ), respectively. SAMix mixes the amplitude spectrum by ∗ t ∗ s Ast λ∗ = λ · MS · A + (1 − λ ) · (1 − MS ) · A .
(2)
st s The target-style image is reconstructed by xst λ∗ = IFFT (Aλ∗ , Φ ). The adver∗ sarially learned parameter λ is optimized by maximizing the task loss LT using the projected gradient descent with T iterations and step size of δ:
λ∗ = arg max LT (F (xst λ ; θ), y), s.t. λ ∈ [0, 1]. λ
(3)
In the training phase, as shown in Fig. 1 (b), the SAMix module generates a batch of augmented images, which are combined with few-shot target domain images to train the UDA model. The overall training objective is to minimize Ltot (θ) = LT (F (xs ; θ), y) + μ · JS(F (xs ; θ), F (xst λ∗ ; θ)) + LU DA ,
(4)
where Lt is the supervised task loss in the source domain; JS is the JensenShannon divergence [27], which regularizes the model predictions consistency between the source images xs and their augmented versions xst λ∗ ; LU DA is the training loss in the original UDA method, and μ is a weighting parameter.
3
Experiments and Results
We evaluated SAMix on two medical image datasets. Fundus [5,14] is an optic disc and cup segmentation task. Following [21], we consider images collected from different scanners as distinct domains. The source domain contains 400 images of the REFUGE [14] training set. We took 400 images from the REFUGE validation set and 159 images of RIM-One [5] to form the target domain 1 & 2. We center crop and resize the disc region to 256 × 256 as network input. Camelyon [1] is a tumor tissue binary classification task across 5 hospitals. We use the training set of Camelyon as the source domain (302, 436 images from hospitals 1 − 3) and consider the validation set (34, 904 images from hospital 4) and test set (85, 054 images from the hospital 5) as the target domains 1 and 2, respectively. All the images are resized into 256 × 256 as network input. For all experiments, the source domain images are split into training and validation in the ratio of 4 : 1. We randomly selected K-shot target domain images for training, while the remaining target domain images were reserved for testing. 3.1
Implementation Details
SAMix is evaluated as a plug-in module for four UDA models: AdaptSeg [17] and Advent [18] for Fundus, and SRDC [16] and DALN [4] for Camelyon. For a fair comparison, we adopted the same network architecture for all the methods on each task. For Fundus, we use a DeepLabV2-Res101 [3] as the backbone
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation
733
Table 1. 10-run average DSC (%) and ASD of models on REFUGE. The best performance is in bold and the second best is indicated with underline. Method
Source Domain → Target Domain 1 DSC (↑) ASD (↓) cup disc avg cup disc
Source Only
61.16∗ 66.54∗ 63.85∗ 14.37∗ 11.69∗ 13.03∗ 55.77∗ 58.62∗ 57.20∗ 20.95∗ 17.63∗ 19.30∗
AdaptSeg Advent
61.45∗ 66.61∗ 64.03∗ 13.79∗ 11.47∗ 12.64∗ 56.67∗ 60.50∗ 58.59∗ 20.44∗ 17.97∗ 19.21∗ 62.03∗ 66.82∗ 64.43∗ 12.82∗ 11.54∗ 12.18∗ 56.43∗ 60.56∗ 58.50∗ 20.31∗ 17.86∗ 19.09∗
ASM SM-PPM AdaptSeg+SAMix Advent+SAMix
69.18∗ 74.55∗ 76.56 76.32
71.91∗ 77.62∗ 80.57 80.64
70.05∗ 76.09∗ 78.57 78.48
8.92∗ 6.09∗ 4.97 4.90
8.35∗ 5.66∗ 4.12 3.98
avg
8.64∗ 5.88∗ 4.55 4.44
Source Domain → Target Domain 2 DSC (↑) ASD (↓) cup disc avg cup disc
57.79∗ 59.62∗ 61.75 62.02
61.86∗ 64.17∗ 66.20 66.35
59.83∗ 61.90∗ 63.98 64.19
19.26∗ 14.52∗ 12.75 11.97
16.94∗ 12.22∗ 11.09 10.85
avg
18.10∗ 13.37∗ 11.92 11.41
∗ p < 0.05 in the one-tailed paired t-test with Advent+SAMix.
with SGD optimizer for 80 epochs. The task loss Lt is the Dice loss. The initial learning rate is 0.001, which decays by 0.1 for every 20 epochs. The batch size is 16. For Camelyon, we use a ResNet-50 [8] with SGD optimizer for 20 epochs. Lt is the binary cross-entropy loss. The initial learning rate is 0.0001, which decays by 0.1 every 5 epochs. The batch size is 128. We use the fixed weighting factor μ = 0.01, iterations T = 10, and step size δ = 0.1 in all the experiments. 3.2
Method Effectiveness
We demonstrate the effectiveness of SAMix by comparing it with two sets of baselines. First, we compare the performance of UDA models with and without SAMix. Second, we compare SAMix against other FSUDA methods [9,11]. Fundus. Table 1 shows the 10-run average Dice coefficient (DSC) and Average Surface Distance (ASD) of all the methods trained with the source domain and 1-shot target domain image. The results are evaluated in the two target domains. Compared to the model trained solely on the source domain (Source only), the performance gain achieved by UDA methods (AdaptSeg and Advent) is limited. However, incorporating SAMix as a plug-in for UDA methods (AdaptSeg+SAMix and Advent+SAMix) enhances the original UDA performance significantly (p < 0.05). Moreover, SAMix+Advent surpasses the two FSUDA methods (ASM and SM-PPM) significantly. This improvement is primarily due to spectrally augmented target-style samples by SAMix. To assess the functionality of the target-aware spectral sensitivity map in measuring the model’s generalization performance on the target domain, we computed the DoDiSS maps of the four models (AdaptSeg, ASM, SM-PPM, and AdaptSeg+SAMix). The results are presented in Fig. 2(a). The DoDiSS map of AdaptSeg+SAMix demonstrates a clear suppression of sensitivity, leading to improved model performance. To better visualize the results, the model generalizability (average DSC) versus the averaged 1 -norm of the DoDiSS map is presented in Fig. 2 (b). The figure shows a clear trend of improved model performance as the averaged DoDiSS decreases. To assess the effectiveness of
734
J. Zhang et al.
Fig. 2. Method effectiveness analysis. (a) The DoDiSS maps visualization; (b) Scattering plot of model generalizability v.s. sensitivity; (c) Feature space visualization.
SAMix-augmented target-style images in bridging the gap of domain shift, the feature distributions of Fundus images before and after adaptation are visualized in Fig. 2 (c) by t-SNE [12]. Figure 2(c1) shows the domain shift between the source and target domain features. The augmented samples from SAMix build the connection between the two domains with only a single example image from the target domain. Please note that, except the 1-shot sample, all the other target domain samples are used here for visualization only but never seen during training/validation. Incorporating these augmented samples in AdaptSeg merges the source and target distributions as in Fig. 2 (c2). Table 2. 10-run average Acc (%) and AUC (%) of models on Camelyon. The best performance is in bold and the second best is indicated with underline. Method
Source Domain → Target Domain 1 Source Domain → Target Domain 2 Acc (↑) AUC (↑) Acc (↑) AUC (↑)
Source Only
75.42∗
71.67∗
65.55∗
60.18∗
∗
∗
∗
56.44∗
∗
73.47∗ 74.62 75.90
DALN
78.63
∗
ASM 83.66 SRDC+SAMix 84.28 DALN+SAMix 86.41
74.74
80.43 80.05 82.58
62.57
77.75 78.64 80.84
∗ p < 0.05 in the one-tailed paired t-test with DALN+SAMix.
Camelyon. The evaluation results of the 10-run average accuracy (Acc) and Area Under the receiver operating Curve (AUC) of all methods trained with 1shot target domain image are presented in Table 2. The clustering-based SRDC is not included in the table, as the model crashed in this few-shot scenario. Also, the SM-PPM is not included because it is specifically designed for segmentation tasks. The results suggest that combining SAMix with UDA not only enhances the original UDA performance but also significantly outperforms other FSUDA methods.
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation
735
Fig. 3. Data efficiency of FSUDA methods on (a) Fundus and (b) Camelyon.
3.3
Data Efficiency
As the availability of target domain images is limited, data efficiency plays a crucial role in determining the data augmentation performance. Therefore, we evaluated the model’s performance with varying numbers of target domain images in the training process. Figure 3 (a) and (b) illustrate the domain adaptation results on Fundus and Camelyon (both in target domain 1), respectively. Our method consistently outperforms other baselines with just a 1-shot target image for training. Furthermore, we qualitatively showcase the data efficiency of SAMix. Figure 4 (a) displays the generated image of SAMix given the target domain image. While maintaining the retinal structure of the source image, the augmented images exhibit a more similar style to the target image, indicating SAMix can effectively transfer the target domain style. Figure 4 (b) shows an example case of the segmented results. Compared with other baselines, the SAMix segmentation presents much less prediction error, especially in the cup region.
Fig. 4. (a) SAMix generated samples. (b) Case study of the Fundus segmentation.
Fig. 5. Ablation study. (a) Average DSC on Fundus. (b) AUC on Camelyon.
736
J. Zhang et al.
3.4
Ablation Study
To assess the efficacy of the components in SAMix, we conducted an ablation study with AdaptSeg+SAMix and DALN+SAMix (Full model) on Fundus and Camelyon datasets. This was done by 1) replacing our proposed DoDiSS map with the original one in [27] (Original map); 2) replacing the SAMix module with the random spectral swapping (FDA, β = 0.01, 0.09) in [25]; 3) removing the three major components (No LU DA , No SAMix, No JS) in a leave-one-out manner. Figure 5 suggests that, compared with the Full model, the model performance degrades when the proposed components are either removed or replaced by previous methods, which indicates the efficacy of the SAMix components.
4
Discussion and Conclusion
This paper introduces a novel approach, Sensitivity-guided Spectral Adversarial MixUp (SAMix), which utilizes an adversarial mixing scheme and a spectral sensitivity map to generate target-style samples effectively. The proposed method facilitates the adaptation of existing UDA methods in the few-shot scenario. Thorough empirical analyses demonstrate the effectiveness and efficiency of SAMix as a plug-in module for various UDA methods across multiple tasks. Acknowledgments. This research was partially supported by the National Science Foundation (NSF) under the CAREER award OAC 2046708, the National Institutes of Health (NIH) under award R21EB028001, and the Rensselaer-IBM AI Research Collaboration of the IBM AI Horizons Network.
References 1. Bandi, P., et al.: From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge. IEEE Trans. Med. Imaging (2018) 2. Chen, C., Dou, Q., Chen, H., Qin, J., Heng, P.A.: Synergistic image and feature adaptation: Towards cross-modality domain adaptation for medical image segmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 33, pp. 865–872 (2019) 3. Chen, L.C., Papandreou, G., Kokkinos, I., Murphy, K., Yuille, A.L.: DeepLab: semantic image segmentation with deep convolutional nets, atrous convolution, and fully connected CRFs. IEEE Trans. Pattern Anal. Mach. Intell. 40(4), 834– 848 (2017) 4. Chen, L., et al.: Reusing the task-specific classifier as a discriminator: discriminator-free adversarial domain adaptation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7181– 7190 (2022)
Spectral Adversarial MixUp for Few-Shot Unsupervised Domain Adaptation
737
5. Fumero, F., Alay´ on, S., Sanchez, J.L., Sigut, J., Gonzalez-Hernandez, M.: Rimone: an open retinal image database for optic nerve evaluation. In: 2011 24th International Symposium on Computer-Based Medical Systems (CBMS), pp. 1–6. IEEE (2011) 6. Guan, H., Liu, M.: Domain adaptation for medical image analysis: a survey. IEEE Trans. Biomed. Eng. 69(3), 1173–1185 (2021) 7. Guyader, N., Chauvin, A., Peyrin, C., H´erault, J., Marendaz, C.: Image phase or amplitude? Rapid scene categorization is an amplitude-based process. C.R. Biol. 327(4), 313–318 (2004) 8. He, K., Zhang, X., Ren, S., Sun, J.: Deep residual learning for image recognition. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 770–778 (2016) 9. Huang, X., Belongie, S.: Arbitrary style transfer in real-time with adaptive instance normalization. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 1501–1510 (2017) 10. Liu, Q., Chen, C., Qin, J., Dou, Q., Heng, P.A.: Feddg: federated domain generalization on medical image segmentation via episodic learning in continuous frequency space. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 1013–1023 (2021) 11. Luo, Y., Liu, P., Guan, T., Yu, J., Yang, Y.: Adversarial style mining for one-shot unsupervised domain adaptation. Adv. Neural. Inf. Process. Syst. 33, 20612–20623 (2020) 12. Van der Maaten, L., Hinton, G.: Visualizing data using t-SNE. J. Mach. Learn. Res. 9(11) (2008) 13. Noroozi, M., Favaro, P.: Unsupervised learning of visual representations by solving jigsaw puzzles. In: Leibe, B., Matas, J., Sebe, N., Welling, M. (eds.) ECCV 2016. LNCS, vol. 9910, pp. 69–84. Springer, Cham (2016). https://doi.org/10.1007/9783-319-46466-4 5 14. Orlando, J.I., et al.: Refuge challenge: a unified framework for evaluating automated methods for glaucoma assessment from fundus photographs. Med. Image Anal. 59, 101570 (2020) 15. Pan, S.J., Yang, Q.: A survey on transfer learning. IEEE Trans. Knowl. Data Eng. 22(10), 1345–1359 (2009) 16. Tang, H., Chen, K., Jia, K.: Unsupervised domain adaptation via structurally regularized deep clustering. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8725–8735 (2020) 17. Tsai, Y.H., Hung, W.C., Schulter, S., Sohn, K., Yang, M.H., Chandraker, M.: Learning to adapt structured output space for semantic segmentation. In: Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7472–7481 (2018) 18. Vu, T.H., Jain, H., Bucher, M., Cord, M., P´erez, P.: Advent: adversarial entropy minimization for domain adaptation in semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 2517–2526 (2019) 19. Wang, H., Xiao, C., Kossaifi, J., Yu, Z., Anandkumar, A., Wang, Z.: AugMax: adversarial composition of random augmentations for robust training. Adv. Neural. Inf. Process. Syst. 34, 237–250 (2021) 20. Wang, J., et al.: Generalizing to unseen domains: a survey on domain generalization. IEEE Trans. Knowl. Data Eng. (2022)
738
J. Zhang et al.
21. Wang, S., Yu, L., Li, K., Yang, X., Fu, C.W., Heng, P.A.: Dofe: domain-oriented feature embedding for generalizable fundus image segmentation on unseen datasets. IEEE Trans. Med. Imaging (2020) 22. Wu, X., Wu, Z., Lu, Y., Ju, L., Wang, S.: Style mixing and patchwise prototypical matching for one-shot unsupervised domain adaptive semantic segmentation. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 36, pp. 2740– 2749 (2022) 23. Xie, Q., et al.: Unsupervised domain adaptation for medical image segmentation by disentanglement learning and self-training. IEEE Trans. Med. Imaging, 1 (2022). https://doi.org/10.1109/TMI.2022.3192303 24. Xu, Q., Zhang, R., Zhang, Y., Wang, Y., Tian, Q.: A Fourier-based framework for domain generalization. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 14383–14392 (2021) 25. Yang, Y., Soatto, S.: FDA: Fourier domain adaptation for semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 4085–4095 (2020) 26. Yin, D., Gontijo Lopes, R., Shlens, J., Cubuk, E.D., Gilmer, J.: A Fourier perspective on model robustness in computer vision. In: Advances in Neural Information Processing Systems, vol. 32 (2019) 27. Zhang, J., et al.: When neural networks fail to generalize? a model sensitivity perspective. In: Proceedings of the AAAI Conference on Artificial Intelligence (2023) 28. Zhang, J., Chao, H., Xu, X., Niu, C., Wang, G., Yan, P.: Task-oriented low-dose CT image denoising. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12906, pp. 441–450. Springer, Cham (2021). https://doi.org/10.1007/978-3-03087231-1 43 29. Zhang, J., Chao, H., Yan, P.: Toward adversarial robustness in unlabeled target domains. IEEE Trans. Image Process. 32, 1272–1284 (2023). https://doi.org/10. 1109/TIP.2023.3242141
Cross-Dataset Adaptation for Instrument Classification in Cataract Surgery Videos Jay N. Paranjape1(B) , Shameema Sikder2,3 , Vishal M. Patel1 , and S. Swaroop Vedula3 1
2
Department of Electrical and Computer Engineering, The Johns Hopkins University, Baltimore, USA [email protected] Wilmer Eye Institute, The Johns Hopkins University, Baltimore, USA 3 Malone Center for Engineering in Healthcare, The Johns Hopkins University, Baltimore, USA
Abstract. Surgical tool presence detection is an important part of the intra-operative and post-operative analysis of a surgery. State-of-the-art models, which perform this task well on a particular dataset, however, perform poorly when tested on another dataset. This occurs due to a significant domain shift between the datasets resulting from the use of different tools, sensors, data resolution etc. In this paper, we highlight this domain shift in the commonly performed cataract surgery and propose a novel end-to-end Unsupervised Domain Adaptation (UDA) method called the Barlow Adaptor that addresses the problem of distribution shift without requiring any labels from another domain. In addition, we introduce a novel loss called the Barlow Feature Alignment Loss (BFAL) which aligns features across different domains while reducing redundancy and the need for higher batch sizes, thus improving cross-dataset performance. The use of BFAL is a novel approach to address the challenge of domain shift in cataract surgery data. Extensive experiments are conducted on two cataract surgery datasets and it is shown that the proposed method outperforms the state-of-the-art UDA methods by 6%. The code can be found at https://github.com/JayParanjape/Barlow-Adaptor. Keywords: Surgical Tool Classification · Unsupervised Domain Adaptation · Cataract Surgery · Surgical Data Science
1
Introduction
Surgical instrument identification and classification are critical to deliver several priorities in surgical data science [21]. Various deep learning methods have been developed to classify instruments in surgical videos using data routinely Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 70. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 739–748, 2023. https://doi.org/10.1007/978-3-031-43907-0_70
740
J. N. Paranjape et al.
generated in institutions [2]. However, differences in image capture systems and protocols lead to nontrivial dataset shifts, causing a significant drop in performance of the deep learning methods when tested on new datasets [13]. Using cataract surgery as an example, Fig. 1 illustrates the drop in accuracy of existing methods to classify instruments when trained on one dataset and tested on another dataset [19,28]. Cataract surgery is one of the most common procedures [18], and methods to develop generalizable networks will enable clinically useful applications.
Fig. 1. Dataset shift between the CATARACTS dataset (CAT) [6] and D99 [7, 9] dataset. Results for models trained on one dataset and tested on another show a significant drop in performance.
Domain adaptation methods aim to attempt to mitigate the drop in algorithm performance across domains [13]. Unsupervised Domain Adaptation (UDA) methods are particularly useful when the source dataset is labeled and the target dataset is unlabeled. In this paper, we describe a novel end-to-end UDA method, which we call the Barlow Adaptor, and its application for instrument classification in video images from cataract surgery. We define a novel loss function called the Barlow Feature Alignment Loss (BFAL) that aligns the features learnt by the model between the source and target domains, without requiring any labeled target data. It encourages the model to learn non-redundant features that are domain agnostic and thus tackles the problem of UDA. BFAL can be added as an add-on to existing methods with minimal code changes. The contributions of our paper are threefold: 1. We define a novel loss for feature alignment called BFAL that doesn’t require large batch sizes and encourages learning non-redundant, domain agnostic features. 2. We use BFAL to generate an end-to-end system called the Barlow Adaptor that performs UDA. We evaluate the effectiveness of this method and compare it with existing UDA methods for instrument classification in cataract surgery images. 3. We motivate new research on methods for generalizable deep learning models for surgical instrument classification using cataract surgery as the test-bed. Our work proposes a solution to the problem of lack of generalizability of
Cross-Dataset Adaptation for Instrument Classification
741
deep learning models that was identified in previous literature on cataract surgery instrument classification.
2
Related Work
Instrument Identification in Cataract Surgery Video Images. The motivation for instrument identification is its utility in downstream tasks such as activity localization and skill assessment [3,8,22]. The current state-of-the-art instrument identification method called Deep-Phase [28] uses a ResNet architecture to identify instruments and then to identify steps in the procedure. However, a recent study has shown that while these methods work well on one dataset, there is a significant drop in performance when tested on a different dataset [16]. Our analyses reiterate similar findings on drop in performance (Fig. 1) and highlight the effect of domain shift between data from different institutions even for the same procedure. Unsupervised Domain Adaptation. UDA is a special case of domain adaptation, where a model has access to annotated training data from a source domain and unannotated data from a target domain [13]. Various methods have been proposed in the literature to perform UDA. One line of research involves aligning the feature distributions between the source and target domains. Maximum Mean Discrepancy (MMD) is commonly used as a distance metric between the source and target distributions [15]. Other UDA methods use a convolutional neural network (CNN) to generate features and then use MMD as an additional loss to align distributions [1,11,12,20,25,27]. While MMD is a first-order statistic, Deep CORAL [17] penalizes the difference in the second-order covariance between the source and target distributions. Our method uses feature alignment by enforcing a stricter loss function during training. Another line of research for UDA involves adversarial training. Domain Adaptive Neural Network (DANN) [5] involves a minimax game, in which one network minimizes the cross entropy loss for classification in the source domain, while the other maximizes the cross entropy loss for domain classification. Few recent methods generate pseudo labels on the target domain and then train the network on them. One such method is Source Hypothesis Transfer (SHOT) [10], which performs source-free domain adaptation by further performing information maximization on the target domain predictions. While CNN-based methods are widely popular for UDA, there are also methods which make use of the recently proposed Vision Transformer (ViT) [4], along with an ensemble of the above described UDA based losses. A recent approach called Cross Domain Transformer (CDTrans) uses cross-domain attention to produce pseudo labels for training that was evaluated in various datasets [24]. Our proposed loss function is effective for both CNN and ViT-based backbones.
742
3
J. N. Paranjape et al.
Proposed Method
In the UDA task, we are given ns observations from the source domain DS . Each of these observations is in the form of a tuple (xs , ys ), where xs denotes an image from the source training data and ys denotes the corresponding label, which is the instrument index present in the image. In addition, we are given nt observations from the target domain DT . Each of these can be represented by xt , which represents the image from the target training data. However, there are no labels present for the target domain during training. The goal of UDA is to predict the labels yt for the target domain data. Barlow Feature Alignment Loss (BFAL). We introduce a novel loss, which encourages features between the source and target to be similar to each other while reducing the redundancy between the learnt features. BFAL works on pairs of feature projections of the source and target. More specifically, let fs ∈ RBXD and ft ∈ RBXD be the features corresponding to the source and target domain, respectively. Here B represents the batch size and D represents the feature dimension. Similar to [26], we project these features into a P dimensional space using a fully connected layer called the Projector, followed by a batch normalization to whiten the projections. Let the resultant projections be denoted by ps ∈ RBXP for the source and pt ∈ RBXP for the target domains. Next, we compute the correlation matrix C1 ∈ RP XP . Each element of C1 is computed as follows B bi bj ij b=1 ps pt C1 = . (1) B B bj 2 bi 2 b=1 (ps ) b=1 (pt ) Finally, the BFAL is computed using the L2 loss between the elements of C1 and the identity matrix I as follows LBF A =
P i=1
2 (1 − Cii + μ 1)
f eature alignment
P i=1 j=i
2 (Cij , 1 )
(2)
redundancy reduction
where μ is a constant. Intuitively, the first term of the loss function can be thought of as a feature alignment term since we push the diagonal elements in the covariance matrix towards 1. In other words, we encourage the feature projections between the source and target to be perfectly correlated. On the other hand, by pushing the off-diagonal elements to 0, we decorrelate different components of the projections. Hence, this term can be considered a redundancy reduction term, since we are pushing each feature vector component to be independent of one another. BFAL is inspired by a recent technique in self-supervised learning, called the Barlow Twins [26], where the authors show the effectiveness of such a formulation at lower batch sizes. In our experiments, we observe that even keeping a batch size of 16 gave good results over other existing methods. Furthermore, BFAL does not require large amounts of data to converge.
Cross-Dataset Adaptation for Instrument Classification
743
Fig. 2. Architecture corresponding to the Barlow Adaptor. Training occurs using pairs of images from the source and target domain. They are fed into the feature extractor, which generates features used for the CORAL loss. Further, a projector network P projects the features into a P dimensional space. These are then used to calculate the Barlow Feature Alignment Loss. One branch from the source features goes into the source classifier network that is used to compute the cross entropy loss with the labeled source data. [Backprop = backpropagation; src = source dataset, tgt = target dataset]
Barlow Adaptor. We propose an end-to-end method that utilizes data from the labeled source domain and the unlabeled target domain. The architecture corresponding to our method is shown in Fig. 2. There are two main sub-parts of the architecture - the Feature Extractor F , and the Source Classifier C. First, we divide the training images randomly into batches of pairs {xs , xt } and apply F on them, which gives us the features extracted from these sets of images. For the Feature Detector, we show the effectiveness of our novel loss using ViT and ResNet50 both of which have been pre-trained on ImageNet. The features obtained are denoted as fs and ft for the source and target domains, respectively. Next, we apply C on these features to get logits for the classification task. The source classifier is a feed forward neural network, which is initialized from scratch. These logits are used, along with source labels ys to compute the source cross entropy loss as LCE = the B M −1 bm bm b=1 m=1 ys log(ps ), B where M represents the number of classes, B represents the total minibatches, while m and b represent their respective indices. The features fs and ft are further used to compute the Correlation Alignment(CORAL) loss and the BFAL, which enforce the feature extractor to align its weights so as to learn features that are domain agnostic as well as nonredundant. The BFAL is calculated as mentioned in the previous subsection. The CORAL loss is computed as depicted in Eq. 4, following the UDA method Deep CORAL [17]. While the BFAL focuses on reducing redundancy, CORAL
744
J. N. Paranjape et al.
works by aligning the distributions between the source and target domain data. This is achieved by taking the difference between the covariance matrices of the source and target features - fs and ft respectively. The final loss is the weighted sum of the three individual losses as follows: Lf inal = LCE + λ(LCORAL + LBF A ),
(3)
where LCORAL =
1 1 1 (f T fs ) − (1T fs )T (1T fs ), Cs − Ct 2F , Cs = 4D2 B−1 s B
(4)
1 1 (ftT ft ) − (1T ft )T (1T ft ). (5) B−1 B Each of these three losses plays a different role in the UDA task. The cross entropy loss encourages the model to learn discriminative features between images with different instruments. The CORAL loss pushes the features between the source and target towards having a similar distribution. Finally, the BFAL tries to make the features between the source and the target non-redundant and same. BFAL is a stricter loss than CORAL as it forces features to not only have the same distribution but also be equal. Further, it also differs from CORAL in learning independent features as it explicitly penalizes non-zero non-diagonal entries in the correlation matrix. While using BFAL alone gives good results, using it in addition to CORAL gives slightly better results empirically. We note these observations in our ablation studies. Between the cross entropy loss and the BFAL, an adversarial game is played where the former makes the features more discriminative and the latter tries to make them equal. The optimal features thus learnt are different in aspects required to identify instruments but are equal for any domain-related aspect. This property of the Barlow Adaptor is especially useful for surgical domains where the background has similar characteristics for most of the images within a domain. For example, for cataract surgery images, the position of the pupil or the presence of blood during the usage of certain instruments might be used by the model for classification along with the instrument features. These features depend highly upon the surgical procedures and the skill of the surgeon, thus making them highly domain-specific and possibly unavailable in the target domain. Using BFAL during training attempts to prevent the model from learning such features. Ct =
4
Experiments and Results
We evaluate the proposed UDA method for the task of instrument classification using two cataract surgery image datasets. In our experiments, one dataset is used as the source domain and the other is used as the target domain. We use micro and macro accuracies as our evaluation metrics. Micro accuracy denotes the number of correctly classified observations divided by the total number of observations. In contrast, macro accuracy denotes the average of the classwise accuracies and is effective in evaluating classes with less number of samples.
Cross-Dataset Adaptation for Instrument Classification
745
Table 1. Mapping of surgical tools between CATARACTS(L) and D99(R) CATARACTS
D99
CATARACTS
D99
Secondary Incision Knife
Paracentesis Blade
Bonn Forceps
0.12 Forceps
Charleux Cannula
Anterior Chamber Cannula Irrigation
Irrigation
Capsulorhexis Forceps
Utrata Forceps
Cotton
Weckcell Sponge
Hydrodissection Cannula
Hydrodissection Cannula
Implant Injector
IOL Injector
Phacoemulsifier Handpiece Phaco Handpiece
Suture Needle
Suture
Capsulorhexis Cystotome
Cystotome
Needle Holder
Needle Driver
Primary Incision Knife
Keratome
Micromanipulator Chopper
Datasets. The first dataset we use is CATARACTS [6], which consists of 50 videos with framewise annotations available for 21 surgical instruments. The dataset is divided into 25 training videos and 25 testing videos. We separate 5 videos from the training set and use them as the validation set for our experiments. The second dataset is called D99 in this work [7,9], which consists of 105 videos of cataract surgery with annotations for 25 surgical instruments. Of the 105 videos, we use 65 videos for training, 10 for validation and 30 for testing. We observe a significant distribution shift between the two datasets as seen in Fig. 1. This is caused by several factors such as lighting, camera resolution, and differences in instruments used for the same steps. For our experiments in this work, we use 14 classes of instruments that are common to both datasets. Table 1 shows a mapping of instruments between the two datasets. For each dataset, we normalize the images using the means and standard deviations calculated from the respective training images. In addition, we resize all images to 224 × 224 size and apply random horizontal flipping with a probability of 0.5 before passing them to the model. Experimental Setup. We train the Barlow Adaptor for multi-class classification with the above-mentioned 14 classes in Pytorch. For the Resnet50 backbone, we use weights pretrained on Imagenet [14] for initialization. For the ViT backbone, we use the base-224 class of weights from the TIMM library [23]. The Source Classifier C and the Projector P are randomly initialized. We use the validation sets to select the hyperparameters for the models. Based on these empirical results, we choose λ from Eq. 3 to be 0.001 and μ from Eq. 2 to be 0.0039. We use SGD as the optimizer with momentum of 0.9 and a batch size of 16. We start the training with a learning rate of 0.001 and reduce it by a factor of 0.33 every 20 epochs. The entire setup is trained with a single NVIDIA Quatro RTX 8000 GPU. We use the same set of hyperparameters for the CNN and ViT backbones in both datasets. Results. Table 2 shows results comparing the performance of the Barlow Adaptor with recent UDA methods. We highlight the effect of domain shift by comparing the source-only models and the target-only models, where we observe a
746
J. N. Paranjape et al.
Table 2. Macro and micro accuracies for cross domain tool classification. Here, sourceonly denotes models that have only been trained on one domain and tested on the other. Similarly, target-only denotes models that have been trained on the test domain and thus act as an upper bound. Deep CORAL [17] is similar to using CORAL with ResNet backbone, so we don’t list the latter separately. Here, CAT represents the CATARACTS dataset. D99 → CAT
CAT → D99
Method
Macro Acc Micro Acc Macro Acc Micro Acc
Source Only (ResNet50 backbone) MMD with ResNet50 backbone [15]
27.9% 32.2%
14.9% 15.9%
14.25% 20.6%
16.9% 24.3%
Source Only (ViT backbone) MMD with ViT backbone [15] CORAL with ViT backbone [17]
30.43% 31.32% 28.7%
14.14% 13.81% 16.5%
13.99% 16.42% 15.38%
17.11% 20% 18.5
DANN [5] Deep CORAL [17] CDTrans [24]
22.4% 32.8% 29.1%
11.6% 14% 14.7%
16.7% 18.6% 20.9%
19.5% 22 24.7%
Barlow Adaptor with ResNet50 (Ours) 35.1% 31.91% Barlow Adaptor with ViT (Ours)
17.1% 12.81%
24.62% 17.35%
28.13% 20.8%
Target Only (ResNet50) Target Only (ViT)
67.2% 66.33%
57% 56.43%
62.2% 60.46%
55% 49.80%
Table 3. Findings from ablation studies to evaluate the Barlow Adaptor. Here, Source Only is the case where neither CORAL nor BFAL is used. We use Macro Accuracy for comparison. Here, CAT represents the CATARACTS dataset.
ViT Feature Extractor
ResNet50 Feature Extractor
Method
D99 → CAT CAT → D99 D99 → CAT CAT → D99
Source Only(LCE )
30.43%
16.7%
27.9%
14.9%
Only CORAL(LCORAL ) 28.7%
15.38%
32.8%
18.6%
Only BFAL(LBF A )
17.01%
32.3%
24.46%
17.35%
35.1%
24.62%
29.8%
Barlow Adaptor (Eq. 3) 32.1%
significant drop of 27% and 43% in macro accuracy for the CATARACTS dataset and the D99 dataset, respectively. Using the Barlow Adaptor, we observe an increase in macro accuracy by 7.2% over the source only model. Similarly, we observe an increase in macro accuracy of 9% with the Barlow Adaptor when the source is CATARACTS and the target is the D99 dataset compared with the source only model. Furthermore, estimates of macro and micro accuracy are larger with the Barlow Adaptor than those with other existing methods. Finally, improved accuracy with the Barlow Adaptor is seen with both ResNet and ViT backbones.
Cross-Dataset Adaptation for Instrument Classification
747
Ablation Study. We tested the performance gain due to each part of the Barlow Adaptor. Specifically, the Barlow Adaptor has CORAL loss and BFAL as its two major feature alignment losses. We remove one component at a time and observe a decrease in performance with both ResNet and ViT backbones (Table 3). This shows that each loss has a part to play in domain adaptation. Further ablations are included in the supplementary material.
5
Conclusion
Domain shift between datasets of cataract surgery images limits generalizability of deep learning methods for surgical instrument classification. We address this limitation using an end-to-end UDA method called the Barlow Adaptor. As part of this method, we introduce a novel loss function for feature alignment called the BFAL. Our evaluation of the method shows larger improvements in classification performance compared with other state-of-the-art methods for UDA. BFAL is an independent module and can be readily integrated into other methods as well. BFAL can be easily extended to other network layers and architectures as it only takes pairs of features as inputs. Acknowledgement. This research was supported by a grant from the National Institutes of Health, USA; R01EY033065. The content is solely the responsibility of the authors and does not necessarily represent the official views of the National Institutes of Health.
References 1. Baktashmotlagh, M., Harandi, M., Salzmann, M.: Distribution-matching embedding for visual domain adaptation. J. Mach. Learn. Res. 17(1), 3760–3789 (2016) 2. Bouget, D., Allan, M., Stoyanov, D., Jannin, P.: Vision-based and marker-less surgical tool detection and tracking: a review of the literature. Med. Image Anal. 35, 633–654 (2017) 3. demir, K., Schieber, H., Weise, T., Roth, D., Maier, A., Yang, S.: Deep learning in surgical workflow analysis: a review (2022) 4. Dosovitskiy, A., et al.: An image is worth 16×16 words: transformers for image recognition at scale. In: International Conference on Learning Representations (2021) 5. Ganin, Y., Lempitsky, V.: Unsupervised domain adaptation by backpropagation. In: Proceedings of the 32nd International Conference on International Conference on Machine Learning, ICML 2015, vol. 37, pp. 1180–1189. JMLR.org (2015) 6. Hajj, H., et al.: Cataracts: challenge on automatic tool annotation for cataract surgery. Med. Image Anal. 52, 24–41 (2018) 7. Hira, S.: Video-based assessment of intraoperative surgical skill. Comput.-Assist. Radiol. Surg. 17(10), 1801–1811 (2022) 8. Josef, L., James, W., Michael, S.: Evolution and applications of artificial intelligence to cataract surgery. Ophthalmol. Sci. 2, 100164 (2022) 9. Kim, T., O’Brien, M., Zafar, S., Hager, G., Sikder, S., Vedula, S.: Objective assessment of intraoperative technical skill in capsulorhexis using videos of cataract surgery. Comput.-Assist. Radiol. Surg. 14(6), 1097–1105 (2019)
748
J. N. Paranjape et al.
10. Liang, J., Hu, D., Feng, J.: Do we really need to access the source data? source hypothesis transfer for unsupervised domain adaptation. In: III, H.D., Singh, A. (eds.) Proceedings of the 37th International Conference on Machine Learning. Proceedings of Machine Learning Research, vol. 119, pp. 6028–6039. PMLR (2020) 11. Long, M., Wang, J., Ding, G., Sun, J., Yu, P.S.: Transfer feature learning with joint distribution adaptation. In: 2013 IEEE International Conference on Computer Vision, pp. 2200–2207 (2013) 12. Pan, S.J., Tsang, I.W., Kwok, J.T., Yang, Q.: Domain adaptation via transfer component analysis. IEEE Trans. Neural Netw. 22(2), 199–210 (2011) 13. Patel, V.M., Gopalan, R., Li, R., Chellappa, R.: Visual domain adaptation: a survey of recent advances. IEEE Signal Process. Maga. 32(3), 53–69 (2015) 14. Russakovsky, O., et al.: ImageNet large scale visual recognition challenge. Int. J. Comput. Vision 115(3), 211–252 (2015). https://doi.org/10.1007/s11263-0150816-y 15. Sch¨ olkopf, B., Platt, J., Hofmann, T.: A kernel method for the two-sample-problem, pp. 513–520 (2007) 16. Sokolova, N., Schoeffmann, K., Taschwer, M., Putzgruber-Adamitsch, D., ElShabrawi, Y.: Evaluating the generalization performance of instrument classification in cataract surgery videos. In: Ro, Y.M., et al. (eds.) MMM 2020. LNCS, vol. 11962, pp. 626–636. Springer, Cham (2020). https://doi.org/10.1007/978-3030-37734-2 51 17. Sun, B., Saenko, K.: Deep CORAL: correlation alignment for deep domain adaptation, pp. 443–450 (2016) 18. Trikha, S., Turnbull, A., Morris, R., Anderson, D., Hossain, P.: The journey to femtosecond laser-assisted cataract surgery: new beginnings or a false dawn? Eye (London, England) 27 (2013) 19. Twinanda, A.P., Shehata, S., Mutter, D., Marescaux, J., de Mathelin, M., Padoy, N.: Endonet: a deep architecture for recognition tasks on laparoscopic videos. IEEE Trans. Med. Imaging 36(1), 86–97 (2017) 20. Tzeng, E., Hoffman, J., Zhang, N., Saenko, K., Darrell, T.: Deep domain confusion: maximizing for domain invariance (2014) 21. Vedula, S.S., et al.: Artificial intelligence methods and artificial intelligence-enabled metrics for surgical education: a multidisciplinary consensus. J. Am. Coll. Surg. 234(6), 1181–1192 (2022) 22. Ward, T.M., et al.: Computer vision in surgery. Surgery 169(5), 1253–1256 (2021) 23. Wightman, R.: Pytorch image models. https://github.com/rwightman/pytorchimage-models (2019) 24. Xu, T., Chen, W., Wang, P., Wang, F., Li, H., Jin, R.: Cdtrans: cross-domain transformer for unsupervised domain adaptation (2021) 25. Yan, H., Ding, Y., Li, P., Wang, Q., Xu, Y., Zuo, W.: Mind the class weight bias: Weighted maximum mean discrepancy for unsupervised domain adaptation. In: 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 945–954 (2017) 26. Zbontar, J., Jing, L., Misra, I., LeCun, Y., Deny, S.: Barlow twins: self-supervised learning via redundancy reduction (2021) 27. Zhong, E., et al.: Cross domain distribution adaptation via kernel mapping. In: Proceedings of the 15th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD 2009, pp. 1027–1036. Association for Computing Machinery, New York (2009) 28. Zisimopoulos, O., et al.: Deepphase: surgical phase recognition in cataracts videos (2018)
Black-box Domain Adaptative Cell Segmentation via Multi-source Distillation Xingguang Wang1 , Zhongyu Li1(B) , Xiangde Luo2 , Jing Wan1 , Jianwei Zhu1 , Ziqi Yang1 , Meng Yang3 , and Cunbao Xu4(B) 1
School of Software Engineering, Xian Jiaotong University, Xi’an, China [email protected] 2 School of Mechanical and Electrical Engineering, University of Electronic Science and Technology of China, Chengdu, China 3 Hunan Frontline Medical Technology Co., Ltd., Changsha, China 4 Department of Pathology, Quanzhou First Hospital Affiliated to Fujian Medical University, Quanzhou, China [email protected]
Abstract. Cell segmentation plays a critical role in diagnosing various cancers. Although deep learning techniques have been widely investigated, the enormous types and diverse appearances of histopathological cells still pose significant challenges for clinical applications. Moreover, data protection policies in different clinical centers and hospitals limit the training of data-dependent deep models. In this paper, we present a novel framework for cross-tissue domain adaptative cell segmentation without access both source domain data and model parameters, namely Multi-source Black-box Domain Adaptation (MBDA). Given the target domain data, our framework can achieve the cell segmentation based on knowledge distillation, by only using the outputs of models trained on multiple source domain data. Considering the domain shift cross different pathological tissues, predictions from the source models may not be reliable, where the noise labels can limit the training of the target model. To address this issue, we propose two practical approaches for weighting knowledge from the multi-source model predictions and filtering out noisy predictions. First, we assign pixel-level weights to the outputs of source models to reduce uncertainty during knowledge distillation. Second, we design a pseudo-label cutout and selection strategy for these predictions to facilitate the knowledge distillation from local cells to global pathological images. Experimental results on four types of pathological tissues demonstrate that our proposed black-box domain adaptation approach can achieve comparable and even better performance in comparison with state-of-the-art white-box approaches. The code and dataset are released at: https://github.com/NeuronXJTU/MBDA-CellSeg.
Keywords: Multi-source domain adaptation segmentation · Knowledge distillation
· Black-box model · Cell
c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 749–758, 2023. https://doi.org/10.1007/978-3-031-43907-0_71
750
1
X. Wang et al.
Introduction
Semantic segmentation plays a vital role in pathological image analysis. It can help people conduct cell counting, cell morphology analysis, and tissue analysis, which reduces human labor [19]. However, data acquisition for medical images poses unique challenges due to privacy concerns and the high cost of manual annotation. Moreover, pathological images from different tissues or cancer types often show significant domain shifts, which hamper the generalization of models trained on one dataset to others. Due to the abovementioned challenges, some researchers have proposed various white-box domain adaptation methods to address these issues. Recently, [8,16] propose to use generative adversarial networks to align the distributions of source and target domains and generate source-domain lookalike outputs for target images. Source-free domain adaptation methods have been also widely investigated due to the privacy protection. [3,5,14] explore how to implicitly align target domain data with the model trained on the source domain without accessing the source domain data. There are also many studies on multi-source white-box domain adaptation. Ahmed et al. [1] propose a novel algorithm which automatically identifies the optimal blend of source models to generate the target model by optimizing a specifically designed unsupervised loss. Li et al. [13] extend the above work to semantic segmentation and proposed a method named model-invariant feature learning, which takes full advantage of the diverse characteristics of the source-domain models. Nonetheless, several recent investigations have demonstrated that the domain adaptation methods for source-free white-box models still present a privacy risk due to the potential leakage of model parameters [4]. Such privacy breaches may detrimental to the privacy protection policies of hospitals. Moreover, the target domain uses the same neural network as the source domain, which is not desirable for low-resource target users like hospitals [15]. We thus present a more challenging task of relying solely on black-box models from vendors to avoid parameter leakage. In clinical applications, various vendors can offer output interfaces for different pathological images. While black-box models are proficient in specific domains, their performances greatly degrade when the target domain is updated with new pathology slices. Therefore, how to leverage the existing knowledge of black-box models to effectively train new models for the target domain without accessing the source domain data remains a critical challenge. In this paper, we present a novel source-free domain adaptation framework for cross-tissue cell segmentation without accessing both source domain data and model parameters, which can seamlessly integrate heterogeneous models from different source domains into any cell segmentation network with high generality. To the best of our knowledge, this is the first study on the exploration of multi-source black-box domain adaptation for cross-tissue cell segmentation. In this setting, conventional multi-source ensemble methods are not applicable due to the unavailability of model parameters, and simply aggregating the black-box outputs would introduce a considerable amount of noise, which can be detrimental to the training of the target domain model. Therefore, we develop two
Black-box Domain Adaptative Cell Segmentation
751
Teacher model
Confidence threshold 0.97 0.90 0.95 0.92 0.93 0.99 0.92 0.92
Pseudo-cutout label
EMA
0.95 0.95
Adaptive vote
Student model
0.95 0.92
Unlabeled
Labeled Target domain
Score
Pixel-level weight Multi-source models (Black-box)
Logits map
Prediction uncertainty Boundary ambiguity
Weighted logits map
Fig. 1. Overview of our purposed framework, where logits maps denote the raw predictions from source models and ω denotes pixel-level weight for each prediction. The semi-supervised loss, denoted as Lssl , encompasses the supervised loss, consistency loss, and maximize mutual information loss.
strategies within this new framework to address this issue. Firstly, we propose a pixel-level multi-source domain weighting method, which reduces source domain noise by knowledge weighting. This method effectively addresses two significant challenges encountered in the analysis of cellular images, namely, the uncertainty in source domain output and the ambiguity in cell boundary semantics. Secondly, we also take into account the structured information from cells to images, which may be overlooked during distillation, and design an adaptive knowledge voting strategy. This strategy enables us to ignore low-confidence regions, similar to Cutout [6], but with selective masking of pixels, which effectively balances the trade-off between exploiting similarities and preserving differences of different domains. As a result, we refer to the labels generated through the voting strategy as pseudo-cutout labels.
2
Method
Overview: Figure 1 shows a binary cell segmentation task with three source models trained on different tissues and a target model, i.e., the student model in Fig. 1. We only use the source models’ predictions on the target data for knowledge transfer without accessing the source data and parameters. The η and η indicate that different perturbations are added to the target images. Subsequently, we feed the perturbed images into the source domain predictor to generate the corresponding raw segmentation outputs. These outputs are then processed by two main components of our framework: a pixel-level weighting method that takes into account the prediction uncertainty and cell boundary ambiguity, and an adaptive knowledge voter that utilizes confidence gates and a dynamic ensemble strategy. These components we designed are to extract reliable knowledge from the predictions of source domain models and reduce noise
752
X. Wang et al.
during distillation. Finally, we obtain a weighted logit for knowledge distillation from pixel level and a high-confidence pseudo-cutout label for further structured distillation from cell to global pathological image. Knowledge Distillation by Weighted Logits Map: We denote DSN = {Xs , Ys }N as a collection of N source domains and DT = {Xti , Ytj } as single target domain, where the number of labeled instances Ytj Xti . We are only provided with black-box models {fsn }N n=1 trained on multiple source domains for knowledge transfer. The parameters {θsn }N {xis , ysi }N n=1 n=1 of these source domain predictors are not allowed to participate in gradient backpropagation as a result of the privacy policy. Thus, our ultimate objective is to derive a novel student model ft : Xt → Yt that is relevant to the source domain task. Accordingly, direct knowledge transfer using the output of the source domain predictor may lead to feature bias in the student model due to the unavoidable covariance [20] between the target and source domains. Inspired by [21], we incorporate prediction uncertainty and cell boundary impurity to establish pixel-level weights for multi-source outputs. We assume that k-square-neighbors of a pixel as a cell region, i.e., for a logits map with height H and width W , we define the region as follow: Nk {(i, j) | (i, j) ∈ (H, W )} = {(u, v) | |u − i| ≤ k, |v − j| ≤ k}
(1)
where (i, j) denotes centre of region, and k denotes the size of k-square-neighbors. Firstly, we develop a pixel-level predictive uncertainty algorithm to aid in assessing the correlation between multiple source domains and the target domain. For a given target image xt ∈ Xti , we initially feed it into the source predicn N tors {fsn }N n=1 to obtain their respective prediction {ps }n=1 . To leverage the rich semantic information from the source domain predictor predictions, we utilize predictive entropy of the softmax outputs to measure the prediction uncertainty scores. In the semantic segmentation scenario of C-classes classification, we define (i,j) as follow: the pixel-level uncertainty score Un Un(i,j) = −
C
Osn(i,j,c) log Osn(i,j,c)
(2)
c=1
where Osn denotes softmax output,i.e.,Osn = softmax(pns ) from nth source predictor. Due to the unique characteristics of cell morphology, merely relying on uncertainty information is insufficient to produce high-quality ensemble logits map that accurately capture the relevance between the source and target domains. The target pseudo-label for the nth predictor fsn can be obtained by applying the softmax function to the output and selecting the category with the highest probability score, i.e., Yt = arg maxc∈{1,...,C} (softmax(pns )). Then according to C-classes classification tasks, we divide the cell region into C subsets, Nkc (i, j) = {(u, v) ∈ Nk (i, j) | Yt = c}. After that, we determine the degree of
Black-box Domain Adaptative Cell Segmentation
753
impurity in an area of interest by analyzing the statistics of the boundary region, which represents the level of semantic information ambiguity. Specifically, the number of different objects within the area is considered a proxy for its impurity level, with higher counts indicating higher impurity.The boundary impurity P (i,j) can be calculated as: Pn(i,j) = −
C |N c (i, j)| k
c=1
|Nk (i, j)|
log
|Nkc (i, j)| |Nk (i, j)|
(3)
where | · | denotes the number of pixels in the area. By assigning lower weights to the pixels with high uncertainty and boundary ambiguity, we can obtain pixel-level weight scores W n for each pns , i.e., exp (Un Pn ) W n = − log N (4) n=1 exp (Un Pn ) where denotes element-wise matrix multiplication. According N to the pixellevel weight, we will obtain an ensemble logits map M = n=1 W n · pns . And the object of the knowledge distillation is a classical regularization term [9]: Lkd (ft ; Xt , M) = Ext ∈Xt Dkl (M || ft (xt ))
(5)
where Dkl denotes the Kullback-Leibler (KL) divergence loss. Adaptive Pseudo-Cutout Label: As previously mentioned, the outputs from the source domain black-box predictors have been adjusted by the pixel-level weight. However, they are still noisy and only pixel-level information is considered while ignoring structured information in the knowledge distillation process. Thus, we utilize the output of the black-box predictor on the target domain to produce an adaptive pseudo-cutout label, which will be employed to further regularize the knowledge distillation process. We have revised the method in [7] to generate high-quality pseudo labels that resemble the Cutout augmentation technique. For softmax outputs {Osn }N n=1 from N source predictors, we first set a threshold α to filter low-confidence pixels. To handle pixels with normalized probability values below the threshold, we employ a Cutout-like operation and discard these pixels. Subsequently, we apply an adaptive voting strategy to the N source domain outputs. Initially, during the training of the target model, if at least one source domain output exceeds the threshold, we consider the pixel as a positive or negative sample, which facilitates rapid knowledge acquisition by the model. As the training progresses, we gradually tighten the voting strategy and only retain regional pixels that have received adequate votes. The strategy can be summarised as follow: 1, Osn (i, j) > α, ((i,j) | (i,j)∈(H,W )) (6) = Vn 0, otherwise. where α is empirically set as 0.9.
754
X. Wang et al.
N (i,j) Then we will aggregate the voting scores, i.e., V (i,j) = n=1 Vn and determine whether to retain each pixel using an adaptive vote gate G ∈ {1, 2, 3, etc.}. By filtering with a threshold and integrating the voting strategy, we generate high-confidence pseudo-labels that remain effective even when the source and target domains exhibit covariance. Finally, we define the ensemble result as a pseudo-cutout label Pˆs and employ consistency regularization as below: Lpcl (ft ; Xt , Pˆs ) = Ext ∈Xt lce (Pˆs || ft (xt ))
(7)
where lce denotes cross-entropy loss function. Loss Functions: Finally, we incorporate global structural information about the predicted outcome of the target domain into both distillation and semisupervised learning. To mitigate the noise effect of the source domain predictors, we introduce maximize mutual information targets tofacilitate discrete representation learning by the network. We define E(p) = − i pi log pi as conditional entropy. The object can be described as follow: Lmmi (ft ; Xt ) = H(Yt ) − H(Yt |Xt ) = E (Ext ∈Xt ft (xt )) − Ext ∈Xt E (ft (xt )) ,
(8)
where the increasing H(Yt ) and the decreasing H(Yt |Xt ) help to balances class separation and classifier complexity [15]. We adopt the classical and effective mean-teacher framework as a baseline for semi-supervised learning and update the teacher model parameters by exponential moving average. Also, we apply two different perturbations (η, η ) to the target domain data and feed them into the student model and the mean-teacher model respectively. The consistency loss of unsupervised learning can be defined as below: 2 (9) Lcons (θt , θt ) = Ext ∈Xt ||ft (xt , θt , η ) − ft (xt , θt , η)|| Finally, we get the overall objective: Lall = Lkd + Lpcl + Lcons − Lmmi + Lsup
(10)
where Lsup denotes the ordinary cross-entropy loss for supervised learning and we set the weight of each loss function to 1 in the training.
3
Experiments
Dataset and Setting: We collect four pathology image datasets to validate our proposed approach. Firstly, we acquire 50 images from a cohort of patients with Triple Negative Breast Cancer (TNBC), which is released by Naylor et al [18]. Hou et al. [10] publish a dataset of nucleus segmentation containing 5,060 segmented slides from 10 TCGA cancer types. In this work, we use 98 images from
Black-box Domain Adaptative Cell Segmentation
755
Fig. 2. Visualized segmentation on the BRCA and KIRC target domains respectively.
invasive carcinoma of the breast (BRCA). We have also included 463 images of Kidney Renal Clear cell carcinoma (KIRC) in our dataset, which are made publicly available by Irshad et al [11]. Awan et al. [2] publicly release a dataset containing tissue slide images and associated clinical data on colorectal cancer (CRC), from which we randomly select 200 patches for our study. In our experiments, we transfer knowledge from three black-box models trained on different source domains to a new target domain model (e.g.,from CRC, TNBC, KIRC to BRCA). The backbone network for the student model and source domain black-box predictors employ the widely adopted residual U-Net [12], which is commonly used for medical image segmentation. For each source domain network, we conduct full-supervision training on the corresponding source domain data and directly evaluate its performance on target domain data. The upper performance metrics (Source-only upper) are shown in the Table 1. To ensure the reliability of the results, we use the same data for training, validation, and testing, which account for 80%, 10%, and 10% of the original data respectively. For the target domain network, we use unsupervised and semi-supervised as our task settings respectively. In semi-supervised domain adaptation, we only use 10% of the target domain data as labeled data. Experimental Results: To validate our method, we compare it with the following approaches: (1) CellSegSSDA [8], an adversarial learning based semisupervised domain adaptation approach. (2) US-MSMA [13], a multi-source model domain aggregation network. (3) SFDA-DPL [5], a source-free unsupervised domain adaptation approach. (4) BBUDA [17], an unsupervised black-box model domain adaptation framework. A point worth noting is that most of the methods we compared with are white-box methods, which means they can obtain more information from the source domain than us. For single-source domain adaptation approach, CellsegSSDA and SFDA-DPL, we employ two strategies to ensure the fairness of the experiments: (1) single-source, i.e. performing adaptation on each single source, where we select the best results to display in the Table 1; (2) source-combined, i.e. all source domains are combined into a traditional single source. The Table 1 and Fig. 2 demonstrate that our proposed method exhibits superior performance, even when compared to these white-box methods, surpassing them in various evaluation metrics and visualization results. In addition, the experimental results also show that simply combining multiple
756
X. Wang et al.
Table 1. Quantitative comparison with unsupervised and semi-supervised domain adaptation methods under 3 segmentation metrics. Source
CRC&TNBC&KIRC
CRC&TNBC&BRCA
Target
BRCA
KIRC
Standards
Methods
Dice
HD95
ASSD
Source-only
Source(upper)
0.6991
41.9604
10.8780 0.7001
34.5575
6.7822
0.6327 0.6620 0.6828 0.6729 0.7334 0.7351
43.8113 39.6950 46.5393 41.8879 37.1309 39.4103
11.6313 11.3911 12.1484 11.5375 8.7817 8.7014
26.3252 42.9875 25.4274 46.7358 18.7093 30.9221
4.6023 7.4398 4.2998 8.7463 3.0187 6.2080
Single-source(upper) SFDA-DPL [5] BBUDA [17] SFDA-DPL [5] Source-Combined BBUDA [17] US-MSMA [13] Multi-source Our(UDA) Multi-source
Dice 0.6383 0.6836 0.6446 0.6895 0.7161 0.7281
HD95
ASSD
Single-source(upper) CellSegSSDA [8] 0.6852 45.2595 9.9133 0.6937 58.7221 12.5176 CellSegSSDA [8] 0.7202 43.9251 8.0944 0.6699 55.1768 10.3623 Source-Combined Our(SSDA) 0.7565 39.0552 9.3346 0.7443 31.7582 6.0873 Multi-source fully-supervised upper bounds
0.7721
35.1449
7.2848
0.7540
23.53767 4.1882
Table 2. Ablation study of three modules in our proposed method. CRC&KIRC&BRCA to TNBC WL PCL MMI
Dice
HD95
×
×
×
0.6708
56.9111
ASSD
×
×
0.6822
54.3386 14.9817
×
0.6890
57.0889
0.7075 58.8798
16.3837 12.9512 10.7247
source data into a traditional single source will result in performance degradation in some cases, which also proves the importance of studying multi-source domain adaptation methods. Ablation Study: To evaluate the impact of our proposed methods of weighted logits(WL), pseudo-cutout label(PCL) and maximize mutual information(MMI) on the model performance, we conduct an ablation study. We compare the baseline model with the models that added these three methods separately. We chose CRC, KIRC and BRCA as our source domains, and TNBC as our target domain. The results of these experiments, presented in the Table 2, show that our proposed modules are indeed useful.
4
Conclusion
Our proposed multi-source black-box domain adaptation method achieves competitive performance by solely relying on the source domain outputs, without the need for access to the source domain data or models, thus avoiding information leakage from the source domain. Additionally, the method does not assume
Black-box Domain Adaptative Cell Segmentation
757
the same architecture across domains, allowing us to learn lightweight target models from large source models, improving learning efficiency. We demonstrate the effectiveness of our method on multiple public datasets and believe it can be readily applied to other domains and adaptation scenarios. Moving forward, we plan to integrate our approach with active learning methods to enhance annotation efficiency in the semi-supervised setting. By leveraging multi-source domain knowledge, we aim to improve the reliability of the target model and enable more efficient annotation for better model performance. Acknowledgements. This work is partially supported by the National Natural Science Foundation of China under grant No. 61902310 and the Natural Science Basic Research Program of Shaanxi, China under grant 2020JQ030.
References 1. Ahmed, S.M., Raychaudhuri, D.S., Paul, S., Oymak, S., Roy-Chowdhury, A.K.: Unsupervised multi-source domain adaptation without access to source data. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 10103–10112 (2021) 2. Awan, R., et al.: Glandular morphometrics for objective grading of colorectal adenocarcinoma histology images. Sci. Rep. 7(1), 16852 (2017) 3. Bateson, M., Kervadec, H., Dolz, J., Lombaert, H., Ben Ayed, I.: Source-relaxed domain adaptation for image segmentation. In: Martel, A.L., et al. (eds.) MICCAI 2020. LNCS, vol. 12261, pp. 490–499. Springer, Cham (2020). https://doi.org/10. 1007/978-3-030-59710-8 48 4. Carlini, N., et al.: Extracting training data from large language models. In: USENIX Security Symposium, vol. 6 (2021) 5. Chen, C., Liu, Q., Jin, Y., Dou, Q., Heng, P.-A.: Source-free domain adaptive fundus image segmentation with denoised pseudo-labeling. In: de Bruijne, M., et al. (eds.) MICCAI 2021. LNCS, vol. 12905, pp. 225–235. Springer, Cham (2021). https://doi.org/10.1007/978-3-030-87240-3 22 6. DeVries, T., Taylor, G.W.: Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552 (2017) 7. Feng, H., et al.: Kd3a: unsupervised multi-source decentralized domain adaptation via knowledge distillation. In: ICML, pp. 3274–3283 (2021) 8. Haq, M.M., Huang, J.: Adversarial domain adaptation for cell segmentation. In: Medical Imaging with Deep Learning, pp. 277–287. PMLR (2020) 9. Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531 (2015) 10. Hou, L., et al.: Dataset of segmented nuclei in hematoxylin and eosin stained histopathology images of ten cancer types. Sci. Data 7(1), 185 (2020) 11. Irshad, H., et al.: Crowdsourcing image annotation for nucleus detection and segmentation in computational pathology: evaluating experts, automated methods, and the crowd. In: Pacific Symposium on Biocomputing Co-chairs, pp. 294–305. World Scientific (2014) 12. Kerfoot, E., Clough, J., Oksuz, I., Lee, J., King, A.P., Schnabel, J.A.: Left-ventricle quantification using residual U-Net. In: Pop, M., et al. (eds.) STACOM 2018. LNCS, vol. 11395, pp. 371–380. Springer, Cham (2019). https://doi.org/10.1007/ 978-3-030-12029-0 40
758
X. Wang et al.
13. Li, Z., Togo, R., Ogawa, T., Haseyama, M.: Union-set multi-source model adaptation for semantic segmentation. In: Computer Vision-ECCV 2022: 17th European Conference, Tel Aviv, Israel, 23–27 October 2022, Proceedings, Part XXIX, pp. 579–595. Springer, Heidelberg (2022). https://doi.org/10.1007/978-3-031-198182 33 14. Liang, J., Hu, D., Feng, J.: Do we really need to access the source data? source hypothesis transfer for unsupervised domain adaptation. In: International Conference on Machine Learning, pp. 6028–6039. PMLR (2020) 15. Liang, J., Hu, D., Feng, J., He, R.: Dine: domain adaptation from single and multiple black-box predictors. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8003–8013 (2022) 16. Liu, X., et al.: Adversarial unsupervised domain adaptation with conditional and label shift: Infer, align and iterate. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp. 10367–10376 (2021) 17. Liu, X., et al.: Unsupervised black-box model domain adaptation for brain tumor segmentation. Front. Neurosci. 16, 837646 (2022) 18. Naylor, P., La´e, M., Reyal, F., Walter, T.: Segmentation of nuclei in histopathology images by deep regression of the distance map. IEEE Trans. Med. Imaging 38(2), 448–459 (2018) 19. Scherr, T., L¨ offler, K., B¨ ohland, M., Mikut, R.: Cell segmentation and tracking using cnn-based distance predictions and a graph-based matching strategy. Plos One 15(12), e0243219 (2020) 20. Shimodaira, H.: Improving predictive inference under covariate shift by weighting the log-likelihood function. J. Stat. Plan. Infer. 90(2), 227–244 (2000) 21. Xie, B., Yuan, L., Li, S., Liu, C.H., Cheng, X.: Towards fewer annotations: active learning via region impurity and prediction uncertainty for domain adaptive semantic segmentation. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 8068–8078 (2022)
MedGen3D: A Deep Generative Framework for Paired 3D Image and Mask Generation Kun Han1(B) , Yifeng Xiong1 , Chenyu You2 , Pooya Khosravi1 , Shanlin Sun1 , Xiangyi Yan1 , James S. Duncan2 , and Xiaohui Xie1 1
University of California, Irvine, USA [email protected] 2 Yale University, New Haven, USA
Abstract. Acquiring and annotating sufficient labeled data is crucial in developing accurate and robust learning-based models, but obtaining such data can be challenging in many medical image segmentation tasks. One promising solution is to synthesize realistic data with ground-truth mask annotations. However, no prior studies have explored generating complete 3D volumetric images with masks. In this paper, we present MedGen3D, a deep generative framework that can generate paired 3D medical images and masks. First, we represent the 3D medical data as 2D sequences and propose the Multi-Condition Diffusion Probabilistic Model (MC-DPM) to generate multi-label mask sequences adhering to anatomical geometry. Then, we use an image sequence generator and semantic diffusion refiner conditioned on the generated mask sequences to produce realistic 3D medical images that align with the generated masks. Our proposed framework guarantees accurate alignment between synthetic images and segmentation maps. Experiments on 3D thoracic CT and brain MRI datasets show that our synthetic data is both diverse and faithful to the original data, and demonstrate the benefits for downstream segmentation tasks. We anticipate that MedGen3D’s ability to synthesize paired 3D medical images and masks will prove valuable in training deep learning models for medical imaging tasks. Keywords: Deep Generative Framework · 3D Volumetric Images with Masks · Fidelity and Diversity · Segmentation
1
Introduction
In medical image analysis, the availability of a substantial quantity of accurately annotated 3D data is a prerequisite for achieving high performance in tasks like segmentation and detection [7,15,23,26,28–36]. This, in turn, leads to more Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 72. c The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 759–769, 2023. https://doi.org/10.1007/978-3-031-43907-0_72
760
K. Han et al.
precise diagnoses and treatment plans. However, obtaining and annotating such data presents many challenges, including the complexity of medical images, the requirement for specialized expertise, and privacy concerns. Generating realistic synthetic data presents a promising solution to the above challenges as it eliminates the need for manual annotation and alleviates privacy risks. However, most prior studies [4,5,14,25] have focused on 2D image synthesis, with only a few generating corresponding segmentation masks. For instance, [13] uses dual generative adversarial networks (GAN) [12,34] to synthesize 2D labeled retina fundus images, while [10] combines a label generator [22] with an image generator [21] to generate 2D brain MRI data. More recently, [24] uses WGAN [3] to generate small 3D patches and corresponding vessel segmentations. However, there has been no prior research on generating whole 3D volumetric images with the corresponding segmentation masks. Generating 3D volumetric images with corresponding segmentation masks faces two major obstacles. First, directly feeding entire 3D volumes to neural networks is impractical due to GPU memory constraints, and downsizing the resolution may compromise the quality of the synthetic data. Second, treating the entire 3D volume as a single data point during training is suboptimal because of the limited availability of annotated 3D data. Thus, innovative methods are required to overcome these challenges and generate high-quality synthetic 3D volumetric data with corresponding segmentation masks. We propose MedGen3D, a novel diffusion-based deep generative framework that generates paired 3D volumetric medical images and multi-label masks. Our approach treats 3D medical data as sequences of slices and employs an autoregressive process to sequentially generate 3D masks and images. In the first stage, a Multi-Condition Diffusion Probabilistic Model (MC-DPM) generates mask sequences by combining conditional and unconditional generation processes. Specifically, the MC-DPM generates mask subsequences (i.e., several consecutive slices) at any position directly from random noise or by conditioning on existing slices to generate subsequences forward or backward. Given that medical images have similar anatomical structures, slice indices serve as additional conditions to aid the mask subsequence generation. In the second stage, we introduce a conditional image generator with a seq-to-seq model from [27] and a semantic diffusion refiner. By conditioning on the mask sequences generated in the first stage, our image generator synthesizes realistic medical images aligned with masks while preserving spatial consistency across adjacent slices. The main contributions of our work are as follows: 1) Our proposed framework is the first to address the challenge of synthesizing complete 3D volumetric medical images with their corresponding masks; 2) we introduce a multicondition diffusion probabilistic model for generating 3D anatomical masks with high fidelity and diversity; 3) we leverage the generated masks to condition an image sequence generator and a semantic diffusion refiner, which produces realistic medical images that align accurately with the generated masks; and 4) we present experimental results that demonstrate the fidelity and diversity of the generated 3D multi-label medical images, highlighting their potential benefits for downstream segmentation tasks.
MedGen3D: Paired 3D Image and Mask Generation
2 2.1
761
Preliminary Diffusion Probabilistic Model
A diffusion probabilistic model (DPM) [16] is a parameterized Markov chain of length T, which is designed to learn the data distribution p(X). DPM builds the Forward Diffusion Process (FDP) to get the diffused data point Xt at any √ time step t by q (Xt | Xt−1 ) = N Xt ; 1 − βt Xt−1 , βt I , with X0 ∼ q(X0 ) and t ¯ t = s=1 (1 − βs ), Reverse Diffusion p(XT ) = N (XT ; 0, I). Let αt = 1−βt and α Process (RDP) is trained to predict the noise added in the FDP by minimizing: √ 2 √ (1) Loss(θ) = EX0 ∼q(X0 ),∼N (0,I),t − θ α ¯ t X0 + 1 − α ¯ t , t , where θ is predicted noise and θ is the model parameters. 2.2
Classifier-Free Guidance
Samples from conditional diffusion models can be improved with classifier-free guidance [17] by setting the condition c as ∅ with probability p. During sampling, the output of the model is extrapolated further in the direction of θ (Xt | c) and away from θ (Xt | ∅) as follows: ˆθ (Xt | c) = θ (Xt | ∅) + s · (θ (Xt | c) − θ (Xt | ∅)) ,
(2)
where ∅ represents a null condition and s ≥ 1 is the guidance scale.
3
Methodology
We propose a sequential process to generate complex 3D volumetric images with masks, as illustrated in Fig. 1. The first stage generates multi-label segmentation, and the second stage performs conditional medical image generation. The details will be presented in the following sections. 3.1
3D Mask Generator
Due to the limited annotated real data and GPU memory constraints, directly feeding the entire 3D volume to the network is impractical. Instead, we treat 3D
Fig. 1. Overview of the proposed MedGen3D, including a 3D mask generator to autoregressively generate the mask sequences starting from a random position z, and a conditional image generator to generate 3D images conditioned on generated masks.
762
K. Han et al.
Fig. 2. Proposed 3D mask generator. Given target position z, MC-DPM is designed to generate mask subsequences (length of m) for specific region, unconditionally or conditioning on first or last n slices, according to the pre-defined probability pC ∈ {pF , pB , pU }. Binary indicators are assigned to slices to signify the conditional slices. We ignore the binary indicators in the inference process for clear visualization with red outline denoting the conditional slices and green outline denoting the generated slices.
medical data as a series of subsequences. To generate an entire mask sequence, an initial subsequence of m consecutive slices is unconditionally generated from random noise. Then the subsequence is expanded forward and backward in an autoregressive manner, conditioned on existing slices. Inspired by classifier-free guidance in Sect. 2.2, we propose a general MultiCondition Diffusion Probabilistic Model (MC-DPM) to unify all three conditional generations (unconditional, forward, and backward). As shown in Fig. 2, MC-DPM is able to generate mask sequences directly from random noise or conditioning on existing slices. Furthermore, as 3D medical data typically have similar anatomical structures, slices with the same relative position roughly correspond to the same anatomical regions. Therefore, we can utilize the relative position of slices as conditions to guide the MC-DPM in generating subsequences of the target region and control the length of generated sequences. Train: For a given 3D multi-label mask M ∈ RD×H×W , subsequneces of m consecutive slices are selected as {Mz , Mz+1 , . . . , Mz+(m−1) }, with z as the randomly selected starting indices. For each subsequence, we determine the conditional slices X C ∈ {Rn×H×W , ∅} by selecting either the first or the last n slices, or no slice, based on a probability pC ∈ {pF orward , pBackward , pU ncondition }. The objective of the MC-DPM is to generate the remaining slices, denoted as C X P ∈ R(m−len(X ))×H×W .
MedGen3D: Paired 3D Image and Mask Generation
763
To incorporate the position condition, we utilize the relative position of the subsequence z˜ = z/D, where z is the index of the subsequence’s starting slice. Then we embed the position condition and concatenate it with the time embedding to aid the generation process. We also utilize a binary indicator for each slice in the subsequence to signify the existence of conditional slices. The joint distribution of reverse diffusion process (RDP) with the conditional slices X C can be written as: P |X C , z˜) pθ (X0:T
=
p(XTP )
T
P pθ (Xt−1 | XtP , X C , z˜).
(3)
t=1
where p(XTP ) = N XTP ; 0, I , z˜ = z/D and pθ is the distribution parameterized by the model. Overall, √ the model √ will be trained by minimizing the following loss function, ¯ t X0P + 1 − α ¯ t : with XtP = α 2 (4) Loss(θ) = EX0 ∼q(X0 ),∼N (0,I),pC ,z,t − θ XtP , X C , z, t . Inference: During inference, MC-DPM first generates a subsequence of m slices from random noise given a random location z. The entire mask sequence can then be generated autoregressively by expanding in both directions, conditioned on the existing slices, as shown in Fig. 2. Please refer to the Supplementary for a detailed generation process and network structure. 3.2
Conditional Image Generator
In the second step, we employ a sequence-to-sequence method to generate medical images conditioned on masks, as shown in Fig. 3. Image Sequence Generator: In the sequence-to-sequence generation task, new slice is the combination of the warped previous slice and newly generated texture, weighted by a continuous mask [27]. We utilize Vid2Vid [27] as our image sequence generator. We train Vid2Vid with its original loss, which includes GAN loss on multi-scale images and video discriminators, flow estimation loss, and feature matching loss. Semantic Diffusion Refiner: Despite the high cross-slice consistency and spatial continuity achieved by vid2vid, issues such as blocking, blurriness and suboptimal texture generation persist. Given that diffusion models have been shown to generate superior images [9], we propose a semantic diffusion refiner utilizing a diffusion probabilistic model to refine the previously generated images. For each of the 3 different views, we train a semantic diffusion model (SDM), which takes 2D masks and noisy images as inputs to generate images aligned with input masks. During inference, we only apply small noising steps (10 steps)
764
K. Han et al.
Fig. 3. Image Sequence Generator. Given the generated 3D mask, the initial image is generated by Vid2Vid model sequentially. To utilize the semantic diffusion model (SDM) to refine the initial result, we first apply small steps (10 steps) noise, and then use three SDMs to refine. The final result is the mean 3D images from 3 different views (Axial, Coronal, and Sagittal), yielding significant improvements over the initially generated image.
to the generated images so that the overall anatomical structure and spatial continuity are preserved. After that, we refine the images using the pre-trained semantic diffusion model. The final refined 3D images are the mean results from 3 views. Experimental results show an evident improvement in the quality of generated images with the help of semantic diffusion refiner.
4 4.1
Experiments and Results Datasets and Setups
Datasets: We conducted experiments on the thoracic site using three thoracic CT datasets and the brain site with two brain MRI datasets. For both generative models and downstream segmentation tasks, we utilized the following datasets: – SegTHOR [19]: 3D thorax CT scans (25 training, 5 validation, 10 testing); – OASIS [20]: 3D brain MRI T1 scans (40 training, 10 validation, 10 testing); For the downstream segmentation task only and the transfer learning, we utilized 10 fine-tuning, 5 validation, and 10 testing scans from each of the 3D thorax CT datasets of StructSeg-Thorax [2] and Public-Thor [7], as well as the 3D brain MRI T1 dataset from ADNI [1]. Implementation: For thoracic datasets, we crop and pad CT scans to (96 × 320 × 320). The annotations of six organs (left lung, right lung, spinal cord, esophagus, heart, and trachea) are examined by an experienced radiation oncologist. We also include a body mask to aid in the image generation of body regions. For brain MRI datasets, we use Freesurfer [11] to get segmentations of four regions (cortex, subcortical gray matter, white matter, and CSF), and then crop the volume to (192 × 160 × 160). We assign discrete values to masks of different regions or organs for both thoracic and brain datasets and then combine them into one 3D volume. When synthesizing mask sequences, we resize the
MedGen3D: Paired 3D Image and Mask Generation
765
width and height of the masks to 128×128 and set the length of the subsequence m to 6. We use official segmentation models provided by MONAI [6] along with standard data augmentations, including spatial and color transformations. Setup: We compare the synthetic image quality with DDPM [16], 3D-α-WGAN [18] and Vid2Vid [27], and utilize four segmentation models with different training strategies to demonstrate the benefit for the downstream task. 4.2
Evaluate the Quality of Synthetic Image.
Synthetic Dataset: To address the limited availability of annotated 3D medical data, we used only 30 CT scans from SegTHOR (25 for training and 5 for validation) and 50 MRI scans from OASIS (40 for training and 10 for validation) to generate 110 3D thoracic CT scans and 110 3D brain MRI scans, respectively (Fig. 4).
Fig. 4. Our proposed method produces more anatomically accurate images compared to 3D-α-WGAN and vid2vid, as demonstrated by the clearer organ boundaries and more realistic textures. Left: Qualitative comparison between different generative models. Right: Visualization of synthetic 3D brain MRI slices at different relative positions.
We compare the fidelity and diversity of our synthetic data with DDPM [16] (train 3 for different views), 3D-α-WGAN [18], and vid2vid [27] by calculating the mean Fr`echet Inception Distance (FID) and Learned Perceptual Image Patch Similarity (LPIPS) from 3 different views. According to Table 1, our proposed method has a slightly lower FID score but a similar LPIPS score compared to DDPM which directly generates 2D images from noise. We speculate that this is because DDPM is trained on 2D images without explicit anatomical constraints and only generates 2D images. On the other hand, 3D-α-WGAN [18], which uses much larger 3D training data (146 for thorax and 414 for brain), has significantly worse FID and LPIPS scores than our method. Moreover, our proposed method outperforms Vid2Vid, showing the effectiveness of our semantic diffusion refiner.
766
K. Han et al. Table 1. Synthetic image quality comparison between baselines and ours.
Thoracic CT Brain MRI FID ↓ LPIPS ↑ FID ↓ LPIPS ↑ DDPM [16] 35.2 0.316 3D-α-WGAN [18] 136.2 0.286 47.3 0.300 Vid2Vid [27]
34.9 0.298 136.4 0.289 48.2 0.324
Ours
40.3
39.6
0.305
0.326
Table 2. Experiment 2: DSC of different thoracic segmentation models. There are 5 training strategies, namely: E2-1: Training with real SegTHOR training data; E2-2: Training with synthetic data; E2-3: Training with both synthetic and real data; E2-4: Finetuning model from E2-2 using real training data; and E2-5: finetuning model from E2-3 using real training data. (* denotes the training data source.) SegTHOR* StructSeg-Thorax Public-Thor Unet 2D Unet 3D UNETR Swin UNETR Unet 2D Unet 3D UNETR Swin UNETR Unet 2D Unet 3D UNETR Swin UNETR E2-1 0.817
0.873
0.867
0.878
0.722
0.793
0.789
0.810
0.822
0.837
0.836
0.847
E2-2 0.815
0.846
0.845
0.854
0.736
0.788
0.788
0.803
0.786
0.838
0.814
0.842
E2-3 0.845
0.881
0.886
0.886
0.772
0.827
0.824
0.827
0.812
0.856
0.853
0.856
E2-4 0.855
0.887
0.894
0.899
0.775
0.833
0.825
0.833
0.824
0.861
0.852
0.867
E2-5 0.847
0.891
0.890
0.897
0.783
0.833
0.823
0.835
0.818
0.864
0.858
0.867
4.3
Evaluate the Benefits for Segmentation Task
We explore the benefits of synthetic data for downstream segmentation tasks by comparing Sørensen-Dice coefficient (DSC) of 4 segmentation models, including Unet2D [23], UNet3D [8], UNETR [15], and Swin-UNETR [26]. In Table 2 and 3, we utilize real training data (from SegTHOR and OASIS) and synthetic data to train the segmentation models with 5 different strategies, and test on all 3 thoracic CT datasets and 2 brain MRI datasets. In Table 4, we aim to demonstrate whether the synthetic data can aid transfer learning with limited real finetuning data from each of the testing datasets (StructSeg-Thorax, Public-Thor and ADNI) with four training strategies. According to Table 2 and Table 3, the significant DSC difference between 2D and 3D segmentation models underlines the crucial role of 3D annotated data. While purely synthetic data (E2-2) fails to achieve the same performance as real training data (E2-1), the combination of real and synthetic data (E2-3) improves model performance in most cases, except for Unet2D on the PublicThor dataset. Furthermore, fine-tuning the pre-trained model with real data (E2-4 and E2-5) consistently outperforms the model trained only with real data. Please refer to Supplementary for organ-level DSC comparisons of the Swin-UNETR model with more details. According to Table 4, for transfer learning, utilizing the pre-trained model (E3-2) leads to better performance compared to training from scratch (E3-1).
MedGen3D: Paired 3D Image and Mask Generation
767
Table 3. Experiment 2: DSC of brain segmentation models. Please refer to Table 2 for detailed training strategies. (* denotes the training data source.) OASIS* ADNI Unet 2D Unet 3D UNETR Swin UNETR Unet 2D Unet 3D UNETR Swin UNETR E2-1 0.930
0.951
0.952
0.954
0.815
0.826
0.880
0.894
E2-2 0.905
0.936
0.935
0.934
0.759
0.825
0.828
0.854
E2-3 0.938
0.953
0.953
0.955
0.818
0.888
0.898
0.906
E2-4 0.940
0.955
0.954
0.956
0.819
0.891
0.903
0.903
E2-5 0.940
0.954
0.954
0.956
0.819
0.894
0.902
0.906
Table 4. Experiment 3: DSC of Swin-UNETR finetuned with real dataset. There are 4 training strategies: E3-1: Training from scratch for each dataset using limited finetuning data; E3-2 Finetuning the model E2-1 from experiment 2; E3-3 Finetuning the model E2-4 from experiment 2; and E3-4 Finetuning the model E2-5 from experiment 2. (* denotes the finetuning data source.)
Thoracic CT Brain MRI StructSeg-Thorax* Public-Thor* ADNI* E3-1 0.845
0.897
0.946
E3-2 0.865
0.901
0.948
E3-3 0.878
0.913
0.949
E3-4 0.882
0.914
0.949
Additionally, pretraining the model with synthetic data (E3-3 and E3-4) can facilitate transfer learning to a new dataset with limited annotated data. We have included video demonstrations of the generated 3D volumetric images in the supplementary material, which offer a more comprehensive representation of the generated image’s quality.
5
Conclusion
This paper introduces MedGen3D, a new framework for synthesizing 3D medical mask-image pairs. Our experiments demonstrate its potential in realistic data generation and downstream segmentation tasks with limited annotated data. Future work includes merging the image sequence generator and semantic diffusion refiner for end-to-end training and extending the framework to synthesize 3D medical images across modalities. Overall, we believe that our work opens up new possibilities for generating 3D high-quality medical images paired with masks, and look forward to future developments in this field.
768
K. Han et al.
References 1. https://adni.loni.usc.edu/ 2. https://structseg2019.grand-challenge.org/dataset/ 3. Arjovsky, M., Chintala, S., Bottou, L.: Wasserstein gan. arXiv preprint arXiv: Arxiv-1701.07875 (2017) 4. Baur, C., Albarqouni, S., Navab, N.: Melanogans: high resolution skin lesion synthesis with gans. arXiv preprint arXiv:1804.04338 (2018) 5. Bermudez, C., Plassard, A.J., Davis, L.T., Newton, A.T., Resnick, S.M., Landman, B.A.: Learning implicit brain mri manifolds with deep learning. In: Medical Imaging: Image Processing. SPIE (2018) 6. Cardoso, M.J., et al.: Monai: an open-source framework for deep learning in healthcare. arXiv preprint arXiv:2211.02701 (2022) 7. Chen, X., et al.: A deep learning-based auto-segmentation system for organs-atrisk on whole-body computed tomography images for radiation therapy. Radiother. Oncol. 160, 175–184 (2021) ¨ Abdulkadir, A., Lienkamp, S.S., Brox, T., Ronneberger, O.: 3D U-Net: 8. C ¸ i¸cek, O., learning dense volumetric segmentation from sparse annotation. In: Ourselin, S., Joskowicz, L., Sabuncu, M.R., Unal, G., Wells, W. (eds.) MICCAI 2016. LNCS, vol. 9901, pp. 424–432. Springer, Cham (2016). https://doi.org/10.1007/978-3-31946723-8 49 9. Dhariwal, P., Nichol, A.: Diffusion models beat gans on image synthesis. In: NeurIPS (2021) 10. Fernandez, V., et al.: Can segmentation models be trained with fully synthetically generated data? In: Zhao, C., Svoboda, D., Wolterink, J.M., Escobar, M. (eds.) MICCAI Workshop. SASHIMI 2022, vol. 13570, pp. 79–90. Springer, Heidelberg (2022). https://doi.org/10.1007/978-3-031-16980-9 8 11. Fischl, B.: Freesurfer. In: Neuroimage (2012) 12. Goodfellow, I., et al.: Generative adversarial networks. Commun. ACM 63, 139– 144 (2020) 13. Guibas, J.T., Virdi, T.S., Li, P.S.: Synthetic medical images from dual generative adversarial networks. arXiv preprint arXiv:1709.01872 (2017) 14. Han, C., et al.: Gan-based synthetic brain MR image generation. In: ISBI. IEEE (2018) 15. Hatamizadeh, A., et al.: Unetr: transformers for 3d medical image segmentation. In: WACV (2022) 16. Ho, J., Jain, A., Abbeel, P.: Denoising diffusion probabilistic models. In: NeurIPS (2020) 17. Ho, J., Salimans, T.: Classifier-free diffusion guidance. arXiv preprint arXiv: Arxiv-2207.12598 (2022) 18. Kwon, G., Han, C., Kim, D.: Generation of 3D brain MRI using auto-encoding generative adversarial networks. In: Shen, D., et al. (eds.) MICCAI 2019. LNCS, vol. 11766, pp. 118–126. Springer, Cham (2019). https://doi.org/10.1007/978-3030-32248-9 14 19. Lambert, Z., Petitjean, C., Dubray, B., Kuan, S.: Segthor: segmentation of thoracic organs at risk in ct images. In: IPTA. IEEE (2020) 20. Marcus, D.S., Wang, T.H., Parker, J., Csernansky, J.G., Morris, J.C., Buckner, R.L.: Open access series of imaging studies (oasis): cross-sectional mri data in young, middle aged, nondemented, and demented older adults. J. Cogn. Neurosci. 19, 1498–1507 (2007)
MedGen3D: Paired 3D Image and Mask Generation
769
21. Park, T., Liu, M.Y., Wang, T.C., Zhu, J.Y.: Semantic image synthesis with spatially-adaptive normalization. In: CVPR (2019) 22. Rombach, R., Blattmann, A., Lorenz, D., Esser, P., Ommer, B.: High-resolution image synthesis with latent diffusion models. In: CVPR (2022) 23. Ronneberger, O., Fischer, P., Brox, T.: U-Net: convolutional networks for biomedical image segmentation. In: Navab, N., Hornegger, J., Wells, W.M., Frangi, A.F. (eds.) MICCAI 2015. LNCS, vol. 9351, pp. 234–241. Springer, Cham (2015). https://doi.org/10.1007/978-3-319-24574-4 28 24. Subramaniam, P., Kossen, T., et al.: Generating 3d tof-mra volumes and segmentation labels using generative adversarial networks. Med. Image Anal. 78, 102396 (2022) 25. Sun, L., Chen, J., Xu, Y., Gong, M., Yu, K., Batmanghelich, K.: Hierarchical amortized gan for 3d high resolution medical image synthesis. IEEE J. Biomed. Health Inf. 28, 3966–3975 (2022) 26. Tang, Y., et al.: Self-supervised pre-training of swin transformers for 3d medical image analysis. In: CVPR (2022) 27. Wang, T.C., et al.: Video-to-video synthesis. arXiv preprint arXiv:1808.06601 (2018) 28. Yan, X., Tang, H., Sun, S., Ma, H., Kong, D., Xie, X.: After-unet: axial fusion transformer unet for medical image segmentation. In: WACV (2022) 29. You, C., et al.: Mine your own anatomy: revisiting medical image segmentation with extremely limited labels. arXiv preprint arXiv:2209.13476 (2022) 30. You, C., et al.: Rethinking semi-supervised medical image segmentation: a variancereduction perspective. arXiv preprint arXiv:2302.01735 (2023) 31. You, C., Dai, W., Min, Y., Staib, L., Duncan, J.S.: Implicit anatomical rendering for medical image segmentation with stochastic experts. arXiv preprint arXiv:2304.03209 (2023) 32. You, C., Dai, W., Min, Y., Staib, L., Sekhon, J., Duncan, J.S.: Action++: improving semi-supervised medical image segmentation with adaptive anatomical contrast. arXiv preprint arXiv:2304.02689 (2023) 33. You, C., Dai, W., Staib, L., Duncan, J.S.: Bootstrapping semi-supervised medical image segmentation with anatomical-aware contrastive distillation. arXiv preprint arXiv:2206.02307 (2022) 34. You, C., et al.: Class-aware adversarial transformers for medical image segmentation. In: NeurIPS (2022) 35. You, C., Zhao, R., Staib, L.H., Duncan, J.S.: Momentum contrastive voxel-wise representation learning for semi-supervised volumetric medical image segmentation. In: Wang, L., Dou, Q., Fletcher, P.T., Speidel, S., Li, S. (eds.) Medical Image Computing and Computer Assisted Intervention – MICCAI, vol. 13434, pp. 639– 652. Springer, Heidelberg (2022). https://doi.org/10.1007/978-3-031-16440-8 61 36. You, C., Zhou, Y., Zhao, R., Staib, L., Duncan, J.S.: Simcvd: simple contrastive voxel-wise representation distillation for semi-supervised medical image segmentation. IEEE Trans. Med.Imaging 41, 2228–2237 (2022)
Unsupervised Domain Transfer with Conditional Invertible Neural Networks Kris K. Dreher1,2(B) , Leonardo Ayala1,3 , Melanie Schellenberg1,4 , olke1,4 , Tim J. Adler1 , Silvia Seidlitz1,4,5,6 , Marco H¨ ubner1,4 , Jan-Hinrich N¨ 1,4,5 , Alexander Studier-Fischer7 , Janek Gr¨ ohl8,9 , Felix Nickel7 , Jan Sellner 4 1,6 Ullrich K¨ othe , Alexander Seitel , and Lena Maier-Hein1,3,4,5,6 1
Intelligent Medical Systems, German Cancer Research Center (DKFZ), Heidelberg, Germany {k.dreher,l.maier-hein}@dkfz-heidelberg.de 2 Faculty of Physics and Astronomy, Heidelberg University, Heidelberg, Germany 3 Medical Faculty, Heidelberg University, Heidelberg, Germany 4 Faculty of Mathematics and Computer Science, Heidelberg University, Heidelberg, Germany 5 Helmholtz Information and Data Science School for Health, Karlsruhe, Heidelberg, Germany 6 National Center for Tumor Diseases (NCT) Heidelberg a Partnership Between DKFZ and Heidelberg University Hospital, Heidelberg, Germany 7 Department of General, Visceral, and Transplantation Surgery, Heidelberg University Hospital, Heidelberg, Germany 8 Cancer Research UK Cambridge Institute, University of Cambridge, Cambridge, UK 9 Department of Physics, University of Cambridge, Cambridge, UK
Abstract. Synthetic medical image generation has evolved as a key technique for neural network training and validation. A core challenge, however, remains in the domain gap between simulations and real data. While deep learning-based domain transfer using Cycle Generative Adversarial Networks and similar architectures has led to substantial progress in the field, there are use cases in which state-of-the-art approaches still fail to generate training images that produce convincing results on relevant downstream tasks. Here, we address this issue with a domain transfer approach based on conditional invertible neural networks (cINNs). As a particular advantage, our method inherently guarantees cycle consistency through its invertible architecture, and network training can efficiently be conducted with maximum likelihood training. To showcase our method’s generic applicability, we apply it to two spectral imaging modalities at different scales, namely hyperspectral imaging (pixel-level) and photoacoustic tomography (image-level). According to Supplementary Information The online version contains supplementary material available at https://doi.org/10.1007/978-3-031-43907-0 73. c The Author(s) 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 770–780, 2023. https://doi.org/10.1007/978-3-031-43907-0_73
Unsupervised Domain Transfer with Conditional Invertible Neural Networks
771
comprehensive experiments, our method enables the generation of realistic spectral data and outperforms the state of the art on two downstream classification tasks (binary and multi-class). cINN-based domain transfer could thus evolve as an important method for realistic synthetic data generation in the field of spectral imaging and beyond. The code is available at https://github.com/IMSY-DKFZ/UDT-cINN. Keywords: Domain transfer · invertible neural networks · medical imaging · photoacoustic tomography · hyperspectral imaging · deep learning
1
Introduction
The success of supervised learning methods in the medical domain led to countless breakthroughs that might be translated into clinical routine and have the potential to revolutionize healthcare [6,13]. For many applications, however, labeled reference data (ground truth) may not be available for training and validating a neural network in a supervised manner. One such application is spectral imaging which comprises various non-interventional, non-ionizing imaging techniques that can resolve functional tissue properties such as blood oxygenation in real time [1,3,23]. While simulations have the potential to overcome the lack of ground truth, synthetic data is not yet sufficiently realistic [9]. Cycle Generative Adversarial Networks (GAN)-based architectures are widely used for domain transfer [12,24] but may suffer from issues such as unstable training, hallucinations, or mode collapse [15]. Furthermore, they have predominantly been used for conventional RGB imaging and one-channel cross-modality domain adaptation, and may not be suitable for other imaging modalities with more channels. We address these challenges with the following contributions: Domain Transfer Method: We present an entirely new sim-to-real transfer approach based on conditional invertible neural networks (cINNs) (cf. Fig. 1) specifically designed for data with many spectral channels. This approach inherently addresses weaknesses of the state of the art with respect to the preservation of spectral consistency and, importantly, does not require paired images. Instantiation to Spectral Imaging: We show that our method can generically be applied to two complementary modalities: photoacoustic tomography (PAT; image-level) and hyperspectral imaging (HSI; pixel-level). Comprehensive Validation: In comprehensive validation studies based on more than 2,000 PAT images (real: ∼1,000) and more than 6 million spectra for HSI (real: ∼6 million) we investigate and subsequently confirm our two main hypotheses: (H1) Our cINN-based models can close the domain gap between simulated and real spectral data better than current state-of-the-art methods
772
K. K. Dreher et al.
Fig. 1. Pipeline for data-driven spectral image analysis in the absence of labeled reference data. A physics-based simulation framework generates simulated spectral images with corresponding reference labels (e.g. tissue type or oxygenation (sO2 )). Our domain transfer method based on cINNs leverages unlabeled real data to increase their realism. The domain-transferred data can then be used for supervised training of a downstream task (e.g. classification).
regarding spectral plausibility. (H2) Training models on data transferred by our cINN-based approach can improve their performance on the corresponding (clinical) downstream task without them having seen labeled real data.
2 2.1
Materials and Methods Domain Transfer with Conditional Invertible Neural Networks
Concept Overview. Our domain transfer approach (cf. Fig. 2) is based on the assumption that data samples from both domains carry domain-invariant information (e.g. on optical tissue properties) and domain-variant information (e.g. modality-specific artifacts). The invertible architecture, which inherently guarantees cycle consistency, transfers both simulated and real data into a shared latent space. While the domain-invariant features are captured in the latent space, the domain-variant features can either be filtered (during encoding) or added (during decoding) by utilizing a domain label D. To achieve spectral consistency, we leverage the fact that different tissue types feature characteristic spectral signatures and condition the model on the tissue label Y if available. For unlabeled (real) data, we use randomly generated proxy labels instead. To achieve high visual quality beyond spectral consistency, we include two discriminators Dissim and Disreal for their respective domains. Finally, as a key theoretical advantage, we avoid mode collapse with maximum likelihood optimization. Implementation details are provided in the following. cINN Model Design. The core of our architecture is a cINN [2] (cf. Fig. 2), comprising multiple (i) scales of Ni -chained affine conditional coupling (CC)
Unsupervised Domain Transfer with Conditional Invertible Neural Networks
773
Fig. 2. Proposed architecture based on cINNs. The invertible architecture transfers both simulated and real data into a shared latent space (right). By conditioning on the domain D (bottom), a latent vector can be transferred to either the simulated or the real domain (left) for which the discriminator Dissim and Disreal calculate the losses for adversarial training.
blocks [7]. These scales are necessary in order to increase the receptive field of the network and are achieved by Haar wavelet downsampling [11]. A CC block consists of subnetworks that can be freely chosen depending on the data dimensionality (e.g. fully connected or convolutional networks) as they are only evaluated in the forward direction. The CC blocks receive a condition consisting of two parts: domain label and tissue label, which are then concatenated to the input along the channel dimension. In the case of PAT, the tissue label is a full semantic and random segmentation map for the simulated and real data, respectively. In the case of HSI, the tissue label is a one-hot encoded vector for organ labels. Model Training. In the following, the proposed cINN with its parameters θ will be referred to as f (x, DY, θ) and its inverse as f −1 for any input x ∼ pD from domain D ∈ {Dsim , Dreal } with prior density pD and its corresponding latent space variable z. The condition DY is the combination of domain label D as well as the tissue label Y ∈ {Ysim , Yreal }. Then the maximum likelihood loss ML for a training sample xi is described by ||f (xi , DY, θ)||22 ∂f − log|Ji | with Ji = det . (1) ML = Ei D 2 ∂x xi For the adversarial training, we employ the least squares training scheme [18] −1 ◦ fD and discriminator DisD for each domain with for generator GenD = fD xD as input from the source domain and xD as input from the target domain: L = (DisD (GenD (xD ) − 1))2 (2) E GenD
L =
DisD
E
xD ∼pD
xD ∼pD
(DisD (xD ) − 1)2 +
E
xD ∼pD
(DisD (GenD (xD )))2 .
(3)
774
K. K. Dreher et al.
Finally, the full loss for the proposed model comprises the following: L
T otalGen
= ML + ML + real
sim
L
Genreal
+
L
Gensim
and
L
T otalDis
=
L
Disreal
+
L .
Dissim
(4)
Model Inference. The domain transfer is done in two steps: 1) A simulated image is encoded in the latent space with conditions Dsim and Ysim to its latent representation z, 2) z is decoded to the real domain via Dreal with the simulated tissue label Ysim : xsim→real = f −1 (·, Dreal Ysim , θ) ◦ f (·, Dsim Ysim , θ)(xsim ). 2.2
Spectral Imaging Data
Photoacoustic Tomography Data. PAT is a non-ionizing imaging modality that enables the imaging of functional tissue properties such as tissue oxygenation [22]. The real PAT data (cf. Fig. 3) used in this work are images of human forearms that were recorded from 30 healthy volunteers using the MSOT Acuity Echo (iThera Medical GmbH, Munich, Germany) (all regulations followed under study ID: S-451/2020, and the study is registered with the German Clinical Trials Register under reference number DRKS00023205). In this study, 16 wavelengths from 700 nm to 850 nm in steps of 10 nm were recorded for each image. The resulting 180 images were semantically segmented into the structures shown in Fig. 3 according to the annotation protocol provided in [20]. Additionally, a full sweep of each forearm was performed to generate more unlabeled images, thus
Fig. 3. Training data used for the validation experiments. For PAT, 960 real images from 30 volunteers were acquired. For HSI, more than six million spectra corresponding to 460 images and 20 individuals were used. The tissue labels PAT correspond to 2D semantic segmentations, whereas the tissue labels for HSI represent 10 different organs. For PAT, ∼1600 images were simulated, whereas around 210,000 spectra were simulated for HSI.
Unsupervised Domain Transfer with Conditional Invertible Neural Networks
775
amounting to a total of 955 real images. The simulated PAT data (cf. Fig. 3) used in this work comprises 1,572 simulated images of human forearms. They were generated with the toolkit for Simulation and Image Processing for Photonics and Acoustics (SIMPA) [8] based on a forearm literature model [21] and with a digital device twin of the MSOT Acuity Echo. Hyperspectral Imaging Data. HSI is an emerging modality with high potential for surgery [4]. In this work, we performed pixel-wise analysis of HSI R Tissue (Diaspecimages. The real HSI data was acquired with the Tivita tive Vision GmbH, Am Salzhaff, Germany) camera, featuring a spectral resolution of approximately 5 nm in the spectral range between 500 nm and 1000 nm nm. In total, 458 images, corresponding to 20 different pigs, were acquired (all regulations followed under study IDs: 35-9185.81/G-161/18 and 35-9185.81/G262/19) and annotated with ten structures: bladder, colon, fat, liver, omentum, peritoneum, skin, small bowel, spleen, and stomach (cf. Fig. 3). This amounts to 6,410,983 real spectra in total. The simulated HSI data was generated with a Monte Carlo method (cf. algorithm provided in the supplementary material). This procedure resulted in 213,541 simulated spectra with annotated organ labels.
3
Experiments and Results
The purpose of the experiments was to investigate hypotheses H1 and H2 (cf. Sect. 1). As comparison methods, a CycleGAN [24] and an unsupervised imageto-image translation (UNIT) network [16] were implemented fully convolutionally for PAT and in an adapted version for the one-dimensional HSI data. To make the comparison fair, the tissue label conditions were concatenated with the input, and we put significant effort into optimizing the UNIT on our data. Realism of Synthetic Data (H1) : According to qualitative analyses (Fig. 4) our domain transfer approach improves simulated PAT images with respect to key properties, including the realism of skin, background, and sharpness of vessels.
Fig. 4. Qualitative results. In comparison to simulated PAT images (left), images generated by the cINN (middle) resemble real PAT images (right) more closely. All images show a human forearm at 800 nm.
776
K. K. Dreher et al.
Fig. 5. Our domain transfer approach yields realistic spectra (here: of veins). The PCA plots in a) represent a kernel density estimation of the first and second components of a PCA embedding of the real data, which represent about 67% and 6% of the variance in the real data, respectively. The distributions on top and on the right of the PCA plot correspond to the marginal distributions of each dataset’s first two components. b) Violin plots show that the cINN yields spectra that feature a smaller difference to the real data compared to the simulations and the UNIT-generated data. The dashed lines represent the mean difference value, and each dot represents the difference for one wavelength.
A principal component analysis (PCA) performed on all artery and vein spectra of the real and synthetic datasets demonstrates that the distribution of the synthetic data is much closer to the real data after applying our domain transfer approach (cf. Fig. 5a)). The same holds for the absolute difference, as shown in Fig. 5b). Slightly better performance was achieved with the cINN compared to the UNIT. Similarly, our approach improves the realism of HSI spectra, as illustrated in Fig. 6, for spectra of five exemplary organs (colon, stomach, omentum, spleen, and fat). The cINN-transferred spectra generally match the real data very closely. Failure cases where the real data has a high variance (translucent band) are also shown. Benefit of Domain-Transferred Data for Downstream Tasks (H2): We examined two classification tasks for which reference data generation was feasible: classification of veins/arteries in PAT and organ classification in HSI. For both modalities, we used the completely untouched real test sets, comprising 162 images in the case of PAT and ∼ 920,000 spectra in the case of HSI. For both tasks, a calibrated random forest classifier (sklearn [19] with default parameters) was trained on the simulated, the domain-transferred (by UNIT and cINN), and real spectra. As metrics, the balanced accuracy (BA), area under receiver operating characteristic (AUROC) curve, and F1-score were selected based on [17]. As shown in Table 1, our domain transfer approach dramatically increases the classification performance for both downstream tasks. Compared to physicsbased simulation, the cINN obtained a relative improvement of 37% (BA), 25% (AUROC), and 22% (F1 Score) for PAT whereas the UNIT only achieved a
Unsupervised Domain Transfer with Conditional Invertible Neural Networks
777
Fig. 6. The cINN-transferred spectra are in closer agreement with the real spectra than the simulations and the UNIT-transferred spectra. Spectra for five exemplary organs are shown from 500 nm to 1000 nm. For each subplot, a zoom-in for the near-infrared region (>900 nm) is shown. The translucent bands represent the standard deviation across spectra for each organ. Table 1. Classification scores for different training data. The training data refers to real data, physics-based simulated data, data generated by a CycleGAN, by a UNIT without and with tissue labels (UNITY ), and by a cINN without (cINND ) and with (proposed cINNDY ) tissue labels as condition. Additionally, cINNDY without GAN refers to a cINNDY without the adversarial training. The best-performing methods, except if trained on real data, are printed in bold. Classifier training data PAT HSI BA AUROC F1-Score BA
AUROC F1-Score
Real
0.75 0.84
0.82
0.40 0.81
0.44
Simulated CycleGAN UNIT UNITY cINND cINNDY without GAN cINNDY (proposed)
0.52 0.39 0.50 0.64 0.66 0.65 0.71
0.64 0.16 0.65 0.77 0.72 0.76 0.78
0.24 0.11 0.20 0.24 0.25 0.28 0.29
0.18 0.06 0.20 0.25 0.20 0.26 0.24
0.64 0.20 0.44 0.81 0.73 0.78 0.80
0.75 0.57 0.72 0.74 0.72 0.75 0.76
relative improvement in the range of 20%-27% (depending on the metric). For HSI, the cINN achieved a relative improvement of 21% (BA), 1% (AUROC), and 33% (F1 Score) and it scored better in all metrics except for the F1 Score than the UNIT. For all metrics, training on real data still yields better results.
4
Discussion
With this paper, we presented the first domain transfer approach that combines the benefits of cINNs (exact maximum likelihood estimation) with those
778
K. K. Dreher et al.
of GANs (high image quality). A comprehensive validation involving qualitative and quantitative measures for the remaining domain gap and downstream tasks suggests that the approach is well-suited for sim-to-real transfer in spectral imaging. For both PAT and HSI, the domain gap between simulations and real data could be substantially reduced, and a dramatic increase in downstream task performance was obtained - also when compared to the popular UNIT approach. The only similar work on domain transfer in PAT has used a cycle GANbased architecture on a single wavelength with only photon propagation as PAT image simulator instead of full acoustic wave simulation and image reconstruction [14]. This potentially leads to spectral inconsistency in the sense that the spectral information either is lost during translation or remains unchanged from the source domain instead of adapting to the target domain. Outside the spectral/medical imaging community, Liu et al. [16] and Grover et al. [10] tasked variational autoencoders and invertible neural networks for each domain, respectively, to create the shared encoding. They both combined this approach with adversarial training to achieve high-quality image generation. Das et al. [5] built upon this approach by using labels from the source domain to condition the domain transfer task. In contrast to previous work, which used en-/decoders for each domain, we train a single network as shown in Fig. 2. with a two-fold condition consisting of a domain label (D) and a tissue label (Y ) from the source domain, which has the advantage of explicitly aiding the spectral domain transfer. The main limitation of our approach is the high dimensionality of the parameter space of the cINN as dimensionality reduction of data is not possible due to the information and volume-preserving property of INNs. This implies that the method is not suitable for arbitrarily high dimensions. Future work will comprise the rigorous validation of our method with tissue-mimicking phantoms for which reference data are available. In conclusion, our proposed approach of cINN-based domain transfer enables the generation of realistic spectral data. As it is not limited to spectral data, it could develop into a powerful method for domain transfer in the absence of labeled real data for a wide range of image modalities in the medical domain and beyond. Acknowledgements. This project was supported by the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (NEURAL SPICING, 101002198) and the Surgical Oncology Program of the National Center for Tumor Diseases (NCT) Heidelberg.
References 1. Adler, T.J., et al.: Uncertainty-aware performance assessment of optical imaging modalities with invertible neural networks. Int. J. Comput. Assist. Radiol. Surg. 14(6), 997–1007 (2019). https://doi.org/10.1007/s11548-019-01939-9 2. Ardizzone, L., L¨ uth, C., Kruse, J., Rother, C., K¨ othe, U.: Conditional invertible neural networks for guided image generation (2020)
Unsupervised Domain Transfer with Conditional Invertible Neural Networks
779
3. Ayala, L., et al.: Spectral imaging enables contrast agent-free real-time ischemia monitoring in laparoscopic surgery. Sci. Adv. (2023). https://doi.org/10.1126/ sciadv.add6778 4. Clancy, N.T., Jones, G., Maier-Hein, L., Elson, D.S., Stoyanov, D.: Surgical spectral imaging. Med. Image Anal. 63, 101699 (2020) 5. Das, H.P., Tran, R., Singh, J., Lin, Y.W., Spanos, C.J.: Cdcgen: cross-domain conditional generation via normalizing flows and adversarial training. arXiv preprint arXiv:2108.11368 (2021) 6. De Fauw, J., Ledsam, J.R., Romera-Paredes, B., Nikolov, S., Tomasev, N., Blackwell, S., et al.: Clinically applicable deep learning for diagnosis and referral in retinal disease. Nat. Med. 24(9), 1342–1350 (2018) 7. Dinh, L., Sohl-Dickstein, J., Bengio, S.: Density estimation using real nvp. arXiv preprint arXiv:1605.08803 (2016) 8. Gr¨ ohl, J., et al.: Simpa: an open-source toolkit for simulation and image processing for photonics and acoustics. J. Biomed. Opt. 27(8), 083010 (2022) 9. Gr¨ ohl, J., Schellenberg, M., Dreher, K., Maier-Hein, L.: Deep learning for biomedical photoacoustic imaging: a review. Photoacoustics 22, 100241 (2021) 10. Grover, A., Chute, C., Shu, R., Cao, Z., Ermon, S.: Alignflow: cycle consistent learning from multiple domains via normalizing flows. In: Proceedings of the AAAI Conference on Artificial Intelligence, vol. 34, pp. 4028–4035 (2020) 11. Haar, A.: Zur theorie der orthogonalen funktionensysteme. Mathematische Annalen 71(1), 38–53 (1911) 12. Hoffman, J., Tzet al.: Cycada: cycle-consistent adversarial domain adaptation. In: International Conference on Machine Learning, pp. 1989–1998 (2018) 13. Isensee, F., Jaeger, P.F., Kohl, S.A., Petersen, J., Maier-Hein, K.H.: nnu-net a selfconfiguring method for deep learning-based biomedical image segmentation. Nat. Methods 18(2), 203–211 (2021) 14. Li, J., et al.: Deep learning-based quantitative optoacoustic tomography of deep tissues in the absence of labeled experimental data. Optica 9(1), 32–41 (2022) 15. Li, K., Zhang, Y., Li, K., Fu, Y.: Adversarial feature hallucination networks for fewshot learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 13470–13479 (2020) 16. Liu, M.Y., Breuel, T., Kautz, J.: Unsupervised image-to-image translation networks. Adv. Neural Inf. Process. Syst. 30 (2017) 17. Maier-Hein, L., Reinke, A., Godau, P., Tizabi, M.D., B¨ uttner, F., Christodoulou, E., et al.: Metrics reloaded: pitfalls and recommendations for image analysis validation (2022). https://doi.org/10.48550/ARXIV.2206.01653 18. Mao, X., Li, Q., Xie, H., Lau, R.Y., Wang, Z., Paul Smolley, S.: Least squares generative adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2794–2802 (2017) 19. Pedregosa, F., et al.: Scikit-learn: machine learning in python. J. Mach. Learn. Res. 12, 2825–2830 (2011) 20. Schellenberg, M., et al.: Semantic segmentation of multispectral photoacoustic images using deep learning. Photoacoustics 26, 100341 (2022). https://doi.org/ 10.1016/j.pacs.2022.100341 21. Schellenberg, M., et al.: Photoacoustic image synthesis with generative adversarial networks. Photoacoustics 28, 100402 (2022) 22. Wang, X., Xie, X., Ku, G., Wang, L.V., Stoica, G.: Noninvasive imaging of hemoglobin concentration and oxygenation in the rat brain using high-resolution photoacoustic tomography. J. Biomed. Opt. 11(2), 024015 (2006)
780
K. K. Dreher et al.
23. Wirkert, S.J., et al.: Physiological parameter estimation from multispectral images unleashed. In: Descoteaux, M., Maier-Hein, L., Franz, A., Jannin, P., Collins, D.L., Duchesne, S. (eds.) MICCAI 2017. LNCS, vol. 10435, pp. 134–141. Springer, Cham (2017). https://doi.org/10.1007/978-3-319-66179-7 16 24. Zhu, J.Y., Park, T., Isola, P., Efros, A.A.: Unpaired image-to-image translation using cycle-consistent adversarial networks. In: Proceedings of the IEEE International Conference on Computer Vision, pp. 2223–2232 (2017)
Open Access This chapter is licensed under the terms of the Creative Commons Attribution 4.0 International License (http://creativecommons.org/licenses/by/4.0/), which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons license and indicate if changes were made. The images or other third party material in this chapter are included in the chapter’s Creative Commons license, unless indicated otherwise in a credit line to the material. If material is not included in the chapter’s Creative Commons license and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder.
Author Index
A Adams, Jadie 486 Adeleke, Olusola 300 Adeli, Ehsan 279 Adler, Tim J. 770 Altmann, Andre 425 Arcucci, Rossella 637 Arora, Chetan 206 Ayala, Leonardo 770 Azampour, Mohammad Farid B Bai, Wenjia 637 Balachandran, Abishek 573 Bammer, Roland 183 Barbed, O. León 583 Basu, Soumen 206 Batten, James 162 Baugh, Matthew 162 Belyaev, Mikhail 605 Bilgic, Berkin 457 Blaschko, Matthew B. 77 Bloch, Isabelle 227 Bolelli, Federico 248 Bône, Alexandre 227 Bontempo, Gianpaolo 248 Boyd, Joseph 594 Bozorgtabar, Behzad 195 Brandt, Johannes 57 Braren, Rickmer 141 C Cai, Jianfei 183 Calderara, Simone 248 Cao, Peng 310 Cao, Qing 87 Cardoso, M. Jorge 300, 446 Chao, Hanqing 728 Chapman, James 425 Che, Haoxuan 695 Chen, Bingzhi 562
435
Chen, Chen 637 Chen, Faquan 551 Chen, Hao 477, 695 Chen, Jiongquan 342 Chen, Kang 87 Chen, Pin-Yu 728 Chen, Wei 130 Chen, Xiaofei 405 Chen, Yaqi 551 Chen, Yixiong 706 Chen, Yu-Jen 173 Chen, Zefeiyun 518 Cheng, Jun 342 Cheng, Sibo 637 Chikontwe, Philip 528 Cho, Jaejin 457 Christodoulidis, Stergios 594 Comaniciu, Dorin 573 Cook, Gary 300 Cournède, Paul-henry 594 D De Benetti, Francesca 290 Delrieux, Claudio 67 Denehy, Linda 152 Dhurandhar, Amit 728 Dima, Alina F. 141 Ding, Chris 706 Dombrowski, Mischa 162 Dreher, Kris K. 770 Duncan, James S. 759 E El Fakhri, Georges 539 Elhabian, Shireen Y. 486, 508, 615 F Fainstein, Miguel 67 Fang, Yuqi 46 Fang, Yuxin 518 Fei, Jingjing 551
© The Editor(s) (if applicable) and The Author(s), under exclusive license to Springer Nature Switzerland AG 2023 H. Greenspan et al. (Eds.): MICCAI 2023, LNCS 14220, pp. 781–785, 2023. https://doi.org/10.1007/978-3-031-43907-0
782
Feldman, Paula 67 Feng, Yingjie 130 Ficarra, Elisa 248 Fillioux, Leo 594 Fua, Pascal 583 G Gamazo Tejero, Javier 331 Gao, Jiaxing 416 Gao, Jun 352 Gao, Shengbo 98 Gao, Simon S. 477 Ge, Rongjun 405 Ghamsarian, Negin 331 Ghesu, Florin C. 573 Goh, Vicky 300 Gomez, Alberto 152 Goncharov, Mikhail 605 Gonzalez Duque, Vanessa 435 Gori, Pietro 227 Gotway, Michael B. 651 Graf, Markus 141 Graf, Robert 141 Graham, Mark S. 446 Gröhl, Janek 770 Gu, Lin 13 Gu, Xianfeng 130 Gu, Yun 674 Guan, Yun 384 Guo, Lei 416 Gupta, Mayank 206 Gupta, Pankaj 206 Guymer, Robyn H. 477 H Han, Junwei 416 Han, Kun 759 Harada, Tatsuya 13 Hayashi, Yuichiro 663 He, Chunming 684 He, Jinlong 374 He, Junjun 674 He, Yuting 405 Heimann, Alexander F. 497 Hejrati, Mohsen 477 Ho, Tsung-Yi 173 Hu, Jingtong 119 Hu, Xindi 342 Hu, Xinrong 119, 173 Huang, Chaoqin 87
Author Index
Huang, Chenxi 551 Hübner, Marco 770 Hwang, Jae Youn 539 I Iarussi, Emmanuel 67 Islam, Mobarakol 35 Iyer, Krithika 615 J Jäger, H. Rolf 446 Jeong, Jaehoon 528 Jiang, Aofan 87 Jiang, Caiwen 3 Jiang, Hua 706 Jin, Haibo 695 Jing, Jie 238 Jun, Yohan 457 K K. Z. Tehrani, Ali 467 Kainz, Bernhard 162 Kaissis, Georgios 57 Kang, Qingbo 352 Katoh, Takashi 394 Kerdegari, Hamideh 152 Khosravi, Pooya 759 Kim, Seo Yeong 626 King, Andrew P. 152 Kirschke, Jan S. 141 Kitasaka, Takayuki 663 Kobayashi, Caique 457 Köthe, Ullrich 770 Kurmukov, Anvar 605 L Lao, Qicheng 352 Lawry Aguila, Ana 425 Le, Ngoc Minh Thu 152 Le, Thanh Phuong 152 Le, Thi Mai Thao 152 Le, Trung 183 Lee, Haeyun 539 Lee, Kyungsu 539 Lee, Won Hee 626 Lemke, Tristan 141 Li, Hongwei Bran 141 Li, Jingxian 706 Li, Junyu 342 Li, Kai 684
Author Index
Li, Kang 352 Li, Pengfei 374 Li, Qingyuan 518 Li, Shuo 405 Li, Xiu 684 Li, Zhongyu 749 Li, Zihao 98 Liang, Jianming 651 Liang, Wei 310 Lin, Tiancheng 259 Lin, Yili 717 Liu, Che 637 Liu, Gang 374 Liu, Hong 363 Liu, Jiameng 3 Liu, Jiangcong 384 Liu, Li 706 Liu, Mianxin 3 Liu, Mingxia 46 Liu, Side 518 Liu, Tianbao 518 Liu, Tianming 416 Liu, Xiaoli 310 Liu, Yishu 562 Liu, Yuting 717 Liu, Zaiyi 24 Lu, Donghuan 363 Lu, Guangming 562 Lu, Zexin 238 Luna, Miguel 528 Luo, Mingyuan 342 Luo, Xiangde 749
M Ma, Chong 416 Ma, DongAo 651 Ma, Hao 384 Ma, Jiechao 98 Ma, Lei 3 Mah, Yee H. 446 Mahapatra, Dwarikanath 195 Maier-Hein, Lena 770 Márquez-Neila, Pablo 331 Meissen, Felix 57 Menten, Martin J. 141 Mishra, Divyanshu 216 Mohaiu, Antonia T. 573 Montiel, José M. M. 583
783
Mori, Kensaku 663 Müller, Johanna P. 162 Müller, Philip 57 Murillo, Ana C. 583 N Nachev, Parashkev 446 Nakagawa, Akira 394 Nam, Siwoo 528 Navab, Nassir 290, 435 Nguyen, Van Hao 152 Nguyen-Duc, Thanh 183 Ni, Dong 342 Nickel, Felix 770 Nie, Dong 717 Niethammer, Marc 320 Noble, J. Alison 216 Nölke, Jan-Hinrich 770 O Oda, Masahiro 663 Okuno, Yasushi 394 Ourselin, Sebastien 300, 446 Ouyang, Jiahong 279 P Pan, Jiahui 562 Pan, Yongsheng 3 Pang, Jiaxuan 651 Papageorghiou, Aris T. 216 Papanai, Ashish 206 Paranjape, Jay N. 739 Park, Hye Won 626 Park, Sang Hyun 528 Paschali, Magdalini 290 Patel, Ashay 300 Patel, Vishal M. 739 Pely, Adam 477 Peng, Wei 279 Phung Tran Huy, Nhat 152 Phung, Dinh 183 Pinaya, Walter Hugo Lopez 300, 446 Pisani, Luigi 152 Pisov, Maxim 605 Pohl, Kilian M. 279 Popdan, Ioan M. 573 Porrello, Angelo 248
784
Q Qian, Jikuan 342 Qiao, Lishan 46 Qiao, Mengyun 637 Quan, Quan 24 R Raffler, Philipp 141 Razavi, Reza 152 Ren, Hongliang 35 Rivaz, Hassan 467 Rohé, Marc-Michel 227 Rominger, Axel 290 Rueckert, Daniel 57, 141 S Saha, Pramit 216 Sarfati, Emma 227 Sari, Hasan 290 Schellenberg, Melanie 770 Schoeffmann, Klaus 331 Seidlitz, Silvia 770 Seitel, Alexander 770 Sellner, Jan 770 Shah, Anand 637 Shen, Chenyu 270 Shen, Dinggang 3 Shi, Enze 416 Shi, Kuangyu 290 Shi, Yiyu 119, 173 Sikder, Shameema 739 Siless, Viviana 67 Simson, Walter 290, 435 Soboleva, Vera 605 Studier-Fischer, Alexander 770 Sun, Jinghan 363 Sun, Shanlin 759 Sznitman, Raphael 331 T Tajer, Ali 728 Tan, Jeremy 162 Tang, Longxiang 684 Tannast, Moritz 497 Teo, James T. 446 Thiran, Jean-Philippe 195 Thwaites, Louise 152 Tian, Lixia 384 Tokuhisa, Atsushi 394 Truong, Thi Phuong Thao 152
Author Index
Tudosiu, Petru-Daniel 300, 446 Tursynbek, Nurislam 320 U Umeda, Yuhei
394
V Vakalopoulou, Maria 594 Vedula, S. Swaroop 739 Velikova, Yordanka 435 Vizitiu, Anamaria 573 W Wada, Mutsuyo 394 Wada, Yuichiro 394 Wan, Jing 749 Wang, An 35 Wang, Jiale 497 Wang, Liansheng 363 Wang, Qianqian 46 Wang, Tao 238 Wang, Wei 46 Wang, Xiaoqing 457 Wang, Xingguang 749 Wang, Yanfeng 87 Wang, Yusi 518 Wei, Dong 363 Wei, Meng 674 Wei, Xiaozheng 416 Wei, Yaonai 416 Wen, Xuyun 717 Wendler, Thomas 290 Werring, David 446 Wolf, Sebastian 331 Woo, Jonghye 539 Wright, Paul 446 Wu, Mengqi 46 Wu, Qi 13 Wu, Shuang 87 Wu, Yawen 119 Wu, Zhichao 477 X Xia, Yong 13, 109 Xie, Wangduo 77 Xie, Weijie 518 Xie, Xiaohui 759 Xie, Yutong 13, 109 Xiong, Xiaosong 3 Xiong, Yifeng 759
Author Index
Xu, Cunbao 749 Xu, Hong 508 Xu, Le 384 Xu, Mengya 35 Xu, Xiaowei 119 Xu, Xiaoyin 130 Xu, Yangyang 728 Xu, Yi 259 Xu, Zhe 363 Xue, Cheng 405 Y Yacoub, Sophie 152 Yamazaki, Kimihiro 394 Yan, Pingkun 728 Yan, Xiangyi 759 Yan, Zhongnuo 342 Yang, Guanyu 405 Yang, Jie 674 Yang, Jinzhu 310 Yang, Meng 749 Yang, Ming 717 Yang, Wei 518 Yang, Xin 342 Yang, Yuncheng 674 Yang, Ziqi 749 Yang, Ziyuan 270 Yao, Heming 477 Yao, Qingsong 24 Ye, Jin 674 Ye, Zhanhao 562 You, Chenyu 759 Yu, Hui 238 Yu, Zhimiao 259 Z Zaharchuk, Greg 279 Zaiane, Osmar R. 310 Zeng, Biqing 562
785
Zeng, Dewen 119 Zeng, Zi 87 Zhang, Daoqiang 717 Zhang, Jiajin 728 Zhang, Jianpeng 13, 109 Zhang, Kai 310 Zhang, Miao 477 Zhang, Min 130 Zhang, Shaoteng 109 Zhang, Shu 98 Zhang, Songyao 416 Zhang, Tuo 416 Zhang, Weitong 637 Zhang, Xiao 3 Zhang, Ya 87 Zhang, Yang 35 Zhang, Yi 238, 270 Zhang, Yuanji 342 Zhang, Yulun 684 Zhang, Zheng 562 Zhang, Ziji 98 Zhao, He 183, 216 Zhao, Pengfei 310 Zhao, Qingyu 279 Zhao, Xinkai 663 Zhao, Zhanpeng 518 Zhao, Zixu 374 Zheng, Guoyan 497 Zheng, Kaiyi 518 Zheng, Yefeng 363 Zhong, Shenjun 374 Zhong, Tianyang 416 Zhou, Ke 518 Zhou, S. Kevin 24 Zhou, Zongwei 706 Zhu, Heqin 24 Zhu, Jianwei 749 Zimmer, Veronika A. 141 Zinkernagel, Martin 331