Unpacking the “black box” to build better AI models
When deep learning models are deployed in the real world, perhaps to detect financial fraud from credit card activity or identify cancer in medical images, they are often able to outperform humans.
But what exactly are these deep learning models learning? Does a model trained to spot skin cancer in clinical images, for example, actually learn the colors and textures of cancerous tissue, or is it flagging some other features or patterns?
These powerful machine-learning models are typically based on artificial neural networks that can have millions of nodes that process data to make predictions. Due to their complexity, researchers often call these models “black boxes” because even the scientists who build them don’t understand everything that is going on under the hood.
Stefanie Jegelka isn’t satisfied with that “black box” explanation. A newly tenured associate professor in the MIT Department of Electrical Engineering and Computer Science, Jegelka is digging deep into deep learning to understand what these models can learn and how they behave, and how to build certain prior information into these models.
“At the end of the day, what a deep-learning model will learn depends on so many factors. But building an understanding that is relevant in practice will help us design better models, and also help us understand what is going on inside them so we know when we can deploy a model and when we can’t. That is critically important,” says Jegelka, who is also a member of the Computer Science and Artificial Intelligence Laboratory (CSAIL) and the Institute for Data, Systems, and Society (IDSS).
Jegelka is particularly interested in optimizing machine-learning models when input data are in the form of graphs. Graph data pose specific challenges: For instance, information in the data consists of both information about individual nodes and edges, as well as the structure — what is connected to what. In addition, graphs have mathematical symmetries that need to be respected by the machine-learning model so that, for instance, the same graph always leads to the same prediction. Building such symmetries into a machine-learning model is usually not easy.
Take molecules, for instance. Molecules can be represented as graphs, with vertices that correspond to atoms and edges that correspond to chemical bonds between them. Drug companies may want to use deep learning to rapidly predict the properties of many molecules, narrowing down the number they must physically test in the lab.
Jegelka studies methods to build mathematical machine-learning models that can effectively take graph data as an input and output something else, in this case a prediction of a molecule’s chemical properties. This is particularly challenging since a molecule’s properties are determined not only by the atoms within it, but also by the connections between them.
Other examples of machine learning on graphs include traffic routing, chip design, and recommender systems.
Designing these models is made even more difficult by the fact that data used to train them are often different from data the models see in practice. Perhaps the model was trained using small molecular graphs or traffic networks, but the graphs it sees once deployed are larger or more complex.
In this case, what can researchers expect this model to learn, and will it still work in practice if the real-world data are different?
“Your model is not going to be able to learn everything because of some hardness problems in computer science, but what you can learn and what you can’t learn depends on how you set the model up,” Jegelka says.
She approaches this question by combining her passion for algorithms and discrete mathematics with her excitement for machine learning.
From butterflies to bioinformatics
Jegelka grew up in a small town in Germany and became interested in science when she was a high school student; a supportive teacher encouraged her to participate in an international science competition. She and her teammates from the U.S. and Singapore won an award for a website they created about butterflies, in three languages.
“For our project, we took images of wings with a scanning electron microscope at a local university of applied sciences. I also got the opportunity to use a high-speed camera at Mercedes Benz — this camera usually filmed combustion engines — which I used to capture a slow-motion video of the movement of a butterfly’s wings. That was the first time I really got in touch with science and exploration,” she recalls.
Intrigued by both biology and mathematics, Jegelka decided to study bioinformatics at the University of Tübingen and the University of Texas at Austin. She had a few opportunities to conduct research as an undergraduate, including an internship in computational neuroscience at Georgetown University, but wasn’t sure what career to follow.
When she returned for her final year of college, Jegelka moved in with two roommates who were working as research assistants at the Max Planck Institute in Tübingen.
“They were working on machine learning, and that sounded really cool to me. I had to write my bachelor’s thesis, so I asked at the institute if they had a project for me. I started working on machine learning at the Max Planck Institute and I loved it. I learned so much there, and it was a great place for research,” she says.
She stayed on at the Max Planck Institute to complete a master’s thesis, and then embarked on a PhD in machine learning at the Max Planck Institute and the Swiss Federal Institute of Technology.
During her PhD, she explored how concepts from discrete mathematics can help improve machine-learning techniques.
Teaching models to learn
The more Jegelka learned about machine learning, the more intrigued she became by the challenges of understanding how models behave, and how to steer this behavior.
“You can do so much with machine learning, but only if you have the right model and data. It is not just a black-box thing where you throw it at the data and it works. You actually have to think about it, its properties, and what you want the model to learn and do,” she says.
After completing a postdoc at the University of California at Berkeley, Jegelka was hooked on research and decided to pursue a career in academia. She joined the faculty at MIT in 2015 as an assistant professor.
“What I really loved about MIT, from the very beginning, was that the people really care deeply about research and creativity. That is what I appreciate the most about MIT. The people here really value originality and depth in research,” she says.
That focus on creativity has enabled Jegelka to explore a broad range of topics.
In collaboration with other faculty at MIT, she studies machine-learning applications in biology, imaging, computer vision, and materials science.
But what really drives Jegelka is probing the fundamentals of machine learning, and most recently, the issue of robustness. Often, a model performs well on training data, but its performance deteriorates when it is deployed on slightly different data. Building prior knowledge into a model can make it more reliable, but understanding what information the model needs to be successful and how to build it in is not so simple, she says.
She is also exploring methods to improve the performance of machine-learning models for image classification.
Image classification models are everywhere, from the facial recognition systems on mobile phones to tools that identify fake accounts on social media. These models need massive amounts of data for training, but since it is expensive for humans to hand-label millions of images, researchers often use unlabeled datasets to pretrain models instead.
These models then reuse the representations they have learned when they are fine-tuned later for a specific task.
Ideally, researchers want the model to learn as much as it can during pretraining, so it can apply that knowledge to its downstream task. But in practice, these models often learn only a few simple correlations — like that one image has sunshine and one has shade — and use these “shortcuts” to classify images.
“We showed that this is a problem in ‘contrastive learning,’ which is a standard technique for pre-training, both theoretically and empirically. But we also show that you can influence the kinds of information the model will learn to represent by modifying the types of data you show the model. This is one step toward understanding what models are actually going to do in practice,” she says.
Researchers still don’t understand everything that goes on inside a deep-learning model, or details about how they can influence what a model learns and how it behaves, but Jegelka looks forward to continue exploring these topics.
“Often in machine learning, we see something happen in practice and we try to understand it theoretically. This is a huge challenge. You want to build an understanding that matches what you see in practice, so that you can do better. We are still just at the beginning of understanding this,” she says.
Outside the lab, Jegelka is a fan of music, art, traveling, and cycling. But these days, she enjoys spending most of her free time with her preschool-aged daughter. More